migrator.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. package migrator
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "reflect"
  6. "strings"
  7. "github.com/jinzhu/gorm"
  8. "github.com/jinzhu/gorm/clause"
  9. "github.com/jinzhu/gorm/schema"
  10. )
  11. // Migrator m struct
  12. type Migrator struct {
  13. Config
  14. }
  15. // Config schema config
  16. type Config struct {
  17. CreateIndexAfterCreateTable bool
  18. AllowDeferredConstraintsWhenAutoMigrate bool
  19. DB *gorm.DB
  20. gorm.Dialector
  21. }
  22. func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
  23. stmt := m.DB.Statement
  24. if stmt == nil {
  25. stmt = &gorm.Statement{DB: m.DB}
  26. }
  27. if err := stmt.Parse(value); err != nil {
  28. return err
  29. }
  30. return fc(stmt)
  31. }
  32. func (m Migrator) DataTypeOf(field *schema.Field) string {
  33. if field.DBDataType != "" {
  34. return field.DBDataType
  35. }
  36. return m.Dialector.DataTypeOf(field)
  37. }
  38. // AutoMigrate
  39. func (m Migrator) AutoMigrate(values ...interface{}) error {
  40. // TODO smart migrate data type
  41. for _, value := range m.ReorderModels(values, true) {
  42. tx := m.DB.Session(&gorm.Session{})
  43. if !tx.Migrator().HasTable(value) {
  44. if err := tx.Migrator().CreateTable(value); err != nil {
  45. return err
  46. }
  47. } else {
  48. if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
  49. for _, field := range stmt.Schema.FieldsByDBName {
  50. if !tx.Migrator().HasColumn(value, field.DBName) {
  51. if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
  52. return err
  53. }
  54. }
  55. }
  56. for _, rel := range stmt.Schema.Relationships.Relations {
  57. if constraint := rel.ParseConstraint(); constraint != nil {
  58. if !tx.Migrator().HasConstraint(value, constraint.Name) {
  59. if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
  60. return err
  61. }
  62. }
  63. }
  64. for _, chk := range stmt.Schema.ParseCheckConstraints() {
  65. if !tx.Migrator().HasConstraint(value, chk.Name) {
  66. if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
  67. return err
  68. }
  69. }
  70. }
  71. // create join table
  72. if rel.JoinTable != nil {
  73. joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
  74. if !tx.Migrator().HasTable(joinValue) {
  75. defer tx.Migrator().CreateTable(joinValue)
  76. }
  77. }
  78. }
  79. return nil
  80. }); err != nil {
  81. return err
  82. }
  83. }
  84. }
  85. return nil
  86. }
  87. func (m Migrator) CreateTable(values ...interface{}) error {
  88. for _, value := range m.ReorderModels(values, false) {
  89. tx := m.DB.Session(&gorm.Session{})
  90. if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
  91. var (
  92. createTableSQL = "CREATE TABLE ? ("
  93. values = []interface{}{clause.Table{Name: stmt.Table}}
  94. hasPrimaryKeyInDataType bool
  95. )
  96. for _, dbName := range stmt.Schema.DBNames {
  97. field := stmt.Schema.FieldsByDBName[dbName]
  98. createTableSQL += fmt.Sprintf("? ?")
  99. hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY")
  100. values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.DataTypeOf(field)})
  101. if field.AutoIncrement {
  102. createTableSQL += " AUTO_INCREMENT"
  103. }
  104. if field.NotNull {
  105. createTableSQL += " NOT NULL"
  106. }
  107. if field.Unique {
  108. createTableSQL += " UNIQUE"
  109. }
  110. if field.DefaultValue != "" {
  111. createTableSQL += " DEFAULT ?"
  112. values = append(values, clause.Expr{SQL: field.DefaultValue})
  113. }
  114. createTableSQL += ","
  115. }
  116. if !hasPrimaryKeyInDataType {
  117. createTableSQL += "PRIMARY KEY ?,"
  118. primaryKeys := []interface{}{}
  119. for _, field := range stmt.Schema.PrimaryFields {
  120. primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
  121. }
  122. values = append(values, primaryKeys)
  123. }
  124. for _, idx := range stmt.Schema.ParseIndexes() {
  125. if m.CreateIndexAfterCreateTable {
  126. tx.Migrator().CreateIndex(value, idx.Name)
  127. } else {
  128. createTableSQL += "INDEX ? ?,"
  129. values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
  130. }
  131. }
  132. for _, rel := range stmt.Schema.Relationships.Relations {
  133. if constraint := rel.ParseConstraint(); constraint != nil {
  134. sql, vars := buildConstraint(constraint)
  135. createTableSQL += sql + ","
  136. values = append(values, vars...)
  137. }
  138. // create join table
  139. if rel.JoinTable != nil {
  140. joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
  141. if !tx.Migrator().HasTable(joinValue) {
  142. defer tx.Migrator().CreateTable(joinValue)
  143. }
  144. }
  145. }
  146. for _, chk := range stmt.Schema.ParseCheckConstraints() {
  147. createTableSQL += "CONSTRAINT ? CHECK ?,"
  148. values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
  149. }
  150. createTableSQL = strings.TrimSuffix(createTableSQL, ",")
  151. createTableSQL += ")"
  152. return tx.Exec(createTableSQL, values...).Error
  153. }); err != nil {
  154. return err
  155. }
  156. }
  157. return nil
  158. }
  159. func (m Migrator) DropTable(values ...interface{}) error {
  160. values = m.ReorderModels(values, false)
  161. for i := len(values) - 1; i >= 0; i-- {
  162. value := values[i]
  163. if m.DB.Migrator().HasTable(value) {
  164. tx := m.DB.Session(&gorm.Session{})
  165. if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
  166. return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
  167. }); err != nil {
  168. return err
  169. }
  170. }
  171. }
  172. return nil
  173. }
  174. func (m Migrator) HasTable(value interface{}) bool {
  175. var count int64
  176. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  177. currentDatabase := m.DB.Migrator().CurrentDatabase()
  178. return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
  179. })
  180. return count > 0
  181. }
  182. func (m Migrator) RenameTable(oldName, newName string) error {
  183. return m.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error
  184. }
  185. func (m Migrator) AddColumn(value interface{}, field string) error {
  186. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  187. if field := stmt.Schema.LookUpField(field); field != nil {
  188. return m.DB.Exec(
  189. "ALTER TABLE ? ADD ? ?",
  190. clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)},
  191. ).Error
  192. }
  193. return fmt.Errorf("failed to look up field with name: %s", field)
  194. })
  195. }
  196. func (m Migrator) DropColumn(value interface{}, field string) error {
  197. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  198. if field := stmt.Schema.LookUpField(field); field != nil {
  199. return m.DB.Exec(
  200. "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName},
  201. ).Error
  202. }
  203. return fmt.Errorf("failed to look up field with name: %s", field)
  204. })
  205. }
  206. func (m Migrator) AlterColumn(value interface{}, field string) error {
  207. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  208. if field := stmt.Schema.LookUpField(field); field != nil {
  209. return m.DB.Exec(
  210. "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
  211. clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)},
  212. ).Error
  213. }
  214. return fmt.Errorf("failed to look up field with name: %s", field)
  215. })
  216. }
  217. func (m Migrator) HasColumn(value interface{}, field string) bool {
  218. var count int64
  219. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  220. currentDatabase := m.DB.Migrator().CurrentDatabase()
  221. name := field
  222. if field := stmt.Schema.LookUpField(field); field != nil {
  223. name = field.DBName
  224. }
  225. return m.DB.Raw(
  226. "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
  227. currentDatabase, stmt.Table, name,
  228. ).Row().Scan(&count)
  229. })
  230. return count > 0
  231. }
  232. func (m Migrator) RenameColumn(value interface{}, oldName, field string) error {
  233. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  234. if field := stmt.Schema.LookUpField(field); field != nil {
  235. oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
  236. return m.DB.Exec(
  237. "ALTER TABLE ? RENAME COLUMN ? TO ?",
  238. clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName},
  239. ).Error
  240. }
  241. return fmt.Errorf("failed to look up field with name: %s", field)
  242. })
  243. }
  244. func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
  245. err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
  246. rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
  247. if err == nil {
  248. columnTypes, err = rows.ColumnTypes()
  249. }
  250. return err
  251. })
  252. return
  253. }
  254. func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
  255. return gorm.ErrNotImplemented
  256. }
  257. func (m Migrator) DropView(name string) error {
  258. return gorm.ErrNotImplemented
  259. }
  260. func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
  261. sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
  262. if constraint.OnDelete != "" {
  263. sql += " ON DELETE " + constraint.OnDelete
  264. }
  265. if constraint.OnUpdate != "" {
  266. sql += " ON UPDATE " + constraint.OnUpdate
  267. }
  268. var foreignKeys, references []interface{}
  269. for _, field := range constraint.ForeignKeys {
  270. foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
  271. }
  272. for _, field := range constraint.References {
  273. references = append(references, clause.Column{Name: field.DBName})
  274. }
  275. results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
  276. return
  277. }
  278. func (m Migrator) CreateConstraint(value interface{}, name string) error {
  279. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  280. checkConstraints := stmt.Schema.ParseCheckConstraints()
  281. if chk, ok := checkConstraints[name]; ok {
  282. return m.DB.Exec(
  283. "ALTER TABLE ? ADD CONSTRAINT ? CHECK ?",
  284. clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
  285. ).Error
  286. }
  287. for _, rel := range stmt.Schema.Relationships.Relations {
  288. if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
  289. sql, values := buildConstraint(constraint)
  290. return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error
  291. }
  292. }
  293. err := fmt.Errorf("failed to create constraint with name %v", name)
  294. if field := stmt.Schema.LookUpField(name); field != nil {
  295. for _, cc := range checkConstraints {
  296. if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil {
  297. return err
  298. }
  299. }
  300. for _, rel := range stmt.Schema.Relationships.Relations {
  301. if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
  302. if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil {
  303. return err
  304. }
  305. }
  306. }
  307. }
  308. return err
  309. })
  310. }
  311. func (m Migrator) DropConstraint(value interface{}, name string) error {
  312. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  313. return m.DB.Exec(
  314. "ALTER TABLE ? DROP CONSTRAINT ?",
  315. clause.Table{Name: stmt.Table}, clause.Column{Name: name},
  316. ).Error
  317. })
  318. }
  319. func (m Migrator) HasConstraint(value interface{}, name string) bool {
  320. var count int64
  321. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  322. currentDatabase := m.DB.Migrator().CurrentDatabase()
  323. return m.DB.Raw(
  324. "SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
  325. currentDatabase, stmt.Table, name,
  326. ).Row().Scan(&count)
  327. })
  328. return count > 0
  329. }
  330. func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
  331. for _, opt := range opts {
  332. str := stmt.Quote(opt.DBName)
  333. if opt.Expression != "" {
  334. str = opt.Expression
  335. } else if opt.Length > 0 {
  336. str += fmt.Sprintf("(%d)", opt.Length)
  337. }
  338. if opt.Collate != "" {
  339. str += " COLLATE " + opt.Collate
  340. }
  341. if opt.Sort != "" {
  342. str += " " + opt.Sort
  343. }
  344. results = append(results, clause.Expr{SQL: str})
  345. }
  346. return
  347. }
  348. type BuildIndexOptionsInterface interface {
  349. BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
  350. }
  351. func (m Migrator) CreateIndex(value interface{}, name string) error {
  352. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  353. err := fmt.Errorf("failed to create index with name %v", name)
  354. indexes := stmt.Schema.ParseIndexes()
  355. if idx, ok := indexes[name]; ok {
  356. opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
  357. values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
  358. createIndexSQL := "CREATE "
  359. if idx.Class != "" {
  360. createIndexSQL += idx.Class + " "
  361. }
  362. createIndexSQL += "INDEX ? ON ??"
  363. if idx.Comment != "" {
  364. values = append(values, idx.Comment)
  365. createIndexSQL += " COMMENT ?"
  366. }
  367. if idx.Type != "" {
  368. createIndexSQL += " USING " + idx.Type
  369. }
  370. return m.DB.Exec(createIndexSQL, values...).Error
  371. } else if field := stmt.Schema.LookUpField(name); field != nil {
  372. for _, idx := range indexes {
  373. for _, idxOpt := range idx.Fields {
  374. if idxOpt.Field == field {
  375. if err = m.CreateIndex(value, idx.Name); err != nil {
  376. return err
  377. }
  378. }
  379. }
  380. }
  381. }
  382. return err
  383. })
  384. }
  385. func (m Migrator) DropIndex(value interface{}, name string) error {
  386. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  387. return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error
  388. })
  389. }
  390. func (m Migrator) HasIndex(value interface{}, name string) bool {
  391. var count int64
  392. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  393. currentDatabase := m.DB.Migrator().CurrentDatabase()
  394. return m.DB.Raw(
  395. "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
  396. currentDatabase, stmt.Table, name,
  397. ).Row().Scan(&count)
  398. })
  399. return count > 0
  400. }
  401. func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
  402. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  403. return m.DB.Exec(
  404. "ALTER TABLE ? RENAME INDEX ? TO ?",
  405. clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
  406. ).Error
  407. })
  408. }
  409. func (m Migrator) CurrentDatabase() (name string) {
  410. m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
  411. return
  412. }
  413. // ReorderModels reorder models according to constraint dependencies
  414. func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) {
  415. type Dependency struct {
  416. *gorm.Statement
  417. Depends []*schema.Schema
  418. }
  419. var (
  420. modelNames, orderedModelNames []string
  421. orderedModelNamesMap = map[string]bool{}
  422. valuesMap = map[string]Dependency{}
  423. insertIntoOrderedList func(name string)
  424. )
  425. parseDependence := func(value interface{}, addToList bool) {
  426. dep := Dependency{
  427. Statement: &gorm.Statement{DB: m.DB, Dest: value},
  428. }
  429. dep.Parse(value)
  430. for _, rel := range dep.Schema.Relationships.Relations {
  431. if c := rel.ParseConstraint(); c != nil && c.Schema != c.ReferenceSchema {
  432. dep.Depends = append(dep.Depends, c.ReferenceSchema)
  433. }
  434. }
  435. valuesMap[dep.Schema.Table] = dep
  436. if addToList {
  437. modelNames = append(modelNames, dep.Schema.Table)
  438. }
  439. }
  440. insertIntoOrderedList = func(name string) {
  441. if _, ok := orderedModelNamesMap[name]; ok {
  442. return // avoid loop
  443. }
  444. dep := valuesMap[name]
  445. for _, d := range dep.Depends {
  446. if _, ok := valuesMap[d.Table]; ok {
  447. insertIntoOrderedList(d.Table)
  448. } else if autoAdd {
  449. parseDependence(reflect.New(d.ModelType).Interface(), autoAdd)
  450. insertIntoOrderedList(d.Table)
  451. }
  452. }
  453. orderedModelNames = append(orderedModelNames, name)
  454. orderedModelNamesMap[name] = true
  455. }
  456. for _, value := range values {
  457. parseDependence(value, true)
  458. }
  459. for _, name := range modelNames {
  460. insertIntoOrderedList(name)
  461. }
  462. for _, name := range orderedModelNames {
  463. results = append(results, valuesMap[name].Statement.Dest)
  464. }
  465. return
  466. }