Selaa lähdekoodia

Implement callbacks

Jinzhu 4 vuotta sitten
vanhempi
säilyke
e509b3100d
6 muutettua tiedostoa jossa 422 lisäystä ja 22 poistoa
  1. 211 0
      callbacks.go
  2. 131 0
      callbacks_test.go
  3. 14 7
      helpers.go
  4. 46 0
      logger/logger.go
  5. 0 15
      model.go
  6. 20 0
      utils/utils.go

+ 211 - 0
callbacks.go

@@ -0,0 +1,211 @@
+package gorm
+
+import (
+	"fmt"
+	"log"
+
+	"github.com/jinzhu/gorm/logger"
+	"github.com/jinzhu/gorm/utils"
+)
+
+// Callbacks gorm callbacks manager
+type Callbacks struct {
+	creates    []func(*DB)
+	queries    []func(*DB)
+	updates    []func(*DB)
+	deletes    []func(*DB)
+	row        []func(*DB)
+	raw        []func(*DB)
+	db         *DB
+	processors []*processor
+}
+
+type processor struct {
+	kind      string
+	name      string
+	before    string
+	after     string
+	remove    bool
+	replace   bool
+	match     func(*DB) bool
+	handler   func(*DB)
+	callbacks *Callbacks
+}
+
+func (cs *Callbacks) Create() *processor {
+	return &processor{callbacks: cs, kind: "create"}
+}
+
+func (cs *Callbacks) Query() *processor {
+	return &processor{callbacks: cs, kind: "query"}
+}
+
+func (cs *Callbacks) Update() *processor {
+	return &processor{callbacks: cs, kind: "update"}
+}
+
+func (cs *Callbacks) Delete() *processor {
+	return &processor{callbacks: cs, kind: "delete"}
+}
+
+func (cs *Callbacks) Row() *processor {
+	return &processor{callbacks: cs, kind: "row"}
+}
+
+func (cs *Callbacks) Raw() *processor {
+	return &processor{callbacks: cs, kind: "raw"}
+}
+
+func (p *processor) Before(name string) *processor {
+	p.before = name
+	return p
+}
+
+func (p *processor) After(name string) *processor {
+	p.after = name
+	return p
+}
+
+func (p *processor) Match(fc func(*DB) bool) *processor {
+	p.match = fc
+	return p
+}
+
+func (p *processor) Get(name string) func(*DB) {
+	for i := len(p.callbacks.processors) - 1; i >= 0; i-- {
+		if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove {
+			return v.handler
+		}
+	}
+	return nil
+}
+
+func (p *processor) Register(name string, fn func(*DB)) {
+	p.name = name
+	p.handler = fn
+	p.callbacks.processors = append(p.callbacks.processors, p)
+	p.callbacks.compile(p.callbacks.db)
+}
+
+func (p *processor) Remove(name string) {
+	logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum())
+	p.name = name
+	p.remove = true
+	p.callbacks.processors = append(p.callbacks.processors, p)
+	p.callbacks.compile(p.callbacks.db)
+}
+
+func (p *processor) Replace(name string, fn func(*DB)) {
+	logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
+	p.name = name
+	p.handler = fn
+	p.replace = true
+	p.callbacks.processors = append(p.callbacks.processors, p)
+	p.callbacks.compile(p.callbacks.db)
+}
+
+// getRIndex get right index from string slice
+func getRIndex(strs []string, str string) int {
+	for i := len(strs) - 1; i >= 0; i-- {
+		if strs[i] == str {
+			return i
+		}
+	}
+	return -1
+}
+
+func sortProcessors(ps []*processor) []func(*DB) {
+	var (
+		allNames, sortedNames []string
+		sortProcessor         func(*processor) error
+	)
+
+	for _, p := range ps {
+		// show warning message the callback name already exists
+		if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove {
+			log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum())
+		}
+		allNames = append(allNames, p.name)
+	}
+
+	sortProcessor = func(p *processor) error {
+		if getRIndex(sortedNames, p.name) == -1 { // if not sorted
+			if p.before != "" { // if defined before callback
+				if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 {
+					if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true {
+						// if before callback already sorted, append current callback just after it
+						sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...)
+					} else if curIdx > sortedIdx {
+						return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before)
+					}
+				} else if idx := getRIndex(allNames, p.before); idx != -1 {
+					// if before callback exists
+					ps[idx].after = p.name
+				}
+			}
+
+			if p.after != "" { // if defined after callback
+				if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 {
+					// if after callback sorted, append current callback to last
+					sortedNames = append(sortedNames, p.name)
+				} else if idx := getRIndex(allNames, p.after); idx != -1 {
+					// if after callback exists but haven't sorted
+					// set after callback's before callback to current callback
+					if after := ps[idx]; after.before == "" {
+						after.before = p.name
+						sortProcessor(after)
+					}
+				}
+			}
+
+			// if current callback haven't been sorted, append it to last
+			if getRIndex(sortedNames, p.name) == -1 {
+				sortedNames = append(sortedNames, p.name)
+			}
+		}
+
+		return nil
+	}
+
+	for _, p := range ps {
+		sortProcessor(p)
+	}
+
+	var fns []func(*DB)
+	for _, name := range sortedNames {
+		if idx := getRIndex(allNames, name); !ps[idx].remove {
+			fns = append(fns, ps[idx].handler)
+		}
+	}
+
+	return fns
+}
+
+// compile processors
+func (cs *Callbacks) compile(db *DB) {
+	processors := map[string][]*processor{}
+	for _, p := range cs.processors {
+		if p.name != "" {
+			if p.match == nil || p.match(db) {
+				processors[p.kind] = append(processors[p.kind], p)
+			}
+		}
+	}
+
+	for name, ps := range processors {
+		switch name {
+		case "create":
+			cs.creates = sortProcessors(ps)
+		case "query":
+			cs.queries = sortProcessors(ps)
+		case "update":
+			cs.updates = sortProcessors(ps)
+		case "delete":
+			cs.deletes = sortProcessors(ps)
+		case "row":
+			cs.row = sortProcessors(ps)
+		case "raw":
+			cs.raw = sortProcessors(ps)
+		}
+	}
+}

+ 131 - 0
callbacks_test.go

@@ -0,0 +1,131 @@
+package gorm
+
+import (
+	"fmt"
+	"reflect"
+	"runtime"
+	"strings"
+	"testing"
+)
+
+func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) {
+	var got []string
+
+	for _, f := range funcs {
+		got = append(got, getFuncName(f))
+	}
+
+	return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got)
+}
+
+func getFuncName(fc func(*DB)) string {
+	fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".")
+	return fnames[len(fnames)-1]
+}
+
+func c1(*DB) {}
+func c2(*DB) {}
+func c3(*DB) {}
+func c4(*DB) {}
+func c5(*DB) {}
+
+func TestCallbacks(t *testing.T) {
+	type callback struct {
+		name    string
+		before  string
+		after   string
+		remove  bool
+		replace bool
+		err     error
+		match   func(*DB) bool
+		h       func(*DB)
+	}
+
+	datas := []struct {
+		callbacks []callback
+		results   []string
+	}{
+		{
+			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}},
+			results:   []string{"c1", "c2", "c3", "c4", "c5"},
+		},
+		{
+			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}},
+			results:   []string{"c1", "c2", "c3", "c5", "c4"},
+		},
+		{
+			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}},
+			results:   []string{"c1", "c2", "c3", "c5", "c4"},
+		},
+		{
+			callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}},
+			results:   []string{"c1", "c2", "c3", "c5", "c4"},
+		},
+		{
+			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}},
+			results:   []string{"c1", "c5", "c2", "c3", "c4"},
+		},
+		{
+			callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
+			results:   []string{"c1", "c3", "c5", "c2", "c4"},
+		},
+		{
+			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
+			results:   []string{"c1", "c5", "c3", "c4"},
+		},
+		{
+			callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
+			results:   []string{"c1", "c4", "c3"},
+		},
+	}
+
+	// func TestRegisterCallbackWithComplexOrder(t *testing.T) {
+	// 	var callback2 = &Callback{logger: defaultLogger}
+
+	// 	callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
+	// 	callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
+	// 	callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
+	// 	callback2.Delete().Register("after_create1", afterCreate1)
+	// 	callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
+
+	// 	if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
+	// 		t.Errorf("register callback with order")
+	// 	}
+	// }
+
+	for idx, data := range datas {
+		callbacks := &Callbacks{}
+
+		for _, c := range data.callbacks {
+			p := callbacks.Create()
+
+			if c.name == "" {
+				c.name = getFuncName(c.h)
+			}
+
+			if c.before != "" {
+				p = p.Before(c.before)
+			}
+
+			if c.after != "" {
+				p = p.After(c.after)
+			}
+
+			if c.match != nil {
+				p = p.Match(c.match)
+			}
+
+			if c.remove {
+				p.Remove(c.name)
+			} else if c.replace {
+				p.Replace(c.name, c.h)
+			} else {
+				p.Register(c.name, c.h)
+			}
+		}
+
+		if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok {
+			t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg)
+		}
+	}
+}

+ 14 - 7
errors.go → helpers.go

@@ -1,6 +1,9 @@
 package gorm
 
-import "errors"
+import (
+	"errors"
+	"time"
+)
 
 var (
 	// ErrRecordNotFound record not found error
@@ -13,10 +16,14 @@ var (
 	ErrUnaddressable = errors.New("using unaddressable value")
 )
 
-type Error struct {
-	Err error
-}
-
-func (e Error) Unwrap() error {
-	return e.Err
+// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
+// It may be embeded into your model or you may build your own model without it
+//    type User struct {
+//      gorm.Model
+//    }
+type Model struct {
+	ID        uint `gorm:"primary_key"`
+	CreatedAt time.Time
+	UpdatedAt time.Time
+	DeletedAt *time.Time `gorm:"index"`
 }

+ 46 - 0
logger/logger.go

@@ -1,7 +1,15 @@
 package logger
 
+import (
+	"fmt"
+	"log"
+	"os"
+)
+
 type LogLevel int
 
+var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", 0)}
+
 const (
 	Info LogLevel = iota + 1
 	Warn
@@ -11,4 +19,42 @@ const (
 // Interface logger interface
 type Interface interface {
 	LogMode(LogLevel) Interface
+	Info(string, ...interface{})
+	Warn(string, ...interface{})
+	Error(string, ...interface{})
+}
+
+// Writer log writer interface
+type Writer interface {
+	Print(...interface{})
+}
+
+type Logger struct {
+	Writer
+	logLevel LogLevel
+}
+
+func (logger Logger) LogMode(level LogLevel) Interface {
+	return Logger{Writer: logger.Writer, logLevel: level}
+}
+
+// Info print info
+func (logger Logger) Info(msg string, data ...interface{}) {
+	if logger.logLevel >= Info {
+		logger.Print("[info] " + fmt.Sprintf(msg, data...))
+	}
+}
+
+// Warn print warn messages
+func (logger Logger) Warn(msg string, data ...interface{}) {
+	if logger.logLevel >= Warn {
+		logger.Print("[warn] " + fmt.Sprintf(msg, data...))
+	}
+}
+
+// Error print error messages
+func (logger Logger) Error(msg string, data ...interface{}) {
+	if logger.logLevel >= Error {
+		logger.Print("[error] " + fmt.Sprintf(msg, data...))
+	}
 }

+ 0 - 15
model.go

@@ -1,15 +0,0 @@
-package gorm
-
-import "time"
-
-// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
-// It may be embeded into your model or you may build your own model without it
-//    type User struct {
-//      gorm.Model
-//    }
-type Model struct {
-	ID        uint `gorm:"primary_key"`
-	CreatedAt time.Time
-	UpdatedAt time.Time
-	DeletedAt *time.Time `gorm:"index"`
-}

+ 20 - 0
utils/utils.go

@@ -0,0 +1,20 @@
+package utils
+
+import (
+	"fmt"
+	"regexp"
+	"runtime"
+)
+
+var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
+var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`)
+
+func FileWithLineNum() string {
+	for i := 2; i < 15; i++ {
+		_, file, line, ok := runtime.Caller(i)
+		if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
+			return fmt.Sprintf("%v:%v", file, line)
+		}
+	}
+	return ""
+}