mysql.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package gorm
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strings"
  6. )
  7. type mysql struct{}
  8. func (s *mysql) BinVar(i int) string {
  9. return "$$" // ?
  10. }
  11. func (s *mysql) SupportLastInsertId() bool {
  12. return true
  13. }
  14. func (s *mysql) HasTop() bool {
  15. return false
  16. }
  17. func (d *mysql) SqlTag(value reflect.Value, size int) string {
  18. switch value.Kind() {
  19. case reflect.Bool:
  20. return "boolean"
  21. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
  22. return "int"
  23. case reflect.Int64, reflect.Uint64:
  24. return "bigint"
  25. case reflect.Float32, reflect.Float64:
  26. return "double"
  27. case reflect.String:
  28. if size > 0 && size < 65532 {
  29. return fmt.Sprintf("varchar(%d)", size)
  30. } else {
  31. return "longtext"
  32. }
  33. case reflect.Struct:
  34. if value.Type() == timeType {
  35. return "datetime"
  36. }
  37. default:
  38. if _, ok := value.Interface().([]byte); ok {
  39. if size > 0 && size < 65532 {
  40. return fmt.Sprintf("varbinary(%d)", size)
  41. } else {
  42. return "longblob"
  43. }
  44. }
  45. }
  46. panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
  47. }
  48. func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
  49. suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
  50. switch value.Kind() {
  51. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
  52. return "int" + suffix_str
  53. case reflect.Int64, reflect.Uint64:
  54. return "bigint" + suffix_str
  55. default:
  56. panic("Invalid primary key type")
  57. }
  58. }
  59. func (s *mysql) ReturningStr(tableName, key string) string {
  60. return ""
  61. }
  62. func (s *mysql) SelectFromDummyTable() string {
  63. return "FROM DUAL"
  64. }
  65. func (s *mysql) Quote(key string) string {
  66. return fmt.Sprintf("`%s`", key)
  67. }
  68. func (s *mysql) databaseName(scope *Scope) string {
  69. from := strings.Index(scope.db.parent.source, "/") + 1
  70. to := strings.Index(scope.db.parent.source, "?")
  71. if to == -1 {
  72. to = len(scope.db.parent.source)
  73. }
  74. return scope.db.parent.source[from:to]
  75. }
  76. func (s *mysql) HasTable(scope *Scope, tableName string) bool {
  77. var count int
  78. newScope := scope.New(nil)
  79. newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v",
  80. newScope.AddToVars(tableName),
  81. newScope.AddToVars(s.databaseName(scope))))
  82. newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
  83. return count > 0
  84. }
  85. func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool {
  86. var count int
  87. newScope := scope.New(nil)
  88. newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_schema = %v AND table_name = %v AND column_name = %v",
  89. newScope.AddToVars(s.databaseName(scope)),
  90. newScope.AddToVars(tableName),
  91. newScope.AddToVars(columnName),
  92. ))
  93. newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
  94. return count > 0
  95. }
  96. func (s *mysql) RemoveIndex(scope *Scope, indexName string) {
  97. scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Exec()
  98. }