ソースを参照

Add more tests for setter, valuer

Jinzhu 4 年 前
コミット
18236fa3d7
5 ファイル変更275 行追加127 行削除
  1. 51 80
      schema/field.go
  2. 112 25
      schema/field_test.go
  3. 41 0
      schema/model_test.go
  4. 29 19
      schema/schema_helper_test.go
  5. 42 3
      schema/schema_test.go

+ 51 - 80
schema/field.go

@@ -164,6 +164,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 	case reflect.Struct:
 		if _, ok := fieldValue.Interface().(*time.Time); ok {
 			field.DataType = Time
+		} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
+			field.DataType = Time
 		}
 	case reflect.Array, reflect.Slice:
 		if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) {
@@ -311,6 +313,24 @@ func (field *Field) setupValuerAndSetter() {
 		}
 	}
 
+	recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
+		reflectV := reflect.ValueOf(v)
+		if reflectV.Type().ConvertibleTo(field.FieldType) {
+			field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+		} else if valuer, ok := v.(driver.Valuer); ok {
+			if v, err = valuer.Value(); err == nil {
+				return setter(value, v)
+			}
+		} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
+			field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
+		} else if reflectV.Kind() == reflect.Ptr {
+			return field.Setter(value, reflectV.Elem().Interface())
+		} else {
+			return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+		}
+		return err
+	}
+
 	// Setter
 	switch field.FieldType.Kind() {
 	case reflect.Bool:
@@ -321,17 +341,12 @@ func (field *Field) setupValuerAndSetter() {
 			case *bool:
 				field.ReflectValuer(value).SetBool(*data)
 			default:
-				reflectV := reflect.ValueOf(v)
-				if reflectV.Type().ConvertibleTo(field.FieldType) {
-					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
-				} else {
-					field.ReflectValuer(value).SetBool(!reflect.ValueOf(v).IsZero())
-				}
+				return recoverFunc(value, v, field.Setter)
 			}
 			return nil
 		}
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-		field.Setter = func(value reflect.Value, v interface{}) error {
+		field.Setter = func(value reflect.Value, v interface{}) (err error) {
 			switch data := v.(type) {
 			case int64:
 				field.ReflectValuer(value).SetInt(data)
@@ -366,19 +381,12 @@ func (field *Field) setupValuerAndSetter() {
 					return err
 				}
 			default:
-				reflectV := reflect.ValueOf(v)
-				if reflectV.Type().ConvertibleTo(field.FieldType) {
-					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
-				} else if reflectV.Kind() == reflect.Ptr {
-					return field.Setter(value, reflectV.Elem().Interface())
-				} else {
-					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
-				}
+				return recoverFunc(value, v, field.Setter)
 			}
-			return nil
+			return err
 		}
 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
-		field.Setter = func(value reflect.Value, v interface{}) error {
+		field.Setter = func(value reflect.Value, v interface{}) (err error) {
 			switch data := v.(type) {
 			case uint64:
 				field.ReflectValuer(value).SetUint(data)
@@ -413,19 +421,12 @@ func (field *Field) setupValuerAndSetter() {
 					return err
 				}
 			default:
-				reflectV := reflect.ValueOf(v)
-				if reflectV.Type().ConvertibleTo(field.FieldType) {
-					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
-				} else if reflectV.Kind() == reflect.Ptr {
-					return field.Setter(value, reflectV.Elem().Interface())
-				} else {
-					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
-				}
+				return recoverFunc(value, v, field.Setter)
 			}
-			return nil
+			return err
 		}
 	case reflect.Float32, reflect.Float64:
-		field.Setter = func(value reflect.Value, v interface{}) error {
+		field.Setter = func(value reflect.Value, v interface{}) (err error) {
 			switch data := v.(type) {
 			case float64:
 				field.ReflectValuer(value).SetFloat(data)
@@ -460,19 +461,12 @@ func (field *Field) setupValuerAndSetter() {
 					return err
 				}
 			default:
-				reflectV := reflect.ValueOf(v)
-				if reflectV.Type().ConvertibleTo(field.FieldType) {
-					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
-				} else if reflectV.Kind() == reflect.Ptr {
-					return field.Setter(value, reflectV.Elem().Interface())
-				} else {
-					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
-				}
+				return recoverFunc(value, v, field.Setter)
 			}
-			return nil
+			return err
 		}
 	case reflect.String:
-		field.Setter = func(value reflect.Value, v interface{}) error {
+		field.Setter = func(value reflect.Value, v interface{}) (err error) {
 			switch data := v.(type) {
 			case string:
 				field.ReflectValuer(value).SetString(data)
@@ -483,16 +477,9 @@ func (field *Field) setupValuerAndSetter() {
 			case float64, float32:
 				field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
 			default:
-				reflectV := reflect.ValueOf(v)
-				if reflectV.Type().ConvertibleTo(field.FieldType) {
-					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
-				} else if reflectV.Kind() == reflect.Ptr {
-					return field.Setter(value, reflectV.Elem().Interface())
-				} else {
-					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
-				}
+				return recoverFunc(value, v, field.Setter)
 			}
-			return nil
+			return err
 		}
 	default:
 		fieldValue := reflect.New(field.FieldType)
@@ -511,7 +498,7 @@ func (field *Field) setupValuerAndSetter() {
 						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
 					}
 				default:
-					return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name)
+					return recoverFunc(value, v, field.Setter)
 				}
 				return nil
 			}
@@ -529,62 +516,46 @@ func (field *Field) setupValuerAndSetter() {
 						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
 					}
 				default:
-					return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name)
+					return recoverFunc(value, v, field.Setter)
 				}
 				return nil
 			}
 		default:
 			if _, ok := fieldValue.Interface().(sql.Scanner); ok {
+				// struct scanner
 				field.Setter = func(value reflect.Value, v interface{}) (err error) {
-					if valuer, ok := v.(driver.Valuer); ok {
+					reflectV := reflect.ValueOf(v)
+					if reflectV.Type().ConvertibleTo(field.FieldType) {
+						field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+					} else if valuer, ok := v.(driver.Valuer); ok {
 						if v, err = valuer.Value(); err == nil {
-							err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
-						}
-					} else {
-						err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
-					}
-					return
-				}
-				return
-			}
-
-			if fieldValue.CanAddr() {
-				if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
-					field.Setter = func(value reflect.Value, v interface{}) (err error) {
-						if valuer, ok := v.(driver.Valuer); ok {
-							if v, err = valuer.Value(); err == nil {
-								err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
-							}
-						} else {
 							err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
 						}
-						return
+					} else {
+						err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
 					}
 					return
 				}
-			}
-
-			if field.FieldType.Kind() == reflect.Ptr {
+			} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
+				// pointer scanner
 				field.Setter = func(value reflect.Value, v interface{}) (err error) {
 					reflectV := reflect.ValueOf(v)
 					if reflectV.Type().ConvertibleTo(field.FieldType) {
 						field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
 					} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
 						field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
+					} else if valuer, ok := v.(driver.Valuer); ok {
+						if v, err = valuer.Value(); err == nil {
+							err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
+						}
 					} else {
-						return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+						err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
 					}
-					return nil
+					return
 				}
 			} else {
 				field.Setter = func(value reflect.Value, v interface{}) (err error) {
-					reflectV := reflect.ValueOf(v)
-					if reflectV.Type().ConvertibleTo(field.FieldType) {
-						field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
-					} else {
-						return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
-					}
-					return nil
+					return recoverFunc(value, v, field.Setter)
 				}
 			}
 		}

+ 112 - 25
schema/field_test.go

@@ -1,6 +1,7 @@
 package schema_test
 
 import (
+	"database/sql"
 	"reflect"
 	"sync"
 	"testing"
@@ -13,8 +14,7 @@ import (
 
 func TestFieldValuerAndSetter(t *testing.T) {
 	var (
-		cacheMap      = sync.Map{}
-		userSchema, _ = schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
+		userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
 		user          = tests.User{
 			Model: gorm.Model{
 				ID:        10,
@@ -54,20 +54,38 @@ func TestFieldValuerAndSetter(t *testing.T) {
 
 	for k, v := range newValues {
 		if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
-			t.Errorf("no error should happen when assign value to field %v", k)
+			t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
 		}
 	}
 	checkField(t, userSchema, reflectValue, newValues)
+
+	// test valuer and other type
+	age := myint(10)
+	newValues2 := map[string]interface{}{
+		"name":       sql.NullString{String: "valuer_and_setter_3", Valid: true},
+		"id":         &sql.NullInt64{Int64: 3, Valid: true},
+		"created_at": tests.Now(),
+		"deleted_at": time.Now(),
+		"age":        &age,
+		"birthday":   mytime(time.Now()),
+		"active":     mybool(true),
+	}
+
+	for k, v := range newValues2 {
+		if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
+			t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
+		}
+	}
+	checkField(t, userSchema, reflectValue, newValues2)
 }
 
 func TestPointerFieldValuerAndSetter(t *testing.T) {
 	var (
-		cacheMap      = sync.Map{}
-		userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{})
-		name          = "pointer_field_valuer_and_setter"
-		age           = 18
-		active        = true
-		user          = User{
+		userSchema, _      = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
+		name               = "pointer_field_valuer_and_setter"
+		age           uint = 18
+		active             = true
+		user               = User{
 			Model: &gorm.Model{
 				ID:        10,
 				CreatedAt: time.Now(),
@@ -110,22 +128,91 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
 		}
 	}
 	checkField(t, userSchema, reflectValue, newValues)
+
+	// test valuer and other type
+	age2 := myint(10)
+	newValues2 := map[string]interface{}{
+		"name":       sql.NullString{String: "valuer_and_setter_3", Valid: true},
+		"id":         &sql.NullInt64{Int64: 3, Valid: true},
+		"created_at": tests.Now(),
+		"deleted_at": time.Now(),
+		"age":        &age2,
+		"birthday":   mytime(time.Now()),
+		"active":     mybool(true),
+	}
+
+	for k, v := range newValues2 {
+		if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
+			t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
+		}
+	}
+	checkField(t, userSchema, reflectValue, newValues2)
 }
 
-type User struct {
-	*gorm.Model
-	Name      *string
-	Age       *int
-	Birthday  *time.Time
-	Account   *tests.Account
-	Pets      []*tests.Pet
-	Toys      []tests.Toy `gorm:"polymorphic:Owner"`
-	CompanyID *int
-	Company   *tests.Company
-	ManagerID *int
-	Manager   *User
-	Team      []User           `gorm:"foreignkey:ManagerID"`
-	Languages []tests.Language `gorm:"many2many:UserSpeak"`
-	Friends   []*User          `gorm:"many2many:user_friends"`
-	Active    *bool
+func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
+	var (
+		userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
+		name          = "advanced_data_type_valuer_and_setter"
+		deletedAt     = mytime(time.Now())
+		isAdmin       = mybool(false)
+		user          = AdvancedDataTypeUser{
+			ID:           sql.NullInt64{Int64: 10, Valid: true},
+			Name:         &sql.NullString{String: name, Valid: true},
+			Birthday:     sql.NullTime{Time: time.Now(), Valid: true},
+			RegisteredAt: mytime(time.Now()),
+			DeletedAt:    &deletedAt,
+			Active:       mybool(true),
+			Admin:        &isAdmin,
+		}
+		reflectValue = reflect.ValueOf(&user)
+	)
+
+	// test valuer
+	values := map[string]interface{}{
+		"id":            user.ID,
+		"name":          user.Name,
+		"birthday":      user.Birthday,
+		"registered_at": user.RegisteredAt,
+		"deleted_at":    user.DeletedAt,
+		"active":        user.Active,
+		"admin":         user.Admin,
+	}
+	checkField(t, userSchema, reflectValue, values)
+
+	// test setter
+	newDeletedAt := mytime(time.Now())
+	newIsAdmin := mybool(true)
+	newValues := map[string]interface{}{
+		"id":            sql.NullInt64{Int64: 1, Valid: true},
+		"name":          &sql.NullString{String: name + "rename", Valid: true},
+		"birthday":      time.Now(),
+		"registered_at": mytime(time.Now()),
+		"deleted_at":    &newDeletedAt,
+		"active":        mybool(false),
+		"admin":         &newIsAdmin,
+	}
+
+	for k, v := range newValues {
+		if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
+			t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
+		}
+	}
+	checkField(t, userSchema, reflectValue, newValues)
+
+	newValues2 := map[string]interface{}{
+		"id":            5,
+		"name":          name + "rename2",
+		"birthday":      time.Now(),
+		"registered_at": time.Now(),
+		"deleted_at":    time.Now(),
+		"active":        true,
+		"admin":         false,
+	}
+
+	for k, v := range newValues2 {
+		if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
+			t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
+		}
+	}
+	checkField(t, userSchema, reflectValue, newValues2)
 }

+ 41 - 0
schema/model_test.go

@@ -0,0 +1,41 @@
+package schema_test
+
+import (
+	"database/sql"
+	"time"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/tests"
+)
+
+type User struct {
+	*gorm.Model
+	Name      *string
+	Age       *uint
+	Birthday  *time.Time
+	Account   *tests.Account
+	Pets      []*tests.Pet
+	Toys      []*tests.Toy `gorm:"polymorphic:Owner"`
+	CompanyID *int
+	Company   *tests.Company
+	ManagerID *int
+	Manager   *User
+	Team      []*User           `gorm:"foreignkey:ManagerID"`
+	Languages []*tests.Language `gorm:"many2many:UserSpeak"`
+	Friends   []*User           `gorm:"many2many:user_friends"`
+	Active    *bool
+}
+
+type mytime time.Time
+type myint int
+type mybool = bool
+
+type AdvancedDataTypeUser struct {
+	ID           sql.NullInt64
+	Name         *sql.NullString
+	Birthday     sql.NullTime
+	RegisteredAt mytime
+	DeletedAt    *mytime
+	Active       mybool
+	Admin        *mybool
+}

+ 29 - 19
schema/schema_helper_test.go

@@ -1,6 +1,7 @@
 package schema_test
 
 import (
+	"database/sql/driver"
 	"fmt"
 	"reflect"
 	"strings"
@@ -194,30 +195,39 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
 func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
 	for k, v := range values {
 		t.Run("CheckField/"+k, func(t *testing.T) {
-			field := s.FieldsByDBName[k]
-			fv := field.ValueOf(value)
-
-			if reflect.ValueOf(fv).Kind() == reflect.Ptr {
-				if reflect.ValueOf(v).Kind() == reflect.Ptr {
-					if fv != v {
-						t.Errorf("pointer expects: %p, but got %p", v, fv)
+			var (
+				checker func(fv interface{}, v interface{})
+				field   = s.FieldsByDBName[k]
+				fv      = field.ValueOf(value)
+			)
+
+			checker = func(fv interface{}, v interface{}) {
+				if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v {
+					t.Errorf("expects: %p, but got %p", v, fv)
+				} else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) {
+					if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv {
+						t.Errorf("expects: %p, but got %p", v, fv)
 					}
-				} else if fv == nil {
-					if v != nil {
-						t.Errorf("expects: %+v, but got nil", v)
+				} else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) {
+					if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v {
+						t.Errorf("expects: %p, but got %p", v, fv)
 					}
-				} else if reflect.ValueOf(fv).Elem().Interface() != v {
-					t.Errorf("expects: %+v, but got %+v", v, fv)
-				}
-			} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
-				if reflect.ValueOf(v).Elem().Interface() != fv {
-					t.Errorf("expects: %+v, but got %+v", v, fv)
-				}
-			} else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) {
-				if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv {
+				} else if valuer, isValuer := fv.(driver.Valuer); isValuer {
+					valuerv, _ := valuer.Value()
+					checker(valuerv, v)
+				} else if valuer, isValuer := v.(driver.Valuer); isValuer {
+					valuerv, _ := valuer.Value()
+					checker(fv, valuerv)
+				} else if reflect.ValueOf(fv).Kind() == reflect.Ptr {
+					checker(reflect.ValueOf(fv).Elem().Interface(), v)
+				} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
+					checker(fv, reflect.ValueOf(v).Elem().Interface())
+				} else {
 					t.Errorf("expects: %+v, but got %+v", v, fv)
 				}
 			}
+
+			checker(fv, v)
 		})
 	}
 }

+ 42 - 3
schema/schema_test.go

@@ -9,13 +9,24 @@ import (
 )
 
 func TestParseSchema(t *testing.T) {
-	cacheMap := sync.Map{}
-
-	user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
+	user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse user, got error %v", err)
 	}
 
+	checkUserSchema(t, user)
+}
+
+func TestParseSchemaWithPointerFields(t *testing.T) {
+	user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
+	if err != nil {
+		t.Fatalf("failed to parse pointer user, got error %v", err)
+	}
+
+	checkUserSchema(t, user)
+}
+
+func checkUserSchema(t *testing.T, user *schema.Schema) {
 	// check schema
 	checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"})
 
@@ -101,3 +112,31 @@ func TestParseSchema(t *testing.T) {
 		checkSchemaRelation(t, user, relation)
 	}
 }
+
+func TestParseSchemaWithAdvancedDataType(t *testing.T) {
+	user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
+	if err != nil {
+		t.Fatalf("failed to parse pointer user, got error %v", err)
+	}
+
+	// check schema
+	checkSchema(t, user, schema.Schema{Name: "AdvancedDataTypeUser", Table: "advanced_data_type_users"}, []string{"ID"})
+
+	// check fields
+	fields := []schema.Field{
+		{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true},
+		{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
+		{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
+		{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},
+		{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"DeletedAt"}, DataType: schema.Time},
+		{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
+		{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
+	}
+
+	for _, f := range fields {
+		checkSchemaField(t, user, &f, func(f *schema.Field) {
+			f.Creatable = true
+			f.Updatable = true
+		})
+	}
+}