Răsfoiți Sursa

Improve test structure

Jinzhu 4 ani în urmă
părinte
comite
8cb15cadde

+ 12 - 0
callbacks/callbacks.go

@@ -0,0 +1,12 @@
+package callbacks
+
+import "github.com/jinzhu/gorm"
+
+func RegisterDefaultCallbacks(db *gorm.DB) {
+	callback := db.Callback()
+	callback.Create().Register("gorm:before_create", BeforeCreate)
+	callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
+	callback.Create().Register("gorm:create", Create)
+	callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
+	callback.Create().Register("gorm:after_create", AfterCreate)
+}

+ 24 - 0
callbacks/create.go

@@ -0,0 +1,24 @@
+package callbacks
+
+import "github.com/jinzhu/gorm"
+
+func BeforeCreate(db *gorm.DB) {
+	// before save
+	// before create
+
+	// assign timestamp
+}
+
+func SaveBeforeAssociations(db *gorm.DB) {
+}
+
+func Create(db *gorm.DB) {
+}
+
+func SaveAfterAssociations(db *gorm.DB) {
+}
+
+func AfterCreate(db *gorm.DB) {
+	// after save
+	// after create
+}

+ 11 - 0
callbacks/interface.go

@@ -0,0 +1,11 @@
+package callbacks
+
+import "github.com/jinzhu/gorm"
+
+type beforeSaveInterface interface {
+	BeforeSave(*gorm.DB) error
+}
+
+type beforeCreateInterface interface {
+	BeforeCreate(*gorm.DB) error
+}

+ 7 - 0
dialects/mysql/go.mod

@@ -0,0 +1,7 @@
+module github.com/jinzhu/gorm/dialects/mysql
+
+go 1.13
+
+require (
+	github.com/go-sql-driver/mysql v1.5.0
+)

+ 29 - 0
dialects/mysql/mysql.go

@@ -0,0 +1,29 @@
+package mysql
+
+import (
+	_ "github.com/go-sql-driver/mysql"
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/callbacks"
+)
+
+type Dialector struct {
+}
+
+func Open(dsn string) gorm.Dialector {
+	return &Dialector{}
+}
+
+func (Dialector) Initialize(db *gorm.DB) error {
+	// register callbacks
+	callbacks.RegisterDefaultCallbacks(db)
+
+	return nil
+}
+
+func (Dialector) Migrator() gorm.Migrator {
+	return nil
+}
+
+func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
+	return "?"
+}

+ 12 - 0
dialects/mysql/mysql_test.go

@@ -0,0 +1,12 @@
+package mysql_test
+
+import (
+	"testing"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/dialects/mysql"
+)
+
+func TestOpen(t *testing.T) {
+	gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), nil)
+}

+ 7 - 0
dialects/sqlite/go.mod

@@ -0,0 +1,7 @@
+module github.com/jinzhu/gorm/dialects/mysql
+
+go 1.13
+
+require (
+	github.com/mattn/go-sqlite3 v2.0.3+incompatible
+)

+ 28 - 0
dialects/sqlite/sqlite.go

@@ -0,0 +1,28 @@
+package sqlite
+
+import (
+	"github.com/jinzhu/gorm/callbacks"
+	_ "github.com/mattn/go-sqlite3"
+)
+
+type Dialector struct {
+}
+
+func Open(dsn string) gorm.Dialector {
+	return &Dialector{}
+}
+
+func (Dialector) Initialize(db *gorm.DB) error {
+	// register callbacks
+	callbacks.RegisterDefaultCallbacks(db)
+
+	return nil
+}
+
+func (Dialector) Migrator() gorm.Migrator {
+	return nil
+}
+
+func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
+	return "?"
+}

+ 15 - 0
dialects/sqlite/sqlite_test.go

@@ -0,0 +1,15 @@
+package sqlite_test
+
+import (
+	"os"
+	"path/filepath"
+	"testing"
+
+	"github.com/jinzhu/gorm"
+)
+
+var DB *gorm.DB
+
+func TestOpen(t *testing.T) {
+	db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
+}

+ 1 - 0
finisher_api.go

@@ -12,6 +12,7 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) {
 // First find first record that match given conditions, order by primary key
 func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
 	tx = db.getInstance()
+	tx.callbacks.Create().Execute(tx.Limit(1).Order("id"))
 	return
 }
 

+ 29 - 4
gorm.go

@@ -13,7 +13,7 @@ import (
 type Config struct {
 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 	// You can cancel it by setting `SkipDefaultTransaction` to true
-	SkipDefaultTransaction bool
+	SkipDefaultTransaction bool // TODO
 
 	// NamingStrategy tables, columns naming strategy
 	NamingStrategy schema.Namer
@@ -27,6 +27,7 @@ type Config struct {
 
 // Dialector GORM database dialector
 type Dialector interface {
+	Initialize(*DB) error
 	Migrator() Migrator
 	BindVar(stmt Statement, v interface{}) string
 }
@@ -36,7 +37,8 @@ type DB struct {
 	*Config
 	Dialector
 	Instance
-	clone bool
+	clone     bool
+	callbacks *callbacks
 }
 
 // Session session config when create new session
@@ -48,15 +50,33 @@ type Session struct {
 
 // Open initialize db session based on dialector
 func Open(dialector Dialector, config *Config) (db *DB, err error) {
+	if config == nil {
+		config = &Config{}
+	}
+
 	if config.NamingStrategy == nil {
 		config.NamingStrategy = schema.NamingStrategy{}
 	}
 
-	return &DB{
+	if config.Logger == nil {
+		config.Logger = logger.Default
+	}
+
+	if config.NowFunc == nil {
+		config.NowFunc = func() time.Time { return time.Now().Local() }
+	}
+
+	db = &DB{
 		Config:    config,
 		Dialector: dialector,
 		clone:     true,
-	}, nil
+		callbacks: InitializeCallbacks(),
+	}
+
+	if dialector != nil {
+		err = dialector.Initialize(db)
+	}
+	return
 }
 
 // Session create new db session
@@ -112,6 +132,11 @@ func (db *DB) Get(key string) (interface{}, bool) {
 	return nil, false
 }
 
+// Callback returns callback manager
+func (db *DB) Callback() *callbacks {
+	return db.callbacks
+}
+
 func (db *DB) getInstance() *DB {
 	if db.clone {
 		ctx := db.Instance.Context

+ 115 - 109
schema/schema_helper_test.go

@@ -10,85 +10,89 @@ import (
 )
 
 func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
-	equalFieldNames := []string{"Name", "Table"}
+	t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
+		equalFieldNames := []string{"Name", "Table"}
 
-	for _, name := range equalFieldNames {
-		got := reflect.ValueOf(s).Elem().FieldByName(name).Interface()
-		expects := reflect.ValueOf(v).FieldByName(name).Interface()
-		if !reflect.DeepEqual(got, expects) {
-			t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got)
+		for _, name := range equalFieldNames {
+			got := reflect.ValueOf(s).Elem().FieldByName(name).Interface()
+			expects := reflect.ValueOf(v).FieldByName(name).Interface()
+			if !reflect.DeepEqual(got, expects) {
+				t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got)
+			}
 		}
-	}
 
-	for idx, field := range primaryFields {
-		var found bool
-		for _, f := range s.PrimaryFields {
-			if f.Name == field {
-				found = true
+		for idx, field := range primaryFields {
+			var found bool
+			for _, f := range s.PrimaryFields {
+				if f.Name == field {
+					found = true
+				}
 			}
-		}
 
-		if idx == 0 {
-			if field != s.PrioritizedPrimaryField.Name {
-				t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name)
+			if idx == 0 {
+				if field != s.PrioritizedPrimaryField.Name {
+					t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name)
+				}
 			}
-		}
 
-		if !found {
-			t.Errorf("schema %v failed to found priamry key: %v", s, field)
+			if !found {
+				t.Errorf("schema %v failed to found priamry key: %v", s, field)
+			}
 		}
-	}
+	})
 }
 
 func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) {
-	if fc != nil {
-		fc(f)
-	}
-
-	if f.TagSettings == nil {
-		if f.Tag != "" {
-			f.TagSettings = schema.ParseTagSetting(f.Tag)
-		} else {
-			f.TagSettings = map[string]string{}
+	t.Run("CheckField/"+f.Name, func(t *testing.T) {
+		if fc != nil {
+			fc(f)
 		}
-	}
-
-	if parsedField, ok := s.FieldsByName[f.Name]; !ok {
-		t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
-	} else {
-		equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"}
 
-		for _, name := range equalFieldNames {
-			got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface()
-			expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface()
-			if !reflect.DeepEqual(got, expects) {
-				t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got)
+		if f.TagSettings == nil {
+			if f.Tag != "" {
+				f.TagSettings = schema.ParseTagSetting(f.Tag)
+			} else {
+				f.TagSettings = map[string]string{}
 			}
 		}
 
-		if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
-			t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
-		}
+		if parsedField, ok := s.FieldsByName[f.Name]; !ok {
+			t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
+		} else {
+			equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"}
 
-		for _, name := range []string{f.DBName, f.Name} {
-			if field := s.LookUpField(name); field == nil || parsedField != field {
+			for _, name := range equalFieldNames {
+				got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface()
+				expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface()
+				if !reflect.DeepEqual(got, expects) {
+					t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got)
+				}
+			}
+
+			if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
 				t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
 			}
-		}
 
-		if f.PrimaryKey {
-			var found bool
-			for _, primaryField := range s.PrimaryFields {
-				if primaryField == parsedField {
-					found = true
+			for _, name := range []string{f.DBName, f.Name} {
+				if field := s.LookUpField(name); field == nil || parsedField != field {
+					t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
 				}
 			}
 
-			if !found {
-				t.Errorf("schema %v doesn't include field %v", s, f.Name)
+			if f.PrimaryKey {
+				var found bool
+				for _, primaryField := range s.PrimaryFields {
+					if primaryField == parsedField {
+						found = true
+					}
+				}
+
+				if !found {
+					t.Errorf("schema %v doesn't include field %v", s, f.Name)
+				}
 			}
 		}
-	}
+	})
 }
 
 type Relation struct {
@@ -123,79 +127,81 @@ type Reference struct {
 }
 
 func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
-	if r, ok := s.Relationships.Relations[relation.Name]; ok {
-		if r.Name != relation.Name {
-			t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name)
-		}
-
-		if r.Type != relation.Type {
-			t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type)
-		}
-
-		if r.Schema.Name != relation.Schema {
-			t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name)
-		}
-
-		if r.FieldSchema.Name != relation.FieldSchema {
-			t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name)
-		}
-
-		if r.Polymorphic != nil {
-			if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID {
-				t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name)
+	t.Run("CheckRelation/"+relation.Name, func(t *testing.T) {
+		if r, ok := s.Relationships.Relations[relation.Name]; ok {
+			if r.Name != relation.Name {
+				t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name)
 			}
 
-			if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type {
-				t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name)
+			if r.Type != relation.Type {
+				t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type)
 			}
 
-			if r.Polymorphic.Value != relation.Polymorphic.Value {
-				t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value)
+			if r.Schema.Name != relation.Schema {
+				t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name)
 			}
-		}
 
-		if r.JoinTable != nil {
-			if r.JoinTable.Name != relation.JoinTable.Name {
-				t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name)
+			if r.FieldSchema.Name != relation.FieldSchema {
+				t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name)
 			}
 
-			if r.JoinTable.Table != relation.JoinTable.Table {
-				t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
-			}
+			if r.Polymorphic != nil {
+				if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID {
+					t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name)
+				}
 
-			for _, f := range relation.JoinTable.Fields {
-				checkSchemaField(t, r.JoinTable, &f, nil)
+				if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type {
+					t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name)
+				}
+
+				if r.Polymorphic.Value != relation.Polymorphic.Value {
+					t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value)
+				}
 			}
-		}
 
-		if len(relation.References) != len(r.References) {
-			t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References))
-		}
+			if r.JoinTable != nil {
+				if r.JoinTable.Name != relation.JoinTable.Name {
+					t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name)
+				}
 
-		for _, ref := range relation.References {
-			var found bool
-			for _, rf := range r.References {
-				if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) {
-					found = true
+				if r.JoinTable.Table != relation.JoinTable.Table {
+					t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
+				}
+
+				for _, f := range relation.JoinTable.Fields {
+					checkSchemaField(t, r.JoinTable, &f, nil)
 				}
 			}
 
-			if !found {
-				var refs []string
+			if len(relation.References) != len(r.References) {
+				t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References))
+			}
+
+			for _, ref := range relation.References {
+				var found bool
 				for _, rf := range r.References {
-					var primaryKey, primaryKeySchema string
-					if rf.PrimaryKey != nil {
-						primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name
+					if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) {
+						found = true
 					}
-					refs = append(refs, fmt.Sprintf(
-						"{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}",
-						primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey,
-					))
 				}
-				t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", "))
+
+				if !found {
+					var refs []string
+					for _, rf := range r.References {
+						var primaryKey, primaryKeySchema string
+						if rf.PrimaryKey != nil {
+							primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name
+						}
+						refs = append(refs, fmt.Sprintf(
+							"{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}",
+							primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey,
+						))
+					}
+					t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", "))
+				}
 			}
+		} else {
+			t.Errorf("schema %v failed to find relations by name %v", s, relation.Name)
 		}
-	} else {
-		t.Errorf("schema %v failed to find relations by name %v", s, relation.Name)
-	}
+	})
 }

+ 1 - 0
tests/create_test.go

@@ -0,0 +1 @@
+package tests