callbacks_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package gorm_test
  2. import (
  3. "errors"
  4. "reflect"
  5. "testing"
  6. "github.com/jinzhu/gorm"
  7. )
  8. func (s *Product) BeforeCreate() (err error) {
  9. if s.Code == "Invalid" {
  10. err = errors.New("invalid product")
  11. }
  12. s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
  13. return
  14. }
  15. func (s *Product) BeforeUpdate() (err error) {
  16. if s.Code == "dont_update" {
  17. err = errors.New("can't update")
  18. }
  19. s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
  20. return
  21. }
  22. func (s *Product) BeforeSave() (err error) {
  23. if s.Code == "dont_save" {
  24. err = errors.New("can't save")
  25. }
  26. s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
  27. return
  28. }
  29. func (s *Product) AfterFind() {
  30. s.AfterFindCallTimes = s.AfterFindCallTimes + 1
  31. }
  32. func (s *Product) AfterCreate(tx *gorm.DB) {
  33. tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
  34. }
  35. func (s *Product) AfterUpdate() {
  36. s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
  37. }
  38. func (s *Product) AfterSave() (err error) {
  39. if s.Code == "after_save_error" {
  40. err = errors.New("can't save")
  41. }
  42. s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
  43. return
  44. }
  45. func (s *Product) BeforeDelete() (err error) {
  46. if s.Code == "dont_delete" {
  47. err = errors.New("can't delete")
  48. }
  49. s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
  50. return
  51. }
  52. func (s *Product) AfterDelete() (err error) {
  53. if s.Code == "after_delete_error" {
  54. err = errors.New("can't delete")
  55. }
  56. s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
  57. return
  58. }
  59. func (s *Product) GetCallTimes() []int64 {
  60. return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
  61. }
  62. func TestRunCallbacks(t *testing.T) {
  63. p := Product{Code: "unique_code", Price: 100}
  64. DB.Save(&p)
  65. if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
  66. t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
  67. }
  68. DB.Where("Code = ?", "unique_code").First(&p)
  69. if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
  70. t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
  71. }
  72. p.Price = 200
  73. DB.Save(&p)
  74. if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
  75. t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
  76. }
  77. var products []Product
  78. DB.Find(&products, "code = ?", "unique_code")
  79. if products[0].AfterFindCallTimes != 2 {
  80. t.Errorf("AfterFind callbacks should work with slice")
  81. }
  82. DB.Where("Code = ?", "unique_code").First(&p)
  83. if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
  84. t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
  85. }
  86. DB.Delete(&p)
  87. if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
  88. t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
  89. }
  90. if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
  91. t.Errorf("Can't find a deleted record")
  92. }
  93. }
  94. func TestCallbacksWithErrors(t *testing.T) {
  95. p := Product{Code: "Invalid", Price: 100}
  96. if DB.Save(&p).Error == nil {
  97. t.Errorf("An error from before create callbacks happened when create with invalid value")
  98. }
  99. if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
  100. t.Errorf("Should not save record that have errors")
  101. }
  102. if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
  103. t.Errorf("An error from after create callbacks happened when create with invalid value")
  104. }
  105. p2 := Product{Code: "update_callback", Price: 100}
  106. DB.Save(&p2)
  107. p2.Code = "dont_update"
  108. if DB.Save(&p2).Error == nil {
  109. t.Errorf("An error from before update callbacks happened when update with invalid value")
  110. }
  111. if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
  112. t.Errorf("Record Should not be updated due to errors happened in before update callback")
  113. }
  114. if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
  115. t.Errorf("Record Should not be updated due to errors happened in before update callback")
  116. }
  117. p2.Code = "dont_save"
  118. if DB.Save(&p2).Error == nil {
  119. t.Errorf("An error from before save callbacks happened when update with invalid value")
  120. }
  121. p3 := Product{Code: "dont_delete", Price: 100}
  122. DB.Save(&p3)
  123. if DB.Delete(&p3).Error == nil {
  124. t.Errorf("An error from before delete callbacks happened when delete")
  125. }
  126. if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
  127. t.Errorf("An error from before delete callbacks happened")
  128. }
  129. p4 := Product{Code: "after_save_error", Price: 100}
  130. DB.Save(&p4)
  131. if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
  132. t.Errorf("Record should be reverted if get an error in after save callback")
  133. }
  134. p5 := Product{Code: "after_delete_error", Price: 100}
  135. DB.Save(&p5)
  136. if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
  137. t.Errorf("Record should be found")
  138. }
  139. DB.Delete(&p5)
  140. if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
  141. t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
  142. }
  143. }
  144. func TestGetCallback(t *testing.T) {
  145. scope := DB.NewScope(nil)
  146. if DB.Callback().Create().Get("gorm:test_callback") != nil {
  147. t.Errorf("`gorm:test_callback` should be nil")
  148. }
  149. DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
  150. callback := DB.Callback().Create().Get("gorm:test_callback")
  151. if callback == nil {
  152. t.Errorf("`gorm:test_callback` should be non-nil")
  153. }
  154. callback(scope)
  155. if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
  156. t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
  157. }
  158. DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
  159. callback = DB.Callback().Create().Get("gorm:test_callback")
  160. if callback == nil {
  161. t.Errorf("`gorm:test_callback` should be non-nil")
  162. }
  163. callback(scope)
  164. if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
  165. t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
  166. }
  167. DB.Callback().Create().Remove("gorm:test_callback")
  168. if DB.Callback().Create().Get("gorm:test_callback") != nil {
  169. t.Errorf("`gorm:test_callback` should be nil")
  170. }
  171. DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
  172. callback = DB.Callback().Create().Get("gorm:test_callback")
  173. if callback == nil {
  174. t.Errorf("`gorm:test_callback` should be non-nil")
  175. }
  176. callback(scope)
  177. if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
  178. t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
  179. }
  180. }