callbacks.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/jinzhu/gorm/logger"
  6. "github.com/jinzhu/gorm/schema"
  7. "github.com/jinzhu/gorm/utils"
  8. )
  9. func initializeCallbacks(db *DB) *callbacks {
  10. return &callbacks{
  11. processors: map[string]*processor{
  12. "create": &processor{db: db},
  13. "query": &processor{db: db},
  14. "update": &processor{db: db},
  15. "delete": &processor{db: db},
  16. "row": &processor{db: db},
  17. "raw": &processor{db: db},
  18. },
  19. }
  20. }
  21. // callbacks gorm callbacks manager
  22. type callbacks struct {
  23. processors map[string]*processor
  24. }
  25. type processor struct {
  26. db *DB
  27. fns []func(*DB)
  28. callbacks []*callback
  29. }
  30. type callback struct {
  31. name string
  32. before string
  33. after string
  34. remove bool
  35. replace bool
  36. match func(*DB) bool
  37. handler func(*DB)
  38. processor *processor
  39. }
  40. func (cs *callbacks) Create() *processor {
  41. return cs.processors["create"]
  42. }
  43. func (cs *callbacks) Query() *processor {
  44. return cs.processors["query"]
  45. }
  46. func (cs *callbacks) Update() *processor {
  47. return cs.processors["update"]
  48. }
  49. func (cs *callbacks) Delete() *processor {
  50. return cs.processors["delete"]
  51. }
  52. func (cs *callbacks) Row() *processor {
  53. return cs.processors["row"]
  54. }
  55. func (cs *callbacks) Raw() *processor {
  56. return cs.processors["raw"]
  57. }
  58. func (p *processor) Execute(db *DB) {
  59. if stmt := db.Statement; stmt != nil {
  60. if stmt.Model == nil {
  61. stmt.Model = stmt.Dest
  62. }
  63. if stmt.Model != nil {
  64. err := stmt.Parse(stmt.Model)
  65. if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
  66. db.AddError(err)
  67. }
  68. }
  69. }
  70. for _, f := range p.fns {
  71. f(db)
  72. }
  73. }
  74. func (p *processor) Get(name string) func(*DB) {
  75. for i := len(p.callbacks) - 1; i >= 0; i-- {
  76. if v := p.callbacks[i]; v.name == name && !v.remove {
  77. return v.handler
  78. }
  79. }
  80. return nil
  81. }
  82. func (p *processor) Before(name string) *callback {
  83. return &callback{before: name, processor: p}
  84. }
  85. func (p *processor) After(name string) *callback {
  86. return &callback{after: name, processor: p}
  87. }
  88. func (p *processor) Match(fc func(*DB) bool) *callback {
  89. return &callback{match: fc, processor: p}
  90. }
  91. func (p *processor) Register(name string, fn func(*DB)) error {
  92. return (&callback{processor: p}).Register(name, fn)
  93. }
  94. func (p *processor) Remove(name string) error {
  95. return (&callback{processor: p}).Remove(name)
  96. }
  97. func (p *processor) Replace(name string, fn func(*DB)) error {
  98. return (&callback{processor: p}).Replace(name, fn)
  99. }
  100. func (p *processor) compile() (err error) {
  101. var callbacks []*callback
  102. for _, callback := range p.callbacks {
  103. if callback.match == nil || callback.match(p.db) {
  104. callbacks = append(callbacks, callback)
  105. }
  106. }
  107. if p.fns, err = sortCallbacks(p.callbacks); err != nil {
  108. logger.Default.Error("Got error when compile callbacks, got %v", err)
  109. }
  110. return
  111. }
  112. func (c *callback) Before(name string) *callback {
  113. c.before = name
  114. return c
  115. }
  116. func (c *callback) After(name string) *callback {
  117. c.after = name
  118. return c
  119. }
  120. func (c *callback) Register(name string, fn func(*DB)) error {
  121. c.name = name
  122. c.handler = fn
  123. c.processor.callbacks = append(c.processor.callbacks, c)
  124. return c.processor.compile()
  125. }
  126. func (c *callback) Remove(name string) error {
  127. logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum())
  128. c.name = name
  129. c.remove = true
  130. c.processor.callbacks = append(c.processor.callbacks, c)
  131. return c.processor.compile()
  132. }
  133. func (c *callback) Replace(name string, fn func(*DB)) error {
  134. logger.Default.Info("replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
  135. c.name = name
  136. c.handler = fn
  137. c.replace = true
  138. c.processor.callbacks = append(c.processor.callbacks, c)
  139. return c.processor.compile()
  140. }
  141. // getRIndex get right index from string slice
  142. func getRIndex(strs []string, str string) int {
  143. for i := len(strs) - 1; i >= 0; i-- {
  144. if strs[i] == str {
  145. return i
  146. }
  147. }
  148. return -1
  149. }
  150. func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
  151. var (
  152. names, sorted []string
  153. sortCallback func(*callback) error
  154. )
  155. for _, c := range cs {
  156. // show warning message the callback name already exists
  157. if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
  158. logger.Default.Warn("duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
  159. }
  160. names = append(names, c.name)
  161. }
  162. sortCallback = func(c *callback) error {
  163. if c.before != "" { // if defined before callback
  164. if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
  165. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  166. // if before callback already sorted, append current callback just after it
  167. sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
  168. } else if curIdx > sortedIdx {
  169. return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
  170. }
  171. } else if idx := getRIndex(names, c.before); idx != -1 {
  172. // if before callback exists
  173. cs[idx].after = c.name
  174. }
  175. }
  176. if c.after != "" { // if defined after callback
  177. if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
  178. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  179. // if after callback sorted, append current callback to last
  180. sorted = append(sorted, c.name)
  181. } else if curIdx < sortedIdx {
  182. return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
  183. }
  184. } else if idx := getRIndex(names, c.after); idx != -1 {
  185. // if after callback exists but haven't sorted
  186. // set after callback's before callback to current callback
  187. after := cs[idx]
  188. if after.before == "" {
  189. after.before = c.name
  190. }
  191. if err := sortCallback(after); err != nil {
  192. return err
  193. }
  194. if err := sortCallback(c); err != nil {
  195. return err
  196. }
  197. }
  198. }
  199. // if current callback haven't been sorted, append it to last
  200. if getRIndex(sorted, c.name) == -1 {
  201. sorted = append(sorted, c.name)
  202. }
  203. return nil
  204. }
  205. for _, c := range cs {
  206. if err = sortCallback(c); err != nil {
  207. return
  208. }
  209. }
  210. for _, name := range sorted {
  211. if idx := getRIndex(names, name); !cs[idx].remove {
  212. fns = append(fns, cs[idx].handler)
  213. }
  214. }
  215. return
  216. }