瀏覽代碼

Refactor getColumnAsArray

Jinzhu 8 年之前
父節點
當前提交
822e895d4d
共有 4 個文件被更改,包括 27 次插入21 次删除
  1. 3 3
      association.go
  2. 2 2
      join_table_handler.go
  3. 3 3
      preload.go
  4. 19 13
      scope_utils.go

+ 3 - 3
association.go

@@ -117,7 +117,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
 		}
 	}
 
-	deletingPrimaryKeys := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, values...)
+	deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
 
 	if relationship.Kind == "many_to_many" {
 		// source value's foreign keys
@@ -141,7 +141,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
 
 		if relationship.Kind == "belongs_to" {
 			// find with deleting relation's foreign keys
-			primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
+			primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
 			newDB = newDB.Where(
 				fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
 				toQueryValues(primaryKeys)...,
@@ -158,7 +158,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
 			}
 		} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
 			// find all relations
-			primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, scope.Value)
+			primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
 			newDB = newDB.Where(
 				fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
 				toQueryValues(primaryKeys)...,

+ 2 - 2
join_table_handler.go

@@ -154,7 +154,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
 			foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
 		}
 
-		foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
+		foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
 
 		var condString string
 		if len(foreignFieldValues) > 0 {
@@ -165,7 +165,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
 
 			condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
 
-			keys := scope.getColumnAsArray(foreignFieldNames)
+			keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
 			values = append(values, toQueryValues(keys))
 		} else {
 			condString = fmt.Sprintf("1 <> 1")

+ 3 - 3
preload.go

@@ -77,7 +77,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
 	relation := field.Relationship
 
 	// get relations's primary keys
-	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
+	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
 	if len(primaryKeys) == 0 {
 		return
 	}
@@ -112,7 +112,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
 	relation := field.Relationship
 
 	// get relations's primary keys
-	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
+	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
 	if len(primaryKeys) == 0 {
 		return
 	}
@@ -149,7 +149,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
 	relation := field.Relationship
 
 	// get relations's primary keys
-	primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
+	primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
 	if len(primaryKeys) == 0 {
 		return
 	}

+ 19 - 13
scope_utils.go

@@ -2,24 +2,30 @@ package gorm
 
 import "reflect"
 
-func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
-	indirectScopeValue := scope.IndirectValue()
-	switch indirectScopeValue.Kind() {
-	case reflect.Slice:
-		for i := 0; i < indirectScopeValue.Len(); i++ {
+func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
+	for _, value := range values {
+		indirectValue := reflect.ValueOf(value)
+		for indirectValue.Kind() == reflect.Ptr {
+			indirectValue = indirectValue.Elem()
+		}
+
+		switch indirectValue.Kind() {
+		case reflect.Slice:
+			for i := 0; i < indirectValue.Len(); i++ {
+				var result []interface{}
+				var object = reflect.Indirect(indirectValue.Index(i))
+				for _, column := range columns {
+					result = append(result, object.FieldByName(column).Interface())
+				}
+				results = append(results, result)
+			}
+		case reflect.Struct:
 			var result []interface{}
-			var object = reflect.Indirect(indirectScopeValue.Index(i))
 			for _, column := range columns {
-				result = append(result, object.FieldByName(column).Interface())
+				result = append(result, indirectValue.FieldByName(column).Interface())
 			}
 			results = append(results, result)
 		}
-	case reflect.Struct:
-		var result []interface{}
-		for _, column := range columns {
-			result = append(result, indirectScopeValue.FieldByName(column).Interface())
-		}
-		return [][]interface{}{result}
 	}
 	return
 }