Przeglądaj źródła

Add Field Valuer, Setter

Jinzhu 4 lat temu
rodzic
commit
2cb88dc7c5
3 zmienionych plików z 423 dodań i 0 usunięć
  1. 357 0
      schema/field.go
  2. 64 0
      schema/field_test.go
  3. 2 0
      schema/schema.go

+ 357 - 0
schema/field.go

@@ -1,11 +1,15 @@
 package schema
 
 import (
+	"database/sql"
 	"database/sql/driver"
+	"fmt"
 	"reflect"
 	"strconv"
 	"sync"
 	"time"
+
+	"github.com/jinzhu/now"
 )
 
 type DataType string
@@ -43,6 +47,9 @@ type Field struct {
 	TagSettings     map[string]string
 	Schema          *Schema
 	EmbeddedSchema  *Schema
+	ReflectValuer   func(reflect.Value) reflect.Value
+	Valuer          func(reflect.Value) interface{}
+	Setter          func(reflect.Value, interface{}) error
 }
 
 func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
@@ -186,6 +193,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 		for _, ef := range field.EmbeddedSchema.Fields {
 			ef.Schema = schema
 			ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
+			// index is negative means is pointer
+			if field.FieldType.Kind() == reflect.Struct {
+				ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
+			} else {
+				ef.StructField.Index = append([]int{-fieldStruct.Index[0]}, ef.StructField.Index...)
+			}
 
 			if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
 				ef.DBName = prefix + ef.DBName
@@ -199,3 +212,347 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 
 	return field
 }
+
+// ValueOf field value of
+func (field *Field) ValueOf(value reflect.Value) interface{} {
+	if field != nil {
+		return field.Valuer(value)
+	}
+	return nil
+}
+
+func (field *Field) Set(value reflect.Value, v interface{}) error {
+	if field != nil {
+		return field.Setter(value, v)
+	}
+
+	return fmt.Errorf("failed to set field value: %v", field.Name)
+}
+
+// create valuer, setter when parse struct
+func (field *Field) setupValuerAndSetter() {
+	// Valuer
+	switch {
+	case len(field.StructField.Index) == 1:
+		field.Valuer = func(value reflect.Value) interface{} {
+			return value.Field(field.StructField.Index[0]).Interface()
+		}
+	case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
+		field.Valuer = func(value reflect.Value) interface{} {
+			return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface()
+		}
+	default:
+		field.Valuer = func(value reflect.Value) interface{} {
+			v := value.Field(field.StructField.Index[0])
+			for _, idx := range field.StructField.Index[1:] {
+				if v.Kind() == reflect.Ptr {
+					if v.Type().Elem().Kind() == reflect.Struct {
+						if !v.IsNil() {
+							v = v.Elem().Field(-idx)
+							continue
+						}
+					}
+					return nil
+				} else {
+					v = v.Field(idx)
+				}
+			}
+			return v.Interface()
+		}
+	}
+
+	// ReflectValuer
+	switch {
+	case len(field.StructField.Index) == 1:
+		if field.FieldType.Kind() == reflect.Ptr {
+			field.ReflectValuer = func(value reflect.Value) reflect.Value {
+				fieldValue := value.Field(field.StructField.Index[0])
+				if fieldValue.IsNil() {
+					fieldValue.Set(reflect.New(field.FieldType.Elem()))
+				}
+				return fieldValue
+			}
+		} else {
+			field.ReflectValuer = func(value reflect.Value) reflect.Value {
+				return value.Field(field.StructField.Index[0])
+			}
+		}
+	case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
+		field.Valuer = func(value reflect.Value) interface{} {
+			return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface()
+		}
+	default:
+		field.ReflectValuer = func(value reflect.Value) reflect.Value {
+			v := value.Field(field.StructField.Index[0])
+			for _, idx := range field.StructField.Index[1:] {
+				if v.Kind() == reflect.Ptr {
+					if v.Type().Elem().Kind() == reflect.Struct {
+						if v.IsNil() {
+							v.Set(reflect.New(v.Type().Elem()))
+						}
+
+						if idx >= 0 {
+							v = v.Elem().Field(idx)
+						} else {
+							v = v.Elem().Field(-idx)
+						}
+					}
+				} else {
+					v = v.Field(idx)
+				}
+			}
+			return v
+		}
+	}
+
+	// Setter
+	switch field.FieldType.Kind() {
+	case reflect.Bool:
+		field.Setter = func(value reflect.Value, v interface{}) error {
+			switch data := v.(type) {
+			case bool:
+				field.ReflectValuer(value).SetBool(data)
+			case *bool:
+				field.ReflectValuer(value).SetBool(*data)
+			default:
+				reflectV := reflect.ValueOf(v)
+				if reflectV.Type().ConvertibleTo(field.FieldType) {
+					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+				} else {
+					field.ReflectValuer(value).SetBool(!reflect.ValueOf(v).IsZero())
+				}
+			}
+			return nil
+		}
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		field.Setter = func(value reflect.Value, v interface{}) error {
+			switch data := v.(type) {
+			case int64:
+				field.ReflectValuer(value).SetInt(data)
+			case int:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case int8:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case int16:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case int32:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case uint:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case uint8:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case uint16:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case uint32:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case uint64:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case float32:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case float64:
+				field.ReflectValuer(value).SetInt(int64(data))
+			case []byte:
+				return field.Setter(value, string(data))
+			case string:
+				if i, err := strconv.ParseInt(data, 0, 64); err == nil {
+					field.ReflectValuer(value).SetInt(i)
+				} else {
+					return err
+				}
+			default:
+				reflectV := reflect.ValueOf(v)
+				if reflectV.Type().ConvertibleTo(field.FieldType) {
+					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+				} else if reflectV.Kind() == reflect.Ptr {
+					return field.Setter(value, reflectV.Elem().Interface())
+				} else {
+					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+				}
+			}
+			return nil
+		}
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		field.Setter = func(value reflect.Value, v interface{}) error {
+			switch data := v.(type) {
+			case uint64:
+				field.ReflectValuer(value).SetUint(data)
+			case uint:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case uint8:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case uint16:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case uint32:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case int64:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case int:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case int8:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case int16:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case int32:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case float32:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case float64:
+				field.ReflectValuer(value).SetUint(uint64(data))
+			case []byte:
+				return field.Setter(value, string(data))
+			case string:
+				if i, err := strconv.ParseUint(data, 0, 64); err == nil {
+					field.ReflectValuer(value).SetUint(i)
+				} else {
+					return err
+				}
+			default:
+				reflectV := reflect.ValueOf(v)
+				if reflectV.Type().ConvertibleTo(field.FieldType) {
+					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+				} else if reflectV.Kind() == reflect.Ptr {
+					return field.Setter(value, reflectV.Elem().Interface())
+				} else {
+					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+				}
+			}
+			return nil
+		}
+	case reflect.Float32, reflect.Float64:
+		field.Setter = func(value reflect.Value, v interface{}) error {
+			switch data := v.(type) {
+			case float64:
+				field.ReflectValuer(value).SetFloat(data)
+			case float32:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case int64:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case int:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case int8:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case int16:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case int32:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case uint:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case uint8:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case uint16:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case uint32:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case uint64:
+				field.ReflectValuer(value).SetFloat(float64(data))
+			case []byte:
+				return field.Setter(value, string(data))
+			case string:
+				if i, err := strconv.ParseFloat(data, 64); err == nil {
+					field.ReflectValuer(value).SetFloat(i)
+				} else {
+					return err
+				}
+			default:
+				reflectV := reflect.ValueOf(v)
+				if reflectV.Type().ConvertibleTo(field.FieldType) {
+					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+				} else if reflectV.Kind() == reflect.Ptr {
+					return field.Setter(value, reflectV.Elem().Interface())
+				} else {
+					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+				}
+			}
+			return nil
+		}
+	case reflect.String:
+		field.Setter = func(value reflect.Value, v interface{}) error {
+			switch data := v.(type) {
+			case string:
+				field.ReflectValuer(value).SetString(data)
+			case []byte:
+				field.ReflectValuer(value).SetString(string(data))
+			case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
+				field.ReflectValuer(value).SetString(fmt.Sprint(data))
+			case float64, float32:
+				field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
+			default:
+				reflectV := reflect.ValueOf(v)
+				if reflectV.Type().ConvertibleTo(field.FieldType) {
+					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+				} else if reflectV.Kind() == reflect.Ptr {
+					return field.Setter(value, reflectV.Elem().Interface())
+				} else {
+					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+				}
+			}
+			return nil
+		}
+	default:
+		fieldValue := reflect.New(field.FieldType)
+		switch fieldValue.Interface().(type) {
+		case time.Time:
+			field.Setter = func(value reflect.Value, v interface{}) error {
+				switch data := v.(type) {
+				case time.Time:
+					field.ReflectValuer(value).Set(reflect.ValueOf(v))
+				case *time.Time:
+					field.ReflectValuer(value).Set(reflect.ValueOf(v).Elem())
+				case string:
+					if t, err := now.Parse(data); err == nil {
+						field.ReflectValuer(value).Set(reflect.ValueOf(t))
+					} else {
+						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
+					}
+				default:
+					return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name)
+				}
+				return nil
+			}
+		case *time.Time:
+			field.Setter = func(value reflect.Value, v interface{}) error {
+				switch data := v.(type) {
+				case time.Time:
+					field.ReflectValuer(value).Elem().Set(reflect.ValueOf(v))
+				case *time.Time:
+					field.ReflectValuer(value).Set(reflect.ValueOf(v))
+				case string:
+					if t, err := now.Parse(data); err == nil {
+						field.ReflectValuer(value).Elem().Set(reflect.ValueOf(t))
+					} else {
+						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
+					}
+				default:
+					return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name)
+				}
+				return nil
+			}
+		default:
+			if fieldValue.CanAddr() {
+				if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
+					field.Setter = func(value reflect.Value, v interface{}) (err error) {
+						if valuer, ok := v.(driver.Valuer); ok {
+							if v, err = valuer.Value(); err == nil {
+								err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
+							}
+						} else {
+							err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
+						}
+						return
+					}
+					return
+				}
+			}
+
+			field.Setter = func(value reflect.Value, v interface{}) (err error) {
+				reflectV := reflect.ValueOf(v)
+				if reflectV.Type().ConvertibleTo(field.FieldType) {
+					field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
+				} else {
+					return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
+				}
+				return nil
+			}
+		}
+	}
+}

+ 64 - 0
schema/field_test.go

@@ -0,0 +1,64 @@
+package schema_test
+
+import (
+	"reflect"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/schema"
+	"github.com/jinzhu/gorm/tests"
+)
+
+func TestFieldValuerAndSetter(t *testing.T) {
+	var (
+		cacheMap      = sync.Map{}
+		userSchema, _ = schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
+		user          = tests.User{
+			Model: gorm.Model{
+				ID:        10,
+				CreatedAt: time.Now(),
+				DeletedAt: tests.Now(),
+			},
+			Name:     "valuer_and_setter",
+			Age:      18,
+			Birthday: tests.Now(),
+		}
+		reflectValue = reflect.ValueOf(user)
+	)
+
+	values := map[string]interface{}{
+		"name":       user.Name,
+		"id":         user.ID,
+		"created_at": user.CreatedAt,
+		"deleted_at": user.DeletedAt,
+		"age":        user.Age,
+		"birthday":   user.Birthday,
+	}
+
+	for k, v := range values {
+		if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v {
+			t.Errorf("user's %v value should equal %+v, but got %+v", k, v, rv)
+		}
+	}
+
+	newValues := map[string]interface{}{
+		"name":       "valuer_and_setter_2",
+		"id":         "2",
+		"created_at": time.Now(),
+		"deleted_at": tests.Now(),
+		"age":        20,
+		"birthday":   time.Now(),
+	}
+
+	for k, v := range newValues {
+		if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
+			t.Errorf("no error should happen when assign value to field %v", k)
+		}
+
+		if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v {
+			t.Errorf("user's %v value should equal %+v after assign new value, but got %+v", k, v, rv)
+		}
+	}
+}

+ 2 - 0
schema/schema.go

@@ -128,6 +128,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
 		if _, ok := schema.FieldsByName[field.Name]; !ok {
 			schema.FieldsByName[field.Name] = field
 		}
+
+		field.setupValuerAndSetter()
 	}
 
 	if f := schema.LookUpField("id"); f != nil {