join_table_handler.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "strings"
  7. )
  8. // JoinTableHandlerInterface is an interface for how to handle many2many relations
  9. type JoinTableHandlerInterface interface {
  10. // initialize join table handler
  11. Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
  12. // Table return join table's table name
  13. Table(db *DB) string
  14. // Add create relationship in join table for source and destination
  15. Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
  16. // Delete delete relationship in join table for sources
  17. Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
  18. // JoinWith query with `Join` conditions
  19. JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
  20. // SourceForeignKeys return source foreign keys
  21. SourceForeignKeys() []JoinTableForeignKey
  22. // DestinationForeignKeys return destination foreign keys
  23. DestinationForeignKeys() []JoinTableForeignKey
  24. }
  25. // JoinTableForeignKey join table foreign key struct
  26. type JoinTableForeignKey struct {
  27. DBName string
  28. AssociationDBName string
  29. }
  30. // JoinTableSource is a struct that contains model type and foreign keys
  31. type JoinTableSource struct {
  32. ModelType reflect.Type
  33. ForeignKeys []JoinTableForeignKey
  34. }
  35. // JoinTableHandler default join table handler
  36. type JoinTableHandler struct {
  37. TableName string `sql:"-"`
  38. Source JoinTableSource `sql:"-"`
  39. Destination JoinTableSource `sql:"-"`
  40. }
  41. // SourceForeignKeys return source foreign keys
  42. func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
  43. return s.Source.ForeignKeys
  44. }
  45. // DestinationForeignKeys return destination foreign keys
  46. func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
  47. return s.Destination.ForeignKeys
  48. }
  49. // Setup initialize a default join table handler
  50. func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
  51. s.TableName = tableName
  52. s.Source = JoinTableSource{ModelType: source}
  53. s.Source.ForeignKeys = []JoinTableForeignKey{}
  54. for idx, dbName := range relationship.ForeignFieldNames {
  55. s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
  56. DBName: relationship.ForeignDBNames[idx],
  57. AssociationDBName: dbName,
  58. })
  59. }
  60. s.Destination = JoinTableSource{ModelType: destination}
  61. s.Destination.ForeignKeys = []JoinTableForeignKey{}
  62. for idx, dbName := range relationship.AssociationForeignFieldNames {
  63. s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
  64. DBName: relationship.AssociationForeignDBNames[idx],
  65. AssociationDBName: dbName,
  66. })
  67. }
  68. }
  69. // Table return join table's table name
  70. func (s JoinTableHandler) Table(db *DB) string {
  71. return DefaultTableNameHandler(db, s.TableName)
  72. }
  73. func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
  74. for _, source := range sources {
  75. scope := db.NewScope(source)
  76. modelType := scope.GetModelStruct().ModelType
  77. for _, joinTableSource := range joinTableSources {
  78. if joinTableSource.ModelType == modelType {
  79. for _, foreignKey := range joinTableSource.ForeignKeys {
  80. if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
  81. conditionMap[foreignKey.DBName] = field.Field.Interface()
  82. }
  83. }
  84. break
  85. }
  86. }
  87. }
  88. }
  89. // Add create relationship in join table for source and destination
  90. func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
  91. var (
  92. scope = db.NewScope("")
  93. conditionMap = map[string]interface{}{}
  94. )
  95. // Update condition map for source
  96. s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
  97. // Update condition map for destination
  98. s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
  99. var assignColumns, binVars, conditions []string
  100. var values []interface{}
  101. for key, value := range conditionMap {
  102. assignColumns = append(assignColumns, scope.Quote(key))
  103. binVars = append(binVars, `?`)
  104. conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
  105. values = append(values, value)
  106. }
  107. for _, value := range values {
  108. values = append(values, value)
  109. }
  110. quotedTable := scope.Quote(handler.Table(db))
  111. sql := fmt.Sprintf(
  112. "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
  113. quotedTable,
  114. strings.Join(assignColumns, ","),
  115. strings.Join(binVars, ","),
  116. scope.Dialect().SelectFromDummyTable(),
  117. quotedTable,
  118. strings.Join(conditions, " AND "),
  119. )
  120. return db.Exec(sql, values...).Error
  121. }
  122. // Delete delete relationship in join table for sources
  123. func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
  124. var (
  125. scope = db.NewScope(nil)
  126. conditions []string
  127. values []interface{}
  128. conditionMap = map[string]interface{}{}
  129. )
  130. s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
  131. for key, value := range conditionMap {
  132. conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
  133. values = append(values, value)
  134. }
  135. return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
  136. }
  137. // JoinWith query with `Join` conditions
  138. func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
  139. var (
  140. scope = db.NewScope(source)
  141. tableName = handler.Table(db)
  142. quotedTableName = scope.Quote(tableName)
  143. joinConditions []string
  144. values []interface{}
  145. )
  146. if s.Source.ModelType == scope.GetModelStruct().ModelType {
  147. destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
  148. for _, foreignKey := range s.Destination.ForeignKeys {
  149. joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
  150. }
  151. var foreignDBNames []string
  152. var foreignFieldNames []string
  153. for _, foreignKey := range s.Source.ForeignKeys {
  154. foreignDBNames = append(foreignDBNames, foreignKey.DBName)
  155. if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
  156. foreignFieldNames = append(foreignFieldNames, field.Name)
  157. }
  158. }
  159. foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
  160. var condString string
  161. if len(foreignFieldValues) > 0 {
  162. var quotedForeignDBNames []string
  163. for _, dbName := range foreignDBNames {
  164. quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
  165. }
  166. condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
  167. keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
  168. values = append(values, toQueryValues(keys))
  169. } else {
  170. condString = fmt.Sprintf("1 <> 1")
  171. }
  172. return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
  173. Where(condString, toQueryValues(foreignFieldValues)...)
  174. }
  175. db.Error = errors.New("wrong source type for join table handler")
  176. return db
  177. }