schema.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package schema
  2. import (
  3. "errors"
  4. "fmt"
  5. "go/ast"
  6. "reflect"
  7. "sync"
  8. "github.com/jinzhu/gorm/logger"
  9. )
  10. // ErrUnsupportedDataType unsupported data type
  11. var ErrUnsupportedDataType = errors.New("unsupported data type")
  12. type Schema struct {
  13. Name string
  14. ModelType reflect.Type
  15. Table string
  16. PrioritizedPrimaryField *Field
  17. DBNames []string
  18. PrimaryFields []*Field
  19. Fields []*Field
  20. FieldsByName map[string]*Field
  21. FieldsByDBName map[string]*Field
  22. FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
  23. Relationships Relationships
  24. BeforeCreate, AfterCreate bool
  25. BeforeUpdate, AfterUpdate bool
  26. BeforeDelete, AfterDelete bool
  27. BeforeSave, AfterSave bool
  28. AfterFind bool
  29. err error
  30. namer Namer
  31. cacheStore *sync.Map
  32. }
  33. func (schema Schema) String() string {
  34. if schema.ModelType.Name() == "" {
  35. return fmt.Sprintf("%v(%v)", schema.Name, schema.Table)
  36. }
  37. return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name())
  38. }
  39. func (schema Schema) LookUpField(name string) *Field {
  40. if field, ok := schema.FieldsByDBName[name]; ok {
  41. return field
  42. }
  43. if field, ok := schema.FieldsByName[name]; ok {
  44. return field
  45. }
  46. return nil
  47. }
  48. // get data type from dialector
  49. func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
  50. reflectValue := reflect.ValueOf(dest)
  51. modelType := reflectValue.Type()
  52. for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
  53. modelType = modelType.Elem()
  54. }
  55. if modelType.Kind() != reflect.Struct {
  56. if modelType.PkgPath() == "" {
  57. return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
  58. }
  59. return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
  60. }
  61. if v, ok := cacheStore.Load(modelType); ok {
  62. return v.(*Schema), reflectValue, nil
  63. }
  64. schema := &Schema{
  65. Name: modelType.Name(),
  66. ModelType: modelType,
  67. Table: namer.TableName(modelType.Name()),
  68. FieldsByName: map[string]*Field{},
  69. FieldsByDBName: map[string]*Field{},
  70. Relationships: Relationships{Relations: map[string]*Relationship{}},
  71. cacheStore: cacheStore,
  72. namer: namer,
  73. }
  74. defer func() {
  75. if schema.err != nil {
  76. logger.Default.Error(schema.err.Error())
  77. cacheStore.Delete(modelType)
  78. }
  79. }()
  80. for i := 0; i < modelType.NumField(); i++ {
  81. if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
  82. if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
  83. schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
  84. } else {
  85. schema.Fields = append(schema.Fields, field)
  86. }
  87. }
  88. }
  89. for _, field := range schema.Fields {
  90. if field.DBName == "" && field.DataType != "" {
  91. field.DBName = namer.ColumnName(schema.Table, field.Name)
  92. }
  93. if field.DBName != "" {
  94. // nonexistence or shortest path or first appear prioritized if has permission
  95. if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
  96. if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
  97. schema.DBNames = append(schema.DBNames, field.DBName)
  98. }
  99. schema.FieldsByDBName[field.DBName] = field
  100. schema.FieldsByName[field.Name] = field
  101. if v != nil && v.PrimaryKey {
  102. if schema.PrioritizedPrimaryField == v {
  103. schema.PrioritizedPrimaryField = nil
  104. }
  105. for idx, f := range schema.PrimaryFields {
  106. if f == v {
  107. schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
  108. } else if schema.PrioritizedPrimaryField == nil {
  109. schema.PrioritizedPrimaryField = f
  110. }
  111. }
  112. }
  113. if field.PrimaryKey {
  114. if schema.PrioritizedPrimaryField == nil {
  115. schema.PrioritizedPrimaryField = field
  116. }
  117. schema.PrimaryFields = append(schema.PrimaryFields, field)
  118. }
  119. }
  120. }
  121. if _, ok := schema.FieldsByName[field.Name]; !ok {
  122. schema.FieldsByName[field.Name] = field
  123. }
  124. field.setupValuerAndSetter()
  125. }
  126. if f := schema.LookUpField("id"); f != nil {
  127. if f.PrimaryKey {
  128. schema.PrioritizedPrimaryField = f
  129. } else if len(schema.PrimaryFields) == 0 {
  130. f.PrimaryKey = true
  131. schema.PrioritizedPrimaryField = f
  132. schema.PrimaryFields = append(schema.PrimaryFields, f)
  133. }
  134. }
  135. schema.FieldsWithDefaultDBValue = map[string]*Field{}
  136. for db, field := range schema.FieldsByDBName {
  137. if field.HasDefaultValue && field.DefaultValueInterface == nil {
  138. schema.FieldsWithDefaultDBValue[db] = field
  139. }
  140. }
  141. if schema.PrioritizedPrimaryField != nil {
  142. switch schema.PrioritizedPrimaryField.DataType {
  143. case Int, Uint:
  144. schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField
  145. }
  146. }
  147. callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
  148. for _, name := range callbacks {
  149. if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
  150. switch methodValue.Type().String() {
  151. case "func(*gorm.DB)": // TODO hack
  152. reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
  153. default:
  154. logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
  155. }
  156. }
  157. }
  158. cacheStore.Store(modelType, schema)
  159. // parse relations for unidentified fields
  160. for _, field := range schema.Fields {
  161. if field.DataType == "" && field.Creatable {
  162. if schema.parseRelation(field); schema.err != nil {
  163. return schema, reflectValue, schema.err
  164. }
  165. }
  166. }
  167. return schema, reflectValue, schema.err
  168. }