callback_update.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "sort"
  6. "strings"
  7. )
  8. // Define callbacks for updating
  9. func init() {
  10. DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
  11. DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
  12. DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
  13. DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
  14. DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
  15. DefaultCallback.Update().Register("gorm:update", updateCallback)
  16. DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
  17. DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
  18. DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
  19. }
  20. // assignUpdatingAttributesCallback assign updating attributes to model
  21. func assignUpdatingAttributesCallback(scope *Scope) {
  22. if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
  23. if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
  24. scope.InstanceSet("gorm:update_attrs", updateMaps)
  25. } else {
  26. scope.SkipLeft()
  27. }
  28. }
  29. }
  30. // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
  31. func beforeUpdateCallback(scope *Scope) {
  32. if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
  33. scope.Err(errors.New("missing WHERE clause while updating"))
  34. return
  35. }
  36. if _, ok := scope.Get("gorm:update_column"); !ok {
  37. if !scope.HasError() {
  38. scope.CallMethod("BeforeSave")
  39. }
  40. if !scope.HasError() {
  41. scope.CallMethod("BeforeUpdate")
  42. }
  43. }
  44. }
  45. // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
  46. func updateTimeStampForUpdateCallback(scope *Scope) {
  47. if _, ok := scope.Get("gorm:update_column"); !ok {
  48. scope.SetColumn("UpdatedAt", scope.db.nowFunc())
  49. }
  50. }
  51. // updateCallback the callback used to update data to database
  52. func updateCallback(scope *Scope) {
  53. if !scope.HasError() {
  54. var sqls []string
  55. if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
  56. // Sort the column names so that the generated SQL is the same every time.
  57. updateMap := updateAttrs.(map[string]interface{})
  58. var columns []string
  59. for c := range updateMap {
  60. columns = append(columns, c)
  61. }
  62. sort.Strings(columns)
  63. for _, column := range columns {
  64. value := updateMap[column]
  65. sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
  66. }
  67. } else {
  68. for _, field := range scope.Fields() {
  69. if scope.changeableField(field) {
  70. if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) {
  71. if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
  72. sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
  73. }
  74. } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
  75. for _, foreignKey := range relationship.ForeignDBNames {
  76. if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
  77. sqls = append(sqls,
  78. fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
  79. }
  80. }
  81. }
  82. }
  83. }
  84. }
  85. var extraOption string
  86. if str, ok := scope.Get("gorm:update_option"); ok {
  87. extraOption = fmt.Sprint(str)
  88. }
  89. if len(sqls) > 0 {
  90. scope.Raw(fmt.Sprintf(
  91. "UPDATE %v SET %v%v%v",
  92. scope.QuotedTableName(),
  93. strings.Join(sqls, ", "),
  94. addExtraSpaceIfExist(scope.CombinedConditionSql()),
  95. addExtraSpaceIfExist(extraOption),
  96. )).Exec()
  97. }
  98. }
  99. }
  100. // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
  101. func afterUpdateCallback(scope *Scope) {
  102. if _, ok := scope.Get("gorm:update_column"); !ok {
  103. if !scope.HasError() {
  104. scope.CallMethod("AfterUpdate")
  105. }
  106. if !scope.HasError() {
  107. scope.CallMethod("AfterSave")
  108. }
  109. }
  110. }