소스 검색

Test many to many relation with customized column

Jinzhu 8 년 전
부모
커밋
d87a960248
4개의 변경된 파일56개의 추가작업 그리고 12개의 파일을 삭제
  1. 1 1
      association.go
  2. 39 0
      customize_column_test.go
  3. 15 10
      join_table_handler.go
  4. 1 1
      model_struct.go

+ 1 - 1
association.go

@@ -394,7 +394,7 @@ func toQueryCondition(scope *Scope, columns []string) string {
 	if len(columns) > 1 {
 		return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
 	} else {
-		return strings.Join(columns, ",")
+		return strings.Join(newColumns, ",")
 	}
 }
 

+ 39 - 0
customize_column_test.go

@@ -63,3 +63,42 @@ func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
 		t.Errorf("Should not raise error: %s", err)
 	}
 }
+
+type CustomizePerson struct {
+	IdPerson string             `gorm:"column:idPerson;primary_key:true"`
+	Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
+}
+
+type CustomizeAccount struct {
+	IdAccount string `gorm:"column:idAccount;primary_key:true"`
+	Name      string
+}
+
+func TestManyToManyWithCustomizedColumn(t *testing.T) {
+	DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
+	DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
+
+	account := CustomizeAccount{IdAccount: "account", Name: "id1"}
+	person := CustomizePerson{
+		IdPerson: "person",
+		Accounts: []CustomizeAccount{account},
+	}
+
+	if err := DB.Create(&account).Error; err != nil {
+		t.Errorf("no error should happen, but got %v", err)
+	}
+
+	if err := DB.Create(&person).Error; err != nil {
+		t.Errorf("no error should happen, but got %v", err)
+	}
+
+	var person1 CustomizePerson
+	scope := DB.NewScope(nil)
+	if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
+		t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
+	}
+
+	if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
+		t.Errorf("should preload correct accounts")
+	}
+}

+ 15 - 10
join_table_handler.go

@@ -92,7 +92,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
 	var assignColumns, binVars, conditions []string
 	var values []interface{}
 	for key, value := range searchMap {
-		assignColumns = append(assignColumns, key)
+		assignColumns = append(assignColumns, scope.Quote(key))
 		binVars = append(binVars, `?`)
 		conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
 		values = append(values, value)
@@ -102,7 +102,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
 		values = append(values, value)
 	}
 
-	quotedTable := handler.Table(db)
+	quotedTable := scope.Quote(handler.Table(db))
 	sql := fmt.Sprintf(
 		"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
 		quotedTable,
@@ -117,11 +117,14 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
 }
 
 func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
-	var conditions []string
-	var values []interface{}
+	var (
+		scope      = db.NewScope(nil)
+		conditions []string
+		values     []interface{}
+	)
 
 	for key, value := range s.GetSearchMap(db, sources...) {
-		conditions = append(conditions, fmt.Sprintf("%v = ?", key))
+		conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
 		values = append(values, value)
 	}
 
@@ -129,12 +132,14 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
 }
 
 func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
-	quotedTable := handler.Table(db)
+	var (
+		scope          = db.NewScope(source)
+		modelType      = scope.GetModelStruct().ModelType
+		quotedTable    = scope.Quote(handler.Table(db))
+		joinConditions []string
+		values         []interface{}
+	)
 
-	scope := db.NewScope(source)
-	modelType := scope.GetModelStruct().ModelType
-	var joinConditions []string
-	var values []interface{}
 	if s.Source.ModelType == modelType {
 		destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
 		for _, foreignKey := range s.Destination.ForeignKeys {

+ 1 - 1
model_struct.go

@@ -99,7 +99,7 @@ type Relationship struct {
 
 func getForeignField(column string, fields []*StructField) *StructField {
 	for _, field := range fields {
-		if field.Name == column || field.DBName == ToDBName(column) {
+		if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) {
 			return field
 		}
 	}