Sfoglia il codice sorgente

Test Field Valuer, Setter

Jinzhu 4 anni fa
parent
commit
faee069a9f
6 ha cambiato i file con 224 aggiunte e 89 eliminazioni
  1. 108 74
      schema/field.go
  2. 77 10
      schema/field_test.go
  3. 3 3
      schema/relationship.go
  4. 32 0
      schema/schema_helper_test.go
  5. 2 1
      schema/schema_test.go
  6. 2 1
      tests/model.go

+ 108 - 74
schema/field.go

@@ -25,52 +25,53 @@ const (
 )
 
 type Field struct {
-	Name            string
-	DBName          string
-	BindNames       []string
-	DataType        DataType
-	DBDataType      string
-	PrimaryKey      bool
-	AutoIncrement   bool
-	Creatable       bool
-	Updatable       bool
-	HasDefaultValue bool
-	DefaultValue    string
-	NotNull         bool
-	Unique          bool
-	Comment         string
-	Size            int
-	Precision       int
-	FieldType       reflect.Type
-	StructField     reflect.StructField
-	Tag             reflect.StructTag
-	TagSettings     map[string]string
-	Schema          *Schema
-	EmbeddedSchema  *Schema
-	ReflectValuer   func(reflect.Value) reflect.Value
-	Valuer          func(reflect.Value) interface{}
-	Setter          func(reflect.Value, interface{}) error
+	Name              string
+	DBName            string
+	BindNames         []string
+	DataType          DataType
+	DBDataType        string
+	PrimaryKey        bool
+	AutoIncrement     bool
+	Creatable         bool
+	Updatable         bool
+	HasDefaultValue   bool
+	DefaultValue      string
+	NotNull           bool
+	Unique            bool
+	Comment           string
+	Size              int
+	Precision         int
+	FieldType         reflect.Type
+	IndirectFieldType reflect.Type
+	StructField       reflect.StructField
+	Tag               reflect.StructTag
+	TagSettings       map[string]string
+	Schema            *Schema
+	EmbeddedSchema    *Schema
+	ReflectValuer     func(reflect.Value) reflect.Value
+	Valuer            func(reflect.Value) interface{}
+	Setter            func(reflect.Value, interface{}) error
 }
 
 func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 	field := &Field{
-		Name:        fieldStruct.Name,
-		BindNames:   []string{fieldStruct.Name},
-		FieldType:   fieldStruct.Type,
-		StructField: fieldStruct,
-		Creatable:   true,
-		Updatable:   true,
-		Tag:         fieldStruct.Tag,
-		TagSettings: ParseTagSetting(fieldStruct.Tag),
-		Schema:      schema,
+		Name:              fieldStruct.Name,
+		BindNames:         []string{fieldStruct.Name},
+		FieldType:         fieldStruct.Type,
+		IndirectFieldType: fieldStruct.Type,
+		StructField:       fieldStruct,
+		Creatable:         true,
+		Updatable:         true,
+		Tag:               fieldStruct.Tag,
+		TagSettings:       ParseTagSetting(fieldStruct.Tag),
+		Schema:            schema,
 	}
 
-	for field.FieldType.Kind() == reflect.Ptr {
-		field.FieldType = field.FieldType.Elem()
+	for field.IndirectFieldType.Kind() == reflect.Ptr {
+		field.IndirectFieldType = field.IndirectFieldType.Elem()
 	}
 
-	fieldValue := reflect.New(field.FieldType)
-
+	fieldValue := reflect.New(field.IndirectFieldType)
 	// if field is valuer, used its value or first fields as data type
 	if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer {
 		var overrideFieldValue bool
@@ -79,10 +80,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 			fieldValue = reflect.ValueOf(v)
 		}
 
-		if field.FieldType.Kind() == reflect.Struct {
-			for i := 0; i < field.FieldType.NumField(); i++ {
+		if field.IndirectFieldType.Kind() == reflect.Struct {
+			for i := 0; i < field.IndirectFieldType.NumField(); i++ {
 				if !overrideFieldValue {
-					newFieldType := field.FieldType.Field(i).Type
+					newFieldType := field.IndirectFieldType.Field(i).Type
 					for newFieldType.Kind() == reflect.Ptr {
 						newFieldType = newFieldType.Elem()
 					}
@@ -92,7 +93,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 				}
 
 				// copy tag settings from valuer
-				for key, value := range ParseTagSetting(field.FieldType.Field(i).Tag) {
+				for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag) {
 					if _, ok := field.TagSettings[key]; !ok {
 						field.TagSettings[key] = value
 					}
@@ -197,7 +198,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 			if field.FieldType.Kind() == reflect.Struct {
 				ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
 			} else {
-				ef.StructField.Index = append([]int{-fieldStruct.Index[0]}, ef.StructField.Index...)
+				ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
 			}
 
 			if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
@@ -235,26 +236,29 @@ func (field *Field) setupValuerAndSetter() {
 	switch {
 	case len(field.StructField.Index) == 1:
 		field.Valuer = func(value reflect.Value) interface{} {
-			return value.Field(field.StructField.Index[0]).Interface()
+			return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface()
 		}
 	case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
 		field.Valuer = func(value reflect.Value) interface{} {
-			return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface()
+			return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface()
 		}
 	default:
 		field.Valuer = func(value reflect.Value) interface{} {
-			v := value.Field(field.StructField.Index[0])
-			for _, idx := range field.StructField.Index[1:] {
-				if v.Kind() == reflect.Ptr {
+			v := reflect.Indirect(value)
+
+			for _, idx := range field.StructField.Index {
+				if idx >= 0 {
+					v = v.Field(idx)
+				} else {
+					v = v.Field(-idx - 1)
+
 					if v.Type().Elem().Kind() == reflect.Struct {
 						if !v.IsNil() {
-							v = v.Elem().Field(-idx)
-							continue
+							v = v.Elem()
 						}
+					} else {
+						return nil
 					}
-					return nil
-				} else {
-					v = v.Field(idx)
 				}
 			}
 			return v.Interface()
@@ -266,7 +270,7 @@ func (field *Field) setupValuerAndSetter() {
 	case len(field.StructField.Index) == 1:
 		if field.FieldType.Kind() == reflect.Ptr {
 			field.ReflectValuer = func(value reflect.Value) reflect.Value {
-				fieldValue := value.Field(field.StructField.Index[0])
+				fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
 				if fieldValue.IsNil() {
 					fieldValue.Set(reflect.New(field.FieldType.Elem()))
 				}
@@ -274,31 +278,33 @@ func (field *Field) setupValuerAndSetter() {
 			}
 		} else {
 			field.ReflectValuer = func(value reflect.Value) reflect.Value {
-				return value.Field(field.StructField.Index[0])
+				return reflect.Indirect(value).Field(field.StructField.Index[0])
 			}
 		}
 	case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
-		field.Valuer = func(value reflect.Value) interface{} {
-			return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface()
+		field.ReflectValuer = func(value reflect.Value) reflect.Value {
+			return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
 		}
 	default:
 		field.ReflectValuer = func(value reflect.Value) reflect.Value {
-			v := value.Field(field.StructField.Index[0])
-			for _, idx := range field.StructField.Index[1:] {
+			v := reflect.Indirect(value)
+			for _, idx := range field.StructField.Index {
+				if idx >= 0 {
+					v = v.Field(idx)
+				} else {
+					v = v.Field(-idx - 1)
+				}
+
 				if v.Kind() == reflect.Ptr {
 					if v.Type().Elem().Kind() == reflect.Struct {
 						if v.IsNil() {
 							v.Set(reflect.New(v.Type().Elem()))
 						}
+					}
 
-						if idx >= 0 {
-							v = v.Elem().Field(idx)
-						} else {
-							v = v.Elem().Field(-idx)
-						}
+					if idx < len(field.StructField.Index)-1 {
+						v = v.Elem()
 					}
-				} else {
-					v = v.Field(idx)
 				}
 			}
 			return v
@@ -490,7 +496,7 @@ func (field *Field) setupValuerAndSetter() {
 		}
 	default:
 		fieldValue := reflect.New(field.FieldType)
-		switch fieldValue.Interface().(type) {
+		switch fieldValue.Elem().Interface().(type) {
 		case time.Time:
 			field.Setter = func(value reflect.Value, v interface{}) error {
 				switch data := v.(type) {
@@ -528,6 +534,20 @@ func (field *Field) setupValuerAndSetter() {
 				return nil
 			}
 		default:
+			if _, ok := fieldValue.Interface().(sql.Scanner); ok {
+				field.Setter = func(value reflect.Value, v interface{}) (err error) {
+					if valuer, ok := v.(driver.Valuer); ok {
+						if v, err = valuer.Value(); err == nil {
+							err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
+						}
+					} else {
+						err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
+					}
+					return
+				}
+				return
+			}
+
 			if fieldValue.CanAddr() {
 				if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
 					field.Setter = func(value reflect.Value, v interface{}) (err error) {
@@ -544,14 +564,28 @@ func (field *Field) setupValuerAndSetter() {
 				}
 			}
 
-			field.Setter = func(value reflect.Value, v interface{}) (err error) {
-				reflectV := reflect.ValueOf(v)
-				if reflectV.Type().ConvertibleTo(field.FieldType) {
-					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
-				} else {
-					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+			if field.FieldType.Kind() == reflect.Ptr {
+				field.Setter = func(value reflect.Value, v interface{}) (err error) {
+					reflectV := reflect.ValueOf(v)
+					if reflectV.Type().ConvertibleTo(field.FieldType) {
+						field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+					} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
+						field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
+					} else {
+						return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+					}
+					return nil
+				}
+			} else {
+				field.Setter = func(value reflect.Value, v interface{}) (err error) {
+					reflectV := reflect.ValueOf(v)
+					if reflectV.Type().ConvertibleTo(field.FieldType) {
+						field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+					} else {
+						return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+					}
+					return nil
 				}
-				return nil
 			}
 		}
 	}

+ 77 - 10
schema/field_test.go

@@ -24,10 +24,12 @@ func TestFieldValuerAndSetter(t *testing.T) {
 			Name:     "valuer_and_setter",
 			Age:      18,
 			Birthday: tests.Now(),
+			Active:   true,
 		}
-		reflectValue = reflect.ValueOf(user)
+		reflectValue = reflect.ValueOf(&user)
 	)
 
+	// test valuer
 	values := map[string]interface{}{
 		"name":       user.Name,
 		"id":         user.ID,
@@ -35,30 +37,95 @@ func TestFieldValuerAndSetter(t *testing.T) {
 		"deleted_at": user.DeletedAt,
 		"age":        user.Age,
 		"birthday":   user.Birthday,
+		"active":     true,
 	}
+	checkField(t, userSchema, reflectValue, values)
 
-	for k, v := range values {
-		if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v {
-			t.Errorf("user's %v value should equal %+v, but got %+v", k, v, rv)
-		}
-	}
-
+	// test setter
 	newValues := map[string]interface{}{
 		"name":       "valuer_and_setter_2",
-		"id":         "2",
+		"id":         2,
 		"created_at": time.Now(),
 		"deleted_at": tests.Now(),
 		"age":        20,
 		"birthday":   time.Now(),
+		"active":     false,
 	}
 
 	for k, v := range newValues {
 		if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
 			t.Errorf("no error should happen when assign value to field %v", k)
 		}
+	}
+	checkField(t, userSchema, reflectValue, newValues)
+}
+
+func TestPointerFieldValuerAndSetter(t *testing.T) {
+	var (
+		cacheMap      = sync.Map{}
+		userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{})
+		name          = "pointer_field_valuer_and_setter"
+		age           = 18
+		active        = true
+		user          = User{
+			Model: &gorm.Model{
+				ID:        10,
+				CreatedAt: time.Now(),
+				DeletedAt: tests.Now(),
+			},
+			Name:     &name,
+			Age:      &age,
+			Birthday: tests.Now(),
+			Active:   &active,
+		}
+		reflectValue = reflect.ValueOf(&user)
+	)
+
+	// test valuer
+	values := map[string]interface{}{
+		"name":       user.Name,
+		"id":         user.ID,
+		"created_at": user.CreatedAt,
+		"deleted_at": user.DeletedAt,
+		"age":        user.Age,
+		"birthday":   user.Birthday,
+		"active":     true,
+	}
+	checkField(t, userSchema, reflectValue, values)
+
+	// test setter
+	newValues := map[string]interface{}{
+		"name":       "valuer_and_setter_2",
+		"id":         2,
+		"created_at": time.Now(),
+		"deleted_at": tests.Now(),
+		"age":        20,
+		"birthday":   time.Now(),
+		"active":     false,
+	}
 
-		if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v {
-			t.Errorf("user's %v value should equal %+v after assign new value, but got %+v", k, v, rv)
+	for k, v := range newValues {
+		if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
+			t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
 		}
 	}
+	checkField(t, userSchema, reflectValue, newValues)
+}
+
+type User struct {
+	*gorm.Model
+	Name      *string
+	Age       *int
+	Birthday  *time.Time
+	Account   *tests.Account
+	Pets      []*tests.Pet
+	Toys      []tests.Toy `gorm:"polymorphic:Owner"`
+	CompanyID *int
+	Company   *tests.Company
+	ManagerID *int
+	Manager   *User
+	Team      []User           `gorm:"foreignkey:ManagerID"`
+	Languages []tests.Language `gorm:"many2many:UserSpeak"`
+	Friends   []*User          `gorm:"many2many:user_friends"`
+	Active    *bool
 }

+ 3 - 3
schema/relationship.go

@@ -54,7 +54,7 @@ type Reference struct {
 func (schema *Schema) parseRelation(field *Field) {
 	var (
 		err        error
-		fieldValue = reflect.New(field.FieldType).Interface()
+		fieldValue = reflect.New(field.IndirectFieldType).Interface()
 		relation   = &Relationship{
 			Name:        field.Name,
 			Field:       field,
@@ -74,7 +74,7 @@ func (schema *Schema) parseRelation(field *Field) {
 	} else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" {
 		schema.buildMany2ManyRelation(relation, field, many2many)
 	} else {
-		switch field.FieldType.Kind() {
+		switch field.IndirectFieldType.Kind() {
 		case reflect.Struct, reflect.Slice:
 			schema.guessRelation(relation, field, true)
 		default:
@@ -83,7 +83,7 @@ func (schema *Schema) parseRelation(field *Field) {
 	}
 
 	if relation.Type == "has" {
-		switch field.FieldType.Kind() {
+		switch field.IndirectFieldType.Kind() {
 		case reflect.Struct:
 			relation.Type = HasOne
 		case reflect.Slice:

+ 32 - 0
schema/schema_helper_test.go

@@ -2,6 +2,7 @@ package schema_test
 
 import (
 	"fmt"
+	"reflect"
 	"strings"
 	"testing"
 
@@ -189,3 +190,34 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
 		}
 	})
 }
+
+func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
+	for k, v := range values {
+		t.Run("CheckField/"+k, func(t *testing.T) {
+			field := s.FieldsByDBName[k]
+			fv := field.ValueOf(value)
+
+			if reflect.ValueOf(fv).Kind() == reflect.Ptr {
+				if reflect.ValueOf(v).Kind() == reflect.Ptr {
+					if fv != v {
+						t.Errorf("pointer expects: %p, but got %p", v, fv)
+					}
+				} else if fv == nil {
+					if v != nil {
+						t.Errorf("expects: %+v, but got nil", v)
+					}
+				} else if reflect.ValueOf(fv).Elem().Interface() != v {
+					t.Errorf("expects: %+v, but got %+v", v, fv)
+				}
+			} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
+				if reflect.ValueOf(v).Elem().Interface() != fv {
+					t.Errorf("expects: %+v, but got %+v", v, fv)
+				}
+			} else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) {
+				if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv {
+					t.Errorf("expects: %+v, but got %+v", v, fv)
+				}
+			}
+		})
+	}
+}

+ 2 - 1
schema/schema_test.go

@@ -29,7 +29,8 @@ func TestParseSchema(t *testing.T) {
 		{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint},
 		{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
 		{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int},
-		{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint},
+		{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Int},
+		{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
 	}
 
 	for _, f := range fields {

+ 2 - 1
tests/model.go

@@ -21,11 +21,12 @@ type User struct {
 	Toys      []Toy `gorm:"polymorphic:Owner"`
 	CompanyID *int
 	Company   Company
-	ManagerID uint
+	ManagerID int
 	Manager   *User
 	Team      []User     `gorm:"foreignkey:ManagerID"`
 	Languages []Language `gorm:"many2many:UserSpeak"`
 	Friends   []*User    `gorm:"many2many:user_friends"`
+	Active    bool
 }
 
 type Account struct {