Kaynağa Gözat

Add more clauses

Jinzhu 4 yıl önce
ebeveyn
işleme
46b1c85f88

+ 14 - 8
callbacks.go

@@ -69,14 +69,20 @@ func (cs *callbacks) Raw() *processor {
 }
 
 func (p *processor) Execute(db *DB) {
-	if stmt := db.Statement; stmt != nil && stmt.Dest != nil {
-		var err error
-		stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy)
-
-		if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) {
-			db.AddError(err)
-		} else if stmt.Table == "" && stmt.Schema != nil {
-			stmt.Table = stmt.Schema.Table
+	if stmt := db.Statement; stmt != nil {
+		if stmt.Model == nil {
+			stmt.Model = stmt.Dest
+		}
+
+		if stmt.Model != nil {
+			var err error
+			stmt.Schema, err = schema.Parse(stmt.Model, db.cacheStore, db.NamingStrategy)
+
+			if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
+				db.AddError(err)
+			} else if stmt.Table == "" && stmt.Schema != nil {
+				stmt.Table = stmt.Schema.Table
+			}
 		}
 	}
 

+ 4 - 2
callbacks/callbacks.go

@@ -1,6 +1,8 @@
 package callbacks
 
-import "github.com/jinzhu/gorm"
+import (
+	"github.com/jinzhu/gorm"
+)
 
 func RegisterDefaultCallbacks(db *gorm.DB) {
 	enableTransaction := func(db *gorm.DB) bool {
@@ -17,7 +19,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) {
 	createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
 
 	queryCallback := db.Callback().Query()
-	queryCallback.Register("gorm:query", BeforeCreate)
+	queryCallback.Register("gorm:query", Query)
 	queryCallback.Register("gorm:preload", Preload)
 	queryCallback.Register("gorm:after_query", AfterQuery)
 

+ 1 - 1
callbacks/create.go

@@ -22,7 +22,7 @@ func Create(db *gorm.DB) {
 		Table: clause.Table{Table: db.Statement.Table},
 	})
 
-	db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
+	db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 	fmt.Println(err)
 	fmt.Println(result)

+ 16 - 1
callbacks/query.go

@@ -1,8 +1,23 @@
 package callbacks
 
-import "github.com/jinzhu/gorm"
+import (
+	"fmt"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/clause"
+)
 
 func Query(db *gorm.DB) {
+	db.Statement.AddClauseIfNotExists(clause.Select{})
+	db.Statement.AddClauseIfNotExists(clause.From{
+		Tables: []clause.Table{{Table: clause.CurrentTable}},
+	})
+
+	db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
+	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
+	fmt.Println(err)
+	fmt.Println(result)
+	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
 }
 
 func Preload(db *gorm.DB) {

+ 18 - 1
chainable_api.go

@@ -1,6 +1,10 @@
 package gorm
 
-import "github.com/jinzhu/gorm/clause"
+import (
+	"fmt"
+
+	"github.com/jinzhu/gorm/clause"
+)
 
 // Model specify the model you would like to run db operations
 //    // update all users's name to `hello`
@@ -107,6 +111,19 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
 //     db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
 func (db *DB) Order(value interface{}) (tx *DB) {
 	tx = db.getInstance()
+
+	switch v := value.(type) {
+	case clause.OrderBy:
+		db.Statement.AddClause(clause.OrderByClause{
+			Columns: []clause.OrderBy{v},
+		})
+	default:
+		db.Statement.AddClause(clause.OrderByClause{
+			Columns: []clause.OrderBy{{
+				Column: clause.Column{Name: fmt.Sprint(value), Raw: true},
+			}},
+		})
+	}
 	return
 }
 

+ 11 - 20
clause/clause.go

@@ -11,11 +11,6 @@ type Clause struct {
 	Builder              ClauseBuilder
 }
 
-// ClauseBuilder clause builder, allows to custmize how to build clause
-type ClauseBuilder interface {
-	Build(Clause, Builder)
-}
-
 // Build build clause
 func (c Clause) Build(builder Builder) {
 	if c.Builder != nil {
@@ -47,25 +42,21 @@ type Interface interface {
 	MergeExpression(Expression)
 }
 
+// OverrideNameInterface override name interface
 type OverrideNameInterface interface {
 	OverrideName() string
 }
 
-// Column quote with name
-type Column struct {
-	Table string
-	Name  string
-	Alias string
-	Raw   bool
-}
-
-func ToColumns(value ...interface{}) []Column {
-	return nil
+// ClauseBuilder clause builder, allows to custmize how to build clause
+type ClauseBuilder interface {
+	Build(Clause, Builder)
 }
 
-// Table quote with name
-type Table struct {
-	Table string
-	Alias string
-	Raw   bool
+// Builder builder interface
+type Builder interface {
+	WriteByte(byte) error
+	Write(sql ...string) error
+	WriteQuoted(field interface{}) error
+	AddVar(vars ...interface{}) string
+	Quote(field interface{}) string
 }

+ 18 - 7
clause/expression.go

@@ -1,5 +1,10 @@
 package clause
 
+const (
+	PrimaryKey   string = "@@@priamry_key@@@"
+	CurrentTable string = "@@@table@@@"
+)
+
 // Expression expression interface
 type Expression interface {
 	Build(builder Builder)
@@ -10,13 +15,19 @@ type NegationExpressionBuilder interface {
 	NegationBuild(builder Builder)
 }
 
-// Builder builder interface
-type Builder interface {
-	WriteByte(byte) error
-	Write(sql ...string) error
-	WriteQuoted(field interface{}) error
-	AddVar(vars ...interface{}) string
-	Quote(field interface{}) string
+// Column quote with name
+type Column struct {
+	Table string
+	Name  string
+	Alias string
+	Raw   bool
+}
+
+// Table quote with name
+type Table struct {
+	Table string
+	Alias string
+	Raw   bool
 }
 
 // Expr raw expression

+ 7 - 0
clause/from.go

@@ -20,3 +20,10 @@ func (from From) Build(builder Builder) {
 		builder.WriteQuoted(table)
 	}
 }
+
+// MergeExpression merge order by clauses
+func (from From) MergeExpression(expr Expression) {
+	if v, ok := expr.(From); ok {
+		from.Tables = append(v.Tables, from.Tables...)
+	}
+}

+ 6 - 0
clause/on_conflict.go

@@ -0,0 +1,6 @@
+package clause
+
+type OnConflict struct {
+	ON     string  // duplicate key
+	Values *Values // update c=c+1
+}

+ 34 - 0
clause/order_by.go

@@ -1,4 +1,38 @@
 package clause
 
 type OrderBy struct {
+	Column  Column
+	Desc    bool
+	Reorder bool
+}
+
+type OrderByClause struct {
+	Columns []OrderBy
+}
+
+// Name where clause name
+func (orderBy OrderByClause) Name() string {
+	return "ORDER BY"
+}
+
+// Build build where clause
+func (orderBy OrderByClause) Build(builder Builder) {
+	for i := len(orderBy.Columns) - 1; i >= 0; i-- {
+		builder.WriteQuoted(orderBy.Columns[i].Column)
+
+		if orderBy.Columns[i].Desc {
+			builder.Write(" DESC")
+		}
+
+		if orderBy.Columns[i].Reorder {
+			break
+		}
+	}
+}
+
+// MergeExpression merge order by clauses
+func (orderBy OrderByClause) MergeExpression(expr Expression) {
+	if v, ok := expr.(OrderByClause); ok {
+		orderBy.Columns = append(v.Columns, orderBy.Columns...)
+	}
 }

+ 8 - 4
clause/select.go

@@ -1,15 +1,19 @@
 package clause
 
+// SelectInterface select clause interface
+type SelectInterface interface {
+	Selects() []Column
+	Omits() []Column
+}
+
 // Select select attrs when querying, updating, creating
 type Select struct {
 	SelectColumns []Column
 	OmitColumns   []Column
 }
 
-// SelectInterface select clause interface
-type SelectInterface interface {
-	Selects() []Column
-	Omits() []Column
+func (s Select) Name() string {
+	return "SELECT"
 }
 
 func (s Select) Selects() []Column {

+ 6 - 2
finisher_api.go

@@ -2,6 +2,8 @@ package gorm
 
 import (
 	"database/sql"
+
+	"github.com/jinzhu/gorm/clause"
 )
 
 // Create insert the value into database
@@ -20,9 +22,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
 
 // First find first record that match given conditions, order by primary key
 func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
-	tx = db.getInstance()
+	tx = db.getInstance().Limit(1).Order(clause.OrderBy{
+		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
+		Desc:   true,
+	})
 	tx.Statement.Dest = out
-	tx.Limit(1)
 	tx.callbacks.Query().Execute(tx)
 	return
 }

+ 5 - 4
gorm.go

@@ -61,10 +61,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
 	}
 
 	db = &DB{
-		Config:     config,
-		Dialector:  dialector,
-		clone:      true,
-		cacheStore: &sync.Map{},
+		Config:         config,
+		Dialector:      dialector,
+		ClauseBuilders: map[string]clause.ClauseBuilder{},
+		clone:          true,
+		cacheStore:     &sync.Map{},
 	}
 
 	db.callbacks = initializeCallbacks(db)

+ 13 - 3
statement.go

@@ -84,18 +84,28 @@ func (stmt Statement) Quote(field interface{}) string {
 
 	switch v := field.(type) {
 	case clause.Table:
-		str.WriteString(v.Table)
+
 		if v.Alias != "" {
 			str.WriteString(" AS ")
 			str.WriteString(v.Alias)
 		}
 	case clause.Column:
 		if v.Table != "" {
-			str.WriteString(v.Table)
+			if v.Table == clause.CurrentTable {
+				str.WriteString(stmt.Table)
+			} else {
+				str.WriteString(v.Table)
+			}
 			str.WriteByte('.')
 		}
 
-		str.WriteString(v.Name)
+		if v.Name == clause.PrimaryKey {
+			if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
+				str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName)
+			}
+		} else {
+			str.WriteString(v.Name)
+		}
 		if v.Alias != "" {
 			str.WriteString(" AS ")
 			str.WriteString(v.Alias)