postgres.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package postgres
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "regexp"
  6. "strconv"
  7. "github.com/jinzhu/gorm"
  8. "github.com/jinzhu/gorm/callbacks"
  9. "github.com/jinzhu/gorm/logger"
  10. "github.com/jinzhu/gorm/migrator"
  11. "github.com/jinzhu/gorm/schema"
  12. _ "github.com/lib/pq"
  13. )
  14. type Dialector struct {
  15. DSN string
  16. }
  17. func Open(dsn string) gorm.Dialector {
  18. return &Dialector{DSN: dsn}
  19. }
  20. func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
  21. // register callbacks
  22. callbacks.RegisterDefaultCallbacks(db)
  23. db.DB, err = sql.Open("postgres", dialector.DSN)
  24. return
  25. }
  26. func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
  27. return Migrator{migrator.Migrator{Config: migrator.Config{
  28. DB: db,
  29. Dialector: dialector,
  30. CreateIndexAfterCreateTable: true,
  31. }}}
  32. }
  33. func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
  34. return "$" + strconv.Itoa(len(stmt.Vars))
  35. }
  36. func (dialector Dialector) QuoteChars() [2]byte {
  37. return [2]byte{'"', '"'} // "name"
  38. }
  39. var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
  40. func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
  41. return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
  42. }
  43. func (dialector Dialector) DataTypeOf(field *schema.Field) string {
  44. switch field.DataType {
  45. case schema.Bool:
  46. return "boolean"
  47. case schema.Int, schema.Uint:
  48. if field.AutoIncrement {
  49. switch {
  50. case field.Size < 16:
  51. return "smallserial"
  52. case field.Size < 31:
  53. return "serial"
  54. default:
  55. return "bigserial"
  56. }
  57. } else {
  58. switch {
  59. case field.Size < 16:
  60. return "smallint"
  61. case field.Size < 31:
  62. return "integer"
  63. default:
  64. return "bigint"
  65. }
  66. }
  67. case schema.Float:
  68. return "decimal"
  69. case schema.String:
  70. if field.Size > 0 {
  71. return fmt.Sprintf("varchar(%d)", field.Size)
  72. }
  73. return "text"
  74. case schema.Time:
  75. return "timestamp with time zone"
  76. case schema.Bytes:
  77. return "bytea"
  78. }
  79. return ""
  80. }