From f392710f10b9e6aa723f581c0d36ec67be59ec3f Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Wed, 26 Feb 2025 15:34:58 -0800 Subject: [PATCH 1/6] Added a parser to the analyzer --- sql/analyzer/analyzer.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/analyzer/analyzer.go b/sql/analyzer/analyzer.go index 5bd7df72ea..41e2cc56b8 100644 --- a/sql/analyzer/analyzer.go +++ b/sql/analyzer/analyzer.go @@ -266,6 +266,7 @@ func (ab *Builder) Build() *Analyzer { Catalog: NewCatalog(ab.provider), Coster: memo.NewDefaultCoster(), ExecBuilder: rowexec.DefaultBuilder, + Parser: sql.GlobalParser, } } @@ -288,6 +289,8 @@ type Analyzer struct { ExecBuilder sql.NodeExecBuilder // Runner represents the engine, which is represented as a separate interface to work around circular dependencies Runner StatementRunner + // Parser is the parser used to parse SQL statements. + Parser sql.Parser } // NewDefault creates a default Analyzer instance with all default Rules and configuration. From 28fdce58501a780ee49438128bcde0f8ea79b92d Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Wed, 26 Feb 2025 15:52:11 -0800 Subject: [PATCH 2/6] Use the parser to quote identifiers --- memory/table.go | 2 +- sql/analyzer/resolve_column_defaults.go | 12 +++++----- sql/expression/get_field.go | 29 ++++++++++++++++--------- sql/parser.go | 9 ++++++++ 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/memory/table.go b/memory/table.go index 79bb6d33e1..7d6248fa88 100644 --- a/memory/table.go +++ b/memory/table.go @@ -134,7 +134,7 @@ func stripTblNames(e sql.Expression) (sql.Expression, transform.TreeIdentity, er case *expression.GetField: // strip table names ne := expression.NewGetField(e.Index(), e.Type(), e.Name(), e.IsNullable()) - ne = ne.WithBackTickNames(e.IsBackTickNames()) + ne = ne.WithQuotedNames(sql.GlobalParser, e.IsQuotedIdentifier()) return ne, transform.NewTree, nil default: } diff --git a/sql/analyzer/resolve_column_defaults.go b/sql/analyzer/resolve_column_defaults.go index 09f5ebf30e..13c3c44653 100644 --- a/sql/analyzer/resolve_column_defaults.go +++ b/sql/analyzer/resolve_column_defaults.go @@ -307,7 +307,7 @@ func stripTableNamesFromDefault(e *expression.Wrapper) (sql.Expression, transfor return expression.WrapExpression(&nd), transform.NewTree, nil } -func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { +func backtickDefaultColumnValueNames(ctx *sql.Context, a *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { span, ctx := ctx.Span("backtickDefaultColumnValueNames") defer span.End() @@ -315,7 +315,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node, switch node := n.(type) { case *plan.AlterDefaultSet: eWrapper := expression.WrapExpression(node.Default) - newExpr, same, err := backtickDefault(eWrapper) + newExpr, same, err := quoteIdentifiers(a.Parser, eWrapper) if err != nil { return node, transform.SameTree, err } @@ -335,7 +335,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node, return e, transform.SameTree, nil } - return backtickDefault(eWrapper) + return quoteIdentifiers(a.Parser, eWrapper) }) case *plan.ResolvedTable: ct, ok := node.Table.(*information_schema.ColumnsTable) @@ -354,7 +354,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node, return e, transform.SameTree, nil } - return backtickDefault(eWrapper) + return quoteIdentifiers(a.Parser, eWrapper) }) if err != nil { @@ -376,7 +376,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node, }) } -func backtickDefault(wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) { +func quoteIdentifiers(parser sql.Parser, wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) { newDefault, ok := wrap.Unwrap().(*sql.ColumnDefaultValue) if !ok { return wrap, transform.SameTree, nil @@ -388,7 +388,7 @@ func backtickDefault(wrap *expression.Wrapper) (sql.Expression, transform.TreeId newExpr, same, err := transform.Expr(newDefault.Expr, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { if e, isGf := expr.(*expression.GetField); isGf { - return e.WithBackTickNames(true), transform.NewTree, nil + return e.WithQuotedNames(parser,true), transform.NewTree, nil } return expr, transform.SameTree, nil }) diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index af6c60b99b..bf11da00f7 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -38,7 +38,11 @@ type GetField struct { fieldType2 sql.Type2 nullable bool - backTickNames bool + // parser is the parser used to parse the expression and print it + parser sql.Parser + + // quoteName indicates whether the field name should be quoted when printed with String() + quoteName bool } var _ sql.Expression = (*GetField)(nil) @@ -161,10 +165,14 @@ func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, err } func (p *GetField) String() string { - if p.table == "" { - if p.backTickNames { - return "`" + p.name + "`" + if p.quoteName { + if p.table == "" { + return p.parser.QuoteIdentifier(p.name) } + return p.parser.QuoteIdentifier(p.table) + "." + p.parser.QuoteIdentifier(p.name) + } + + if p.table == "" { return p.name } return p.table + "." + p.name @@ -188,16 +196,17 @@ func (p *GetField) WithIndex(n int) sql.Expression { return &p2 } -// WithBackTickNames returns a copy of this expression with the backtick names flag set to the given value. -func (p *GetField) WithBackTickNames(backtick bool) *GetField { +// WithQuotedNames returns a copy of this expression with the backtick names flag set to the given value. +func (p *GetField) WithQuotedNames(parser sql.Parser, quoteNames bool) *GetField { p2 := *p - p2.backTickNames = backtick + p2.quoteName = quoteNames + p2.parser = parser return &p2 } -// IsBackTickNames returns whether the field name should be quoted with backticks. -func (p *GetField) IsBackTickNames() bool { - return p.backTickNames +// IsQuotedIdentifier returns whether the field name should be quoted. +func (p *GetField) IsQuotedIdentifier() bool { + return p.quoteName } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/parser.go b/sql/parser.go index 60bcb99e70..0a7d7c8b6e 100644 --- a/sql/parser.go +++ b/sql/parser.go @@ -16,6 +16,7 @@ package sql import ( "context" + "fmt" trace2 "runtime/trace" "strings" "unicode" @@ -44,6 +45,10 @@ type Parser interface { // the index of the start of the next query. If |query| represents a no-op statement, such as ";" or "-- comment", // then implementations must return Vitess' ErrEmpty error. ParseOneWithOptions(context.Context, string, ast.ParserOptions) (ast.Statement, int, error) + // QuoteIdentifier returns the identifier given quoted according to this parser's dialect. This is used to + // standardize identifiers that cannot be parsed without quoting, because they break the normal identifier naming + // rules (such as containing spaces) + QuoteIdentifier(identifier string) string } var _ Parser = &MysqlParser{} @@ -99,3 +104,7 @@ func RemoveSpaceAndDelimiter(query string, d rune) string { return r == d || unicode.IsSpace(r) }) } + +func (m *MysqlParser) QuoteIdentifier(identifier string) string { + return fmt.Sprintf("`%s`", strings.ReplaceAll(identifier, "`", "``")) +} \ No newline at end of file From bd5e86491ee168bfbfc4376ff0a0dc4ddce4a36e Mon Sep 17 00:00:00 2001 From: zachmu Date: Thu, 27 Feb 2025 02:20:31 +0000 Subject: [PATCH 3/6] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/analyzer/resolve_column_defaults.go | 2 +- sql/parser.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/analyzer/resolve_column_defaults.go b/sql/analyzer/resolve_column_defaults.go index 13c3c44653..7bac4ae010 100644 --- a/sql/analyzer/resolve_column_defaults.go +++ b/sql/analyzer/resolve_column_defaults.go @@ -388,7 +388,7 @@ func quoteIdentifiers(parser sql.Parser, wrap *expression.Wrapper) (sql.Expressi newExpr, same, err := transform.Expr(newDefault.Expr, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { if e, isGf := expr.(*expression.GetField); isGf { - return e.WithQuotedNames(parser,true), transform.NewTree, nil + return e.WithQuotedNames(parser, true), transform.NewTree, nil } return expr, transform.SameTree, nil }) diff --git a/sql/parser.go b/sql/parser.go index 0a7d7c8b6e..e1f79372ad 100644 --- a/sql/parser.go +++ b/sql/parser.go @@ -45,8 +45,8 @@ type Parser interface { // the index of the start of the next query. If |query| represents a no-op statement, such as ";" or "-- comment", // then implementations must return Vitess' ErrEmpty error. ParseOneWithOptions(context.Context, string, ast.ParserOptions) (ast.Statement, int, error) - // QuoteIdentifier returns the identifier given quoted according to this parser's dialect. This is used to - // standardize identifiers that cannot be parsed without quoting, because they break the normal identifier naming + // QuoteIdentifier returns the identifier given quoted according to this parser's dialect. This is used to + // standardize identifiers that cannot be parsed without quoting, because they break the normal identifier naming // rules (such as containing spaces) QuoteIdentifier(identifier string) string } @@ -107,4 +107,4 @@ func RemoveSpaceAndDelimiter(query string, d rune) string { func (m *MysqlParser) QuoteIdentifier(identifier string) string { return fmt.Sprintf("`%s`", strings.ReplaceAll(identifier, "`", "``")) -} \ No newline at end of file +} From ac3c932fc3420a49ac9e30c12fb5936d2f45768c Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Thu, 27 Feb 2025 09:39:53 -0800 Subject: [PATCH 4/6] Bug fix: not quoting table names is load bearing --- sql/expression/get_field.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index bf11da00f7..3afe3953ca 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -165,16 +165,16 @@ func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, err } func (p *GetField) String() string { - if p.quoteName { - if p.table == "" { - return p.parser.QuoteIdentifier(p.name) - } - return p.parser.QuoteIdentifier(p.table) + "." + p.parser.QuoteIdentifier(p.name) - } - + // We never quote anything if the table identifier is present. Quoting the field name is a very narrow use case + // used only for serializing column default values and related fields, in which case the table name will always be + // stripped away. The output of this method is load-bearing in many places of analysis and execution. if p.table == "" { + if p.quoteName { + return p.parser.QuoteIdentifier(p.name) + } return p.name } + return p.table + "." + p.name } From e921e6c9c32c138f24fb5551a676fd07e054b79c Mon Sep 17 00:00:00 2001 From: zachmu Date: Thu, 27 Feb 2025 17:42:01 +0000 Subject: [PATCH 5/6] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/get_field.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 3afe3953ca..583972ba17 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -170,11 +170,11 @@ func (p *GetField) String() string { // stripped away. The output of this method is load-bearing in many places of analysis and execution. if p.table == "" { if p.quoteName { - return p.parser.QuoteIdentifier(p.name) + return p.parser.QuoteIdentifier(p.name) } return p.name } - + return p.table + "." + p.name } From b3936fcbd7c1f024177bea0b6a86903016c479c9 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Thu, 27 Feb 2025 10:27:45 -0800 Subject: [PATCH 6/6] Renamed rule --- sql/analyzer/resolve_column_defaults.go | 4 ++-- sql/analyzer/rule_ids.go | 6 +++--- sql/analyzer/ruleid_string.go | 6 +++--- sql/analyzer/rules.go | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/analyzer/resolve_column_defaults.go b/sql/analyzer/resolve_column_defaults.go index 7bac4ae010..001735321b 100644 --- a/sql/analyzer/resolve_column_defaults.go +++ b/sql/analyzer/resolve_column_defaults.go @@ -307,8 +307,8 @@ func stripTableNamesFromDefault(e *expression.Wrapper) (sql.Expression, transfor return expression.WrapExpression(&nd), transform.NewTree, nil } -func backtickDefaultColumnValueNames(ctx *sql.Context, a *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { - span, ctx := ctx.Span("backtickDefaultColumnValueNames") +func quoteDefaultColumnValueNames(ctx *sql.Context, a *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { + span, ctx := ctx.Span("quoteDefaultColumnValueNames") defer span.End() return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { diff --git a/sql/analyzer/rule_ids.go b/sql/analyzer/rule_ids.go index 1d20e7498b..27dfe93adf 100644 --- a/sql/analyzer/rule_ids.go +++ b/sql/analyzer/rule_ids.go @@ -81,7 +81,7 @@ const ( validateDeleteFromId // validateDeleteFrom // after all - cacheSubqueryAliasesInJoinsId // cacheSubqueryAliasesInJoins - BacktickDefaulColumnValueNamesId // backtickDefaultColumnValueNames - TrackProcessId // trackProcess + cacheSubqueryAliasesInJoinsId // cacheSubqueryAliasesInJoins + QuoteDefaultColumnValueNamesId // quoteDefaultColumnValueNames + TrackProcessId // trackProcess ) diff --git a/sql/analyzer/ruleid_string.go b/sql/analyzer/ruleid_string.go index 3c3c6d8fe9..9030001711 100755 --- a/sql/analyzer/ruleid_string.go +++ b/sql/analyzer/ruleid_string.go @@ -76,13 +76,13 @@ func _() { _ = x[validateAggregationsId-65] _ = x[validateDeleteFromId-66] _ = x[cacheSubqueryAliasesInJoinsId-67] - _ = x[BacktickDefaulColumnValueNamesId-68] + _ = x[QuoteDefaultColumnValueNamesId-68] _ = x[TrackProcessId-69] } -const _RuleId_name = "applyDefaultSelectLimitvalidateOffsetAndLimitvalidateStarExpressionsvalidateCreateTablevalidateAlterTablevalidateExprSemloadStoredProceduresvalidateDropTablesresolveDropConstraintvalidateDropConstraintresolveCreateSelectresolveSubqueriesresolveUnionsvalidateColumnDefaultsvalidateCreateTriggervalidateReadOnlyDatabasevalidateReadOnlyTransactionvalidateDatabaseSetvalidatePrivilegesflattenTableAliasespushdownSubqueryAliasFiltersvalidateCheckConstraintsreplaceCountStarreplaceCrossJoinsmoveJoinConditionsToFiltersimplifyFilterspushNotFiltershoistOutOfScopeFiltersunnestInSubqueriesunnestExistsSubqueriesfinalizeSubqueriesfinalizeUnionsloadTriggersprocessTruncateresolveAlterColumnstripTableNamesFromColumnDefaultsoptimizeJoinspushFiltersapplyIndexesFromOuterScopepruneTablesassignExecIndexesinlineSubqueryAliasRefseraseProjectionflattenDistinctreplaceAggreplaceIdxSortinsertTopNNodesreplaceIdxOrderByDistanceapplyHashInresolveInsertRowsapplyTriggersapplyProceduresassignRoutinesmodifyUpdateExprsForJoinapplyForeignKeysinterpretervalidateResolvedvalidateOrderByvalidateGroupByvalidateSchemaSourcevalidateIndexCreationvalidateOperandsvalidateIntervalUsagevalidateSubqueryColumnsvalidateUnionSchemasMatchvalidateAggregationsvalidateDeleteFromcacheSubqueryAliasesInJoinsbacktickDefaultColumnValueNamestrackProcess" +const _RuleId_name = "applyDefaultSelectLimitvalidateOffsetAndLimitvalidateStarExpressionsvalidateCreateTablevalidateAlterTablevalidateExprSemloadStoredProceduresvalidateDropTablesresolveDropConstraintvalidateDropConstraintresolveCreateSelectresolveSubqueriesresolveUnionsvalidateColumnDefaultsvalidateCreateTriggervalidateReadOnlyDatabasevalidateReadOnlyTransactionvalidateDatabaseSetvalidatePrivilegesflattenTableAliasespushdownSubqueryAliasFiltersvalidateCheckConstraintsreplaceCountStarreplaceCrossJoinsmoveJoinConditionsToFiltersimplifyFilterspushNotFiltershoistOutOfScopeFiltersunnestInSubqueriesunnestExistsSubqueriesfinalizeSubqueriesfinalizeUnionsloadTriggersprocessTruncateresolveAlterColumnstripTableNamesFromColumnDefaultsoptimizeJoinspushFiltersapplyIndexesFromOuterScopepruneTablesassignExecIndexesinlineSubqueryAliasRefseraseProjectionflattenDistinctreplaceAggreplaceIdxSortinsertTopNNodesreplaceIdxOrderByDistanceapplyHashInresolveInsertRowsapplyTriggersapplyProceduresassignRoutinesmodifyUpdateExprsForJoinapplyForeignKeysinterpretervalidateResolvedvalidateOrderByvalidateGroupByvalidateSchemaSourcevalidateIndexCreationvalidateOperandsvalidateIntervalUsagevalidateSubqueryColumnsvalidateUnionSchemasMatchvalidateAggregationsvalidateDeleteFromcacheSubqueryAliasesInJoinsquoteDefaultColumnValueNamestrackProcess" -var _RuleId_index = [...]uint16{0, 23, 45, 68, 87, 105, 120, 140, 158, 179, 201, 220, 237, 250, 272, 293, 317, 344, 363, 381, 400, 428, 452, 468, 485, 511, 526, 540, 562, 580, 602, 620, 634, 646, 661, 679, 712, 725, 736, 762, 773, 790, 813, 828, 843, 853, 867, 882, 907, 918, 935, 948, 963, 977, 1001, 1017, 1028, 1044, 1059, 1074, 1094, 1115, 1131, 1152, 1175, 1200, 1220, 1238, 1265, 1296, 1308} +var _RuleId_index = [...]uint16{0, 23, 45, 68, 87, 105, 120, 140, 158, 179, 201, 220, 237, 250, 272, 293, 317, 344, 363, 381, 400, 428, 452, 468, 485, 511, 526, 540, 562, 580, 602, 620, 634, 646, 661, 679, 712, 725, 736, 762, 773, 790, 813, 828, 843, 853, 867, 882, 907, 918, 935, 948, 963, 977, 1001, 1017, 1028, 1044, 1059, 1074, 1094, 1115, 1131, 1152, 1175, 1200, 1220, 1238, 1265, 1293, 1305} func (i RuleId) String() string { if i < 0 || i >= RuleId(len(_RuleId_index)-1) { diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index d352ac0a7c..bc797983c9 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -23,7 +23,7 @@ func init() { {applyProceduresId, applyProcedures}, {inlineSubqueryAliasRefsId, inlineSubqueryAliasRefs}, {cacheSubqueryAliasesInJoinsId, cacheSubqueryAliasesInJoins}, - {BacktickDefaulColumnValueNamesId, backtickDefaultColumnValueNames}, + {QuoteDefaultColumnValueNamesId, quoteDefaultColumnValueNames}, {TrackProcessId, trackProcess}, } }