瀏覽代碼

Keep Refactoring Association Mode

Jinzhu 8 年之前
父節點
當前提交
67874f9232
共有 2 個文件被更改,包括 64 次插入76 次删除
  1. 64 40
      association.go
  2. 0 36
      association_utils.go

+ 64 - 40
association.go

@@ -57,16 +57,23 @@ func (association *Association) Replace(values ...interface{}) *Association {
 			newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
 		}
 
-		// Relations except new created
+		// Delete Relations except new created
 		if len(values) > 0 {
 			var associationForeignFieldNames []string
 			if relationship.Kind == "many_to_many" {
-				associationForeignFieldNames = relationship.AssociationForeignFieldNames
+				// if many to many relations, get association fields name from association foreign keys
+				associationFields := scope.New(reflect.New(field.Type()).Interface()).Fields()
+				for _, dbName := range relationship.AssociationForeignFieldNames {
+					associationForeignFieldNames = append(associationForeignFieldNames, associationFields[dbName].Name)
+				}
 			} else {
-				associationForeignFieldNames = relationship.AssociationForeignDBNames
+				// If other relations, use primary keys
+				for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
+					associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
+				}
 			}
 
-			newPrimaryKeys := association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
+			newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
 
 			if len(newPrimaryKeys) > 0 {
 				sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
@@ -75,12 +82,25 @@ func (association *Association) Replace(values ...interface{}) *Association {
 		}
 
 		if relationship.Kind == "many_to_many" {
-			if sourcePrimaryKeys := association.getPrimaryKeys(relationship.ForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
+			// if many to many relations, delete related relations from join table
+
+			// get source fields name from source foreign keys
+			var (
+				sourceFields            = scope.Fields()
+				sourceForeignFieldNames []string
+			)
+
+			for _, dbName := range relationship.ForeignFieldNames {
+				sourceForeignFieldNames = append(sourceForeignFieldNames, sourceFields[dbName].Name)
+			}
+
+			if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
 				newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
 
 				association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
 			}
 		} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
+			// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
 			var foreignKeyMap = map[string]interface{}{}
 			for idx, foreignKey := range relationship.ForeignDBNames {
 				foreignKeyMap[foreignKey] = nil
@@ -110,11 +130,9 @@ func (association *Association) Delete(values ...interface{}) *Association {
 	}
 
 	var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
-	for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
-		if field.IsPrimaryKey {
-			deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
-			deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
-		}
+	for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
+		deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
+		deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
 	}
 
 	deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
@@ -127,8 +145,15 @@ func (association *Association) Delete(values ...interface{}) *Association {
 			}
 		}
 
+		// get association's foreign fields name
+		var associationFields = scope.New(reflect.New(field.Type()).Interface()).Fields()
+		var associationForeignFieldNames []string
+		for _, associationDBName := range relationship.AssociationForeignFieldNames {
+			associationForeignFieldNames = append(associationForeignFieldNames, associationFields[associationDBName].Name)
+		}
+
 		// association value's foreign keys
-		deletingPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
+		deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
 		sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
 		newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
 
@@ -147,7 +172,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
 				toQueryValues(primaryKeys)...,
 			)
 
-			// set foreign key to be null
+			// set foreign key to be null if there are some records affected
 			modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
 			if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
 				if results.RowsAffected > 0 {
@@ -176,28 +201,29 @@ func (association *Association) Delete(values ...interface{}) *Association {
 		}
 	}
 
-	// Remove deleted records from field
+	// Remove deleted records from source's field
 	if association.Error == nil {
 		if association.Field.Field.Kind() == reflect.Slice {
 			leftValues := reflect.Zero(association.Field.Field.Type())
 
 			for i := 0; i < association.Field.Field.Len(); i++ {
 				reflectValue := association.Field.Field.Index(i)
-				primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
-				var included = false
+				primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
+				var isDeleted = false
 				for _, pk := range deletingPrimaryKeys {
 					if equalAsString(primaryKey, pk) {
-						included = true
+						isDeleted = true
+						break
 					}
 				}
-				if !included {
+				if !isDeleted {
 					leftValues = reflect.Append(leftValues, reflectValue)
 				}
 			}
 
 			association.Field.Set(leftValues)
 		} else if association.Field.Field.Kind() == reflect.Struct {
-			primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
+			primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
 			for _, pk := range deletingPrimaryKeys {
 				if equalAsString(primaryKey, pk) {
 					association.Field.Set(reflect.Zero(association.Field.Field.Type()))
@@ -222,34 +248,32 @@ func (association *Association) Count() int {
 		relationship = association.Field.Relationship
 		scope        = association.Scope
 		fieldValue   = association.Field.Field.Interface()
-		newScope     = scope.New(fieldValue)
+		query        = scope.DB()
 	)
 
 	if relationship.Kind == "many_to_many" {
-		relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count)
+		query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value)
 	} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
-		query := scope.DB()
-		for idx, foreignKey := range relationship.ForeignDBNames {
-			if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
-				query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
-					field.Field.Interface())
-			}
-		}
-
-		if relationship.PolymorphicType != "" {
-			query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
-		}
-		query.Model(fieldValue).Count(&count)
+		primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
+		query = query.Where(
+			fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
+			toQueryValues(primaryKeys)...,
+		)
 	} else if relationship.Kind == "belongs_to" {
-		query := scope.DB()
-		for idx, primaryKey := range relationship.AssociationForeignDBNames {
-			if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok {
-				query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)),
-					field.Field.Interface())
-			}
-		}
-		query.Model(fieldValue).Count(&count)
+		primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
+		query = query.Where(
+			fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
+			toQueryValues(primaryKeys)...,
+		)
+	}
+
+	if relationship.PolymorphicType != "" {
+		query = query.Where(
+			fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
+			scope.TableName(),
+		)
 	}
 
+	query.Model(fieldValue).Count(&count)
 	return count
 }

+ 0 - 36
association_utils.go

@@ -82,42 +82,6 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
 	return association
 }
 
-func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
-	scope := association.Scope
-
-	for _, value := range values {
-		reflectValue := reflect.Indirect(reflect.ValueOf(value))
-		if reflectValue.Kind() == reflect.Slice {
-			for i := 0; i < reflectValue.Len(); i++ {
-				primaryKeys := []interface{}{}
-				newScope := scope.New(reflectValue.Index(i).Interface())
-				for _, column := range columns {
-					if field, ok := newScope.FieldByName(column); ok {
-						primaryKeys = append(primaryKeys, field.Field.Interface())
-					} else {
-						primaryKeys = append(primaryKeys, "")
-					}
-				}
-				results = append(results, primaryKeys)
-			}
-		} else if reflectValue.Kind() == reflect.Struct {
-			newScope := scope.New(value)
-			var primaryKeys []interface{}
-			for _, column := range columns {
-				if field, ok := newScope.FieldByName(column); ok {
-					primaryKeys = append(primaryKeys, field.Field.Interface())
-				} else {
-					primaryKeys = append(primaryKeys, "")
-				}
-			}
-
-			results = append(results, primaryKeys)
-		}
-	}
-
-	return
-}
-
 func toQueryMarks(primaryValues [][]interface{}) string {
 	var results []string