Kaynağa Gözat

Keep refactoring association mode

Jinzhu 8 yıl önce
ebeveyn
işleme
dc23ae63bf
4 değiştirilmiş dosya ile 149 ekleme ve 152 silme
  1. 108 29
      association.go
  2. 0 122
      association_utils.go
  3. 1 1
      main.go
  4. 40 0
      utils.go

+ 108 - 29
association.go

@@ -1,27 +1,28 @@
 package gorm
 
 import (
+	"errors"
 	"fmt"
 	"reflect"
 )
 
-// Association Association Mode contains some helper methods to handle relationship things easily.
+// Association Mode contains some helper methods to handle relationship things easily.
 type Association struct {
-	Scope  *Scope
-	Column string
 	Error  error
-	Field  *Field
+	scope  *Scope
+	column string
+	field  *Field
 }
 
 // Find find out all related associations
 func (association *Association) Find(value interface{}) *Association {
-	association.Scope.related(value, association.Column)
-	return association.setErr(association.Scope.db.Error)
+	association.scope.related(value, association.column)
+	return association.setErr(association.scope.db.Error)
 }
 
-// Append append new associations for many2many, has_many, will replace current association for has_one, belongs_to
+// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
 func (association *Association) Append(values ...interface{}) *Association {
-	if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
+	if relationship := association.field.Relationship; relationship.Kind == "has_one" {
 		return association.Replace(values...)
 	}
 	return association.saveAssociations(values...)
@@ -30,14 +31,14 @@ func (association *Association) Append(values ...interface{}) *Association {
 // Replace replace current associations with new one
 func (association *Association) Replace(values ...interface{}) *Association {
 	var (
-		relationship = association.Field.Relationship
-		scope        = association.Scope
-		field        = association.Field.Field
+		relationship = association.field.Relationship
+		scope        = association.scope
+		field        = association.field.Field
 		newDB        = scope.NewDB()
 	)
 
 	// Append new values
-	association.Field.Set(reflect.Zero(association.Field.Field.Type()))
+	association.field.Set(reflect.Zero(association.field.Field.Type()))
 	association.saveAssociations(values...)
 
 	// Belongs To
@@ -109,7 +110,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
 				}
 			}
 
-			fieldValue := reflect.New(association.Field.Field.Type()).Interface()
+			fieldValue := reflect.New(association.field.Field.Type()).Interface()
 			association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
 		}
 	}
@@ -119,9 +120,9 @@ func (association *Association) Replace(values ...interface{}) *Association {
 // Delete remove relationship between source & passed arguments, but won't delete those arguments
 func (association *Association) Delete(values ...interface{}) *Association {
 	var (
-		relationship = association.Field.Relationship
-		scope        = association.Scope
-		field        = association.Field.Field
+		relationship = association.field.Relationship
+		scope        = association.scope
+		field        = association.field.Field
 		newDB        = scope.NewDB()
 	)
 
@@ -196,18 +197,18 @@ func (association *Association) Delete(values ...interface{}) *Association {
 			)
 
 			// set matched relation's foreign key to be null
-			fieldValue := reflect.New(association.Field.Field.Type()).Interface()
+			fieldValue := reflect.New(association.field.Field.Type()).Interface()
 			association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
 		}
 	}
 
 	// Remove deleted records from source's field
 	if association.Error == nil {
-		if association.Field.Field.Kind() == reflect.Slice {
-			leftValues := reflect.Zero(association.Field.Field.Type())
+		if field.Kind() == reflect.Slice {
+			leftValues := reflect.Zero(field.Type())
 
-			for i := 0; i < association.Field.Field.Len(); i++ {
-				reflectValue := association.Field.Field.Index(i)
+			for i := 0; i < field.Len(); i++ {
+				reflectValue := field.Index(i)
 				primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
 				var isDeleted = false
 				for _, pk := range deletingPrimaryKeys {
@@ -221,12 +222,12 @@ func (association *Association) Delete(values ...interface{}) *Association {
 				}
 			}
 
-			association.Field.Set(leftValues)
-		} else if association.Field.Field.Kind() == reflect.Struct {
-			primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
+			association.field.Set(leftValues)
+		} else if field.Kind() == reflect.Struct {
+			primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
 			for _, pk := range deletingPrimaryKeys {
 				if equalAsString(primaryKey, pk) {
-					association.Field.Set(reflect.Zero(association.Field.Field.Type()))
+					association.field.Set(reflect.Zero(field.Type()))
 					break
 				}
 			}
@@ -245,14 +246,14 @@ func (association *Association) Clear() *Association {
 func (association *Association) Count() int {
 	var (
 		count        = 0
-		relationship = association.Field.Relationship
-		scope        = association.Scope
-		fieldValue   = association.Field.Field.Interface()
+		relationship = association.field.Relationship
+		scope        = association.scope
+		fieldValue   = association.field.Field.Interface()
 		query        = scope.DB()
 	)
 
 	if relationship.Kind == "many_to_many" {
-		query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value)
+		query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
 	} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
 		primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
 		query = query.Where(
@@ -277,3 +278,81 @@ func (association *Association) Count() int {
 	query.Model(fieldValue).Count(&count)
 	return count
 }
+
+// saveAssociations save passed values as associations
+func (association *Association) saveAssociations(values ...interface{}) *Association {
+	var (
+		scope        = association.scope
+		field        = association.field
+		relationship = field.Relationship
+	)
+
+	saveAssociation := func(reflectValue reflect.Value) {
+		// value has to been pointer
+		if reflectValue.Kind() != reflect.Ptr {
+			reflectPtr := reflect.New(reflectValue.Type())
+			reflectPtr.Elem().Set(reflectValue)
+			reflectValue = reflectPtr
+		}
+
+		// value has to been saved for many2many
+		if relationship.Kind == "many_to_many" {
+			if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
+				association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
+			}
+		}
+
+		// Assign Fields
+		var fieldType = field.Field.Type()
+		var setFieldBackToValue, setSliceFieldBackToValue bool
+		if reflectValue.Type().AssignableTo(fieldType) {
+			field.Set(reflectValue)
+		} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
+			// if field's type is struct, then need to set value back to argument after save
+			setFieldBackToValue = true
+			field.Set(reflectValue.Elem())
+		} else if fieldType.Kind() == reflect.Slice {
+			if reflectValue.Type().AssignableTo(fieldType.Elem()) {
+				field.Set(reflect.Append(field.Field, reflectValue))
+			} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
+				// if field's type is slice of struct, then need to set value back to argument after save
+				setSliceFieldBackToValue = true
+				field.Set(reflect.Append(field.Field, reflectValue.Elem()))
+			}
+		}
+
+		if relationship.Kind == "many_to_many" {
+			association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
+		} else {
+			association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
+
+			if setFieldBackToValue {
+				reflectValue.Elem().Set(field.Field)
+			} else if setSliceFieldBackToValue {
+				reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
+			}
+		}
+	}
+
+	for _, value := range values {
+		reflectValue := reflect.ValueOf(value)
+		indirectReflectValue := reflect.Indirect(reflectValue)
+		if indirectReflectValue.Kind() == reflect.Struct {
+			saveAssociation(reflectValue)
+		} else if indirectReflectValue.Kind() == reflect.Slice {
+			for i := 0; i < indirectReflectValue.Len(); i++ {
+				saveAssociation(indirectReflectValue.Index(i))
+			}
+		} else {
+			association.setErr(errors.New("invalid value type"))
+		}
+	}
+	return association
+}
+
+func (association *Association) setErr(err error) *Association {
+	if err != nil {
+		association.Error = err
+	}
+	return association
+}

+ 0 - 122
association_utils.go

@@ -1,122 +0,0 @@
-package gorm
-
-import (
-	"errors"
-	"fmt"
-	"reflect"
-	"strings"
-)
-
-func (association *Association) setErr(err error) *Association {
-	if err != nil {
-		association.Error = err
-	}
-	return association
-}
-
-func (association *Association) saveAssociations(values ...interface{}) *Association {
-	scope := association.Scope
-	field := association.Field
-	relationship := association.Field.Relationship
-
-	saveAssociation := func(reflectValue reflect.Value) {
-		// value has to been pointer
-		if reflectValue.Kind() != reflect.Ptr {
-			reflectPtr := reflect.New(reflectValue.Type())
-			reflectPtr.Elem().Set(reflectValue)
-			reflectValue = reflectPtr
-		}
-
-		// value has to been saved for many2many
-		if relationship.Kind == "many_to_many" {
-			if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
-				association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
-			}
-		}
-
-		// Assign Fields
-		var fieldType = field.Field.Type()
-		var setFieldBackToValue, setSliceFieldBackToValue bool
-		if reflectValue.Type().AssignableTo(fieldType) {
-			field.Set(reflectValue)
-		} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
-			// if field's type is struct, then need to set value back to argument after save
-			setFieldBackToValue = true
-			field.Set(reflectValue.Elem())
-		} else if fieldType.Kind() == reflect.Slice {
-			if reflectValue.Type().AssignableTo(fieldType.Elem()) {
-				field.Set(reflect.Append(field.Field, reflectValue))
-			} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
-				// if field's type is slice of struct, then need to set value back to argument after save
-				setSliceFieldBackToValue = true
-				field.Set(reflect.Append(field.Field, reflectValue.Elem()))
-			}
-		}
-
-		if relationship.Kind == "many_to_many" {
-			association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
-		} else {
-			association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
-
-			if setFieldBackToValue {
-				reflectValue.Elem().Set(field.Field)
-			} else if setSliceFieldBackToValue {
-				reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
-			}
-		}
-	}
-
-	for _, value := range values {
-		reflectValue := reflect.ValueOf(value)
-		indirectReflectValue := reflect.Indirect(reflectValue)
-		if indirectReflectValue.Kind() == reflect.Struct {
-			saveAssociation(reflectValue)
-		} else if indirectReflectValue.Kind() == reflect.Slice {
-			for i := 0; i < indirectReflectValue.Len(); i++ {
-				saveAssociation(indirectReflectValue.Index(i))
-			}
-		} else {
-			association.setErr(errors.New("invalid value type"))
-		}
-	}
-	return association
-}
-
-func toQueryMarks(primaryValues [][]interface{}) string {
-	var results []string
-
-	for _, primaryValue := range primaryValues {
-		var marks []string
-		for _ = range primaryValue {
-			marks = append(marks, "?")
-		}
-
-		if len(marks) > 1 {
-			results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
-		} else {
-			results = append(results, strings.Join(marks, ""))
-		}
-	}
-	return strings.Join(results, ",")
-}
-
-func toQueryCondition(scope *Scope, columns []string) string {
-	var newColumns []string
-	for _, column := range columns {
-		newColumns = append(newColumns, scope.Quote(column))
-	}
-
-	if len(columns) > 1 {
-		return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
-	}
-	return strings.Join(newColumns, ",")
-}
-
-func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
-	for _, primaryValue := range primaryValues {
-		for _, value := range primaryValue {
-			values = append(values, value)
-		}
-	}
-	return values
-}

+ 1 - 1
main.go

@@ -480,7 +480,7 @@ func (s *DB) Association(column string) *Association {
 			if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
 				err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
 			} else {
-				return &Association{Scope: scope, Column: column, Field: field}
+				return &Association{scope: scope, column: column, field: field}
 			}
 		} else {
 			err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)

+ 40 - 0
utils.go

@@ -2,6 +2,7 @@ package gorm
 
 import (
 	"bytes"
+	"fmt"
 	"strings"
 	"sync"
 )
@@ -100,3 +101,42 @@ type expr struct {
 func Expr(expression string, args ...interface{}) *expr {
 	return &expr{expr: expression, args: args}
 }
+
+func toQueryMarks(primaryValues [][]interface{}) string {
+	var results []string
+
+	for _, primaryValue := range primaryValues {
+		var marks []string
+		for _ = range primaryValue {
+			marks = append(marks, "?")
+		}
+
+		if len(marks) > 1 {
+			results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
+		} else {
+			results = append(results, strings.Join(marks, ""))
+		}
+	}
+	return strings.Join(results, ",")
+}
+
+func toQueryCondition(scope *Scope, columns []string) string {
+	var newColumns []string
+	for _, column := range columns {
+		newColumns = append(newColumns, scope.Quote(column))
+	}
+
+	if len(columns) > 1 {
+		return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
+	}
+	return strings.Join(newColumns, ",")
+}
+
+func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
+	for _, primaryValue := range primaryValues {
+		for _, value := range primaryValue {
+			values = append(values, value)
+		}
+	}
+	return values
+}