Browse Source

Add sqlite migration tests

Jinzhu 4 years ago
parent
commit
6d58b62fd4

+ 6 - 3
callbacks/query.go

@@ -8,10 +8,13 @@ import (
 )
 
 func Query(db *gorm.DB) {
-	db.Statement.AddClauseIfNotExists(clause.Select{})
-	db.Statement.AddClauseIfNotExists(clause.From{})
+	if db.Statement.SQL.String() == "" {
+		db.Statement.AddClauseIfNotExists(clause.Select{})
+		db.Statement.AddClauseIfNotExists(clause.From{})
+
+		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
+	}
 
-	db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 	fmt.Println(err)
 	fmt.Println(result)

+ 5 - 2
callbacks/raw.go

@@ -1,11 +1,14 @@
 package callbacks
 
-import "github.com/jinzhu/gorm"
+import (
+	"github.com/jinzhu/gorm"
+)
 
 func RawExec(db *gorm.DB) {
 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
-	db.RowsAffected, _ = result.RowsAffected()
 	if err != nil {
 		db.AddError(err)
+	} else {
+		db.RowsAffected, _ = result.RowsAffected()
 	}
 }

+ 5 - 3
callbacks/row.go

@@ -6,10 +6,12 @@ import (
 )
 
 func RowQuery(db *gorm.DB) {
-	db.Statement.AddClauseIfNotExists(clause.Select{})
-	db.Statement.AddClauseIfNotExists(clause.From{})
+	if db.Statement.SQL.String() == "" {
+		db.Statement.AddClauseIfNotExists(clause.Select{})
+		db.Statement.AddClauseIfNotExists(clause.From{})
 
-	db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
+		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
+	}
 
 	if _, ok := db.Get("rows"); ok {
 		db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)

+ 2 - 3
chainable_api.go

@@ -222,8 +222,7 @@ func (db *DB) Unscoped() (tx *DB) {
 
 func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
 	tx = db.getInstance()
-	stmt := tx.Statement
-	stmt.SQL = strings.Builder{}
-	clause.Expr{SQL: sql, Vars: values}.Build(stmt)
+	tx.Statement.SQL = strings.Builder{}
+	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
 	return
 }

+ 4 - 2
clause/expression.go

@@ -1,6 +1,8 @@
 package clause
 
-import "strings"
+import (
+	"strings"
+)
 
 // Expression expression interface
 type Expression interface {
@@ -22,7 +24,7 @@ type Expr struct {
 func (expr Expr) Build(builder Builder) {
 	sql := expr.SQL
 	for _, v := range expr.Vars {
-		sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1)
+		sql = strings.Replace(sql, "?", builder.AddVar(v), 1)
 	}
 	builder.Write(sql)
 }

+ 35 - 0
clause/expression_test.go

@@ -0,0 +1,35 @@
+package clause_test
+
+import (
+	"fmt"
+	"sync"
+	"testing"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/clause"
+	"github.com/jinzhu/gorm/schema"
+	"github.com/jinzhu/gorm/tests"
+)
+
+func TestExpr(t *testing.T) {
+	results := []struct {
+		SQL    string
+		Result string
+		Vars   []interface{}
+	}{{
+		SQL:    "create table ? (? ?, ? ?)",
+		Vars:   []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
+		Result: "create table `users` (`id` int, `name` text)",
+	}}
+
+	for idx, result := range results {
+		t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
+			user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
+			stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
+			clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
+			if stmt.SQL.String() != result.Result {
+				t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
+			}
+		})
+	}
+}

+ 2 - 2
dialects/sqlite/migrator.go

@@ -30,8 +30,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
 		}
 
 		return m.DB.Raw(
-			"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)",
-			stmt.Table, `%"`+name+`" %`, `%`+name+` %`,
+			"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
+			stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%",
 		).Row().Scan(&count)
 	})
 	return count > 0

+ 9 - 8
dialects/sqlite/sqlite.go

@@ -28,8 +28,9 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 
 func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
 	return Migrator{migrator.Migrator{Config: migrator.Config{
-		DB:        db,
-		Dialector: dialector,
+		DB:                          db,
+		Dialector:                   dialector,
+		CreateIndexAfterCreateTable: true,
 	}}}
 }
 
@@ -44,20 +45,20 @@ func (dialector Dialector) QuoteChars() [2]byte {
 func (dialector Dialector) DataTypeOf(field *schema.Field) string {
 	switch field.DataType {
 	case schema.Bool:
-		return "NUMERIC"
+		return "numeric"
 	case schema.Int, schema.Uint:
 		if field.AutoIncrement {
 			// https://www.sqlite.org/autoinc.html
-			return "INTEGER PRIMARY KEY AUTOINCREMENT"
+			return "integer PRIMARY KEY AUTOINCREMENT"
 		} else {
-			return "INTEGER"
+			return "integer"
 		}
 	case schema.Float:
-		return "REAL"
+		return "real"
 	case schema.String, schema.Time:
-		return "TEXT"
+		return "text"
 	case schema.Bytes:
-		return "BLOB"
+		return "blob"
 	}
 
 	return ""

+ 3 - 0
finisher_api.go

@@ -2,6 +2,7 @@ package gorm
 
 import (
 	"database/sql"
+	"strings"
 
 	"github.com/jinzhu/gorm/clause"
 )
@@ -166,6 +167,8 @@ func (db *DB) Rollback() (tx *DB) {
 
 func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
 	tx = db.getInstance()
+	tx.Statement.SQL = strings.Builder{}
+	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
 	tx.callbacks.Raw().Execute(tx)
 	return
 }

+ 1 - 0
go.mod

@@ -5,4 +5,5 @@ go 1.13
 require (
 	github.com/jinzhu/inflection v1.0.0
 	github.com/jinzhu/now v1.1.1
+	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
 )

+ 21 - 12
migrator/migrator.go

@@ -18,7 +18,8 @@ type Migrator struct {
 
 // Config schema config
 type Config struct {
-	DB *gorm.DB
+	CreateIndexAfterCreateTable bool
+	DB                          *gorm.DB
 	gorm.Dialector
 }
 
@@ -80,9 +81,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
 					}
 
 					// create join table
-					joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
-					if !m.DB.Migrator().HasTable(joinValue) {
-						defer m.DB.Migrator().CreateTable(joinValue)
+					if rel.JoinTable != nil {
+						joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
+						if !m.DB.Migrator().HasTable(joinValue) {
+							defer m.DB.Migrator().CreateTable(joinValue)
+						}
 					}
 				}
 				return nil
@@ -140,8 +143,12 @@ func (m Migrator) CreateTable(values ...interface{}) error {
 			}
 
 			for _, idx := range stmt.Schema.ParseIndexes() {
-				createTableSQL += "INDEX ? ?,"
-				values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
+				if m.CreateIndexAfterCreateTable {
+					m.DB.Migrator().CreateIndex(value, idx.Name)
+				} else {
+					createTableSQL += "INDEX ? ?,"
+					values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
+				}
 			}
 
 			for _, rel := range stmt.Schema.Relationships.Relations {
@@ -152,9 +159,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
 				}
 
 				// create join table
-				joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
-				if !m.DB.Migrator().HasTable(joinValue) {
-					defer m.DB.Migrator().CreateTable(joinValue)
+				if rel.JoinTable != nil {
+					joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
+					if !m.DB.Migrator().HasTable(joinValue) {
+						defer m.DB.Migrator().CreateTable(joinValue)
+					}
 				}
 			}
 
@@ -302,7 +311,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
 	for _, field := range constraint.References {
 		references = append(references, clause.Column{Name: field.DBName})
 	}
-	results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
+	results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
 	return
 }
 
@@ -326,14 +335,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
 		err := fmt.Errorf("failed to create constraint with name %v", name)
 		if field := stmt.Schema.LookUpField(name); field != nil {
 			for _, cc := range checkConstraints {
-				if err = m.CreateIndex(value, cc.Name); err != nil {
+				if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil {
 					return err
 				}
 			}
 
 			for _, rel := range stmt.Schema.Relationships.Relations {
 				if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
-					if err = m.CreateIndex(value, constraint.Name); err != nil {
+					if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil {
 						return err
 					}
 				}

+ 1 - 1
schema/naming.go

@@ -46,7 +46,7 @@ func (ns NamingStrategy) JoinTableName(str string) string {
 
 // RelationshipFKName generate fk name for relation
 func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
-	return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table)
+	return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name))
 }
 
 // CheckerName generate checker name

+ 1 - 1
schema/relationship.go

@@ -339,7 +339,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
 		}
 	}
 
-	if constraint.ReferenceSchema == nil {
+	if rel.JoinTable != nil || constraint.ReferenceSchema == nil {
 		return nil
 	}
 

+ 4 - 1
statement.go

@@ -152,8 +152,11 @@ func (stmt *Statement) AddVar(vars ...interface{}) string {
 				stmt.Vars = append(stmt.Vars, v.Value)
 				placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
 			}
-		case clause.Column:
+		case clause.Column, clause.Table:
 			placeholders.WriteString(stmt.Quote(v))
+		case clause.Expr:
+			placeholders.WriteString(v.SQL)
+			stmt.Vars = append(stmt.Vars, v.Vars...)
 		case []interface{}:
 			if len(v) > 0 {
 				placeholders.WriteByte('(')

+ 6 - 1
tests/dummy_dialecter.go

@@ -2,6 +2,7 @@ package tests
 
 import (
 	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/schema"
 )
 
 type DummyDialector struct {
@@ -11,7 +12,7 @@ func (DummyDialector) Initialize(*gorm.DB) error {
 	return nil
 }
 
-func (DummyDialector) Migrator() gorm.Migrator {
+func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator {
 	return nil
 }
 
@@ -22,3 +23,7 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
 func (DummyDialector) QuoteChars() [2]byte {
 	return [2]byte{'`', '`'} // `name`
 }
+
+func (DummyDialector) DataTypeOf(*schema.Field) string {
+	return ""
+}

+ 12 - 2
tests/migrate.go

@@ -9,11 +9,21 @@ import (
 func TestMigrate(t *testing.T, db *gorm.DB) {
 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}}
 
-	db.AutoMigrate(allModels...)
+	for _, m := range allModels {
+		if db.Migrator().HasTable(m) {
+			if err := db.Migrator().DropTable(m); err != nil {
+				t.Errorf("Failed to drop table, got error %v", err)
+			}
+		}
+	}
+
+	if err := db.AutoMigrate(allModels...); err != nil {
+		t.Errorf("Failed to auto migrate, but got error %v", err)
+	}
 
 	for _, m := range allModels {
 		if !db.Migrator().HasTable(m) {
-			t.Errorf("Failed to create table for %+v", m)
+			t.Errorf("Failed to create table for %#v", m)
 		}
 	}
 }