Browse Source

Add Migrator

Jinzhu 4 years ago
parent
commit
62dcd7896a
5 changed files with 172 additions and 7 deletions
  1. 1 4
      callbacks.go
  2. 2 0
      helpers.go
  3. 10 2
      migrator.go
  4. 152 1
      migrator/migrator.go
  5. 7 0
      statement.go

+ 1 - 4
callbacks.go

@@ -75,13 +75,10 @@ func (p *processor) Execute(db *DB) {
 		}
 
 		if stmt.Model != nil {
-			var err error
-			stmt.Schema, err = schema.Parse(stmt.Model, db.cacheStore, db.NamingStrategy)
+			err := stmt.Parse(stmt.Model)
 
 			if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
 				db.AddError(err)
-			} else if stmt.Table == "" && stmt.Schema != nil {
-				stmt.Table = stmt.Schema.Table
 			}
 		}
 	}

+ 2 - 0
helpers.go

@@ -15,6 +15,8 @@ var (
 	ErrInvalidTransaction = errors.New("no valid transaction")
 	// ErrUnaddressable unaddressable value
 	ErrUnaddressable = errors.New("using unaddressable value")
+	// ErrNotImplemented not implemented
+	ErrNotImplemented = errors.New("not implemented")
 )
 
 // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt

+ 10 - 2
migrator.go

@@ -4,6 +4,11 @@ import (
 	"database/sql"
 )
 
+// Migrator returns migrator
+func (db *DB) Migrator() Migrator {
+	return db.Dialector.Migrator()
+}
+
 // ViewOption view option
 type ViewOption struct {
 	Replace     bool
@@ -15,10 +20,13 @@ type Migrator interface {
 	// AutoMigrate
 	AutoMigrate(dst ...interface{}) error
 
+	// Database
+	CurrentDatabase() string
+
 	// Tables
 	CreateTable(dst ...interface{}) error
 	DropTable(dst ...interface{}) error
-	HasTable(dst ...interface{}) error
+	HasTable(dst ...interface{}) bool
 	RenameTable(oldName, newName string) error
 
 	// Columns
@@ -39,6 +47,6 @@ type Migrator interface {
 	// Indexes
 	CreateIndex(dst interface{}, name string) error
 	DropIndex(dst interface{}, name string) error
-	HasIndex(dst interface{}, name string) error
+	HasIndex(dst interface{}, name string) bool
 	RenameIndex(dst interface{}, oldName, newName string) error
 }

+ 152 - 1
migrator/migrator.go

@@ -1,6 +1,11 @@
 package migrator
 
-import "github.com/jinzhu/gorm"
+import (
+	"database/sql"
+	"fmt"
+
+	"github.com/jinzhu/gorm"
+)
 
 // Migrator migrator struct
 type Migrator struct {
@@ -12,3 +17,149 @@ type Config struct {
 	CheckExistsBeforeDropping bool
 	DB                        *gorm.DB
 }
+
+func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
+	stmt := migrator.DB.Statement
+	if stmt == nil {
+		stmt = &gorm.Statement{DB: migrator.DB}
+	}
+
+	if err := stmt.Parse(value); err != nil {
+		return err
+	}
+
+	return fc(stmt)
+}
+
+// AutoMigrate
+func (migrator Migrator) AutoMigrate(values ...interface{}) error {
+	return gorm.ErrNotImplemented
+}
+
+func (migrator Migrator) CreateTable(values ...interface{}) error {
+	return gorm.ErrNotImplemented
+}
+
+func (migrator Migrator) DropTable(values ...interface{}) error {
+	for _, value := range values {
+		if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+			return migrator.DB.Exec("DROP TABLE " + stmt.Quote(stmt.Table)).Error
+		}); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (migrator Migrator) HasTable(values ...interface{}) bool {
+	var count int64
+	for _, value := range values {
+		err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+			currentDatabase := migrator.DB.Migrator().CurrentDatabase()
+			return migrator.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Scan(&count).Error
+		})
+
+		if err != nil || count == 0 {
+			return false
+		}
+	}
+
+	return true
+}
+
+func (migrator Migrator) RenameTable(oldName, newName string) error {
+	return migrator.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error
+}
+
+func (migrator Migrator) AddColumn(value interface{}, field string) error {
+	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		if field := stmt.Schema.LookUpField(field); field != nil {
+			return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ADD ? %s", field.DBDataType), stmt.Table, field.DBName).Error
+		}
+		return fmt.Errorf("failed to look up field with name: %s", field)
+	})
+}
+
+func (migrator Migrator) DropColumn(value interface{}, field string) error {
+	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		if field := stmt.Schema.LookUpField(field); field != nil {
+			return migrator.DB.Exec("ALTER TABLE ? DROP COLUMN ?", stmt.Table, field.DBName).Error
+		}
+		return fmt.Errorf("failed to look up field with name: %s", field)
+	})
+}
+
+func (migrator Migrator) AlterColumn(value interface{}, field string) error {
+	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		if field := stmt.Schema.LookUpField(field); field != nil {
+			return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ALTER COLUMN ? TYPE %s", field.DBDataType), stmt.Table, field.DBName).Error
+		}
+		return fmt.Errorf("failed to look up field with name: %s", field)
+	})
+}
+
+func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error {
+	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		if field := stmt.Schema.LookUpField(field); field != nil {
+			oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
+			return migrator.DB.Exec("ALTER TABLE ? RENAME COLUMN ? TO ?", stmt.Table, oldName, field.DBName).Error
+		}
+		return fmt.Errorf("failed to look up field with name: %s", field)
+	})
+}
+
+func (migrator Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) {
+	return nil, gorm.ErrNotImplemented
+}
+
+func (migrator Migrator) CreateView(name string, option gorm.ViewOption) error {
+	return gorm.ErrNotImplemented
+}
+
+func (migrator Migrator) DropView(name string) error {
+	return gorm.ErrNotImplemented
+}
+
+func (migrator Migrator) CreateConstraint(value interface{}, name string) error {
+	return gorm.ErrNotImplemented
+}
+
+func (migrator Migrator) DropConstraint(value interface{}, name string) error {
+	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return migrator.DB.Raw("ALTER TABLE ? DROP CONSTRAINT ?", stmt.Table, name).Error
+	})
+}
+
+func (migrator Migrator) CreateIndex(value interface{}, name string) error {
+	return gorm.ErrNotImplemented
+}
+
+func (migrator Migrator) DropIndex(value interface{}, name string) error {
+	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return migrator.DB.Raw("DROP INDEX ? ON ?", name, stmt.Table).Error
+	})
+}
+
+func (migrator Migrator) HasIndex(value interface{}, name string) bool {
+	var count int64
+	migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		currentDatabase := migrator.DB.Migrator().CurrentDatabase()
+		return migrator.DB.Raw("SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name).Scan(&count).Error
+	})
+
+	if count != 0 {
+		return true
+	}
+	return false
+}
+
+func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error {
+	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return migrator.DB.Exec("ALTER TABLE ? RENAME INDEX ? TO ?", stmt.Table, oldName, newName).Error
+	})
+}
+
+func (migrator Migrator) CurrentDatabase() (name string) {
+	migrator.DB.Raw("SELECT DATABASE()").Scan(&name)
+	return
+}

+ 7 - 0
statement.go

@@ -267,3 +267,10 @@ func (stmt *Statement) Build(clauses ...string) {
 	}
 	// TODO handle named vars
 }
+
+func (stmt *Statement) Parse(value interface{}) (err error) {
+	if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
+		stmt.Table = stmt.Schema.Table
+	}
+	return err
+}