migrator.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. package migrator
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "reflect"
  6. "strings"
  7. "github.com/jinzhu/gorm"
  8. "github.com/jinzhu/gorm/clause"
  9. "github.com/jinzhu/gorm/schema"
  10. )
  11. // Migrator m struct
  12. type Migrator struct {
  13. Config
  14. }
  15. // Config schema config
  16. type Config struct {
  17. CreateIndexAfterCreateTable bool
  18. DB *gorm.DB
  19. gorm.Dialector
  20. }
  21. func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
  22. stmt := m.DB.Statement
  23. if stmt == nil {
  24. stmt = &gorm.Statement{DB: m.DB}
  25. }
  26. if err := stmt.Parse(value); err != nil {
  27. return err
  28. }
  29. return fc(stmt)
  30. }
  31. func (m Migrator) DataTypeOf(field *schema.Field) string {
  32. if field.DBDataType != "" {
  33. return field.DBDataType
  34. }
  35. return m.Dialector.DataTypeOf(field)
  36. }
  37. // AutoMigrate
  38. func (m Migrator) AutoMigrate(values ...interface{}) error {
  39. // TODO smart migrate data type
  40. for _, value := range values {
  41. if !m.DB.Migrator().HasTable(value) {
  42. if err := m.DB.Migrator().CreateTable(value); err != nil {
  43. return err
  44. }
  45. } else {
  46. if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
  47. for _, field := range stmt.Schema.FieldsByDBName {
  48. if !m.DB.Migrator().HasColumn(value, field.DBName) {
  49. if err := m.DB.Migrator().AddColumn(value, field.DBName); err != nil {
  50. return err
  51. }
  52. }
  53. }
  54. for _, rel := range stmt.Schema.Relationships.Relations {
  55. if constraint := rel.ParseConstraint(); constraint != nil {
  56. if !m.DB.Migrator().HasConstraint(value, constraint.Name) {
  57. if err := m.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil {
  58. return err
  59. }
  60. }
  61. }
  62. for _, chk := range stmt.Schema.ParseCheckConstraints() {
  63. if !m.DB.Migrator().HasConstraint(value, chk.Name) {
  64. if err := m.DB.Migrator().CreateConstraint(value, chk.Name); err != nil {
  65. return err
  66. }
  67. }
  68. }
  69. // create join table
  70. if rel.JoinTable != nil {
  71. joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
  72. if !m.DB.Migrator().HasTable(joinValue) {
  73. defer m.DB.Migrator().CreateTable(joinValue)
  74. }
  75. }
  76. }
  77. return nil
  78. }); err != nil {
  79. return err
  80. }
  81. }
  82. }
  83. return nil
  84. }
  85. func (m Migrator) CreateTable(values ...interface{}) error {
  86. for _, value := range values {
  87. if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
  88. var (
  89. createTableSQL = "CREATE TABLE ? ("
  90. values = []interface{}{clause.Table{Name: stmt.Table}}
  91. hasPrimaryKeyInDataType bool
  92. )
  93. for _, dbName := range stmt.Schema.DBNames {
  94. field := stmt.Schema.FieldsByDBName[dbName]
  95. createTableSQL += fmt.Sprintf("? ?")
  96. hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY")
  97. values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.DataTypeOf(field)})
  98. if field.AutoIncrement {
  99. createTableSQL += " AUTO_INCREMENT"
  100. }
  101. if field.NotNull {
  102. createTableSQL += " NOT NULL"
  103. }
  104. if field.Unique {
  105. createTableSQL += " UNIQUE"
  106. }
  107. if field.DefaultValue != "" {
  108. createTableSQL += " DEFAULT ?"
  109. values = append(values, clause.Expr{SQL: field.DefaultValue})
  110. }
  111. createTableSQL += ","
  112. }
  113. if !hasPrimaryKeyInDataType {
  114. createTableSQL += "PRIMARY KEY ?,"
  115. primaryKeys := []interface{}{}
  116. for _, field := range stmt.Schema.PrimaryFields {
  117. primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
  118. }
  119. values = append(values, primaryKeys)
  120. }
  121. for _, idx := range stmt.Schema.ParseIndexes() {
  122. if m.CreateIndexAfterCreateTable {
  123. m.DB.Migrator().CreateIndex(value, idx.Name)
  124. } else {
  125. createTableSQL += "INDEX ? ?,"
  126. values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
  127. }
  128. }
  129. for _, rel := range stmt.Schema.Relationships.Relations {
  130. if constraint := rel.ParseConstraint(); constraint != nil {
  131. sql, vars := buildConstraint(constraint)
  132. createTableSQL += sql + ","
  133. values = append(values, vars...)
  134. }
  135. // create join table
  136. if rel.JoinTable != nil {
  137. joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
  138. if !m.DB.Migrator().HasTable(joinValue) {
  139. defer m.DB.Migrator().CreateTable(joinValue)
  140. }
  141. }
  142. }
  143. for _, chk := range stmt.Schema.ParseCheckConstraints() {
  144. createTableSQL += "CONSTRAINT ? CHECK ?,"
  145. values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
  146. }
  147. createTableSQL = strings.TrimSuffix(createTableSQL, ",")
  148. createTableSQL += ")"
  149. return m.DB.Exec(createTableSQL, values...).Error
  150. }); err != nil {
  151. return err
  152. }
  153. }
  154. return nil
  155. }
  156. func (m Migrator) DropTable(values ...interface{}) error {
  157. for _, value := range values {
  158. if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
  159. return m.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
  160. }); err != nil {
  161. return err
  162. }
  163. }
  164. return nil
  165. }
  166. func (m Migrator) HasTable(value interface{}) bool {
  167. var count int64
  168. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  169. currentDatabase := m.DB.Migrator().CurrentDatabase()
  170. return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
  171. })
  172. return count > 0
  173. }
  174. func (m Migrator) RenameTable(oldName, newName string) error {
  175. return m.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error
  176. }
  177. func (m Migrator) AddColumn(value interface{}, field string) error {
  178. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  179. if field := stmt.Schema.LookUpField(field); field != nil {
  180. return m.DB.Exec(
  181. "ALTER TABLE ? ADD ? ?",
  182. clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)},
  183. ).Error
  184. }
  185. return fmt.Errorf("failed to look up field with name: %s", field)
  186. })
  187. }
  188. func (m Migrator) DropColumn(value interface{}, field string) error {
  189. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  190. if field := stmt.Schema.LookUpField(field); field != nil {
  191. return m.DB.Exec(
  192. "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName},
  193. ).Error
  194. }
  195. return fmt.Errorf("failed to look up field with name: %s", field)
  196. })
  197. }
  198. func (m Migrator) AlterColumn(value interface{}, field string) error {
  199. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  200. if field := stmt.Schema.LookUpField(field); field != nil {
  201. return m.DB.Exec(
  202. "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
  203. clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)},
  204. ).Error
  205. }
  206. return fmt.Errorf("failed to look up field with name: %s", field)
  207. })
  208. }
  209. func (m Migrator) HasColumn(value interface{}, field string) bool {
  210. var count int64
  211. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  212. currentDatabase := m.DB.Migrator().CurrentDatabase()
  213. name := field
  214. if field := stmt.Schema.LookUpField(field); field != nil {
  215. name = field.DBName
  216. }
  217. return m.DB.Raw(
  218. "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
  219. currentDatabase, stmt.Table, name,
  220. ).Row().Scan(&count)
  221. })
  222. return count > 0
  223. }
  224. func (m Migrator) RenameColumn(value interface{}, oldName, field string) error {
  225. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  226. if field := stmt.Schema.LookUpField(field); field != nil {
  227. oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
  228. return m.DB.Exec(
  229. "ALTER TABLE ? RENAME COLUMN ? TO ?",
  230. clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName},
  231. ).Error
  232. }
  233. return fmt.Errorf("failed to look up field with name: %s", field)
  234. })
  235. }
  236. func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
  237. err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
  238. rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
  239. if err == nil {
  240. columnTypes, err = rows.ColumnTypes()
  241. }
  242. return err
  243. })
  244. return
  245. }
  246. func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
  247. return gorm.ErrNotImplemented
  248. }
  249. func (m Migrator) DropView(name string) error {
  250. return gorm.ErrNotImplemented
  251. }
  252. func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
  253. sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
  254. if constraint.OnDelete != "" {
  255. sql += " ON DELETE " + constraint.OnDelete
  256. }
  257. if constraint.OnUpdate != "" {
  258. sql += " ON UPDATE " + constraint.OnUpdate
  259. }
  260. var foreignKeys, references []interface{}
  261. for _, field := range constraint.ForeignKeys {
  262. foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
  263. }
  264. for _, field := range constraint.References {
  265. references = append(references, clause.Column{Name: field.DBName})
  266. }
  267. results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
  268. return
  269. }
  270. func (m Migrator) CreateConstraint(value interface{}, name string) error {
  271. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  272. checkConstraints := stmt.Schema.ParseCheckConstraints()
  273. if chk, ok := checkConstraints[name]; ok {
  274. return m.DB.Exec(
  275. "ALTER TABLE ? ADD CONSTRAINT ? CHECK ?",
  276. clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
  277. ).Error
  278. }
  279. for _, rel := range stmt.Schema.Relationships.Relations {
  280. if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
  281. sql, values := buildConstraint(constraint)
  282. return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error
  283. }
  284. }
  285. err := fmt.Errorf("failed to create constraint with name %v", name)
  286. if field := stmt.Schema.LookUpField(name); field != nil {
  287. for _, cc := range checkConstraints {
  288. if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil {
  289. return err
  290. }
  291. }
  292. for _, rel := range stmt.Schema.Relationships.Relations {
  293. if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
  294. if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil {
  295. return err
  296. }
  297. }
  298. }
  299. }
  300. return err
  301. })
  302. }
  303. func (m Migrator) DropConstraint(value interface{}, name string) error {
  304. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  305. return m.DB.Exec(
  306. "ALTER TABLE ? DROP CONSTRAINT ?",
  307. clause.Table{Name: stmt.Table}, clause.Column{Name: name},
  308. ).Error
  309. })
  310. }
  311. func (m Migrator) HasConstraint(value interface{}, name string) bool {
  312. var count int64
  313. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  314. currentDatabase := m.DB.Migrator().CurrentDatabase()
  315. return m.DB.Raw(
  316. "SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
  317. currentDatabase, stmt.Table, name,
  318. ).Row().Scan(&count)
  319. })
  320. return count > 0
  321. }
  322. func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
  323. for _, opt := range opts {
  324. str := stmt.Quote(opt.DBName)
  325. if opt.Expression != "" {
  326. str = opt.Expression
  327. } else if opt.Length > 0 {
  328. str += fmt.Sprintf("(%d)", opt.Length)
  329. }
  330. if opt.Collate != "" {
  331. str += " COLLATE " + opt.Collate
  332. }
  333. if opt.Sort != "" {
  334. str += " " + opt.Sort
  335. }
  336. results = append(results, clause.Expr{SQL: str})
  337. }
  338. return
  339. }
  340. type BuildIndexOptionsInterface interface {
  341. BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
  342. }
  343. func (m Migrator) CreateIndex(value interface{}, name string) error {
  344. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  345. err := fmt.Errorf("failed to create index with name %v", name)
  346. indexes := stmt.Schema.ParseIndexes()
  347. if idx, ok := indexes[name]; ok {
  348. opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
  349. values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
  350. createIndexSQL := "CREATE "
  351. if idx.Class != "" {
  352. createIndexSQL += idx.Class + " "
  353. }
  354. createIndexSQL += "INDEX ? ON ??"
  355. if idx.Comment != "" {
  356. values = append(values, idx.Comment)
  357. createIndexSQL += " COMMENT ?"
  358. }
  359. if idx.Type != "" {
  360. createIndexSQL += " USING " + idx.Type
  361. }
  362. return m.DB.Exec(createIndexSQL, values...).Error
  363. } else if field := stmt.Schema.LookUpField(name); field != nil {
  364. for _, idx := range indexes {
  365. for _, idxOpt := range idx.Fields {
  366. if idxOpt.Field == field {
  367. if err = m.CreateIndex(value, idx.Name); err != nil {
  368. return err
  369. }
  370. }
  371. }
  372. }
  373. }
  374. return err
  375. })
  376. }
  377. func (m Migrator) DropIndex(value interface{}, name string) error {
  378. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  379. return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error
  380. })
  381. }
  382. func (m Migrator) HasIndex(value interface{}, name string) bool {
  383. var count int64
  384. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  385. currentDatabase := m.DB.Migrator().CurrentDatabase()
  386. return m.DB.Raw(
  387. "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
  388. currentDatabase, stmt.Table, name,
  389. ).Row().Scan(&count)
  390. })
  391. return count > 0
  392. }
  393. func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
  394. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  395. return m.DB.Exec(
  396. "ALTER TABLE ? RENAME INDEX ? TO ?",
  397. clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
  398. ).Error
  399. })
  400. }
  401. func (m Migrator) CurrentDatabase() (name string) {
  402. m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
  403. return
  404. }