Browse Source

Add Raw, Row, Rows

Jinzhu 4 years ago
parent
commit
215f5e7765

+ 3 - 0
callbacks/callbacks.go

@@ -38,4 +38,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) {
 	updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
 	updateCallback.Register("gorm:after_update", AfterUpdate)
 	updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
+
+	db.Callback().Row().Register("gorm:raw", RowQuery)
+	db.Callback().Raw().Register("gorm:raw", RawExec)
 }

+ 11 - 0
callbacks/raw.go

@@ -0,0 +1,11 @@
+package callbacks
+
+import "github.com/jinzhu/gorm"
+
+func RawExec(db *gorm.DB) {
+	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
+	db.RowsAffected, _ = result.RowsAffected()
+	if err != nil {
+		db.AddError(err)
+	}
+}

+ 19 - 0
callbacks/row.go

@@ -0,0 +1,19 @@
+package callbacks
+
+import (
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/clause"
+)
+
+func RowQuery(db *gorm.DB) {
+	db.Statement.AddClauseIfNotExists(clause.Select{})
+	db.Statement.AddClauseIfNotExists(clause.From{})
+
+	db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
+
+	if _, ok := db.Get("rows"); ok {
+		db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
+	} else {
+		db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
+	}
+}

+ 3 - 0
chainable_api.go

@@ -222,5 +222,8 @@ func (db *DB) Unscoped() (tx *DB) {
 
 func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
 	tx = db.getInstance()
+	stmt := tx.Statement
+	stmt.SQL = strings.Builder{}
+	clause.Expr{SQL: sql, Vars: values}.Build(stmt)
 	return
 }

+ 4 - 1
dialects/mssql/mssql.go

@@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 }
 
 func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
-	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
+	return Migrator{migrator.Migrator{Config: migrator.Config{
+		DB:        db,
+		Dialector: dialector,
+	}}}
 }
 
 func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {

+ 4 - 1
dialects/mysql/mysql.go

@@ -29,7 +29,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 }
 
 func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
-	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
+	return Migrator{migrator.Migrator{Config: migrator.Config{
+		DB:        db,
+		Dialector: dialector,
+	}}}
 }
 
 func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {

+ 4 - 1
dialects/postgres/postgres.go

@@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 }
 
 func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
-	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
+	return Migrator{migrator.Migrator{Config: migrator.Config{
+		DB:        db,
+		Dialector: dialector,
+	}}}
 }
 
 func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {

+ 4 - 1
dialects/sqlite/sqlite.go

@@ -27,7 +27,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 }
 
 func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
-	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
+	return Migrator{migrator.Migrator{Config: migrator.Config{
+		DB:        db,
+		Dialector: dialector,
+	}}}
 }
 
 func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {

+ 5 - 1
dialects/sqlite/sqlite_test.go

@@ -22,6 +22,10 @@ func init() {
 	}
 }
 
-func TestSqlite(t *testing.T) {
+func TestCURD(t *testing.T) {
 	tests.RunTestsSuit(t, DB)
 }
+
+func TestMigrate(t *testing.T) {
+	tests.TestMigrate(t, DB)
+}

+ 7 - 2
finisher_api.go

@@ -108,11 +108,15 @@ func (db *DB) Count(value interface{}) (tx *DB) {
 }
 
 func (db *DB) Row() *sql.Row {
-	return nil
+	tx := db.getInstance()
+	tx.callbacks.Row().Execute(tx)
+	return tx.Statement.Dest.(*sql.Row)
 }
 
 func (db *DB) Rows() (*sql.Rows, error) {
-	return nil, nil
+	tx := db.Set("rows", true)
+	tx.callbacks.Row().Execute(tx)
+	return tx.Statement.Dest.(*sql.Rows), tx.Error
 }
 
 // Scan scan value to a struct
@@ -162,5 +166,6 @@ func (db *DB) Rollback() (tx *DB) {
 
 func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
 	tx = db.getInstance()
+	tx.callbacks.Raw().Execute(tx)
 	return
 }

+ 5 - 0
gorm.go

@@ -138,6 +138,11 @@ func (db *DB) Callback() *callbacks {
 	return db.callbacks
 }
 
+// AutoMigrate run auto migration for given models
+func (db *DB) AutoMigrate(dst ...interface{}) error {
+	return db.Migrator().AutoMigrate(dst...)
+}
+
 func (db *DB) getInstance() *DB {
 	if db.clone {
 		ctx := db.Instance.Context

+ 9 - 2
migrator/migrator.go

@@ -265,8 +265,15 @@ func (m Migrator) RenameColumn(value interface{}, oldName, field string) error {
 	})
 }
 
-func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) {
-	return nil, gorm.ErrNotImplemented
+func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
+	err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
+		if err == nil {
+			columnTypes, err = rows.ColumnTypes()
+		}
+		return err
+	})
+	return
 }
 
 func (m Migrator) CreateView(name string, option gorm.ViewOption) error {

+ 4 - 1
schema/check.go

@@ -17,9 +17,12 @@ func (schema *Schema) ParseCheckConstraints() map[string]Check {
 	for _, field := range schema.FieldsByDBName {
 		if chk := field.TagSettings["CHECK"]; chk != "" {
 			names := strings.Split(chk, ",")
-			if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) {
+			if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) {
 				checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
 			} else {
+				if names[0] == "" {
+					chk = strings.Join(names[1:], ",")
+				}
 				name := schema.namer.CheckerName(schema.Table, field.DBName)
 				checks[name] = Check{Name: name, Constraint: chk, Field: field}
 			}

+ 55 - 0
schema/check_test.go

@@ -0,0 +1,55 @@
+package schema_test
+
+import (
+	"reflect"
+	"sync"
+	"testing"
+
+	"github.com/jinzhu/gorm/schema"
+)
+
+type UserCheck struct {
+	Name  string `gorm:"check:name_checker,name <> 'jinzhu'"`
+	Name2 string `gorm:"check:name <> 'jinzhu'"`
+	Name3 string `gorm:"check:,name <> 'jinzhu'"`
+}
+
+func TestParseCheck(t *testing.T) {
+	user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
+	if err != nil {
+		t.Fatalf("failed to parse user check, got error %v", err)
+	}
+
+	results := map[string]schema.Check{
+		"name_checker": {
+			Name:       "name_checker",
+			Constraint: "name <> 'jinzhu'",
+		},
+		"chk_user_checks_name2": {
+			Name:       "chk_user_checks_name2",
+			Constraint: "name <> 'jinzhu'",
+		},
+		"chk_user_checks_name3": {
+			Name:       "chk_user_checks_name3",
+			Constraint: "name <> 'jinzhu'",
+		},
+	}
+
+	checks := user.ParseCheckConstraints()
+
+	for k, result := range results {
+		v, ok := checks[k]
+		if !ok {
+			t.Errorf("Failed to found check %v from parsed checks %+v", k, checks)
+		}
+
+		for _, name := range []string{"Name", "Constraint"} {
+			if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
+				t.Errorf(
+					"check %v %v should equal, expects %v, got %v",
+					k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
+				)
+			}
+		}
+	}
+}

+ 1 - 1
schema/index_test.go

@@ -21,7 +21,7 @@ type UserIndex struct {
 func TestParseIndex(t *testing.T) {
 	user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
 	if err != nil {
-		t.Fatalf("failed to parse user index index, got error %v", err)
+		t.Fatalf("failed to parse user index, got error %v", err)
 	}
 
 	results := map[string]schema.Index{

+ 5 - 1
schema/relationship.go

@@ -317,7 +317,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
 		settings = ParseTagSetting(str, ",")
 	)
 
-	if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) {
+	if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) {
 		name = str[0:idx]
 	} else {
 		name = rel.Schema.namer.RelationshipFKName(*rel)
@@ -339,5 +339,9 @@ func (rel *Relationship) ParseConstraint() *Constraint {
 		}
 	}
 
+	if constraint.ReferenceSchema == nil {
+		return nil
+	}
+
 	return &constraint
 }

+ 19 - 0
tests/migrate.go

@@ -0,0 +1,19 @@
+package tests
+
+import (
+	"testing"
+
+	"github.com/jinzhu/gorm"
+)
+
+func TestMigrate(t *testing.T, db *gorm.DB) {
+	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}}
+
+	db.AutoMigrate(allModels...)
+
+	for _, m := range allModels {
+		if !db.Migrator().HasTable(m) {
+			t.Errorf("Failed to create table for %+v", m)
+		}
+	}
+}