Bladeren bron

Refactor SQL Explainer

Jinzhu 4 jaren geleden
bovenliggende
commit
c3b798aec8
2 gewijzigde bestanden met toevoegingen van 45 en 18 verwijderingen
  1. 23 8
      logger/sql.go
  2. 22 10
      logger/sql_test.go

+ 23 - 8
logger/sql.go

@@ -3,6 +3,7 @@ package logger
 import (
 	"database/sql/driver"
 	"fmt"
+	"reflect"
 	"regexp"
 	"strconv"
 	"strings"
@@ -19,19 +20,17 @@ func isPrintable(s []byte) bool {
 	return true
 }
 
+var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
+
 func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string {
-	for idx, v := range vars {
-		if valuer, ok := v.(driver.Valuer); ok {
-			v, _ = valuer.Value()
-		}
+	var convertParams func(interface{}, int)
 
+	convertParams = func(v interface{}, idx int) {
 		switch v := v.(type) {
 		case bool:
 			vars[idx] = fmt.Sprint(v)
 		case time.Time:
 			vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
-		case *time.Time:
-			vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
 		case []byte:
 			if isPrintable(v) {
 				vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
@@ -48,19 +47,35 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
 			if v == nil {
 				vars[idx] = "NULL"
 			} else {
+				rv := reflect.Indirect(reflect.ValueOf(v))
+				for _, t := range convertableTypes {
+					if rv.Type().ConvertibleTo(t) {
+						convertParams(rv.Convert(t).Interface(), idx)
+						return
+					}
+				}
+
 				vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
 			}
 		}
 	}
 
+	for idx, v := range vars {
+		if valuer, ok := v.(driver.Valuer); ok {
+			v, _ = valuer.Value()
+		}
+
+		convertParams(v, idx)
+	}
+
 	if numericPlaceholder == nil {
 		for _, v := range vars {
 			sql = strings.Replace(sql, "?", v.(string), 1)
 		}
 	} else {
-		sql = numericPlaceholder.ReplaceAllString(sql, "$$$$$1")
+		sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
 		for idx, v := range vars {
-			sql = strings.Replace(sql, "$$"+strconv.Itoa(idx), v.(string), 1)
+			sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1)
 		}
 	}
 

+ 22 - 10
logger/sql_test.go

@@ -9,7 +9,13 @@ import (
 )
 
 func TestExplainSQL(t *testing.T) {
-	tt := now.MustParse("2020-02-23 11:10:10")
+	type role string
+	type password []byte
+	var (
+		tt     = now.MustParse("2020-02-23 11:10:10")
+		myrole = role("admin")
+		pwd    = password([]byte("pass"))
+	)
 
 	results := []struct {
 		SQL           string
@@ -18,22 +24,28 @@ func TestExplainSQL(t *testing.T) {
 		Result        string
 	}{
 		{
-			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (?, ?, ?, ?, ?, ?, ?, ?)",
+			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
 			NumericRegexp: nil,
-			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
-			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
+			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
+			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
 		},
 		{
-			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7)",
+			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p0, @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
 			NumericRegexp: regexp.MustCompile("@p(\\d+)"),
-			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil},
-			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
+			Vars:          []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
+			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
 		},
 		{
-			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ($2, $3, $0, $1, $6, $7, $4, $5)",
+			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($2, $3, $0, $1, $6, $7, $4, $5, $8, $9, $10)",
 			NumericRegexp: regexp.MustCompile("\\$(\\d+)"),
-			Vars:          []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt},
-			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL)`,
+			Vars:          []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
+			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
+		},
+		{
+			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p0, @p10, @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9)",
+			NumericRegexp: regexp.MustCompile("@p(\\d+)"),
+			Vars:          []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
+			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
 		},
 	}