Răsfoiți Sursa

Add callbacks

Jinzhu 4 ani în urmă
părinte
comite
728c0d4470
13 a modificat fișierele cu 101 adăugiri și 51 ștergeri
  1. 18 11
      callbacks.go
  2. 33 6
      callbacks/callbacks.go
  3. 1 15
      callbacks/create.go
  4. 12 0
      callbacks/delete.go
  5. 9 0
      callbacks/transaction.go
  6. 12 0
      callbacks/update.go
  7. 0 5
      dialects/sqlite/go.mod
  8. 0 2
      dialects/sqlite/go.sum
  9. 1 4
      go.mod
  10. 0 2
      go.sum
  11. 2 1
      gorm.go
  12. 11 3
      statement.go
  13. 2 2
      tests/callbacks_test.go

+ 18 - 11
callbacks.go

@@ -9,15 +9,15 @@ import (
 	"github.com/jinzhu/gorm/utils"
 )
 
-func InitializeCallbacks() *callbacks {
+func initializeCallbacks(db *DB) *callbacks {
 	return &callbacks{
 		processors: map[string]*processor{
-			"create": &processor{},
-			"query":  &processor{},
-			"update": &processor{},
-			"delete": &processor{},
-			"row":    &processor{},
-			"raw":    &processor{},
+			"create": &processor{db: db},
+			"query":  &processor{db: db},
+			"update": &processor{db: db},
+			"delete": &processor{db: db},
+			"row":    &processor{db: db},
+			"raw":    &processor{db: db},
 		},
 	}
 }
@@ -118,7 +118,14 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
 	return (&callback{processor: p}).Replace(name, fn)
 }
 
-func (p *processor) compile(db *DB) (err error) {
+func (p *processor) compile() (err error) {
+	var callbacks []*callback
+	for _, callback := range p.callbacks {
+		if callback.match == nil || callback.match(p.db) {
+			callbacks = append(callbacks, callback)
+		}
+	}
+
 	if p.fns, err = sortCallbacks(p.callbacks); err != nil {
 		logger.Default.Error("Got error when compile callbacks, got %v", err)
 	}
@@ -139,7 +146,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
 	c.name = name
 	c.handler = fn
 	c.processor.callbacks = append(c.processor.callbacks, c)
-	return c.processor.compile(c.processor.db)
+	return c.processor.compile()
 }
 
 func (c *callback) Remove(name string) error {
@@ -147,7 +154,7 @@ func (c *callback) Remove(name string) error {
 	c.name = name
 	c.remove = true
 	c.processor.callbacks = append(c.processor.callbacks, c)
-	return c.processor.compile(c.processor.db)
+	return c.processor.compile()
 }
 
 func (c *callback) Replace(name string, fn func(*DB)) error {
@@ -156,7 +163,7 @@ func (c *callback) Replace(name string, fn func(*DB)) error {
 	c.handler = fn
 	c.replace = true
 	c.processor.callbacks = append(c.processor.callbacks, c)
-	return c.processor.compile(c.processor.db)
+	return c.processor.compile()
 }
 
 // getRIndex get right index from string slice

+ 33 - 6
callbacks/callbacks.go

@@ -3,10 +3,37 @@ package callbacks
 import "github.com/jinzhu/gorm"
 
 func RegisterDefaultCallbacks(db *gorm.DB) {
-	callback := db.Callback()
-	callback.Create().Register("gorm:before_create", BeforeCreate)
-	callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
-	callback.Create().Register("gorm:create", Create)
-	callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
-	callback.Create().Register("gorm:after_create", AfterCreate)
+	enableTransaction := func(db *gorm.DB) bool {
+		return !db.SkipDefaultTransaction
+	}
+
+	createCallback := db.Callback().Create()
+	createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
+	createCallback.Register("gorm:before_create", BeforeCreate)
+	createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
+	createCallback.Register("gorm:create", Create)
+	createCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
+	createCallback.Register("gorm:after_create", AfterCreate)
+	createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
+
+	queryCallback := db.Callback().Query()
+	queryCallback.Register("gorm:query", BeforeCreate)
+	queryCallback.Register("gorm:preload", Preload)
+	queryCallback.Register("gorm:after_query", AfterQuery)
+
+	deleteCallback := db.Callback().Delete()
+	deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
+	deleteCallback.Register("gorm:before_delete", BeforeDelete)
+	deleteCallback.Register("gorm:delete", Delete)
+	deleteCallback.Register("gorm:after_delete", AfterDelete)
+	deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
+
+	updateCallback := db.Callback().Update()
+	updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
+	updateCallback.Register("gorm:before_update", BeforeUpdate)
+	updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
+	updateCallback.Register("gorm:update", Update)
+	updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
+	updateCallback.Register("gorm:after_update", AfterUpdate)
+	updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
 }

+ 1 - 15
callbacks/create.go

@@ -18,7 +18,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
 
 func Create(db *gorm.DB) {
 	db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
-
+	db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
 }
 
@@ -29,17 +29,3 @@ func AfterCreate(db *gorm.DB) {
 	// after save
 	// after create
 }
-
-func objectToFieldsMap(stmt *gorm.Statement) {
-	if stmt.Schema != nil {
-		if s, ok := stmt.Clauses["SELECT"]; ok {
-			s.Attrs
-		}
-
-		if s, ok := stmt.Clauses["OMIT"]; ok {
-			s.Attrs
-		}
-
-		stmt.Schema.LookUpField(s.S)
-	}
-}

+ 12 - 0
callbacks/delete.go

@@ -0,0 +1,12 @@
+package callbacks
+
+import "github.com/jinzhu/gorm"
+
+func BeforeDelete(db *gorm.DB) {
+}
+
+func Delete(db *gorm.DB) {
+}
+
+func AfterDelete(db *gorm.DB) {
+}

+ 9 - 0
callbacks/transaction.go

@@ -0,0 +1,9 @@
+package callbacks
+
+import "github.com/jinzhu/gorm"
+
+func BeginTransaction(db *gorm.DB) {
+}
+
+func CommitOrRollbackTransaction(db *gorm.DB) {
+}

+ 12 - 0
callbacks/update.go

@@ -0,0 +1,12 @@
+package callbacks
+
+import "github.com/jinzhu/gorm"
+
+func BeforeUpdate(db *gorm.DB) {
+}
+
+func Update(db *gorm.DB) {
+}
+
+func AfterUpdate(db *gorm.DB) {
+}

+ 0 - 5
dialects/sqlite/go.mod

@@ -1,5 +0,0 @@
-module github.com/jinzhu/gorm/dialects/sqlite
-
-go 1.13
-
-require github.com/mattn/go-sqlite3 v2.0.3+incompatible

+ 0 - 2
dialects/sqlite/go.sum

@@ -1,2 +0,0 @@
-github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
-github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=

+ 1 - 4
go.mod

@@ -2,7 +2,4 @@ module github.com/jinzhu/gorm
 
 go 1.13
 
-require (
-	github.com/jinzhu/inflection v1.0.0
-	gopkg.in/errgo.v2 v2.1.0
-)
+require github.com/jinzhu/inflection v1.0.0

+ 0 - 2
go.sum

@@ -1,2 +0,0 @@
-github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
-github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=

+ 2 - 1
gorm.go

@@ -63,10 +63,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
 		Config:     config,
 		Dialector:  dialector,
 		clone:      true,
-		callbacks:  InitializeCallbacks(),
 		cacheStore: &sync.Map{},
 	}
 
+	db.callbacks = initializeCallbacks(db)
+
 	if dialector != nil {
 		err = dialector.Initialize(db)
 	}

+ 11 - 3
statement.go

@@ -21,6 +21,13 @@ type Instance struct {
 	Statement    *Statement
 }
 
+func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) {
+	if len(clauses) > 0 {
+		instance.Statement.Build(clauses...)
+	}
+	return instance.Statement.SQL.String(), instance.Statement.Vars
+}
+
 // AddError add error to instance
 func (inst Instance) AddError(err error) {
 	if inst.Error == nil {
@@ -205,16 +212,17 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
 
 // Build build sql with clauses names
 func (stmt Statement) Build(clauses ...string) {
-	var includeSpace bool
+	var firstClauseWritten bool
 
 	for _, name := range clauses {
 		if c, ok := stmt.Clauses[name]; ok {
-			if includeSpace {
+			if firstClauseWritten {
 				stmt.WriteByte(' ')
 			}
 
-			includeSpace = true
+			firstClauseWritten = true
 			c.Build(stmt)
 		}
 	}
+	// TODO handle named vars
 }

+ 2 - 2
tests/callbacks_test.go

@@ -99,8 +99,8 @@ func TestCallbacks(t *testing.T) {
 	}
 
 	for idx, data := range datas {
-		var err error
-		callbacks := gorm.InitializeCallbacks()
+		db, err := gorm.Open(nil, nil)
+		callbacks := db.Callback()
 
 		for _, c := range data.callbacks {
 			var v interface{} = callbacks.Create()