diff --git a/contrib/drivers/dm/dm_table_fields.go b/contrib/drivers/dm/dm_table_fields.go index 892973d37c1..9dc3c6c8ca1 100644 --- a/contrib/drivers/dm/dm_table_fields.go +++ b/contrib/drivers/dm/dm_table_fields.go @@ -11,12 +11,21 @@ import ( "fmt" "strings" + "github.com/gogf/gf/v2/container/gmap" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/util/gutil" ) +// escapeSingleQuote escapes single quotes in the string to prevent SQL injection. +// In SQL, single quotes are escaped by doubling them (two single quotes). +func escapeSingleQuote(s string) string { + return strings.ReplaceAll(s, "'", "''") +} + const ( - tableFieldsSqlTmp = `SELECT * FROM ALL_TAB_COLUMNS WHERE Table_Name= '%s' AND OWNER = '%s'` + tableFieldsSqlTmp = `SELECT c.COLUMN_NAME, c.DATA_TYPE, c.DATA_DEFAULT, c.NULLABLE, cc.COMMENTS FROM ALL_TAB_COLUMNS c LEFT JOIN ALL_COL_COMMENTS cc ON c.COLUMN_NAME = cc.COLUMN_NAME AND c.TABLE_NAME = cc.TABLE_NAME AND c.OWNER = cc.OWNER WHERE c.TABLE_NAME = '%s' AND c.OWNER = '%s'` + tableFieldsPkSqlSchemaTmp = `SELECT COLS.COLUMN_NAME AS PRIMARY_KEY_COLUMN FROM USER_CONSTRAINTS CONS JOIN USER_CONS_COLUMNS COLS ON CONS.CONSTRAINT_NAME = COLS.CONSTRAINT_NAME WHERE CONS.TABLE_NAME = '%s' AND CONS.CONSTRAINT_TYPE = 'P'` + tableFieldsPkSqlDBATmp = `SELECT COLS.COLUMN_NAME AS PRIMARY_KEY_COLUMN FROM DBA_CONSTRAINTS CONS JOIN DBA_CONS_COLUMNS COLS ON CONS.CONSTRAINT_NAME = COLS.CONSTRAINT_NAME WHERE CONS.TABLE_NAME = '%s' AND CONS.OWNER = '%s' AND CONS.CONSTRAINT_TYPE = 'P'` ) // TableFields retrieves and returns the fields' information of specified table of current schema. @@ -24,8 +33,9 @@ func (d *Driver) TableFields( ctx context.Context, table string, schema ...string, ) (fields map[string]*gdb.TableField, err error) { var ( - result gdb.Result - link gdb.Link + result gdb.Result + pkResult gdb.Result + link gdb.Link // When no schema is specified, the configuration item is returned by default usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) ) @@ -38,14 +48,35 @@ func (d *Driver) TableFields( ctx, link, fmt.Sprintf( tableFieldsSqlTmp, - strings.ToUpper(table), - strings.ToUpper(d.GetSchema()), + escapeSingleQuote(strings.ToUpper(table)), + escapeSingleQuote(strings.ToUpper(d.GetSchema())), ), ) if err != nil { return nil, err } + // Query the primary key field + pkResult, err = d.DoSelect( + ctx, link, + fmt.Sprintf(tableFieldsPkSqlSchemaTmp, escapeSingleQuote(strings.ToUpper(table))), + ) + if err != nil { + return nil, err + } + if pkResult.IsEmpty() { + pkResult, err = d.DoSelect( + ctx, link, + fmt.Sprintf(tableFieldsPkSqlDBATmp, escapeSingleQuote(strings.ToUpper(table)), escapeSingleQuote(strings.ToUpper(d.GetSchema()))), + ) + if err != nil { + return nil, err + } + } fields = make(map[string]*gdb.TableField) + pkFields := gmap.NewStrStrMap() + for _, pk := range pkResult { + pkFields.Set(pk["PRIMARY_KEY_COLUMN"].String(), "PRI") + } for i, m := range result { // m[NULLABLE] returns "N" "Y" // "N" means not null @@ -60,9 +91,9 @@ func (d *Driver) TableFields( Type: m["DATA_TYPE"].String(), Null: nullable, Default: m["DATA_DEFAULT"].Val(), - // Key: m["Key"].String(), + Key: pkFields.Get(m["COLUMN_NAME"].String()), // Extra: m["Extra"].String(), - // Comment: m["Comment"].String(), + Comment: m["COMMENTS"].String(), } } return fields, nil diff --git a/contrib/drivers/dm/dm_z_unit_basic_test.go b/contrib/drivers/dm/dm_z_unit_basic_test.go index 7e0f9c976a5..aeb83a70a7f 100644 --- a/contrib/drivers/dm/dm_z_unit_basic_test.go +++ b/contrib/drivers/dm/dm_z_unit_basic_test.go @@ -28,10 +28,7 @@ func Test_DB_Ping(t *testing.T) { } func TestTables(t *testing.T) { - tables := []string{"A_tables", "A_tables2"} - for _, v := range tables { - createInitTable(v) - } + tables := createInitTables(2) gtest.C(t, func(t *gtest.T) { result, err := db.Tables(ctx) gtest.AssertNil(err) @@ -39,7 +36,7 @@ func TestTables(t *testing.T) { for i := 0; i < len(tables); i++ { find := false for j := 0; j < len(result); j++ { - if strings.ToUpper(tables[i]) == result[j] { + if strings.ToUpper(tables[i]) == strings.ToUpper(result[j]) { find = true break } @@ -52,7 +49,7 @@ func TestTables(t *testing.T) { for i := 0; i < len(tables); i++ { find := false for j := 0; j < len(result); j++ { - if strings.ToUpper(tables[i]) == result[j] { + if strings.ToUpper(tables[i]) == strings.ToUpper(result[j]) { find = true break } @@ -91,9 +88,6 @@ func TestTableFields(t *testing.T) { "CREATED_TIME": {"TIMESTAMP", false}, } - _, err := dbErr.TableFields(ctx, "Fields") - gtest.AssertNE(err, nil) - res, err := db.TableFields(ctx, tables) gtest.AssertNil(err) @@ -114,6 +108,14 @@ func TestTableFields(t *testing.T) { }) } +func TestTableFields_WithWrongPassword(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + // dbErr is configured with wrong password, so it should return an error + _, err := dbErr.TableFields(ctx, "Fields") + gtest.AssertNE(err, nil) + }) +} + func Test_DB_Query(t *testing.T) { tableName := "A_tables" createInitTable(tableName) @@ -153,7 +155,6 @@ func TestModelSave(t *testing.T) { result sql.Result err error ) - db.SetDebug(true) result, err = db.Model(table).Data(g.Map{ "id": 1, diff --git a/contrib/drivers/dm/dm_z_unit_init_test.go b/contrib/drivers/dm/dm_z_unit_init_test.go index 5d81f649a3e..94d056716a9 100644 --- a/contrib/drivers/dm/dm_z_unit_init_test.go +++ b/contrib/drivers/dm/dm_z_unit_init_test.go @@ -67,7 +67,6 @@ func init() { UpdatedAt: "updated_time", } - // todo nodeLink := gdb.ConfigNode{ Type: TestDBType, Name: TestDBName, @@ -111,6 +110,8 @@ func init() { } ctx = context.Background() + + // db.SetDebug(true) } func dropTable(table string) { @@ -143,7 +144,7 @@ func createTable(table ...string) (name string) { CREATE TABLE "%s" ( "ID" BIGINT NOT NULL, -"ACCOUNT_NAME" VARCHAR(128) DEFAULT '' NOT NULL, +"ACCOUNT_NAME" VARCHAR(128) DEFAULT '' NOT NULL COMMENT 'Account Name', "PWD_RESET" TINYINT DEFAULT 0 NOT NULL, "ENABLED" INT DEFAULT 1 NOT NULL, "DELETED" INT DEFAULT 0 NOT NULL, @@ -156,7 +157,6 @@ NOT CLUSTER PRIMARY KEY("ID")) STORAGE(ON "MAIN", CLUSTERBTR) ; `, name)); err != nil { gtest.Fatal(err) } - return } @@ -169,7 +169,7 @@ func createInitTable(table ...string) (name string) { "account_name": fmt.Sprintf(`name_%d`, i), "pwd_reset": 0, "attr_index": i, - "create_time": gtime.Now().String(), + "created_time": gtime.Now(), }) } result, err := db.Schema(TestDBName).Insert(context.Background(), name, array.Slice()) @@ -212,3 +212,11 @@ NOT CLUSTER PRIMARY KEY("ID")) STORAGE(ON "MAIN", CLUSTERBTR) ; return name, nil } + +func createInitTables(len int) []string { + tables := make([]string, 0, len) + for range len { + tables = append(tables, createInitTable()) + } + return tables +} diff --git a/contrib/drivers/dm/dm_z_unit_pr_test.go b/contrib/drivers/dm/dm_z_unit_pr_test.go new file mode 100644 index 00000000000..65711f16968 --- /dev/null +++ b/contrib/drivers/dm/dm_z_unit_pr_test.go @@ -0,0 +1,40 @@ +// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package dm_test + +import ( + "testing" + + "github.com/gogf/gf/v2/test/gtest" +) + +// PR #4157 WherePri +func Test_WherePri_PR4157(t *testing.T) { + tableName := "A_tables" + createInitTable(tableName) + defer dropTable(tableName) + gtest.C(t, func(t *gtest.T) { + var resOne *User + err := db.Model(tableName).WherePri(1).Scan(&resOne) + t.AssertNil(err) + t.AssertNQ(resOne, nil) + t.AssertEQ(resOne.ID, int64(1)) + }) +} + +// PR #4157 get table field comments +func Test_TableFields_Comment_PR4157(t *testing.T) { + tableName := "A_tables" + schema := "SYSDBA" + createInitTable(tableName) + defer dropTable(tableName) + gtest.C(t, func(t *gtest.T) { + fields, err := db.Model().TableFields(tableName, schema) + t.AssertNil(err) + t.AssertEQ(fields["ACCOUNT_NAME"].Comment, "Account Name") + }) +}