field.go 18 KB


  1. package schema
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "fmt"
  6. "reflect"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/jinzhu/now"
  12. )
  13. type DataType string
  14. type TimeType int64
  15. const (
  16. UnixSecond TimeType = 1
  17. UnixNanosecond TimeType = 2
  18. )
  19. const (
  20. Bool DataType = "bool"
  21. Int = "int"
  22. Uint = "uint"
  23. Float = "float"
  24. String = "string"
  25. Time = "time"
  26. Bytes = "bytes"
  27. )
  28. type Field struct {
  29. Name string
  30. DBName string
  31. BindNames []string
  32. DataType DataType
  33. DBDataType string
  34. PrimaryKey bool
  35. AutoIncrement bool
  36. Creatable bool
  37. Updatable bool
  38. HasDefaultValue bool
  39. AutoCreateTime TimeType
  40. AutoUpdateTime TimeType
  41. DefaultValue string
  42. DefaultValueInterface interface{}
  43. NotNull bool
  44. Unique bool
  45. Comment string
  46. Size int
  47. Precision int
  48. FieldType reflect.Type
  49. IndirectFieldType reflect.Type
  50. StructField reflect.StructField
  51. Tag reflect.StructTag
  52. TagSettings map[string]string
  53. Schema *Schema
  54. EmbeddedSchema *Schema
  55. ReflectValueOf func(reflect.Value) reflect.Value
  56. ValueOf func(reflect.Value) (value interface{}, zero bool)
  57. Set func(reflect.Value, interface{}) error
  58. }
  59. func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
  60. field := &Field{
  61. Name: fieldStruct.Name,
  62. BindNames: []string{fieldStruct.Name},
  63. FieldType: fieldStruct.Type,
  64. IndirectFieldType: fieldStruct.Type,
  65. StructField: fieldStruct,
  66. Creatable: true,
  67. Updatable: true,
  68. Tag: fieldStruct.Tag,
  69. TagSettings: ParseTagSetting(fieldStruct.Tag),
  70. Schema: schema,
  71. }
  72. for field.IndirectFieldType.Kind() == reflect.Ptr {
  73. field.IndirectFieldType = field.IndirectFieldType.Elem()
  74. }
  75. fieldValue := reflect.New(field.IndirectFieldType)
  76. // if field is valuer, used its value or first fields as data type
  77. if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf {
  78. var overrideFieldValue bool
  79. if v, err := valuer.Value(); v != nil && err == nil {
  80. overrideFieldValue = true
  81. fieldValue = reflect.ValueOf(v)
  82. }
  83. if field.IndirectFieldType.Kind() == reflect.Struct {
  84. for i := 0; i < field.IndirectFieldType.NumField(); i++ {
  85. if !overrideFieldValue {
  86. newFieldType := field.IndirectFieldType.Field(i).Type
  87. for newFieldType.Kind() == reflect.Ptr {
  88. newFieldType = newFieldType.Elem()
  89. }
  90. fieldValue = reflect.New(newFieldType)
  91. overrideFieldValue = true
  92. }
  93. // copy tag settings from valuer
  94. for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag) {
  95. if _, ok := field.TagSettings[key]; !ok {
  96. field.TagSettings[key] = value
  97. }
  98. }
  99. }
  100. }
  101. }
  102. // setup permission
  103. if _, ok := field.TagSettings["-"]; ok {
  104. field.Creatable = false
  105. field.Updatable = false
  106. }
  107. if dbName, ok := field.TagSettings["COLUMN"]; ok {
  108. field.DBName = dbName
  109. }
  110. if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) {
  111. field.PrimaryKey = true
  112. }
  113. if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) {
  114. field.AutoIncrement = true
  115. field.HasDefaultValue = true
  116. }
  117. if v, ok := field.TagSettings["DEFAULT"]; ok {
  118. field.HasDefaultValue = true
  119. field.DefaultValue = v
  120. }
  121. if num, ok := field.TagSettings["SIZE"]; ok {
  122. field.Size, _ = strconv.Atoi(num)
  123. }
  124. if p, ok := field.TagSettings["PRECISION"]; ok {
  125. field.Precision, _ = strconv.Atoi(p)
  126. }
  127. if val, ok := field.TagSettings["NOT NULL"]; ok && checkTruth(val) {
  128. field.NotNull = true
  129. }
  130. if val, ok := field.TagSettings["UNIQUE"]; ok && checkTruth(val) {
  131. field.Unique = true
  132. }
  133. if val, ok := field.TagSettings["COMMENT"]; ok {
  134. field.Comment = val
  135. }
  136. if val, ok := field.TagSettings["TYPE"]; ok {
  137. field.DBDataType = val
  138. }
  139. if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) {
  140. if strings.ToUpper(v) == "NANO" {
  141. field.AutoCreateTime = UnixNanosecond
  142. } else {
  143. field.AutoCreateTime = UnixSecond
  144. }
  145. }
  146. if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) {
  147. if strings.ToUpper(v) == "NANO" {
  148. field.AutoUpdateTime = UnixNanosecond
  149. } else {
  150. field.AutoUpdateTime = UnixSecond
  151. }
  152. }
  153. switch fieldValue.Elem().Kind() {
  154. case reflect.Bool:
  155. field.DataType = Bool
  156. if field.HasDefaultValue {
  157. field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue)
  158. }
  159. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  160. field.DataType = Int
  161. if field.HasDefaultValue {
  162. field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64)
  163. }
  164. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  165. field.DataType = Uint
  166. if field.HasDefaultValue {
  167. field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64)
  168. }
  169. case reflect.Float32, reflect.Float64:
  170. field.DataType = Float
  171. if field.HasDefaultValue {
  172. field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64)
  173. }
  174. case reflect.String:
  175. field.DataType = String
  176. if field.HasDefaultValue {
  177. field.DefaultValueInterface = field.DefaultValue
  178. }
  179. case reflect.Struct:
  180. if _, ok := fieldValue.Interface().(*time.Time); ok {
  181. field.DataType = Time
  182. } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
  183. field.DataType = Time
  184. }
  185. case reflect.Array, reflect.Slice:
  186. if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) {
  187. field.DataType = Bytes
  188. }
  189. }
  190. if field.Size == 0 {
  191. switch fieldValue.Kind() {
  192. case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
  193. field.Size = 64
  194. case reflect.Int8, reflect.Uint8:
  195. field.Size = 8
  196. case reflect.Int16, reflect.Uint16:
  197. field.Size = 16
  198. case reflect.Int32, reflect.Uint32, reflect.Float32:
  199. field.Size = 32
  200. }
  201. }
  202. if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
  203. var err error
  204. field.Creatable = false
  205. field.Updatable = false
  206. if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
  207. schema.err = err
  208. }
  209. for _, ef := range field.EmbeddedSchema.Fields {
  210. ef.Schema = schema
  211. ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
  212. // index is negative means is pointer
  213. if field.FieldType.Kind() == reflect.Struct {
  214. ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
  215. } else {
  216. ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
  217. }
  218. if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
  219. ef.DBName = prefix + ef.DBName
  220. }
  221. for k, v := range field.TagSettings {
  222. ef.TagSettings[k] = v
  223. }
  224. }
  225. }
  226. return field
  227. }
  228. // create valuer, setter when parse struct
  229. func (field *Field) setupValuerAndSetter() {
  230. // ValueOf
  231. switch {
  232. case len(field.StructField.Index) == 1:
  233. field.ValueOf = func(value reflect.Value) (interface{}, bool) {
  234. fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
  235. return fieldValue.Interface(), fieldValue.IsZero()
  236. }
  237. case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
  238. field.ValueOf = func(value reflect.Value) (interface{}, bool) {
  239. fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
  240. return fieldValue.Interface(), fieldValue.IsZero()
  241. }
  242. default:
  243. field.ValueOf = func(value reflect.Value) (interface{}, bool) {
  244. v := reflect.Indirect(value)
  245. for _, idx := range field.StructField.Index {
  246. if idx >= 0 {
  247. v = v.Field(idx)
  248. } else {
  249. v = v.Field(-idx - 1)
  250. if v.Type().Elem().Kind() == reflect.Struct {
  251. if !v.IsNil() {
  252. v = v.Elem()
  253. }
  254. } else {
  255. return nil, true
  256. }
  257. }
  258. }
  259. return v.Interface(), v.IsZero()
  260. }
  261. }
  262. // ReflectValueOf
  263. switch {
  264. case len(field.StructField.Index) == 1:
  265. if field.FieldType.Kind() == reflect.Ptr {
  266. field.ReflectValueOf = func(value reflect.Value) reflect.Value {
  267. fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
  268. if fieldValue.IsNil() {
  269. fieldValue.Set(reflect.New(field.FieldType.Elem()))
  270. }
  271. return fieldValue
  272. }
  273. } else {
  274. field.ReflectValueOf = func(value reflect.Value) reflect.Value {
  275. return reflect.Indirect(value).Field(field.StructField.Index[0])
  276. }
  277. }
  278. case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
  279. field.ReflectValueOf = func(value reflect.Value) reflect.Value {
  280. return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
  281. }
  282. default:
  283. field.ReflectValueOf = func(value reflect.Value) reflect.Value {
  284. v := reflect.Indirect(value)
  285. for _, idx := range field.StructField.Index {
  286. if idx >= 0 {
  287. v = v.Field(idx)
  288. } else {
  289. v = v.Field(-idx - 1)
  290. }
  291. if v.Kind() == reflect.Ptr {
  292. if v.Type().Elem().Kind() == reflect.Struct {
  293. if v.IsNil() {
  294. v.Set(reflect.New(v.Type().Elem()))
  295. }
  296. }
  297. if idx < len(field.StructField.Index)-1 {
  298. v = v.Elem()
  299. }
  300. }
  301. }
  302. return v
  303. }
  304. }
  305. recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
  306. reflectV := reflect.ValueOf(v)
  307. if reflectV.Type().ConvertibleTo(field.FieldType) {
  308. field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
  309. } else if valuer, ok := v.(driver.Valuer); ok {
  310. if v, err = valuer.Value(); err == nil {
  311. return setter(value, v)
  312. }
  313. } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
  314. field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
  315. } else if reflectV.Kind() == reflect.Ptr {
  316. return field.Set(value, reflectV.Elem().Interface())
  317. } else {
  318. return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
  319. }
  320. return err
  321. }
  322. // Set
  323. switch field.FieldType.Kind() {
  324. case reflect.Bool:
  325. field.Set = func(value reflect.Value, v interface{}) error {
  326. switch data := v.(type) {
  327. case bool:
  328. field.ReflectValueOf(value).SetBool(data)
  329. case *bool:
  330. field.ReflectValueOf(value).SetBool(*data)
  331. default:
  332. return recoverFunc(value, v, field.Set)
  333. }
  334. return nil
  335. }
  336. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  337. field.Set = func(value reflect.Value, v interface{}) (err error) {
  338. switch data := v.(type) {
  339. case int64:
  340. field.ReflectValueOf(value).SetInt(data)
  341. case int:
  342. field.ReflectValueOf(value).SetInt(int64(data))
  343. case int8:
  344. field.ReflectValueOf(value).SetInt(int64(data))
  345. case int16:
  346. field.ReflectValueOf(value).SetInt(int64(data))
  347. case int32:
  348. field.ReflectValueOf(value).SetInt(int64(data))
  349. case uint:
  350. field.ReflectValueOf(value).SetInt(int64(data))
  351. case uint8:
  352. field.ReflectValueOf(value).SetInt(int64(data))
  353. case uint16:
  354. field.ReflectValueOf(value).SetInt(int64(data))
  355. case uint32:
  356. field.ReflectValueOf(value).SetInt(int64(data))
  357. case uint64:
  358. field.ReflectValueOf(value).SetInt(int64(data))
  359. case float32:
  360. field.ReflectValueOf(value).SetInt(int64(data))
  361. case float64:
  362. field.ReflectValueOf(value).SetInt(int64(data))
  363. case []byte:
  364. return field.Set(value, string(data))
  365. case string:
  366. if i, err := strconv.ParseInt(data, 0, 64); err == nil {
  367. field.ReflectValueOf(value).SetInt(i)
  368. } else {
  369. return err
  370. }
  371. case time.Time:
  372. if field.AutoCreateTime == UnixNanosecond {
  373. field.ReflectValueOf(value).SetInt(data.UnixNano())
  374. } else {
  375. field.ReflectValueOf(value).SetInt(data.Unix())
  376. }
  377. case *time.Time:
  378. if data != nil {
  379. if field.AutoCreateTime == UnixNanosecond {
  380. field.ReflectValueOf(value).SetInt(data.UnixNano())
  381. } else {
  382. field.ReflectValueOf(value).SetInt(data.Unix())
  383. }
  384. } else {
  385. field.ReflectValueOf(value).SetInt(0)
  386. }
  387. default:
  388. return recoverFunc(value, v, field.Set)
  389. }
  390. return err
  391. }
  392. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  393. field.Set = func(value reflect.Value, v interface{}) (err error) {
  394. switch data := v.(type) {
  395. case uint64:
  396. field.ReflectValueOf(value).SetUint(data)
  397. case uint:
  398. field.ReflectValueOf(value).SetUint(uint64(data))
  399. case uint8:
  400. field.ReflectValueOf(value).SetUint(uint64(data))
  401. case uint16:
  402. field.ReflectValueOf(value).SetUint(uint64(data))
  403. case uint32:
  404. field.ReflectValueOf(value).SetUint(uint64(data))
  405. case int64:
  406. field.ReflectValueOf(value).SetUint(uint64(data))
  407. case int:
  408. field.ReflectValueOf(value).SetUint(uint64(data))
  409. case int8:
  410. field.ReflectValueOf(value).SetUint(uint64(data))
  411. case int16:
  412. field.ReflectValueOf(value).SetUint(uint64(data))
  413. case int32:
  414. field.ReflectValueOf(value).SetUint(uint64(data))
  415. case float32:
  416. field.ReflectValueOf(value).SetUint(uint64(data))
  417. case float64:
  418. field.ReflectValueOf(value).SetUint(uint64(data))
  419. case []byte:
  420. return field.Set(value, string(data))
  421. case string:
  422. if i, err := strconv.ParseUint(data, 0, 64); err == nil {
  423. field.ReflectValueOf(value).SetUint(i)
  424. } else {
  425. return err
  426. }
  427. default:
  428. return recoverFunc(value, v, field.Set)
  429. }
  430. return err
  431. }
  432. case reflect.Float32, reflect.Float64:
  433. field.Set = func(value reflect.Value, v interface{}) (err error) {
  434. switch data := v.(type) {
  435. case float64:
  436. field.ReflectValueOf(value).SetFloat(data)
  437. case float32:
  438. field.ReflectValueOf(value).SetFloat(float64(data))
  439. case int64:
  440. field.ReflectValueOf(value).SetFloat(float64(data))
  441. case int:
  442. field.ReflectValueOf(value).SetFloat(float64(data))
  443. case int8:
  444. field.ReflectValueOf(value).SetFloat(float64(data))
  445. case int16:
  446. field.ReflectValueOf(value).SetFloat(float64(data))
  447. case int32:
  448. field.ReflectValueOf(value).SetFloat(float64(data))
  449. case uint:
  450. field.ReflectValueOf(value).SetFloat(float64(data))
  451. case uint8:
  452. field.ReflectValueOf(value).SetFloat(float64(data))
  453. case uint16:
  454. field.ReflectValueOf(value).SetFloat(float64(data))
  455. case uint32:
  456. field.ReflectValueOf(value).SetFloat(float64(data))
  457. case uint64:
  458. field.ReflectValueOf(value).SetFloat(float64(data))
  459. case []byte:
  460. return field.Set(value, string(data))
  461. case string:
  462. if i, err := strconv.ParseFloat(data, 64); err == nil {
  463. field.ReflectValueOf(value).SetFloat(i)
  464. } else {
  465. return err
  466. }
  467. default:
  468. return recoverFunc(value, v, field.Set)
  469. }
  470. return err
  471. }
  472. case reflect.String:
  473. field.Set = func(value reflect.Value, v interface{}) (err error) {
  474. switch data := v.(type) {
  475. case string:
  476. field.ReflectValueOf(value).SetString(data)
  477. case []byte:
  478. field.ReflectValueOf(value).SetString(string(data))
  479. case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
  480. field.ReflectValueOf(value).SetString(fmt.Sprint(data))
  481. case float64, float32:
  482. field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
  483. default:
  484. return recoverFunc(value, v, field.Set)
  485. }
  486. return err
  487. }
  488. default:
  489. fieldValue := reflect.New(field.FieldType)
  490. switch fieldValue.Elem().Interface().(type) {
  491. case time.Time:
  492. field.Set = func(value reflect.Value, v interface{}) error {
  493. switch data := v.(type) {
  494. case time.Time:
  495. field.ReflectValueOf(value).Set(reflect.ValueOf(v))
  496. case *time.Time:
  497. field.ReflectValueOf(value).Set(reflect.ValueOf(v).Elem())
  498. case string:
  499. if t, err := now.Parse(data); err == nil {
  500. field.ReflectValueOf(value).Set(reflect.ValueOf(t))
  501. } else {
  502. return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
  503. }
  504. default:
  505. return recoverFunc(value, v, field.Set)
  506. }
  507. return nil
  508. }
  509. case *time.Time:
  510. field.Set = func(value reflect.Value, v interface{}) error {
  511. switch data := v.(type) {
  512. case time.Time:
  513. field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v))
  514. case *time.Time:
  515. field.ReflectValueOf(value).Set(reflect.ValueOf(v))
  516. case string:
  517. if t, err := now.Parse(data); err == nil {
  518. field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t))
  519. } else {
  520. return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
  521. }
  522. default:
  523. return recoverFunc(value, v, field.Set)
  524. }
  525. return nil
  526. }
  527. default:
  528. if _, ok := fieldValue.Interface().(sql.Scanner); ok {
  529. // struct scanner
  530. field.Set = func(value reflect.Value, v interface{}) (err error) {
  531. reflectV := reflect.ValueOf(v)
  532. if reflectV.Type().ConvertibleTo(field.FieldType) {
  533. field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
  534. } else if valuer, ok := v.(driver.Valuer); ok {
  535. if v, err = valuer.Value(); err == nil {
  536. err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
  537. }
  538. } else {
  539. err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
  540. }
  541. return
  542. }
  543. } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
  544. // pointer scanner
  545. field.Set = func(value reflect.Value, v interface{}) (err error) {
  546. reflectV := reflect.ValueOf(v)
  547. if reflectV.Type().ConvertibleTo(field.FieldType) {
  548. field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
  549. } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
  550. field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
  551. } else if valuer, ok := v.(driver.Valuer); ok {
  552. if v, err = valuer.Value(); err == nil {
  553. err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
  554. }
  555. } else {
  556. err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
  557. }
  558. return
  559. }
  560. } else {
  561. field.Set = func(value reflect.Value, v interface{}) (err error) {
  562. return recoverFunc(value, v, field.Set)
  563. }
  564. }
  565. }
  566. }
  567. }