statement.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package gorm
  2. import (
  3. "bytes"
  4. "context"
  5. "database/sql"
  6. "fmt"
  7. "strings"
  8. "sync"
  9. "github.com/jinzhu/gorm/clause"
  10. )
  11. // Statement statement
  12. type Statement struct {
  13. Dest interface{}
  14. Table interface{}
  15. Clauses map[string][]clause.Interface
  16. Settings sync.Map
  17. Context context.Context
  18. DB *DB
  19. StatementBuilder
  20. }
  21. // StatementBuilder statement builder
  22. type StatementBuilder struct {
  23. SQL bytes.Buffer
  24. Vars []interface{}
  25. NamedVars []sql.NamedArg
  26. }
  27. // Write write string
  28. func (stmt Statement) Write(sql ...string) (err error) {
  29. for _, s := range sql {
  30. _, err = stmt.SQL.WriteString(s)
  31. }
  32. return
  33. }
  34. // WriteQuoted write quoted field
  35. func (stmt Statement) WriteQuoted(field interface{}) (err error) {
  36. _, err = stmt.SQL.WriteString(stmt.Quote(field))
  37. return
  38. }
  39. // Write write string
  40. func (stmt Statement) AddVar(vars ...interface{}) string {
  41. var placeholders []string
  42. for _, v := range vars {
  43. if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 {
  44. stmt.NamedVars = append(stmt.NamedVars, namedArg)
  45. placeholders = append(placeholders, "@"+namedArg.Name)
  46. } else {
  47. placeholders = append(placeholders, stmt.DB.Dialector.BindVar(stmt, v))
  48. }
  49. }
  50. return strings.Join(placeholders, ",")
  51. }
  52. // Quote returns quoted value
  53. func (stmt Statement) Quote(field interface{}) (str string) {
  54. return fmt.Sprint(field)
  55. }
  56. // AddClause add clause
  57. func (s Statement) AddClause(clause clause.Interface) {
  58. s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause)
  59. }