|
@@ -1,7 +1,6 @@
|
|
|
package callbacks
|
|
|
|
|
|
import (
|
|
|
- "fmt"
|
|
|
"reflect"
|
|
|
|
|
|
"github.com/jinzhu/gorm"
|
|
@@ -11,8 +10,6 @@ import (
|
|
|
func BeforeCreate(db *gorm.DB) {
|
|
|
// before save
|
|
|
// before create
|
|
|
-
|
|
|
- // assign timestamp
|
|
|
}
|
|
|
|
|
|
func SaveBeforeAssociations(db *gorm.DB) {
|
|
@@ -22,16 +19,29 @@ func Create(db *gorm.DB) {
|
|
|
db.Statement.AddClauseIfNotExists(clause.Insert{
|
|
|
Table: clause.Table{Name: db.Statement.Table},
|
|
|
})
|
|
|
- values, _ := ConvertToCreateValues(db.Statement)
|
|
|
- db.Statement.AddClause(values)
|
|
|
+ db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
|
|
|
|
|
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
|
|
|
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
|
|
|
|
|
- fmt.Printf("%+v\n", values)
|
|
|
- fmt.Println(err)
|
|
|
- fmt.Println(result)
|
|
|
- fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
|
|
+ if err == nil {
|
|
|
+ if db.Statement.Schema != nil {
|
|
|
+ if insertID, err := result.LastInsertId(); err == nil {
|
|
|
+ switch db.Statement.ReflectValue.Kind() {
|
|
|
+ case reflect.Slice, reflect.Array:
|
|
|
+ for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
|
|
+ db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
|
|
+ insertID--
|
|
|
+ }
|
|
|
+ case reflect.Struct:
|
|
|
+ db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ db.RowsAffected, _ = result.RowsAffected()
|
|
|
+ } else {
|
|
|
+ db.AddError(err)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func SaveAfterAssociations(db *gorm.DB) {
|
|
@@ -43,19 +53,18 @@ func AfterCreate(db *gorm.DB) {
|
|
|
}
|
|
|
|
|
|
// ConvertToCreateValues convert to create values
|
|
|
-func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) {
|
|
|
+func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
|
|
switch value := stmt.Dest.(type) {
|
|
|
case map[string]interface{}:
|
|
|
- return ConvertMapToValues(stmt, value), nil
|
|
|
+ return ConvertMapToValues(stmt, value)
|
|
|
case []map[string]interface{}:
|
|
|
- return ConvertSliceOfMapToValues(stmt, value), nil
|
|
|
+ return ConvertSliceOfMapToValues(stmt, value)
|
|
|
default:
|
|
|
var (
|
|
|
values = clause.Values{}
|
|
|
selectColumns, restricted = SelectAndOmitColumns(stmt)
|
|
|
curTime = stmt.DB.NowFunc()
|
|
|
isZero = false
|
|
|
- returnningValues []map[string]interface{}
|
|
|
)
|
|
|
|
|
|
for _, db := range stmt.Schema.DBNames {
|
|
@@ -66,13 +75,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest))
|
|
|
- switch reflectValue.Kind() {
|
|
|
+ switch stmt.ReflectValue.Kind() {
|
|
|
case reflect.Slice, reflect.Array:
|
|
|
- values.Values = make([][]interface{}, reflectValue.Len())
|
|
|
+ values.Values = make([][]interface{}, stmt.ReflectValue.Len())
|
|
|
defaultValueFieldsHavingValue := map[string][]interface{}{}
|
|
|
- for i := 0; i < reflectValue.Len(); i++ {
|
|
|
- rv := reflect.Indirect(reflectValue.Index(i))
|
|
|
+ for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
|
|
+ rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
|
|
values.Values[i] = make([]interface{}, len(values.Columns))
|
|
|
for idx, column := range values.Columns {
|
|
|
field := stmt.Schema.FieldsByDBName[column.Name]
|
|
@@ -91,7 +99,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
|
|
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
|
|
if v, isZero := field.ValueOf(rv); !isZero {
|
|
|
if len(defaultValueFieldsHavingValue[db]) == 0 {
|
|
|
- defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len())
|
|
|
+ defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len())
|
|
|
}
|
|
|
defaultValueFieldsHavingValue[db][i] = v
|
|
|
}
|
|
@@ -113,20 +121,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
|
|
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
|
|
for idx, column := range values.Columns {
|
|
|
field := stmt.Schema.FieldsByDBName[column.Name]
|
|
|
- if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero {
|
|
|
+ if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
|
|
|
if field.DefaultValueInterface != nil {
|
|
|
values.Values[0][idx] = field.DefaultValueInterface
|
|
|
- field.Set(reflectValue, field.DefaultValueInterface)
|
|
|
+ field.Set(stmt.ReflectValue, field.DefaultValueInterface)
|
|
|
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
|
|
- field.Set(reflectValue, curTime)
|
|
|
- values.Values[0][idx], _ = field.ValueOf(reflectValue)
|
|
|
+ field.Set(stmt.ReflectValue, curTime)
|
|
|
+ values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
|
|
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
|
|
- if v, isZero := field.ValueOf(reflectValue); !isZero {
|
|
|
+ if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
|
|
values.Columns = append(values.Columns, clause.Column{Name: db})
|
|
|
values.Values[0] = append(values.Values[0], v)
|
|
|
}
|
|
@@ -134,6 +142,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- return values, returnningValues
|
|
|
+ return values
|
|
|
}
|
|
|
}
|