Selaa lähdekoodia

Set nopLogger to DefaultCallback for avoid nil pointer dereference (#2742)

Shunsuke Otani 4 vuotta sitten
vanhempi
säilyke
e8c07b5531
3 muutettua tiedostoa jossa 36 lisäystä ja 7 poistoa
  1. 2 7
      callback.go
  2. 30 0
      callbacks_test.go
  3. 4 0
      logger.go

+ 2 - 7
callback.go

@@ -3,7 +3,7 @@ package gorm
 import "fmt"
 
 // DefaultCallback default callbacks defined by gorm
-var DefaultCallback = &Callback{}
+var DefaultCallback = &Callback{logger: nopLogger{}}
 
 // Callback is a struct that contains all CRUD callbacks
 //   Field `creates` contains callbacks will be call when creating object
@@ -101,12 +101,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
 		}
 	}
 
-	if cp.logger != nil {
-		// note cp.logger will be nil during the default gorm callback registrations
-		// as they occur within init() blocks. However, any user-registered callbacks
-		// will happen after cp.logger exists (as the default logger or user-specified).
-		cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
-	}
+	cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
 	cp.name = callbackName
 	cp.processor = &callback
 	cp.parent.processors = append(cp.parent.processors, cp)

+ 30 - 0
callbacks_test.go

@@ -217,3 +217,33 @@ func TestGetCallback(t *testing.T) {
 		t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
 	}
 }
+
+func TestUseDefaultCallback(t *testing.T) {
+	createCallbackName := "gorm:test_use_default_callback_for_create"
+	gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
+		// nop
+	})
+	if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
+		t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
+	}
+	gorm.DefaultCallback.Create().Remove(createCallbackName)
+	if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
+		t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
+	}
+
+	updateCallbackName := "gorm:test_use_default_callback_for_update"
+	scopeValueName := "gorm:test_use_default_callback_for_update_value"
+	gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
+		scope.Set(scopeValueName, 1)
+	})
+	gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
+		scope.Set(scopeValueName, 2)
+	})
+
+	scope := DB.NewScope(nil)
+	callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
+	callback(scope)
+	if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
+		t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
+	}
+}

+ 4 - 0
logger.go

@@ -135,3 +135,7 @@ type Logger struct {
 func (logger Logger) Print(values ...interface{}) {
 	logger.Println(LogFormatter(values...)...)
 }
+
+type nopLogger struct{}
+
+func (nopLogger) Print(values ...interface{}) {}