callbacks.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "time"
  6. "github.com/jinzhu/gorm/logger"
  7. "github.com/jinzhu/gorm/schema"
  8. "github.com/jinzhu/gorm/utils"
  9. )
  10. func initializeCallbacks(db *DB) *callbacks {
  11. return &callbacks{
  12. processors: map[string]*processor{
  13. "create": &processor{db: db},
  14. "query": &processor{db: db},
  15. "update": &processor{db: db},
  16. "delete": &processor{db: db},
  17. "row": &processor{db: db},
  18. "raw": &processor{db: db},
  19. },
  20. }
  21. }
  22. // callbacks gorm callbacks manager
  23. type callbacks struct {
  24. processors map[string]*processor
  25. }
  26. type processor struct {
  27. db *DB
  28. fns []func(*DB)
  29. callbacks []*callback
  30. }
  31. type callback struct {
  32. name string
  33. before string
  34. after string
  35. remove bool
  36. replace bool
  37. match func(*DB) bool
  38. handler func(*DB)
  39. processor *processor
  40. }
  41. func (cs *callbacks) Create() *processor {
  42. return cs.processors["create"]
  43. }
  44. func (cs *callbacks) Query() *processor {
  45. return cs.processors["query"]
  46. }
  47. func (cs *callbacks) Update() *processor {
  48. return cs.processors["update"]
  49. }
  50. func (cs *callbacks) Delete() *processor {
  51. return cs.processors["delete"]
  52. }
  53. func (cs *callbacks) Row() *processor {
  54. return cs.processors["row"]
  55. }
  56. func (cs *callbacks) Raw() *processor {
  57. return cs.processors["raw"]
  58. }
  59. func (p *processor) Execute(db *DB) {
  60. curTime := time.Now()
  61. if stmt := db.Statement; stmt != nil {
  62. if stmt.Model == nil {
  63. stmt.Model = stmt.Dest
  64. }
  65. if stmt.Model != nil {
  66. err := stmt.Parse(stmt.Model)
  67. if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
  68. db.AddError(err)
  69. }
  70. }
  71. }
  72. for _, f := range p.fns {
  73. f(db)
  74. }
  75. if stmt := db.Statement; stmt != nil {
  76. db.Logger.RunWith(logger.Info, func() {
  77. db.Logger.Info(db.Dialector.Explain(stmt.SQL.String(), stmt.Vars))
  78. })
  79. }
  80. }
  81. func (p *processor) Get(name string) func(*DB) {
  82. for i := len(p.callbacks) - 1; i >= 0; i-- {
  83. if v := p.callbacks[i]; v.name == name && !v.remove {
  84. return v.handler
  85. }
  86. }
  87. return nil
  88. }
  89. func (p *processor) Before(name string) *callback {
  90. return &callback{before: name, processor: p}
  91. }
  92. func (p *processor) After(name string) *callback {
  93. return &callback{after: name, processor: p}
  94. }
  95. func (p *processor) Match(fc func(*DB) bool) *callback {
  96. return &callback{match: fc, processor: p}
  97. }
  98. func (p *processor) Register(name string, fn func(*DB)) error {
  99. return (&callback{processor: p}).Register(name, fn)
  100. }
  101. func (p *processor) Remove(name string) error {
  102. return (&callback{processor: p}).Remove(name)
  103. }
  104. func (p *processor) Replace(name string, fn func(*DB)) error {
  105. return (&callback{processor: p}).Replace(name, fn)
  106. }
  107. func (p *processor) compile() (err error) {
  108. var callbacks []*callback
  109. for _, callback := range p.callbacks {
  110. if callback.match == nil || callback.match(p.db) {
  111. callbacks = append(callbacks, callback)
  112. }
  113. }
  114. if p.fns, err = sortCallbacks(p.callbacks); err != nil {
  115. logger.Default.Error("Got error when compile callbacks, got %v", err)
  116. }
  117. return
  118. }
  119. func (c *callback) Before(name string) *callback {
  120. c.before = name
  121. return c
  122. }
  123. func (c *callback) After(name string) *callback {
  124. c.after = name
  125. return c
  126. }
  127. func (c *callback) Register(name string, fn func(*DB)) error {
  128. c.name = name
  129. c.handler = fn
  130. c.processor.callbacks = append(c.processor.callbacks, c)
  131. return c.processor.compile()
  132. }
  133. func (c *callback) Remove(name string) error {
  134. logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum())
  135. c.name = name
  136. c.remove = true
  137. c.processor.callbacks = append(c.processor.callbacks, c)
  138. return c.processor.compile()
  139. }
  140. func (c *callback) Replace(name string, fn func(*DB)) error {
  141. logger.Default.Info("replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
  142. c.name = name
  143. c.handler = fn
  144. c.replace = true
  145. c.processor.callbacks = append(c.processor.callbacks, c)
  146. return c.processor.compile()
  147. }
  148. // getRIndex get right index from string slice
  149. func getRIndex(strs []string, str string) int {
  150. for i := len(strs) - 1; i >= 0; i-- {
  151. if strs[i] == str {
  152. return i
  153. }
  154. }
  155. return -1
  156. }
  157. func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
  158. var (
  159. names, sorted []string
  160. sortCallback func(*callback) error
  161. )
  162. for _, c := range cs {
  163. // show warning message the callback name already exists
  164. if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
  165. logger.Default.Warn("duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
  166. }
  167. names = append(names, c.name)
  168. }
  169. sortCallback = func(c *callback) error {
  170. if c.before != "" { // if defined before callback
  171. if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
  172. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  173. // if before callback already sorted, append current callback just after it
  174. sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
  175. } else if curIdx > sortedIdx {
  176. return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
  177. }
  178. } else if idx := getRIndex(names, c.before); idx != -1 {
  179. // if before callback exists
  180. cs[idx].after = c.name
  181. }
  182. }
  183. if c.after != "" { // if defined after callback
  184. if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
  185. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  186. // if after callback sorted, append current callback to last
  187. sorted = append(sorted, c.name)
  188. } else if curIdx < sortedIdx {
  189. return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
  190. }
  191. } else if idx := getRIndex(names, c.after); idx != -1 {
  192. // if after callback exists but haven't sorted
  193. // set after callback's before callback to current callback
  194. after := cs[idx]
  195. if after.before == "" {
  196. after.before = c.name
  197. }
  198. if err := sortCallback(after); err != nil {
  199. return err
  200. }
  201. if err := sortCallback(c); err != nil {
  202. return err
  203. }
  204. }
  205. }
  206. // if current callback haven't been sorted, append it to last
  207. if getRIndex(sorted, c.name) == -1 {
  208. sorted = append(sorted, c.name)
  209. }
  210. return nil
  211. }
  212. for _, c := range cs {
  213. if err = sortCallback(c); err != nil {
  214. return
  215. }
  216. }
  217. for _, name := range sorted {
  218. if idx := getRIndex(names, name); !cs[idx].remove {
  219. fns = append(fns, cs[idx].handler)
  220. }
  221. }
  222. return
  223. }