migrator.go 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package postgres
  2. import (
  3. "fmt"
  4. "github.com/jinzhu/gorm"
  5. "github.com/jinzhu/gorm/clause"
  6. "github.com/jinzhu/gorm/migrator"
  7. "github.com/jinzhu/gorm/schema"
  8. )
  9. type Migrator struct {
  10. migrator.Migrator
  11. }
  12. func (m Migrator) CurrentDatabase() (name string) {
  13. m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name)
  14. return
  15. }
  16. func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
  17. for _, opt := range opts {
  18. str := stmt.Quote(opt.DBName)
  19. if opt.Expression != "" {
  20. str = opt.Expression
  21. }
  22. if opt.Collate != "" {
  23. str += " COLLATE " + opt.Collate
  24. }
  25. if opt.Sort != "" {
  26. str += " " + opt.Sort
  27. }
  28. results = append(results, clause.Expr{SQL: str})
  29. }
  30. return
  31. }
  32. func (m Migrator) HasIndex(value interface{}, indexName string) bool {
  33. var count int64
  34. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  35. return m.DB.Raw(
  36. "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, indexName,
  37. ).Row().Scan(&count)
  38. })
  39. return count > 0
  40. }
  41. func (m Migrator) CreateIndex(value interface{}, name string) error {
  42. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  43. err := fmt.Errorf("failed to create index with name %v", name)
  44. indexes := stmt.Schema.ParseIndexes()
  45. if idx, ok := indexes[name]; ok {
  46. opts := m.BuildIndexOptions(idx.Fields, stmt)
  47. values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
  48. createIndexSQL := "CREATE "
  49. if idx.Class != "" {
  50. createIndexSQL += idx.Class + " "
  51. }
  52. createIndexSQL += "INDEX ?"
  53. if idx.Type != "" {
  54. createIndexSQL += " USING " + idx.Type
  55. }
  56. createIndexSQL += " ON ??"
  57. if idx.Where != "" {
  58. createIndexSQL += " WHERE " + idx.Where
  59. }
  60. return m.DB.Exec(createIndexSQL, values...).Error
  61. } else if field := stmt.Schema.LookUpField(name); field != nil {
  62. for _, idx := range indexes {
  63. for _, idxOpt := range idx.Fields {
  64. if idxOpt.Field == field {
  65. if err = m.CreateIndex(value, idx.Name); err != nil {
  66. return err
  67. }
  68. }
  69. }
  70. }
  71. }
  72. return err
  73. })
  74. }