Browse Source

Make inesrt into db works

Jinzhu 4 years ago
parent
commit
fa22807e12
10 changed files with 92 additions and 66 deletions
  1. 1 1
      callbacks.go
  2. 33 25
      callbacks/create.go
  3. 2 6
      callbacks/query.go
  4. 13 10
      logger/logger.go
  5. 10 1
      logger/sql.go
  6. 1 1
      schema/field.go
  7. 2 2
      schema/relationship.go
  8. 8 7
      schema/schema.go
  9. 19 13
      statement.go
  10. 3 0
      tests/tests.go

+ 1 - 1
callbacks.go

@@ -91,7 +91,7 @@ func (p *processor) Execute(db *DB) {
 
 	if stmt := db.Statement; stmt != nil {
 		db.Logger.Trace(curTime, func() (string, int64) {
-			return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars), db.RowsAffected
+			return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
 		}, db.Error)
 	}
 }

+ 33 - 25
callbacks/create.go

@@ -1,7 +1,6 @@
 package callbacks
 
 import (
-	"fmt"
 	"reflect"
 
 	"github.com/jinzhu/gorm"
@@ -11,8 +10,6 @@ import (
 func BeforeCreate(db *gorm.DB) {
 	// before save
 	// before create
-
-	// assign timestamp
 }
 
 func SaveBeforeAssociations(db *gorm.DB) {
@@ -22,16 +19,29 @@ func Create(db *gorm.DB) {
 	db.Statement.AddClauseIfNotExists(clause.Insert{
 		Table: clause.Table{Name: db.Statement.Table},
 	})
-	values, _ := ConvertToCreateValues(db.Statement)
-	db.Statement.AddClause(values)
+	db.Statement.AddClause(ConvertToCreateValues(db.Statement))
 
 	db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 
-	fmt.Printf("%+v\n", values)
-	fmt.Println(err)
-	fmt.Println(result)
-	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
+	if err == nil {
+		if db.Statement.Schema != nil {
+			if insertID, err := result.LastInsertId(); err == nil {
+				switch db.Statement.ReflectValue.Kind() {
+				case reflect.Slice, reflect.Array:
+					for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
+						db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
+						insertID--
+					}
+				case reflect.Struct:
+					db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
+				}
+			}
+		}
+		db.RowsAffected, _ = result.RowsAffected()
+	} else {
+		db.AddError(err)
+	}
 }
 
 func SaveAfterAssociations(db *gorm.DB) {
@@ -43,19 +53,18 @@ func AfterCreate(db *gorm.DB) {
 }
 
 // ConvertToCreateValues convert to create values
-func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) {
+func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
 	switch value := stmt.Dest.(type) {
 	case map[string]interface{}:
-		return ConvertMapToValues(stmt, value), nil
+		return ConvertMapToValues(stmt, value)
 	case []map[string]interface{}:
-		return ConvertSliceOfMapToValues(stmt, value), nil
+		return ConvertSliceOfMapToValues(stmt, value)
 	default:
 		var (
 			values                    = clause.Values{}
 			selectColumns, restricted = SelectAndOmitColumns(stmt)
 			curTime                   = stmt.DB.NowFunc()
 			isZero                    = false
-			returnningValues          []map[string]interface{}
 		)
 
 		for _, db := range stmt.Schema.DBNames {
@@ -66,13 +75,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
 			}
 		}
 
-		reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest))
-		switch reflectValue.Kind() {
+		switch stmt.ReflectValue.Kind() {
 		case reflect.Slice, reflect.Array:
-			values.Values = make([][]interface{}, reflectValue.Len())
+			values.Values = make([][]interface{}, stmt.ReflectValue.Len())
 			defaultValueFieldsHavingValue := map[string][]interface{}{}
-			for i := 0; i < reflectValue.Len(); i++ {
-				rv := reflect.Indirect(reflectValue.Index(i))
+			for i := 0; i < stmt.ReflectValue.Len(); i++ {
+				rv := reflect.Indirect(stmt.ReflectValue.Index(i))
 				values.Values[i] = make([]interface{}, len(values.Columns))
 				for idx, column := range values.Columns {
 					field := stmt.Schema.FieldsByDBName[column.Name]
@@ -91,7 +99,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
 					if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
 						if v, isZero := field.ValueOf(rv); !isZero {
 							if len(defaultValueFieldsHavingValue[db]) == 0 {
-								defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len())
+								defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len())
 							}
 							defaultValueFieldsHavingValue[db][i] = v
 						}
@@ -113,20 +121,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
 			values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
 			for idx, column := range values.Columns {
 				field := stmt.Schema.FieldsByDBName[column.Name]
-				if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero {
+				if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
 					if field.DefaultValueInterface != nil {
 						values.Values[0][idx] = field.DefaultValueInterface
-						field.Set(reflectValue, field.DefaultValueInterface)
+						field.Set(stmt.ReflectValue, field.DefaultValueInterface)
 					} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
-						field.Set(reflectValue, curTime)
-						values.Values[0][idx], _ = field.ValueOf(reflectValue)
+						field.Set(stmt.ReflectValue, curTime)
+						values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
 					}
 				}
 			}
 
 			for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
 				if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
-					if v, isZero := field.ValueOf(reflectValue); !isZero {
+					if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
 						values.Columns = append(values.Columns, clause.Column{Name: db})
 						values.Values[0] = append(values.Values[0], v)
 					}
@@ -134,6 +142,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
 			}
 		}
 
-		return values, returnningValues
+		return values
 	}
 }

+ 2 - 6
callbacks/query.go

@@ -1,8 +1,6 @@
 package callbacks
 
 import (
-	"fmt"
-
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/clause"
 )
@@ -15,10 +13,8 @@ func Query(db *gorm.DB) {
 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
 	}
 
-	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
-	fmt.Println(err)
-	fmt.Println(result)
-	fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
+	rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
+	db.AddError(err)
 }
 
 func Preload(db *gorm.DB) {

+ 13 - 10
logger/logger.go

@@ -66,9 +66,9 @@ func New(writer Writer, config Config) Interface {
 	)
 
 	if config.Colorful {
-		infoPrefix = Green + "%s\n" + Reset + Green + "[info]" + Reset
-		warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn]" + Reset
-		errPrefix = Magenta + "%s\n" + Reset + Red + "[error]" + Reset
+		infoPrefix = Green + "%s\n" + Reset + Green + "[info] " + Reset
+		warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
+		errPrefix = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
 		tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s"
 		traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s"
 	}
@@ -93,29 +93,28 @@ type logger struct {
 
 // LogMode log mode
 func (l logger) LogMode(level LogLevel) Interface {
-	config := l.Config
-	config.LogLevel = level
-	return logger{Writer: l.Writer, Config: config}
+	l.LogLevel = level
+	return l
 }
 
 // Info print info
 func (l logger) Info(msg string, data ...interface{}) {
 	if l.LogLevel >= Info {
-		l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
+		l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
 	}
 }
 
 // Warn print warn messages
 func (l logger) Warn(msg string, data ...interface{}) {
 	if l.LogLevel >= Warn {
-		l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
+		l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
 	}
 }
 
 // Error print error messages
 func (l logger) Error(msg string, data ...interface{}) {
 	if l.LogLevel >= Error {
-		l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
+		l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
 	}
 }
 
@@ -123,7 +122,11 @@ func (l logger) Error(msg string, data ...interface{}) {
 func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) {
 	if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold {
 		sql, rows := fc()
-		l.Printf(l.traceErrPrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
+		fileline := utils.FileWithLineNum()
+		if err != nil {
+			fileline += " " + err.Error()
+		}
+		l.Printf(l.traceErrPrefix, fileline, float64(elapsed.Nanoseconds())/1e6, rows, sql)
 	} else if l.LogLevel >= Info {
 		sql, rows := fc()
 		l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)

+ 10 - 1
logger/sql.go

@@ -30,7 +30,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
 		case bool:
 			vars[idx] = fmt.Sprint(v)
 		case time.Time:
-			vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
+			if v.IsZero() {
+				vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
+			} else {
+				vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
+			}
 		case []byte:
 			if isPrintable(v) {
 				vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
@@ -48,6 +52,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
 				vars[idx] = "NULL"
 			} else {
 				rv := reflect.Indirect(reflect.ValueOf(v))
+				if !rv.IsValid() {
+					vars[idx] = "NULL"
+					return
+				}
+
 				for _, t := range convertableTypes {
 					if rv.Type().ConvertibleTo(t) {
 						convertParams(rv.Convert(t).Interface(), idx)

+ 1 - 1
schema/field.go

@@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 		var err error
 		field.Creatable = false
 		field.Updatable = false
-		if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
+		if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
 			schema.err = err
 		}
 		for _, ef := range field.EmbeddedSchema.Fields {

+ 2 - 2
schema/relationship.go

@@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
 		}
 	)
 
-	if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
+	if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
 		schema.err = err
 		return
 	}
@@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 		}
 	}
 
-	if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
+	if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
 		schema.err = err
 	}
 	relation.JoinTable.Name = many2many

+ 8 - 7
schema/schema.go

@@ -48,21 +48,22 @@ func (schema Schema) LookUpField(name string) *Field {
 }
 
 // get data type from dialector
-func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
-	modelType := reflect.ValueOf(dest).Type()
+func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
+	reflectValue := reflect.ValueOf(dest)
+	modelType := reflectValue.Type()
 	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
 		modelType = modelType.Elem()
 	}
 
 	if modelType.Kind() != reflect.Struct {
 		if modelType.PkgPath() == "" {
-			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
+			return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 		}
-		return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
+		return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 	}
 
 	if v, ok := cacheStore.Load(modelType); ok {
-		return v.(*Schema), nil
+		return v.(*Schema), reflectValue, nil
 	}
 
 	schema := &Schema{
@@ -167,10 +168,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
 	for _, field := range schema.Fields {
 		if field.DataType == "" && field.Creatable {
 			if schema.parseRelation(field); schema.err != nil {
-				return schema, schema.err
+				return schema, reflectValue, schema.err
 			}
 		}
 	}
 
-	return schema, schema.err
+	return schema, reflectValue, schema.err
 }

+ 19 - 13
statement.go

@@ -5,6 +5,7 @@ import (
 	"database/sql"
 	"database/sql/driver"
 	"fmt"
+	"reflect"
 	"strconv"
 	"strings"
 	"sync"
@@ -32,22 +33,23 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
 func (inst *Instance) AddError(err error) {
 	if inst.Error == nil {
 		inst.Error = err
-	} else {
+	} else if err != nil {
 		inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
 	}
 }
 
 // Statement statement
 type Statement struct {
-	Table    string
-	Model    interface{}
-	Dest     interface{}
-	Clauses  map[string]clause.Clause
-	Selects  []string // selected columns
-	Omits    []string // omit columns
-	Settings sync.Map
-	DB       *DB
-	Schema   *schema.Schema
+	Table        string
+	Model        interface{}
+	Dest         interface{}
+	ReflectValue reflect.Value
+	Clauses      map[string]clause.Clause
+	Selects      []string // selected columns
+	Omits        []string // omit columns
+	Settings     sync.Map
+	DB           *DB
+	Schema       *schema.Schema
 
 	// SQL Builder
 	SQL       strings.Builder
@@ -197,7 +199,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
 // BuildCondtion build condition
 func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
 	if sql, ok := query.(string); ok {
-		if i, err := strconv.Atoi(sql); err != nil {
+		if i, err := strconv.Atoi(sql); err == nil {
 			query = i
 		} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
 			return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
@@ -272,8 +274,12 @@ func (stmt *Statement) Build(clauses ...string) {
 }
 
 func (stmt *Statement) Parse(value interface{}) (err error) {
-	if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
-		stmt.Table = stmt.Schema.Table
+	if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
+		stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)
+
+		if stmt.Table == "" {
+			stmt.Table = stmt.Schema.Table
+		}
 	}
 	return err
 }

+ 3 - 0
tests/tests.go

@@ -17,6 +17,9 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) {
 }
 
 func TestCreate(t *testing.T, db *gorm.DB) {
+	db.AutoMigrate(&User{})
+	db = db.Debug()
+
 	t.Run("Create", func(t *testing.T) {
 		var user = User{
 			Name:     "create",