@@ -3292,8 +3292,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {
3292
3292
}
3293
3293
3294
3294
func tblInfoFromCol (from ast.ResultSetNode , name * types.FieldName ) * model.TableInfo {
3295
- var tableList []* ast.TableName
3296
- tableList = extractTableList (from , tableList , true )
3295
+ tableList := ExtractTableList (from , true )
3297
3296
for _ , field := range tableList {
3298
3297
if field .Name .L == name .TblName .L {
3299
3298
return field .TableInfo
@@ -5753,8 +5752,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (
5753
5752
return nil , err
5754
5753
}
5755
5754
5756
- var tableList []* ast.TableName
5757
- tableList = extractTableList (update .TableRefs .TableRefs , tableList , false )
5755
+ tableList := ExtractTableList (update .TableRefs .TableRefs , false )
5758
5756
for _ , t := range tableList {
5759
5757
dbName := t .Schema .L
5760
5758
if dbName == "" {
@@ -6282,8 +6280,7 @@ func (b *PlanBuilder) buildDelete(ctx context.Context, ds *ast.DeleteStmt) (Plan
6282
6280
}
6283
6281
} else {
6284
6282
// Delete from a, b, c, d.
6285
- var tableList []* ast.TableName
6286
- tableList = extractTableList (ds .TableRefs .TableRefs , tableList , false )
6283
+ tableList := ExtractTableList (ds .TableRefs .TableRefs , false )
6287
6284
for _ , v := range tableList {
6288
6285
if isCTE (v ) {
6289
6286
return nil , plannererrors .ErrNonUpdatableTable .GenWithStackByArgs (v .Name .O , "DELETE" )
@@ -7097,17 +7094,6 @@ func buildWindowSpecs(specs []ast.WindowSpec) (map[string]*ast.WindowSpec, error
7097
7094
return specsMap , nil
7098
7095
}
7099
7096
7100
- func unfoldSelectList (list * ast.SetOprSelectList , unfoldList * ast.SetOprSelectList ) {
7101
- for _ , sel := range list .Selects {
7102
- switch s := sel .(type ) {
7103
- case * ast.SelectStmt :
7104
- unfoldList .Selects = append (unfoldList .Selects , s )
7105
- case * ast.SetOprSelectList :
7106
- unfoldSelectList (s , unfoldList )
7107
- }
7108
- }
7109
- }
7110
-
7111
7097
type updatableTableListResolver struct {
7112
7098
updatableTableList []* ast.TableName
7113
7099
}
@@ -7136,111 +7122,149 @@ func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) {
7136
7122
return inNode , true
7137
7123
}
7138
7124
7139
- // extractTableList extracts all the TableNames from node.
7125
+ // ExtractTableList is a wrapper for tableListExtractor and removes duplicate TableName
7140
7126
// If asName is true, extract AsName prior to OrigName.
7141
7127
// Privilege check should use OrigName, while expression may use AsName.
7142
- // TODO: extracting all tables by vistor model maybe a better way
7143
- func extractTableList (node ast.Node , input []* ast.TableName , asName bool ) []* ast.TableName {
7144
- switch x := node .(type ) {
7145
- case * ast.SelectStmt :
7146
- if x .From != nil {
7147
- input = extractTableList (x .From .TableRefs , input , asName )
7148
- }
7149
- if x .Where != nil {
7150
- input = extractTableList (x .Where , input , asName )
7151
- }
7152
- if x .With != nil {
7153
- for _ , cte := range x .With .CTEs {
7154
- input = extractTableList (cte .Query , input , asName )
7155
- }
7156
- }
7157
- for _ , f := range x .Fields .Fields {
7158
- if s , ok := f .Expr .(* ast.SubqueryExpr ); ok {
7159
- input = extractTableList (s , input , asName )
7160
- }
7161
- }
7162
- case * ast.DeleteStmt :
7163
- input = extractTableList (x .TableRefs .TableRefs , input , asName )
7164
- if x .IsMultiTable {
7165
- for _ , t := range x .Tables .Tables {
7166
- input = extractTableList (t , input , asName )
7167
- }
7168
- }
7169
- if x .Where != nil {
7170
- input = extractTableList (x .Where , input , asName )
7171
- }
7172
- if x .With != nil {
7173
- for _ , cte := range x .With .CTEs {
7174
- input = extractTableList (cte .Query , input , asName )
7175
- }
7176
- }
7177
- case * ast.UpdateStmt :
7178
- input = extractTableList (x .TableRefs .TableRefs , input , asName )
7179
- for _ , e := range x .List {
7180
- input = extractTableList (e .Expr , input , asName )
7181
- }
7182
- if x .Where != nil {
7183
- input = extractTableList (x .Where , input , asName )
7184
- }
7185
- if x .With != nil {
7186
- for _ , cte := range x .With .CTEs {
7187
- input = extractTableList (cte .Query , input , asName )
7128
+ func ExtractTableList (node ast.Node , asName bool ) []* ast.TableName {
7129
+ if node == nil {
7130
+ return []* ast.TableName {}
7131
+ }
7132
+ e := & tableListExtractor {
7133
+ asName : asName ,
7134
+ tableNames : []* ast.TableName {},
7135
+ }
7136
+ node .Accept (e )
7137
+ tableNames := e .tableNames
7138
+ m := make (map [string ]map [string ]* ast.TableName ) // k1: schemaName, k2: tableName, v: ast.TableName
7139
+ for _ , x := range tableNames {
7140
+ k1 , k2 := x .Schema .L , x .Name .L
7141
+ // allow empty schema name OR empty table name
7142
+ if k1 != "" || k2 != "" {
7143
+ if _ , ok := m [k1 ]; ! ok {
7144
+ m [k1 ] = make (map [string ]* ast.TableName )
7188
7145
}
7146
+ m [k1 ][k2 ] = x
7189
7147
}
7190
- case * ast.InsertStmt :
7191
- input = extractTableList (x .Table .TableRefs , input , asName )
7192
- input = extractTableList (x .Select , input , asName )
7193
- case * ast.SetOprStmt :
7194
- l := & ast.SetOprSelectList {}
7195
- unfoldSelectList (x .SelectList , l )
7196
- for _ , s := range l .Selects {
7197
- input = extractTableList (s .(ast.ResultSetNode ), input , asName )
7198
- }
7199
- case * ast.PatternInExpr :
7200
- if s , ok := x .Sel .(* ast.SubqueryExpr ); ok {
7201
- input = extractTableList (s , input , asName )
7148
+ }
7149
+ tableNames = tableNames [:0 ]
7150
+ for _ , x := range m {
7151
+ for _ , v := range x {
7152
+ tableNames = append (tableNames , v )
7202
7153
}
7203
- case * ast.ExistsSubqueryExpr :
7204
- if s , ok := x .Sel .(* ast.SubqueryExpr ); ok {
7205
- input = extractTableList (s , input , asName )
7154
+ }
7155
+ return tableNames
7156
+ }
7157
+
7158
+ // tableListExtractor extracts all the TableNames from node.
7159
+ type tableListExtractor struct {
7160
+ asName bool
7161
+ tableNames []* ast.TableName
7162
+ }
7163
+
7164
+ func (e * tableListExtractor ) Enter (n ast.Node ) (_ ast.Node , skipChildren bool ) {
7165
+ innerExtract := func (inner ast.Node ) []* ast.TableName {
7166
+ if inner == nil {
7167
+ return nil
7206
7168
}
7207
- case * ast. BinaryOperationExpr :
7208
- if s , ok := x . R .( * ast. SubqueryExpr ); ok {
7209
- input = extractTableList ( s , input , asName )
7169
+ innerExtractor := & tableListExtractor {
7170
+ asName : e . asName ,
7171
+ tableNames : [] * ast. TableName {},
7210
7172
}
7211
- case * ast.SubqueryExpr :
7212
- input = extractTableList (x .Query , input , asName )
7213
- case * ast.Join :
7214
- input = extractTableList (x .Left , input , asName )
7215
- input = extractTableList (x .Right , input , asName )
7173
+ inner .Accept (innerExtractor )
7174
+ return innerExtractor .tableNames
7175
+ }
7176
+
7177
+ switch x := n .(type ) {
7178
+ case * ast.TableName :
7179
+ e .tableNames = append (e .tableNames , x )
7216
7180
case * ast.TableSource :
7217
7181
if s , ok := x .Source .(* ast.TableName ); ok {
7218
- if x .AsName .L != "" && asName {
7182
+ if x .AsName .L != "" && e . asName {
7219
7183
newTableName := * s
7220
7184
newTableName .Name = x .AsName
7221
7185
newTableName .Schema = model .NewCIStr ("" )
7222
- input = append (input , & newTableName )
7186
+ e . tableNames = append (e . tableNames , & newTableName )
7223
7187
} else {
7224
- input = append (input , s )
7188
+ e . tableNames = append (e . tableNames , s )
7225
7189
}
7226
7190
} else if s , ok := x .Source .(* ast.SelectStmt ); ok {
7227
7191
if s .From != nil {
7228
- var innerList []* ast.TableName
7229
- innerList = extractTableList (s .From .TableRefs , innerList , asName )
7192
+ innerList := innerExtract (s .From .TableRefs )
7230
7193
if len (innerList ) > 0 {
7231
7194
innerTableName := innerList [0 ]
7232
- if x .AsName .L != "" && asName {
7195
+ if x .AsName .L != "" && e . asName {
7233
7196
newTableName := * innerList [0 ]
7234
7197
newTableName .Name = x .AsName
7235
7198
newTableName .Schema = model .NewCIStr ("" )
7236
7199
innerTableName = & newTableName
7237
7200
}
7238
- input = append (input , innerTableName )
7201
+ e . tableNames = append (e . tableNames , innerTableName )
7239
7202
}
7240
7203
}
7241
7204
}
7205
+ return n , true
7206
+
7207
+ case * ast.ShowStmt :
7208
+ if x .DBName != "" {
7209
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : model .NewCIStr (x .DBName )})
7210
+ }
7211
+ case * ast.CreateDatabaseStmt :
7212
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .Name })
7213
+ case * ast.AlterDatabaseStmt :
7214
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .Name })
7215
+ case * ast.DropDatabaseStmt :
7216
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .Name })
7217
+
7218
+ case * ast.FlashBackDatabaseStmt :
7219
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .DBName })
7220
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : model .NewCIStr (x .NewName )})
7221
+ case * ast.FlashBackToTimestampStmt :
7222
+ if x .DBName .L != "" {
7223
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : x .DBName })
7224
+ }
7225
+ case * ast.FlashBackTableStmt :
7226
+ if newName := x .NewName ; newName != "" {
7227
+ e .tableNames = append (e .tableNames , & ast.TableName {
7228
+ Schema : x .Table .Schema ,
7229
+ Name : model .NewCIStr (newName )})
7230
+ }
7231
+
7232
+ case * ast.GrantStmt :
7233
+ if x .ObjectType == ast .ObjectTypeTable || x .ObjectType == ast .ObjectTypeNone {
7234
+ if x .Level .Level == ast .GrantLevelDB || x .Level .Level == ast .GrantLevelTable {
7235
+ e .tableNames = append (e .tableNames , & ast.TableName {
7236
+ Schema : model .NewCIStr (x .Level .DBName ),
7237
+ Name : model .NewCIStr (x .Level .TableName ),
7238
+ })
7239
+ }
7240
+ }
7241
+ case * ast.RevokeStmt :
7242
+ if x .ObjectType == ast .ObjectTypeTable || x .ObjectType == ast .ObjectTypeNone {
7243
+ if x .Level .Level == ast .GrantLevelDB || x .Level .Level == ast .GrantLevelTable {
7244
+ e .tableNames = append (e .tableNames , & ast.TableName {
7245
+ Schema : model .NewCIStr (x .Level .DBName ),
7246
+ Name : model .NewCIStr (x .Level .TableName ),
7247
+ })
7248
+ }
7249
+ }
7250
+ case * ast.BRIEStmt :
7251
+ if x .Kind == ast .BRIEKindBackup || x .Kind == ast .BRIEKindRestore {
7252
+ for _ , v := range x .Schemas {
7253
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : model .NewCIStr (v )})
7254
+ }
7255
+ }
7256
+ case * ast.UseStmt :
7257
+ e .tableNames = append (e .tableNames , & ast.TableName {Schema : model .NewCIStr (x .DBName )})
7258
+ case * ast.ExecuteStmt :
7259
+ if v , ok := x .PrepStmt .(* PlanCacheStmt ); ok {
7260
+ e .tableNames = append (e .tableNames , innerExtract (v .PreparedAst .Stmt )... )
7261
+ }
7242
7262
}
7243
- return input
7263
+ return n , false
7264
+ }
7265
+
7266
+ func (* tableListExtractor ) Leave (n ast.Node ) (ast.Node , bool ) {
7267
+ return n , true
7244
7268
}
7245
7269
7246
7270
func collectTableName (node ast.ResultSetNode , updatableName * map [string ]bool , info * map [string ]* ast.TableName ) {
0 commit comments