Kaynağa Gözat

save_associations:true should store related item (#2067)

* save_associations:true should store related item, save_associations priority on related objects

* code quality
Ikhtiyor 5 yıl önce
ebeveyn
işleme
d3e666a1e0
3 değiştirilmiş dosya ile 100 ekleme ve 4 silme
  1. 3 3
      callback_save.go
  2. 88 0
      main_test.go
  3. 9 1
      migration_test.go

+ 3 - 3
callback_save.go

@@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
 
 		if v, ok := value.(string); ok {
 			v = strings.ToLower(v)
-			if v == "false" || v != "skip" {
-				return false
-			}
+			return v == "true"
 		}
 
 		return true
@@ -36,9 +34,11 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
 			if value, ok := scope.Get("gorm:save_associations"); ok {
 				autoUpdate = checkTruth(value)
 				autoCreate = autoUpdate
+				saveReference = autoUpdate
 			} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
 				autoUpdate = checkTruth(value)
 				autoCreate = autoUpdate
+				saveReference = autoUpdate
 			}
 
 			if value, ok := scope.Get("gorm:association_autoupdate"); ok {

+ 88 - 0
main_test.go

@@ -933,6 +933,94 @@ func TestOpenWithOneParameter(t *testing.T) {
 	}
 }
 
+func TestSaveAssociations(t *testing.T) {
+	db := DB.New()
+	deltaAddressCount := 0
+	if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil {
+		t.Errorf("failed to fetch address count")
+		t.FailNow()
+	}
+
+	placeAddress := &Address{
+		Address1: "somewhere on earth",
+	}
+	ownerAddress1 := &Address{
+		Address1: "near place address",
+	}
+	ownerAddress2 := &Address{
+		Address1: "address2",
+	}
+	db.Create(placeAddress)
+
+	addressCountShouldBe := func(t *testing.T, expectedCount int) {
+		countFromDB := 0
+		t.Helper()
+		err := db.Model(&Address{}).Count(&countFromDB).Error
+		if err != nil {
+			t.Error("failed to fetch address count")
+		}
+		if countFromDB != expectedCount {
+			t.Errorf("address count mismatch: %d", countFromDB)
+		}
+	}
+	addressCountShouldBe(t, deltaAddressCount+1)
+
+	// owner address should be created, place address should be reused
+	place1 := &Place{
+		PlaceAddressID: placeAddress.ID,
+		PlaceAddress:   placeAddress,
+		OwnerAddress:   ownerAddress1,
+	}
+	err := db.Create(place1).Error
+	if err != nil {
+		t.Errorf("failed to store place: %s", err.Error())
+	}
+	addressCountShouldBe(t, deltaAddressCount+2)
+
+	// owner address should be created again, place address should be reused
+	place2 := &Place{
+		PlaceAddressID: placeAddress.ID,
+		PlaceAddress: &Address{
+			ID:       777,
+			Address1: "address1",
+		},
+		OwnerAddress:   ownerAddress2,
+		OwnerAddressID: 778,
+	}
+	err = db.Create(place2).Error
+	if err != nil {
+		t.Errorf("failed to store place: %s", err.Error())
+	}
+	addressCountShouldBe(t, deltaAddressCount+3)
+
+	count := 0
+	db.Model(&Place{}).Where(&Place{
+		PlaceAddressID: placeAddress.ID,
+		OwnerAddressID: ownerAddress1.ID,
+	}).Count(&count)
+	if count != 1 {
+		t.Errorf("only one instance of (%d, %d) should be available, found: %d",
+			placeAddress.ID, ownerAddress1.ID, count)
+	}
+
+	db.Model(&Place{}).Where(&Place{
+		PlaceAddressID: placeAddress.ID,
+		OwnerAddressID: ownerAddress2.ID,
+	}).Count(&count)
+	if count != 1 {
+		t.Errorf("only one instance of (%d, %d) should be available, found: %d",
+			placeAddress.ID, ownerAddress2.ID, count)
+	}
+
+	db.Model(&Place{}).Where(&Place{
+		PlaceAddressID: placeAddress.ID,
+	}).Count(&count)
+	if count != 2 {
+		t.Errorf("two instances of (%d) should be available, found: %d",
+			placeAddress.ID, count)
+	}
+}
+
 func TestBlockGlobalUpdate(t *testing.T) {
 	db := DB.New()
 	db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})

+ 9 - 1
migration_test.go

@@ -118,6 +118,14 @@ type Company struct {
 	Owner *User `sql:"-"`
 }
 
+type Place struct {
+	Id             int64
+	PlaceAddressID int
+	PlaceAddress   *Address `gorm:"save_associations:false"`
+	OwnerAddressID int
+	OwnerAddress   *Address `gorm:"save_associations:true"`
+}
+
 type EncryptedData []byte
 
 func (data *EncryptedData) Scan(value interface{}) error {
@@ -284,7 +292,7 @@ func runMigration() {
 		DB.Exec(fmt.Sprintf("drop table %v;", table))
 	}
 
-	values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
+	values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
 	for _, value := range values {
 		DB.DropTable(value)
 	}