|
@@ -100,10 +100,11 @@ func (scope *Scope) HasError() bool {
|
|
|
return scope.db.Error != nil
|
|
|
}
|
|
|
|
|
|
-func (scope *Scope) PrimaryFields() []*Field {
|
|
|
- var fields = []*Field{}
|
|
|
- for _, field := range scope.GetModelStruct().PrimaryFields {
|
|
|
- fields = append(fields, scope.Fields()[field.DBName])
|
|
|
+func (scope *Scope) PrimaryFields() (fields []*Field) {
|
|
|
+ for _, field := range scope.Fields() {
|
|
|
+ if field.IsPrimaryKey {
|
|
|
+ fields = append(fields, field)
|
|
|
+ }
|
|
|
}
|
|
|
return fields
|
|
|
}
|
|
@@ -111,11 +112,11 @@ func (scope *Scope) PrimaryFields() []*Field {
|
|
|
func (scope *Scope) PrimaryField() *Field {
|
|
|
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
|
|
if len(primaryFields) > 1 {
|
|
|
- if field, ok := scope.Fields()["id"]; ok {
|
|
|
+ if field, ok := scope.FieldByName("id"); ok {
|
|
|
return field
|
|
|
}
|
|
|
}
|
|
|
- return scope.Fields()[primaryFields[0].DBName]
|
|
|
+ return scope.PrimaryFields()[0]
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
@@ -164,20 +165,23 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
|
|
updateAttrs[field.DBName] = value
|
|
|
return field.Set(value)
|
|
|
} else if name, ok := column.(string); ok {
|
|
|
- if field, ok := scope.Fields()[name]; ok {
|
|
|
- updateAttrs[field.DBName] = value
|
|
|
- return field.Set(value)
|
|
|
- }
|
|
|
-
|
|
|
- dbName := ToDBName(name)
|
|
|
- if field, ok := scope.Fields()[dbName]; ok {
|
|
|
- updateAttrs[field.DBName] = value
|
|
|
- return field.Set(value)
|
|
|
+ var (
|
|
|
+ dbName = ToDBName(name)
|
|
|
+ mostMatchedField *Field
|
|
|
+ )
|
|
|
+ for _, field := range scope.Fields() {
|
|
|
+ if field.DBName == value {
|
|
|
+ updateAttrs[field.DBName] = value
|
|
|
+ return field.Set(value)
|
|
|
+ }
|
|
|
+ if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
|
|
|
+ mostMatchedField = field
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- if field, ok := scope.FieldByName(name); ok {
|
|
|
- updateAttrs[field.DBName] = value
|
|
|
- return field.Set(value)
|
|
|
+ if mostMatchedField != nil {
|
|
|
+ updateAttrs[mostMatchedField.DBName] = value
|
|
|
+ return mostMatchedField.Set(value)
|
|
|
}
|
|
|
}
|
|
|
return errors.New("could not convert column to field")
|
|
@@ -286,12 +290,20 @@ func (scope *Scope) CombinedConditionSql() string {
|
|
|
|
|
|
// FieldByName find gorm.Field with name and db name
|
|
|
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
|
|
+ var (
|
|
|
+ dbName = ToDBName(name)
|
|
|
+ mostMatchedField *Field
|
|
|
+ )
|
|
|
+
|
|
|
for _, field := range scope.Fields() {
|
|
|
if field.Name == name || field.DBName == name {
|
|
|
return field, true
|
|
|
}
|
|
|
+ if field.DBName == dbName {
|
|
|
+ mostMatchedField = field
|
|
|
+ }
|
|
|
}
|
|
|
- return nil, false
|
|
|
+ return mostMatchedField, mostMatchedField != nil
|
|
|
}
|
|
|
|
|
|
// Raw set sql
|
|
@@ -390,12 +402,12 @@ func (scope *Scope) OmitAttrs() []string {
|
|
|
return scope.Search.omits
|
|
|
}
|
|
|
|
|
|
-func (scope *Scope) scan(rows *sql.Rows, columns []string, fields map[string]*Field) {
|
|
|
+func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
|
|
|
var values = make([]interface{}, len(columns))
|
|
|
var ignored interface{}
|
|
|
|
|
|
for index, column := range columns {
|
|
|
- if field, ok := fields[column]; ok {
|
|
|
+ if field, ok := fieldsMap[column]; ok {
|
|
|
if field.Field.Kind() == reflect.Ptr {
|
|
|
values[index] = field.Field.Addr().Interface()
|
|
|
} else {
|
|
@@ -411,7 +423,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields map[string]*Fi
|
|
|
scope.Err(rows.Scan(values...))
|
|
|
|
|
|
for index, column := range columns {
|
|
|
- if field, ok := fields[column]; ok {
|
|
|
+ if field, ok := fieldsMap[column]; ok {
|
|
|
if field.Field.Kind() != reflect.Ptr {
|
|
|
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
|
|
|
field.Field.Set(v)
|