Parcourir la source

Add support for polymorphic relationships using the POLYMORPHIC setting.

This commit adds support for two settings:

FOREIGNTYPE - A field that is used to store the type of the owner.

POLYMORPHIC - A shortcut to set FOREIGNKEY and FOREIGNTYPE to the same
value suffixed by "Id" and "Type" respectively.

The type is stored as the table name, which I thought might be useful
for other queries.

The biggest gotcha of this commit is that I flipped the definition of
has_one and belongs_to. gorm is very flexible such that it didn't
really care if it was a has_one or belongs_to, and can pretty much
determine it at runtime. For the sake of the error, I had to define
one of them as belongs_to, and I chose the one with the fields as
the belongs_to, like ActiveRecord. The error could probably be
genericized to "gorm cannot determine type", but I think it's nicer
to tell people DONT DO PATTERN XYZ CAUSE IT WONT WORK. Functionally,
it doesn't matter.
jnfeinstein il y a 9 ans
Parent
commit
8b451f0084
6 fichiers modifiés avec 78 ajouts et 23 suppressions
  1. 11 6
      association.go
  2. 17 0
      callback_shared.go
  3. 1 0
      field.go
  4. 2 1
      main.go
  5. 9 1
      scope.go
  6. 38 15
      scope_private.go

+ 11 - 6
association.go

@@ -7,11 +7,12 @@ import (
 )
 
 type Association struct {
-	Scope      *Scope
-	PrimaryKey interface{}
-	Column     string
-	Error      error
-	Field      *Field
+	Scope       *Scope
+	PrimaryKey  interface{}
+	PrimaryType interface{}
+	Column      string
+	Error       error
+	Field       *Field
 }
 
 func (association *Association) err(err error) *Association {
@@ -172,7 +173,11 @@ func (association *Association) Count() int {
 		scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count)
 	} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
 		whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey)))
-		scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count)
+		countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey)
+		if relationship.ForeignType != "" {
+			countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType)
+		}
+		countScope.Count(&count)
 	} else if relationship.Kind == "belongs_to" {
 		if v, err := scope.FieldValueByName(association.Column); err == nil {
 			whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey)))

+ 17 - 0
callback_shared.go

@@ -35,6 +35,10 @@ func SaveBeforeAssociations(scope *Scope) {
 				if relationship.ForeignKey != "" {
 					scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue())
 				}
+				if relationship.ForeignType != "" {
+					scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations"))
+					return
+				}
 			}
 		}
 	}
@@ -57,10 +61,17 @@ func SaveAfterAssociations(scope *Scope) {
 						if relationship.JoinTable == "" && relationship.ForeignKey != "" {
 							newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
 						}
+						if relationship.ForeignType != "" {
+							newDB.NewScope(elem).SetColumn(relationship.ForeignType, scope.TableName())
+						}
 
 						scope.Err(newDB.Save(elem).Error)
 
 						if relationship.JoinTable != "" {
+							if relationship.ForeignType != "" {
+								scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations"))
+							}
+
 							newScope := scope.New(elem)
 							joinTable := relationship.JoinTable
 							foreignKey := ToSnake(relationship.ForeignKey)
@@ -89,6 +100,9 @@ func SaveAfterAssociations(scope *Scope) {
 						if relationship.ForeignKey != "" {
 							newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
 						}
+						if relationship.ForeignType != "" {
+							newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName())
+						}
 						scope.Err(newDB.Save(value.Addr().Interface()).Error)
 					} else {
 						destValue := reflect.New(field.Field.Type()).Elem()
@@ -101,6 +115,9 @@ func SaveAfterAssociations(scope *Scope) {
 						if relationship.ForeignKey != "" {
 							newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
 						}
+						if relationship.ForeignType != "" {
+							newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName())
+						}
 						scope.Err(newDB.Save(elem).Error)
 						scope.SetColumn(field.Name, destValue.Interface())
 					}

+ 1 - 0
field.go

@@ -10,6 +10,7 @@ import (
 type relationship struct {
 	JoinTable             string
 	ForeignKey            string
+	ForeignType           string
 	AssociationForeignKey string
 	Kind                  string
 }

+ 2 - 1
main.go

@@ -406,6 +406,7 @@ func (s *DB) Association(column string) *Association {
 	scope := s.clone().NewScope(s.Value)
 
 	primaryKey := scope.PrimaryKeyValue()
+	primaryType := scope.TableName()
 	if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
 		scope.Err(errors.New("primary key can't be nil"))
 	}
@@ -420,7 +421,7 @@ func (s *DB) Association(column string) *Association {
 		scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
 	}
 
-	return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field}
+	return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field}
 }
 
 // Set set value by name

+ 9 - 1
scope.go

@@ -334,8 +334,15 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
 		scopeTyp := scope.IndirectValue().Type()
 
 		foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"])
+		foreignType := SnakeToUpperCamel(settings["FOREIGNTYPE"])
 		associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"])
 		many2many := settings["MANY2MANY"]
+		polymorphic := SnakeToUpperCamel(settings["POLYMORPHIC"])
+
+		if polymorphic != "" {
+			foreignKey = polymorphic + "Id"
+			foreignType = polymorphic + "Type"
+		}
 
 		switch indirectValue.Kind() {
 		case reflect.Slice:
@@ -359,6 +366,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
 				field.Relationship = &relationship{
 					JoinTable:             many2many,
 					ForeignKey:            foreignKey,
+					ForeignType:           foreignType,
 					AssociationForeignKey: associationForeignKey,
 					Kind: "has_many",
 				}
@@ -400,7 +408,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
 					kind = "has_one"
 				}
 
-				field.Relationship = &relationship{ForeignKey: foreignKey, Kind: kind}
+				field.Relationship = &relationship{ForeignKey: foreignKey, ForeignType: foreignType, Kind: kind}
 			}
 		default:
 			field.IsNormal = true

+ 38 - 15
scope_private.go

@@ -489,29 +489,52 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
 			foreignKey = keys[1]
 		}
 
+		var relationship *relationship
+		var field *Field
+		var scopeHasField bool
+		if field, scopeHasField = scope.FieldByName(foreignKey); scopeHasField {
+			relationship = field.Relationship
+		}
+
 		if scopeType == "" || scopeType == fromScopeType {
-			if field, ok := scope.FieldByName(foreignKey); ok {
-				relationship := field.Relationship
+			if scopeHasField {
 				if relationship != nil && relationship.ForeignKey != "" {
 					foreignKey = relationship.ForeignKey
+				}
 
-					if relationship.Kind == "many_to_many" {
-						joinSql := fmt.Sprintf(
-							"INNER JOIN %v ON %v.%v = %v.%v",
-							scope.Quote(relationship.JoinTable),
-							scope.Quote(relationship.JoinTable),
-							scope.Quote(ToSnake(relationship.AssociationForeignKey)),
-							toScope.QuotedTableName(),
-							scope.Quote(toScope.PrimaryKey()))
-						whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey)))
-						toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value)
-						return scope
+				if relationship != nil && relationship.Kind == "many_to_many" {
+					if relationship.ForeignType != "" {
+						scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations"))
 					}
+					joinSql := fmt.Sprintf(
+						"INNER JOIN %v ON %v.%v = %v.%v",
+						scope.Quote(relationship.JoinTable),
+						scope.Quote(relationship.JoinTable),
+						scope.Quote(ToSnake(relationship.AssociationForeignKey)),
+						toScope.QuotedTableName(),
+						scope.Quote(toScope.PrimaryKey()))
+					whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey)))
+					toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value)
+					return scope
 				}
 
-				// has one
+				// has many or has one
+				if toScope.HasColumn(foreignKey) {
+					toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))), scope.PrimaryKeyValue())
+					if relationship != nil && relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) {
+						toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
+					}
+					toScope.callCallbacks(scope.db.parent.callback.queries)
+					return scope
+				}
+
+				// belongs to
 				if foreignValue, err := scope.FieldValueByName(foreignKey); err == nil {
 					sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
+					if relationship != nil && relationship.ForeignType != "" && scope.HasColumn(relationship.ForeignType) {
+						scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations"))
+						return scope
+					}
 					toScope.inlineCondition(sql, foreignValue).callCallbacks(scope.db.parent.callback.queries)
 					return scope
 				}
@@ -519,7 +542,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
 		}
 
 		if scopeType == "" || scopeType == toScopeType {
-			// has many
+			// has many or has one in foreign scope
 			if toScope.HasColumn(foreignKey) {
 				sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey)))
 				return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries)