Ver Fonte

Test parse schema relations

Jinzhu há 4 anos atrás
pai
commit
a4a0895a85
7 ficheiros alterados com 239 adições e 92 exclusões
  1. 4 4
      logger/logger.go
  2. 6 1
      schema/field.go
  3. 41 17
      schema/relationship.go
  4. 41 17
      schema/schema.go
  5. 123 0
      schema/schema_helper_test.go
  6. 23 52
      schema/schema_test.go
  7. 1 1
      tests/model.go

+ 4 - 4
logger/logger.go

@@ -8,7 +8,7 @@ import (
 
 type LogLevel int
 
-var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", 0)}
+var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)}
 
 const (
 	Info LogLevel = iota + 1
@@ -40,21 +40,21 @@ func (logger Logger) LogMode(level LogLevel) Interface {
 
 // Info print info
 func (logger Logger) Info(msg string, data ...interface{}) {
-	if logger.logLevel >= Info {
+	if logger.logLevel <= Info {
 		logger.Print("[info] " + fmt.Sprintf(msg, data...))
 	}
 }
 
 // Warn print warn messages
 func (logger Logger) Warn(msg string, data ...interface{}) {
-	if logger.logLevel >= Warn {
+	if logger.logLevel <= Warn {
 		logger.Print("[warn] " + fmt.Sprintf(msg, data...))
 	}
 }
 
 // Error print error messages
 func (logger Logger) Error(msg string, data ...interface{}) {
-	if logger.logLevel >= Error {
+	if logger.logLevel <= Error {
 		logger.Print("[error] " + fmt.Sprintf(msg, data...))
 	}
 }

+ 6 - 1
schema/field.go

@@ -176,7 +176,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 	}
 
 	if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
-		field.EmbeddedSchema, schema.err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer)
+		var err error
+		field.Creatable = false
+		field.Updatable = false
+		if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
+			schema.err = err
+		}
 		for _, ef := range field.EmbeddedSchema.Fields {
 			ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
 

+ 41 - 17
schema/relationship.go

@@ -33,7 +33,7 @@ type Relationship struct {
 	Schema                   *Schema
 	FieldSchema              *Schema
 	JoinTable                *Schema
-	ForeignKeys, PrimaryKeys []string
+	foreignKeys, primaryKeys []string
 }
 
 type Polymorphic struct {
@@ -51,17 +51,19 @@ type Reference struct {
 
 func (schema *Schema) parseRelation(field *Field) {
 	var (
+		err        error
 		fieldValue = reflect.New(field.FieldType).Interface()
 		relation   = &Relationship{
 			Name:        field.Name,
 			Field:       field,
 			Schema:      schema,
-			ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
-			PrimaryKeys: toColumns(field.TagSettings["REFERENCES"]),
+			foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
+			primaryKeys: toColumns(field.TagSettings["REFERENCES"]),
 		}
 	)
 
-	if relation.FieldSchema, schema.err = Parse(fieldValue, schema.cacheStore, schema.namer); schema.err != nil {
+	if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
+		schema.err = err
 		return
 	}
 
@@ -86,6 +88,20 @@ func (schema *Schema) parseRelation(field *Field) {
 			relation.Type = HasMany
 		}
 	}
+
+	if schema.err == nil {
+		schema.Relationships.Relations[relation.Name] = relation
+		switch relation.Type {
+		case HasOne:
+			schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
+		case HasMany:
+			schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation)
+		case BelongsTo:
+			schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation)
+		case Many2Many:
+			schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation)
+		}
+	}
 }
 
 // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
@@ -125,9 +141,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
 		})
 
 		primaryKeyField := schema.PrioritizedPrimaryField
-		if len(relation.ForeignKeys) > 0 {
-			if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 {
-				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name)
+		if len(relation.foreignKeys) > 0 {
+			if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
+				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
 			}
 		}
 		relation.References = append(relation.References, Reference{
@@ -144,6 +160,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 	relation.Type = Many2Many
 
 	var (
+		err             error
 		joinTableFields []reflect.StructField
 		fieldsMap       = map[string]*Field{}
 	)
@@ -169,7 +186,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 		}
 	}
 
-	relation.JoinTable, schema.err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer)
+	if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
+		schema.err = err
+	}
 	relation.JoinTable.Name = many2many
 	relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
 
@@ -202,18 +221,23 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
 		}
 	}
 
-	if len(relation.ForeignKeys) > 0 {
-		for _, foreignKey := range relation.ForeignKeys {
+	if len(relation.foreignKeys) > 0 {
+		for _, foreignKey := range relation.foreignKeys {
 			if f := foreignSchema.LookUpField(foreignKey); f != nil {
 				foreignFields = append(foreignFields, f)
 			} else {
-				reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.ForeignKeys)
+				reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys)
 				return
 			}
 		}
 	} else {
 		for _, primaryField := range primarySchema.PrimaryFields {
-			if f := foreignSchema.LookUpField(field.Name + primaryField.Name); f != nil {
+			lookUpName := schema.Name + primaryField.Name
+			if !guessHas {
+				lookUpName = field.Name + primaryField.Name
+			}
+
+			if f := foreignSchema.LookUpField(lookUpName); f != nil {
 				foreignFields = append(foreignFields, f)
 				primaryFields = append(primaryFields, primaryField)
 			}
@@ -221,19 +245,19 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
 	}
 
 	if len(foreignFields) == 0 {
-		reguessOrErr("failed to guess %v's relations with %v's field %v", relation.FieldSchema, schema, field.Name)
+		reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas)
 		return
-	} else if len(relation.PrimaryKeys) > 0 {
-		for idx, primaryKey := range relation.PrimaryKeys {
+	} else if len(relation.primaryKeys) > 0 {
+		for idx, primaryKey := range relation.primaryKeys {
 			if f := primarySchema.LookUpField(primaryKey); f != nil {
 				if len(primaryFields) < idx+1 {
 					primaryFields = append(primaryFields, f)
 				} else if f != primaryFields[idx] {
-					reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.PrimaryKeys)
+					reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
 					return
 				}
 			} else {
-				reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.PrimaryKeys)
+				reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
 				return
 			}
 		}

+ 41 - 17
schema/schema.go

@@ -4,7 +4,6 @@ import (
 	"fmt"
 	"go/ast"
 	"reflect"
-	"strings"
 	"sync"
 
 	"github.com/jinzhu/gorm/logger"
@@ -26,7 +25,7 @@ type Schema struct {
 }
 
 func (schema Schema) String() string {
-	return schema.ModelType.PkgPath()
+	return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name())
 }
 
 func (schema Schema) LookUpField(name string) *Field {
@@ -63,6 +62,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
 		Table:          namer.TableName(modelType.Name()),
 		FieldsByName:   map[string]*Field{},
 		FieldsByDBName: map[string]*Field{},
+		Relationships:  Relationships{Relations: map[string]*Relationship{}},
 		cacheStore:     cacheStore,
 		namer:          namer,
 	}
@@ -76,10 +76,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
 
 	for i := 0; i < modelType.NumField(); i++ {
 		if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
-			field := schema.ParseField(fieldStruct)
-			schema.Fields = append(schema.Fields, field)
-			if field.EmbeddedSchema != nil {
+			if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
 				schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
+			} else {
+				schema.Fields = append(schema.Fields, field)
 			}
 		}
 	}
@@ -94,6 +94,27 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
 			if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
 				schema.FieldsByDBName[field.DBName] = field
 				schema.FieldsByName[field.Name] = field
+
+				if v != nil && v.PrimaryKey {
+					if schema.PrioritizedPrimaryField == v {
+						schema.PrioritizedPrimaryField = nil
+					}
+
+					for idx, f := range schema.PrimaryFields {
+						if f == v {
+							schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
+						} else if schema.PrioritizedPrimaryField == nil {
+							schema.PrioritizedPrimaryField = f
+						}
+					}
+				}
+
+				if field.PrimaryKey {
+					if schema.PrioritizedPrimaryField == nil {
+						schema.PrioritizedPrimaryField = field
+					}
+					schema.PrimaryFields = append(schema.PrimaryFields, field)
+				}
 			}
 		}
 
@@ -102,23 +123,26 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
 		}
 	}
 
-	for db, field := range schema.FieldsByDBName {
-		if strings.ToLower(db) == "id" {
-			schema.PrioritizedPrimaryField = field
+	if f := schema.LookUpField("id"); f != nil {
+		if f.PrimaryKey {
+			schema.PrioritizedPrimaryField = f
+		} else if len(schema.PrimaryFields) == 0 {
+			f.PrimaryKey = true
+			schema.PrioritizedPrimaryField = f
+			schema.PrimaryFields = append(schema.PrimaryFields, f)
 		}
+	}
 
-		if field.PrimaryKey {
-			if schema.PrioritizedPrimaryField == nil {
-				schema.PrioritizedPrimaryField = field
-			}
-			schema.PrimaryFields = append(schema.PrimaryFields, field)
-		}
+	cacheStore.Store(modelType, schema)
 
-		if field.DataType == "" {
-			defer schema.parseRelation(field)
+	// parse relations for unidentified fields
+	for _, field := range schema.Fields {
+		if field.DataType == "" && field.Creatable {
+			if schema.parseRelation(field); schema.err != nil {
+				return schema, schema.err
+			}
 		}
 	}
 
-	cacheStore.Store(modelType, schema)
 	return schema, schema.err
 }

+ 123 - 0
schema/schema_helper_test.go

@@ -0,0 +1,123 @@
+package schema_test
+
+import (
+	"reflect"
+	"testing"
+
+	"github.com/jinzhu/gorm/schema"
+)
+
+func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
+	equalFieldNames := []string{"Name", "Table"}
+
+	for _, name := range equalFieldNames {
+		got := reflect.ValueOf(s).Elem().FieldByName(name).Interface()
+		expects := reflect.ValueOf(v).FieldByName(name).Interface()
+		if !reflect.DeepEqual(got, expects) {
+			t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got)
+		}
+	}
+
+	for idx, field := range primaryFields {
+		var found bool
+		for _, f := range s.PrimaryFields {
+			if f.Name == field {
+				found = true
+			}
+		}
+
+		if idx == 0 {
+			if field != s.PrioritizedPrimaryField.Name {
+				t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name)
+			}
+		}
+
+		if !found {
+			t.Errorf("schema %v failed to found priamry key: %v", s, field)
+		}
+	}
+}
+
+func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) {
+	if fc != nil {
+		fc(f)
+	}
+
+	if f.TagSettings == nil {
+		if f.Tag != "" {
+			f.TagSettings = schema.ParseTagSetting(f.Tag)
+		} else {
+			f.TagSettings = map[string]string{}
+		}
+	}
+
+	if parsedField, ok := s.FieldsByName[f.Name]; !ok {
+		t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
+	} else {
+		equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"}
+
+		for _, name := range equalFieldNames {
+			got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface()
+			expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface()
+			if !reflect.DeepEqual(got, expects) {
+				t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got)
+			}
+		}
+
+		if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
+			t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
+		}
+
+		for _, name := range []string{f.DBName, f.Name} {
+			if field := s.LookUpField(name); field == nil || parsedField != field {
+				t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
+			}
+		}
+
+		if f.PrimaryKey {
+			var found bool
+			for _, primaryField := range s.PrimaryFields {
+				if primaryField == parsedField {
+					found = true
+				}
+			}
+
+			if !found {
+				t.Errorf("schema %v doesn't include field %v", s, f.Name)
+			}
+		}
+	}
+}
+
+type Relation struct {
+	Name            string
+	Type            schema.RelationshipType
+	Polymorphic     schema.Polymorphic
+	Schema          string
+	FieldSchema     string
+	JoinTable       string
+	JoinTableFields []schema.Field
+	References      []Reference
+}
+
+type Reference struct {
+	PrimaryKey    string
+	PrimarySchema string
+	ForeignKey    string
+	ForeignSchema string
+	OwnPriamryKey bool
+}
+
+func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
+	if r, ok := s.Relationships.Relations[relation.Name]; ok {
+		if r.Name != relation.Name {
+			t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Name, r.Name)
+		}
+
+		if r.Type != relation.Type {
+			t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Type, r.Type)
+		}
+	} else {
+		t.Errorf("schema %v failed to find relations by name %v", s, relation.Name)
+	}
+}

+ 23 - 52
schema/schema_test.go

@@ -1,7 +1,6 @@
 package schema_test
 
 import (
-	"reflect"
 	"sync"
 	"testing"
 
@@ -11,68 +10,40 @@ import (
 
 func TestParseSchema(t *testing.T) {
 	cacheMap := sync.Map{}
-	user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
 
+	user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse user, got error %v", err)
 	}
 
-	checkSchemaFields(t, user)
-}
+	// check schema
+	checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"})
 
-func checkSchemaFields(t *testing.T, s *schema.Schema) {
+	// check fields
 	fields := []schema.Field{
-		schema.Field{
-			Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint,
-			PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"},
-		},
-		schema.Field{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time},
-		schema.Field{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time},
-		schema.Field{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time},
-		schema.Field{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
-		schema.Field{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint},
-		schema.Field{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
-		schema.Field{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int},
-		schema.Field{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint},
+		{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
+		{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time},
+		{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time},
+		{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time},
+		{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
+		{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint},
+		{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
+		{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int},
+		{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint},
 	}
 
 	for _, f := range fields {
-		f.Creatable = true
-		f.Updatable = true
-		if f.TagSettings == nil {
-			if f.Tag != "" {
-				f.TagSettings = schema.ParseTagSetting(f.Tag)
-			} else {
-				f.TagSettings = map[string]string{}
-			}
-		}
-
-		if foundField, ok := s.FieldsByName[f.Name]; !ok {
-			t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
-		} else {
-			checkSchemaField(t, foundField, f)
-
-			if field, ok := s.FieldsByDBName[f.DBName]; !ok || foundField != field {
-				t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
-			}
-
-			for _, name := range []string{f.DBName, f.Name} {
-				if field := s.LookUpField(name); field == nil || foundField != field {
-					t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
-				}
-			}
-		}
+		checkSchemaField(t, user, &f, func(f *schema.Field) {
+			f.Creatable = true
+			f.Updatable = true
+		})
 	}
-}
 
-func checkSchemaField(t *testing.T, parsedField *schema.Field, field schema.Field) {
-	equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"}
-
-	for _, name := range equalFieldNames {
-		got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface()
-		expects := reflect.ValueOf(field).FieldByName(name).Interface()
-		if !reflect.DeepEqual(got, expects) {
-			t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got)
-		}
+	// check relations
+	relations := []Relation{
+		{Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", References: []Reference{{"ID", "User", "UserID", "Pet", true}}},
+	}
+	for _, relation := range relations {
+		checkSchemaRelation(t, user, relation)
 	}
 }

+ 1 - 1
tests/model.go

@@ -23,7 +23,7 @@ type User struct {
 	Company   Company
 	ManagerID uint
 	Manager   *User
-	Team      []User     `foreignkey:ManagerID`
+	Team      []User     `gorm:"foreignkey:ManagerID"`
 	Friends   []*User    `gorm:"many2many:user_friends"`
 	Languages []Language `gorm:"many2many:user_speaks"`
 }