Prechádzať zdrojové kódy

Finish CreateConstraint

Jinzhu 4 rokov pred
rodič
commit
0be4817ff9
6 zmenil súbory, kde vykonal 241 pridanie a 20 odobranie
  1. 1 1
      clause/expression.go
  2. 141 11
      migrator/migrator.go
  3. 29 0
      schema/check.go
  4. 2 2
      schema/index_test.go
  5. 19 6
      schema/naming.go
  6. 49 0
      schema/relationship.go

+ 1 - 1
clause/expression.go

@@ -22,7 +22,7 @@ type Expr struct {
 func (expr Expr) Build(builder Builder) {
 	sql := expr.SQL
 	for _, v := range expr.Vars {
-		sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1)
+		sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1)
 	}
 	builder.Write(sql)
 }

+ 141 - 11
migrator/migrator.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 
 	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/clause"
 )
 
 // Migrator migrator struct
@@ -33,17 +34,25 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement
 
 // AutoMigrate
 func (migrator Migrator) AutoMigrate(values ...interface{}) error {
+	// if has table
+	// not -> create table
+	// check columns -> add column, change column type
+	// check foreign keys -> create indexes
+	// check indexes -> create indexes
+
 	return gorm.ErrNotImplemented
 }
 
 func (migrator Migrator) CreateTable(values ...interface{}) error {
+	// migrate
+	// create join table
 	return gorm.ErrNotImplemented
 }
 
 func (migrator Migrator) DropTable(values ...interface{}) error {
 	for _, value := range values {
 		if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-			return migrator.DB.Exec("DROP TABLE " + stmt.Quote(stmt.Table)).Error
+			return migrator.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
 		}); err != nil {
 			return err
 		}
@@ -74,7 +83,10 @@ func (migrator Migrator) RenameTable(oldName, newName string) error {
 func (migrator Migrator) AddColumn(value interface{}, field string) error {
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
 		if field := stmt.Schema.LookUpField(field); field != nil {
-			return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ADD ? %s", field.DBDataType), stmt.Table, field.DBName).Error
+			return migrator.DB.Exec(
+				"ALTER TABLE ? ADD ? ?",
+				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
+			).Error
 		}
 		return fmt.Errorf("failed to look up field with name: %s", field)
 	})
@@ -83,7 +95,9 @@ func (migrator Migrator) AddColumn(value interface{}, field string) error {
 func (migrator Migrator) DropColumn(value interface{}, field string) error {
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
 		if field := stmt.Schema.LookUpField(field); field != nil {
-			return migrator.DB.Exec("ALTER TABLE ? DROP COLUMN ?", stmt.Table, field.DBName).Error
+			return migrator.DB.Exec(
+				"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName},
+			).Error
 		}
 		return fmt.Errorf("failed to look up field with name: %s", field)
 	})
@@ -92,7 +106,10 @@ func (migrator Migrator) DropColumn(value interface{}, field string) error {
 func (migrator Migrator) AlterColumn(value interface{}, field string) error {
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
 		if field := stmt.Schema.LookUpField(field); field != nil {
-			return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ALTER COLUMN ? TYPE %s", field.DBDataType), stmt.Table, field.DBName).Error
+			return migrator.DB.Exec(
+				"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
+				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
+			).Error
 		}
 		return fmt.Errorf("failed to look up field with name: %s", field)
 	})
@@ -102,7 +119,10 @@ func (migrator Migrator) RenameColumn(value interface{}, oldName, field string)
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
 		if field := stmt.Schema.LookUpField(field); field != nil {
 			oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
-			return migrator.DB.Exec("ALTER TABLE ? RENAME COLUMN ? TO ?", stmt.Table, oldName, field.DBName).Error
+			return migrator.DB.Exec(
+				"ALTER TABLE ? RENAME COLUMN ? TO ?",
+				clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName},
+			).Error
 		}
 		return fmt.Errorf("failed to look up field with name: %s", field)
 	})
@@ -121,22 +141,126 @@ func (migrator Migrator) DropView(name string) error {
 }
 
 func (migrator Migrator) CreateConstraint(value interface{}, name string) error {
-	return gorm.ErrNotImplemented
+	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		checkConstraints := stmt.Schema.ParseCheckConstraints()
+		if chk, ok := checkConstraints[name]; ok {
+			return migrator.DB.Exec(
+				"ALTER TABLE ? ADD CONSTRAINT ? CHECK ?",
+				clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
+			).Error
+		}
+
+		for _, rel := range stmt.Schema.Relationships.Relations {
+			if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
+				sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
+				if constraint.OnDelete != "" {
+					sql += " ON DELETE " + constraint.OnDelete
+				}
+
+				if constraint.OnUpdate != "" {
+					sql += " ON UPDATE  " + constraint.OnUpdate
+				}
+				var foreignKeys, references []interface{}
+				for _, field := range constraint.ForeignKeys {
+					foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
+				}
+
+				for _, field := range constraint.References {
+					references = append(references, clause.Column{Name: field.DBName})
+				}
+
+				return migrator.DB.Exec(
+					sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references,
+				).Error
+			}
+		}
+
+		err := fmt.Errorf("failed to create constraint with name %v", name)
+		if field := stmt.Schema.LookUpField(name); field != nil {
+			for _, cc := range checkConstraints {
+				if err = migrator.CreateIndex(value, cc.Name); err != nil {
+					return err
+				}
+			}
+
+			for _, rel := range stmt.Schema.Relationships.Relations {
+				if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
+					if err = migrator.CreateIndex(value, constraint.Name); err != nil {
+						return err
+					}
+				}
+			}
+		}
+
+		return err
+	})
 }
 
 func (migrator Migrator) DropConstraint(value interface{}, name string) error {
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-		return migrator.DB.Raw("ALTER TABLE ? DROP CONSTRAINT ?", stmt.Table, name).Error
+		return migrator.DB.Exec(
+			"ALTER TABLE ? DROP CONSTRAINT ?",
+			clause.Table{Name: stmt.Table}, clause.Column{Name: name},
+		).Error
 	})
 }
 
 func (migrator Migrator) CreateIndex(value interface{}, name string) error {
-	return gorm.ErrNotImplemented
+	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		err := fmt.Errorf("failed to create index with name %v", name)
+		indexes := stmt.Schema.ParseIndexes()
+
+		if idx, ok := indexes[name]; ok {
+			fields := []interface{}{}
+			for _, field := range idx.Fields {
+				str := stmt.Quote(field.DBName)
+				if field.Expression != "" {
+					str = field.Expression
+				} else if field.Length > 0 {
+					str += fmt.Sprintf("(%d)", field.Length)
+				}
+
+				if field.Sort != "" {
+					str += " " + field.Sort
+				}
+				fields = append(fields, clause.Expr{SQL: str})
+			}
+			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields}
+
+			createIndexSQL := "CREATE "
+			if idx.Class != "" {
+				createIndexSQL += idx.Class + " "
+			}
+			createIndexSQL += "INDEX ? ON ??"
+
+			if idx.Comment != "" {
+				values = append(values, idx.Comment)
+				createIndexSQL += " COMMENT ?"
+			}
+
+			if idx.Type != "" {
+				createIndexSQL += " USING " + idx.Type
+			}
+
+			return migrator.DB.Raw(createIndexSQL, values...).Error
+		} else if field := stmt.Schema.LookUpField(name); field != nil {
+			for _, idx := range indexes {
+				for _, idxOpt := range idx.Fields {
+					if idxOpt.Field == field {
+						if err = migrator.CreateIndex(value, idx.Name); err != nil {
+							return err
+						}
+					}
+				}
+			}
+		}
+		return err
+	})
 }
 
 func (migrator Migrator) DropIndex(value interface{}, name string) error {
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-		return migrator.DB.Raw("DROP INDEX ? ON ?", name, stmt.Table).Error
+		return migrator.DB.Raw("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error
 	})
 }
 
@@ -144,7 +268,10 @@ func (migrator Migrator) HasIndex(value interface{}, name string) bool {
 	var count int64
 	migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
 		currentDatabase := migrator.DB.Migrator().CurrentDatabase()
-		return migrator.DB.Raw("SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name).Scan(&count).Error
+		return migrator.DB.Raw(
+			"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
+			currentDatabase, stmt.Table, name,
+		).Scan(&count).Error
 	})
 
 	if count != 0 {
@@ -155,7 +282,10 @@ func (migrator Migrator) HasIndex(value interface{}, name string) bool {
 
 func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error {
 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
-		return migrator.DB.Exec("ALTER TABLE ? RENAME INDEX ? TO ?", stmt.Table, oldName, newName).Error
+		return migrator.DB.Exec(
+			"ALTER TABLE ? RENAME INDEX ? TO ?",
+			clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
+		).Error
 	})
 }
 

+ 29 - 0
schema/check.go

@@ -0,0 +1,29 @@
+package schema
+
+import (
+	"regexp"
+	"strings"
+)
+
+type Check struct {
+	Name       string
+	Constraint string // length(phone) >= 10
+	*Field
+}
+
+// ParseCheckConstraints parse schema check constraints
+func (schema *Schema) ParseCheckConstraints() map[string]Check {
+	var checks = map[string]Check{}
+	for _, field := range schema.FieldsByDBName {
+		if chk := field.TagSettings["CHECK"]; chk != "" {
+			names := strings.Split(chk, ",")
+			if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) {
+				checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
+			} else {
+				name := schema.namer.CheckerName(schema.Table, field.DBName)
+				checks[name] = Check{Name: name, Constraint: chk, Field: field}
+			}
+		}
+	}
+	return checks
+}

+ 2 - 2
schema/index_test.go

@@ -15,7 +15,7 @@ type UserIndex struct {
 	Name4 string `gorm:"unique_index"`
 	Name5 int64  `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"`
 	Name6 int64  `gorm:"index:profile,comment:hello \\, world,where:age > 10"`
-	Age   int64  `gorm:"index:profile,expression:(age+10)"`
+	Age   int64  `gorm:"index:profile,expression:ABS(age)"`
 }
 
 func TestParseIndex(t *testing.T) {
@@ -61,7 +61,7 @@ func TestParseIndex(t *testing.T) {
 			Comment: "hello , world",
 			Where:   "age > 10",
 			Fields: []schema.IndexOption{{}, {
-				Expression: "(age+10)",
+				Expression: "ABS(age)",
 			}},
 		},
 	}

+ 19 - 6
schema/naming.go

@@ -14,8 +14,10 @@ import (
 type Namer interface {
 	TableName(table string) string
 	ColumnName(table, column string) string
-	IndexName(table, column string) string
 	JoinTableName(table string) string
+	RelationshipFKName(Relationship) string
+	CheckerName(table, column string) string
+	IndexName(table, column string) string
 }
 
 // NamingStrategy tables, columns naming strategy
@@ -37,6 +39,22 @@ func (ns NamingStrategy) ColumnName(table, column string) string {
 	return toDBName(column)
 }
 
+// JoinTableName convert string to join table name
+func (ns NamingStrategy) JoinTableName(str string) string {
+	return ns.TablePrefix + inflection.Plural(toDBName(str))
+}
+
+// RelationshipFKName generate fk name for relation
+func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
+	return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table)
+}
+
+// CheckerName generate checker name
+func (ns NamingStrategy) CheckerName(table, column string) string {
+	return fmt.Sprintf("chk_%s_%s", table, column)
+}
+
+// IndexName generate index name
 func (ns NamingStrategy) IndexName(table, column string) string {
 	idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
 
@@ -50,11 +68,6 @@ func (ns NamingStrategy) IndexName(table, column string) string {
 	return idxName
 }
 
-// JoinTableName convert string to join table name
-func (ns NamingStrategy) JoinTableName(str string) string {
-	return ns.TablePrefix + inflection.Plural(toDBName(str))
-}
-
 var (
 	smap sync.Map
 	// https://github.com/golang/lint/blob/master/lint.go#L770

+ 49 - 0
schema/relationship.go

@@ -3,6 +3,7 @@ package schema
 import (
 	"fmt"
 	"reflect"
+	"regexp"
 	"strings"
 
 	"github.com/jinzhu/inflection"
@@ -292,3 +293,51 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
 		relation.Type = BelongsTo
 	}
 }
+
+type Constraint struct {
+	Name            string
+	Field           *Field
+	Schema          *Schema
+	ForeignKeys     []*Field
+	ReferenceSchema *Schema
+	References      []*Field
+	OnDelete        string
+	OnUpdate        string
+}
+
+func (rel *Relationship) ParseConstraint() *Constraint {
+	str := rel.Field.TagSettings["CONSTRAINT"]
+	if str == "-" {
+		return nil
+	}
+
+	var (
+		name     string
+		idx      = strings.Index(str, ",")
+		settings = ParseTagSetting(str, ",")
+	)
+
+	if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) {
+		name = str[0:idx]
+	} else {
+		name = rel.Schema.namer.RelationshipFKName(*rel)
+	}
+
+	constraint := Constraint{
+		Name:     name,
+		Field:    rel.Field,
+		OnUpdate: settings["ONUPDATE"],
+		OnDelete: settings["ONDELETE"],
+		Schema:   rel.Schema,
+	}
+
+	for _, ref := range rel.References {
+		if ref.PrimaryKey != nil && !ref.OwnPrimaryKey {
+			constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
+			constraint.References = append(constraint.References, ref.PrimaryKey)
+			constraint.ReferenceSchema = ref.PrimaryKey.Schema
+		}
+	}
+
+	return &constraint
+}