Pārlūkot izejas kodu

Add tag association_autoupdate, association_autocreate, association_save_reference support

Jinzhu 6 gadi atpakaļ
vecāks
revīzija
b2b568daa8
4 mainītis faili ar 241 papildinājumiem un 59 dzēšanām
  1. 138 9
      association_test.go
  2. 102 37
      callback_save.go
  3. 1 1
      query_test.go
  4. 0 12
      scope.go

+ 138 - 9
association_test.go

@@ -885,7 +885,7 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) {
 	DB.Save(&category)
 }
 
-func TestSkipSaveAssociation(t *testing.T) {
+func TestAutoSaveBelongsToAssociation(t *testing.T) {
 	type Company struct {
 		gorm.Model
 		Name string
@@ -895,27 +895,156 @@ func TestSkipSaveAssociation(t *testing.T) {
 		gorm.Model
 		Name      string
 		CompanyID uint
-		Company   Company `gorm:"save_associations:false"`
+		Company   Company `gorm:"association_autoupdate:false;association_autocreate:false;"`
 	}
+
+	DB.Where("name = ?", "auto_save_association").Delete(&Company{})
 	DB.AutoMigrate(&Company{}, &User{})
 
-	DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}})
+	DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}})
 
-	if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() {
-		t.Errorf("Company skip_save_association should not have been saved")
+	if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company auto_save_association should not have been saved when autosave is false")
 	}
 
 	// if foreign key is set, this should be saved even if association isn't
-	company := Company{Name: "skip_save_association"}
+	company := Company{Name: "auto_save_association"}
 	DB.Save(&company)
-	company.Name = "skip_save_association_modified"
+
+	company.Name = "auto_save_association_new_name"
 	user := User{Name: "jinzhu", Company: company}
+
 	DB.Save(&user)
 
-	if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() {
-		t.Errorf("Company skip_save_association should not have been updated")
+	if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company should not have been updated")
 	}
+
 	if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() {
 		t.Errorf("User's foreign key should have been saved")
 	}
+
+	user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}}
+	DB.Set("gorm:association_autocreate", true).Save(&user2)
+	if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company auto_save_association_2 should been created when autocreate is true")
+	}
+
+	user2.Company.Name = "auto_save_association_2_newname"
+	DB.Set("gorm:association_autoupdate", true).Save(&user2)
+
+	if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company should been updated")
+	}
+}
+
+func TestAutoSaveHasOneAssociation(t *testing.T) {
+	type Company struct {
+		gorm.Model
+		UserID uint
+		Name   string
+	}
+
+	type User struct {
+		gorm.Model
+		Name    string
+		Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"`
+	}
+
+	DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{})
+	DB.AutoMigrate(&Company{}, &User{})
+
+	DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}})
+
+	if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false")
+	}
+
+	company := Company{Name: "auto_save_has_one_association"}
+	DB.Save(&company)
+
+	company.Name = "auto_save_has_one_association_new_name"
+	user := User{Name: "jinzhu", Company: company}
+
+	DB.Save(&user)
+
+	if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company should not have been updated")
+	}
+
+	if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() {
+		t.Errorf("Company should not have been updated")
+	}
+
+	if user.Company.UserID == 0 {
+		t.Errorf("UserID should be assigned")
+	}
+
+	company.Name = "auto_save_has_one_association_2_new_name"
+	DB.Set("gorm:association_autoupdate", true).Save(&user)
+
+	if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() {
+		t.Errorf("Company should been updated")
+	}
+
+	user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}}
+	DB.Set("gorm:association_autocreate", true).Save(&user2)
+	if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true")
+	}
+}
+
+func TestAutoSaveMany2ManyAssociation(t *testing.T) {
+	type Company struct {
+		gorm.Model
+		Name string
+	}
+
+	type User struct {
+		gorm.Model
+		Name      string
+		Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"`
+	}
+
+	DB.AutoMigrate(&Company{}, &User{})
+
+	DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}})
+
+	if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false")
+	}
+
+	company := Company{Name: "auto_save_m2m_association"}
+	DB.Save(&company)
+
+	company.Name = "auto_save_m2m_association_new_name"
+	user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}}
+
+	DB.Save(&user)
+
+	if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company should not have been updated")
+	}
+
+	if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company should not been created")
+	}
+
+	if DB.Model(&user).Association("Companies").Count() != 1 {
+		t.Errorf("Relationship should been saved")
+	}
+
+	DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user)
+
+	if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company should been updated")
+	}
+
+	if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() {
+		t.Errorf("Company should been created")
+	}
+
+	if DB.Model(&user).Association("Companies").Count() != 2 {
+		t.Errorf("Relationship should been updated")
+	}
 }

+ 102 - 37
callback_save.go

@@ -1,6 +1,9 @@
 package gorm
 
-import "reflect"
+import (
+	"reflect"
+	"strings"
+)
 
 func beginTransactionCallback(scope *Scope) {
 	scope.Begin()
@@ -10,33 +13,79 @@ func commitOrRollbackTransactionCallback(scope *Scope) {
 	scope.CommitOrRollback()
 }
 
-func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) {
+func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) {
+	checkTruth := func(value interface{}) bool {
+		if v, ok := value.(bool); ok && !v {
+			return false
+		}
+
+		if v, ok := value.(string); ok {
+			v = strings.ToLower(v)
+			if v == "false" || v != "skip" {
+				return false
+			}
+		}
+
+		return true
+	}
+
 	if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
-		if field.Relationship != nil {
-			if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; (!ok || (value != "false" && value != "skip")) && scope.allowSaveAssociations() {
-				return true, field.Relationship
+		if r = field.Relationship; r != nil {
+			autoUpdate, autoCreate, saveReference = true, true, true
+
+			if value, ok := scope.Get("gorm:save_associations"); ok {
+				autoUpdate = checkTruth(value)
+				autoCreate = autoUpdate
+			} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
+				autoUpdate = checkTruth(value)
+				autoCreate = autoUpdate
+			}
+
+			if value, ok := scope.Get("gorm:association_autoupdate"); ok {
+				autoUpdate = checkTruth(value)
+			} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
+				autoUpdate = checkTruth(value)
+			}
+
+			if value, ok := scope.Get("gorm:association_autocreate"); ok {
+				autoCreate = checkTruth(value)
+			} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
+				autoCreate = checkTruth(value)
+			}
+
+			if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
+				saveReference = checkTruth(value)
 			}
-			return false, field.Relationship
 		}
 	}
-	return false, nil
+
+	return
 }
 
 func saveBeforeAssociationsCallback(scope *Scope) {
 	for _, field := range scope.Fields() {
-		if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && relationship.Kind == "belongs_to" {
+		autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
+
+		if relationship != nil && relationship.Kind == "belongs_to" {
 			fieldValue := field.Field.Addr().Interface()
+			newScope := scope.New(fieldValue)
 
-			if allowSaveAssociation {
+			if newScope.PrimaryKeyZero() {
+				if autoCreate {
+					scope.Err(scope.NewDB().Save(fieldValue).Error)
+				}
+			} else if autoUpdate {
 				scope.Err(scope.NewDB().Save(fieldValue).Error)
 			}
 
-			if len(relationship.ForeignFieldNames) != 0 {
-				// set value's foreign key
-				for idx, fieldName := range relationship.ForeignFieldNames {
-					associationForeignName := relationship.AssociationForeignDBNames[idx]
-					if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
-						scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
+			if saveReference {
+				if len(relationship.ForeignFieldNames) != 0 {
+					// set value's foreign key
+					for idx, fieldName := range relationship.ForeignFieldNames {
+						associationForeignName := relationship.AssociationForeignDBNames[idx]
+						if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
+							scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
+						}
 					}
 				}
 			}
@@ -46,8 +95,9 @@ func saveBeforeAssociationsCallback(scope *Scope) {
 
 func saveAfterAssociationsCallback(scope *Scope) {
 	for _, field := range scope.Fields() {
-		if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil &&
-			(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
+		autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
+
+		if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
 			value := field.Field
 
 			switch value.Kind() {
@@ -57,44 +107,59 @@ func saveAfterAssociationsCallback(scope *Scope) {
 					elem := value.Index(i).Addr().Interface()
 					newScope := newDB.NewScope(elem)
 
-					if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
-						for idx, fieldName := range relationship.ForeignFieldNames {
-							associationForeignName := relationship.AssociationForeignDBNames[idx]
-							if f, ok := scope.FieldByName(associationForeignName); ok {
-								scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
+					if saveReference {
+						if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
+							for idx, fieldName := range relationship.ForeignFieldNames {
+								associationForeignName := relationship.AssociationForeignDBNames[idx]
+								if f, ok := scope.FieldByName(associationForeignName); ok {
+									scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
+								}
 							}
 						}
-					}
 
-					if relationship.PolymorphicType != "" {
-						scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
+						if relationship.PolymorphicType != "" {
+							scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
+						}
 					}
 
-					if allowSaveAssociation {
+					if newScope.PrimaryKeyZero() {
+						if autoCreate {
+							scope.Err(newDB.Save(elem).Error)
+						}
+					} else if autoUpdate {
 						scope.Err(newDB.Save(elem).Error)
 					}
 
-					if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil && !newScope.PrimaryKeyZero() {
-						scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
+					if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference {
+						if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
+							scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
+						}
 					}
 				}
 			default:
 				elem := value.Addr().Interface()
 				newScope := scope.New(elem)
-				if len(relationship.ForeignFieldNames) != 0 {
-					for idx, fieldName := range relationship.ForeignFieldNames {
-						associationForeignName := relationship.AssociationForeignDBNames[idx]
-						if f, ok := scope.FieldByName(associationForeignName); ok {
-							scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
+
+				if saveReference {
+					if len(relationship.ForeignFieldNames) != 0 {
+						for idx, fieldName := range relationship.ForeignFieldNames {
+							associationForeignName := relationship.AssociationForeignDBNames[idx]
+							if f, ok := scope.FieldByName(associationForeignName); ok {
+								scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
+							}
 						}
 					}
-				}
 
-				if relationship.PolymorphicType != "" {
-					scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
+					if relationship.PolymorphicType != "" {
+						scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
+					}
 				}
 
-				if allowSaveAssociation {
+				if newScope.PrimaryKeyZero() {
+					if autoCreate {
+						scope.Err(scope.NewDB().Save(elem).Error)
+					}
+				} else if autoUpdate {
 					scope.Err(scope.NewDB().Save(elem).Error)
 				}
 			}

+ 1 - 1
query_test.go

@@ -389,7 +389,7 @@ func TestOffset(t *testing.T) {
 		DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)})
 	}
 	var users1, users2, users3, users4 []User
-	DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
+	DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
 
 	if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
 		t.Errorf("Offset should work")

+ 0 - 12
scope.go

@@ -993,18 +993,6 @@ func (scope *Scope) changeableField(field *Field) bool {
 	return true
 }
 
-func (scope *Scope) allowSaveAssociations() bool {
-	if saveAssociations, ok := scope.Get("gorm:save_associations"); ok {
-		if v, ok := saveAssociations.(bool); ok && !v {
-			return false
-		}
-		if v, ok := saveAssociations.(string); ok && (v != "skip") {
-			return false
-		}
-	}
-	return true && !scope.HasError()
-}
-
 func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
 	toScope := scope.db.NewScope(value)
 	tx := scope.db.Set("gorm:association:source", scope.Value)