Browse Source

Add Before/After callbacks

Jinzhu 4 years ago
parent
commit
e2a360b9fa

+ 60 - 4
callbacks/create.go

@@ -8,8 +8,36 @@ import (
 )
 
 func BeforeCreate(db *gorm.DB) {
-	// before save
-	// before create
+	if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
+		callMethod := func(value interface{}) bool {
+			var ok bool
+			if db.Statement.Schema.BeforeSave {
+				if i, ok := value.(gorm.BeforeSaveInterface); ok {
+					ok = true
+					i.BeforeSave(db)
+				}
+			}
+
+			if db.Statement.Schema.BeforeCreate {
+				if i, ok := value.(gorm.BeforeCreateInterface); ok {
+					ok = true
+					i.BeforeCreate(db)
+				}
+			}
+			return ok
+		}
+
+		if ok := callMethod(db.Statement.Dest); !ok {
+			switch db.Statement.ReflectValue.Kind() {
+			case reflect.Slice, reflect.Array:
+				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
+					callMethod(db.Statement.ReflectValue.Index(i).Interface())
+				}
+			case reflect.Struct:
+				callMethod(db.Statement.ReflectValue.Interface())
+			}
+		}
+	}
 }
 
 func SaveBeforeAssociations(db *gorm.DB) {
@@ -48,8 +76,36 @@ func SaveAfterAssociations(db *gorm.DB) {
 }
 
 func AfterCreate(db *gorm.DB) {
-	// after save
-	// after create
+	if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
+		callMethod := func(value interface{}) bool {
+			var ok bool
+			if db.Statement.Schema.AfterSave {
+				if i, ok := value.(gorm.AfterSaveInterface); ok {
+					ok = true
+					i.AfterSave(db)
+				}
+			}
+
+			if db.Statement.Schema.AfterCreate {
+				if i, ok := value.(gorm.AfterCreateInterface); ok {
+					ok = true
+					i.AfterCreate(db)
+				}
+			}
+			return ok
+		}
+
+		if ok := callMethod(db.Statement.Dest); !ok {
+			switch db.Statement.ReflectValue.Kind() {
+			case reflect.Slice, reflect.Array:
+				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
+					callMethod(db.Statement.ReflectValue.Index(i).Interface())
+				}
+			case reflect.Struct:
+				callMethod(db.Statement.ReflectValue.Interface())
+			}
+		}
+	}
 }
 
 // ConvertToCreateValues convert to create values

+ 49 - 1
callbacks/delete.go

@@ -1,12 +1,60 @@
 package callbacks
 
-import "github.com/jinzhu/gorm"
+import (
+	"reflect"
+
+	"github.com/jinzhu/gorm"
+)
 
 func BeforeDelete(db *gorm.DB) {
+	if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
+		callMethod := func(value interface{}) bool {
+			if db.Statement.Schema.BeforeDelete {
+				if i, ok := value.(gorm.BeforeDeleteInterface); ok {
+					i.BeforeDelete(db)
+					return true
+				}
+			}
+			return false
+		}
+
+		if ok := callMethod(db.Statement.Dest); !ok {
+			switch db.Statement.ReflectValue.Kind() {
+			case reflect.Slice, reflect.Array:
+				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
+					callMethod(db.Statement.ReflectValue.Index(i).Interface())
+				}
+			case reflect.Struct:
+				callMethod(db.Statement.ReflectValue.Interface())
+			}
+		}
+	}
 }
 
 func Delete(db *gorm.DB) {
 }
 
 func AfterDelete(db *gorm.DB) {
+	if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
+		callMethod := func(value interface{}) bool {
+			if db.Statement.Schema.AfterDelete {
+				if i, ok := value.(gorm.AfterDeleteInterface); ok {
+					i.AfterDelete(db)
+					return true
+				}
+			}
+			return false
+		}
+
+		if ok := callMethod(db.Statement.Dest); !ok {
+			switch db.Statement.ReflectValue.Kind() {
+			case reflect.Slice, reflect.Array:
+				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
+					callMethod(db.Statement.ReflectValue.Index(i).Interface())
+				}
+			case reflect.Struct:
+				callMethod(db.Statement.ReflectValue.Interface())
+			}
+		}
+	}
 }

+ 25 - 2
callbacks/query.go

@@ -1,6 +1,8 @@
 package callbacks
 
 import (
+	"reflect"
+
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/clause"
 )
@@ -13,7 +15,7 @@ func Query(db *gorm.DB) {
 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
 	}
 
-	rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
+	_, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 	db.AddError(err)
 }
 
@@ -21,5 +23,26 @@ func Preload(db *gorm.DB) {
 }
 
 func AfterQuery(db *gorm.DB) {
-	// after find
+	if db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
+		callMethod := func(value interface{}) bool {
+			if db.Statement.Schema.AfterFind {
+				if i, ok := value.(gorm.AfterFindInterface); ok {
+					i.AfterFind(db)
+					return true
+				}
+			}
+			return false
+		}
+
+		if ok := callMethod(db.Statement.Dest); !ok {
+			switch db.Statement.ReflectValue.Kind() {
+			case reflect.Slice, reflect.Array:
+				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
+					callMethod(db.Statement.ReflectValue.Index(i).Interface())
+				}
+			case reflect.Struct:
+				callMethod(db.Statement.ReflectValue.Interface())
+			}
+		}
+	}
 }

+ 65 - 1
callbacks/update.go

@@ -1,12 +1,76 @@
 package callbacks
 
-import "github.com/jinzhu/gorm"
+import (
+	"reflect"
+
+	"github.com/jinzhu/gorm"
+)
 
 func BeforeUpdate(db *gorm.DB) {
+	if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
+		callMethod := func(value interface{}) bool {
+			var ok bool
+			if db.Statement.Schema.BeforeSave {
+				if i, ok := value.(gorm.BeforeSaveInterface); ok {
+					ok = true
+					i.BeforeSave(db)
+				}
+			}
+
+			if db.Statement.Schema.BeforeUpdate {
+				if i, ok := value.(gorm.BeforeUpdateInterface); ok {
+					ok = true
+					i.BeforeUpdate(db)
+				}
+			}
+			return ok
+		}
+
+		if ok := callMethod(db.Statement.Dest); !ok {
+			switch db.Statement.ReflectValue.Kind() {
+			case reflect.Slice, reflect.Array:
+				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
+					callMethod(db.Statement.ReflectValue.Index(i).Interface())
+				}
+			case reflect.Struct:
+				callMethod(db.Statement.ReflectValue.Interface())
+			}
+		}
+	}
 }
 
 func Update(db *gorm.DB) {
 }
 
 func AfterUpdate(db *gorm.DB) {
+	if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
+		callMethod := func(value interface{}) bool {
+			var ok bool
+			if db.Statement.Schema.AfterSave {
+				if i, ok := value.(gorm.AfterSaveInterface); ok {
+					ok = true
+					i.AfterSave(db)
+				}
+			}
+
+			if db.Statement.Schema.AfterUpdate {
+				if i, ok := value.(gorm.AfterUpdateInterface); ok {
+					ok = true
+					i.AfterUpdate(db)
+				}
+			}
+			return ok
+		}
+
+		if ok := callMethod(db.Statement.Dest); !ok {
+			switch db.Statement.ReflectValue.Kind() {
+			case reflect.Slice, reflect.Array:
+				for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
+					callMethod(db.Statement.ReflectValue.Index(i).Interface())
+				}
+			case reflect.Struct:
+				callMethod(db.Statement.ReflectValue.Interface())
+			}
+		}
+	}
 }

+ 2 - 2
clause/benchmarks_test.go

@@ -11,7 +11,7 @@ import (
 )
 
 func BenchmarkSelect(b *testing.B) {
-	user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
+	user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
 
 	for i := 0; i < b.N; i++ {
 		stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
@@ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) {
 }
 
 func BenchmarkComplexSelect(b *testing.B) {
-	user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
+	user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
 
 	for i := 0; i < b.N; i++ {
 		stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}

+ 1 - 1
clause/clause_test.go

@@ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string,
 	var (
 		buildNames    []string
 		buildNamesMap = map[string]bool{}
-		user, _       = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
+		user, _, _    = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
 		stmt          = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
 	)
 

+ 1 - 1
clause/expression_test.go

@@ -24,7 +24,7 @@ func TestExpr(t *testing.T) {
 
 	for idx, result := range results {
 		t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
-			user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
+			user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
 			stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
 			clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
 			if stmt.SQL.String() != result.Result {

+ 36 - 0
interfaces.go

@@ -24,3 +24,39 @@ type CommonDB interface {
 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
 }
+
+type BeforeCreateInterface interface {
+	BeforeCreate(*DB)
+}
+
+type AfterCreateInterface interface {
+	AfterCreate(*DB)
+}
+
+type BeforeUpdateInterface interface {
+	BeforeUpdate(*DB)
+}
+
+type AfterUpdateInterface interface {
+	AfterUpdate(*DB)
+}
+
+type BeforeSaveInterface interface {
+	BeforeSave(*DB)
+}
+
+type AfterSaveInterface interface {
+	AfterSave(*DB)
+}
+
+type BeforeDeleteInterface interface {
+	BeforeDelete(*DB)
+}
+
+type AfterDeleteInterface interface {
+	AfterDelete(*DB)
+}
+
+type AfterFindInterface interface {
+	AfterFind(*DB)
+}

+ 38 - 0
schema/callbacks_test.go

@@ -0,0 +1,38 @@
+package schema_test
+
+import (
+	"reflect"
+	"sync"
+	"testing"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/schema"
+)
+
+type UserWithCallback struct {
+}
+
+func (UserWithCallback) BeforeSave(*gorm.DB) {
+}
+
+func (UserWithCallback) AfterCreate(*gorm.DB) {
+}
+
+func TestCallback(t *testing.T) {
+	user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
+	if err != nil {
+		t.Fatalf("failed to parse user with callback, got error %v", err)
+	}
+
+	for _, str := range []string{"BeforeSave", "AfterCreate"} {
+		if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
+			t.Errorf("%v should be true", str)
+		}
+	}
+
+	for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} {
+		if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
+			t.Errorf("%v should be false", str)
+		}
+	}
+}

+ 1 - 1
schema/check_test.go

@@ -15,7 +15,7 @@ type UserCheck struct {
 }
 
 func TestParseCheck(t *testing.T) {
-	user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
+	user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse user check, got error %v", err)
 	}

+ 12 - 12
schema/field_test.go

@@ -14,8 +14,8 @@ import (
 
 func TestFieldValuerAndSetter(t *testing.T) {
 	var (
-		userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
-		user          = tests.User{
+		userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
+		user             = tests.User{
 			Model: gorm.Model{
 				ID:        10,
 				CreatedAt: time.Now(),
@@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) {
 
 func TestPointerFieldValuerAndSetter(t *testing.T) {
 	var (
-		userSchema, _      = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
-		name               = "pointer_field_valuer_and_setter"
-		age           uint = 18
-		active             = true
-		user               = User{
+		userSchema, _, _      = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
+		name                  = "pointer_field_valuer_and_setter"
+		age              uint = 18
+		active                = true
+		user                  = User{
 			Model: &gorm.Model{
 				ID:        10,
 				CreatedAt: time.Now(),
@@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
 
 func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
 	var (
-		userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
-		name          = "advanced_data_type_valuer_and_setter"
-		deletedAt     = mytime(time.Now())
-		isAdmin       = mybool(false)
-		user          = AdvancedDataTypeUser{
+		userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
+		name             = "advanced_data_type_valuer_and_setter"
+		deletedAt        = mytime(time.Now())
+		isAdmin          = mybool(false)
+		user             = AdvancedDataTypeUser{
 			ID:           sql.NullInt64{Int64: 10, Valid: true},
 			Name:         &sql.NullString{String: name, Valid: true},
 			Birthday:     sql.NullTime{Time: time.Now(), Valid: true},

+ 1 - 1
schema/index_test.go

@@ -19,7 +19,7 @@ type UserIndex struct {
 }
 
 func TestParseIndex(t *testing.T) {
-	user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
+	user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse user index, got error %v", err)
 	}

+ 31 - 14
schema/schema.go

@@ -14,20 +14,25 @@ import (
 var ErrUnsupportedDataType = errors.New("unsupported data type")
 
 type Schema struct {
-	Name                     string
-	ModelType                reflect.Type
-	Table                    string
-	PrioritizedPrimaryField  *Field
-	DBNames                  []string
-	PrimaryFields            []*Field
-	Fields                   []*Field
-	FieldsByName             map[string]*Field
-	FieldsByDBName           map[string]*Field
-	FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
-	Relationships            Relationships
-	err                      error
-	namer                    Namer
-	cacheStore               *sync.Map
+	Name                      string
+	ModelType                 reflect.Type
+	Table                     string
+	PrioritizedPrimaryField   *Field
+	DBNames                   []string
+	PrimaryFields             []*Field
+	Fields                    []*Field
+	FieldsByName              map[string]*Field
+	FieldsByDBName            map[string]*Field
+	FieldsWithDefaultDBValue  map[string]*Field // fields with default value assigned by database
+	Relationships             Relationships
+	BeforeCreate, AfterCreate bool
+	BeforeUpdate, AfterUpdate bool
+	BeforeDelete, AfterDelete bool
+	BeforeSave, AfterSave     bool
+	AfterFind                 bool
+	err                       error
+	namer                     Namer
+	cacheStore                *sync.Map
 }
 
 func (schema Schema) String() string {
@@ -162,6 +167,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
 		}
 	}
 
+	callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
+	for _, name := range callbacks {
+		if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
+			switch methodValue.Type().String() {
+			case "func(*gorm.DB)": // TODO hack
+				reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
+			default:
+				logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
+			}
+		}
+	}
+
 	cacheStore.Store(modelType, schema)
 
 	// parse relations for unidentified fields

+ 3 - 3
schema/schema_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestParseSchema(t *testing.T) {
-	user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
+	user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse user, got error %v", err)
 	}
@@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
 }
 
 func TestParseSchemaWithPointerFields(t *testing.T) {
-	user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
+	user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse pointer user, got error %v", err)
 	}
@@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
 }
 
 func TestParseSchemaWithAdvancedDataType(t *testing.T) {
-	user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
+	user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
 		t.Fatalf("failed to parse pointer user, got error %v", err)
 	}