Răsfoiți Sursa

Parse Indexes

Jinzhu 4 ani în urmă
părinte
comite
ad419855e9
3 a modificat fișierele cu 230 adăugiri și 2 ștergeri
  1. 116 0
      schema/index.go
  2. 96 0
      schema/index_test.go
  3. 18 2
      schema/naming.go

+ 116 - 0
schema/index.go

@@ -0,0 +1,116 @@
+package schema
+
+import (
+	"strconv"
+	"strings"
+)
+
+type Index struct {
+	Name   string
+	Class  string // UNIQUE | FULLTEXT | SPATIAL
+	Fields []IndexOption
+}
+
+type IndexOption struct {
+	*Field
+	Expression string
+	Sort       string // DESC, ASC
+	Collate    string
+	Length     int
+	Type       string // btree, hash, gist, spgist, gin, and brin
+	Where      string
+	Comment    string
+}
+
+// ParseIndexes parse schema indexes
+func (schema *Schema) ParseIndexes() map[string]Index {
+	var indexes = map[string]Index{}
+
+	for _, field := range schema.FieldsByDBName {
+		if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE_INDEX"] != "" {
+			for _, index := range parseFieldIndexes(field) {
+				idx := indexes[index.Name]
+				idx.Name = index.Name
+				if idx.Class == "" {
+					idx.Class = index.Class
+				}
+				idx.Fields = append(idx.Fields, index.Fields...)
+				indexes[index.Name] = idx
+			}
+		}
+	}
+
+	return indexes
+}
+
+func parseFieldIndexes(field *Field) (indexes []Index) {
+	for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
+		if value != "" {
+			v := strings.Split(value, ":")
+			k := strings.TrimSpace(strings.ToUpper(v[0]))
+			if k == "INDEX" || k == "UNIQUE_INDEX" {
+				var (
+					name     string
+					tag      = strings.Join(v[1:], ":")
+					settings = map[string]string{}
+				)
+
+				names := strings.Split(tag, ",")
+				for i := 0; i < len(names); i++ {
+					if len(names[i]) > 0 {
+						j := i
+						for {
+							if names[j][len(names[j])-1] == '\\' {
+								i++
+								names[j] = names[j][0:len(names[j])-1] + names[i]
+								names[i] = ""
+							} else {
+								break
+							}
+						}
+					}
+
+					if i == 0 {
+						name = names[0]
+					}
+
+					values := strings.Split(names[i], ":")
+					k := strings.TrimSpace(strings.ToUpper(values[0]))
+
+					if len(values) >= 2 {
+						settings[k] = strings.Join(values[1:], ":")
+					} else if k != "" {
+						settings[k] = k
+					}
+				}
+
+				if name == "" {
+					name = field.Schema.namer.IndexName(field.Schema.Table, field.Name)
+				}
+
+				length, _ := strconv.Atoi(settings["LENGTH"])
+
+				if (k == "UNIQUE_INDEX") || settings["UNIQUE"] != "" {
+					settings["CLASS"] = "UNIQUE"
+				}
+
+				indexes = append(indexes, Index{
+					Name:  name,
+					Class: settings["CLASS"],
+					Fields: []IndexOption{{
+						Field:      field,
+						Expression: settings["EXPRESSION"],
+						Sort:       settings["SORT"],
+						Collate:    settings["COLLATE"],
+						Type:       settings["TYPE"],
+						Length:     length,
+						Where:      settings["WHERE"],
+						Comment:    settings["COMMENT"],
+					}},
+				})
+			}
+		}
+	}
+
+	return
+}

+ 96 - 0
schema/index_test.go

@@ -0,0 +1,96 @@
+package schema_test
+
+import (
+	"reflect"
+	"sync"
+	"testing"
+
+	"github.com/jinzhu/gorm/schema"
+)
+
+type UserIndex struct {
+	Name  string `gorm:"index"`
+	Name2 string `gorm:"index:idx_name,unique"`
+	Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"`
+	Name4 string `gorm:"unique_index"`
+	Name5 int64  `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"`
+	Name6 int64  `gorm:"index:profile,comment:hello \\, world,where:age > 10"`
+	Age   int64  `gorm:"index:profile,expression:(age+10)"`
+}
+
+func TestParseIndex(t *testing.T) {
+	user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
+	if err != nil {
+		t.Fatalf("failed to parse user index index, got error %v", err)
+	}
+
+	results := map[string]schema.Index{
+		"idx_user_indices_name": {
+			Name:   "idx_user_indices_name",
+			Fields: []schema.IndexOption{{}},
+		},
+		"idx_name": {
+			Name:   "idx_name",
+			Class:  "UNIQUE",
+			Fields: []schema.IndexOption{{}},
+		},
+		"idx_user_indices_name3": {
+			Name: "idx_user_indices_name3",
+			Fields: []schema.IndexOption{{
+				Sort:    "desc",
+				Collate: "utf8",
+				Length:  10,
+				Type:    "btree",
+				Where:   "name3 != 'jinzhu'",
+			}},
+		},
+		"idx_user_indices_name4": {
+			Name:   "idx_user_indices_name4",
+			Class:  "UNIQUE",
+			Fields: []schema.IndexOption{{}},
+		},
+		"idx_user_indices_name5": {
+			Name:  "idx_user_indices_name5",
+			Class: "FULLTEXT",
+			Fields: []schema.IndexOption{{
+				Comment: "hello , world",
+				Where:   "age > 10",
+			}},
+		},
+		"profile": {
+			Name: "profile",
+			Fields: []schema.IndexOption{{
+				Comment: "hello , world",
+				Where:   "age > 10",
+			}, {
+				Expression: "(age+10)",
+			}},
+		},
+	}
+
+	indices := user.ParseIndexes()
+
+	for k, result := range results {
+		v, ok := indices[k]
+		if !ok {
+			t.Errorf("Failed to found index %v from parsed indices %+v", k, indices)
+		}
+
+		if result.Name != v.Name {
+			t.Errorf("index %v name should equal, expects %v, got %v", k, result.Name, v.Name)
+		}
+
+		if result.Class != v.Class {
+			t.Errorf("index %v Class should equal, expects %v, got %v", k, result.Class, v.Class)
+		}
+
+		for idx, ef := range result.Fields {
+			rf := v.Fields[idx]
+			for _, name := range []string{"Expression", "Sort", "Collate", "Length", "Type", "Where"} {
+				if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
+					t.Errorf("index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface())
+				}
+			}
+		}
+	}
+}

+ 18 - 2
schema/naming.go

@@ -1,9 +1,11 @@
 package schema
 
 import (
+	"crypto/sha1"
 	"fmt"
 	"strings"
 	"sync"
+	"unicode/utf8"
 
 	"github.com/jinzhu/inflection"
 )
@@ -12,6 +14,7 @@ import (
 type Namer interface {
 	TableName(table string) string
 	ColumnName(table, column string) string
+	IndexName(table, column string) string
 	JoinTableName(table string) string
 }
 
@@ -30,8 +33,21 @@ func (ns NamingStrategy) TableName(str string) string {
 }
 
 // ColumnName convert string to column name
-func (ns NamingStrategy) ColumnName(table, str string) string {
-	return toDBName(str)
+func (ns NamingStrategy) ColumnName(table, column string) string {
+	return toDBName(column)
+}
+
+func (ns NamingStrategy) IndexName(table, column string) string {
+	idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
+
+	if utf8.RuneCountInString(idxName) > 64 {
+		h := sha1.New()
+		h.Write([]byte(idxName))
+		bs := h.Sum(nil)
+
+		idxName = fmt.Sprintf("idx%v%v", table, column)[0:56] + string(bs)[:8]
+	}
+	return idxName
 }
 
 // JoinTableName convert string to join table name