Jinzhu 4 лет назад
Родитель
Сommit
ea0b13f7a3
5 измененных файлов с 80 добавлено и 81 удалено
  1. 2 2
      schema/field.go
  2. 27 43
      schema/index.go
  3. 24 22
      schema/index_test.go
  4. 1 1
      schema/schema_helper_test.go
  5. 26 13
      schema/utils.go

+ 2 - 2
schema/field.go

@@ -74,7 +74,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 		Creatable:         true,
 		Updatable:         true,
 		Tag:               fieldStruct.Tag,
-		TagSettings:       ParseTagSetting(fieldStruct.Tag),
+		TagSettings:       ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"),
 		Schema:            schema,
 	}
 
@@ -104,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 				}
 
 				// copy tag settings from valuer
-				for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag) {
+				for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
 					if _, ok := field.TagSettings[key]; !ok {
 						field.TagSettings[key] = value
 					}

+ 27 - 43
schema/index.go

@@ -6,9 +6,12 @@ import (
 )
 
 type Index struct {
-	Name   string
-	Class  string // UNIQUE | FULLTEXT | SPATIAL
-	Fields []IndexOption
+	Name    string
+	Class   string // UNIQUE | FULLTEXT | SPATIAL
+	Type    string // btree, hash, gist, spgist, gin, and brin
+	Where   string
+	Comment string
+	Fields  []IndexOption
 }
 
 type IndexOption struct {
@@ -17,9 +20,6 @@ type IndexOption struct {
 	Sort       string // DESC, ASC
 	Collate    string
 	Length     int
-	Type       string // btree, hash, gist, spgist, gin, and brin
-	Where      string
-	Comment    string
 }
 
 // ParseIndexes parse schema indexes
@@ -34,6 +34,15 @@ func (schema *Schema) ParseIndexes() map[string]Index {
 				if idx.Class == "" {
 					idx.Class = index.Class
 				}
+				if idx.Type == "" {
+					idx.Type = index.Type
+				}
+				if idx.Where == "" {
+					idx.Where = index.Where
+				}
+				if idx.Comment == "" {
+					idx.Comment = index.Comment
+				}
 				idx.Fields = append(idx.Fields, index.Fields...)
 				indexes[index.Name] = idx
 			}
@@ -50,62 +59,37 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
 			k := strings.TrimSpace(strings.ToUpper(v[0]))
 			if k == "INDEX" || k == "UNIQUE_INDEX" {
 				var (
-					name     string
-					tag      = strings.Join(v[1:], ":")
-					settings = map[string]string{}
+					name      string
+					tag       = strings.Join(v[1:], ":")
+					idx       = strings.Index(tag, ",")
+					settings  = ParseTagSetting(tag, ",")
+					length, _ = strconv.Atoi(settings["LENGTH"])
 				)
 
-				names := strings.Split(tag, ",")
-				for i := 0; i < len(names); i++ {
-					if len(names[i]) > 0 {
-						j := i
-						for {
-							if names[j][len(names[j])-1] == '\\' {
-								i++
-								names[j] = names[j][0:len(names[j])-1] + names[i]
-								names[i] = ""
-							} else {
-								break
-							}
-						}
-					}
-
-					if i == 0 {
-						name = names[0]
-					}
-
-					values := strings.Split(names[i], ":")
-					k := strings.TrimSpace(strings.ToUpper(values[0]))
-
-					if len(values) >= 2 {
-						settings[k] = strings.Join(values[1:], ":")
-					} else if k != "" {
-						settings[k] = k
-					}
+				if idx != -1 {
+					name = tag[0:idx]
 				}
 
 				if name == "" {
 					name = field.Schema.namer.IndexName(field.Schema.Table, field.Name)
 				}
 
-				length, _ := strconv.Atoi(settings["LENGTH"])
-
 				if (k == "UNIQUE_INDEX") || settings["UNIQUE"] != "" {
 					settings["CLASS"] = "UNIQUE"
 				}
 
 				indexes = append(indexes, Index{
-					Name:  name,
-					Class: settings["CLASS"],
+					Name:    name,
+					Class:   settings["CLASS"],
+					Type:    settings["TYPE"],
+					Where:   settings["WHERE"],
+					Comment: settings["COMMENT"],
 					Fields: []IndexOption{{
 						Field:      field,
 						Expression: settings["EXPRESSION"],
 						Sort:       settings["SORT"],
 						Collate:    settings["COLLATE"],
-						Type:       settings["TYPE"],
 						Length:     length,
-						Where:      settings["WHERE"],
-						Comment:    settings["COMMENT"],
 					}},
 				})
 			}

+ 24 - 22
schema/index_test.go

@@ -35,13 +35,13 @@ func TestParseIndex(t *testing.T) {
 			Fields: []schema.IndexOption{{}},
 		},
 		"idx_user_indices_name3": {
-			Name: "idx_user_indices_name3",
+			Name:  "idx_user_indices_name3",
+			Type:  "btree",
+			Where: "name3 != 'jinzhu'",
 			Fields: []schema.IndexOption{{
 				Sort:    "desc",
 				Collate: "utf8",
 				Length:  10,
-				Type:    "btree",
-				Where:   "name3 != 'jinzhu'",
 			}},
 		},
 		"idx_user_indices_name4": {
@@ -50,19 +50,17 @@ func TestParseIndex(t *testing.T) {
 			Fields: []schema.IndexOption{{}},
 		},
 		"idx_user_indices_name5": {
-			Name:  "idx_user_indices_name5",
-			Class: "FULLTEXT",
-			Fields: []schema.IndexOption{{
-				Comment: "hello , world",
-				Where:   "age > 10",
-			}},
+			Name:    "idx_user_indices_name5",
+			Class:   "FULLTEXT",
+			Comment: "hello , world",
+			Where:   "age > 10",
+			Fields:  []schema.IndexOption{{}},
 		},
 		"profile": {
-			Name: "profile",
-			Fields: []schema.IndexOption{{
-				Comment: "hello , world",
-				Where:   "age > 10",
-			}, {
+			Name:    "profile",
+			Comment: "hello , world",
+			Where:   "age > 10",
+			Fields: []schema.IndexOption{{}, {
 				Expression: "(age+10)",
 			}},
 		},
@@ -76,19 +74,23 @@ func TestParseIndex(t *testing.T) {
 			t.Errorf("Failed to found index %v from parsed indices %+v", k, indices)
 		}
 
-		if result.Name != v.Name {
-			t.Errorf("index %v name should equal, expects %v, got %v", k, result.Name, v.Name)
-		}
-
-		if result.Class != v.Class {
-			t.Errorf("index %v Class should equal, expects %v, got %v", k, result.Class, v.Class)
+		for _, name := range []string{"Name", "Class", "Type", "Where", "Comment"} {
+			if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
+				t.Errorf(
+					"index %v %v should equal, expects %v, got %v",
+					k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
+				)
+			}
 		}
 
 		for idx, ef := range result.Fields {
 			rf := v.Fields[idx]
-			for _, name := range []string{"Expression", "Sort", "Collate", "Length", "Type", "Where"} {
+			for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
 				if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
-					t.Errorf("index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface())
+					t.Errorf(
+						"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
+						reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
+					)
 				}
 			}
 		}

+ 1 - 1
schema/schema_helper_test.go

@@ -44,7 +44,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
 
 		if f.TagSettings == nil {
 			if f.Tag != "" {
-				f.TagSettings = schema.ParseTagSetting(f.Tag)
+				f.TagSettings = schema.ParseTagSetting(f.Tag.Get("gorm"), ";")
 			} else {
 				f.TagSettings = map[string]string{}
 			}

+ 26 - 13
schema/utils.go

@@ -6,22 +6,35 @@ import (
 	"strings"
 )
 
-func ParseTagSetting(tags reflect.StructTag) map[string]string {
-	setting := map[string]string{}
-
-	for _, value := range strings.Split(tags.Get("gorm"), ";") {
-		if value != "" {
-			v := strings.Split(value, ":")
-			k := strings.TrimSpace(strings.ToUpper(v[0]))
-
-			if len(v) >= 2 {
-				setting[k] = strings.Join(v[1:], ":")
-			} else {
-				setting[k] = k
+func ParseTagSetting(str string, sep string) map[string]string {
+	settings := map[string]string{}
+	names := strings.Split(str, sep)
+
+	for i := 0; i < len(names); i++ {
+		j := i
+		if len(names[j]) > 0 {
+			for {
+				if names[j][len(names[j])-1] == '\\' {
+					i++
+					names[j] = names[j][0:len(names[j])-1] + sep + names[i]
+					names[i] = ""
+				} else {
+					break
+				}
 			}
 		}
+
+		values := strings.Split(names[j], ":")
+		k := strings.TrimSpace(strings.ToUpper(values[0]))
+
+		if len(values) >= 2 {
+			settings[k] = strings.Join(values[1:], ":")
+		} else if k != "" {
+			settings[k] = k
+		}
 	}
-	return setting
+
+	return settings
 }
 
 func checkTruth(val string) bool {