Ver Fonte

Implement build conditions

Jinzhu há 4 anos atrás
pai
commit
85bfd175c6
5 ficheiros alterados com 154 adições e 15 exclusões
  1. 2 0
      chainable_api.go
  2. 5 0
      clause/clause.go
  3. 59 7
      clause/operators.go
  4. 7 1
      gorm.go
  5. 81 7
      statement.go

+ 2 - 0
chainable_api.go

@@ -7,12 +7,14 @@ package gorm
 //    db.Model(&user).Update("name", "hello")
 func (db *DB) Model(value interface{}) (tx *DB) {
 	tx = db.getInstance()
+	tx.Statement.Model = value
 	return
 }
 
 // Table specify the table you would like to run db operations
 func (db *DB) Table(name string) (tx *DB) {
 	tx = db.getInstance()
+	tx.Statement.Table = name
 	return
 }
 

+ 5 - 0
clause/clause.go

@@ -11,6 +11,11 @@ type BuilderInterface interface {
 // Interface clause interface
 type Interface interface {
 	Name() string
+	Builder
+}
+
+// Builder condition builder
+type Builder interface {
 	Build(builder BuilderInterface)
 }
 

+ 59 - 7
clause/operators.go

@@ -2,7 +2,8 @@ package clause
 
 import "strings"
 
-type AddConditions []Interface
+type Condition Builder
+type AddConditions []Condition
 
 func (cs AddConditions) Build(builder BuilderInterface) {
 	for idx, c := range cs {
@@ -13,7 +14,7 @@ func (cs AddConditions) Build(builder BuilderInterface) {
 	}
 }
 
-type ORConditions []Interface
+type ORConditions []Condition
 
 func (cs ORConditions) Build(builder BuilderInterface) {
 	for idx, c := range cs {
@@ -24,7 +25,7 @@ func (cs ORConditions) Build(builder BuilderInterface) {
 	}
 }
 
-type NotConditions []Interface
+type NotConditions []Condition
 
 func (cs NotConditions) Build(builder BuilderInterface) {
 	for idx, c := range cs {
@@ -64,16 +65,22 @@ type IN struct {
 func (in IN) Build(builder BuilderInterface) {
 	builder.WriteQuoted(in.Column)
 
-	if len(in.Values) == 0 {
+	switch len(in.Values) {
+	case 0:
 		builder.Write(" IN (NULL)")
-	} else {
+	case 1:
+		builder.Write(" = ", builder.AddVar(in.Values...))
+	default:
 		builder.Write(" IN (", builder.AddVar(in.Values...), ")")
 	}
 }
 
 func (in IN) NegationBuild(builder BuilderInterface) {
-	if len(in.Values) != 0 {
-		builder.WriteQuoted(in.Column)
+	switch len(in.Values) {
+	case 0:
+	case 1:
+		builder.Write(" <> ", builder.AddVar(in.Values...))
+	default:
 		builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")")
 	}
 }
@@ -193,3 +200,48 @@ func (like Like) NegationBuild(builder BuilderInterface) {
 	builder.WriteQuoted(like.Column)
 	builder.Write(" NOT LIKE ", builder.AddVar(like.Value))
 }
+
+// Map
+type Map map[interface{}]interface{}
+
+func (m Map) Build(builder BuilderInterface) {
+	// TODO
+}
+
+func (m Map) NegationBuild(builder BuilderInterface) {
+	// TODO
+}
+
+// Attrs
+type Attrs struct {
+	Value  interface{}
+	Select []string
+	Omit   []string
+}
+
+func (attrs Attrs) Build(builder BuilderInterface) {
+	// TODO
+	// builder.WriteQuoted(like.Column)
+	// builder.Write(" LIKE ", builder.AddVar(like.Value))
+}
+
+func (attrs Attrs) NegationBuild(builder BuilderInterface) {
+	// TODO
+}
+
+// ID
+type ID struct {
+	Value []interface{}
+}
+
+func (id ID) Build(builder BuilderInterface) {
+	if len(id.Value) == 1 {
+	}
+	// TODO
+	// builder.WriteQuoted(like.Column)
+	// builder.Write(" LIKE ", builder.AddVar(like.Value))
+}
+
+func (id ID) NegationBuild(builder BuilderInterface) {
+	// TODO
+}

+ 7 - 1
gorm.go

@@ -93,7 +93,7 @@ func (db *DB) getInstance() *DB {
 			Dialector: db.Dialector,
 			Context:   context.Background(),
 			Result: Result{
-				Statement: &Statement{DB: db, Clauses: map[string][]clause.Interface{}},
+				Statement: &Statement{DB: db, Clauses: map[string][]clause.Condition{}},
 			},
 		}
 	}
@@ -106,3 +106,9 @@ func (db *DB) Debug() (tx *DB) {
 	tx = db.getInstance()
 	return
 }
+
+// Session start session mode
+func (db *DB) Session() (tx *DB) {
+	tx = db.getInstance()
+	return
+}

+ 81 - 7
statement.go

@@ -4,7 +4,9 @@ import (
 	"bytes"
 	"context"
 	"database/sql"
+	"database/sql/driver"
 	"fmt"
+	"strconv"
 	"strings"
 	"sync"
 
@@ -13,9 +15,10 @@ import (
 
 // Statement statement
 type Statement struct {
+	Model    interface{}
 	Dest     interface{}
-	Table    interface{}
-	Clauses  map[string][]clause.Interface
+	Table    string
+	Clauses  map[string][]clause.Condition
 	Settings sync.Map
 	Context  context.Context
 	DB       *DB
@@ -45,16 +48,29 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) {
 
 // Write write string
 func (stmt Statement) AddVar(vars ...interface{}) string {
-	var placeholders []string
-	for _, v := range vars {
+	var placeholders strings.Builder
+	for idx, v := range vars {
+		if idx > 0 {
+			placeholders.WriteByte(',')
+		}
+
 		if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 {
 			stmt.NamedVars = append(stmt.NamedVars, namedArg)
-			placeholders = append(placeholders, "@"+namedArg.Name)
+			placeholders.WriteByte('@')
+			placeholders.WriteString(namedArg.Name)
+		} else if arrs, ok := v.([]interface{}); ok {
+			placeholders.WriteByte('(')
+			if len(arrs) > 0 {
+				placeholders.WriteString(stmt.AddVar(arrs...))
+			} else {
+				placeholders.WriteString("NULL")
+			}
+			placeholders.WriteByte(')')
 		} else {
-			placeholders = append(placeholders, stmt.DB.Dialector.BindVar(stmt, v))
+			placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
 		}
 	}
-	return strings.Join(placeholders, ",")
+	return placeholders.String()
 }
 
 // Quote returns quoted value
@@ -66,3 +82,61 @@ func (stmt Statement) Quote(field interface{}) (str string) {
 func (s Statement) AddClause(clause clause.Interface) {
 	s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause)
 }
+
+// BuildCondtions build conditions
+func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) {
+	if sql, ok := query.(string); ok {
+		if i, err := strconv.Atoi(sql); err != nil {
+			query = i
+		} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
+			return []clause.Condition{clause.Raw{SQL: sql, Values: args}}
+		}
+	}
+
+	args = append([]interface{}{query}, args...)
+	for _, arg := range args {
+		if valuer, ok := arg.(driver.Valuer); ok {
+			arg, _ = valuer.Value()
+		}
+
+		switch v := arg.(type) {
+		case clause.Builder:
+			conditions = append(conditions, v)
+		case *DB:
+			if v.Statement == nil {
+				if cs, ok := v.Statement.Clauses["WHERE"]; ok {
+					conditions = append(conditions, cs...)
+				}
+			}
+		case map[interface{}]interface{}:
+			var clauseMap = clause.Map{}
+			for i, j := range v {
+				clauseMap[i] = j
+			}
+			conditions = append(conditions, clauseMap)
+		case map[string]string:
+			var clauseMap = clause.Map{}
+			for i, j := range v {
+				clauseMap[i] = j
+			}
+			conditions = append(conditions, clauseMap)
+		case map[string]interface{}:
+			var clauseMap = clause.Map{}
+			for i, j := range v {
+				clauseMap[i] = j
+			}
+			conditions = append(conditions, clauseMap)
+		default:
+			// TODO check is struct
+			// struct, slice -> ids
+		}
+	}
+
+	if len(conditions) == 0 {
+		conditions = append(conditions, clause.ID{Value: args})
+	}
+	return conditions
+}
+
+func (s Statement) AddError(err error) {
+}