statement.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "database/sql/driver"
  6. "fmt"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "github.com/jinzhu/gorm/clause"
  11. "github.com/jinzhu/gorm/schema"
  12. )
  13. // Instance db instance
  14. type Instance struct {
  15. Error error
  16. RowsAffected int64
  17. Context context.Context
  18. Statement *Statement
  19. }
  20. // AddError add error to instance
  21. func (inst Instance) AddError(err error) {
  22. if inst.Error == nil {
  23. inst.Error = err
  24. } else {
  25. inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
  26. }
  27. }
  28. // Statement statement
  29. type Statement struct {
  30. Table string
  31. Model interface{}
  32. Dest interface{}
  33. Clauses map[string]clause.Clause
  34. Settings sync.Map
  35. DB *DB
  36. Schema *schema.Schema
  37. // SQL Builder
  38. SQL strings.Builder
  39. Vars []interface{}
  40. NamedVars []sql.NamedArg
  41. }
  42. // StatementOptimizer statement optimizer interface
  43. type StatementOptimizer interface {
  44. OptimizeStatement(Statement)
  45. }
  46. // Write write string
  47. func (stmt Statement) Write(sql ...string) (err error) {
  48. for _, s := range sql {
  49. _, err = stmt.SQL.WriteString(s)
  50. }
  51. return
  52. }
  53. // Write write string
  54. func (stmt Statement) WriteByte(c byte) (err error) {
  55. return stmt.SQL.WriteByte(c)
  56. }
  57. // WriteQuoted write quoted field
  58. func (stmt Statement) WriteQuoted(field interface{}) (err error) {
  59. _, err = stmt.SQL.WriteString(stmt.Quote(field))
  60. return
  61. }
  62. // Quote returns quoted value
  63. func (stmt Statement) Quote(field interface{}) string {
  64. var str strings.Builder
  65. switch v := field.(type) {
  66. case clause.Table:
  67. str.WriteString(v.Table)
  68. if v.Alias != "" {
  69. str.WriteString(" AS ")
  70. str.WriteString(v.Alias)
  71. }
  72. case clause.Column:
  73. if v.Table != "" {
  74. str.WriteString(v.Table)
  75. str.WriteByte('.')
  76. }
  77. str.WriteString(v.Name)
  78. if v.Alias != "" {
  79. str.WriteString(" AS ")
  80. str.WriteString(v.Alias)
  81. }
  82. default:
  83. fmt.Sprint(field)
  84. }
  85. return str.String()
  86. }
  87. // Write write string
  88. func (stmt Statement) AddVar(vars ...interface{}) string {
  89. var placeholders strings.Builder
  90. for idx, v := range vars {
  91. if idx > 0 {
  92. placeholders.WriteByte(',')
  93. }
  94. if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 {
  95. stmt.NamedVars = append(stmt.NamedVars, namedArg)
  96. placeholders.WriteByte('@')
  97. placeholders.WriteString(namedArg.Name)
  98. } else if arrs, ok := v.([]interface{}); ok {
  99. placeholders.WriteByte('(')
  100. if len(arrs) > 0 {
  101. placeholders.WriteString(stmt.AddVar(arrs...))
  102. } else {
  103. placeholders.WriteString("NULL")
  104. }
  105. placeholders.WriteByte(')')
  106. } else {
  107. placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
  108. }
  109. }
  110. return placeholders.String()
  111. }
  112. // AddClause add clause
  113. func (stmt Statement) AddClause(v clause.Interface) {
  114. if optimizer, ok := v.(StatementOptimizer); ok {
  115. optimizer.OptimizeStatement(stmt)
  116. }
  117. c, _ := stmt.Clauses[v.Name()]
  118. if namer, ok := v.(clause.OverrideNameInterface); ok {
  119. c.Name = namer.OverrideName()
  120. } else {
  121. c.Name = v.Name()
  122. }
  123. if c.Expression != nil {
  124. v.MergeExpression(c.Expression)
  125. }
  126. c.Expression = v
  127. stmt.Clauses[v.Name()] = c
  128. }
  129. // BuildCondtion build condition
  130. func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
  131. if sql, ok := query.(string); ok {
  132. if i, err := strconv.Atoi(sql); err != nil {
  133. query = i
  134. } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
  135. return []clause.Expression{clause.String{SQL: sql, Values: args}}
  136. }
  137. }
  138. args = append([]interface{}{query}, args...)
  139. for _, arg := range args {
  140. if valuer, ok := arg.(driver.Valuer); ok {
  141. arg, _ = valuer.Value()
  142. }
  143. switch v := arg.(type) {
  144. case clause.Expression:
  145. conditions = append(conditions, v)
  146. case *DB:
  147. if v.Statement == nil {
  148. if cs, ok := v.Statement.Clauses["WHERE"]; ok {
  149. conditions = append(conditions, cs.Expression)
  150. }
  151. }
  152. case map[interface{}]interface{}:
  153. var clauseMap = clause.Map{}
  154. for i, j := range v {
  155. clauseMap[i] = j
  156. }
  157. conditions = append(conditions, clauseMap)
  158. case map[string]string:
  159. var clauseMap = clause.Map{}
  160. for i, j := range v {
  161. clauseMap[i] = j
  162. }
  163. conditions = append(conditions, clauseMap)
  164. case map[string]interface{}:
  165. var clauseMap = clause.Map{}
  166. for i, j := range v {
  167. clauseMap[i] = j
  168. }
  169. conditions = append(conditions, clauseMap)
  170. default:
  171. // TODO check is struct
  172. // struct, slice -> ids
  173. }
  174. }
  175. if len(conditions) == 0 {
  176. conditions = append(conditions, clause.ID{Value: args})
  177. }
  178. return conditions
  179. }
  180. // Build build sql with clauses names
  181. func (stmt Statement) Build(clauses ...string) {
  182. var includeSpace bool
  183. for _, name := range clauses {
  184. if c, ok := stmt.Clauses[name]; ok {
  185. if includeSpace {
  186. stmt.WriteByte(' ')
  187. }
  188. includeSpace = true
  189. c.Build(stmt)
  190. }
  191. }
  192. }