Przeglądaj źródła

Implement parse many2many relation

Jinzhu 4 lat temu
rodzic
commit
fd9b688084
5 zmienionych plików z 133 dodań i 69 usunięć
  1. 3 3
      schema/field.go
  2. 0 6
      schema/naming.go
  3. 102 60
      schema/relationship.go
  4. 5 0
      schema/utils.go
  5. 23 0
      schema/utils_test.go

+ 3 - 3
schema/field.go

@@ -103,11 +103,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 		field.DBName = dbName
 	}
 
-	if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) {
+	if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) {
 		field.PrimaryKey = true
 	}
 
-	if val, ok := field.TagSettings["AUTO_INCREMENT"]; ok && checkTruth(val) {
+	if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) {
 		field.AutoIncrement = true
 		field.HasDefaultValue = true
 	}
@@ -180,7 +180,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 		for _, ef := range field.EmbeddedSchema.Fields {
 			ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
 
-			if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok {
+			if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
 				ef.DBName = prefix + ef.DBName
 			}
 

+ 0 - 6
schema/naming.go

@@ -13,7 +13,6 @@ type Namer interface {
 	TableName(table string) string
 	ColumnName(column string) string
 	JoinTableName(table string) string
-	JoinTableColumnName(table, column string) string
 }
 
 // NamingStrategy tables, columns naming strategy
@@ -40,11 +39,6 @@ func (ns NamingStrategy) JoinTableName(str string) string {
 	return ns.TablePrefix + toDBName(str)
 }
 
-// JoinTableColumnName convert string to join table column name
-func (ns NamingStrategy) JoinTableColumnName(referenceTable, referenceColumn string) string {
-	return inflection.Singular(toDBName(referenceTable)) + toDBName(referenceColumn)
-}
-
 var (
 	smap sync.Map
 	// https://github.com/golang/lint/blob/master/lint.go#L770

+ 102 - 60
schema/relationship.go

@@ -57,7 +57,7 @@ func (schema *Schema) parseRelation(field *Field) {
 			Field:       field,
 			Schema:      schema,
 			ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
-			PrimaryKeys: toColumns(field.TagSettings["PRIMARYKEY"]),
+			PrimaryKeys: toColumns(field.TagSettings["REFERENCES"]),
 		}
 	)
 
@@ -65,63 +65,13 @@ func (schema *Schema) parseRelation(field *Field) {
 		return
 	}
 
-	// Parse Polymorphic relations
-	//
-	// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
-	//     type User struct {
-	//       Toys []Toy `gorm:"polymorphic:Owner;"`
-	//     }
-	//     type Pet struct {
-	//       Toy Toy `gorm:"polymorphic:Owner;"`
-	//     }
-	//     type Toy struct {
-	//       OwnerID   int
-	//       OwnerType string
-	//     }
 	if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
-		relation.Polymorphic = &Polymorphic{
-			Value:           schema.Table,
-			PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
-			PolymorphicID:   relation.FieldSchema.FieldsByName[polymorphic+"ID"],
-		}
-
-		if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
-			relation.Polymorphic.Value = strings.TrimSpace(value)
-		}
-
-		if relation.Polymorphic.PolymorphicType == nil {
-			schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
-		}
-
-		if relation.Polymorphic.PolymorphicID == nil {
-			schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
-		}
-
-		if schema.err == nil {
-			relation.References = append(relation.References, Reference{
-				PriamryValue: relation.Polymorphic.Value,
-				ForeignKey:   relation.Polymorphic.PolymorphicType,
-			})
-
-			primaryKeyField := schema.PrioritizedPrimaryField
-			if len(relation.ForeignKeys) > 0 {
-				if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 {
-					schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name)
-				}
-			}
-			relation.References = append(relation.References, Reference{
-				PriamryKey:    primaryKeyField,
-				ForeignKey:    relation.Polymorphic.PolymorphicType,
-				OwnPriamryKey: true,
-			})
-		}
-
-		relation.Type = "has"
+		schema.buildPolymorphicRelation(relation, field, polymorphic)
+	} else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" {
+		schema.buildMany2ManyRelation(relation, field, many2many)
 	} else {
 		switch field.FieldType.Kind() {
-		case reflect.Struct:
-			schema.guessRelation(relation, field, true)
-		case reflect.Slice:
+		case reflect.Struct, reflect.Slice:
 			schema.guessRelation(relation, field, true)
 		default:
 			schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
@@ -138,6 +88,102 @@ func (schema *Schema) parseRelation(field *Field) {
 	}
 }
 
+// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
+//     type User struct {
+//       Toys []Toy `gorm:"polymorphic:Owner;"`
+//     }
+//     type Pet struct {
+//       Toy Toy `gorm:"polymorphic:Owner;"`
+//     }
+//     type Toy struct {
+//       OwnerID   int
+//       OwnerType string
+//     }
+func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
+	relation.Polymorphic = &Polymorphic{
+		Value:           schema.Table,
+		PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
+		PolymorphicID:   relation.FieldSchema.FieldsByName[polymorphic+"ID"],
+	}
+
+	if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
+		relation.Polymorphic.Value = strings.TrimSpace(value)
+	}
+
+	if relation.Polymorphic.PolymorphicType == nil {
+		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
+	}
+
+	if relation.Polymorphic.PolymorphicID == nil {
+		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
+	}
+
+	if schema.err == nil {
+		relation.References = append(relation.References, Reference{
+			PriamryValue: relation.Polymorphic.Value,
+			ForeignKey:   relation.Polymorphic.PolymorphicType,
+		})
+
+		primaryKeyField := schema.PrioritizedPrimaryField
+		if len(relation.ForeignKeys) > 0 {
+			if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 {
+				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name)
+			}
+		}
+		relation.References = append(relation.References, Reference{
+			PriamryKey:    primaryKeyField,
+			ForeignKey:    relation.Polymorphic.PolymorphicType,
+			OwnPriamryKey: true,
+		})
+	}
+
+	relation.Type = "has"
+}
+
+func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
+	relation.Type = Many2Many
+
+	var (
+		joinTableFields []reflect.StructField
+		fieldsMap       = map[string]*Field{}
+	)
+
+	for _, s := range []*Schema{schema, relation.Schema} {
+		for _, primaryField := range s.PrimaryFields {
+			fieldName := s.Name + primaryField.Name
+			if _, ok := fieldsMap[fieldName]; ok {
+				if field.Name != s.Name {
+					fieldName = field.Name + primaryField.Name
+				} else {
+					fieldName = s.Name + primaryField.Name + "Reference"
+				}
+			}
+
+			fieldsMap[fieldName] = primaryField
+			joinTableFields = append(joinTableFields, reflect.StructField{
+				Name:    fieldName,
+				PkgPath: primaryField.StructField.PkgPath,
+				Type:    primaryField.StructField.Type,
+				Tag:     removeSettingFromTag(primaryField.StructField.Tag, "column"),
+			})
+		}
+	}
+
+	relation.JoinTable, schema.err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer)
+	relation.JoinTable.Name = many2many
+	relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
+
+	// build references
+	for _, f := range relation.JoinTable.Fields {
+		relation.References = append(relation.References, Reference{
+			PriamryKey:    fieldsMap[f.Name],
+			ForeignKey:    f,
+			OwnPriamryKey: schema == fieldsMap[f.Name].Schema,
+		})
+	}
+	return
+}
+
 func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) {
 	var (
 		primaryFields, foreignFields []*Field
@@ -214,10 +260,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
 	if guessHas {
 		relation.Type = "has"
 	} else {
-		relation.Type = "belongs_to"
+		relation.Type = BelongsTo
 	}
 }
-
-func (schema *Schema) parseMany2ManyRelation(relation *Relationship, field *Field) error {
-	return nil
-}

+ 5 - 0
schema/utils.go

@@ -2,6 +2,7 @@ package schema
 
 import (
 	"reflect"
+	"regexp"
 	"strings"
 )
 
@@ -38,3 +39,7 @@ func toColumns(val string) (results []string) {
 	}
 	return
 }
+
+func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag {
+	return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}"))
+}

+ 23 - 0
schema/utils_test.go

@@ -0,0 +1,23 @@
+package schema
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestRemoveSettingFromTag(t *testing.T) {
+	tags := map[string]string{
+		`gorm:"before:value;column:db;after:value" other:"before:value;column:db;after:value"`:  `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`,
+		`gorm:"before:value;column:db;" other:"before:value;column:db;after:value"`:             `gorm:"before:value;" other:"before:value;column:db;after:value"`,
+		`gorm:"before:value;column:db" other:"before:value;column:db;after:value"`:              `gorm:"before:value;" other:"before:value;column:db;after:value"`,
+		`gorm:"column:db" other:"before:value;column:db;after:value"`:                           `gorm:"" other:"before:value;column:db;after:value"`,
+		`gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`,
+		`gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`,
+	}
+
+	for k, v := range tags {
+		if string(removeSettingFromTag(reflect.StructTag(k), "column")) != v {
+			t.Errorf("%v after removeSettingFromTag should equal %v, but got %v", k, v, removeSettingFromTag(reflect.StructTag(k), "column"))
+		}
+	}
+}