Browse Source

Merge pull request #2396 from emirb/fix-singulartable-race-condition

Fix SingularTable race condition
Emir Beganović 5 years ago
parent
commit
e3987fd4b8
5 changed files with 60 additions and 20 deletions
  1. 3 1
      main.go
  2. 29 4
      main_test.go
  3. 15 2
      model_struct.go
  4. 1 1
      preload_test.go
  5. 12 12
      wercker.yml

+ 3 - 1
main.go

@@ -12,6 +12,7 @@ import (
 
 // DB contains information for current db connection
 type DB struct {
+	sync.RWMutex
 	Value        interface{}
 	Error        error
 	RowsAffected int64
@@ -170,7 +171,8 @@ func (s *DB) HasBlockGlobalUpdate() bool {
 
 // SingularTable use singular table by default
 func (s *DB) SingularTable(enable bool) {
-	modelStructsMap = sync.Map{}
+	s.parent.Lock()
+	defer s.parent.Unlock()
 	s.parent.singularTable = enable
 }
 

+ 29 - 4
main_test.go

@@ -9,6 +9,7 @@ import (
 	"reflect"
 	"strconv"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 
@@ -277,6 +278,30 @@ func TestTableName(t *testing.T) {
 	DB.SingularTable(false)
 }
 
+func TestTableNameConcurrently(t *testing.T) {
+	DB := DB.Model("")
+	if DB.NewScope(Order{}).TableName() != "orders" {
+		t.Errorf("Order's table name should be orders")
+	}
+
+	var wg sync.WaitGroup
+	wg.Add(10)
+
+	for i := 1; i <= 10; i++ {
+		go func(db *gorm.DB) {
+			DB.SingularTable(true)
+			wg.Done()
+		}(DB)
+	}
+	wg.Wait()
+
+	if DB.NewScope(Order{}).TableName() != "order" {
+		t.Errorf("Order's singular table name should be order")
+	}
+
+	DB.SingularTable(false)
+}
+
 func TestNullValues(t *testing.T) {
 	DB.DropTable(&NullValue{})
 	DB.AutoMigrate(&NullValue{})
@@ -1066,12 +1091,12 @@ func TestCountWithHaving(t *testing.T) {
 
 	DB.Create(getPreparedUser("user1", "pluck_user"))
 	DB.Create(getPreparedUser("user2", "pluck_user"))
-	user3:=getPreparedUser("user3", "pluck_user")
-	user3.Languages=[]Language{}
+	user3 := getPreparedUser("user3", "pluck_user")
+	user3.Languages = []Language{}
 	DB.Create(user3)
 
 	var count int
-	err:=db.Model(User{}).Select("users.id").
+	err := db.Model(User{}).Select("users.id").
 		Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id").
 		Joins("LEFT JOIN languages ON user_languages.language_id = languages.id").
 		Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error
@@ -1080,7 +1105,7 @@ func TestCountWithHaving(t *testing.T) {
 		t.Error("Unexpected error on query count with having")
 	}
 
-	if count!=2{
+	if count != 2 {
 		t.Error("Unexpected result on query count with having")
 	}
 }

+ 15 - 2
model_struct.go

@@ -40,9 +40,11 @@ func (s *ModelStruct) TableName(db *DB) string {
 			s.defaultTableName = tabler.TableName()
 		} else {
 			tableName := ToTableName(s.ModelType.Name())
+			db.parent.RLock()
 			if db == nil || (db.parent != nil && !db.parent.singularTable) {
 				tableName = inflection.Plural(tableName)
 			}
+			db.parent.RUnlock()
 			s.defaultTableName = tableName
 		}
 	}
@@ -163,7 +165,18 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
 	}
 
 	// Get Cached model struct
-	if value, ok := modelStructsMap.Load(reflectType); ok && value != nil {
+	isSingularTable := false
+	if scope.db != nil && scope.db.parent != nil {
+		scope.db.parent.RLock()
+		isSingularTable = scope.db.parent.singularTable
+		scope.db.parent.RUnlock()
+	}
+
+	hashKey := struct {
+		singularTable bool
+		reflectType   reflect.Type
+	}{isSingularTable, reflectType}
+	if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
 		return value.(*ModelStruct)
 	}
 
@@ -612,7 +625,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
 		}
 	}
 
-	modelStructsMap.Store(reflectType, &modelStruct)
+	modelStructsMap.Store(hashKey, &modelStruct)
 
 	return &modelStruct
 }

+ 1 - 1
preload_test.go

@@ -1677,7 +1677,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
 	lvl := Level1{
 		Name: "l1",
 		Level2s: []Level2{
-			Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
+			{Name: "l2-1"}, {Name: "l2-2"},
 		},
 	}
 	DB.Save(&lvl)

+ 12 - 12
wercker.yml

@@ -83,7 +83,7 @@ build:
                 code: |
                     cd $WERCKER_SOURCE_DIR
                     go version
-                    go get -t ./...
+                    go get -t -v ./...
 
         # Build the project
         - script:
@@ -95,54 +95,54 @@ build:
         - script:
                 name: test sqlite
                 code: |
-                    go test ./...
+                    go test -race -v ./...
 
         - script:
                 name: test mariadb
                 code: |
-                    GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./...
+                    GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
 
         - script:
                 name: test mysql5.7
                 code: |
-                    GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./...
+                    GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
 
         - script:
                 name: test mysql5.6
                 code: |
-                    GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./...
+                    GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
 
         - script:
                 name: test mysql5.5
                 code: |
-                    GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./...
+                    GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test -race ./...
 
         - script:
                 name: test postgres
                 code: |
-                    GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
+                    GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
 
         - script:
                 name: test postgres96
                 code: |
-                    GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
+                    GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
 
         - script:
                 name: test postgres95
                 code: |
-                    GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
+                    GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
 
         - script:
                 name: test postgres94
                 code: |
-                    GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
+                    GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
 
         - script:
                 name: test postgres93
                 code: |
-                    GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
+                    GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./...
 
         - script:
                 name: test mssql
                 code: |
-                    GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...
+                    GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test -race ./...