statement.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "database/sql/driver"
  6. "fmt"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "github.com/jinzhu/gorm/clause"
  11. "github.com/jinzhu/gorm/schema"
  12. )
  13. // Instance db instance
  14. type Instance struct {
  15. Error error
  16. RowsAffected int64
  17. Context context.Context
  18. Statement *Statement
  19. }
  20. func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
  21. if len(clauses) > 0 {
  22. instance.Statement.Build(clauses...)
  23. }
  24. return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars
  25. }
  26. // AddError add error to instance
  27. func (inst *Instance) AddError(err error) {
  28. if inst.Error == nil {
  29. inst.Error = err
  30. } else {
  31. inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
  32. }
  33. }
  34. // Statement statement
  35. type Statement struct {
  36. Table string
  37. Model interface{}
  38. Dest interface{}
  39. Clauses map[string]clause.Clause
  40. Selects []string // selected columns
  41. Omits []string // omit columns
  42. Settings sync.Map
  43. DB *DB
  44. Schema *schema.Schema
  45. // SQL Builder
  46. SQL strings.Builder
  47. Vars []interface{}
  48. NamedVars []sql.NamedArg
  49. }
  50. // StatementOptimizer statement optimizer interface
  51. type StatementOptimizer interface {
  52. OptimizeStatement(*Statement)
  53. }
  54. // Write write string
  55. func (stmt *Statement) Write(sql ...string) (err error) {
  56. for _, s := range sql {
  57. _, err = stmt.SQL.WriteString(s)
  58. }
  59. return
  60. }
  61. // Write write string
  62. func (stmt *Statement) WriteByte(c byte) (err error) {
  63. return stmt.SQL.WriteByte(c)
  64. }
  65. // WriteQuoted write quoted field
  66. func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
  67. _, err = stmt.SQL.WriteString(stmt.Quote(field))
  68. return
  69. }
  70. // Quote returns quoted value
  71. func (stmt Statement) Quote(field interface{}) string {
  72. var str strings.Builder
  73. str.WriteByte(stmt.DB.quoteChars[0])
  74. switch v := field.(type) {
  75. case clause.Table:
  76. if v.Name == clause.CurrentTable {
  77. str.WriteString(stmt.Table)
  78. } else {
  79. str.WriteString(v.Name)
  80. }
  81. if v.Alias != "" {
  82. str.WriteByte(stmt.DB.quoteChars[1])
  83. str.WriteString(" AS ")
  84. str.WriteByte(stmt.DB.quoteChars[0])
  85. str.WriteString(v.Alias)
  86. str.WriteByte(stmt.DB.quoteChars[1])
  87. }
  88. case clause.Column:
  89. if v.Table != "" {
  90. if v.Table == clause.CurrentTable {
  91. str.WriteString(stmt.Table)
  92. } else {
  93. str.WriteString(v.Table)
  94. }
  95. str.WriteByte(stmt.DB.quoteChars[1])
  96. str.WriteByte('.')
  97. str.WriteByte(stmt.DB.quoteChars[0])
  98. }
  99. if v.Name == clause.PrimaryKey {
  100. if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
  101. str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName)
  102. }
  103. } else {
  104. str.WriteString(v.Name)
  105. }
  106. if v.Alias != "" {
  107. str.WriteByte(stmt.DB.quoteChars[1])
  108. str.WriteString(" AS ")
  109. str.WriteByte(stmt.DB.quoteChars[0])
  110. str.WriteString(v.Alias)
  111. str.WriteByte(stmt.DB.quoteChars[1])
  112. }
  113. default:
  114. str.WriteString(fmt.Sprint(field))
  115. }
  116. str.WriteByte(stmt.DB.quoteChars[1])
  117. return str.String()
  118. }
  119. // Write write string
  120. func (stmt *Statement) AddVar(vars ...interface{}) string {
  121. var placeholders strings.Builder
  122. for idx, v := range vars {
  123. if idx > 0 {
  124. placeholders.WriteByte(',')
  125. }
  126. switch v := v.(type) {
  127. case sql.NamedArg:
  128. if len(v.Name) > 0 {
  129. stmt.NamedVars = append(stmt.NamedVars, v)
  130. placeholders.WriteByte('@')
  131. placeholders.WriteString(v.Name)
  132. } else {
  133. stmt.Vars = append(stmt.Vars, v.Value)
  134. placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
  135. }
  136. case clause.Column:
  137. placeholders.WriteString(stmt.Quote(v))
  138. case []interface{}:
  139. if len(v) > 0 {
  140. placeholders.WriteByte('(')
  141. placeholders.WriteString(stmt.AddVar(v...))
  142. placeholders.WriteByte(')')
  143. } else {
  144. placeholders.WriteString("(NULL)")
  145. }
  146. default:
  147. stmt.Vars = append(stmt.Vars, v)
  148. placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
  149. }
  150. }
  151. return placeholders.String()
  152. }
  153. // AddClause add clause
  154. func (stmt *Statement) AddClause(v clause.Interface) {
  155. if optimizer, ok := v.(StatementOptimizer); ok {
  156. optimizer.OptimizeStatement(stmt)
  157. }
  158. c, ok := stmt.Clauses[v.Name()]
  159. if !ok {
  160. c.Name = v.Name()
  161. }
  162. v.MergeClause(&c)
  163. stmt.Clauses[v.Name()] = c
  164. }
  165. // AddClauseIfNotExists add clause if not exists
  166. func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
  167. if _, ok := stmt.Clauses[v.Name()]; !ok {
  168. stmt.AddClause(v)
  169. }
  170. }
  171. // BuildCondtion build condition
  172. func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
  173. if sql, ok := query.(string); ok {
  174. if i, err := strconv.Atoi(sql); err != nil {
  175. query = i
  176. } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
  177. return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
  178. }
  179. }
  180. args = append([]interface{}{query}, args...)
  181. for _, arg := range args {
  182. if valuer, ok := arg.(driver.Valuer); ok {
  183. arg, _ = valuer.Value()
  184. }
  185. switch v := arg.(type) {
  186. case clause.Expression:
  187. conditions = append(conditions, v)
  188. case *DB:
  189. if v.Statement == nil {
  190. if cs, ok := v.Statement.Clauses["WHERE"]; ok {
  191. conditions = append(conditions, cs.Expression)
  192. }
  193. }
  194. case map[interface{}]interface{}:
  195. var clauseMap = clause.Map{}
  196. for i, j := range v {
  197. clauseMap[i] = j
  198. }
  199. conditions = append(conditions, clauseMap)
  200. case map[string]string:
  201. var clauseMap = clause.Map{}
  202. for i, j := range v {
  203. clauseMap[i] = j
  204. }
  205. conditions = append(conditions, clauseMap)
  206. case map[string]interface{}:
  207. var clauseMap = clause.Map{}
  208. for i, j := range v {
  209. clauseMap[i] = j
  210. }
  211. conditions = append(conditions, clauseMap)
  212. default:
  213. // TODO check is struct
  214. // struct, slice -> ids
  215. }
  216. }
  217. if len(conditions) == 0 {
  218. conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args})
  219. }
  220. return conditions
  221. }
  222. // Build build sql with clauses names
  223. func (stmt *Statement) Build(clauses ...string) {
  224. var firstClauseWritten bool
  225. for _, name := range clauses {
  226. if c, ok := stmt.Clauses[name]; ok {
  227. if firstClauseWritten {
  228. stmt.WriteByte(' ')
  229. }
  230. firstClauseWritten = true
  231. if b, ok := stmt.DB.ClauseBuilders[name]; ok {
  232. b.Build(c, stmt)
  233. } else {
  234. c.Build(stmt)
  235. }
  236. }
  237. }
  238. // TODO handle named vars
  239. }
  240. func (stmt *Statement) Parse(value interface{}) (err error) {
  241. if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
  242. stmt.Table = stmt.Schema.Table
  243. }
  244. return err
  245. }