Browse Source

Add DataTypeOf for dialector

Jinzhu 4 years ago
parent
commit
fab7d96da5

+ 37 - 0
dialects/mssql/migrator.go

@@ -0,0 +1,37 @@
+package mssql
+
+import (
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/migrator"
+)
+
+type Migrator struct {
+	migrator.Migrator
+}
+
+func (m Migrator) HasIndex(value interface{}, name string) bool {
+	var count int
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Raw(
+			"SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)",
+			name, stmt.Table,
+		).Row().Scan(&count)
+	})
+	return count > 0
+}
+
+func (m Migrator) HasConstraint(value interface{}, name string) bool {
+	var count int64
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Raw(
+			`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ?  AND T.Name = ? AND I.TABLE_CATALOG = ?;`,
+			name, stmt.Table, m.CurrentDatabase(),
+		).Row().Scan(&count)
+	})
+	return count > 0
+}
+
+func (m Migrator) CurrentDatabase() (name string) {
+	m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
+	return
+}

+ 75 - 0
dialects/mssql/mssql.go

@@ -0,0 +1,75 @@
+package mssql
+
+import (
+	"database/sql"
+	"fmt"
+
+	_ "github.com/denisenkom/go-mssqldb"
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/callbacks"
+	"github.com/jinzhu/gorm/migrator"
+	"github.com/jinzhu/gorm/schema"
+)
+
+type Dialector struct {
+	DSN string
+}
+
+func Open(dsn string) gorm.Dialector {
+	return &Dialector{DSN: dsn}
+}
+
+func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
+	// register callbacks
+	callbacks.RegisterDefaultCallbacks(db)
+
+	db.DB, err = sql.Open("sqlserver", dialector.DSN)
+	return
+}
+
+func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
+	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
+}
+
+func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
+	return "?"
+}
+
+func (dialector Dialector) QuoteChars() [2]byte {
+	return [2]byte{'[', ']'} // `name`
+}
+
+func (dialector Dialector) DataTypeOf(field *schema.Field) string {
+	switch field.DataType {
+	case schema.Bool:
+		return "bit"
+	case schema.Int, schema.Uint:
+		var sqlType string
+		switch {
+		case field.Size < 16:
+			sqlType = "smallint"
+		case field.Size < 31:
+			sqlType = "int"
+		default:
+			sqlType = "bigint"
+		}
+
+		if field.AutoIncrement {
+			return sqlType + " IDENTITY(1,1)"
+		}
+		return sqlType
+	case schema.Float:
+		return "decimal"
+	case schema.String:
+		if field.Size > 0 && field.Size <= 4000 {
+			return fmt.Sprintf("nvarchar(%d)", field.Size)
+		}
+		return "ntext"
+	case schema.Time:
+		return "datetimeoffset"
+	case schema.Bytes:
+		return "binary"
+	}
+
+	return ""
+}

+ 43 - 0
dialects/mysql/migrator.go

@@ -0,0 +1,43 @@
+package mysql
+
+import (
+	"fmt"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/clause"
+	"github.com/jinzhu/gorm/migrator"
+)
+
+type Migrator struct {
+	migrator.Migrator
+}
+
+func (m Migrator) AlterColumn(value interface{}, field string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		if field := stmt.Schema.LookUpField(field); field != nil {
+			return m.DB.Exec(
+				"ALTER TABLE ? MODIFY COLUMN ? TYPE ?",
+				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
+			).Error
+		}
+		return fmt.Errorf("failed to look up field with name: %s", field)
+	})
+}
+
+func (m Migrator) DropConstraint(value interface{}, name string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		for _, chk := range stmt.Schema.ParseCheckConstraints() {
+			if chk.Name == name {
+				return m.DB.Exec(
+					"ALTER TABLE ? DROP CHECK ?",
+					clause.Table{Name: stmt.Table}, clause.Column{Name: name},
+				).Error
+			}
+		}
+
+		return m.DB.Exec(
+			"ALTER TABLE ? DROP FOREIGN KEY ?",
+			clause.Table{Name: stmt.Table}, clause.Column{Name: name},
+		).Error
+	})
+}

+ 77 - 6
dialects/mysql/mysql.go

@@ -1,33 +1,104 @@
 package mysql
 
 import (
+	"database/sql"
+	"fmt"
+	"math"
+
 	_ "github.com/go-sql-driver/mysql"
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/callbacks"
+	"github.com/jinzhu/gorm/migrator"
+	"github.com/jinzhu/gorm/schema"
 )
 
 type Dialector struct {
+	DSN string
 }
 
 func Open(dsn string) gorm.Dialector {
-	return &Dialector{}
+	return &Dialector{DSN: dsn}
 }
 
-func (Dialector) Initialize(db *gorm.DB) error {
+func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 	// register callbacks
 	callbacks.RegisterDefaultCallbacks(db)
+	db.DB, err = sql.Open("sqlite3", dialector.DSN)
 
 	return nil
 }
 
-func (Dialector) Migrator() gorm.Migrator {
-	return nil
+func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
+	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
 }
 
-func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
+func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
 	return "?"
 }
 
-func (Dialector) QuoteChars() [2]byte {
+func (dialector Dialector) QuoteChars() [2]byte {
 	return [2]byte{'`', '`'} // `name`
 }
+
+func (dialector Dialector) DataTypeOf(field *schema.Field) string {
+	switch field.DataType {
+	case schema.Bool:
+		return "boolean"
+	case schema.Int, schema.Uint:
+		sqlType := "int"
+		switch {
+		case field.Size <= 8:
+			sqlType = "tinyint"
+		case field.Size <= 16:
+			sqlType = "smallint"
+		case field.Size <= 32:
+			sqlType = "int"
+		default:
+			sqlType = "bigint"
+		}
+
+		if field.DataType == schema.Uint {
+			sqlType += " unsigned"
+		}
+
+		if field.AutoIncrement {
+			sqlType += " AUTO_INCREMENT"
+		}
+		return sqlType
+	case schema.Float:
+		if field.Size <= 32 {
+			return "float"
+		}
+		return "double"
+	case schema.String:
+		size := field.Size
+		if size >= 65536 && size <= int(math.Pow(2, 24)) {
+			return "mediumtext"
+		} else if size > int(math.Pow(2, 24)) || size < 0 {
+			return "longtext"
+		}
+		return fmt.Sprintf("varchar(%d)", size)
+	case schema.Time:
+		precision := ""
+		if field.Precision > 0 {
+			precision = fmt.Sprintf("(%d)", field.Precision)
+		}
+
+		if field.NotNull || field.PrimaryKey {
+			return "datetime" + precision
+		}
+		return "datetime" + precision + " NULL"
+	case schema.Bytes:
+		if field.Size > 0 && field.Size < 65536 {
+			return fmt.Sprintf("varbinary(%d)", field.Size)
+		}
+
+		if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) {
+			return "mediumblob"
+		}
+
+		return "longblob"
+	}
+
+	return ""
+}

+ 89 - 0
dialects/postgres/migrator.go

@@ -0,0 +1,89 @@
+package postgres
+
+import (
+	"fmt"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/clause"
+	"github.com/jinzhu/gorm/migrator"
+	"github.com/jinzhu/gorm/schema"
+)
+
+type Migrator struct {
+	migrator.Migrator
+}
+
+func (m Migrator) CurrentDatabase() (name string) {
+	m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name)
+	return
+}
+
+func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
+	for _, opt := range opts {
+		str := stmt.Quote(opt.DBName)
+		if opt.Expression != "" {
+			str = opt.Expression
+		}
+
+		if opt.Collate != "" {
+			str += " COLLATE " + opt.Collate
+		}
+
+		if opt.Sort != "" {
+			str += " " + opt.Sort
+		}
+		results = append(results, clause.Expr{SQL: str})
+	}
+	return
+}
+
+func (m Migrator) HasIndex(value interface{}, indexName string) bool {
+	var count int64
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Raw(
+			"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, indexName,
+		).Row().Scan(&count)
+	})
+
+	return count > 0
+}
+
+func (m Migrator) CreateIndex(value interface{}, name string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		err := fmt.Errorf("failed to create index with name %v", name)
+		indexes := stmt.Schema.ParseIndexes()
+
+		if idx, ok := indexes[name]; ok {
+			opts := m.BuildIndexOptions(idx.Fields, stmt)
+			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
+
+			createIndexSQL := "CREATE "
+			if idx.Class != "" {
+				createIndexSQL += idx.Class + " "
+			}
+			createIndexSQL += "INDEX ?"
+
+			if idx.Type != "" {
+				createIndexSQL += " USING " + idx.Type
+			}
+			createIndexSQL += " ON ??"
+
+			if idx.Where != "" {
+				createIndexSQL += " WHERE " + idx.Where
+			}
+
+			return m.DB.Exec(createIndexSQL, values...).Error
+		} else if field := stmt.Schema.LookUpField(name); field != nil {
+			for _, idx := range indexes {
+				for _, idxOpt := range idx.Fields {
+					if idxOpt.Field == field {
+						if err = m.CreateIndex(value, idx.Name); err != nil {
+							return err
+						}
+					}
+				}
+			}
+		}
+		return err
+	})
+}

+ 47 - 4
dialects/postgres/postgres.go

@@ -2,9 +2,12 @@ package postgres
 
 import (
 	"database/sql"
+	"fmt"
 
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/callbacks"
+	"github.com/jinzhu/gorm/migrator"
+	"github.com/jinzhu/gorm/schema"
 	_ "github.com/lib/pq"
 )
 
@@ -24,14 +27,54 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 	return
 }
 
-func (Dialector) Migrator() gorm.Migrator {
-	return nil
+func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
+	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
 }
 
-func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
+func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
 	return "?"
 }
 
-func (Dialector) QuoteChars() [2]byte {
+func (dialector Dialector) QuoteChars() [2]byte {
 	return [2]byte{'"', '"'} // "name"
 }
+
+func (dialector Dialector) DataTypeOf(field *schema.Field) string {
+	switch field.DataType {
+	case schema.Bool:
+		return "boolean"
+	case schema.Int, schema.Uint:
+		if field.AutoIncrement {
+			switch {
+			case field.Size < 16:
+				return "smallserial"
+			case field.Size < 31:
+				return "serial"
+			default:
+				return "bigserial"
+			}
+		} else {
+			switch {
+			case field.Size < 16:
+				return "smallint"
+			case field.Size < 31:
+				return "integer"
+			default:
+				return "bigint"
+			}
+		}
+	case schema.Float:
+		return "decimal"
+	case schema.String:
+		if field.Size > 0 {
+			return fmt.Sprintf("varchar(%d)", field.Size)
+		}
+		return "text"
+	case schema.Time:
+		return "timestamp with time zone"
+	case schema.Bytes:
+		return "bytea"
+	}
+
+	return ""
+}

+ 122 - 0
dialects/sqlite/migrator.go

@@ -0,0 +1,122 @@
+package sqlite
+
+import (
+	"fmt"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/clause"
+	"github.com/jinzhu/gorm/migrator"
+	"github.com/jinzhu/gorm/schema"
+)
+
+type Migrator struct {
+	migrator.Migrator
+}
+
+func (m Migrator) HasTable(value interface{}) bool {
+	var count int
+	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
+	})
+	return count > 0
+}
+
+func (m Migrator) HasColumn(value interface{}, field string) bool {
+	var count int
+	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		name := field
+		if field := stmt.Schema.LookUpField(field); field != nil {
+			name = field.DBName
+		}
+
+		return m.DB.Raw(
+			"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)",
+			stmt.Table, `%"`+name+`" %`, `%`+name+` %`,
+		).Row().Scan(&count)
+	})
+	return count > 0
+}
+
+func (m Migrator) HasIndex(value interface{}, name string) bool {
+	var count int
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Raw(
+			"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE ?",
+			stmt.Table, "%INDEX "+name+" ON%",
+		).Row().Scan(&count)
+	})
+	return count > 0
+}
+
+func (m Migrator) CreateConstraint(interface{}, string) error {
+	return gorm.ErrNotImplemented
+}
+
+func (m Migrator) DropConstraint(interface{}, string) error {
+	return gorm.ErrNotImplemented
+}
+
+func (m Migrator) CurrentDatabase() (name string) {
+	var null interface{}
+	m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
+	return
+}
+
+func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
+	for _, opt := range opts {
+		str := stmt.Quote(opt.DBName)
+		if opt.Expression != "" {
+			str = opt.Expression
+		}
+
+		if opt.Collate != "" {
+			str += " COLLATE " + opt.Collate
+		}
+
+		if opt.Sort != "" {
+			str += " " + opt.Sort
+		}
+		results = append(results, clause.Expr{SQL: str})
+	}
+	return
+}
+
+func (m Migrator) CreateIndex(value interface{}, name string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		err := fmt.Errorf("failed to create index with name %v", name)
+		indexes := stmt.Schema.ParseIndexes()
+
+		if idx, ok := indexes[name]; ok {
+			opts := m.BuildIndexOptions(idx.Fields, stmt)
+			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
+
+			createIndexSQL := "CREATE "
+			if idx.Class != "" {
+				createIndexSQL += idx.Class + " "
+			}
+			createIndexSQL += "INDEX ?"
+
+			if idx.Type != "" {
+				createIndexSQL += " USING " + idx.Type
+			}
+			createIndexSQL += " ON ??"
+
+			if idx.Where != "" {
+				createIndexSQL += " WHERE " + idx.Where
+			}
+
+			return m.DB.Exec(createIndexSQL, values...).Error
+		} else if field := stmt.Schema.LookUpField(name); field != nil {
+			for _, idx := range indexes {
+				for _, idxOpt := range idx.Fields {
+					if idxOpt.Field == field {
+						if err = m.CreateIndex(value, idx.Name); err != nil {
+							return err
+						}
+					}
+				}
+			}
+		}
+		return err
+	})
+}

+ 28 - 4
dialects/sqlite/sqlite.go

@@ -5,6 +5,8 @@ import (
 
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/callbacks"
+	"github.com/jinzhu/gorm/migrator"
+	"github.com/jinzhu/gorm/schema"
 	_ "github.com/mattn/go-sqlite3"
 )
 
@@ -24,14 +26,36 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 	return
 }
 
-func (Dialector) Migrator() gorm.Migrator {
-	return nil
+func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
+	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
 }
 
-func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
+func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
 	return "?"
 }
 
-func (Dialector) QuoteChars() [2]byte {
+func (dialector Dialector) QuoteChars() [2]byte {
 	return [2]byte{'`', '`'} // `name`
 }
+
+func (dialector Dialector) DataTypeOf(field *schema.Field) string {
+	switch field.DataType {
+	case schema.Bool:
+		return "NUMERIC"
+	case schema.Int, schema.Uint:
+		if field.AutoIncrement {
+			// https://www.sqlite.org/autoinc.html
+			return "INTEGER PRIMARY KEY AUTOINCREMENT"
+		} else {
+			return "INTEGER"
+		}
+	case schema.Float:
+		return "REAL"
+	case schema.String, schema.Time:
+		return "TEXT"
+	case schema.Bytes:
+		return "BLOB"
+	}
+
+	return ""
+}

+ 4 - 1
interfaces.go

@@ -3,12 +3,15 @@ package gorm
 import (
 	"context"
 	"database/sql"
+
+	"github.com/jinzhu/gorm/schema"
 )
 
 // Dialector GORM database dialector
 type Dialector interface {
 	Initialize(*DB) error
-	Migrator() Migrator
+	Migrator(db *DB) Migrator
+	DataTypeOf(*schema.Field) string
 	BindVar(stmt *Statement, v interface{}) string
 	QuoteChars() [2]byte
 }

+ 2 - 2
migrator.go

@@ -6,7 +6,7 @@ import (
 
 // Migrator returns migrator
 func (db *DB) Migrator() Migrator {
-	return db.Dialector.Migrator()
+	return db.Dialector.Migrator(db)
 }
 
 // ViewOption view option
@@ -26,7 +26,7 @@ type Migrator interface {
 	// Tables
 	CreateTable(dst ...interface{}) error
 	DropTable(dst ...interface{}) error
-	HasTable(dst ...interface{}) bool
+	HasTable(dst interface{}) bool
 	RenameTable(oldName, newName string) error
 
 	// Columns

+ 112 - 111
migrator/migrator.go

@@ -11,21 +11,21 @@ import (
 	"github.com/jinzhu/gorm/schema"
 )
 
-// Migrator migrator struct
+// Migrator m struct
 type Migrator struct {
-	*Config
+	Config
 }
 
 // Config schema config
 type Config struct {
-	CheckExistsBeforeDropping bool
-	DB                        *gorm.DB
+	DB *gorm.DB
+	gorm.Dialector
 }
 
-func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
-	stmt := migrator.DB.Statement
+func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
+	stmt := m.DB.Statement
 	if stmt == nil {
-		stmt = &gorm.Statement{DB: migrator.DB}
+		stmt = &gorm.Statement{DB: m.DB}
 	}
 
 	if err := stmt.Parse(value); err != nil {
@@ -35,20 +35,28 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement
 	return fc(stmt)
 }
 
+func (m Migrator) DataTypeOf(field *schema.Field) string {
+	if field.DBDataType != "" {
+		return field.DBDataType
+	}
+
+	return m.Dialector.DataTypeOf(field)
+}
+
 // AutoMigrate
-func (migrator Migrator) AutoMigrate(values ...interface{}) error {
+func (m Migrator) AutoMigrate(values ...interface{}) error {
 	// TODO smart migrate data type
 
 	for _, value := range values {
-		if !migrator.DB.Migrator().HasTable(value) {
-			if err := migrator.DB.Migrator().CreateTable(value); err != nil {
+		if !m.DB.Migrator().HasTable(value) {
+			if err := m.DB.Migrator().CreateTable(value); err != nil {
 				return err
 			}
 		} else {
-			if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+			if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
 				for _, field := range stmt.Schema.FieldsByDBName {
-					if !migrator.DB.Migrator().HasColumn(value, field.DBName) {
-						if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil {
+					if !m.DB.Migrator().HasColumn(value, field.DBName) {
+						if err := m.DB.Migrator().AddColumn(value, field.DBName); err != nil {
 							return err
 						}
 					}
@@ -56,16 +64,16 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error {
 
 				for _, rel := range stmt.Schema.Relationships.Relations {
 					if constraint := rel.ParseConstraint(); constraint != nil {
-						if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) {
-							if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil {
+						if !m.DB.Migrator().HasConstraint(value, constraint.Name) {
+							if err := m.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil {
 								return err
 							}
 						}
 					}
 
 					for _, chk := range stmt.Schema.ParseCheckConstraints() {
-						if !migrator.DB.Migrator().HasConstraint(value, chk.Name) {
-							if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil {
+						if !m.DB.Migrator().HasConstraint(value, chk.Name) {
+							if err := m.DB.Migrator().CreateConstraint(value, chk.Name); err != nil {
 								return err
 							}
 						}
@@ -73,8 +81,8 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error {
 
 					// create join table
 					joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
-					if !migrator.DB.Migrator().HasTable(joinValue) {
-						defer migrator.DB.Migrator().CreateTable(joinValue)
+					if !m.DB.Migrator().HasTable(joinValue) {
+						defer m.DB.Migrator().CreateTable(joinValue)
 					}
 				}
 				return nil
@@ -87,9 +95,9 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error {
 	return nil
 }
 
-func (migrator Migrator) CreateTable(values ...interface{}) error {
+func (m Migrator) CreateTable(values ...interface{}) error {
 	for _, value := range values {
-		if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
 			var (
 				createTableSQL          = "CREATE TABLE ? ("
 				values                  = []interface{}{clause.Table{Name: stmt.Table}}
@@ -100,7 +108,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error {
 				field := stmt.Schema.FieldsByDBName[dbName]
 				createTableSQL += fmt.Sprintf("? ?")
 				hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY")
-				values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType})
+				values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.DataTypeOf(field)})
 
 				if field.AutoIncrement {
 					createTableSQL += " AUTO_INCREMENT"
@@ -133,7 +141,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error {
 
 			for _, idx := range stmt.Schema.ParseIndexes() {
 				createTableSQL += "INDEX ? ?,"
-				values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt))
+				values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
 			}
 
 			for _, rel := range stmt.Schema.Relationships.Relations {
@@ -145,8 +153,8 @@ func (migrator Migrator) CreateTable(values ...interface{}) error {
 
 				// create join table
 				joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
-				if !migrator.DB.Migrator().HasTable(joinValue) {
-					defer migrator.DB.Migrator().CreateTable(joinValue)
+				if !m.DB.Migrator().HasTable(joinValue) {
+					defer m.DB.Migrator().CreateTable(joinValue)
 				}
 			}
 
@@ -158,7 +166,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error {
 			createTableSQL = strings.TrimSuffix(createTableSQL, ",")
 
 			createTableSQL += ")"
-			return migrator.DB.Exec(createTableSQL, values...).Error
+			return m.DB.Exec(createTableSQL, values...).Error
 		}); err != nil {
 			return err
 		}
@@ -166,10 +174,10 @@ func (migrator Migrator) CreateTable(values ...interface{}) error {
 	return nil
 }
 
-func (migrator Migrator) DropTable(values ...interface{}) error {
+func (m Migrator) DropTable(values ...interface{}) error {
 	for _, value := range values {
-		if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-			return migrator.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
+		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
+			return m.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
 		}); err != nil {
 			return err
 		}
@@ -177,42 +185,36 @@ func (migrator Migrator) DropTable(values ...interface{}) error {
 	return nil
 }
 
-func (migrator Migrator) HasTable(values ...interface{}) bool {
+func (m Migrator) HasTable(value interface{}) bool {
 	var count int64
-	for _, value := range values {
-		err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-			currentDatabase := migrator.DB.Migrator().CurrentDatabase()
-			return migrator.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Scan(&count).Error
-		})
-
-		if err != nil || count == 0 {
-			return false
-		}
-	}
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		currentDatabase := m.DB.Migrator().CurrentDatabase()
+		return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
+	})
 
-	return true
+	return count > 0
 }
 
-func (migrator Migrator) RenameTable(oldName, newName string) error {
-	return migrator.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error
+func (m Migrator) RenameTable(oldName, newName string) error {
+	return m.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error
 }
 
-func (migrator Migrator) AddColumn(value interface{}, field string) error {
-	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+func (m Migrator) AddColumn(value interface{}, field string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 		if field := stmt.Schema.LookUpField(field); field != nil {
-			return migrator.DB.Exec(
+			return m.DB.Exec(
 				"ALTER TABLE ? ADD ? ?",
-				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
+				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)},
 			).Error
 		}
 		return fmt.Errorf("failed to look up field with name: %s", field)
 	})
 }
 
-func (migrator Migrator) DropColumn(value interface{}, field string) error {
-	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+func (m Migrator) DropColumn(value interface{}, field string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 		if field := stmt.Schema.LookUpField(field); field != nil {
-			return migrator.DB.Exec(
+			return m.DB.Exec(
 				"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName},
 			).Error
 		}
@@ -220,44 +222,41 @@ func (migrator Migrator) DropColumn(value interface{}, field string) error {
 	})
 }
 
-func (migrator Migrator) AlterColumn(value interface{}, field string) error {
-	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+func (m Migrator) AlterColumn(value interface{}, field string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 		if field := stmt.Schema.LookUpField(field); field != nil {
-			return migrator.DB.Exec(
+			return m.DB.Exec(
 				"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
-				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
+				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)},
 			).Error
 		}
 		return fmt.Errorf("failed to look up field with name: %s", field)
 	})
 }
 
-func (migrator Migrator) HasColumn(value interface{}, field string) bool {
+func (m Migrator) HasColumn(value interface{}, field string) bool {
 	var count int64
-	migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-		currentDatabase := migrator.DB.Migrator().CurrentDatabase()
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		currentDatabase := m.DB.Migrator().CurrentDatabase()
 		name := field
 		if field := stmt.Schema.LookUpField(field); field != nil {
 			name = field.DBName
 		}
 
-		return migrator.DB.Raw(
+		return m.DB.Raw(
 			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
 			currentDatabase, stmt.Table, name,
-		).Scan(&count).Error
+		).Row().Scan(&count)
 	})
 
-	if count != 0 {
-		return true
-	}
-	return false
+	return count > 0
 }
 
-func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error {
-	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+func (m Migrator) RenameColumn(value interface{}, oldName, field string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 		if field := stmt.Schema.LookUpField(field); field != nil {
-			oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
-			return migrator.DB.Exec(
+			oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
+			return m.DB.Exec(
 				"ALTER TABLE ? RENAME COLUMN ? TO ?",
 				clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName},
 			).Error
@@ -266,15 +265,15 @@ func (migrator Migrator) RenameColumn(value interface{}, oldName, field string)
 	})
 }
 
-func (migrator Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) {
+func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) {
 	return nil, gorm.ErrNotImplemented
 }
 
-func (migrator Migrator) CreateView(name string, option gorm.ViewOption) error {
+func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
 	return gorm.ErrNotImplemented
 }
 
-func (migrator Migrator) DropView(name string) error {
+func (m Migrator) DropView(name string) error {
 	return gorm.ErrNotImplemented
 }
 
@@ -300,11 +299,11 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
 	return
 }
 
-func (migrator Migrator) CreateConstraint(value interface{}, name string) error {
-	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+func (m Migrator) CreateConstraint(value interface{}, name string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 		checkConstraints := stmt.Schema.ParseCheckConstraints()
 		if chk, ok := checkConstraints[name]; ok {
-			return migrator.DB.Exec(
+			return m.DB.Exec(
 				"ALTER TABLE ? ADD CONSTRAINT ? CHECK ?",
 				clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
 			).Error
@@ -313,21 +312,21 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error
 		for _, rel := range stmt.Schema.Relationships.Relations {
 			if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
 				sql, values := buildConstraint(constraint)
-				return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error
+				return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error
 			}
 		}
 
 		err := fmt.Errorf("failed to create constraint with name %v", name)
 		if field := stmt.Schema.LookUpField(name); field != nil {
 			for _, cc := range checkConstraints {
-				if err = migrator.CreateIndex(value, cc.Name); err != nil {
+				if err = m.CreateIndex(value, cc.Name); err != nil {
 					return err
 				}
 			}
 
 			for _, rel := range stmt.Schema.Relationships.Relations {
 				if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
-					if err = migrator.CreateIndex(value, constraint.Name); err != nil {
+					if err = m.CreateIndex(value, constraint.Name); err != nil {
 						return err
 					}
 				}
@@ -338,32 +337,29 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error
 	})
 }
 
-func (migrator Migrator) DropConstraint(value interface{}, name string) error {
-	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-		return migrator.DB.Exec(
+func (m Migrator) DropConstraint(value interface{}, name string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Exec(
 			"ALTER TABLE ? DROP CONSTRAINT ?",
 			clause.Table{Name: stmt.Table}, clause.Column{Name: name},
 		).Error
 	})
 }
 
-func (migrator Migrator) HasConstraint(value interface{}, name string) bool {
+func (m Migrator) HasConstraint(value interface{}, name string) bool {
 	var count int64
-	migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-		currentDatabase := migrator.DB.Migrator().CurrentDatabase()
-		return migrator.DB.Raw(
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		currentDatabase := m.DB.Migrator().CurrentDatabase()
+		return m.DB.Raw(
 			"SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
 			currentDatabase, stmt.Table, name,
-		).Scan(&count).Error
+		).Row().Scan(&count)
 	})
 
-	if count != 0 {
-		return true
-	}
-	return false
+	return count > 0
 }
 
-func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
+func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
 	for _, opt := range opts {
 		str := stmt.Quote(opt.DBName)
 		if opt.Expression != "" {
@@ -372,6 +368,10 @@ func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results
 			str += fmt.Sprintf("(%d)", opt.Length)
 		}
 
+		if opt.Collate != "" {
+			str += " COLLATE " + opt.Collate
+		}
+
 		if opt.Sort != "" {
 			str += " " + opt.Sort
 		}
@@ -380,13 +380,17 @@ func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results
 	return
 }
 
-func (migrator Migrator) CreateIndex(value interface{}, name string) error {
-	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+type BuildIndexOptionsInterface interface {
+	BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
+}
+
+func (m Migrator) CreateIndex(value interface{}, name string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 		err := fmt.Errorf("failed to create index with name %v", name)
 		indexes := stmt.Schema.ParseIndexes()
 
 		if idx, ok := indexes[name]; ok {
-			opts := buildIndexOptions(idx.Fields, stmt)
+			opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
 
 			createIndexSQL := "CREATE "
@@ -404,12 +408,12 @@ func (migrator Migrator) CreateIndex(value interface{}, name string) error {
 				createIndexSQL += " USING " + idx.Type
 			}
 
-			return migrator.DB.Raw(createIndexSQL, values...).Error
+			return m.DB.Exec(createIndexSQL, values...).Error
 		} else if field := stmt.Schema.LookUpField(name); field != nil {
 			for _, idx := range indexes {
 				for _, idxOpt := range idx.Fields {
 					if idxOpt.Field == field {
-						if err = migrator.CreateIndex(value, idx.Name); err != nil {
+						if err = m.CreateIndex(value, idx.Name); err != nil {
 							return err
 						}
 					}
@@ -420,38 +424,35 @@ func (migrator Migrator) CreateIndex(value interface{}, name string) error {
 	})
 }
 
-func (migrator Migrator) DropIndex(value interface{}, name string) error {
-	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-		return migrator.DB.Raw("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error
+func (m Migrator) DropIndex(value interface{}, name string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error
 	})
 }
 
-func (migrator Migrator) HasIndex(value interface{}, name string) bool {
+func (m Migrator) HasIndex(value interface{}, name string) bool {
 	var count int64
-	migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-		currentDatabase := migrator.DB.Migrator().CurrentDatabase()
-		return migrator.DB.Raw(
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		currentDatabase := m.DB.Migrator().CurrentDatabase()
+		return m.DB.Raw(
 			"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
 			currentDatabase, stmt.Table, name,
-		).Scan(&count).Error
+		).Row().Scan(&count)
 	})
 
-	if count != 0 {
-		return true
-	}
-	return false
+	return count > 0
 }
 
-func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error {
-	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-		return migrator.DB.Exec(
+func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
+	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Exec(
 			"ALTER TABLE ? RENAME INDEX ? TO ?",
 			clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
 		).Error
 	})
 }
 
-func (migrator Migrator) CurrentDatabase() (name string) {
-	migrator.DB.Raw("SELECT DATABASE()").Scan(&name)
+func (m Migrator) CurrentDatabase() (name string) {
+	m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
 	return
 }

+ 4 - 1
schema/field.go

@@ -138,7 +138,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 	}
 
 	if num, ok := field.TagSettings["SIZE"]; ok {
-		field.Size, _ = strconv.Atoi(num)
+		var err error
+		if field.Size, err = strconv.Atoi(num); err != nil {
+			field.Size = -1
+		}
 	}
 
 	if p, ok := field.TagSettings["PRECISION"]; ok {