Browse Source

Add clause tests

Jinzhu 4 years ago
parent
commit
0160bab7dc

+ 1 - 1
chainable_api.go

@@ -80,7 +80,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
 func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
 	tx = db.getInstance()
 	tx.Statement.AddClause(clause.Where{
-		ORConditions: []clause.ORConditions{
+		OrConditions: []clause.OrConditions{
 			tx.Statement.BuildCondtion(query, args...),
 		},
 	})

+ 8 - 6
clause/clause_test.go

@@ -12,17 +12,19 @@ import (
 	"github.com/jinzhu/gorm/tests"
 )
 
-func TestClause(t *testing.T) {
+func TestClauses(t *testing.T) {
 	var (
-		db, _   = gorm.Open(nil, nil)
+		db, _   = gorm.Open(tests.DummyDialector{}, nil)
 		results = []struct {
 			Clauses []clause.Interface
 			Result  string
 			Vars    []interface{}
-		}{{
-			[]clause.Interface{clause.Select{}, clause.From{}},
-			"SELECT * FROM users", []interface{}{},
-		}}
+		}{
+			{
+				[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{AndConditions: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}}}},
+				"SELECT * FROM `users` WHERE `users`.`id` = ?", []interface{}{"1"},
+			},
+		}
 	)
 
 	for idx, result := range results {

+ 5 - 0
clause/expression.go

@@ -5,6 +5,11 @@ const (
 	CurrentTable string = "@@@table@@@"
 )
 
+var PrimaryColumn = Column{
+	Table: CurrentTable,
+	Name:  PrimaryKey,
+}
+
 // Expression expression interface
 type Expression interface {
 	Build(builder Builder)

+ 10 - 2
clause/query.go

@@ -6,6 +6,14 @@ import "strings"
 // Query Expressions
 ////////////////////////////////////////////////////////////////////////////////
 
+func Add(exprs ...Expression) AddConditions {
+	return AddConditions(exprs)
+}
+
+func Or(exprs ...Expression) OrConditions {
+	return OrConditions(exprs)
+}
+
 type AddConditions []Expression
 
 func (cs AddConditions) Build(builder Builder) {
@@ -17,9 +25,9 @@ func (cs AddConditions) Build(builder Builder) {
 	}
 }
 
-type ORConditions []Expression
+type OrConditions []Expression
 
-func (cs ORConditions) Build(builder Builder) {
+func (cs OrConditions) Build(builder Builder) {
 	for idx, c := range cs {
 		if idx > 0 {
 			builder.Write(" OR ")

+ 4 - 4
clause/where.go

@@ -3,7 +3,7 @@ package clause
 // Where where clause
 type Where struct {
 	AndConditions AddConditions
-	ORConditions  []ORConditions
+	OrConditions  []OrConditions
 	builders      []Expression
 }
 
@@ -31,8 +31,8 @@ func (where Where) Build(builder Builder) {
 		}
 	}
 
-	var singleOrConditions []ORConditions
-	for _, or := range where.ORConditions {
+	var singleOrConditions []OrConditions
+	for _, or := range where.OrConditions {
 		if len(or) == 1 {
 			if withConditions {
 				builder.Write(" OR ")
@@ -69,7 +69,7 @@ func (where Where) Build(builder Builder) {
 func (where Where) MergeExpression(expr Expression) {
 	if w, ok := expr.(Where); ok {
 		where.AndConditions = append(where.AndConditions, w.AndConditions...)
-		where.ORConditions = append(where.ORConditions, w.ORConditions...)
+		where.OrConditions = append(where.OrConditions, w.OrConditions...)
 		where.builders = append(where.builders, w.builders...)
 	} else {
 		where.builders = append(where.builders, expr)

+ 4 - 0
dialects/mysql/mysql.go

@@ -27,3 +27,7 @@ func (Dialector) Migrator() gorm.Migrator {
 func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
 	return "?"
 }
+
+func (Dialector) QuoteChars() [2]byte {
+	return [2]byte{'`', '`'} // `name`
+}

+ 4 - 0
dialects/postgres/postgres.go

@@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator {
 func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
 	return "?"
 }
+
+func (Dialector) QuoteChars() [2]byte {
+	return [2]byte{'"', '"'} // "name"
+}

+ 4 - 0
dialects/sqlite/sqlite.go

@@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator {
 func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
 	return "?"
 }
+
+func (Dialector) QuoteChars() [2]byte {
+	return [2]byte{'`', '`'} // `name`
+}

+ 3 - 2
go.mod

@@ -3,7 +3,8 @@ module github.com/jinzhu/gorm
 go 1.13
 
 require (
+	github.com/go-sql-driver/mysql v1.5.0 // indirect
 	github.com/jinzhu/inflection v1.0.0
-	github.com/lib/pq v1.3.0
-	github.com/mattn/go-sqlite3 v2.0.3+incompatible
+	github.com/lib/pq v1.3.0 // indirect
+	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
 )

+ 13 - 6
gorm.go

@@ -23,16 +23,21 @@ type Config struct {
 	NowFunc func() time.Time
 }
 
+type shared struct {
+	callbacks  *callbacks
+	cacheStore *sync.Map
+	quoteChars [2]byte
+}
+
 // DB GORM DB definition
 type DB struct {
 	*Config
 	Dialector
 	Instance
-	DB             CommonDB
 	ClauseBuilders map[string]clause.ClauseBuilder
+	DB             CommonDB
 	clone          bool
-	callbacks      *callbacks
-	cacheStore     *sync.Map
+	*shared
 }
 
 // Session session config when create session with Session() method
@@ -65,13 +70,16 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
 		Dialector:      dialector,
 		ClauseBuilders: map[string]clause.ClauseBuilder{},
 		clone:          true,
-		cacheStore:     &sync.Map{},
+		shared: &shared{
+			cacheStore: &sync.Map{},
+		},
 	}
 
 	db.callbacks = initializeCallbacks(db)
 
 	if dialector != nil {
 		err = dialector.Initialize(db)
+		db.quoteChars = dialector.QuoteChars()
 	}
 	return
 }
@@ -146,8 +154,7 @@ func (db *DB) getInstance() *DB {
 			Dialector:      db.Dialector,
 			ClauseBuilders: db.ClauseBuilders,
 			DB:             db.DB,
-			callbacks:      db.callbacks,
-			cacheStore:     db.cacheStore,
+			shared:         db.shared,
 		}
 	}
 

+ 1 - 0
interfaces.go

@@ -10,6 +10,7 @@ type Dialector interface {
 	Initialize(*DB) error
 	Migrator() Migrator
 	BindVar(stmt *Statement, v interface{}) string
+	QuoteChars() [2]byte
 }
 
 // CommonDB common db interface

+ 11 - 0
statement.go

@@ -81,6 +81,7 @@ func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
 // Quote returns quoted value
 func (stmt Statement) Quote(field interface{}) string {
 	var str strings.Builder
+	str.WriteByte(stmt.DB.quoteChars[0])
 
 	switch v := field.(type) {
 	case clause.Table:
@@ -91,8 +92,11 @@ func (stmt Statement) Quote(field interface{}) string {
 		}
 
 		if v.Alias != "" {
+			str.WriteByte(stmt.DB.quoteChars[1])
 			str.WriteString(" AS ")
+			str.WriteByte(stmt.DB.quoteChars[0])
 			str.WriteString(v.Alias)
+			str.WriteByte(stmt.DB.quoteChars[1])
 		}
 	case clause.Column:
 		if v.Table != "" {
@@ -101,7 +105,9 @@ func (stmt Statement) Quote(field interface{}) string {
 			} else {
 				str.WriteString(v.Table)
 			}
+			str.WriteByte(stmt.DB.quoteChars[1])
 			str.WriteByte('.')
+			str.WriteByte(stmt.DB.quoteChars[0])
 		}
 
 		if v.Name == clause.PrimaryKey {
@@ -111,14 +117,19 @@ func (stmt Statement) Quote(field interface{}) string {
 		} else {
 			str.WriteString(v.Name)
 		}
+
 		if v.Alias != "" {
+			str.WriteByte(stmt.DB.quoteChars[1])
 			str.WriteString(" AS ")
+			str.WriteByte(stmt.DB.quoteChars[0])
 			str.WriteString(v.Alias)
+			str.WriteByte(stmt.DB.quoteChars[1])
 		}
 	default:
 		fmt.Sprint(field)
 	}
 
+	str.WriteByte(stmt.DB.quoteChars[1])
 	return str.String()
 }
 

+ 24 - 0
tests/dummy_dialecter.go

@@ -0,0 +1,24 @@
+package tests
+
+import (
+	"github.com/jinzhu/gorm"
+)
+
+type DummyDialector struct {
+}
+
+func (DummyDialector) Initialize(*gorm.DB) error {
+	return nil
+}
+
+func (DummyDialector) Migrator() gorm.Migrator {
+	return nil
+}
+
+func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
+	return "?"
+}
+
+func (DummyDialector) QuoteChars() [2]byte {
+	return [2]byte{'`', '`'} // `name`
+}