callback_query.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. )
  7. // Define callbacks for querying
  8. func init() {
  9. DefaultCallback.Query().Register("gorm:query", queryCallback)
  10. DefaultCallback.Query().Register("gorm:preload", preloadCallback)
  11. DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
  12. }
  13. // queryCallback used to query data from database
  14. func queryCallback(scope *Scope) {
  15. if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
  16. return
  17. }
  18. //we are only preloading relations, dont touch base model
  19. if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
  20. return
  21. }
  22. defer scope.trace(NowFunc())
  23. var (
  24. isSlice, isPtr bool
  25. resultType reflect.Type
  26. results = scope.IndirectValue()
  27. )
  28. if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
  29. if primaryField := scope.PrimaryField(); primaryField != nil {
  30. scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
  31. }
  32. }
  33. if value, ok := scope.Get("gorm:query_destination"); ok {
  34. results = indirect(reflect.ValueOf(value))
  35. }
  36. if kind := results.Kind(); kind == reflect.Slice {
  37. isSlice = true
  38. resultType = results.Type().Elem()
  39. results.Set(reflect.MakeSlice(results.Type(), 0, 0))
  40. if resultType.Kind() == reflect.Ptr {
  41. isPtr = true
  42. resultType = resultType.Elem()
  43. }
  44. } else if kind != reflect.Struct {
  45. scope.Err(errors.New("unsupported destination, should be slice or struct"))
  46. return
  47. }
  48. scope.prepareQuerySQL()
  49. if !scope.HasError() {
  50. scope.db.RowsAffected = 0
  51. if str, ok := scope.Get("gorm:query_option"); ok {
  52. scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
  53. }
  54. if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
  55. defer rows.Close()
  56. columns, _ := rows.Columns()
  57. for rows.Next() {
  58. scope.db.RowsAffected++
  59. elem := results
  60. if isSlice {
  61. elem = reflect.New(resultType).Elem()
  62. }
  63. scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
  64. if isSlice {
  65. if isPtr {
  66. results.Set(reflect.Append(results, elem.Addr()))
  67. } else {
  68. results.Set(reflect.Append(results, elem))
  69. }
  70. }
  71. }
  72. if err := rows.Err(); err != nil {
  73. scope.Err(err)
  74. } else if scope.db.RowsAffected == 0 && !isSlice {
  75. scope.Err(ErrRecordNotFound)
  76. }
  77. }
  78. }
  79. }
  80. // afterQueryCallback will invoke `AfterFind` method after querying
  81. func afterQueryCallback(scope *Scope) {
  82. if !scope.HasError() {
  83. scope.CallMethod("AfterFind")
  84. }
  85. }