|
@@ -4,7 +4,9 @@ import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
"database/sql"
|
|
|
+ "database/sql/driver"
|
|
|
"fmt"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
|
|
@@ -13,9 +15,10 @@ import (
|
|
|
|
|
|
// Statement statement
|
|
|
type Statement struct {
|
|
|
+ Model interface{}
|
|
|
Dest interface{}
|
|
|
- Table interface{}
|
|
|
- Clauses map[string][]clause.Interface
|
|
|
+ Table string
|
|
|
+ Clauses map[string][]clause.Condition
|
|
|
Settings sync.Map
|
|
|
Context context.Context
|
|
|
DB *DB
|
|
@@ -45,16 +48,29 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) {
|
|
|
|
|
|
// Write write string
|
|
|
func (stmt Statement) AddVar(vars ...interface{}) string {
|
|
|
- var placeholders []string
|
|
|
- for _, v := range vars {
|
|
|
+ var placeholders strings.Builder
|
|
|
+ for idx, v := range vars {
|
|
|
+ if idx > 0 {
|
|
|
+ placeholders.WriteByte(',')
|
|
|
+ }
|
|
|
+
|
|
|
if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 {
|
|
|
stmt.NamedVars = append(stmt.NamedVars, namedArg)
|
|
|
- placeholders = append(placeholders, "@"+namedArg.Name)
|
|
|
+ placeholders.WriteByte('@')
|
|
|
+ placeholders.WriteString(namedArg.Name)
|
|
|
+ } else if arrs, ok := v.([]interface{}); ok {
|
|
|
+ placeholders.WriteByte('(')
|
|
|
+ if len(arrs) > 0 {
|
|
|
+ placeholders.WriteString(stmt.AddVar(arrs...))
|
|
|
+ } else {
|
|
|
+ placeholders.WriteString("NULL")
|
|
|
+ }
|
|
|
+ placeholders.WriteByte(')')
|
|
|
} else {
|
|
|
- placeholders = append(placeholders, stmt.DB.Dialector.BindVar(stmt, v))
|
|
|
+ placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
|
|
|
}
|
|
|
}
|
|
|
- return strings.Join(placeholders, ",")
|
|
|
+ return placeholders.String()
|
|
|
}
|
|
|
|
|
|
// Quote returns quoted value
|
|
@@ -66,3 +82,61 @@ func (stmt Statement) Quote(field interface{}) (str string) {
|
|
|
func (s Statement) AddClause(clause clause.Interface) {
|
|
|
s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause)
|
|
|
}
|
|
|
+
|
|
|
+// BuildCondtions build conditions
|
|
|
+func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) {
|
|
|
+ if sql, ok := query.(string); ok {
|
|
|
+ if i, err := strconv.Atoi(sql); err != nil {
|
|
|
+ query = i
|
|
|
+ } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
|
|
|
+ return []clause.Condition{clause.Raw{SQL: sql, Values: args}}
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ args = append([]interface{}{query}, args...)
|
|
|
+ for _, arg := range args {
|
|
|
+ if valuer, ok := arg.(driver.Valuer); ok {
|
|
|
+ arg, _ = valuer.Value()
|
|
|
+ }
|
|
|
+
|
|
|
+ switch v := arg.(type) {
|
|
|
+ case clause.Builder:
|
|
|
+ conditions = append(conditions, v)
|
|
|
+ case *DB:
|
|
|
+ if v.Statement == nil {
|
|
|
+ if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
|
|
+ conditions = append(conditions, cs...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ case map[interface{}]interface{}:
|
|
|
+ var clauseMap = clause.Map{}
|
|
|
+ for i, j := range v {
|
|
|
+ clauseMap[i] = j
|
|
|
+ }
|
|
|
+ conditions = append(conditions, clauseMap)
|
|
|
+ case map[string]string:
|
|
|
+ var clauseMap = clause.Map{}
|
|
|
+ for i, j := range v {
|
|
|
+ clauseMap[i] = j
|
|
|
+ }
|
|
|
+ conditions = append(conditions, clauseMap)
|
|
|
+ case map[string]interface{}:
|
|
|
+ var clauseMap = clause.Map{}
|
|
|
+ for i, j := range v {
|
|
|
+ clauseMap[i] = j
|
|
|
+ }
|
|
|
+ conditions = append(conditions, clauseMap)
|
|
|
+ default:
|
|
|
+ // TODO check is struct
|
|
|
+ // struct, slice -> ids
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(conditions) == 0 {
|
|
|
+ conditions = append(conditions, clause.ID{Value: args})
|
|
|
+ }
|
|
|
+ return conditions
|
|
|
+}
|
|
|
+
|
|
|
+func (s Statement) AddError(err error) {
|
|
|
+}
|