瀏覽代碼

Implement naming strategy

Jinzhu 4 年之前
父節點
當前提交
bc68fde6aa
共有 5 個文件被更改,包括 142 次插入4 次删除
  1. 2 0
      go.mod
  2. 2 0
      go.sum
  3. 8 4
      gorm.go
  4. 96 0
      schema/naming.go
  5. 34 0
      schema/naming_test.go

+ 2 - 0
go.mod

@@ -1,3 +1,5 @@
 module github.com/jinzhu/gorm
 
 go 1.13
+
+require github.com/jinzhu/inflection v1.0.0

+ 2 - 0
go.sum

@@ -0,0 +1,2 @@
+github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
+github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=

+ 8 - 4
gorm.go

@@ -6,18 +6,18 @@ import (
 
 	"github.com/jinzhu/gorm/clause"
 	"github.com/jinzhu/gorm/logger"
+	"github.com/jinzhu/gorm/schema"
 )
 
 // Config GORM config
 type Config struct {
-	// Set true to use singular table name, by default, GORM will pluralize your struct's name as table name
-	// Refer https://github.com/jinzhu/inflection for inflection rules
-	SingularTable bool
-
 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 	// You can cancel it by setting `SkipDefaultTransaction` to true
 	SkipDefaultTransaction bool
 
+	// NamingStrategy tables, columns naming strategy
+	NamingStrategy schema.Namer
+
 	// Logger
 	Logger logger.Interface
 
@@ -48,6 +48,10 @@ type Session struct {
 
 // Open initialize db session based on dialector
 func Open(dialector Dialector, config *Config) (db *DB, err error) {
+	if config.NamingStrategy == nil {
+		config.NamingStrategy = schema.NamingStrategy{}
+	}
+
 	return &DB{
 		Config:    config,
 		Dialector: dialector,

+ 96 - 0
schema/naming.go

@@ -0,0 +1,96 @@
+package schema
+
+import (
+	"fmt"
+	"strings"
+	"sync"
+
+	"github.com/jinzhu/inflection"
+)
+
+// Namer namer interface
+type Namer interface {
+	TableName(string) string
+	ColumnName(string) string
+}
+
+// NamingStrategy tables, columns naming strategy
+type NamingStrategy struct {
+	TablePrefix   string
+	SingularTable bool
+}
+
+// TableName convert string to table name
+func (ns NamingStrategy) TableName(str string) string {
+	if ns.SingularTable {
+		return ns.TablePrefix + toDBName(str)
+	}
+	return ns.TablePrefix + inflection.Plural(toDBName(str))
+}
+
+// ColumnName convert string to column name
+func (ns NamingStrategy) ColumnName(str string) string {
+	return toDBName(str)
+}
+
+var (
+	smap sync.Map
+	// https://github.com/golang/lint/blob/master/lint.go#L770
+	commonInitialisms         = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
+	commonInitialismsReplacer *strings.Replacer
+)
+
+func init() {
+	var commonInitialismsForReplacer []string
+	for _, initialism := range commonInitialisms {
+		commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
+	}
+	commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
+}
+
+func toDBName(name string) string {
+	if name == "" {
+		return ""
+	} else if v, ok := smap.Load(name); ok {
+		return fmt.Sprint(v)
+	}
+
+	var (
+		value                          = commonInitialismsReplacer.Replace(name)
+		buf                            strings.Builder
+		lastCase, nextCase, nextNumber bool // upper case == true
+		curCase                        = value[0] <= 'Z' && value[0] >= 'A'
+	)
+
+	for i, v := range value[:len(value)-1] {
+		nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
+		nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
+
+		if curCase {
+			if lastCase && (nextCase || nextNumber) {
+				buf.WriteRune(v + 32)
+			} else {
+				if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
+					buf.WriteByte('_')
+				}
+				buf.WriteRune(v + 32)
+			}
+		} else {
+			buf.WriteRune(v)
+		}
+
+		lastCase = curCase
+		curCase = nextCase
+	}
+
+	if curCase {
+		if !lastCase && len(value) > 1 {
+			buf.WriteByte('_')
+		}
+		buf.WriteByte(value[len(value)-1] + 32)
+	} else {
+		buf.WriteByte(value[len(value)-1])
+	}
+
+	return buf.String()
+}

+ 34 - 0
schema/naming_test.go

@@ -0,0 +1,34 @@
+package schema
+
+import (
+	"testing"
+)
+
+func TestToDBName(t *testing.T) {
+	var maps = map[string]string{
+		"":                          "",
+		"x":                         "x",
+		"X":                         "x",
+		"userRestrictions":          "user_restrictions",
+		"ThisIsATest":               "this_is_a_test",
+		"PFAndESI":                  "pf_and_esi",
+		"AbcAndJkl":                 "abc_and_jkl",
+		"EmployeeID":                "employee_id",
+		"SKU_ID":                    "sku_id",
+		"FieldX":                    "field_x",
+		"HTTPAndSMTP":               "http_and_smtp",
+		"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
+		"UUID":                      "uuid",
+		"HTTPURL":                   "http_url",
+		"HTTP_URL":                  "http_url",
+		"SHA256Hash":                "sha256_hash",
+		"SHA256HASH":                "sha256_hash",
+		"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
+	}
+
+	for key, value := range maps {
+		if toDBName(key) != value {
+			t.Errorf("%v toName should equal %v, but got %v", key, value, toDBName(key))
+		}
+	}
+}