Przeglądaj źródła

Add migrator tests for postgres

Jinzhu 4 lat temu
rodzic
commit
ce84e82c9e

+ 0 - 4
dialects/mysql/mysql_test.go

@@ -9,10 +9,6 @@ import (
 	"github.com/jinzhu/gorm/tests"
 )
 
-func TestOpen(t *testing.T) {
-	gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), nil)
-}
-
 var (
 	DB  *gorm.DB
 	err error

+ 26 - 0
dialects/postgres/migrator.go

@@ -87,3 +87,29 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
 		return err
 	})
 }
+
+func (m Migrator) HasTable(value interface{}) bool {
+	var count int64
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema =  CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count)
+	})
+
+	return count > 0
+}
+
+func (m Migrator) HasColumn(value interface{}, field string) bool {
+	var count int64
+	m.RunWithValue(value, func(stmt *gorm.Statement) error {
+		name := field
+		if field := stmt.Schema.LookUpField(field); field != nil {
+			name = field.DBName
+		}
+
+		return m.DB.Raw(
+			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?",
+			stmt.Table, name,
+		).Row().Scan(&count)
+	})
+
+	return count > 0
+}

+ 5 - 3
dialects/postgres/postgres.go

@@ -3,6 +3,7 @@ package postgres
 import (
 	"database/sql"
 	"fmt"
+	"strconv"
 
 	"github.com/jinzhu/gorm"
 	"github.com/jinzhu/gorm/callbacks"
@@ -29,13 +30,14 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 
 func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
 	return Migrator{migrator.Migrator{Config: migrator.Config{
-		DB:        db,
-		Dialector: dialector,
+		DB:                          db,
+		Dialector:                   dialector,
+		CreateIndexAfterCreateTable: true,
 	}}}
 }
 
 func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
-	return "?"
+	return "$" + strconv.Itoa(len(stmt.Vars))
 }
 
 func (dialector Dialector) QuoteChars() [2]byte {

+ 29 - 0
dialects/postgres/postgres_test.go

@@ -0,0 +1,29 @@
+package postgres_test
+
+import (
+	"fmt"
+	"testing"
+
+	"github.com/jinzhu/gorm"
+	"github.com/jinzhu/gorm/dialects/postgres"
+	"github.com/jinzhu/gorm/tests"
+)
+
+var (
+	DB  *gorm.DB
+	err error
+)
+
+func init() {
+	if DB, err = gorm.Open(postgres.Open("user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"), &gorm.Config{}); err != nil {
+		panic(fmt.Sprintf("failed to initialize database, got error %v", err))
+	}
+}
+
+func TestCURD(t *testing.T) {
+	tests.RunTestsSuit(t, DB)
+}
+
+func TestMigrate(t *testing.T) {
+	tests.TestMigrate(t, DB)
+}