Ver Fonte

Handle create with default db values

Jinzhu há 4 anos atrás
pai
commit
43ce0b8af2
2 ficheiros alterados com 63 adições e 32 exclusões
  1. 35 19
      callbacks/create.go
  2. 28 13
      schema/schema.go

+ 35 - 19
callbacks/create.go

@@ -59,8 +59,10 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
 		)
 
 		for _, db := range stmt.Schema.DBNames {
-			if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
-				values.Columns = append(values.Columns, clause.Column{Name: db})
+			if stmt.Schema.FieldsWithDefaultDBValue[db] == nil {
+				if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
+					values.Columns = append(values.Columns, clause.Column{Name: db})
+				}
 			}
 		}
 
@@ -68,6 +70,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
 		switch reflectValue.Kind() {
 		case reflect.Slice, reflect.Array:
 			values.Values = make([][]interface{}, reflectValue.Len())
+			defaultValueFieldsHavingValue := map[string][]interface{}{}
 			for i := 0; i < reflectValue.Len(); i++ {
 				rv := reflect.Indirect(reflectValue.Index(i))
 				values.Values[i] = make([]interface{}, len(values.Columns))
@@ -80,44 +83,57 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
 						} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
 							field.Set(rv, curTime)
 							values.Values[i][idx], _ = field.ValueOf(rv)
-						} else if field.HasDefaultValue {
-							if len(returnningValues) == 0 {
-								returnningValues = make([]map[string]interface{}, reflectValue.Len())
-							}
+						}
+					}
+				}
 
-							if returnningValues[i] == nil {
-								returnningValues[i] = map[string]interface{}{}
+				for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
+					if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
+						if v, isZero := field.ValueOf(rv); !isZero {
+							if len(defaultValueFieldsHavingValue[db]) == 0 {
+								defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len())
 							}
-
-							// FIXME
-							returnningValues[i][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface()
+							defaultValueFieldsHavingValue[db][i] = v
 						}
 					}
 				}
 			}
+
+			for db, vs := range defaultValueFieldsHavingValue {
+				values.Columns = append(values.Columns, clause.Column{Name: db})
+				for idx := range values.Values {
+					if vs[idx] == nil {
+						values.Values[idx] = append(values.Values[idx], clause.Expr{SQL: "DEFAULT"})
+					} else {
+						values.Values[idx] = append(values.Values[idx], vs[idx])
+					}
+				}
+			}
 		case reflect.Struct:
 			values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
 			for idx, column := range values.Columns {
 				field := stmt.Schema.FieldsByDBName[column.Name]
-				if values.Values[0][idx], _ = field.ValueOf(reflectValue); isZero {
+				if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero {
 					if field.DefaultValueInterface != nil {
 						values.Values[0][idx] = field.DefaultValueInterface
 						field.Set(reflectValue, field.DefaultValueInterface)
 					} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
 						field.Set(reflectValue, curTime)
 						values.Values[0][idx], _ = field.ValueOf(reflectValue)
-					} else if field.HasDefaultValue {
-						if len(returnningValues) == 0 {
-							returnningValues = make([]map[string]interface{}, 1)
-						}
+					}
+				}
+			}
 
-						values.Values[0][idx] = clause.Expr{SQL: "DEFAULT"}
-						returnningValues[0][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface()
-					} else if field.PrimaryKey {
+			for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
+				if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
+					if v, isZero := field.ValueOf(reflectValue); !isZero {
+						values.Columns = append(values.Columns, clause.Column{Name: db})
+						values.Values[0] = append(values.Values[0], v)
 					}
 				}
 			}
 		}
+
 		return values, returnningValues
 	}
 }

+ 28 - 13
schema/schema.go

@@ -14,19 +14,20 @@ import (
 var ErrUnsupportedDataType = errors.New("unsupported data type")
 
 type Schema struct {
-	Name                    string
-	ModelType               reflect.Type
-	Table                   string
-	PrioritizedPrimaryField *Field
-	DBNames                 []string
-	PrimaryFields           []*Field
-	Fields                  []*Field
-	FieldsByName            map[string]*Field
-	FieldsByDBName          map[string]*Field
-	Relationships           Relationships
-	err                     error
-	namer                   Namer
-	cacheStore              *sync.Map
+	Name                     string
+	ModelType                reflect.Type
+	Table                    string
+	PrioritizedPrimaryField  *Field
+	DBNames                  []string
+	PrimaryFields            []*Field
+	Fields                   []*Field
+	FieldsByName             map[string]*Field
+	FieldsByDBName           map[string]*Field
+	FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
+	Relationships            Relationships
+	err                      error
+	namer                    Namer
+	cacheStore               *sync.Map
 }
 
 func (schema Schema) String() string {
@@ -146,6 +147,20 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
 		}
 	}
 
+	schema.FieldsWithDefaultDBValue = map[string]*Field{}
+	for db, field := range schema.FieldsByDBName {
+		if field.HasDefaultValue && field.DefaultValueInterface == nil {
+			schema.FieldsWithDefaultDBValue[db] = field
+		}
+	}
+
+	if schema.PrioritizedPrimaryField != nil {
+		switch schema.PrioritizedPrimaryField.DataType {
+		case Int, Uint:
+			schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField
+		}
+	}
+
 	cacheStore.Store(modelType, schema)
 
 	// parse relations for unidentified fields