dialect_mysql.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package gorm
  2. import (
  3. "crypto/sha1"
  4. "fmt"
  5. "reflect"
  6. "regexp"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "unicode/utf8"
  11. )
  12. type mysql struct {
  13. commonDialect
  14. }
  15. func init() {
  16. RegisterDialect("mysql", &mysql{})
  17. }
  18. func (mysql) GetName() string {
  19. return "mysql"
  20. }
  21. func (mysql) Quote(key string) string {
  22. return fmt.Sprintf("`%s`", key)
  23. }
  24. // Get Data Type for MySQL Dialect
  25. func (s *mysql) DataTypeOf(field *StructField) string {
  26. var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
  27. // MySQL allows only one auto increment column per table, and it must
  28. // be a KEY column.
  29. if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
  30. if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey {
  31. field.TagSettingsDelete("AUTO_INCREMENT")
  32. }
  33. }
  34. if sqlType == "" {
  35. switch dataValue.Kind() {
  36. case reflect.Bool:
  37. sqlType = "boolean"
  38. case reflect.Int8:
  39. if s.fieldCanAutoIncrement(field) {
  40. field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
  41. sqlType = "tinyint AUTO_INCREMENT"
  42. } else {
  43. sqlType = "tinyint"
  44. }
  45. case reflect.Int, reflect.Int16, reflect.Int32:
  46. if s.fieldCanAutoIncrement(field) {
  47. field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
  48. sqlType = "int AUTO_INCREMENT"
  49. } else {
  50. sqlType = "int"
  51. }
  52. case reflect.Uint8:
  53. if s.fieldCanAutoIncrement(field) {
  54. field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
  55. sqlType = "tinyint unsigned AUTO_INCREMENT"
  56. } else {
  57. sqlType = "tinyint unsigned"
  58. }
  59. case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
  60. if s.fieldCanAutoIncrement(field) {
  61. field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
  62. sqlType = "int unsigned AUTO_INCREMENT"
  63. } else {
  64. sqlType = "int unsigned"
  65. }
  66. case reflect.Int64:
  67. if s.fieldCanAutoIncrement(field) {
  68. field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
  69. sqlType = "bigint AUTO_INCREMENT"
  70. } else {
  71. sqlType = "bigint"
  72. }
  73. case reflect.Uint64:
  74. if s.fieldCanAutoIncrement(field) {
  75. field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
  76. sqlType = "bigint unsigned AUTO_INCREMENT"
  77. } else {
  78. sqlType = "bigint unsigned"
  79. }
  80. case reflect.Float32, reflect.Float64:
  81. sqlType = "double"
  82. case reflect.String:
  83. if size > 0 && size < 65532 {
  84. sqlType = fmt.Sprintf("varchar(%d)", size)
  85. } else {
  86. sqlType = "longtext"
  87. }
  88. case reflect.Struct:
  89. if _, ok := dataValue.Interface().(time.Time); ok {
  90. precision := ""
  91. if p, ok := field.TagSettingsGet("PRECISION"); ok {
  92. precision = fmt.Sprintf("(%s)", p)
  93. }
  94. if _, ok := field.TagSettingsGet("NOT NULL"); ok {
  95. sqlType = fmt.Sprintf("timestamp%v", precision)
  96. } else {
  97. sqlType = fmt.Sprintf("timestamp%v NULL", precision)
  98. }
  99. }
  100. default:
  101. if IsByteArrayOrSlice(dataValue) {
  102. if size > 0 && size < 65532 {
  103. sqlType = fmt.Sprintf("varbinary(%d)", size)
  104. } else {
  105. sqlType = "longblob"
  106. }
  107. }
  108. }
  109. }
  110. if sqlType == "" {
  111. panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
  112. }
  113. if strings.TrimSpace(additionalType) == "" {
  114. return sqlType
  115. }
  116. return fmt.Sprintf("%v %v", sqlType, additionalType)
  117. }
  118. func (s mysql) RemoveIndex(tableName string, indexName string) error {
  119. _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
  120. return err
  121. }
  122. func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
  123. _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
  124. return err
  125. }
  126. func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
  127. if limit != nil {
  128. if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
  129. sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
  130. if offset != nil {
  131. if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
  132. sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
  133. }
  134. }
  135. }
  136. }
  137. return
  138. }
  139. func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
  140. var count int
  141. currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
  142. s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
  143. return count > 0
  144. }
  145. func (s mysql) CurrentDatabase() (name string) {
  146. s.db.QueryRow("SELECT DATABASE()").Scan(&name)
  147. return
  148. }
  149. func (mysql) SelectFromDummyTable() string {
  150. return "FROM DUAL"
  151. }
  152. func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
  153. keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...)
  154. if utf8.RuneCountInString(keyName) <= 64 {
  155. return keyName
  156. }
  157. h := sha1.New()
  158. h.Write([]byte(keyName))
  159. bs := h.Sum(nil)
  160. // sha1 is 40 characters, keep first 24 characters of destination
  161. destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
  162. if len(destRunes) > 24 {
  163. destRunes = destRunes[:24]
  164. }
  165. return fmt.Sprintf("%s%x", string(destRunes), bs)
  166. }
  167. func (mysql) DefaultValueStr() string {
  168. return "VALUES()"
  169. }