Browse Source

Setup Transaction

Jinzhu 4 years ago
parent
commit
5ccd76f76c
5 changed files with 57 additions and 18 deletions
  1. 4 0
      association.go
  2. 3 2
      callbacks/query.go
  3. 40 16
      finisher_api.go
  4. 9 0
      interfaces.go
  5. 1 0
      logger/logger.go

+ 4 - 0
association.go

@@ -3,3 +3,7 @@ package gorm
 // Association Mode contains some helper methods to handle relationship things easily.
 type Association struct {
 }
+
+func (db *DB) Association(column string) *Association {
+	return nil
+}

+ 3 - 2
callbacks/query.go

@@ -11,12 +11,13 @@ func Query(db *gorm.DB) {
 	if db.Statement.SQL.String() == "" {
 		db.Statement.AddClauseIfNotExists(clause.Select{})
 		db.Statement.AddClauseIfNotExists(clause.From{})
-
 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
 	}
 
-	_, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
+	rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 	db.AddError(err)
+	_ = rows
+	// scan rows
 }
 
 func Preload(db *gorm.DB) {

+ 40 - 16
finisher_api.go

@@ -23,6 +23,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
 
 // First find first record that match given conditions, order by primary key
 func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
+	// TODO handle where
 	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
 		Desc:   true,
@@ -35,12 +36,18 @@ func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
 // Take return a record that match given conditions, the order will depend on the database implementation
 func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) {
 	tx = db.getInstance()
+	tx.Statement.Dest = out
+	tx.callbacks.Query().Execute(tx)
 	return
 }
 
 // Last find last record that match given conditions, order by primary key
 func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) {
-	tx = db.getInstance()
+	tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
+		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
+	})
+	tx.Statement.Dest = out
+	tx.callbacks.Query().Execute(tx)
 	return
 }
 
@@ -88,21 +95,12 @@ func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) {
 	return
 }
 
-func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) {
-	tx = db.getInstance()
-	return
-}
-
 //Preloads only preloads relations, don`t touch out
 func (db *DB) Preloads(out interface{}) (tx *DB) {
 	tx = db.getInstance()
 	return
 }
 
-func (db *DB) Association(column string) *Association {
-	return nil
-}
-
 func (db *DB) Count(value interface{}) (tx *DB) {
 	tx = db.getInstance()
 	return
@@ -130,6 +128,7 @@ func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error {
 	return nil
 }
 
+// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
 func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
 	panicked := true
 	tx := db.Begin(opts...)
@@ -150,21 +149,46 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
 	return
 }
 
+// Begin begins a transaction
 func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
 	tx = db.getInstance()
+	if beginner, ok := tx.DB.(TxBeginner); ok {
+		var opt *sql.TxOptions
+		var err error
+		if len(opts) > 0 {
+			opt = opts[0]
+		}
+
+		if tx.DB, err = beginner.BeginTx(db.Context, opt); err != nil {
+			tx.AddError(err)
+		}
+	} else {
+		tx.AddError(ErrInvalidTransaction)
+	}
 	return
 }
 
-func (db *DB) Commit() (tx *DB) {
-	tx = db.getInstance()
-	return
+// Commit commit a transaction
+func (db *DB) Commit() *DB {
+	if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil {
+		db.AddError(comminter.Commit())
+	} else {
+		db.AddError(ErrInvalidTransaction)
+	}
+	return db
 }
 
-func (db *DB) Rollback() (tx *DB) {
-	tx = db.getInstance()
-	return
+// Rollback rollback a transaction
+func (db *DB) Rollback() *DB {
+	if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil {
+		db.AddError(comminter.Rollback())
+	} else {
+		db.AddError(ErrInvalidTransaction)
+	}
+	return db
 }
 
+// Exec execute raw sql
 func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
 	tx = db.getInstance()
 	tx.Statement.SQL = strings.Builder{}

+ 9 - 0
interfaces.go

@@ -25,6 +25,15 @@ type CommonDB interface {
 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
 }
 
+type TxBeginner interface {
+	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
+}
+
+type TxCommiter interface {
+	Commit() error
+	Rollback() error
+}
+
 type BeforeCreateInterface interface {
 	BeforeCreate(*DB)
 }

+ 1 - 0
logger/logger.go

@@ -53,6 +53,7 @@ type Interface interface {
 
 var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
 	SlowThreshold: 100 * time.Millisecond,
+	LogLevel:      Warn,
 	Colorful:      true,
 })