|
@@ -3,9 +3,12 @@ package migrator
|
|
|
import (
|
|
|
"database/sql"
|
|
|
"fmt"
|
|
|
+ "reflect"
|
|
|
+ "strings"
|
|
|
|
|
|
"github.com/jinzhu/gorm"
|
|
|
"github.com/jinzhu/gorm/clause"
|
|
|
+ "github.com/jinzhu/gorm/schema"
|
|
|
)
|
|
|
|
|
|
// Migrator migrator struct
|
|
@@ -34,19 +37,133 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement
|
|
|
|
|
|
// AutoMigrate
|
|
|
func (migrator Migrator) AutoMigrate(values ...interface{}) error {
|
|
|
- // if has table
|
|
|
- // not -> create table
|
|
|
- // check columns -> add column, change column type
|
|
|
- // check foreign keys -> create indexes
|
|
|
- // check indexes -> create indexes
|
|
|
+ // TODO smart migrate data type
|
|
|
|
|
|
- return gorm.ErrNotImplemented
|
|
|
+ for _, value := range values {
|
|
|
+ if !migrator.DB.Migrator().HasTable(value) {
|
|
|
+ if err := migrator.DB.Migrator().CreateTable(value); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
|
|
+ for _, field := range stmt.Schema.FieldsByDBName {
|
|
|
+ if !migrator.DB.Migrator().HasColumn(value, field.DBName) {
|
|
|
+ if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, rel := range stmt.Schema.Relationships.Relations {
|
|
|
+ if constraint := rel.ParseConstraint(); constraint != nil {
|
|
|
+ if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) {
|
|
|
+ if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
|
|
+ if !migrator.DB.Migrator().HasConstraint(value, chk.Name) {
|
|
|
+ if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // create join table
|
|
|
+ joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
|
|
+ if !migrator.DB.Migrator().HasTable(joinValue) {
|
|
|
+ defer migrator.DB.Migrator().CreateTable(joinValue)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (migrator Migrator) CreateTable(values ...interface{}) error {
|
|
|
- // migrate
|
|
|
- // create join table
|
|
|
- return gorm.ErrNotImplemented
|
|
|
+ for _, value := range values {
|
|
|
+ if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
|
|
+ var (
|
|
|
+ createTableSQL = "CREATE TABLE ? ("
|
|
|
+ values = []interface{}{clause.Table{Name: stmt.Table}}
|
|
|
+ hasPrimaryKeyInDataType bool
|
|
|
+ )
|
|
|
+
|
|
|
+ for _, dbName := range stmt.Schema.DBNames {
|
|
|
+ field := stmt.Schema.FieldsByDBName[dbName]
|
|
|
+ createTableSQL += fmt.Sprintf("? ?")
|
|
|
+ hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY")
|
|
|
+ values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType})
|
|
|
+
|
|
|
+ if field.AutoIncrement {
|
|
|
+ createTableSQL += " AUTO_INCREMENT"
|
|
|
+ }
|
|
|
+
|
|
|
+ if field.NotNull {
|
|
|
+ createTableSQL += " NOT NULL"
|
|
|
+ }
|
|
|
+
|
|
|
+ if field.Unique {
|
|
|
+ createTableSQL += " UNIQUE"
|
|
|
+ }
|
|
|
+
|
|
|
+ if field.DefaultValue != "" {
|
|
|
+ createTableSQL += " DEFAULT ?"
|
|
|
+ values = append(values, clause.Expr{SQL: field.DefaultValue})
|
|
|
+ }
|
|
|
+ createTableSQL += ","
|
|
|
+ }
|
|
|
+
|
|
|
+ if !hasPrimaryKeyInDataType {
|
|
|
+ createTableSQL += "PRIMARY KEY ?,"
|
|
|
+ primaryKeys := []interface{}{}
|
|
|
+ for _, field := range stmt.Schema.PrimaryFields {
|
|
|
+ primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
|
|
+ }
|
|
|
+
|
|
|
+ values = append(values, primaryKeys)
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, idx := range stmt.Schema.ParseIndexes() {
|
|
|
+ createTableSQL += "INDEX ? ?,"
|
|
|
+ values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt))
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, rel := range stmt.Schema.Relationships.Relations {
|
|
|
+ if constraint := rel.ParseConstraint(); constraint != nil {
|
|
|
+ sql, vars := buildConstraint(constraint)
|
|
|
+ createTableSQL += sql + ","
|
|
|
+ values = append(values, vars...)
|
|
|
+ }
|
|
|
+
|
|
|
+ // create join table
|
|
|
+ joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
|
|
+ if !migrator.DB.Migrator().HasTable(joinValue) {
|
|
|
+ defer migrator.DB.Migrator().CreateTable(joinValue)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
|
|
+ createTableSQL += "CONSTRAINT ? CHECK ?,"
|
|
|
+ values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
|
|
+ }
|
|
|
+
|
|
|
+ createTableSQL = strings.TrimSuffix(createTableSQL, ",")
|
|
|
+
|
|
|
+ createTableSQL += ")"
|
|
|
+ return migrator.DB.Exec(createTableSQL, values...).Error
|
|
|
+ }); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (migrator Migrator) DropTable(values ...interface{}) error {
|
|
@@ -115,6 +232,27 @@ func (migrator Migrator) AlterColumn(value interface{}, field string) error {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+func (migrator Migrator) HasColumn(value interface{}, field string) bool {
|
|
|
+ var count int64
|
|
|
+ migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
|
|
+ currentDatabase := migrator.DB.Migrator().CurrentDatabase()
|
|
|
+ name := field
|
|
|
+ if field := stmt.Schema.LookUpField(field); field != nil {
|
|
|
+ name = field.DBName
|
|
|
+ }
|
|
|
+
|
|
|
+ return migrator.DB.Raw(
|
|
|
+ "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
|
|
|
+ currentDatabase, stmt.Table, name,
|
|
|
+ ).Scan(&count).Error
|
|
|
+ })
|
|
|
+
|
|
|
+ if count != 0 {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error {
|
|
|
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
|
|
if field := stmt.Schema.LookUpField(field); field != nil {
|
|
@@ -140,6 +278,28 @@ func (migrator Migrator) DropView(name string) error {
|
|
|
return gorm.ErrNotImplemented
|
|
|
}
|
|
|
|
|
|
+func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
|
|
+ sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
|
|
+ if constraint.OnDelete != "" {
|
|
|
+ sql += " ON DELETE " + constraint.OnDelete
|
|
|
+ }
|
|
|
+
|
|
|
+ if constraint.OnUpdate != "" {
|
|
|
+ sql += " ON UPDATE " + constraint.OnUpdate
|
|
|
+ }
|
|
|
+
|
|
|
+ var foreignKeys, references []interface{}
|
|
|
+ for _, field := range constraint.ForeignKeys {
|
|
|
+ foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, field := range constraint.References {
|
|
|
+ references = append(references, clause.Column{Name: field.DBName})
|
|
|
+ }
|
|
|
+ results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
func (migrator Migrator) CreateConstraint(value interface{}, name string) error {
|
|
|
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
|
|
checkConstraints := stmt.Schema.ParseCheckConstraints()
|
|
@@ -152,26 +312,8 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error
|
|
|
|
|
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
|
|
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
|
|
|
- sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
|
|
- if constraint.OnDelete != "" {
|
|
|
- sql += " ON DELETE " + constraint.OnDelete
|
|
|
- }
|
|
|
-
|
|
|
- if constraint.OnUpdate != "" {
|
|
|
- sql += " ON UPDATE " + constraint.OnUpdate
|
|
|
- }
|
|
|
- var foreignKeys, references []interface{}
|
|
|
- for _, field := range constraint.ForeignKeys {
|
|
|
- foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
|
|
- }
|
|
|
-
|
|
|
- for _, field := range constraint.References {
|
|
|
- references = append(references, clause.Column{Name: field.DBName})
|
|
|
- }
|
|
|
-
|
|
|
- return migrator.DB.Exec(
|
|
|
- sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references,
|
|
|
- ).Error
|
|
|
+ sql, values := buildConstraint(constraint)
|
|
|
+ return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -205,27 +347,47 @@ func (migrator Migrator) DropConstraint(value interface{}, name string) error {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+func (migrator Migrator) HasConstraint(value interface{}, name string) bool {
|
|
|
+ var count int64
|
|
|
+ migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
|
|
+ currentDatabase := migrator.DB.Migrator().CurrentDatabase()
|
|
|
+ return migrator.DB.Raw(
|
|
|
+ "SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
|
|
|
+ currentDatabase, stmt.Table, name,
|
|
|
+ ).Scan(&count).Error
|
|
|
+ })
|
|
|
+
|
|
|
+ if count != 0 {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
|
|
+ for _, opt := range opts {
|
|
|
+ str := stmt.Quote(opt.DBName)
|
|
|
+ if opt.Expression != "" {
|
|
|
+ str = opt.Expression
|
|
|
+ } else if opt.Length > 0 {
|
|
|
+ str += fmt.Sprintf("(%d)", opt.Length)
|
|
|
+ }
|
|
|
+
|
|
|
+ if opt.Sort != "" {
|
|
|
+ str += " " + opt.Sort
|
|
|
+ }
|
|
|
+ results = append(results, clause.Expr{SQL: str})
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
func (migrator Migrator) CreateIndex(value interface{}, name string) error {
|
|
|
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
|
|
err := fmt.Errorf("failed to create index with name %v", name)
|
|
|
indexes := stmt.Schema.ParseIndexes()
|
|
|
|
|
|
if idx, ok := indexes[name]; ok {
|
|
|
- fields := []interface{}{}
|
|
|
- for _, field := range idx.Fields {
|
|
|
- str := stmt.Quote(field.DBName)
|
|
|
- if field.Expression != "" {
|
|
|
- str = field.Expression
|
|
|
- } else if field.Length > 0 {
|
|
|
- str += fmt.Sprintf("(%d)", field.Length)
|
|
|
- }
|
|
|
-
|
|
|
- if field.Sort != "" {
|
|
|
- str += " " + field.Sort
|
|
|
- }
|
|
|
- fields = append(fields, clause.Expr{SQL: str})
|
|
|
- }
|
|
|
- values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields}
|
|
|
+ opts := buildIndexOptions(idx.Fields, stmt)
|
|
|
+ values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
|
|
|
|
|
|
createIndexSQL := "CREATE "
|
|
|
if idx.Class != "" {
|