Procházet zdrojové kódy

Almost finish Migrator

Jinzhu před 4 roky
rodič
revize
0801cdf164
2 změnil soubory, kde provedl 208 přidání a 44 odebrání
  1. 2 0
      migrator.go
  2. 206 44
      migrator/migrator.go

+ 2 - 0
migrator.go

@@ -33,6 +33,7 @@ type Migrator interface {
 	AddColumn(dst interface{}, field string) error
 	DropColumn(dst interface{}, field string) error
 	AlterColumn(dst interface{}, field string) error
+	HasColumn(dst interface{}, field string) bool
 	RenameColumn(dst interface{}, oldName, field string) error
 	ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)
 
@@ -43,6 +44,7 @@ type Migrator interface {
 	// Constraints
 	CreateConstraint(dst interface{}, name string) error
 	DropConstraint(dst interface{}, name string) error
+	HasConstraint(dst interface{}, name string) bool
 
 	// Indexes
 	CreateIndex(dst interface{}, name string) error

+ 206 - 44
migrator/migrator.go

@@ -3,9 +3,12 @@ package migrator
 import (
 	"database/sql"
 	"fmt"
+	"reflect"
+	"strings"
 
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/clause"
+	"github.com/jinzhu/gorm/schema"
 )
 
 // Migrator migrator struct
@@ -34,19 +37,133 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement
 
 // AutoMigrate
 func (migrator Migrator) AutoMigrate(values ...interface{}) error {
-	// if has table
-	// not -> create table
-	// check columns -> add column, change column type
-	// check foreign keys -> create indexes
-	// check indexes -> create indexes
+	// TODO smart migrate data type
 
-	return gorm.ErrNotImplemented
+	for _, value := range values {
+		if !migrator.DB.Migrator().HasTable(value) {
+			if err := migrator.DB.Migrator().CreateTable(value); err != nil {
+				return err
+			}
+		} else {
+			if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+				for _, field := range stmt.Schema.FieldsByDBName {
+					if !migrator.DB.Migrator().HasColumn(value, field.DBName) {
+						if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil {
+							return err
+						}
+					}
+				}
+
+				for _, rel := range stmt.Schema.Relationships.Relations {
+					if constraint := rel.ParseConstraint(); constraint != nil {
+						if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) {
+							if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil {
+								return err
+							}
+						}
+					}
+
+					for _, chk := range stmt.Schema.ParseCheckConstraints() {
+						if !migrator.DB.Migrator().HasConstraint(value, chk.Name) {
+							if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil {
+								return err
+							}
+						}
+					}
+
+					// create join table
+					joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
+					if !migrator.DB.Migrator().HasTable(joinValue) {
+						defer migrator.DB.Migrator().CreateTable(joinValue)
+					}
+				}
+				return nil
+			}); err != nil {
+				return err
+			}
+		}
+	}
+
+	return nil
 }
 
 func (migrator Migrator) CreateTable(values ...interface{}) error {
-	// migrate
-	// create join table
-	return gorm.ErrNotImplemented
+	for _, value := range values {
+		if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+			var (
+				createTableSQL          = "CREATE TABLE ? ("
+				values                  = []interface{}{clause.Table{Name: stmt.Table}}
+				hasPrimaryKeyInDataType bool
+			)
+
+			for _, dbName := range stmt.Schema.DBNames {
+				field := stmt.Schema.FieldsByDBName[dbName]
+				createTableSQL += fmt.Sprintf("? ?")
+				hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY")
+				values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType})
+
+				if field.AutoIncrement {
+					createTableSQL += " AUTO_INCREMENT"
+				}
+
+				if field.NotNull {
+					createTableSQL += " NOT NULL"
+				}
+
+				if field.Unique {
+					createTableSQL += " UNIQUE"
+				}
+
+				if field.DefaultValue != "" {
+					createTableSQL += " DEFAULT ?"
+					values = append(values, clause.Expr{SQL: field.DefaultValue})
+				}
+				createTableSQL += ","
+			}
+
+			if !hasPrimaryKeyInDataType {
+				createTableSQL += "PRIMARY KEY ?,"
+				primaryKeys := []interface{}{}
+				for _, field := range stmt.Schema.PrimaryFields {
+					primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
+				}
+
+				values = append(values, primaryKeys)
+			}
+
+			for _, idx := range stmt.Schema.ParseIndexes() {
+				createTableSQL += "INDEX ? ?,"
+				values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt))
+			}
+
+			for _, rel := range stmt.Schema.Relationships.Relations {
+				if constraint := rel.ParseConstraint(); constraint != nil {
+					sql, vars := buildConstraint(constraint)
+					createTableSQL += sql + ","
+					values = append(values, vars...)
+				}
+
+				// create join table
+				joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
+				if !migrator.DB.Migrator().HasTable(joinValue) {
+					defer migrator.DB.Migrator().CreateTable(joinValue)
+				}
+			}
+
+			for _, chk := range stmt.Schema.ParseCheckConstraints() {
+				createTableSQL += "CONSTRAINT ? CHECK ?,"
+				values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
+			}
+
+			createTableSQL = strings.TrimSuffix(createTableSQL, ",")
+
+			createTableSQL += ")"
+			return migrator.DB.Exec(createTableSQL, values...).Error
+		}); err != nil {
+			return err
+		}
+	}
+	return nil
 }
 
 func (migrator Migrator) DropTable(values ...interface{}) error {
@@ -115,6 +232,27 @@ func (migrator Migrator) AlterColumn(value interface{}, field string) error {
 	})
 }
 
+func (migrator Migrator) HasColumn(value interface{}, field string) bool {
+	var count int64
+	migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		currentDatabase := migrator.DB.Migrator().CurrentDatabase()
+		name := field
+		if field := stmt.Schema.LookUpField(field); field != nil {
+			name = field.DBName
+		}
+
+		return migrator.DB.Raw(
+			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
+			currentDatabase, stmt.Table, name,
+		).Scan(&count).Error
+	})
+
+	if count != 0 {
+		return true
+	}
+	return false
+}
+
 func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error {
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
 		if field := stmt.Schema.LookUpField(field); field != nil {
@@ -140,6 +278,28 @@ func (migrator Migrator) DropView(name string) error {
 	return gorm.ErrNotImplemented
 }
 
+func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
+	sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
+	if constraint.OnDelete != "" {
+		sql += " ON DELETE " + constraint.OnDelete
+	}
+
+	if constraint.OnUpdate != "" {
+		sql += " ON UPDATE  " + constraint.OnUpdate
+	}
+
+	var foreignKeys, references []interface{}
+	for _, field := range constraint.ForeignKeys {
+		foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
+	}
+
+	for _, field := range constraint.References {
+		references = append(references, clause.Column{Name: field.DBName})
+	}
+	results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
+	return
+}
+
 func (migrator Migrator) CreateConstraint(value interface{}, name string) error {
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
 		checkConstraints := stmt.Schema.ParseCheckConstraints()
@@ -152,26 +312,8 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error
 
 		for _, rel := range stmt.Schema.Relationships.Relations {
 			if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
-				sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
-				if constraint.OnDelete != "" {
-					sql += " ON DELETE " + constraint.OnDelete
-				}
-
-				if constraint.OnUpdate != "" {
-					sql += " ON UPDATE  " + constraint.OnUpdate
-				}
-				var foreignKeys, references []interface{}
-				for _, field := range constraint.ForeignKeys {
-					foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
-				}
-
-				for _, field := range constraint.References {
-					references = append(references, clause.Column{Name: field.DBName})
-				}
-
-				return migrator.DB.Exec(
-					sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references,
-				).Error
+				sql, values := buildConstraint(constraint)
+				return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error
 			}
 		}
 
@@ -205,27 +347,47 @@ func (migrator Migrator) DropConstraint(value interface{}, name string) error {
 	})
 }
 
+func (migrator Migrator) HasConstraint(value interface{}, name string) bool {
+	var count int64
+	migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		currentDatabase := migrator.DB.Migrator().CurrentDatabase()
+		return migrator.DB.Raw(
+			"SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
+			currentDatabase, stmt.Table, name,
+		).Scan(&count).Error
+	})
+
+	if count != 0 {
+		return true
+	}
+	return false
+}
+
+func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
+	for _, opt := range opts {
+		str := stmt.Quote(opt.DBName)
+		if opt.Expression != "" {
+			str = opt.Expression
+		} else if opt.Length > 0 {
+			str += fmt.Sprintf("(%d)", opt.Length)
+		}
+
+		if opt.Sort != "" {
+			str += " " + opt.Sort
+		}
+		results = append(results, clause.Expr{SQL: str})
+	}
+	return
+}
+
 func (migrator Migrator) CreateIndex(value interface{}, name string) error {
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
 		err := fmt.Errorf("failed to create index with name %v", name)
 		indexes := stmt.Schema.ParseIndexes()
 
 		if idx, ok := indexes[name]; ok {
-			fields := []interface{}{}
-			for _, field := range idx.Fields {
-				str := stmt.Quote(field.DBName)
-				if field.Expression != "" {
-					str = field.Expression
-				} else if field.Length > 0 {
-					str += fmt.Sprintf("(%d)", field.Length)
-				}
-
-				if field.Sort != "" {
-					str += " " + field.Sort
-				}
-				fields = append(fields, clause.Expr{SQL: str})
-			}
-			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields}
+			opts := buildIndexOptions(idx.Fields, stmt)
+			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
 
 			createIndexSQL := "CREATE "
 			if idx.Class != "" {