mssql.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. package mssql
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "regexp"
  6. "strconv"
  7. _ "github.com/denisenkom/go-mssqldb"
  8. "github.com/jinzhu/gorm"
  9. "github.com/jinzhu/gorm/callbacks"
  10. "github.com/jinzhu/gorm/logger"
  11. "github.com/jinzhu/gorm/migrator"
  12. "github.com/jinzhu/gorm/schema"
  13. )
  14. type Dialector struct {
  15. DSN string
  16. }
  17. func Open(dsn string) gorm.Dialector {
  18. return &Dialector{DSN: dsn}
  19. }
  20. func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
  21. // register callbacks
  22. callbacks.RegisterDefaultCallbacks(db)
  23. db.DB, err = sql.Open("sqlserver", dialector.DSN)
  24. return
  25. }
  26. func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
  27. return Migrator{migrator.Migrator{Config: migrator.Config{
  28. DB: db,
  29. Dialector: dialector,
  30. CreateIndexAfterCreateTable: true,
  31. }}}
  32. }
  33. func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
  34. return "@p" + strconv.Itoa(len(stmt.Vars))
  35. }
  36. func (dialector Dialector) QuoteChars() [2]byte {
  37. return [2]byte{'"', '"'} // `name`
  38. }
  39. var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
  40. func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
  41. return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
  42. }
  43. func (dialector Dialector) DataTypeOf(field *schema.Field) string {
  44. switch field.DataType {
  45. case schema.Bool:
  46. return "bit"
  47. case schema.Int, schema.Uint:
  48. var sqlType string
  49. switch {
  50. case field.Size < 16:
  51. sqlType = "smallint"
  52. case field.Size < 31:
  53. sqlType = "int"
  54. default:
  55. sqlType = "bigint"
  56. }
  57. if field.AutoIncrement {
  58. return sqlType + " IDENTITY(1,1)"
  59. }
  60. return sqlType
  61. case schema.Float:
  62. return "decimal"
  63. case schema.String:
  64. size := field.Size
  65. if field.PrimaryKey && size == 0 {
  66. size = 256
  67. }
  68. if size > 0 && size <= 4000 {
  69. return fmt.Sprintf("nvarchar(%d)", size)
  70. }
  71. return "ntext"
  72. case schema.Time:
  73. return "datetimeoffset"
  74. case schema.Bytes:
  75. return "binary"
  76. }
  77. return ""
  78. }