Skip to content

Commit b5d0bcd

Browse files
CbcWestwolfti-chi-bot
authored andcommitted
extension: make RelatedTables work when the statement fails (pingcap#50989)
close pingcap#50988
1 parent b14e28f commit b5d0bcd

File tree

9 files changed

+476
-102
lines changed

9 files changed

+476
-102
lines changed

pkg/extension/event_listener_test.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,18 @@ func TestExtensionStmtEvents(t *testing.T) {
332332
dispatchData: append([]byte{mysql.ComInitDB}, []byte("db1")...),
333333
originalText: "use `db1`",
334334
redactText: "use `db1`",
335+
tables: []stmtctx.TableEntry{
336+
{DB: "db1", Table: ""},
337+
},
335338
},
336339
{
337340
dispatchData: append([]byte{mysql.ComInitDB}, []byte("noexistdb")...),
338341
originalText: "use `noexistdb`",
339342
redactText: "use `noexistdb`",
340343
err: "[schema:1049]Unknown database 'noexistdb'",
344+
tables: []stmtctx.TableEntry{
345+
{DB: "noexistdb", Table: ""},
346+
},
341347
},
342348
{
343349
sql: "set @@tidb_session_alias='alias123'",
@@ -448,7 +454,8 @@ func TestExtensionStmtEvents(t *testing.T) {
448454
r := record.tables[j]
449455
return l.DB < r.DB || (l.DB == r.DB && l.Table < r.Table)
450456
})
451-
require.Equal(t, subCase.tables, record.tables)
457+
require.Equal(t, subCase.tables, record.tables,
458+
"sql: %s\noriginalText: %s\n", subCase.sql, subCase.originalText)
452459

453460
require.Equal(t, len(subCase.executeParams), len(record.params))
454461
for k, param := range subCase.executeParams {

pkg/extension/session.go

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ type StmtEventInfo interface {
8888
// AffectedRows will return the affected rows of the current statement
8989
AffectedRows() uint64
9090
// RelatedTables will return the related tables of the current statement
91+
// For statements succeeding to build logical plan, it uses the `visitinfo` to get the related tables
92+
// For statements failing to build logical plan, it traverses the ast node to get the related tables
9193
RelatedTables() []stmtctx.TableEntry
9294
// GetError will return the error when the current statement is failed
9395
GetError() error

pkg/parser/ast/misc.go

+7
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,13 @@ func (n *FlushStmt) Accept(v Visitor) (Node, bool) {
975975
return v.Leave(newNode)
976976
}
977977
n = newNode.(*FlushStmt)
978+
for i, t := range n.Tables {
979+
node, ok := t.Accept(v)
980+
if !ok {
981+
return n, false
982+
}
983+
n.Tables[i] = node.(*TableName)
984+
}
978985
return v.Leave(n)
979986
}
980987

pkg/planner/core/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ go_test(
224224
"rule_join_reorder_dp_test.go",
225225
"runtime_filter_generator_test.go",
226226
"stringer_test.go",
227+
"util_test.go",
227228
],
228229
data = glob(["testdata/**"]),
229230
embed = [":core"],

pkg/planner/core/logical_plan_builder.go

+119-95
Original file line numberDiff line numberDiff line change
@@ -3292,8 +3292,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {
32923292
}
32933293

32943294
func tblInfoFromCol(from ast.ResultSetNode, name *types.FieldName) *model.TableInfo {
3295-
var tableList []*ast.TableName
3296-
tableList = extractTableList(from, tableList, true)
3295+
tableList := ExtractTableList(from, true)
32973296
for _, field := range tableList {
32983297
if field.Name.L == name.TblName.L {
32993298
return field.TableInfo
@@ -5753,8 +5752,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (
57535752
return nil, err
57545753
}
57555754

5756-
var tableList []*ast.TableName
5757-
tableList = extractTableList(update.TableRefs.TableRefs, tableList, false)
5755+
tableList := ExtractTableList(update.TableRefs.TableRefs, false)
57585756
for _, t := range tableList {
57595757
dbName := t.Schema.L
57605758
if dbName == "" {
@@ -6282,8 +6280,7 @@ func (b *PlanBuilder) buildDelete(ctx context.Context, ds *ast.DeleteStmt) (Plan
62826280
}
62836281
} else {
62846282
// Delete from a, b, c, d.
6285-
var tableList []*ast.TableName
6286-
tableList = extractTableList(ds.TableRefs.TableRefs, tableList, false)
6283+
tableList := ExtractTableList(ds.TableRefs.TableRefs, false)
62876284
for _, v := range tableList {
62886285
if isCTE(v) {
62896286
return nil, plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(v.Name.O, "DELETE")
@@ -7097,17 +7094,6 @@ func buildWindowSpecs(specs []ast.WindowSpec) (map[string]*ast.WindowSpec, error
70977094
return specsMap, nil
70987095
}
70997096

7100-
func unfoldSelectList(list *ast.SetOprSelectList, unfoldList *ast.SetOprSelectList) {
7101-
for _, sel := range list.Selects {
7102-
switch s := sel.(type) {
7103-
case *ast.SelectStmt:
7104-
unfoldList.Selects = append(unfoldList.Selects, s)
7105-
case *ast.SetOprSelectList:
7106-
unfoldSelectList(s, unfoldList)
7107-
}
7108-
}
7109-
}
7110-
71117097
type updatableTableListResolver struct {
71127098
updatableTableList []*ast.TableName
71137099
}
@@ -7136,111 +7122,149 @@ func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) {
71367122
return inNode, true
71377123
}
71387124

7139-
// extractTableList extracts all the TableNames from node.
7125+
// ExtractTableList is a wrapper for tableListExtractor and removes duplicate TableName
71407126
// If asName is true, extract AsName prior to OrigName.
71417127
// Privilege check should use OrigName, while expression may use AsName.
7142-
// TODO: extracting all tables by vistor model maybe a better way
7143-
func extractTableList(node ast.Node, input []*ast.TableName, asName bool) []*ast.TableName {
7144-
switch x := node.(type) {
7145-
case *ast.SelectStmt:
7146-
if x.From != nil {
7147-
input = extractTableList(x.From.TableRefs, input, asName)
7148-
}
7149-
if x.Where != nil {
7150-
input = extractTableList(x.Where, input, asName)
7151-
}
7152-
if x.With != nil {
7153-
for _, cte := range x.With.CTEs {
7154-
input = extractTableList(cte.Query, input, asName)
7155-
}
7156-
}
7157-
for _, f := range x.Fields.Fields {
7158-
if s, ok := f.Expr.(*ast.SubqueryExpr); ok {
7159-
input = extractTableList(s, input, asName)
7160-
}
7161-
}
7162-
case *ast.DeleteStmt:
7163-
input = extractTableList(x.TableRefs.TableRefs, input, asName)
7164-
if x.IsMultiTable {
7165-
for _, t := range x.Tables.Tables {
7166-
input = extractTableList(t, input, asName)
7167-
}
7168-
}
7169-
if x.Where != nil {
7170-
input = extractTableList(x.Where, input, asName)
7171-
}
7172-
if x.With != nil {
7173-
for _, cte := range x.With.CTEs {
7174-
input = extractTableList(cte.Query, input, asName)
7175-
}
7176-
}
7177-
case *ast.UpdateStmt:
7178-
input = extractTableList(x.TableRefs.TableRefs, input, asName)
7179-
for _, e := range x.List {
7180-
input = extractTableList(e.Expr, input, asName)
7181-
}
7182-
if x.Where != nil {
7183-
input = extractTableList(x.Where, input, asName)
7184-
}
7185-
if x.With != nil {
7186-
for _, cte := range x.With.CTEs {
7187-
input = extractTableList(cte.Query, input, asName)
7128+
func ExtractTableList(node ast.Node, asName bool) []*ast.TableName {
7129+
if node == nil {
7130+
return []*ast.TableName{}
7131+
}
7132+
e := &tableListExtractor{
7133+
asName: asName,
7134+
tableNames: []*ast.TableName{},
7135+
}
7136+
node.Accept(e)
7137+
tableNames := e.tableNames
7138+
m := make(map[string]map[string]*ast.TableName) // k1: schemaName, k2: tableName, v: ast.TableName
7139+
for _, x := range tableNames {
7140+
k1, k2 := x.Schema.L, x.Name.L
7141+
// allow empty schema name OR empty table name
7142+
if k1 != "" || k2 != "" {
7143+
if _, ok := m[k1]; !ok {
7144+
m[k1] = make(map[string]*ast.TableName)
71887145
}
7146+
m[k1][k2] = x
71897147
}
7190-
case *ast.InsertStmt:
7191-
input = extractTableList(x.Table.TableRefs, input, asName)
7192-
input = extractTableList(x.Select, input, asName)
7193-
case *ast.SetOprStmt:
7194-
l := &ast.SetOprSelectList{}
7195-
unfoldSelectList(x.SelectList, l)
7196-
for _, s := range l.Selects {
7197-
input = extractTableList(s.(ast.ResultSetNode), input, asName)
7198-
}
7199-
case *ast.PatternInExpr:
7200-
if s, ok := x.Sel.(*ast.SubqueryExpr); ok {
7201-
input = extractTableList(s, input, asName)
7148+
}
7149+
tableNames = tableNames[:0]
7150+
for _, x := range m {
7151+
for _, v := range x {
7152+
tableNames = append(tableNames, v)
72027153
}
7203-
case *ast.ExistsSubqueryExpr:
7204-
if s, ok := x.Sel.(*ast.SubqueryExpr); ok {
7205-
input = extractTableList(s, input, asName)
7154+
}
7155+
return tableNames
7156+
}
7157+
7158+
// tableListExtractor extracts all the TableNames from node.
7159+
type tableListExtractor struct {
7160+
asName bool
7161+
tableNames []*ast.TableName
7162+
}
7163+
7164+
func (e *tableListExtractor) Enter(n ast.Node) (_ ast.Node, skipChildren bool) {
7165+
innerExtract := func(inner ast.Node) []*ast.TableName {
7166+
if inner == nil {
7167+
return nil
72067168
}
7207-
case *ast.BinaryOperationExpr:
7208-
if s, ok := x.R.(*ast.SubqueryExpr); ok {
7209-
input = extractTableList(s, input, asName)
7169+
innerExtractor := &tableListExtractor{
7170+
asName: e.asName,
7171+
tableNames: []*ast.TableName{},
72107172
}
7211-
case *ast.SubqueryExpr:
7212-
input = extractTableList(x.Query, input, asName)
7213-
case *ast.Join:
7214-
input = extractTableList(x.Left, input, asName)
7215-
input = extractTableList(x.Right, input, asName)
7173+
inner.Accept(innerExtractor)
7174+
return innerExtractor.tableNames
7175+
}
7176+
7177+
switch x := n.(type) {
7178+
case *ast.TableName:
7179+
e.tableNames = append(e.tableNames, x)
72167180
case *ast.TableSource:
72177181
if s, ok := x.Source.(*ast.TableName); ok {
7218-
if x.AsName.L != "" && asName {
7182+
if x.AsName.L != "" && e.asName {
72197183
newTableName := *s
72207184
newTableName.Name = x.AsName
72217185
newTableName.Schema = model.NewCIStr("")
7222-
input = append(input, &newTableName)
7186+
e.tableNames = append(e.tableNames, &newTableName)
72237187
} else {
7224-
input = append(input, s)
7188+
e.tableNames = append(e.tableNames, s)
72257189
}
72267190
} else if s, ok := x.Source.(*ast.SelectStmt); ok {
72277191
if s.From != nil {
7228-
var innerList []*ast.TableName
7229-
innerList = extractTableList(s.From.TableRefs, innerList, asName)
7192+
innerList := innerExtract(s.From.TableRefs)
72307193
if len(innerList) > 0 {
72317194
innerTableName := innerList[0]
7232-
if x.AsName.L != "" && asName {
7195+
if x.AsName.L != "" && e.asName {
72337196
newTableName := *innerList[0]
72347197
newTableName.Name = x.AsName
72357198
newTableName.Schema = model.NewCIStr("")
72367199
innerTableName = &newTableName
72377200
}
7238-
input = append(input, innerTableName)
7201+
e.tableNames = append(e.tableNames, innerTableName)
72397202
}
72407203
}
72417204
}
7205+
return n, true
7206+
7207+
case *ast.ShowStmt:
7208+
if x.DBName != "" {
7209+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)})
7210+
}
7211+
case *ast.CreateDatabaseStmt:
7212+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name})
7213+
case *ast.AlterDatabaseStmt:
7214+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name})
7215+
case *ast.DropDatabaseStmt:
7216+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name})
7217+
7218+
case *ast.FlashBackDatabaseStmt:
7219+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName})
7220+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.NewName)})
7221+
case *ast.FlashBackToTimestampStmt:
7222+
if x.DBName.L != "" {
7223+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName})
7224+
}
7225+
case *ast.FlashBackTableStmt:
7226+
if newName := x.NewName; newName != "" {
7227+
e.tableNames = append(e.tableNames, &ast.TableName{
7228+
Schema: x.Table.Schema,
7229+
Name: model.NewCIStr(newName)})
7230+
}
7231+
7232+
case *ast.GrantStmt:
7233+
if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone {
7234+
if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable {
7235+
e.tableNames = append(e.tableNames, &ast.TableName{
7236+
Schema: model.NewCIStr(x.Level.DBName),
7237+
Name: model.NewCIStr(x.Level.TableName),
7238+
})
7239+
}
7240+
}
7241+
case *ast.RevokeStmt:
7242+
if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone {
7243+
if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable {
7244+
e.tableNames = append(e.tableNames, &ast.TableName{
7245+
Schema: model.NewCIStr(x.Level.DBName),
7246+
Name: model.NewCIStr(x.Level.TableName),
7247+
})
7248+
}
7249+
}
7250+
case *ast.BRIEStmt:
7251+
if x.Kind == ast.BRIEKindBackup || x.Kind == ast.BRIEKindRestore {
7252+
for _, v := range x.Schemas {
7253+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(v)})
7254+
}
7255+
}
7256+
case *ast.UseStmt:
7257+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)})
7258+
case *ast.ExecuteStmt:
7259+
if v, ok := x.PrepStmt.(*PlanCacheStmt); ok {
7260+
e.tableNames = append(e.tableNames, innerExtract(v.PreparedAst.Stmt)...)
7261+
}
72427262
}
7243-
return input
7263+
return n, false
7264+
}
7265+
7266+
func (*tableListExtractor) Leave(n ast.Node) (ast.Node, bool) {
7267+
return n, true
72447268
}
72457269

72467270
func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, info *map[string]*ast.TableName) {

pkg/planner/core/point_get_plan.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -1937,8 +1937,7 @@ func buildPointUpdatePlan(ctx PlanContext, pointPlan PhysicalPlan, dbName string
19371937
}
19381938
if tbl.GetPartitionInfo() != nil {
19391939
pt := t.(table.PartitionedTable)
1940-
var updateTableList []*ast.TableName
1941-
updateTableList = extractTableList(updateStmt.TableRefs.TableRefs, updateTableList, true)
1940+
updateTableList := ExtractTableList(updateStmt.TableRefs.TableRefs, true)
19421941
updatePlan.PartitionedTable = make([]table.PartitionedTable, 0, len(updateTableList))
19431942
for _, updateTable := range updateTableList {
19441943
if len(updateTable.PartitionNames) > 0 {

pkg/planner/core/preprocess.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, def
533533
}
534534

535535
// Check the bind operation is not on any temporary table.
536-
tblNames := extractTableList(originNode, nil, false)
536+
tblNames := ExtractTableList(originNode, false)
537537
for _, tn := range tblNames {
538538
tbl, err := p.tableByName(tn)
539539
if err != nil {

0 commit comments

Comments
 (0)