Ver código fonte

Move ModifyColumn implemention to Dialect

Jinzhu 6 anos atrás
pai
commit
89a726ce5d
6 arquivos alterados com 19 adições e 5 exclusões
  1. 2 0
      dialect.go
  2. 5 0
      dialect_common.go
  3. 5 0
      dialect_mysql.go
  4. 5 0
      dialects/mssql/mssql.go
  5. 1 4
      migration_test.go
  6. 1 1
      scope.go

+ 2 - 0
dialect.go

@@ -33,6 +33,8 @@ type Dialect interface {
 	HasTable(tableName string) bool
 	// HasColumn check has column or not
 	HasColumn(tableName string, columnName string) bool
+	// ModifyColumn modify column's type
+	ModifyColumn(tableName string, columnName string, typ string) error
 
 	// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
 	LimitAndOffsetSQL(limit, offset interface{}) string

+ 5 - 0
dialect_common.go

@@ -120,6 +120,11 @@ func (s commonDialect) HasColumn(tableName string, columnName string) bool {
 	return count > 0
 }
 
+func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
+	_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
+	return err
+}
+
 func (s commonDialect) CurrentDatabase() (name string) {
 	s.db.QueryRow("SELECT DATABASE()").Scan(&name)
 	return

+ 5 - 0
dialect_mysql.go

@@ -127,6 +127,11 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error {
 	return err
 }
 
+func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
+	_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
+	return err
+}
+
 func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 	if limit != nil {
 		if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {

+ 5 - 0
dialects/mssql/mssql.go

@@ -140,6 +140,11 @@ func (s mssql) HasColumn(tableName string, columnName string) bool {
 	return count > 0
 }
 
+func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
+	_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
+	return err
+}
+
 func (s mssql) CurrentDatabase() (name string) {
 	s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
 	return

+ 1 - 4
migration_test.go

@@ -435,10 +435,7 @@ func TestMultipleIndexes(t *testing.T) {
 }
 
 func TestModifyColumnType(t *testing.T) {
-	dialect := os.Getenv("GORM_DIALECT")
-	if dialect != "postgres" &&
-		dialect != "mysql" &&
-		dialect != "mssql" {
+	if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
 		t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
 	}
 

+ 1 - 1
scope.go

@@ -1139,7 +1139,7 @@ func (scope *Scope) dropTable() *Scope {
 }
 
 func (scope *Scope) modifyColumn(column string, typ string) {
-	scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
+	scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
 }
 
 func (scope *Scope) dropColumn(column string) {