Skip to content

Commit cd5a32a

Browse files
CbcWestwolfti-chi-bot
authored andcommitted
extension: make RelatedTables work when the statement fails (pingcap#50989)
close pingcap#50988
1 parent 1f29133 commit cd5a32a

File tree

9 files changed

+476
-102
lines changed

9 files changed

+476
-102
lines changed

pkg/extension/event_listener_test.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,18 @@ func TestExtensionStmtEvents(t *testing.T) {
332332
dispatchData: append([]byte{mysql.ComInitDB}, []byte("db1")...),
333333
originalText: "use `db1`",
334334
redactText: "use `db1`",
335+
tables: []stmtctx.TableEntry{
336+
{DB: "db1", Table: ""},
337+
},
335338
},
336339
{
337340
dispatchData: append([]byte{mysql.ComInitDB}, []byte("noexistdb")...),
338341
originalText: "use `noexistdb`",
339342
redactText: "use `noexistdb`",
340343
err: "[schema:1049]Unknown database 'noexistdb'",
344+
tables: []stmtctx.TableEntry{
345+
{DB: "noexistdb", Table: ""},
346+
},
341347
},
342348
{
343349
sql: "set @@tidb_session_alias='alias123'",
@@ -448,7 +454,8 @@ func TestExtensionStmtEvents(t *testing.T) {
448454
r := record.tables[j]
449455
return l.DB < r.DB || (l.DB == r.DB && l.Table < r.Table)
450456
})
451-
require.Equal(t, subCase.tables, record.tables)
457+
require.Equal(t, subCase.tables, record.tables,
458+
"sql: %s\noriginalText: %s\n", subCase.sql, subCase.originalText)
452459

453460
require.Equal(t, len(subCase.executeParams), len(record.params))
454461
for k, param := range subCase.executeParams {

pkg/extension/session.go

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ type StmtEventInfo interface {
8888
// AffectedRows will return the affected rows of the current statement
8989
AffectedRows() uint64
9090
// RelatedTables will return the related tables of the current statement
91+
// For statements succeeding to build logical plan, it uses the `visitinfo` to get the related tables
92+
// For statements failing to build logical plan, it traverses the ast node to get the related tables
9193
RelatedTables() []stmtctx.TableEntry
9294
// GetError will return the error when the current statement is failed
9395
GetError() error

pkg/parser/ast/misc.go

+7
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,13 @@ func (n *FlushStmt) Accept(v Visitor) (Node, bool) {
979979
return v.Leave(newNode)
980980
}
981981
n = newNode.(*FlushStmt)
982+
for i, t := range n.Tables {
983+
node, ok := t.Accept(v)
984+
if !ok {
985+
return n, false
986+
}
987+
n.Tables[i] = node.(*TableName)
988+
}
982989
return v.Leave(n)
983990
}
984991

pkg/planner/core/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ go_test(
223223
"rule_join_reorder_test.go",
224224
"runtime_filter_generator_test.go",
225225
"stringer_test.go",
226+
"util_test.go",
226227
],
227228
data = glob(["testdata/**"]),
228229
embed = [":core"],

pkg/planner/core/logical_plan_builder.go

+119-95
Original file line numberDiff line numberDiff line change
@@ -3419,8 +3419,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {
34193419
}
34203420

34213421
func tblInfoFromCol(from ast.ResultSetNode, name *types.FieldName) *model.TableInfo {
3422-
var tableList []*ast.TableName
3423-
tableList = extractTableList(from, tableList, true)
3422+
tableList := ExtractTableList(from, true)
34243423
for _, field := range tableList {
34253424
if field.Name.L == name.TblName.L {
34263425
return field.TableInfo
@@ -6094,8 +6093,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (
60946093
return nil, err
60956094
}
60966095

6097-
var tableList []*ast.TableName
6098-
tableList = extractTableList(update.TableRefs.TableRefs, tableList, false)
6096+
tableList := ExtractTableList(update.TableRefs.TableRefs, false)
60996097
for _, t := range tableList {
61006098
dbName := t.Schema.L
61016099
if dbName == "" {
@@ -6623,8 +6621,7 @@ func (b *PlanBuilder) buildDelete(ctx context.Context, ds *ast.DeleteStmt) (Plan
66236621
}
66246622
} else {
66256623
// Delete from a, b, c, d.
6626-
var tableList []*ast.TableName
6627-
tableList = extractTableList(ds.TableRefs.TableRefs, tableList, false)
6624+
tableList := ExtractTableList(ds.TableRefs.TableRefs, false)
66286625
for _, v := range tableList {
66296626
if isCTE(v) {
66306627
return nil, ErrNonUpdatableTable.GenWithStackByArgs(v.Name.O, "DELETE")
@@ -7444,17 +7441,6 @@ func buildWindowSpecs(specs []ast.WindowSpec) (map[string]*ast.WindowSpec, error
74447441
return specsMap, nil
74457442
}
74467443

7447-
func unfoldSelectList(list *ast.SetOprSelectList, unfoldList *ast.SetOprSelectList) {
7448-
for _, sel := range list.Selects {
7449-
switch s := sel.(type) {
7450-
case *ast.SelectStmt:
7451-
unfoldList.Selects = append(unfoldList.Selects, s)
7452-
case *ast.SetOprSelectList:
7453-
unfoldSelectList(s, unfoldList)
7454-
}
7455-
}
7456-
}
7457-
74587444
type updatableTableListResolver struct {
74597445
updatableTableList []*ast.TableName
74607446
}
@@ -7483,111 +7469,149 @@ func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) {
74837469
return inNode, true
74847470
}
74857471

7486-
// extractTableList extracts all the TableNames from node.
7472+
// ExtractTableList is a wrapper for tableListExtractor and removes duplicate TableName
74877473
// If asName is true, extract AsName prior to OrigName.
74887474
// Privilege check should use OrigName, while expression may use AsName.
7489-
// TODO: extracting all tables by vistor model maybe a better way
7490-
func extractTableList(node ast.Node, input []*ast.TableName, asName bool) []*ast.TableName {
7491-
switch x := node.(type) {
7492-
case *ast.SelectStmt:
7493-
if x.From != nil {
7494-
input = extractTableList(x.From.TableRefs, input, asName)
7495-
}
7496-
if x.Where != nil {
7497-
input = extractTableList(x.Where, input, asName)
7498-
}
7499-
if x.With != nil {
7500-
for _, cte := range x.With.CTEs {
7501-
input = extractTableList(cte.Query, input, asName)
7502-
}
7503-
}
7504-
for _, f := range x.Fields.Fields {
7505-
if s, ok := f.Expr.(*ast.SubqueryExpr); ok {
7506-
input = extractTableList(s, input, asName)
7507-
}
7508-
}
7509-
case *ast.DeleteStmt:
7510-
input = extractTableList(x.TableRefs.TableRefs, input, asName)
7511-
if x.IsMultiTable {
7512-
for _, t := range x.Tables.Tables {
7513-
input = extractTableList(t, input, asName)
7514-
}
7515-
}
7516-
if x.Where != nil {
7517-
input = extractTableList(x.Where, input, asName)
7518-
}
7519-
if x.With != nil {
7520-
for _, cte := range x.With.CTEs {
7521-
input = extractTableList(cte.Query, input, asName)
7522-
}
7523-
}
7524-
case *ast.UpdateStmt:
7525-
input = extractTableList(x.TableRefs.TableRefs, input, asName)
7526-
for _, e := range x.List {
7527-
input = extractTableList(e.Expr, input, asName)
7528-
}
7529-
if x.Where != nil {
7530-
input = extractTableList(x.Where, input, asName)
7531-
}
7532-
if x.With != nil {
7533-
for _, cte := range x.With.CTEs {
7534-
input = extractTableList(cte.Query, input, asName)
7475+
func ExtractTableList(node ast.Node, asName bool) []*ast.TableName {
7476+
if node == nil {
7477+
return []*ast.TableName{}
7478+
}
7479+
e := &tableListExtractor{
7480+
asName: asName,
7481+
tableNames: []*ast.TableName{},
7482+
}
7483+
node.Accept(e)
7484+
tableNames := e.tableNames
7485+
m := make(map[string]map[string]*ast.TableName) // k1: schemaName, k2: tableName, v: ast.TableName
7486+
for _, x := range tableNames {
7487+
k1, k2 := x.Schema.L, x.Name.L
7488+
// allow empty schema name OR empty table name
7489+
if k1 != "" || k2 != "" {
7490+
if _, ok := m[k1]; !ok {
7491+
m[k1] = make(map[string]*ast.TableName)
75357492
}
7493+
m[k1][k2] = x
75367494
}
7537-
case *ast.InsertStmt:
7538-
input = extractTableList(x.Table.TableRefs, input, asName)
7539-
input = extractTableList(x.Select, input, asName)
7540-
case *ast.SetOprStmt:
7541-
l := &ast.SetOprSelectList{}
7542-
unfoldSelectList(x.SelectList, l)
7543-
for _, s := range l.Selects {
7544-
input = extractTableList(s.(ast.ResultSetNode), input, asName)
7545-
}
7546-
case *ast.PatternInExpr:
7547-
if s, ok := x.Sel.(*ast.SubqueryExpr); ok {
7548-
input = extractTableList(s, input, asName)
7495+
}
7496+
tableNames = tableNames[:0]
7497+
for _, x := range m {
7498+
for _, v := range x {
7499+
tableNames = append(tableNames, v)
75497500
}
7550-
case *ast.ExistsSubqueryExpr:
7551-
if s, ok := x.Sel.(*ast.SubqueryExpr); ok {
7552-
input = extractTableList(s, input, asName)
7501+
}
7502+
return tableNames
7503+
}
7504+
7505+
// tableListExtractor extracts all the TableNames from node.
7506+
type tableListExtractor struct {
7507+
asName bool
7508+
tableNames []*ast.TableName
7509+
}
7510+
7511+
func (e *tableListExtractor) Enter(n ast.Node) (_ ast.Node, skipChildren bool) {
7512+
innerExtract := func(inner ast.Node) []*ast.TableName {
7513+
if inner == nil {
7514+
return nil
75537515
}
7554-
case *ast.BinaryOperationExpr:
7555-
if s, ok := x.R.(*ast.SubqueryExpr); ok {
7556-
input = extractTableList(s, input, asName)
7516+
innerExtractor := &tableListExtractor{
7517+
asName: e.asName,
7518+
tableNames: []*ast.TableName{},
75577519
}
7558-
case *ast.SubqueryExpr:
7559-
input = extractTableList(x.Query, input, asName)
7560-
case *ast.Join:
7561-
input = extractTableList(x.Left, input, asName)
7562-
input = extractTableList(x.Right, input, asName)
7520+
inner.Accept(innerExtractor)
7521+
return innerExtractor.tableNames
7522+
}
7523+
7524+
switch x := n.(type) {
7525+
case *ast.TableName:
7526+
e.tableNames = append(e.tableNames, x)
75637527
case *ast.TableSource:
75647528
if s, ok := x.Source.(*ast.TableName); ok {
7565-
if x.AsName.L != "" && asName {
7529+
if x.AsName.L != "" && e.asName {
75667530
newTableName := *s
75677531
newTableName.Name = x.AsName
75687532
newTableName.Schema = model.NewCIStr("")
7569-
input = append(input, &newTableName)
7533+
e.tableNames = append(e.tableNames, &newTableName)
75707534
} else {
7571-
input = append(input, s)
7535+
e.tableNames = append(e.tableNames, s)
75727536
}
75737537
} else if s, ok := x.Source.(*ast.SelectStmt); ok {
75747538
if s.From != nil {
7575-
var innerList []*ast.TableName
7576-
innerList = extractTableList(s.From.TableRefs, innerList, asName)
7539+
innerList := innerExtract(s.From.TableRefs)
75777540
if len(innerList) > 0 {
75787541
innerTableName := innerList[0]
7579-
if x.AsName.L != "" && asName {
7542+
if x.AsName.L != "" && e.asName {
75807543
newTableName := *innerList[0]
75817544
newTableName.Name = x.AsName
75827545
newTableName.Schema = model.NewCIStr("")
75837546
innerTableName = &newTableName
75847547
}
7585-
input = append(input, innerTableName)
7548+
e.tableNames = append(e.tableNames, innerTableName)
75867549
}
75877550
}
75887551
}
7552+
return n, true
7553+
7554+
case *ast.ShowStmt:
7555+
if x.DBName != "" {
7556+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)})
7557+
}
7558+
case *ast.CreateDatabaseStmt:
7559+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name})
7560+
case *ast.AlterDatabaseStmt:
7561+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name})
7562+
case *ast.DropDatabaseStmt:
7563+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name})
7564+
7565+
case *ast.FlashBackDatabaseStmt:
7566+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName})
7567+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.NewName)})
7568+
case *ast.FlashBackToTimestampStmt:
7569+
if x.DBName.L != "" {
7570+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName})
7571+
}
7572+
case *ast.FlashBackTableStmt:
7573+
if newName := x.NewName; newName != "" {
7574+
e.tableNames = append(e.tableNames, &ast.TableName{
7575+
Schema: x.Table.Schema,
7576+
Name: model.NewCIStr(newName)})
7577+
}
7578+
7579+
case *ast.GrantStmt:
7580+
if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone {
7581+
if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable {
7582+
e.tableNames = append(e.tableNames, &ast.TableName{
7583+
Schema: model.NewCIStr(x.Level.DBName),
7584+
Name: model.NewCIStr(x.Level.TableName),
7585+
})
7586+
}
7587+
}
7588+
case *ast.RevokeStmt:
7589+
if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone {
7590+
if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable {
7591+
e.tableNames = append(e.tableNames, &ast.TableName{
7592+
Schema: model.NewCIStr(x.Level.DBName),
7593+
Name: model.NewCIStr(x.Level.TableName),
7594+
})
7595+
}
7596+
}
7597+
case *ast.BRIEStmt:
7598+
if x.Kind == ast.BRIEKindBackup || x.Kind == ast.BRIEKindRestore {
7599+
for _, v := range x.Schemas {
7600+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(v)})
7601+
}
7602+
}
7603+
case *ast.UseStmt:
7604+
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)})
7605+
case *ast.ExecuteStmt:
7606+
if v, ok := x.PrepStmt.(*PlanCacheStmt); ok {
7607+
e.tableNames = append(e.tableNames, innerExtract(v.PreparedAst.Stmt)...)
7608+
}
75897609
}
7590-
return input
7610+
return n, false
7611+
}
7612+
7613+
func (*tableListExtractor) Leave(n ast.Node) (ast.Node, bool) {
7614+
return n, true
75917615
}
75927616

75937617
func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, info *map[string]*ast.TableName) {

pkg/planner/core/point_get_plan.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -1677,8 +1677,7 @@ func buildPointUpdatePlan(ctx sessionctx.Context, pointPlan PhysicalPlan, dbName
16771677
}
16781678
if tbl.GetPartitionInfo() != nil {
16791679
pt := t.(table.PartitionedTable)
1680-
var updateTableList []*ast.TableName
1681-
updateTableList = extractTableList(updateStmt.TableRefs.TableRefs, updateTableList, true)
1680+
updateTableList := ExtractTableList(updateStmt.TableRefs.TableRefs, true)
16821681
updatePlan.PartitionedTable = make([]table.PartitionedTable, 0, len(updateTableList))
16831682
for _, updateTable := range updateTableList {
16841683
if len(updateTable.PartitionNames) > 0 {

pkg/planner/core/preprocess.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, def
528528
}
529529

530530
// Check the bind operation is not on any temporary table.
531-
tblNames := extractTableList(originNode, nil, false)
531+
tblNames := ExtractTableList(originNode, false)
532532
for _, tn := range tblNames {
533533
tbl, err := p.tableByName(tn)
534534
if err != nil {

0 commit comments

Comments
 (0)