@@ -3419,8 +3419,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {
3419
3419
}
3420
3420
3421
3421
func tblInfoFromCol (from ast.ResultSetNode , name * types.FieldName ) * model.TableInfo {
3422
- var tableList []* ast.TableName
3423
- tableList = extractTableList (from , tableList , true )
3422
+ tableList := ExtractTableList (from , true )
3424
3423
for _ , field := range tableList {
3425
3424
if field .Name .L == name .TblName .L {
3426
3425
return field .TableInfo
@@ -6094,8 +6093,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (
6094
6093
return nil , err
6095
6094
}
6096
6095
6097
- var tableList []* ast.TableName
6098
- tableList = extractTableList (update .TableRefs .TableRefs , tableList , false )
6096
+ tableList := ExtractTableList (update .TableRefs .TableRefs , false )
6099
6097
for _ , t := range tableList {
6100
6098
dbName := t .Schema .L
6101
6099
if dbName == "" {
@@ -6623,8 +6621,7 @@ func (b *PlanBuilder) buildDelete(ctx context.Context, ds *ast.DeleteStmt) (Plan
6623
6621
}
6624
6622
} else {
6625
6623
// Delete from a, b, c, d.
6626
- var tableList []* ast.TableName
6627
- tableList = extractTableList (ds .TableRefs .TableRefs , tableList , false )
6624
+ tableList := ExtractTableList (ds .TableRefs .TableRefs , false )
6628
6625
for _ , v := range tableList {
6629
6626
if isCTE (v ) {
6630
6627
return nil , ErrNonUpdatableTable .GenWithStackByArgs (v .Name .O , "DELETE" )
@@ -7444,17 +7441,6 @@ func buildWindowSpecs(specs []ast.WindowSpec) (map[string]*ast.WindowSpec, error
7444
7441
return specsMap , nil
7445
7442
}
7446
7443
7447
- func unfoldSelectList (list * ast.SetOprSelectList , unfoldList * ast.SetOprSelectList ) {
7448
- for _ , sel := range list .Selects {
7449
- switch s := sel .(type ) {
7450
- case * ast.SelectStmt :
7451
- unfoldList .Selects = append (unfoldList .Selects , s )
7452
- case * ast.SetOprSelectList :
7453
- unfoldSelectList (s , unfoldList )
7454
- }
7455
- }
7456
- }
7457
-
7458
7444
type updatableTableListResolver struct {
7459
7445
updatableTableList []* ast.TableName
7460
7446
}
@@ -7483,111 +7469,149 @@ func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) {
7483
7469
return inNode , true
7484
7470
}
7485
7471
7486
- // extractTableList extracts all the TableNames from node.
7472
+ // ExtractTableList is a wrapper for tableListExtractor and removes duplicate TableName
7487
7473
// If asName is true, extract AsName prior to OrigName.
7488
7474
// Privilege check should use OrigName, while expression may use AsName.
7489
- // TODO: extracting all tables by vistor model maybe a better way
7490
- func extractTableList (node ast.Node , input []* ast.TableName , asName bool ) []* ast.TableName {
7491
- switch x := node .(type ) {
7492
- case * ast.SelectStmt :
7493
- if x .From != nil {
7494
- input = extractTableList (x .From .TableRefs , input , asName )
7495
- }
7496
- if x .Where != nil {
7497
- input = extractTableList (x .Where , input , asName )
7498
- }
7499
- if x .With != nil {
7500
- for _ , cte := range x .With .CTEs {
7501
- input = extractTableList (cte .Query , input , asName )
7502
- }
7503
- }
7504
- for _ , f := range x .Fields .Fields {
7505
- if s , ok := f .Expr .(* ast.SubqueryExpr ); ok {
7506
- input = extractTableList (s , input , asName )
7507
- }
7508
- }
7509
- case * ast.DeleteStmt :
7510
- input = extractTableList (x .TableRefs .TableRefs , input , asName )
7511
- if x .IsMultiTable {
7512
- for _ , t := range x .Tables .Tables {
7513
- input = extractTableList (t , input , asName )
7514
- }
7515
- }
7516
- if x .Where != nil {
7517
- input = extractTableList (x .Where , input , asName )
7518
- }
7519
- if x .With != nil {
7520
- for _ , cte := range x .With .CTEs {
7521
- input = extractTableList (cte .Query , input , asName )
7522
- }
7523
- }
7524
- case * ast.UpdateStmt :
7525
- input = extractTableList (x .TableRefs .TableRefs , input , asName )
7526
- for _ , e := range x .List {
7527
- input = extractTableList (e .Expr , input , asName )
7528
- }
7529
- if x .Where != nil {
7530
- input = extractTableList (x .Where , input , asName )
7531
- }
7532
- if x .With != nil {
7533
- for _ , cte := range x .With .CTEs {
7534
- input = extractTableList (cte .Query , input , asName )
7475
+ func ExtractTableList (node ast.Node , asName bool ) []* ast.TableName {
7476
+ if node == nil {
7477
+ return []* ast.TableName {}
7478
+ }
7479
+ e := & tableListExtractor {
7480
+ asName : asName ,
7481
+ tableNames : []* ast.TableName {},
7482
+ }
7483
+ node .Accept (e )
7484
+ tableNames := e .tableNames
7485
+ m := make (map [string ]map [string ]* ast.TableName ) // k1: schemaName, k2: tableName, v: ast.TableName
7486
+ for _ , x := range tableNames {
7487
+ k1 , k2 := x .Schema .L , x .Name .L
7488
+ // allow empty schema name OR empty table name
7489
+ if k1 != "" || k2 != "" {
7490
+ if _ , ok := m [k1 ]; ! ok {
7491
+ m [k1 ] = make (map [string ]* ast.TableName )
7535
7492
}
7493
+ m [k1 ][k2 ] = x
7536
7494
}
7537
- case * ast.InsertStmt :
7538
- input = extractTableList (x .Table .TableRefs , input , asName )
7539
- input = extractTableList (x .Select , input , asName )
7540
- case * ast.SetOprStmt :
7541
- l := & ast.SetOprSelectList {}
7542
- unfoldSelectList (x .SelectList , l )
7543
- for _ , s := range l .Selects {
7544
- input = extractTableList (s .(ast.ResultSetNode ), input , asName )
7545
- }
7546
- case * ast.PatternInExpr :
7547
- if s , ok := x .Sel .(* ast.SubqueryExpr ); ok {
7548
- input = extractTableList (s , input , asName )
7495
+ }
7496
+ tableNames = tableNames [:0 ]
7497
+ for _ , x := range m {
7498
+ for _ , v := range x {
7499
+ tableNames = append (tableNames , v )
7549
7500
}
7550
- case * ast.ExistsSubqueryExpr :
7551
- if s , ok := x .Sel .(* ast.SubqueryExpr ); ok {
7552
- input = extractTableList (s , input , asName )
7501
+ }
7502
+ return tableNames
7503
+ }
7504
+
7505
+ // tableListExtractor extracts all the TableNames from node.
7506
+ type tableListExtractor struct {
7507
+ asName bool
7508
+ tableNames []* ast.TableName
7509
+ }
7510
+
7511
+ func (e * tableListExtractor ) Enter (n ast.Node ) (_ ast.Node , skipChildren bool ) {
7512
+ innerExtract := func (inner ast.Node ) []* ast.TableName {
7513
+ if inner == nil {
7514
+ return nil
7553
7515
}
7554
- case * ast. BinaryOperationExpr :
7555
- if s , ok := x . R .( * ast. SubqueryExpr ); ok {
7556
- input = extractTableList ( s , input , asName )
7516
+ innerExtractor := & tableListExtractor {
7517
+ asName : e . asName ,
7518
+ tableNames : [] * ast. TableName {},
7557
7519
}
7558
- case * ast.SubqueryExpr :
7559
- input = extractTableList (x .Query , input , asName )
7560
- case * ast.Join :
7561
- input = extractTableList (x .Left , input , asName )
7562
- input = extractTableList (x .Right , input , asName )
7520
+ inner .Accept (innerExtractor )
7521
+ return innerExtractor .tableNames
7522
+ }
7523
+
7524
+ switch x := n .(type ) {
7525
+ case * ast.TableName :
7526
+ e .tableNames = append (e .tableNames , x )
7563
7527
case * ast.TableSource :
7564
7528
if s , ok := x .Source .(* ast.TableName ); ok {
7565
- if x .AsName .L != "" && asName {
7529
+ if x .AsName .L != "" && e . asName {
7566
7530
newTableName := * s
7567
7531
newTableName .Name = x .AsName
7568
7532
newTableName .Schema = model .NewCIStr ("" )
7569
- input = append (input , & newTableName )
7533
+ e . tableNames = append (e . tableNames , & newTableName )
7570
7534
} else {
7571
- input = append (input , s )
7535
+ e . tableNames = append (e . tableNames , s )
7572
7536
}
7573
7537
} else if s , ok := x .Source .(* ast.SelectStmt ); ok {
7574
7538
if s .From != nil {
7575
- var innerList []* ast.TableName
7576
- innerList = extractTableList (s .From .TableRefs , innerList , asName )
7539
+ innerList := innerExtract (s .From .TableRefs )
7577
7540
if len (innerList ) > 0 {
7578
7541
innerTableName := innerList [0 ]
7579
- if x .AsName .L != "" && asName {
7542
+ if x .AsName .L != "" && e . asName {
7580
7543
newTableName := * innerList [0 ]
7581
7544
newTableName .Name = x .AsName
7582
7545
newTableName .Schema = model .NewCIStr ("" )
7583
7546
innerTableName = & newTableName
7584
7547
}
7585
- input = append (input , innerTableName )
7548
+ e . tableNames = append (e . tableNames , innerTableName )
7586
7549
}
7587
7550
}
7588
7551
}
7552
+ return n , true
7553
+
7554
+ case * ast.ShowStmt :
7555
+ if x .DBName != "" {
7556
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : model .NewCIStr (x .DBName )})
7557
+ }
7558
+ case * ast.CreateDatabaseStmt :
7559
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .Name })
7560
+ case * ast.AlterDatabaseStmt :
7561
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .Name })
7562
+ case * ast.DropDatabaseStmt :
7563
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .Name })
7564
+
7565
+ case * ast.FlashBackDatabaseStmt :
7566
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .DBName })
7567
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : model .NewCIStr (x .NewName )})
7568
+ case * ast.FlashBackToTimestampStmt :
7569
+ if x .DBName .L != "" {
7570
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .DBName })
7571
+ }
7572
+ case * ast.FlashBackTableStmt :
7573
+ if newName := x .NewName ; newName != "" {
7574
+ e .tableNames = append (e .tableNames , & ast.TableName {
7575
+ Schema : x .Table .Schema ,
7576
+ Name : model .NewCIStr (newName )})
7577
+ }
7578
+
7579
+ case * ast.GrantStmt :
7580
+ if x .ObjectType == ast .ObjectTypeTable || x .ObjectType == ast .ObjectTypeNone {
7581
+ if x .Level .Level == ast .GrantLevelDB || x .Level .Level == ast .GrantLevelTable {
7582
+ e .tableNames = append (e .tableNames , & ast.TableName {
7583
+ Schema : model .NewCIStr (x .Level .DBName ),
7584
+ Name : model .NewCIStr (x .Level .TableName ),
7585
+ })
7586
+ }
7587
+ }
7588
+ case * ast.RevokeStmt :
7589
+ if x .ObjectType == ast .ObjectTypeTable || x .ObjectType == ast .ObjectTypeNone {
7590
+ if x .Level .Level == ast .GrantLevelDB || x .Level .Level == ast .GrantLevelTable {
7591
+ e .tableNames = append (e .tableNames , & ast.TableName {
7592
+ Schema : model .NewCIStr (x .Level .DBName ),
7593
+ Name : model .NewCIStr (x .Level .TableName ),
7594
+ })
7595
+ }
7596
+ }
7597
+ case * ast.BRIEStmt :
7598
+ if x .Kind == ast .BRIEKindBackup || x .Kind == ast .BRIEKindRestore {
7599
+ for _ , v := range x .Schemas {
7600
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : model .NewCIStr (v )})
7601
+ }
7602
+ }
7603
+ case * ast.UseStmt :
7604
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : model .NewCIStr (x .DBName )})
7605
+ case * ast.ExecuteStmt :
7606
+ if v , ok := x .PrepStmt .(* PlanCacheStmt ); ok {
7607
+ e .tableNames = append (e .tableNames , innerExtract (v .PreparedAst .Stmt )... )
7608
+ }
7589
7609
}
7590
- return input
7610
+ return n , false
7611
+ }
7612
+
7613
+ func (* tableListExtractor ) Leave (n ast.Node ) (ast.Node , bool ) {
7614
+ return n , true
7591
7615
}
7592
7616
7593
7617
func collectTableName (node ast.ResultSetNode , updatableName * map [string ]bool , info * map [string ]* ast.TableName ) {
0 commit comments