Browse Source

scope.Fields() return slice of *Field

Jinzhu 8 years ago
parent
commit
3055bad1e8
11 changed files with 99 additions and 84 deletions
  1. 12 11
      association.go
  2. 2 3
      callback_create.go
  3. 1 1
      callback_query.go
  4. 2 2
      callback_query_preload.go
  5. 2 3
      callback_update.go
  6. 25 20
      field.go
  7. 7 3
      field_test.go
  8. 9 3
      join_table_handler.go
  9. 1 1
      main.go
  10. 34 22
      scope.go
  11. 4 15
      scope_private.go

+ 12 - 11
association.go

@@ -63,9 +63,11 @@ func (association *Association) Replace(values ...interface{}) *Association {
 			var associationForeignFieldNames []string
 			if relationship.Kind == "many_to_many" {
 				// if many to many relations, get association fields name from association foreign keys
-				associationFields := scope.New(reflect.New(field.Type()).Interface()).Fields()
+				associationScope := scope.New(reflect.New(field.Type()).Interface())
 				for _, dbName := range relationship.AssociationForeignFieldNames {
-					associationForeignFieldNames = append(associationForeignFieldNames, associationFields[dbName].Name)
+					if field, ok := associationScope.FieldByName(dbName); ok {
+						associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
+					}
 				}
 			} else {
 				// If other relations, use primary keys
@@ -84,15 +86,12 @@ func (association *Association) Replace(values ...interface{}) *Association {
 
 		if relationship.Kind == "many_to_many" {
 			// if many to many relations, delete related relations from join table
-
-			// get source fields name from source foreign keys
-			var (
-				sourceFields            = scope.Fields()
-				sourceForeignFieldNames []string
-			)
+			var sourceForeignFieldNames []string
 
 			for _, dbName := range relationship.ForeignFieldNames {
-				sourceForeignFieldNames = append(sourceForeignFieldNames, sourceFields[dbName].Name)
+				if field, ok := scope.FieldByName(dbName); ok {
+					sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
+				}
 			}
 
 			if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
@@ -147,10 +146,12 @@ func (association *Association) Delete(values ...interface{}) *Association {
 		}
 
 		// get association's foreign fields name
-		var associationFields = scope.New(reflect.New(field.Type()).Interface()).Fields()
+		var associationScope = scope.New(reflect.New(field.Type()).Interface())
 		var associationForeignFieldNames []string
 		for _, associationDBName := range relationship.AssociationForeignFieldNames {
-			associationForeignFieldNames = append(associationForeignFieldNames, associationFields[associationDBName].Name)
+			if field, ok := associationScope.FieldByName(associationDBName); ok {
+				associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
+			}
 		}
 
 		// association value's foreign keys

+ 2 - 3
callback_create.go

@@ -45,10 +45,9 @@ func createCallback(scope *Scope) {
 		var (
 			columns, placeholders        []string
 			blankColumnsWithDefaultValue []string
-			fields                       = scope.Fields()
 		)
 
-		for _, field := range fields {
+		for _, field := range scope.Fields() {
 			if scope.changeableField(field) {
 				if field.IsNormal {
 					if !field.IsPrimaryKey || !field.IsBlank {
@@ -62,7 +61,7 @@ func createCallback(scope *Scope) {
 					}
 				} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
 					for _, foreignKey := range field.Relationship.ForeignDBNames {
-						if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) {
+						if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
 							columns = append(columns, scope.Quote(foreignField.DBName))
 							placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
 						}

+ 1 - 1
callback_query.go

@@ -68,7 +68,7 @@ func queryCallback(scope *Scope) {
 					elem = reflect.New(resultType).Elem()
 				}
 
-				scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
+				scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap())
 
 				if isSlice {
 					if isPtr {

+ 2 - 2
callback_query_preload.go

@@ -255,7 +255,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
 	for rows.Next() {
 		var (
 			elem   = reflect.New(fieldType).Elem()
-			fields = scope.New(elem.Addr().Interface()).Fields()
+			fields = scope.New(elem.Addr().Interface()).fieldsMap()
 		)
 
 		// register foreign keys in join tables
@@ -284,7 +284,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
 		indirectScopeValue = scope.IndirectValue()
 		fieldsSourceMap    = map[string]reflect.Value{}
 		foreignFieldNames  = []string{}
-		fields             = scope.Fields()
+		fields             = scope.fieldsMap()
 	)
 
 	for _, dbName := range relation.ForeignFieldNames {

+ 2 - 3
callback_update.go

@@ -60,14 +60,13 @@ func updateCallback(scope *Scope) {
 				sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
 			}
 		} else {
-			fields := scope.Fields()
-			for _, field := range fields {
+			for _, field := range scope.Fields() {
 				if scope.changeableField(field) {
 					if !field.IsPrimaryKey && field.IsNormal {
 						sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
 					} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 						for _, foreignKey := range relationship.ForeignDBNames {
-							if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) {
+							if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
 								sqls = append(sqls,
 									fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
 							}

+ 25 - 20
field.go

@@ -56,29 +56,34 @@ func (field *Field) Set(value interface{}) (err error) {
 }
 
 // Fields get value's fields
-func (scope *Scope) Fields() map[string]*Field {
-	if scope.fields == nil {
-		var (
-			fields             = map[string]*Field{}
-			indirectScopeValue = scope.IndirectValue()
-			isStruct           = indirectScopeValue.Kind() == reflect.Struct
-		)
+func (scope *Scope) Fields() []*Field {
+	var (
+		fields             []*Field
+		indirectScopeValue = scope.IndirectValue()
+		isStruct           = indirectScopeValue.Kind() == reflect.Struct
+	)
 
-		for _, structField := range scope.GetModelStruct().StructFields {
-			if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
-				if isStruct {
-					fieldValue := indirectScopeValue
-					for _, name := range structField.Names {
-						fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
-					}
-					fields[structField.DBName] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}
-				} else {
-					fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
-				}
+	for _, structField := range scope.GetModelStruct().StructFields {
+		if isStruct {
+			fieldValue := indirectScopeValue
+			for _, name := range structField.Names {
+				fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
 			}
+			fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
+		} else {
+			fields = append(fields, &Field{StructField: structField, IsBlank: true})
 		}
+	}
 
-		scope.fields = fields
+	return fields
+}
+
+func (scope *Scope) fieldsMap() map[string]*Field {
+	var results = map[string]*Field{}
+	for _, field := range scope.Fields() {
+		if field.IsNormal {
+			results[field.DBName] = field
+		}
 	}
-	return scope.fields
+	return results
 }

+ 7 - 3
field_test.go

@@ -32,12 +32,16 @@ type CalculateFieldCategory struct {
 
 func TestCalculateField(t *testing.T) {
 	var field CalculateField
-	fields := DB.NewScope(&field).Fields()
-	if fields["children"].Relationship == nil || fields["category"].Relationship == nil {
+	var scope = DB.NewScope(&field)
+	if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
 		t.Errorf("Should calculate fields correctly for the first time")
 	}
 
-	if field, ok := fields["embedded_name"]; !ok {
+	if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
+		t.Errorf("Should calculate fields correctly for the first time")
+	}
+
+	if field, ok := scope.FieldByName("embedded_name"); !ok {
 		t.Errorf("should find embedded field")
 	} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
 		t.Errorf("should find embedded field's tag settings")

+ 9 - 3
join_table_handler.go

@@ -74,11 +74,15 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
 
 		if s.Source.ModelType == modelType {
 			for _, foreignKey := range s.Source.ForeignKeys {
-				values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
+				if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
+					values[foreignKey.DBName] = field.Field.Interface()
+				}
 			}
 		} else if s.Destination.ModelType == modelType {
 			for _, foreignKey := range s.Destination.ForeignKeys {
-				values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
+				if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
+					values[foreignKey.DBName] = field.Field.Interface()
+				}
 			}
 		}
 	}
@@ -151,7 +155,9 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
 
 		for _, foreignKey := range s.Source.ForeignKeys {
 			foreignDBNames = append(foreignDBNames, foreignKey.DBName)
-			foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
+			if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
+				foreignFieldNames = append(foreignFieldNames, field.Name)
+			}
 		}
 
 		foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)

+ 1 - 1
main.go

@@ -232,7 +232,7 @@ func (s *DB) ScanRows(rows *sql.Rows, value interface{}) error {
 	)
 
 	if clone.AddError(err) == nil {
-		scope.scan(rows, columns, scope.Fields())
+		scope.scan(rows, columns, scope.fieldsMap())
 	}
 
 	return clone.Error

+ 34 - 22
scope.go

@@ -100,10 +100,11 @@ func (scope *Scope) HasError() bool {
 	return scope.db.Error != nil
 }
 
-func (scope *Scope) PrimaryFields() []*Field {
-	var fields = []*Field{}
-	for _, field := range scope.GetModelStruct().PrimaryFields {
-		fields = append(fields, scope.Fields()[field.DBName])
+func (scope *Scope) PrimaryFields() (fields []*Field) {
+	for _, field := range scope.Fields() {
+		if field.IsPrimaryKey {
+			fields = append(fields, field)
+		}
 	}
 	return fields
 }
@@ -111,11 +112,11 @@ func (scope *Scope) PrimaryFields() []*Field {
 func (scope *Scope) PrimaryField() *Field {
 	if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
 		if len(primaryFields) > 1 {
-			if field, ok := scope.Fields()["id"]; ok {
+			if field, ok := scope.FieldByName("id"); ok {
 				return field
 			}
 		}
-		return scope.Fields()[primaryFields[0].DBName]
+		return scope.PrimaryFields()[0]
 	}
 	return nil
 }
@@ -164,20 +165,23 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
 		updateAttrs[field.DBName] = value
 		return field.Set(value)
 	} else if name, ok := column.(string); ok {
-		if field, ok := scope.Fields()[name]; ok {
-			updateAttrs[field.DBName] = value
-			return field.Set(value)
-		}
-
-		dbName := ToDBName(name)
-		if field, ok := scope.Fields()[dbName]; ok {
-			updateAttrs[field.DBName] = value
-			return field.Set(value)
+		var (
+			dbName           = ToDBName(name)
+			mostMatchedField *Field
+		)
+		for _, field := range scope.Fields() {
+			if field.DBName == value {
+				updateAttrs[field.DBName] = value
+				return field.Set(value)
+			}
+			if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
+				mostMatchedField = field
+			}
 		}
 
-		if field, ok := scope.FieldByName(name); ok {
-			updateAttrs[field.DBName] = value
-			return field.Set(value)
+		if mostMatchedField != nil {
+			updateAttrs[mostMatchedField.DBName] = value
+			return mostMatchedField.Set(value)
 		}
 	}
 	return errors.New("could not convert column to field")
@@ -286,12 +290,20 @@ func (scope *Scope) CombinedConditionSql() string {
 
 // FieldByName find gorm.Field with name and db name
 func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
+	var (
+		dbName           = ToDBName(name)
+		mostMatchedField *Field
+	)
+
 	for _, field := range scope.Fields() {
 		if field.Name == name || field.DBName == name {
 			return field, true
 		}
+		if field.DBName == dbName {
+			mostMatchedField = field
+		}
 	}
-	return nil, false
+	return mostMatchedField, mostMatchedField != nil
 }
 
 // Raw set sql
@@ -390,12 +402,12 @@ func (scope *Scope) OmitAttrs() []string {
 	return scope.Search.omits
 }
 
-func (scope *Scope) scan(rows *sql.Rows, columns []string, fields map[string]*Field) {
+func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
 	var values = make([]interface{}, len(columns))
 	var ignored interface{}
 
 	for index, column := range columns {
-		if field, ok := fields[column]; ok {
+		if field, ok := fieldsMap[column]; ok {
 			if field.Field.Kind() == reflect.Ptr {
 				values[index] = field.Field.Addr().Interface()
 			} else {
@@ -411,7 +423,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields map[string]*Fi
 	scope.Err(rows.Scan(values...))
 
 	for index, column := range columns {
-		if field, ok := fields[column]; ok {
+		if field, ok := fieldsMap[column]; ok {
 			if field.Field.Kind() != reflect.Ptr {
 				if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
 					field.Field.Set(v)

+ 4 - 15
scope_private.go

@@ -437,21 +437,10 @@ func (scope *Scope) shouldSaveAssociations() bool {
 
 func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
 	toScope := scope.db.NewScope(value)
-	fromFields := scope.Fields()
-	toFields := toScope.Fields()
 
 	for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
-		var fromField, toField *Field
-		if field, ok := scope.FieldByName(foreignKey); ok {
-			fromField = field
-		} else {
-			fromField = fromFields[ToDBName(foreignKey)]
-		}
-		if field, ok := toScope.FieldByName(foreignKey); ok {
-			toField = field
-		} else {
-			toField = toFields[ToDBName(foreignKey)]
-		}
+		fromField, _ := scope.FieldByName(foreignKey)
+		toField, _ := toScope.FieldByName(foreignKey)
 
 		if fromField != nil {
 			if relationship := fromField.Relationship; relationship != nil {
@@ -515,7 +504,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
 
 			var sqlTypes, primaryKeys []string
 			for idx, fieldName := range relationship.ForeignFieldNames {
-				if field, ok := scope.Fields()[fieldName]; ok {
+				if field, ok := scope.FieldByName(fieldName); ok {
 					foreignKeyStruct := field.clone()
 					foreignKeyStruct.IsPrimaryKey = false
 					foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
@@ -525,7 +514,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
 			}
 
 			for idx, fieldName := range relationship.AssociationForeignFieldNames {
-				if field, ok := toScope.Fields()[fieldName]; ok {
+				if field, ok := toScope.FieldByName(fieldName); ok {
 					foreignKeyStruct := field.clone()
 					foreignKeyStruct.IsPrimaryKey = false
 					foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"