statement.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "database/sql/driver"
  6. "fmt"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "github.com/jinzhu/gorm/clause"
  11. "github.com/jinzhu/gorm/schema"
  12. )
  13. // Instance db instance
  14. type Instance struct {
  15. Error error
  16. RowsAffected int64
  17. Context context.Context
  18. Statement *Statement
  19. }
  20. func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
  21. if len(clauses) > 0 {
  22. instance.Statement.Build(clauses...)
  23. }
  24. return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars
  25. }
  26. // AddError add error to instance
  27. func (inst *Instance) AddError(err error) {
  28. if inst.Error == nil {
  29. inst.Error = err
  30. } else {
  31. inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
  32. }
  33. }
  34. // Statement statement
  35. type Statement struct {
  36. Table string
  37. Model interface{}
  38. Dest interface{}
  39. Clauses map[string]clause.Clause
  40. Selects []string // selected columns
  41. Omits []string // omit columns
  42. Settings sync.Map
  43. DB *DB
  44. Schema *schema.Schema
  45. // SQL Builder
  46. SQL strings.Builder
  47. Vars []interface{}
  48. NamedVars []sql.NamedArg
  49. }
  50. // StatementOptimizer statement optimizer interface
  51. type StatementOptimizer interface {
  52. OptimizeStatement(*Statement)
  53. }
  54. // Write write string
  55. func (stmt *Statement) Write(sql ...string) (err error) {
  56. for _, s := range sql {
  57. _, err = stmt.SQL.WriteString(s)
  58. }
  59. return
  60. }
  61. // Write write string
  62. func (stmt *Statement) WriteByte(c byte) (err error) {
  63. return stmt.SQL.WriteByte(c)
  64. }
  65. // WriteQuoted write quoted field
  66. func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
  67. _, err = stmt.SQL.WriteString(stmt.Quote(field))
  68. return
  69. }
  70. // Quote returns quoted value
  71. func (stmt Statement) Quote(field interface{}) string {
  72. var str strings.Builder
  73. str.WriteByte(stmt.DB.quoteChars[0])
  74. switch v := field.(type) {
  75. case clause.Table:
  76. if v.Name == clause.CurrentTable {
  77. str.WriteString(stmt.Table)
  78. } else {
  79. str.WriteString(v.Name)
  80. }
  81. if v.Alias != "" {
  82. str.WriteByte(stmt.DB.quoteChars[1])
  83. str.WriteString(" AS ")
  84. str.WriteByte(stmt.DB.quoteChars[0])
  85. str.WriteString(v.Alias)
  86. str.WriteByte(stmt.DB.quoteChars[1])
  87. }
  88. case clause.Column:
  89. if v.Table != "" {
  90. if v.Table == clause.CurrentTable {
  91. str.WriteString(stmt.Table)
  92. } else {
  93. str.WriteString(v.Table)
  94. }
  95. str.WriteByte(stmt.DB.quoteChars[1])
  96. str.WriteByte('.')
  97. str.WriteByte(stmt.DB.quoteChars[0])
  98. }
  99. if v.Name == clause.PrimaryKey {
  100. if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
  101. str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName)
  102. }
  103. } else {
  104. str.WriteString(v.Name)
  105. }
  106. if v.Alias != "" {
  107. str.WriteByte(stmt.DB.quoteChars[1])
  108. str.WriteString(" AS ")
  109. str.WriteByte(stmt.DB.quoteChars[0])
  110. str.WriteString(v.Alias)
  111. str.WriteByte(stmt.DB.quoteChars[1])
  112. }
  113. default:
  114. str.WriteString(fmt.Sprint(field))
  115. }
  116. str.WriteByte(stmt.DB.quoteChars[1])
  117. return str.String()
  118. }
  119. // Write write string
  120. func (stmt *Statement) AddVar(vars ...interface{}) string {
  121. var placeholders strings.Builder
  122. for idx, v := range vars {
  123. if idx > 0 {
  124. placeholders.WriteByte(',')
  125. }
  126. switch v := v.(type) {
  127. case sql.NamedArg:
  128. if len(v.Name) > 0 {
  129. stmt.NamedVars = append(stmt.NamedVars, v)
  130. placeholders.WriteByte('@')
  131. placeholders.WriteString(v.Name)
  132. } else {
  133. stmt.Vars = append(stmt.Vars, v.Value)
  134. placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
  135. }
  136. case clause.Column, clause.Table:
  137. placeholders.WriteString(stmt.Quote(v))
  138. case clause.Expr:
  139. placeholders.WriteString(v.SQL)
  140. stmt.Vars = append(stmt.Vars, v.Vars...)
  141. case []interface{}:
  142. if len(v) > 0 {
  143. placeholders.WriteByte('(')
  144. placeholders.WriteString(stmt.AddVar(v...))
  145. placeholders.WriteByte(')')
  146. } else {
  147. placeholders.WriteString("(NULL)")
  148. }
  149. default:
  150. stmt.Vars = append(stmt.Vars, v)
  151. placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
  152. }
  153. }
  154. return placeholders.String()
  155. }
  156. // AddClause add clause
  157. func (stmt *Statement) AddClause(v clause.Interface) {
  158. if optimizer, ok := v.(StatementOptimizer); ok {
  159. optimizer.OptimizeStatement(stmt)
  160. }
  161. c, ok := stmt.Clauses[v.Name()]
  162. if !ok {
  163. c.Name = v.Name()
  164. }
  165. v.MergeClause(&c)
  166. stmt.Clauses[v.Name()] = c
  167. }
  168. // AddClauseIfNotExists add clause if not exists
  169. func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
  170. if _, ok := stmt.Clauses[v.Name()]; !ok {
  171. stmt.AddClause(v)
  172. }
  173. }
  174. // BuildCondtion build condition
  175. func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
  176. if sql, ok := query.(string); ok {
  177. if i, err := strconv.Atoi(sql); err != nil {
  178. query = i
  179. } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
  180. return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
  181. }
  182. }
  183. args = append([]interface{}{query}, args...)
  184. for _, arg := range args {
  185. if valuer, ok := arg.(driver.Valuer); ok {
  186. arg, _ = valuer.Value()
  187. }
  188. switch v := arg.(type) {
  189. case clause.Expression:
  190. conditions = append(conditions, v)
  191. case *DB:
  192. if v.Statement == nil {
  193. if cs, ok := v.Statement.Clauses["WHERE"]; ok {
  194. conditions = append(conditions, cs.Expression)
  195. }
  196. }
  197. case map[interface{}]interface{}:
  198. var clauseMap = clause.Map{}
  199. for i, j := range v {
  200. clauseMap[i] = j
  201. }
  202. conditions = append(conditions, clauseMap)
  203. case map[string]string:
  204. var clauseMap = clause.Map{}
  205. for i, j := range v {
  206. clauseMap[i] = j
  207. }
  208. conditions = append(conditions, clauseMap)
  209. case map[string]interface{}:
  210. var clauseMap = clause.Map{}
  211. for i, j := range v {
  212. clauseMap[i] = j
  213. }
  214. conditions = append(conditions, clauseMap)
  215. default:
  216. // TODO check is struct
  217. // struct, slice -> ids
  218. }
  219. }
  220. if len(conditions) == 0 {
  221. conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args})
  222. }
  223. return conditions
  224. }
  225. // Build build sql with clauses names
  226. func (stmt *Statement) Build(clauses ...string) {
  227. var firstClauseWritten bool
  228. for _, name := range clauses {
  229. if c, ok := stmt.Clauses[name]; ok {
  230. if firstClauseWritten {
  231. stmt.WriteByte(' ')
  232. }
  233. firstClauseWritten = true
  234. if b, ok := stmt.DB.ClauseBuilders[name]; ok {
  235. b.Build(c, stmt)
  236. } else {
  237. c.Build(stmt)
  238. }
  239. }
  240. }
  241. // TODO handle named vars
  242. }
  243. func (stmt *Statement) Parse(value interface{}) (err error) {
  244. if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
  245. stmt.Table = stmt.Schema.Table
  246. }
  247. return err
  248. }