migrator.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package migrator
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "github.com/jinzhu/gorm"
  6. )
  7. // Migrator migrator struct
  8. type Migrator struct {
  9. *Config
  10. }
  11. // Config schema config
  12. type Config struct {
  13. CheckExistsBeforeDropping bool
  14. DB *gorm.DB
  15. }
  16. func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
  17. stmt := migrator.DB.Statement
  18. if stmt == nil {
  19. stmt = &gorm.Statement{DB: migrator.DB}
  20. }
  21. if err := stmt.Parse(value); err != nil {
  22. return err
  23. }
  24. return fc(stmt)
  25. }
  26. // AutoMigrate
  27. func (migrator Migrator) AutoMigrate(values ...interface{}) error {
  28. return gorm.ErrNotImplemented
  29. }
  30. func (migrator Migrator) CreateTable(values ...interface{}) error {
  31. return gorm.ErrNotImplemented
  32. }
  33. func (migrator Migrator) DropTable(values ...interface{}) error {
  34. for _, value := range values {
  35. if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  36. return migrator.DB.Exec("DROP TABLE " + stmt.Quote(stmt.Table)).Error
  37. }); err != nil {
  38. return err
  39. }
  40. }
  41. return nil
  42. }
  43. func (migrator Migrator) HasTable(values ...interface{}) bool {
  44. var count int64
  45. for _, value := range values {
  46. err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  47. currentDatabase := migrator.DB.Migrator().CurrentDatabase()
  48. return migrator.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Scan(&count).Error
  49. })
  50. if err != nil || count == 0 {
  51. return false
  52. }
  53. }
  54. return true
  55. }
  56. func (migrator Migrator) RenameTable(oldName, newName string) error {
  57. return migrator.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error
  58. }
  59. func (migrator Migrator) AddColumn(value interface{}, field string) error {
  60. return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  61. if field := stmt.Schema.LookUpField(field); field != nil {
  62. return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ADD ? %s", field.DBDataType), stmt.Table, field.DBName).Error
  63. }
  64. return fmt.Errorf("failed to look up field with name: %s", field)
  65. })
  66. }
  67. func (migrator Migrator) DropColumn(value interface{}, field string) error {
  68. return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  69. if field := stmt.Schema.LookUpField(field); field != nil {
  70. return migrator.DB.Exec("ALTER TABLE ? DROP COLUMN ?", stmt.Table, field.DBName).Error
  71. }
  72. return fmt.Errorf("failed to look up field with name: %s", field)
  73. })
  74. }
  75. func (migrator Migrator) AlterColumn(value interface{}, field string) error {
  76. return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  77. if field := stmt.Schema.LookUpField(field); field != nil {
  78. return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ALTER COLUMN ? TYPE %s", field.DBDataType), stmt.Table, field.DBName).Error
  79. }
  80. return fmt.Errorf("failed to look up field with name: %s", field)
  81. })
  82. }
  83. func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error {
  84. return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  85. if field := stmt.Schema.LookUpField(field); field != nil {
  86. oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
  87. return migrator.DB.Exec("ALTER TABLE ? RENAME COLUMN ? TO ?", stmt.Table, oldName, field.DBName).Error
  88. }
  89. return fmt.Errorf("failed to look up field with name: %s", field)
  90. })
  91. }
  92. func (migrator Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) {
  93. return nil, gorm.ErrNotImplemented
  94. }
  95. func (migrator Migrator) CreateView(name string, option gorm.ViewOption) error {
  96. return gorm.ErrNotImplemented
  97. }
  98. func (migrator Migrator) DropView(name string) error {
  99. return gorm.ErrNotImplemented
  100. }
  101. func (migrator Migrator) CreateConstraint(value interface{}, name string) error {
  102. return gorm.ErrNotImplemented
  103. }
  104. func (migrator Migrator) DropConstraint(value interface{}, name string) error {
  105. return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  106. return migrator.DB.Raw("ALTER TABLE ? DROP CONSTRAINT ?", stmt.Table, name).Error
  107. })
  108. }
  109. func (migrator Migrator) CreateIndex(value interface{}, name string) error {
  110. return gorm.ErrNotImplemented
  111. }
  112. func (migrator Migrator) DropIndex(value interface{}, name string) error {
  113. return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  114. return migrator.DB.Raw("DROP INDEX ? ON ?", name, stmt.Table).Error
  115. })
  116. }
  117. func (migrator Migrator) HasIndex(value interface{}, name string) bool {
  118. var count int64
  119. migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  120. currentDatabase := migrator.DB.Migrator().CurrentDatabase()
  121. return migrator.DB.Raw("SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name).Scan(&count).Error
  122. })
  123. if count != 0 {
  124. return true
  125. }
  126. return false
  127. }
  128. func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error {
  129. return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  130. return migrator.DB.Exec("ALTER TABLE ? RENAME INDEX ? TO ?", stmt.Table, oldName, newName).Error
  131. })
  132. }
  133. func (migrator Migrator) CurrentDatabase() (name string) {
  134. migrator.DB.Raw("SELECT DATABASE()").Scan(&name)
  135. return
  136. }