statement.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "database/sql/driver"
  6. "fmt"
  7. "reflect"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "github.com/jinzhu/gorm/clause"
  12. "github.com/jinzhu/gorm/schema"
  13. )
  14. // Instance db instance
  15. type Instance struct {
  16. Error error
  17. RowsAffected int64
  18. Context context.Context
  19. Statement *Statement
  20. }
  21. func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
  22. if len(clauses) > 0 {
  23. instance.Statement.Build(clauses...)
  24. }
  25. return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars
  26. }
  27. // AddError add error to instance
  28. func (inst *Instance) AddError(err error) {
  29. if inst.Error == nil {
  30. inst.Error = err
  31. } else if err != nil {
  32. inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
  33. }
  34. }
  35. // Statement statement
  36. type Statement struct {
  37. Table string
  38. Model interface{}
  39. Dest interface{}
  40. ReflectValue reflect.Value
  41. Clauses map[string]clause.Clause
  42. Selects []string // selected columns
  43. Omits []string // omit columns
  44. Settings sync.Map
  45. DB *DB
  46. Schema *schema.Schema
  47. // SQL Builder
  48. SQL strings.Builder
  49. Vars []interface{}
  50. NamedVars []sql.NamedArg
  51. }
  52. // StatementOptimizer statement optimizer interface
  53. type StatementOptimizer interface {
  54. OptimizeStatement(*Statement)
  55. }
  56. // Write write string
  57. func (stmt *Statement) Write(sql ...string) (err error) {
  58. for _, s := range sql {
  59. _, err = stmt.SQL.WriteString(s)
  60. }
  61. return
  62. }
  63. // Write write string
  64. func (stmt *Statement) WriteByte(c byte) (err error) {
  65. return stmt.SQL.WriteByte(c)
  66. }
  67. // WriteQuoted write quoted field
  68. func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
  69. _, err = stmt.SQL.WriteString(stmt.Quote(field))
  70. return
  71. }
  72. // Quote returns quoted value
  73. func (stmt Statement) Quote(field interface{}) string {
  74. var str strings.Builder
  75. str.WriteByte(stmt.DB.quoteChars[0])
  76. switch v := field.(type) {
  77. case clause.Table:
  78. if v.Name == clause.CurrentTable {
  79. str.WriteString(stmt.Table)
  80. } else {
  81. str.WriteString(v.Name)
  82. }
  83. if v.Alias != "" {
  84. str.WriteByte(stmt.DB.quoteChars[1])
  85. str.WriteString(" AS ")
  86. str.WriteByte(stmt.DB.quoteChars[0])
  87. str.WriteString(v.Alias)
  88. str.WriteByte(stmt.DB.quoteChars[1])
  89. }
  90. case clause.Column:
  91. if v.Table != "" {
  92. if v.Table == clause.CurrentTable {
  93. str.WriteString(stmt.Table)
  94. } else {
  95. str.WriteString(v.Table)
  96. }
  97. str.WriteByte(stmt.DB.quoteChars[1])
  98. str.WriteByte('.')
  99. str.WriteByte(stmt.DB.quoteChars[0])
  100. }
  101. if v.Name == clause.PrimaryKey {
  102. if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
  103. str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName)
  104. }
  105. } else {
  106. str.WriteString(v.Name)
  107. }
  108. if v.Alias != "" {
  109. str.WriteByte(stmt.DB.quoteChars[1])
  110. str.WriteString(" AS ")
  111. str.WriteByte(stmt.DB.quoteChars[0])
  112. str.WriteString(v.Alias)
  113. str.WriteByte(stmt.DB.quoteChars[1])
  114. }
  115. default:
  116. str.WriteString(fmt.Sprint(field))
  117. }
  118. str.WriteByte(stmt.DB.quoteChars[1])
  119. return str.String()
  120. }
  121. // Write write string
  122. func (stmt *Statement) AddVar(vars ...interface{}) string {
  123. var placeholders strings.Builder
  124. for idx, v := range vars {
  125. if idx > 0 {
  126. placeholders.WriteByte(',')
  127. }
  128. switch v := v.(type) {
  129. case sql.NamedArg:
  130. if len(v.Name) > 0 {
  131. stmt.NamedVars = append(stmt.NamedVars, v)
  132. placeholders.WriteByte('@')
  133. placeholders.WriteString(v.Name)
  134. } else {
  135. stmt.Vars = append(stmt.Vars, v.Value)
  136. placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
  137. }
  138. case clause.Column, clause.Table:
  139. placeholders.WriteString(stmt.Quote(v))
  140. case clause.Expr:
  141. placeholders.WriteString(v.SQL)
  142. stmt.Vars = append(stmt.Vars, v.Vars...)
  143. case []interface{}:
  144. if len(v) > 0 {
  145. placeholders.WriteByte('(')
  146. placeholders.WriteString(stmt.AddVar(v...))
  147. placeholders.WriteByte(')')
  148. } else {
  149. placeholders.WriteString("(NULL)")
  150. }
  151. default:
  152. stmt.Vars = append(stmt.Vars, v)
  153. placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
  154. }
  155. }
  156. return placeholders.String()
  157. }
  158. // AddClause add clause
  159. func (stmt *Statement) AddClause(v clause.Interface) {
  160. if optimizer, ok := v.(StatementOptimizer); ok {
  161. optimizer.OptimizeStatement(stmt)
  162. }
  163. c, ok := stmt.Clauses[v.Name()]
  164. if !ok {
  165. c.Name = v.Name()
  166. }
  167. v.MergeClause(&c)
  168. stmt.Clauses[v.Name()] = c
  169. }
  170. // AddClauseIfNotExists add clause if not exists
  171. func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
  172. if _, ok := stmt.Clauses[v.Name()]; !ok {
  173. stmt.AddClause(v)
  174. }
  175. }
  176. // BuildCondtion build condition
  177. func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
  178. if sql, ok := query.(string); ok {
  179. if i, err := strconv.Atoi(sql); err == nil {
  180. query = i
  181. } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
  182. return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
  183. }
  184. }
  185. args = append([]interface{}{query}, args...)
  186. for _, arg := range args {
  187. if valuer, ok := arg.(driver.Valuer); ok {
  188. arg, _ = valuer.Value()
  189. }
  190. switch v := arg.(type) {
  191. case clause.Expression:
  192. conditions = append(conditions, v)
  193. case *DB:
  194. if v.Statement == nil {
  195. if cs, ok := v.Statement.Clauses["WHERE"]; ok {
  196. conditions = append(conditions, cs.Expression)
  197. }
  198. }
  199. case map[interface{}]interface{}:
  200. var clauseMap = clause.Map{}
  201. for i, j := range v {
  202. clauseMap[i] = j
  203. }
  204. conditions = append(conditions, clauseMap)
  205. case map[string]string:
  206. var clauseMap = clause.Map{}
  207. for i, j := range v {
  208. clauseMap[i] = j
  209. }
  210. conditions = append(conditions, clauseMap)
  211. case map[string]interface{}:
  212. var clauseMap = clause.Map{}
  213. for i, j := range v {
  214. clauseMap[i] = j
  215. }
  216. conditions = append(conditions, clauseMap)
  217. default:
  218. // TODO check is struct
  219. // struct, slice -> ids
  220. }
  221. }
  222. if len(conditions) == 0 {
  223. conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args})
  224. }
  225. return conditions
  226. }
  227. // Build build sql with clauses names
  228. func (stmt *Statement) Build(clauses ...string) {
  229. var firstClauseWritten bool
  230. for _, name := range clauses {
  231. if c, ok := stmt.Clauses[name]; ok {
  232. if firstClauseWritten {
  233. stmt.WriteByte(' ')
  234. }
  235. firstClauseWritten = true
  236. if b, ok := stmt.DB.ClauseBuilders[name]; ok {
  237. b.Build(c, stmt)
  238. } else {
  239. c.Build(stmt)
  240. }
  241. }
  242. }
  243. // TODO handle named vars
  244. }
  245. func (stmt *Statement) Parse(value interface{}) (err error) {
  246. if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
  247. stmt.Table = stmt.Schema.Table
  248. }
  249. return err
  250. }