Parcourir la source

Handle constraint dependencies smartly

Jinzhu il y a 4 ans
Parent
commit
d3c63a03cb
2 fichiers modifiés avec 80 ajouts et 9 suppressions
  1. 74 3
      migrator/migrator.go
  2. 6 6
      tests/migrate.go

+ 74 - 3
migrator/migrator.go

@@ -48,7 +48,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string {
 // AutoMigrate
 func (m Migrator) AutoMigrate(values ...interface{}) error {
 	// TODO smart migrate data type
-	for _, value := range values {
+	for _, value := range m.ReorderModels(values, true) {
 		tx := m.DB.Session(&gorm.Session{})
 		if !tx.Migrator().HasTable(value) {
 			if err := tx.Migrator().CreateTable(value); err != nil {
@@ -100,7 +100,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
 }
 
 func (m Migrator) CreateTable(values ...interface{}) error {
-	for _, value := range values {
+	for _, value := range m.ReorderModels(values, false) {
 		tx := m.DB.Session(&gorm.Session{})
 		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
 			var (
@@ -186,7 +186,9 @@ func (m Migrator) CreateTable(values ...interface{}) error {
 }
 
 func (m Migrator) DropTable(values ...interface{}) error {
-	for _, value := range values {
+	values = m.ReorderModels(values, false)
+	for i := len(values) - 1; i >= 0; i-- {
+		value := values[i]
 		tx := m.DB.Session(&gorm.Session{})
 		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
 			return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
@@ -475,3 +477,72 @@ func (m Migrator) CurrentDatabase() (name string) {
 	m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
 	return
 }
+
+// ReorderModels reorder models according to constraint dependencies
+func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) {
+	type Dependency struct {
+		Table   string
+		Depends []*schema.Schema
+	}
+
+	var (
+		modelNames, orderedModelNames []string
+		orderedModelNamesMap          = map[string]bool{}
+		valuesMap                     = map[string]*gorm.Statement{}
+		dependencies                  = map[string]Dependency{}
+		insertIntoOrderedMap          func(name string)
+	)
+
+	parseDependence := func(value interface{}, addToMap bool) {
+		stmt := &gorm.Statement{DB: m.DB, Dest: value}
+		stmt.Parse(value)
+		dep := Dependency{Table: stmt.Schema.Table}
+
+		for _, rel := range stmt.Schema.Relationships.Relations {
+			if constraint := rel.ParseConstraint(); constraint != nil {
+				dep.Depends = append(dep.Depends, constraint.ReferenceSchema)
+			}
+		}
+		dependencies[stmt.Schema.Table] = dep
+
+		if addToMap {
+			modelNames = append(modelNames, stmt.Schema.Table)
+			valuesMap[stmt.Schema.Table] = stmt
+		}
+	}
+
+	for _, value := range values {
+		parseDependence(value, true)
+	}
+
+	insertIntoOrderedMap = func(name string) {
+		// avoid loop
+		if _, ok := orderedModelNamesMap[name]; ok {
+			return
+		}
+
+		dep := dependencies[name]
+		for _, d := range dep.Depends {
+			if _, ok := valuesMap[d.Table]; ok {
+				if _, ok := orderedModelNamesMap[d.Table]; !ok && name != d.Table {
+					insertIntoOrderedMap(d.Table)
+				}
+			} else if autoAdd {
+				parseDependence(reflect.New(d.ModelType).Interface(), autoAdd)
+				insertIntoOrderedMap(d.Table)
+			}
+		}
+
+		orderedModelNames = append(orderedModelNames, name)
+		orderedModelNamesMap[name] = true
+	}
+
+	for _, name := range modelNames {
+		insertIntoOrderedMap(name)
+	}
+
+	for _, name := range orderedModelNames {
+		results = append(results, valuesMap[name].Dest)
+	}
+	return
+}

+ 6 - 6
tests/migrate.go

@@ -1,20 +1,20 @@
 package tests
 
 import (
+	"math/rand"
 	"testing"
+	"time"
 
 	"github.com/jinzhu/gorm"
 )
 
 func TestMigrate(t *testing.T, db *gorm.DB) {
 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}}
+	rand.Seed(time.Now().UnixNano())
+	rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
 
-	for _, m := range allModels {
-		if db.Migrator().HasTable(m) {
-			if err := db.Migrator().DropTable(m); err != nil {
-				t.Errorf("Failed to drop table, got error %v", err)
-			}
-		}
+	if err := db.Migrator().DropTable(allModels...); err != nil {
+		t.Errorf("Failed to drop table, got error %v", err)
 	}
 
 	if err := db.AutoMigrate(allModels...); err != nil {