migrator.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package sqlite
  2. import (
  3. "fmt"
  4. "github.com/jinzhu/gorm"
  5. "github.com/jinzhu/gorm/clause"
  6. "github.com/jinzhu/gorm/migrator"
  7. "github.com/jinzhu/gorm/schema"
  8. )
  9. type Migrator struct {
  10. migrator.Migrator
  11. }
  12. func (m Migrator) HasTable(value interface{}) bool {
  13. var count int
  14. m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  15. return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
  16. })
  17. return count > 0
  18. }
  19. func (m Migrator) HasColumn(value interface{}, field string) bool {
  20. var count int
  21. m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  22. name := field
  23. if field := stmt.Schema.LookUpField(field); field != nil {
  24. name = field.DBName
  25. }
  26. return m.DB.Raw(
  27. "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
  28. stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%",
  29. ).Row().Scan(&count)
  30. })
  31. return count > 0
  32. }
  33. func (m Migrator) HasIndex(value interface{}, name string) bool {
  34. var count int
  35. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  36. return m.DB.Raw(
  37. "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE ?",
  38. stmt.Table, "%INDEX "+name+" ON%",
  39. ).Row().Scan(&count)
  40. })
  41. return count > 0
  42. }
  43. func (m Migrator) CreateConstraint(interface{}, string) error {
  44. return gorm.ErrNotImplemented
  45. }
  46. func (m Migrator) DropConstraint(interface{}, string) error {
  47. return gorm.ErrNotImplemented
  48. }
  49. func (m Migrator) CurrentDatabase() (name string) {
  50. var null interface{}
  51. m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
  52. return
  53. }
  54. func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
  55. for _, opt := range opts {
  56. str := stmt.Quote(opt.DBName)
  57. if opt.Expression != "" {
  58. str = opt.Expression
  59. }
  60. if opt.Collate != "" {
  61. str += " COLLATE " + opt.Collate
  62. }
  63. if opt.Sort != "" {
  64. str += " " + opt.Sort
  65. }
  66. results = append(results, clause.Expr{SQL: str})
  67. }
  68. return
  69. }
  70. func (m Migrator) CreateIndex(value interface{}, name string) error {
  71. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  72. err := fmt.Errorf("failed to create index with name %v", name)
  73. indexes := stmt.Schema.ParseIndexes()
  74. if idx, ok := indexes[name]; ok {
  75. opts := m.BuildIndexOptions(idx.Fields, stmt)
  76. values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
  77. createIndexSQL := "CREATE "
  78. if idx.Class != "" {
  79. createIndexSQL += idx.Class + " "
  80. }
  81. createIndexSQL += "INDEX ?"
  82. if idx.Type != "" {
  83. createIndexSQL += " USING " + idx.Type
  84. }
  85. createIndexSQL += " ON ??"
  86. if idx.Where != "" {
  87. createIndexSQL += " WHERE " + idx.Where
  88. }
  89. return m.DB.Exec(createIndexSQL, values...).Error
  90. } else if field := stmt.Schema.LookUpField(name); field != nil {
  91. for _, idx := range indexes {
  92. for _, idxOpt := range idx.Fields {
  93. if idxOpt.Field == field {
  94. if err = m.CreateIndex(value, idx.Name); err != nil {
  95. return err
  96. }
  97. }
  98. }
  99. }
  100. }
  101. return err
  102. })
  103. }