Pārlūkot izejas kodu

Support custom preloading SQL, close #598, #793, #824

Jinzhu 8 gadi atpakaļ
vecāks
revīzija
5883c70478
3 mainītis faili ar 52 papildinājumiem un 13 dzēšanām
  1. 12 0
      README.md
  2. 37 12
      callback_query_preload.go
  3. 3 1
      preload_test.go

+ 12 - 0
README.md

@@ -432,6 +432,18 @@ db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users)
 //// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to
 ```
 
+#### Custom Preloading SQL
+
+You could custom preloading SQL by passing in `func(db *gorm.DB) *gorm.DB` (same type as the one used for [Scopes](#scopes)), for example:
+
+```go
+db.Preload("Orders", func(db *gorm.DB) *gorm.DB {
+    return db.Order("orders.amount DESC")
+}).Find(&users)
+//// SELECT * FROM users;
+//// SELECT * FROM orders WHERE user_id IN (1,2,3,4) order by orders.amount DESC;
+```
+
 #### Nested Preloading
 
 ```go

+ 37 - 12
callback_query_preload.go

@@ -73,6 +73,23 @@ func preloadCallback(scope *Scope) {
 	}
 }
 
+func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
+	var (
+		preloadDB         = scope.NewDB()
+		preloadConditions []interface{}
+	)
+
+	for _, condition := range conditions {
+		if scopes, ok := condition.(func(*DB) *DB); ok {
+			preloadDB = scopes(preloadDB)
+		} else {
+			preloadConditions = append(preloadConditions, condition)
+		}
+	}
+
+	return preloadDB, preloadConditions
+}
+
 // handleHasOnePreload used to preload has one associations
 func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
 	relation := field.Relationship
@@ -83,9 +100,12 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
 		return
 	}
 
+	// preload conditions
+	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
+
 	// find relations
 	results := makeSlice(field.Struct.Type)
-	scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
+	scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
 
 	// assign find results
 	var (
@@ -119,9 +139,12 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
 		return
 	}
 
+	// preload conditions
+	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
+
 	// find relations
 	results := makeSlice(field.Struct.Type)
-	scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
+	scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
 
 	// assign find results
 	var (
@@ -151,6 +174,9 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
 func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
 	relation := field.Relationship
 
+	// preload conditions
+	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
+
 	// get relations's primary keys
 	primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
 	if len(primaryKeys) == 0 {
@@ -159,7 +185,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
 
 	// find relations
 	results := makeSlice(field.Struct.Type)
-	scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
+	scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
 
 	// assign find results
 	var (
@@ -205,21 +231,20 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
 		sourceKeys = append(sourceKeys, key.DBName)
 	}
 
+	// preload conditions
+	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
+
 	// generate query with join table
 	newScope := scope.New(reflect.New(fieldType).Interface())
-	preloadJoinDB := scope.NewDB().Table(newScope.TableName()).Select("*")
-	preloadJoinDB = joinTableHandler.JoinWith(joinTableHandler, preloadJoinDB, scope.Value)
-
-	if primaryField := newScope.PrimaryField(); primaryField != nil {
-		preloadJoinDB = preloadJoinDB.Order(fmt.Sprintf("%v.%v %v", newScope.QuotedTableName(), newScope.Quote(primaryField.DBName), "ASC"))
-	}
+	preloadDB = preloadDB.Table(newScope.TableName()).Select("*")
+	preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
 
 	// preload inline conditions
-	if len(conditions) > 0 {
-		preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
+	if len(preloadConditions) > 0 {
+		preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
 	}
 
-	rows, err := preloadJoinDB.Rows()
+	rows, err := preloadDB.Rows()
 
 	if scope.Err(err) != nil {
 		return

+ 3 - 1
preload_test.go

@@ -1107,7 +1107,9 @@ func TestNestedManyToManyPreload3(t *testing.T) {
 	}
 
 	var gots []*Level3
-	if err := DB.Preload("Level2.Level1s").Find(&gots).Error; err != nil {
+	if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB {
+		return db.Order("level1.id ASC")
+	}).Find(&gots).Error; err != nil {
 		t.Error(err)
 	}