migrator.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package postgres
  2. import (
  3. "fmt"
  4. "github.com/jinzhu/gorm"
  5. "github.com/jinzhu/gorm/clause"
  6. "github.com/jinzhu/gorm/migrator"
  7. "github.com/jinzhu/gorm/schema"
  8. )
  9. type Migrator struct {
  10. migrator.Migrator
  11. }
  12. func (m Migrator) CurrentDatabase() (name string) {
  13. m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name)
  14. return
  15. }
  16. func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
  17. for _, opt := range opts {
  18. str := stmt.Quote(opt.DBName)
  19. if opt.Expression != "" {
  20. str = opt.Expression
  21. }
  22. if opt.Collate != "" {
  23. str += " COLLATE " + opt.Collate
  24. }
  25. if opt.Sort != "" {
  26. str += " " + opt.Sort
  27. }
  28. results = append(results, clause.Expr{SQL: str})
  29. }
  30. return
  31. }
  32. func (m Migrator) HasIndex(value interface{}, indexName string) bool {
  33. var count int64
  34. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  35. return m.DB.Raw(
  36. "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, indexName,
  37. ).Row().Scan(&count)
  38. })
  39. return count > 0
  40. }
  41. func (m Migrator) CreateIndex(value interface{}, name string) error {
  42. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  43. err := fmt.Errorf("failed to create index with name %v", name)
  44. indexes := stmt.Schema.ParseIndexes()
  45. if idx, ok := indexes[name]; ok {
  46. opts := m.BuildIndexOptions(idx.Fields, stmt)
  47. values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
  48. createIndexSQL := "CREATE "
  49. if idx.Class != "" {
  50. createIndexSQL += idx.Class + " "
  51. }
  52. createIndexSQL += "INDEX ?"
  53. if idx.Type != "" {
  54. createIndexSQL += " USING " + idx.Type
  55. }
  56. createIndexSQL += " ON ??"
  57. if idx.Where != "" {
  58. createIndexSQL += " WHERE " + idx.Where
  59. }
  60. return m.DB.Exec(createIndexSQL, values...).Error
  61. } else if field := stmt.Schema.LookUpField(name); field != nil {
  62. for _, idx := range indexes {
  63. for _, idxOpt := range idx.Fields {
  64. if idxOpt.Field == field {
  65. if err = m.CreateIndex(value, idx.Name); err != nil {
  66. return err
  67. }
  68. }
  69. }
  70. }
  71. }
  72. return err
  73. })
  74. }
  75. func (m Migrator) HasTable(value interface{}) bool {
  76. var count int64
  77. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  78. return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count)
  79. })
  80. return count > 0
  81. }
  82. func (m Migrator) HasColumn(value interface{}, field string) bool {
  83. var count int64
  84. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  85. name := field
  86. if field := stmt.Schema.LookUpField(field); field != nil {
  87. name = field.DBName
  88. }
  89. return m.DB.Raw(
  90. "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?",
  91. stmt.Table, name,
  92. ).Row().Scan(&count)
  93. })
  94. return count > 0
  95. }