Browse Source

Explain SQL for dialects

Jinzhu 4 years ago
parent
commit
bc5ceff82f

+ 8 - 0
callbacks.go

@@ -3,6 +3,7 @@ package gorm
 import (
 	"errors"
 	"fmt"
+	"time"
 
 	"github.com/jinzhu/gorm/logger"
 	"github.com/jinzhu/gorm/schema"
@@ -69,6 +70,7 @@ func (cs *callbacks) Raw() *processor {
 }
 
 func (p *processor) Execute(db *DB) {
+	curTime := time.Now()
 	if stmt := db.Statement; stmt != nil {
 		if stmt.Model == nil {
 			stmt.Model = stmt.Dest
@@ -86,6 +88,12 @@ func (p *processor) Execute(db *DB) {
 	for _, f := range p.fns {
 		f(db)
 	}
+
+	if stmt := db.Statement; stmt != nil {
+		db.Logger.RunWith(logger.Info, func() {
+			db.Logger.Info(db.Dialector.Explain(stmt.SQL.String(), stmt.Vars))
+		})
+	}
 }
 
 func (p *processor) Get(name string) func(*DB) {

+ 8 - 0
dialects/mssql/mssql.go

@@ -3,11 +3,13 @@ package mssql
 import (
 	"database/sql"
 	"fmt"
+	"regexp"
 	"strconv"
 
 	_ "github.com/denisenkom/go-mssqldb"
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/callbacks"
+	"github.com/jinzhu/gorm/logger"
 	"github.com/jinzhu/gorm/migrator"
 	"github.com/jinzhu/gorm/schema"
 )
@@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte {
 	return [2]byte{'"', '"'} // `name`
 }
 
+var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
+
+func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
+	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
+}
+
 func (dialector Dialector) DataTypeOf(field *schema.Field) string {
 	switch field.DataType {
 	case schema.Bool:

+ 7 - 1
dialects/mssql/mssql_test.go

@@ -2,6 +2,7 @@ package mssql_test
 
 import (
 	"fmt"
+	"os"
 	"testing"
 
 	"github.com/jinzhu/gorm"
@@ -15,7 +16,12 @@ var (
 )
 
 func init() {
-	if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil {
+	dsn := "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
+	if os.Getenv("GORM_DSN") != "" {
+		dsn = os.Getenv("GORM_DSN")
+	}
+
+	if DB, err = gorm.Open(mssql.Open(dsn), &gorm.Config{}); err != nil {
 		panic(fmt.Sprintf("failed to initialize database, got error %v", err))
 	}
 }

+ 5 - 0
dialects/mysql/mysql.go

@@ -8,6 +8,7 @@ import (
 	_ "github.com/go-sql-driver/mysql"
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/callbacks"
+	"github.com/jinzhu/gorm/logger"
 	"github.com/jinzhu/gorm/migrator"
 	"github.com/jinzhu/gorm/schema"
 )
@@ -42,6 +43,10 @@ func (dialector Dialector) QuoteChars() [2]byte {
 	return [2]byte{'`', '`'} // `name`
 }
 
+func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
+	return logger.ExplainSQL(sql, nil, `"`, vars...)
+}
+
 func (dialector Dialector) DataTypeOf(field *schema.Field) string {
 	switch field.DataType {
 	case schema.Bool:

+ 7 - 1
dialects/mysql/mysql_test.go

@@ -2,6 +2,7 @@ package mysql_test
 
 import (
 	"fmt"
+	"os"
 	"testing"
 
 	"github.com/jinzhu/gorm"
@@ -15,7 +16,12 @@ var (
 )
 
 func init() {
-	if DB, err = gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), &gorm.Config{}); err != nil {
+	dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
+	if os.Getenv("GORM_DSN") != "" {
+		dsn = os.Getenv("GORM_DSN")
+	}
+
+	if DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}); err != nil {
 		panic(fmt.Sprintf("failed to initialize database, got error %v", err))
 	}
 }

+ 8 - 0
dialects/postgres/postgres.go

@@ -3,10 +3,12 @@ package postgres
 import (
 	"database/sql"
 	"fmt"
+	"regexp"
 	"strconv"
 
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/callbacks"
+	"github.com/jinzhu/gorm/logger"
 	"github.com/jinzhu/gorm/migrator"
 	"github.com/jinzhu/gorm/schema"
 	_ "github.com/lib/pq"
@@ -44,6 +46,12 @@ func (dialector Dialector) QuoteChars() [2]byte {
 	return [2]byte{'"', '"'} // "name"
 }
 
+var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
+
+func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
+	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
+}
+
 func (dialector Dialector) DataTypeOf(field *schema.Field) string {
 	switch field.DataType {
 	case schema.Bool:

+ 7 - 1
dialects/postgres/postgres_test.go

@@ -2,6 +2,7 @@ package postgres_test
 
 import (
 	"fmt"
+	"os"
 	"testing"
 
 	"github.com/jinzhu/gorm"
@@ -15,7 +16,12 @@ var (
 )
 
 func init() {
-	if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil {
+	dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
+	if os.Getenv("GORM_DSN") != "" {
+		dsn = os.Getenv("GORM_DSN")
+	}
+
+	if DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}); err != nil {
 		panic(fmt.Sprintf("failed to initialize database, got error %v", err))
 	}
 }

+ 5 - 0
dialects/sqlite/sqlite.go

@@ -5,6 +5,7 @@ import (
 
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/callbacks"
+	"github.com/jinzhu/gorm/logger"
 	"github.com/jinzhu/gorm/migrator"
 	"github.com/jinzhu/gorm/schema"
 	_ "github.com/mattn/go-sqlite3"
@@ -41,6 +42,10 @@ func (dialector Dialector) QuoteChars() [2]byte {
 	return [2]byte{'`', '`'} // `name`
 }
 
+func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
+	return logger.ExplainSQL(sql, nil, `"`, vars...)
+}
+
 func (dialector Dialector) DataTypeOf(field *schema.Field) string {
 	switch field.DataType {
 	case schema.Bool:

+ 1 - 0
interfaces.go

@@ -14,6 +14,7 @@ type Dialector interface {
 	DataTypeOf(*schema.Field) string
 	BindVar(stmt *Statement, v interface{}) string
 	QuoteChars() [2]byte
+	Explain(sql string, vars ...interface{}) string
 }
 
 // CommonDB common db interface

+ 12 - 5
logger/logger.go

@@ -11,9 +11,9 @@ type LogLevel int
 var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)}
 
 const (
-	Info LogLevel = iota + 1
+	Error LogLevel = iota + 1
 	Warn
-	Error
+	Info
 )
 
 // Interface logger interface
@@ -22,6 +22,7 @@ type Interface interface {
 	Info(string, ...interface{})
 	Warn(string, ...interface{})
 	Error(string, ...interface{})
+	RunWith(LogLevel, func())
 }
 
 // Writer log writer interface
@@ -40,21 +41,27 @@ func (logger Logger) LogMode(level LogLevel) Interface {
 
 // Info print info
 func (logger Logger) Info(msg string, data ...interface{}) {
-	if logger.logLevel <= Info {
+	if logger.logLevel >= Info {
 		logger.Print("[info] " + fmt.Sprintf(msg, data...))
 	}
 }
 
 // Warn print warn messages
 func (logger Logger) Warn(msg string, data ...interface{}) {
-	if logger.logLevel <= Warn {
+	if logger.logLevel >= Warn {
 		logger.Print("[warn] " + fmt.Sprintf(msg, data...))
 	}
 }
 
 // Error print error messages
 func (logger Logger) Error(msg string, data ...interface{}) {
-	if logger.logLevel <= Error {
+	if logger.logLevel >= Error {
 		logger.Print("[error] " + fmt.Sprintf(msg, data...))
 	}
 }
+
+func (logger Logger) RunWith(logLevel LogLevel, fc func()) {
+	if logger.logLevel >= logLevel {
+		fc()
+	}
+}

+ 68 - 0
logger/sql.go

@@ -0,0 +1,68 @@
+package logger
+
+import (
+	"database/sql/driver"
+	"fmt"
+	"regexp"
+	"strconv"
+	"strings"
+	"time"
+	"unicode"
+)
+
+func isPrintable(s []byte) bool {
+	for _, r := range s {
+		if !unicode.IsPrint(rune(r)) {
+			return false
+		}
+	}
+	return true
+}
+
+func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string {
+	for idx, v := range vars {
+		if valuer, ok := v.(driver.Valuer); ok {
+			v, _ = valuer.Value()
+		}
+
+		switch v := v.(type) {
+		case bool:
+			vars[idx] = fmt.Sprint(v)
+		case time.Time:
+			vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
+		case *time.Time:
+			vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
+		case []byte:
+			if isPrintable(v) {
+				vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
+			} else {
+				vars[idx] = escaper + "<binary>" + escaper
+			}
+		case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
+			vars[idx] = fmt.Sprintf("%d", v)
+		case float64, float32:
+			vars[idx] = fmt.Sprintf("%.6f", v)
+		case string:
+			vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
+		default:
+			if v == nil {
+				vars[idx] = "NULL"
+			} else {
+				vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
+			}
+		}
+	}
+
+	if numericPlaceholder == nil {
+		for _, v := range vars {
+			sql = strings.Replace(sql, "?", v.(string), 1)
+		}
+	} else {
+		sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1")
+		for idx, v := range vars {
+			sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1)
+		}
+	}
+
+	return sql
+}

+ 45 - 0
logger/sql_test.go

@@ -0,0 +1,45 @@
+package logger_test
+
+import (
+	"regexp"
+	"testing"
+
+	"github.com/jinzhu/gorm/logger"
+	"github.com/jinzhu/now"
+)
+
+func TestExplainSQL(t *testing.T) {
+	tt := now.MustParse("2020-02-23 11:10:10")
+
+	results := []struct {
+		SQL           string
+		NumericRegexp *regexp.Regexp
+		Vars          []interface{}
+		Result        string
+	}{
+		{
+			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)",
+			NumericRegexp: nil,
+			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
+			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
+		},
+		{
+			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)",
+			NumericRegexp: regexp.MustCompile("@p(\\d+)"),
+			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
+			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
+		},
+		{
+			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)",
+			NumericRegexp: regexp.MustCompile("\\$(\\d+)"),
+			Vars:          []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt},
+			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
+		},
+	}
+
+	for idx, r := range results {
+		if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result {
+			t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result)
+		}
+	}
+}

+ 1 - 1
tests/tests_all.sh

@@ -1,4 +1,4 @@
-dialects=("postgres" "mysql" "mssql" "sqlite")
+dialects=("sqlite" "mysql" "postgres" "mssql")
 
 if [[ $(pwd) == *"gorm/tests"* ]]; then
   cd ..