callback_query_preload.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "strconv"
  7. "strings"
  8. )
  9. // preloadCallback used to preload associations
  10. func preloadCallback(scope *Scope) {
  11. if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
  12. return
  13. }
  14. if ap, ok := scope.Get("gorm:auto_preload"); ok {
  15. // If gorm:auto_preload IS NOT a bool then auto preload.
  16. // Else if it IS a bool, use the value
  17. if apb, ok := ap.(bool); !ok {
  18. autoPreload(scope)
  19. } else if apb {
  20. autoPreload(scope)
  21. }
  22. }
  23. if scope.Search.preload == nil || scope.HasError() {
  24. return
  25. }
  26. var (
  27. preloadedMap = map[string]bool{}
  28. fields = scope.Fields()
  29. )
  30. for _, preload := range scope.Search.preload {
  31. var (
  32. preloadFields = strings.Split(preload.schema, ".")
  33. currentScope = scope
  34. currentFields = fields
  35. )
  36. for idx, preloadField := range preloadFields {
  37. var currentPreloadConditions []interface{}
  38. if currentScope == nil {
  39. continue
  40. }
  41. // if not preloaded
  42. if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
  43. // assign search conditions to last preload
  44. if idx == len(preloadFields)-1 {
  45. currentPreloadConditions = preload.conditions
  46. }
  47. for _, field := range currentFields {
  48. if field.Name != preloadField || field.Relationship == nil {
  49. continue
  50. }
  51. switch field.Relationship.Kind {
  52. case "has_one":
  53. currentScope.handleHasOnePreload(field, currentPreloadConditions)
  54. case "has_many":
  55. currentScope.handleHasManyPreload(field, currentPreloadConditions)
  56. case "belongs_to":
  57. currentScope.handleBelongsToPreload(field, currentPreloadConditions)
  58. case "many_to_many":
  59. currentScope.handleManyToManyPreload(field, currentPreloadConditions)
  60. default:
  61. scope.Err(errors.New("unsupported relation"))
  62. }
  63. preloadedMap[preloadKey] = true
  64. break
  65. }
  66. if !preloadedMap[preloadKey] {
  67. scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
  68. return
  69. }
  70. }
  71. // preload next level
  72. if idx < len(preloadFields)-1 {
  73. currentScope = currentScope.getColumnAsScope(preloadField)
  74. if currentScope != nil {
  75. currentFields = currentScope.Fields()
  76. }
  77. }
  78. }
  79. }
  80. }
  81. func autoPreload(scope *Scope) {
  82. for _, field := range scope.Fields() {
  83. if field.Relationship == nil {
  84. continue
  85. }
  86. if val, ok := field.TagSettingsGet("PRELOAD"); ok {
  87. if preload, err := strconv.ParseBool(val); err != nil {
  88. scope.Err(errors.New("invalid preload option"))
  89. return
  90. } else if !preload {
  91. continue
  92. }
  93. }
  94. scope.Search.Preload(field.Name)
  95. }
  96. }
  97. func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
  98. var (
  99. preloadDB = scope.NewDB()
  100. preloadConditions []interface{}
  101. )
  102. for _, condition := range conditions {
  103. if scopes, ok := condition.(func(*DB) *DB); ok {
  104. preloadDB = scopes(preloadDB)
  105. } else {
  106. preloadConditions = append(preloadConditions, condition)
  107. }
  108. }
  109. return preloadDB, preloadConditions
  110. }
  111. // handleHasOnePreload used to preload has one associations
  112. func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
  113. relation := field.Relationship
  114. // get relations's primary keys
  115. primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
  116. if len(primaryKeys) == 0 {
  117. return
  118. }
  119. // preload conditions
  120. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  121. // find relations
  122. query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
  123. values := toQueryValues(primaryKeys)
  124. if relation.PolymorphicType != "" {
  125. query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
  126. values = append(values, relation.PolymorphicValue)
  127. }
  128. results := makeSlice(field.Struct.Type)
  129. scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
  130. // assign find results
  131. var (
  132. resultsValue = indirect(reflect.ValueOf(results))
  133. indirectScopeValue = scope.IndirectValue()
  134. )
  135. if indirectScopeValue.Kind() == reflect.Slice {
  136. foreignValuesToResults := make(map[string]reflect.Value)
  137. for i := 0; i < resultsValue.Len(); i++ {
  138. result := resultsValue.Index(i)
  139. foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
  140. foreignValuesToResults[foreignValues] = result
  141. }
  142. for j := 0; j < indirectScopeValue.Len(); j++ {
  143. indirectValue := indirect(indirectScopeValue.Index(j))
  144. valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
  145. if result, found := foreignValuesToResults[valueString]; found {
  146. indirectValue.FieldByName(field.Name).Set(result)
  147. }
  148. }
  149. } else {
  150. for i := 0; i < resultsValue.Len(); i++ {
  151. result := resultsValue.Index(i)
  152. scope.Err(field.Set(result))
  153. }
  154. }
  155. }
  156. // handleHasManyPreload used to preload has many associations
  157. func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
  158. relation := field.Relationship
  159. // get relations's primary keys
  160. primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
  161. if len(primaryKeys) == 0 {
  162. return
  163. }
  164. // preload conditions
  165. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  166. // find relations
  167. query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
  168. values := toQueryValues(primaryKeys)
  169. if relation.PolymorphicType != "" {
  170. query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
  171. values = append(values, relation.PolymorphicValue)
  172. }
  173. results := makeSlice(field.Struct.Type)
  174. scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
  175. // assign find results
  176. var (
  177. resultsValue = indirect(reflect.ValueOf(results))
  178. indirectScopeValue = scope.IndirectValue()
  179. )
  180. if indirectScopeValue.Kind() == reflect.Slice {
  181. preloadMap := make(map[string][]reflect.Value)
  182. for i := 0; i < resultsValue.Len(); i++ {
  183. result := resultsValue.Index(i)
  184. foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
  185. preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
  186. }
  187. for j := 0; j < indirectScopeValue.Len(); j++ {
  188. object := indirect(indirectScopeValue.Index(j))
  189. objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
  190. f := object.FieldByName(field.Name)
  191. if results, ok := preloadMap[toString(objectRealValue)]; ok {
  192. f.Set(reflect.Append(f, results...))
  193. } else {
  194. f.Set(reflect.MakeSlice(f.Type(), 0, 0))
  195. }
  196. }
  197. } else {
  198. scope.Err(field.Set(resultsValue))
  199. }
  200. }
  201. // handleBelongsToPreload used to preload belongs to associations
  202. func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
  203. relation := field.Relationship
  204. // preload conditions
  205. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  206. // get relations's primary keys
  207. primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
  208. if len(primaryKeys) == 0 {
  209. return
  210. }
  211. // find relations
  212. results := makeSlice(field.Struct.Type)
  213. scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
  214. // assign find results
  215. var (
  216. resultsValue = indirect(reflect.ValueOf(results))
  217. indirectScopeValue = scope.IndirectValue()
  218. )
  219. foreignFieldToObjects := make(map[string][]*reflect.Value)
  220. if indirectScopeValue.Kind() == reflect.Slice {
  221. for j := 0; j < indirectScopeValue.Len(); j++ {
  222. object := indirect(indirectScopeValue.Index(j))
  223. valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
  224. foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
  225. }
  226. }
  227. for i := 0; i < resultsValue.Len(); i++ {
  228. result := resultsValue.Index(i)
  229. if indirectScopeValue.Kind() == reflect.Slice {
  230. valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
  231. if objects, found := foreignFieldToObjects[valueString]; found {
  232. for _, object := range objects {
  233. object.FieldByName(field.Name).Set(result)
  234. }
  235. }
  236. } else {
  237. scope.Err(field.Set(result))
  238. }
  239. }
  240. }
  241. // handleManyToManyPreload used to preload many to many associations
  242. func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
  243. var (
  244. relation = field.Relationship
  245. joinTableHandler = relation.JoinTableHandler
  246. fieldType = field.Struct.Type.Elem()
  247. foreignKeyValue interface{}
  248. foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
  249. linkHash = map[string][]reflect.Value{}
  250. isPtr bool
  251. )
  252. if fieldType.Kind() == reflect.Ptr {
  253. isPtr = true
  254. fieldType = fieldType.Elem()
  255. }
  256. var sourceKeys = []string{}
  257. for _, key := range joinTableHandler.SourceForeignKeys() {
  258. sourceKeys = append(sourceKeys, key.DBName)
  259. }
  260. // preload conditions
  261. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  262. // generate query with join table
  263. newScope := scope.New(reflect.New(fieldType).Interface())
  264. preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
  265. if len(preloadDB.search.selects) == 0 {
  266. preloadDB = preloadDB.Select("*")
  267. }
  268. preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
  269. // preload inline conditions
  270. if len(preloadConditions) > 0 {
  271. preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
  272. }
  273. rows, err := preloadDB.Rows()
  274. if scope.Err(err) != nil {
  275. return
  276. }
  277. defer rows.Close()
  278. columns, _ := rows.Columns()
  279. for rows.Next() {
  280. var (
  281. elem = reflect.New(fieldType).Elem()
  282. fields = scope.New(elem.Addr().Interface()).Fields()
  283. )
  284. // register foreign keys in join tables
  285. var joinTableFields []*Field
  286. for _, sourceKey := range sourceKeys {
  287. joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
  288. }
  289. scope.scan(rows, columns, append(fields, joinTableFields...))
  290. scope.New(elem.Addr().Interface()).
  291. InstanceSet("gorm:skip_query_callback", true).
  292. callCallbacks(scope.db.parent.callbacks.queries)
  293. var foreignKeys = make([]interface{}, len(sourceKeys))
  294. // generate hashed forkey keys in join table
  295. for idx, joinTableField := range joinTableFields {
  296. if !joinTableField.Field.IsNil() {
  297. foreignKeys[idx] = joinTableField.Field.Elem().Interface()
  298. }
  299. }
  300. hashedSourceKeys := toString(foreignKeys)
  301. if isPtr {
  302. linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
  303. } else {
  304. linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
  305. }
  306. }
  307. if err := rows.Err(); err != nil {
  308. scope.Err(err)
  309. }
  310. // assign find results
  311. var (
  312. indirectScopeValue = scope.IndirectValue()
  313. fieldsSourceMap = map[string][]reflect.Value{}
  314. foreignFieldNames = []string{}
  315. )
  316. for _, dbName := range relation.ForeignFieldNames {
  317. if field, ok := scope.FieldByName(dbName); ok {
  318. foreignFieldNames = append(foreignFieldNames, field.Name)
  319. }
  320. }
  321. if indirectScopeValue.Kind() == reflect.Slice {
  322. for j := 0; j < indirectScopeValue.Len(); j++ {
  323. object := indirect(indirectScopeValue.Index(j))
  324. key := toString(getValueFromFields(object, foreignFieldNames))
  325. fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
  326. }
  327. } else if indirectScopeValue.IsValid() {
  328. key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
  329. fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
  330. }
  331. for source, fields := range fieldsSourceMap {
  332. for _, f := range fields {
  333. //If not 0 this means Value is a pointer and we already added preloaded models to it
  334. if f.Len() != 0 {
  335. continue
  336. }
  337. v := reflect.MakeSlice(f.Type(), 0, 0)
  338. if len(linkHash[source]) > 0 {
  339. v = reflect.Append(f, linkHash[source]...)
  340. }
  341. f.Set(v)
  342. }
  343. }
  344. }