Prechádzať zdrojové kódy

Add Selects, Omits for statement

Jinzhu 4 rokov pred
rodič
commit
98ad29f2c2
7 zmenil súbory, kde vykonal 73 pridanie a 31 odobranie
  1. 58 14
      chainable_api.go
  2. 6 6
      clause/select.go
  3. 1 1
      clause/select_test.go
  4. 0 7
      dialects/mysql/go.mod
  5. 1 3
      go.mod
  6. 5 0
      helpers.go
  7. 2 0
      statement.go

+ 58 - 14
chainable_api.go

@@ -2,6 +2,7 @@ package gorm
 
 import (
 	"fmt"
+	"strings"
 
 	"github.com/jinzhu/gorm/clause"
 )
@@ -31,9 +32,7 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
 	}
 
 	if len(whereConds) > 0 {
-		tx.Statement.AddClause(&clause.Where{
-			tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...),
-		})
+		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...)})
 	}
 	return
 }
@@ -48,38 +47,83 @@ func (db *DB) Table(name string) (tx *DB) {
 // Select specify fields that you want when querying, creating, updating
 func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
 	tx = db.getInstance()
+
+	switch v := query.(type) {
+	case []string:
+		tx.Statement.Selects = v
+
+		for _, arg := range args {
+			switch arg := arg.(type) {
+			case string:
+				tx.Statement.Selects = append(tx.Statement.Selects, arg)
+			case []string:
+				tx.Statement.Selects = append(tx.Statement.Selects, arg...)
+			default:
+				tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
+				return
+			}
+		}
+	case string:
+		fields := strings.FieldsFunc(v, isChar)
+
+		// normal field names
+		if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
+			tx.Statement.Selects = fields
+
+			for _, arg := range args {
+				switch arg := arg.(type) {
+				case string:
+					tx.Statement.Selects = append(tx.Statement.Selects, arg)
+				case []string:
+					tx.Statement.Selects = append(tx.Statement.Selects, arg...)
+				default:
+					tx.Statement.AddClause(clause.Select{
+						Expression: clause.Expr{SQL: v, Vars: args},
+					})
+					return
+				}
+			}
+		} else {
+			tx.Statement.AddClause(clause.Select{
+				Expression: clause.Expr{SQL: v, Vars: args},
+			})
+		}
+	default:
+		tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
+	}
+
 	return
 }
 
 // Omit specify fields that you want to ignore when creating, updating and querying
 func (db *DB) Omit(columns ...string) (tx *DB) {
 	tx = db.getInstance()
+
+	if len(columns) == 1 && strings.Contains(columns[0], ",") {
+		tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar)
+	} else {
+		tx.Statement.Omits = columns
+	}
 	return
 }
 
 func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
 	tx = db.getInstance()
-	tx.Statement.AddClause(&clause.Where{
-		tx.Statement.BuildCondtion(query, args...),
-	})
+	tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)})
 	return
 }
 
 // Not add NOT condition
 func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
 	tx = db.getInstance()
-	tx.Statement.AddClause(&clause.Where{
-		[]clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)},
-	})
+	tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}})
 	return
 }
 
 // Or add OR conditions
 func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
 	tx = db.getInstance()
-	tx.Statement.AddClause(&clause.Where{
-		[]clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)},
-	})
+	tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}})
 	return
 }
 
@@ -110,11 +154,11 @@ func (db *DB) Order(value interface{}) (tx *DB) {
 
 	switch v := value.(type) {
 	case clause.OrderByColumn:
-		db.Statement.AddClause(clause.OrderBy{
+		tx.Statement.AddClause(clause.OrderBy{
 			Columns: []clause.OrderByColumn{v},
 		})
 	default:
-		db.Statement.AddClause(clause.OrderBy{
+		tx.Statement.AddClause(clause.OrderBy{
 			Columns: []clause.OrderByColumn{{
 				Column: clause.Column{Name: fmt.Sprint(value), Raw: true},
 			}},

+ 6 - 6
clause/select.go

@@ -2,8 +2,8 @@ package clause
 
 // Select select attrs when querying, updating, creating
 type Select struct {
-	Columns []Column
-	Omits   []Column
+	Columns    []Column
+	Expression Expression
 }
 
 func (s Select) Name() string {
@@ -24,9 +24,9 @@ func (s Select) Build(builder Builder) {
 }
 
 func (s Select) MergeClause(clause *Clause) {
-	if v, ok := clause.Expression.(Select); ok {
-		s.Columns = append(v.Columns, s.Columns...)
-		s.Omits = append(v.Omits, s.Omits...)
+	if s.Expression != nil {
+		clause.Expression = s.Expression
+	} else {
+		clause.Expression = s
 	}
-	clause.Expression = s
 }

+ 1 - 1
clause/select_test.go

@@ -29,7 +29,7 @@ func TestSelect(t *testing.T) {
 			}, clause.Select{
 				Columns: []clause.Column{{Name: "name"}},
 			}, clause.From{}},
-			"SELECT `users`.`id`,`name` FROM `users`", nil,
+			"SELECT `name` FROM `users`", nil,
 		},
 	}
 

+ 0 - 7
dialects/mysql/go.mod

@@ -1,7 +0,0 @@
-module github.com/jinzhu/gorm/dialects/mysql
-
-go 1.13
-
-require (
-	github.com/go-sql-driver/mysql v1.5.0
-)

+ 1 - 3
go.mod

@@ -3,8 +3,6 @@ module github.com/jinzhu/gorm
 go 1.13
 
 require (
-	github.com/go-sql-driver/mysql v1.5.0 // indirect
 	github.com/jinzhu/inflection v1.0.0
-	github.com/lib/pq v1.3.0 // indirect
-	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
+	github.com/jinzhu/now v1.1.1
 )

+ 5 - 0
helpers.go

@@ -3,6 +3,7 @@ package gorm
 import (
 	"errors"
 	"time"
+	"unicode"
 )
 
 var (
@@ -27,3 +28,7 @@ type Model struct {
 	UpdatedAt time.Time
 	DeletedAt *time.Time `gorm:"index"`
 }
+
+func isChar(c rune) bool {
+	return !unicode.IsLetter(c) && !unicode.IsNumber(c)
+}

+ 2 - 0
statement.go

@@ -43,6 +43,8 @@ type Statement struct {
 	Model    interface{}
 	Dest     interface{}
 	Clauses  map[string]clause.Clause
+	Selects  []string // selected columns
+	Omits    []string // omit columns
 	Settings sync.Map
 	DB       *DB
 	Schema   *schema.Schema