schema.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package schema
  2. import (
  3. "errors"
  4. "fmt"
  5. "go/ast"
  6. "reflect"
  7. "sync"
  8. "github.com/jinzhu/gorm/logger"
  9. )
  10. // ErrUnsupportedDataType unsupported data type
  11. var ErrUnsupportedDataType = errors.New("unsupported data type")
  12. type Schema struct {
  13. Name string
  14. ModelType reflect.Type
  15. Table string
  16. PrioritizedPrimaryField *Field
  17. DBNames []string
  18. PrimaryFields []*Field
  19. Fields []*Field
  20. FieldsByName map[string]*Field
  21. FieldsByDBName map[string]*Field
  22. FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
  23. Relationships Relationships
  24. BeforeCreate, AfterCreate bool
  25. BeforeUpdate, AfterUpdate bool
  26. BeforeDelete, AfterDelete bool
  27. BeforeSave, AfterSave bool
  28. AfterFind bool
  29. err error
  30. namer Namer
  31. cacheStore *sync.Map
  32. }
  33. func (schema Schema) String() string {
  34. if schema.ModelType.Name() == "" {
  35. return fmt.Sprintf("%v(%v)", schema.Name, schema.Table)
  36. }
  37. return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name())
  38. }
  39. func (schema Schema) LookUpField(name string) *Field {
  40. if field, ok := schema.FieldsByDBName[name]; ok {
  41. return field
  42. }
  43. if field, ok := schema.FieldsByName[name]; ok {
  44. return field
  45. }
  46. return nil
  47. }
  48. // get data type from dialector
  49. func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
  50. modelType := reflect.ValueOf(dest).Type()
  51. for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
  52. modelType = modelType.Elem()
  53. }
  54. if modelType.Kind() != reflect.Struct {
  55. if modelType.PkgPath() == "" {
  56. return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
  57. }
  58. return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
  59. }
  60. if v, ok := cacheStore.Load(modelType); ok {
  61. return v.(*Schema), nil
  62. }
  63. schema := &Schema{
  64. Name: modelType.Name(),
  65. ModelType: modelType,
  66. Table: namer.TableName(modelType.Name()),
  67. FieldsByName: map[string]*Field{},
  68. FieldsByDBName: map[string]*Field{},
  69. Relationships: Relationships{Relations: map[string]*Relationship{}},
  70. cacheStore: cacheStore,
  71. namer: namer,
  72. }
  73. defer func() {
  74. if schema.err != nil {
  75. logger.Default.Error(schema.err.Error())
  76. cacheStore.Delete(modelType)
  77. }
  78. }()
  79. for i := 0; i < modelType.NumField(); i++ {
  80. if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
  81. if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
  82. schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
  83. } else {
  84. schema.Fields = append(schema.Fields, field)
  85. }
  86. }
  87. }
  88. for _, field := range schema.Fields {
  89. if field.DBName == "" && field.DataType != "" {
  90. field.DBName = namer.ColumnName(schema.Table, field.Name)
  91. }
  92. if field.DBName != "" {
  93. // nonexistence or shortest path or first appear prioritized if has permission
  94. if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
  95. if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
  96. schema.DBNames = append(schema.DBNames, field.DBName)
  97. }
  98. schema.FieldsByDBName[field.DBName] = field
  99. schema.FieldsByName[field.Name] = field
  100. if v != nil && v.PrimaryKey {
  101. if schema.PrioritizedPrimaryField == v {
  102. schema.PrioritizedPrimaryField = nil
  103. }
  104. for idx, f := range schema.PrimaryFields {
  105. if f == v {
  106. schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
  107. } else if schema.PrioritizedPrimaryField == nil {
  108. schema.PrioritizedPrimaryField = f
  109. }
  110. }
  111. }
  112. if field.PrimaryKey {
  113. if schema.PrioritizedPrimaryField == nil {
  114. schema.PrioritizedPrimaryField = field
  115. }
  116. schema.PrimaryFields = append(schema.PrimaryFields, field)
  117. }
  118. }
  119. }
  120. if _, ok := schema.FieldsByName[field.Name]; !ok {
  121. schema.FieldsByName[field.Name] = field
  122. }
  123. field.setupValuerAndSetter()
  124. }
  125. if f := schema.LookUpField("id"); f != nil {
  126. if f.PrimaryKey {
  127. schema.PrioritizedPrimaryField = f
  128. } else if len(schema.PrimaryFields) == 0 {
  129. f.PrimaryKey = true
  130. schema.PrioritizedPrimaryField = f
  131. schema.PrimaryFields = append(schema.PrimaryFields, f)
  132. }
  133. }
  134. schema.FieldsWithDefaultDBValue = map[string]*Field{}
  135. for db, field := range schema.FieldsByDBName {
  136. if field.HasDefaultValue && field.DefaultValueInterface == nil {
  137. schema.FieldsWithDefaultDBValue[db] = field
  138. }
  139. }
  140. if schema.PrioritizedPrimaryField != nil {
  141. switch schema.PrioritizedPrimaryField.DataType {
  142. case Int, Uint:
  143. schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField
  144. }
  145. }
  146. reflectValue := reflect.Indirect(reflect.New(modelType))
  147. callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
  148. for _, name := range callbacks {
  149. if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
  150. switch methodValue.Type().String() {
  151. case "func(*gorm.DB)": // TODO hack
  152. reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
  153. default:
  154. logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
  155. }
  156. }
  157. }
  158. cacheStore.Store(modelType, schema)
  159. // parse relations for unidentified fields
  160. for _, field := range schema.Fields {
  161. if field.DataType == "" && field.Creatable {
  162. if schema.parseRelation(field); schema.err != nil {
  163. return schema, schema.err
  164. }
  165. }
  166. }
  167. return schema, schema.err
  168. }