Browse Source

Fix parse stmt ReflectValue

Jinzhu 4 years ago
parent
commit
04adbaf7f6

+ 3 - 3
callbacks.go

@@ -3,6 +3,7 @@ package gorm
 import (
 	"errors"
 	"fmt"
+	"reflect"
 	"time"
 
 	"github.com/jinzhu/gorm/logger"
@@ -77,12 +78,11 @@ func (p *processor) Execute(db *DB) {
 		}
 
 		if stmt.Model != nil {
-			err := stmt.Parse(stmt.Model)
-
-			if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
+			if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
 				db.AddError(err)
 			}
 		}
+		stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest))
 	}
 
 	for _, f := range p.fns {

+ 1 - 1
logger/sql.go

@@ -84,7 +84,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
 	} else {
 		sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
 		for idx, v := range vars {
-			sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1)
+			sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1)
 		}
 	}
 

+ 1 - 1
schema/callbacks_test.go

@@ -19,7 +19,7 @@ func (UserWithCallback) AfterCreate(*gorm.DB) {
 }
 
 func TestCallback(t *testing.T) {
-	user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
+	user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse user with callback, got error %v", err)
 	}

+ 1 - 1
schema/check_test.go

@@ -15,7 +15,7 @@ type UserCheck struct {
 }
 
 func TestParseCheck(t *testing.T) {
-	user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
+	user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse user check, got error %v", err)
 	}

+ 1 - 1
schema/field.go

@@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 		var err error
 		field.Creatable = false
 		field.Updatable = false
-		if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
+		if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
 			schema.err = err
 		}
 		for _, ef := range field.EmbeddedSchema.Fields {

+ 12 - 12
schema/field_test.go

@@ -14,8 +14,8 @@ import (
 
 func TestFieldValuerAndSetter(t *testing.T) {
 	var (
-		userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
-		user             = tests.User{
+		userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
+		user          = tests.User{
 			Model: gorm.Model{
 				ID:        10,
 				CreatedAt: time.Now(),
@@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) {
 
 func TestPointerFieldValuerAndSetter(t *testing.T) {
 	var (
-		userSchema, _, _      = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
-		name                  = "pointer_field_valuer_and_setter"
-		age              uint = 18
-		active                = true
-		user                  = User{
+		userSchema, _      = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
+		name               = "pointer_field_valuer_and_setter"
+		age           uint = 18
+		active             = true
+		user               = User{
 			Model: &gorm.Model{
 				ID:        10,
 				CreatedAt: time.Now(),
@@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
 
 func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
 	var (
-		userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
-		name             = "advanced_data_type_valuer_and_setter"
-		deletedAt        = mytime(time.Now())
-		isAdmin          = mybool(false)
-		user             = AdvancedDataTypeUser{
+		userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
+		name          = "advanced_data_type_valuer_and_setter"
+		deletedAt     = mytime(time.Now())
+		isAdmin       = mybool(false)
+		user          = AdvancedDataTypeUser{
 			ID:           sql.NullInt64{Int64: 10, Valid: true},
 			Name:         &sql.NullString{String: name, Valid: true},
 			Birthday:     sql.NullTime{Time: time.Now(), Valid: true},

+ 1 - 1
schema/index_test.go

@@ -19,7 +19,7 @@ type UserIndex struct {
 }
 
 func TestParseIndex(t *testing.T) {
-	user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
+	user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse user index, got error %v", err)
 	}

+ 2 - 2
schema/relationship.go

@@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
 		}
 	)
 
-	if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
+	if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
 		schema.err = err
 		return
 	}
@@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 		}
 	}
 
-	if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
+	if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
 		schema.err = err
 	}
 	relation.JoinTable.Name = many2many

+ 8 - 8
schema/schema.go

@@ -53,22 +53,21 @@ func (schema Schema) LookUpField(name string) *Field {
 }
 
 // get data type from dialector
-func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
-	reflectValue := reflect.ValueOf(dest)
-	modelType := reflectValue.Type()
+func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
+	modelType := reflect.ValueOf(dest).Type()
 	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
 		modelType = modelType.Elem()
 	}
 
 	if modelType.Kind() != reflect.Struct {
 		if modelType.PkgPath() == "" {
-			return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
+			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 		}
-		return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
+		return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 	}
 
 	if v, ok := cacheStore.Load(modelType); ok {
-		return v.(*Schema), reflectValue, nil
+		return v.(*Schema), nil
 	}
 
 	schema := &Schema{
@@ -167,6 +166,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
 		}
 	}
 
+	reflectValue := reflect.Indirect(reflect.New(modelType))
 	callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
 	for _, name := range callbacks {
 		if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
@@ -185,10 +185,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
 	for _, field := range schema.Fields {
 		if field.DataType == "" && field.Creatable {
 			if schema.parseRelation(field); schema.err != nil {
-				return schema, reflectValue, schema.err
+				return schema, schema.err
 			}
 		}
 	}
 
-	return schema, reflectValue, schema.err
+	return schema, schema.err
 }

+ 3 - 3
schema/schema_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestParseSchema(t *testing.T) {
-	user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
+	user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse user, got error %v", err)
 	}
@@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
 }
 
 func TestParseSchemaWithPointerFields(t *testing.T) {
-	user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
+	user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse pointer user, got error %v", err)
 	}
@@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
 }
 
 func TestParseSchemaWithAdvancedDataType(t *testing.T) {
-	user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
+	user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse pointer user, got error %v", err)
 	}

+ 2 - 6
statement.go

@@ -274,12 +274,8 @@ func (stmt *Statement) Build(clauses ...string) {
 }
 
 func (stmt *Statement) Parse(value interface{}) (err error) {
-	if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
-		stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)
-
-		if stmt.Table == "" {
-			stmt.Table = stmt.Schema.Table
-		}
+	if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
+		stmt.Table = stmt.Schema.Table
 	}
 	return err
 }