瀏覽代碼

Merge pull request #303 from jnfeinstein/dev_poly

Support polymorphic has-one and has-many associations
Jinzhu 9 年之前
父節點
當前提交
6d13ae4ead
共有 8 個文件被更改,包括 170 次插入23 次删除
  1. 27 0
      README.md
  2. 11 6
      association.go
  3. 65 0
      association_test.go
  4. 17 0
      callback_shared.go
  5. 1 0
      field.go
  6. 2 1
      main.go
  7. 9 1
      scope.go
  8. 38 15
      scope_private.go

+ 27 - 0
README.md

@@ -17,6 +17,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
 * Iteration Support via [Rows](#row--rows)
 * Scopes
 * sql.Scanner support
+* Polymorphism
 * Every feature comes with tests
 * Convention Over Configuration
 * Developer Friendly
@@ -507,6 +508,32 @@ db.Model(&user).Association("Languages").Clear()
 // Remove all relations between the user and languages
 ```
 
+### Polymorphism
+
+Supports polymorphic has-many and has-one associations.
+
+```go
+  type Cat struct {
+    Id    int
+    Name  string
+    Toy   Toy `gorm:"polymorphic:Owner;"`
+  }
+
+  type Dog struct {
+    Id   int
+    Name string
+    Toy  Toy `gorm:"polymorphic:Owner;"`
+  }
+
+  type Toy struct {
+    Id        int
+    Name      string
+    OwnerId   int
+    OwnerType int
+  }
+```
+Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors.
+
 ## Advanced Usage
 
 ## FirstOrInit

+ 11 - 6
association.go

@@ -7,11 +7,12 @@ import (
 )
 
 type Association struct {
-	Scope      *Scope
-	PrimaryKey interface{}
-	Column     string
-	Error      error
-	Field      *Field
+	Scope       *Scope
+	PrimaryKey  interface{}
+	PrimaryType interface{}
+	Column      string
+	Error       error
+	Field       *Field
 }
 
 func (association *Association) err(err error) *Association {
@@ -172,7 +173,11 @@ func (association *Association) Count() int {
 		scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count)
 	} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
 		whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey)))
-		scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count)
+		countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey)
+		if relationship.ForeignType != "" {
+			countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType)
+		}
+		countScope.Count(&count)
 	} else if relationship.Kind == "belongs_to" {
 		if v, err := scope.FieldValueByName(association.Column); err == nil {
 			whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey)))

+ 65 - 0
association_test.go

@@ -1,6 +1,29 @@
 package gorm_test
 
 import "testing"
+import "github.com/jinzhu/gorm"
+
+type Cat struct {
+	Id   int
+	Name string
+	Toy  Toy `gorm:"polymorphic:Owner;"`
+}
+
+type Dog struct {
+	Id   int
+	Name string
+	Toys []Toy `gorm:"polymorphic:Owner;"`
+}
+
+type Toy struct {
+	Id        int
+	Name      string
+	OwnerId   int
+	OwnerType string
+
+	// Define the owner type as a belongs_to so we can ensure it throws an error
+	Owner Dog `gorm:"foreignkey:owner_id; foreigntype:owner_type;"`
+}
 
 func TestHasOneAndHasManyAssociation(t *testing.T) {
 	DB.DropTable(Category{})
@@ -208,3 +231,45 @@ func TestManyToMany(t *testing.T) {
 		t.Errorf("Relations should be cleared")
 	}
 }
+
+func TestPolymorphic(t *testing.T) {
+	DB.DropTableIfExists(Cat{})
+	DB.DropTableIfExists(Dog{})
+	DB.DropTableIfExists(Toy{})
+
+	DB.AutoMigrate(&Cat{})
+	DB.AutoMigrate(&Dog{})
+	DB.AutoMigrate(&Toy{})
+
+	cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat nip"}}
+	dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "orange ball"}, Toy{Name: "yellow ball"}}}
+	DB.Save(&cat).Save(&dog)
+
+	var catToys []Toy
+	if err := DB.Model(&cat).Related(&catToys, "Toy").Error; err == gorm.RecordNotFound {
+		t.Errorf("Did not find any has one polymorphic association")
+	} else if len(catToys) != 1 {
+		t.Errorf("Should have found only one polymorphic has one association")
+	} else if catToys[0].Name != cat.Toy.Name {
+		t.Errorf("Should have found the proper has one polymorphic association")
+	}
+
+	var dogToys []Toy
+	if err := DB.Model(&dog).Related(&dogToys, "Toys").Error; err == gorm.RecordNotFound {
+		t.Errorf("Did not find any polymorphic has many associations")
+	} else if len(dogToys) != len(dog.Toys) {
+		t.Errorf("Should have found all polymorphic has many associations")
+	}
+
+	if DB.Model(&cat).Association("Toy").Count() != 1 {
+		t.Errorf("Should return one polymorphic has one association")
+	}
+
+	if DB.Model(&dog).Association("Toys").Count() != 2 {
+		t.Errorf("Should return two polymorphic has many associations")
+	}
+
+	if DB.Model(&Toy{OwnerId: dog.Id, OwnerType: "dog"}).Related(&dog, "Owner").Error == nil {
+		t.Errorf("Should have thrown unsupported belongs_to error")
+	}
+}

+ 17 - 0
callback_shared.go

@@ -35,6 +35,10 @@ func SaveBeforeAssociations(scope *Scope) {
 				if relationship.ForeignKey != "" {
 					scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue())
 				}
+				if relationship.ForeignType != "" {
+					scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations"))
+					return
+				}
 			}
 		}
 	}
@@ -57,10 +61,17 @@ func SaveAfterAssociations(scope *Scope) {
 						if relationship.JoinTable == "" && relationship.ForeignKey != "" {
 							newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
 						}
+						if relationship.ForeignType != "" {
+							newDB.NewScope(elem).SetColumn(relationship.ForeignType, scope.TableName())
+						}
 
 						scope.Err(newDB.Save(elem).Error)
 
 						if relationship.JoinTable != "" {
+							if relationship.ForeignType != "" {
+								scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations"))
+							}
+
 							newScope := scope.New(elem)
 							joinTable := relationship.JoinTable
 							foreignKey := ToSnake(relationship.ForeignKey)
@@ -89,6 +100,9 @@ func SaveAfterAssociations(scope *Scope) {
 						if relationship.ForeignKey != "" {
 							newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
 						}
+						if relationship.ForeignType != "" {
+							newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName())
+						}
 						scope.Err(newDB.Save(value.Addr().Interface()).Error)
 					} else {
 						destValue := reflect.New(field.Field.Type()).Elem()
@@ -101,6 +115,9 @@ func SaveAfterAssociations(scope *Scope) {
 						if relationship.ForeignKey != "" {
 							newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
 						}
+						if relationship.ForeignType != "" {
+							newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName())
+						}
 						scope.Err(newDB.Save(elem).Error)
 						scope.SetColumn(field.Name, destValue.Interface())
 					}

+ 1 - 0
field.go

@@ -10,6 +10,7 @@ import (
 type relationship struct {
 	JoinTable             string
 	ForeignKey            string
+	ForeignType           string
 	AssociationForeignKey string
 	Kind                  string
 }

+ 2 - 1
main.go

@@ -406,6 +406,7 @@ func (s *DB) Association(column string) *Association {
 	scope := s.clone().NewScope(s.Value)
 
 	primaryKey := scope.PrimaryKeyValue()
+	primaryType := scope.TableName()
 	if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
 		scope.Err(errors.New("primary key can't be nil"))
 	}
@@ -420,7 +421,7 @@ func (s *DB) Association(column string) *Association {
 		scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
 	}
 
-	return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field}
+	return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field}
 }
 
 // Set set value by name

+ 9 - 1
scope.go

@@ -334,8 +334,15 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
 		scopeTyp := scope.IndirectValue().Type()
 
 		foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"])
+		foreignType := SnakeToUpperCamel(settings["FOREIGNTYPE"])
 		associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"])
 		many2many := settings["MANY2MANY"]
+		polymorphic := SnakeToUpperCamel(settings["POLYMORPHIC"])
+
+		if polymorphic != "" {
+			foreignKey = polymorphic + "Id"
+			foreignType = polymorphic + "Type"
+		}
 
 		switch indirectValue.Kind() {
 		case reflect.Slice:
@@ -359,6 +366,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
 				field.Relationship = &relationship{
 					JoinTable:             many2many,
 					ForeignKey:            foreignKey,
+					ForeignType:           foreignType,
 					AssociationForeignKey: associationForeignKey,
 					Kind: "has_many",
 				}
@@ -400,7 +408,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
 					kind = "has_one"
 				}
 
-				field.Relationship = &relationship{ForeignKey: foreignKey, Kind: kind}
+				field.Relationship = &relationship{ForeignKey: foreignKey, ForeignType: foreignType, Kind: kind}
 			}
 		default:
 			field.IsNormal = true

+ 38 - 15
scope_private.go

@@ -489,29 +489,52 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
 			foreignKey = keys[1]
 		}
 
+		var relationship *relationship
+		var field *Field
+		var scopeHasField bool
+		if field, scopeHasField = scope.FieldByName(foreignKey); scopeHasField {
+			relationship = field.Relationship
+		}
+
 		if scopeType == "" || scopeType == fromScopeType {
-			if field, ok := scope.FieldByName(foreignKey); ok {
-				relationship := field.Relationship
+			if scopeHasField {
 				if relationship != nil && relationship.ForeignKey != "" {
 					foreignKey = relationship.ForeignKey
+				}
 
-					if relationship.Kind == "many_to_many" {
-						joinSql := fmt.Sprintf(
-							"INNER JOIN %v ON %v.%v = %v.%v",
-							scope.Quote(relationship.JoinTable),
-							scope.Quote(relationship.JoinTable),
-							scope.Quote(ToSnake(relationship.AssociationForeignKey)),
-							toScope.QuotedTableName(),
-							scope.Quote(toScope.PrimaryKey()))
-						whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey)))
-						toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value)
-						return scope
+				if relationship != nil && relationship.Kind == "many_to_many" {
+					if relationship.ForeignType != "" {
+						scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations"))
 					}
+					joinSql := fmt.Sprintf(
+						"INNER JOIN %v ON %v.%v = %v.%v",
+						scope.Quote(relationship.JoinTable),
+						scope.Quote(relationship.JoinTable),
+						scope.Quote(ToSnake(relationship.AssociationForeignKey)),
+						toScope.QuotedTableName(),
+						scope.Quote(toScope.PrimaryKey()))
+					whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey)))
+					toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value)
+					return scope
 				}
 
-				// has one
+				// has many or has one
+				if toScope.HasColumn(foreignKey) {
+					toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))), scope.PrimaryKeyValue())
+					if relationship != nil && relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) {
+						toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
+					}
+					toScope.callCallbacks(scope.db.parent.callback.queries)
+					return scope
+				}
+
+				// belongs to
 				if foreignValue, err := scope.FieldValueByName(foreignKey); err == nil {
 					sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
+					if relationship != nil && relationship.ForeignType != "" && scope.HasColumn(relationship.ForeignType) {
+						scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations"))
+						return scope
+					}
 					toScope.inlineCondition(sql, foreignValue).callCallbacks(scope.db.parent.callback.queries)
 					return scope
 				}
@@ -519,7 +542,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
 		}
 
 		if scopeType == "" || scopeType == toScopeType {
-			// has many
+			// has many or has one in foreign scope
 			if toScope.HasColumn(foreignKey) {
 				sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey)))
 				return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries)