Skip to content

Commit

Permalink
ddl, parser: make generated column and expression index same as MySQL (
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored Sep 12, 2024
1 parent 1097ba8 commit 337f4a0
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 22 deletions.
37 changes: 37 additions & 0 deletions ddl/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2378,6 +2378,43 @@ func TestSqlFunctionsInGeneratedColumns(t *testing.T) {
tk.MustExec("create table t (a int, b int as ((a)))")
}

func TestSchemaNameAndTableNameInGeneratedExpr(t *testing.T) {
store := testkit.CreateMockStore(t, mockstore.WithDDLChecker())

tk := testkit.NewTestKit(t, store)
tk.MustExec("create database if not exists test")
tk.MustExec("use test")
tk.MustExec("drop table if exists t")

tk.MustExec("create table t(a int, b int as (lower(test.t.a)))")
tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" +
" `a` int(11) DEFAULT NULL,\n" +
" `b` int(11) GENERATED ALWAYS AS (lower(`a`)) VIRTUAL\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))

tk.MustExec("drop table t")
tk.MustExec("create table t(a int)")
tk.MustExec("alter table t add column b int as (lower(test.t.a))")
tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" +
" `a` int(11) DEFAULT NULL,\n" +
" `b` int(11) GENERATED ALWAYS AS (lower(`a`)) VIRTUAL\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))

tk.MustGetErrCode("alter table t add index idx((lower(test.t1.a)))", errno.ErrBadField)

tk.MustExec("drop table t")
tk.MustGetErrCode("create table t(a int, b int as (lower(test1.t.a)))", errno.ErrWrongDBName)

tk.MustExec("create table t(a int)")
tk.MustGetErrCode("alter table t add column b int as (lower(test.t1.a))", errno.ErrWrongTableName)

tk.MustExec("alter table t add column c int")
tk.MustGetErrCode("alter table t modify column c int as (test.t1.a + 1) stored", errno.ErrWrongTableName)

tk.MustExec("alter table t add column d int as (lower(test.T.a))")
tk.MustExec("alter table t add column e int as (lower(Test.t.a))")
}

func TestParserIssue284(t *testing.T) {
store := testkit.CreateMockStore(t, mockstore.WithDDLChecker())

Expand Down
27 changes: 18 additions & 9 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o

var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)

for _, v := range colDef.Options {
Expand Down Expand Up @@ -1148,7 +1148,10 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o
}
col.GeneratedExprString = sb.String()
col.GeneratedStored = v.Stored
_, dependColNames := findDependedColumnNames(colDef)
_, dependColNames, err := findDependedColumnNames(model.NewCIStr(""), model.NewCIStr(""), colDef)
if err != nil {
return nil, nil, errors.Trace(err)
}
col.Dependences = dependColNames
case ast.ColumnOptionCollate:
if field_types.HasCharset(colDef.Tp) {
Expand Down Expand Up @@ -1567,7 +1570,7 @@ func IsAutoRandomColumnID(tblInfo *model.TableInfo, colID int64) bool {
return false
}

func checkGeneratedColumn(ctx sessionctx.Context, colDefs []*ast.ColumnDef) error {
func checkGeneratedColumn(ctx sessionctx.Context, schemaName model.CIStr, tableName model.CIStr, colDefs []*ast.ColumnDef) error {
var colName2Generation = make(map[string]columnGenerationInDDL, len(colDefs))
var exists bool
var autoIncrementColumn string
Expand All @@ -1582,7 +1585,10 @@ func checkGeneratedColumn(ctx sessionctx.Context, colDefs []*ast.ColumnDef) erro
if containsColumnOption(colDef, ast.ColumnOptionAutoIncrement) {
exists, autoIncrementColumn = true, colDef.Name.Name.L
}
generated, depCols := findDependedColumnNames(colDef)
generated, depCols, err := findDependedColumnNames(schemaName, tableName, colDef)
if err != nil {
return errors.Trace(err)
}
if !generated {
colName2Generation[colDef.Name.Name.L] = columnGenerationInDDL{
position: i,
Expand Down Expand Up @@ -2100,7 +2106,7 @@ func CheckTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo
func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) {
// All of these rely on the AST structure of expressions, which were
// lost in the model (got serialized into strings).
if err := checkGeneratedColumn(ctx, s.Cols); err != nil {
if err := checkGeneratedColumn(ctx, s.Table.Schema, tbInfo.Name, s.Cols); err != nil {
return errors.Trace(err)
}

Expand Down Expand Up @@ -3724,7 +3730,10 @@ func CreateNewColumn(ctx sessionctx.Context, ti ast.Ident, schema *model.DBInfo,
return nil, dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Adding generated stored column through ALTER TABLE")
}

_, dependColNames := findDependedColumnNames(specNewColumn)
_, dependColNames, err := findDependedColumnNames(schema.Name, t.Meta().Name, specNewColumn)
if err != nil {
return nil, errors.Trace(err)
}
if !ctx.GetSessionVars().EnableAutoIncrementInGenerated {
if err := checkAutoIncrementRef(specNewColumn.Name.Name.L, dependColNames, t.Meta()); err != nil {
return nil, errors.Trace(err)
Expand Down Expand Up @@ -4530,7 +4539,7 @@ func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.Col
func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []*ast.ColumnOption) error {
var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutSchemaName
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)

var hasDefaultValue, setOnUpdateNow bool
Expand Down Expand Up @@ -4867,7 +4876,7 @@ func GetModifiableColumnJob(
}

// As same with MySQL, we don't support modifying the stored status for generated columns.
if err = checkModifyGeneratedColumn(sctx, t, col, newCol, specNewColumn, spec.Position); err != nil {
if err = checkModifyGeneratedColumn(sctx, schema.Name, t, col, newCol, specNewColumn, spec.Position); err != nil {
return nil, errors.Trace(err)
}

Expand Down Expand Up @@ -6356,7 +6365,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as

var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)
sb.Reset()
err := idxPart.Expr.Restore(restoreCtx)
Expand Down
15 changes: 12 additions & 3 deletions ddl/generated_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,19 @@ func findPositionRelativeColumn(cols []*table.Column, pos *ast.ColumnPosition) (

// findDependedColumnNames returns a set of string, which indicates
// the names of the columns that are depended by colDef.
func findDependedColumnNames(colDef *ast.ColumnDef) (generated bool, colsMap map[string]struct{}) {
func findDependedColumnNames(schemaName model.CIStr, tableName model.CIStr, colDef *ast.ColumnDef) (generated bool, colsMap map[string]struct{}, err error) {
colsMap = make(map[string]struct{})
for _, option := range colDef.Options {
if option.Tp == ast.ColumnOptionGenerated {
generated = true
colNames := FindColumnNamesInExpr(option.Expr)
for _, depCol := range colNames {
if depCol.Schema.L != "" && schemaName.L != "" && depCol.Schema.L != schemaName.L {
return false, nil, dbterror.ErrWrongDBName.GenWithStackByArgs(depCol.Schema.O)
}
if depCol.Table.L != "" && tableName.L != "" && depCol.Table.L != tableName.L {
return false, nil, dbterror.ErrWrongTableName.GenWithStackByArgs(depCol.Table.O)
}
colsMap[depCol.Name.L] = struct{}{}
}
break
Expand Down Expand Up @@ -192,7 +198,7 @@ func (c *generatedColumnChecker) Leave(inNode ast.Node) (node ast.Node, ok bool)
// 3. check if the modified expr contains non-deterministic functions
// 4. check whether new column refers to any auto-increment columns.
// 5. check if the new column is indexed or stored
func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol, newCol *table.Column, newColDef *ast.ColumnDef, pos *ast.ColumnPosition) error {
func checkModifyGeneratedColumn(sctx sessionctx.Context, schemaName model.CIStr, tbl table.Table, oldCol, newCol *table.Column, newColDef *ast.ColumnDef, pos *ast.ColumnPosition) error {
// rule 1.
oldColIsStored := !oldCol.IsGenerated() || oldCol.GeneratedStored
newColIsStored := !newCol.IsGenerated() || newCol.GeneratedStored
Expand Down Expand Up @@ -252,7 +258,10 @@ func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol
}

// rule 4.
_, dependColNames := findDependedColumnNames(newColDef)
_, dependColNames, err := findDependedColumnNames(schemaName, tbl.Meta().Name, newColDef)
if err != nil {
return errors.Trace(err)
}
if !sctx.GetSessionVars().EnableAutoIncrementInGenerated {
if err := checkAutoIncrementRef(newColDef.Name.Name.L, dependColNames, tbl.Meta()); err != nil {
return errors.Trace(err)
Expand Down
15 changes: 15 additions & 0 deletions parser/ast/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,21 @@ func TestDDLColumnOptionRestore(t *testing.T) {
runNodeRestoreTest(t, testCases, "CREATE TABLE child (id INT %s)", extractNodeFunc)
}

func TestGeneratedRestore(t *testing.T) {
testCases := []NodeRestoreTestCase{
{"generated always as(id + 1)", "GENERATED ALWAYS AS(`id`+1) VIRTUAL"},
{"generated always as(id + 1) virtual", "GENERATED ALWAYS AS(`id`+1) VIRTUAL"},
{"generated always as(id + 1) stored", "GENERATED ALWAYS AS(`id`+1) STORED"},
{"generated always as(lower(id)) stored", "GENERATED ALWAYS AS(LOWER(`id`)) STORED"},
{"generated always as(lower(child.id)) stored", "GENERATED ALWAYS AS(LOWER(`id`)) STORED"},
}
extractNodeFunc := func(node Node) Node {
return node.(*CreateTableStmt).Cols[0].Options[0]
}
runNodeRestoreTestWithFlagsStmtChange(t, testCases, "CREATE TABLE child (id INT %s)", extractNodeFunc,
format.DefaultRestoreFlags|format.RestoreWithoutSchemaName|format.RestoreWithoutTableName)
}

func TestDDLColumnDefRestore(t *testing.T) {
testCases := []NodeRestoreTestCase{
// for type
Expand Down
18 changes: 10 additions & 8 deletions parser/ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,17 @@ func (*TableName) resultSet() {}

// Restore implements Node interface.
func (n *TableName) restoreName(ctx *format.RestoreCtx) {
// restore db name
if n.Schema.String() != "" {
ctx.WriteName(n.Schema.String())
ctx.WritePlain(".")
} else if ctx.DefaultDB != "" {
// Try CTE, for a CTE table name, we shouldn't write the database name.
if !ctx.IsCTETableName(n.Name.L) {
ctx.WriteName(ctx.DefaultDB)
if !ctx.Flags.HasWithoutSchemaNameFlag() {
// restore db name
if n.Schema.String() != "" {
ctx.WriteName(n.Schema.String())
ctx.WritePlain(".")
} else if ctx.DefaultDB != "" {
// Try CTE, for a CTE table name, we shouldn't write the database name.
if !ctx.IsCTETableName(n.Name.L) {
ctx.WriteName(ctx.DefaultDB)
ctx.WritePlain(".")
}
}
}
// restore table name
Expand Down
4 changes: 2 additions & 2 deletions parser/ast/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,11 @@ type ColumnName struct {

// Restore implements Node interface.
func (n *ColumnName) Restore(ctx *format.RestoreCtx) error {
if n.Schema.O != "" && !ctx.IsCTETableName(n.Table.L) {
if n.Schema.O != "" && !ctx.IsCTETableName(n.Table.L) && !ctx.Flags.HasWithoutSchemaNameFlag() {
ctx.WriteName(n.Schema.O)
ctx.WritePlain(".")
}
if n.Table.O != "" {
if n.Table.O != "" && !ctx.Flags.HasWithoutTableNameFlag() {
ctx.WriteName(n.Table.O)
ctx.WritePlain(".")
}
Expand Down
13 changes: 13 additions & 0 deletions parser/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ const (

RestoreTiDBSpecialComment
SkipPlacementRuleForRestore
RestoreWithTTLEnableOff
RestoreWithoutSchemaName
RestoreWithoutTableName
)

const (
Expand All @@ -246,6 +249,16 @@ func (rf RestoreFlags) has(flag RestoreFlags) bool {
return rf&flag != 0
}

// HasWithoutSchemaNameFlag returns a boolean indicating when `rf` has `RestoreWithoutSchemaName` flag.
func (rf RestoreFlags) HasWithoutSchemaNameFlag() bool {
return rf.has(RestoreWithoutSchemaName)
}

// HasWithoutTableNameFlag returns a boolean indicating when `rf` has `RestoreWithoutTableName` flag.
func (rf RestoreFlags) HasWithoutTableNameFlag() bool {
return rf.has(RestoreWithoutTableName)
}

// HasStringSingleQuotesFlag returns a boolean indicating when `rf` has `RestoreStringSingleQuotes` flag.
func (rf RestoreFlags) HasStringSingleQuotesFlag() bool {
return rf.has(RestoreStringSingleQuotes)
Expand Down

0 comments on commit 337f4a0

Please sign in to comment.