Kaynağa Gözat

Add migrator tests for mssql

Jinzhu 4 yıl önce
ebeveyn
işleme
1d803dfdd9

+ 11 - 0
dialects/mssql/migrator.go

@@ -9,6 +9,17 @@ type Migrator struct {
 	migrator.Migrator
 }
 
+func (m Migrator) HasTable(value interface{}) bool {
+	var count int
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Raw(
+			"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?",
+			stmt.Table, m.CurrentDatabase(),
+		).Row().Scan(&count)
+	})
+	return count > 0
+}
+
 func (m Migrator) HasIndex(value interface{}, name string) bool {
 	var count int
 	m.RunWithValue(value, func(stmt *gorm.Statement) error {

+ 12 - 6
dialects/mssql/mssql.go

@@ -3,6 +3,7 @@ package mssql
 import (
 	"database/sql"
 	"fmt"
+	"strconv"
 
 	_ "github.com/denisenkom/go-mssqldb"
 	"github.com/jinzhu/gorm"
@@ -29,17 +30,18 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 
 func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
 	return Migrator{migrator.Migrator{Config: migrator.Config{
-		DB:        db,
-		Dialector: dialector,
+		DB:                          db,
+		Dialector:                   dialector,
+		CreateIndexAfterCreateTable: true,
 	}}}
 }
 
 func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
-	return "?"
+	return "@p" + strconv.Itoa(len(stmt.Vars))
 }
 
 func (dialector Dialector) QuoteChars() [2]byte {
-	return [2]byte{'[', ']'} // `name`
+	return [2]byte{'"', '"'} // `name`
 }
 
 func (dialector Dialector) DataTypeOf(field *schema.Field) string {
@@ -64,8 +66,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
 	case schema.Float:
 		return "decimal"
 	case schema.String:
-		if field.Size > 0 && field.Size <= 4000 {
-			return fmt.Sprintf("nvarchar(%d)", field.Size)
+		size := field.Size
+		if field.PrimaryKey {
+			size = 256
+		}
+		if size > 0 && size <= 4000 {
+			return fmt.Sprintf("nvarchar(%d)", size)
 		}
 		return "ntext"
 	case schema.Time:

+ 29 - 0
dialects/mssql/mssql_test.go

@@ -0,0 +1,29 @@
+package mssql_test
+
+import (
+	"fmt"
+	"testing"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/dialects/mssql"
+	"github.com/jinzhu/gorm/tests"
+)
+
+var (
+	DB  *gorm.DB
+	err error
+)
+
+func init() {
+	if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil {
+		panic(fmt.Sprintf("failed to initialize database, got error %v", err))
+	}
+}
+
+func TestCURD(t *testing.T) {
+	tests.RunTestsSuit(t, DB)
+}
+
+func TestMigrate(t *testing.T) {
+	tests.TestMigrate(t, DB)
+}

+ 7 - 5
migrator/migrator.go

@@ -189,11 +189,13 @@ func (m Migrator) DropTable(values ...interface{}) error {
 	values = m.ReorderModels(values, false)
 	for i := len(values) - 1; i >= 0; i-- {
 		value := values[i]
-		tx := m.DB.Session(&gorm.Session{})
-		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
-			return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
-		}); err != nil {
-			return err
+		if m.DB.Migrator().HasTable(value) {
+			tx := m.DB.Session(&gorm.Session{})
+			if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
+				return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
+			}); err != nil {
+				return err
+			}
 		}
 	}
 	return nil