From 43038f0e733215f64faf19d03d560ad67bb74baf Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Wed, 17 Jan 2024 11:07:12 -0800 Subject: [PATCH 1/3] Changing the VarScope function so that it returns whether a scope was explicitly specified, or if one has been inferred. --- go/vt/sqlparser/ast.go | 87 ++++++++++++++++++++++-------------------- go/vt/sqlparser/sql.go | 6 +-- go/vt/sqlparser/sql.y | 6 +-- 3 files changed, 51 insertions(+), 48 deletions(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index c2b0a3d02d1..52600d905ec 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -6497,113 +6497,116 @@ const ( SetScope_User SetScope = "user" ) -// VarScopeForColName returns the SetScope of the given ColName, along with a new ColName without the scope information. -func VarScopeForColName(colName *ColName) (*ColName, SetScope, error) { +// VarScopeForColName returns the SetScope of the given ColName, along with a new ColName without the scope information, +// and a boolean that indicates if the scope was explicitly specified, or inferred. +func VarScopeForColName(colName *ColName) (*ColName, SetScope, bool, error) { if colName.Qualifier.IsEmpty() { // Forms are like `@@x` and `@x` if strings.HasPrefix(colName.Name.val, "@") && strings.Index(colName.Name.val, ".") != -1 { - varName, scope, err := VarScope(strings.Split(colName.Name.val, ".")...) + varName, scope, explicit, err := VarScope(strings.Split(colName.Name.val, ".")...) if err != nil { - return nil, SetScope_None, err + return nil, SetScope_None, false, err } if scope == SetScope_None { - return colName, scope, nil + return colName, scope, false, nil } - return &ColName{Name: ColIdent{val: varName}}, scope, nil + return &ColName{Name: ColIdent{val: varName}}, scope, explicit, nil } else { - varName, scope, err := VarScope(colName.Name.val) + varName, scope, explicit, err := VarScope(colName.Name.val) if err != nil { - return nil, SetScope_None, err + return nil, SetScope_None, false, err } if scope == SetScope_None { - return colName, scope, nil + return colName, scope, false, nil } - return &ColName{Name: ColIdent{val: varName}}, scope, nil + return &ColName{Name: ColIdent{val: varName}}, scope, explicit, nil } } else if colName.Qualifier.Qualifier.IsEmpty() { // Forms are like `@@GLOBAL.x` and `@@SESSION.x` - varName, scope, err := VarScope(colName.Qualifier.Name.v, colName.Name.val) + varName, scope, explicit, err := VarScope(colName.Qualifier.Name.v, colName.Name.val) if err != nil { - return nil, SetScope_None, err + return nil, SetScope_None, false, err } if scope == SetScope_None { - return colName, scope, nil + return colName, scope, false, nil } - return &ColName{Name: ColIdent{val: varName}}, scope, nil + return &ColName{Name: ColIdent{val: varName}}, scope, explicit, nil } else { // Forms are like `@@GLOBAL.validate_password.length`, which is currently unsupported - _, _, err := VarScope(colName.Qualifier.Qualifier.v, colName.Qualifier.Name.v, colName.Name.val) - return colName, SetScope_None, err + _, _, _, err := VarScope(colName.Qualifier.Qualifier.v, colName.Qualifier.Name.v, colName.Name.val) + return colName, SetScope_None, false, err } } // VarScope returns the SetScope of the given name, broken into parts. For example, `@@GLOBAL.sys_var` would become // `[]string{"@@GLOBAL", "sys_var"}`. Returns the variable name without any scope specifiers, so the aforementioned -// variable would simply return "sys_var". `[]string{"@@other_var"}` would return "other_var". If the name parts do not +// variable would simply return "sys_var". `[]string{"@@other_var"}` would return "other_var". If a scope is not +// explicitly specified, then the boolean return parameter will be false. If the name parts do not // specify a variable (returns SetScope_None), then it is recommended to use the original non-broken string, as this // will always only return the last part. `[]string{"my_db", "my_tbl", "my_col"}` will return "my_col" with SetScope_None. -func VarScope(nameParts ...string) (string, SetScope, error) { +func VarScope(nameParts ...string) (string, SetScope, bool, error) { switch len(nameParts) { case 0: - return "", SetScope_None, nil + return "", SetScope_None, false, nil case 1: // First case covers `@@@`, `@@@@`, etc. if strings.HasPrefix(nameParts[0], "@@@") { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[0]) + return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[0]) } else if strings.HasPrefix(nameParts[0], "@@") { dotIdx := strings.Index(nameParts[0], ".") if dotIdx != -1 { return VarScope(nameParts[0][:dotIdx], nameParts[0][dotIdx+1:]) } - return nameParts[0][2:], SetScope_Session, nil + // Session scope is inferred here, but not explicitly requested + return nameParts[0][2:], SetScope_Session, false, nil } else if strings.HasPrefix(nameParts[0], "@") { - return nameParts[0][1:], SetScope_User, nil + return nameParts[0][1:], SetScope_User, false, nil } else { - return nameParts[0], SetScope_None, nil + return nameParts[0], SetScope_None, false, nil } case 2: // `@user.var` is valid, so we check for it here. if len(nameParts[0]) >= 2 && nameParts[0][0] == '@' && nameParts[0][1] != '@' && !strings.HasPrefix(nameParts[1], "@") { // `@user.@var` is invalid though. - return fmt.Sprintf("%s.%s", nameParts[0][1:], nameParts[1]), SetScope_User, nil + return fmt.Sprintf("%s.%s", nameParts[0][1:], nameParts[1]), SetScope_User, true, nil } // We don't support variables such as `@@validate_password.length` right now, only `@@GLOBAL.sys_var`, etc. // The `@` symbols are only valid on the first name_part. First case also catches `@@@`, etc. if strings.HasPrefix(nameParts[1], "@@") { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } else if strings.HasPrefix(nameParts[1], "@") { - return "", SetScope_None, fmt.Errorf("invalid user variable declaration `%s`", nameParts[1]) + return "", SetScope_None, false, fmt.Errorf("invalid user variable declaration `%s`", nameParts[1]) } switch strings.ToLower(nameParts[0]) { case "@@global": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Global, nil + return nameParts[1], SetScope_Global, true, nil case "@@persist": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Persist, nil + return nameParts[1], SetScope_Persist, true, nil case "@@persist_only": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_PersistOnly, nil + return nameParts[1], SetScope_PersistOnly, true, nil case "@@session": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Session, nil + return nameParts[1], SetScope_Session, true, nil case "@@local": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Session, nil + return nameParts[1], SetScope_Session, true, nil default: // This catches `@@@GLOBAL.sys_var`. Due to the earlier check, this does not error on `@user.var`. if strings.HasPrefix(nameParts[0], "@") { // Last value is column name, so we return that in the error - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_None, nil + return nameParts[1], SetScope_None, false, nil } default: // `@user.var.name` is valid, so we check for it here. @@ -6612,20 +6615,20 @@ func VarScope(nameParts ...string) (string, SetScope, error) { for i := 1; i < len(nameParts); i++ { if strings.HasPrefix(nameParts[i], "@") { // Last value is column name, so we return that in the error - return "", SetScope_None, fmt.Errorf("invalid user variable declaration `%s`", nameParts[len(nameParts)-1]) + return "", SetScope_None, false, fmt.Errorf("invalid user variable declaration `%s`", nameParts[len(nameParts)-1]) } } - return strings.Join(append([]string{nameParts[0][1:]}, nameParts[1:]...), "."), SetScope_User, nil + return strings.Join(append([]string{nameParts[0][1:]}, nameParts[1:]...), "."), SetScope_User, false, nil } // As we don't support `@@GLOBAL.validate_password.length` or anything potentially longer, we error if any part // starts with either `@@` or `@`. We can just check for `@` though. for _, namePart := range nameParts { if strings.HasPrefix(namePart, "@") { // Last value is column name, so we return that in the error - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[len(nameParts)-1]) + return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[len(nameParts)-1]) } } - return nameParts[len(nameParts)-1], SetScope_None, nil + return nameParts[len(nameParts)-1], SetScope_None, false, nil } } diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index c42651643da..5da0feeb1c8 100755 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -20802,7 +20802,7 @@ yydefault: yyDollar = yyS[yypt-1 : yypt+1] //line sql.y:7894 { - colName, scope, err := VarScopeForColName(yyDollar[1].setVarExpr.Name) + colName, scope, _, err := VarScopeForColName(yyDollar[1].setVarExpr.Name) if err != nil { yylex.Error(err.Error()) return 1 @@ -20815,7 +20815,7 @@ yydefault: yyDollar = yyS[yypt-2 : yypt+1] //line sql.y:7905 { - _, scope, err := VarScopeForColName(yyDollar[2].setVarExpr.Name) + _, scope, _, err := VarScopeForColName(yyDollar[2].setVarExpr.Name) if err != nil { yylex.Error(err.Error()) return 1 @@ -20830,7 +20830,7 @@ yydefault: yyDollar = yyS[yypt-2 : yypt+1] //line sql.y:7918 { - _, scope, err := VarScopeForColName(yyDollar[2].setVarExpr.Name) + _, scope, _, err := VarScopeForColName(yyDollar[2].setVarExpr.Name) if err != nil { yylex.Error(err.Error()) return 1 diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 69e15ee0511..a548e67debe 100755 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -7892,7 +7892,7 @@ set_list: set_expression: set_expression_assignment { - colName, scope, err := VarScopeForColName($1.Name) + colName, scope, _, err := VarScopeForColName($1.Name) if err != nil { yylex.Error(err.Error()) return 1 @@ -7903,7 +7903,7 @@ set_expression: } | set_scope_primary set_expression_assignment { - _, scope, err := VarScopeForColName($2.Name) + _, scope, _, err := VarScopeForColName($2.Name) if err != nil { yylex.Error(err.Error()) return 1 @@ -7916,7 +7916,7 @@ set_expression: } | set_scope_secondary set_expression_assignment { - _, scope, err := VarScopeForColName($2.Name) + _, scope, _, err := VarScopeForColName($2.Name) if err != nil { yylex.Error(err.Error()) return 1 From bd94646f66fce9b11dd3b302e7836a09b4328a42 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Wed, 17 Jan 2024 14:23:26 -0800 Subject: [PATCH 2/3] Adding a unit test for the VarScopeForColName function --- go/vt/sqlparser/ast_test.go | 62 +++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index 459c26d265b..91acde42202 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -24,6 +24,7 @@ import ( "testing" "unsafe" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dolthub/vitess/go/sqltypes" @@ -884,6 +885,67 @@ func TestSplitStatementToPieces(t *testing.T) { } } +func TestVarScopeForColName(t *testing.T) { + testcases := []struct { + colName ColName + expectedName ColName + expectedScope string + expectedExplicit bool + }{ + { + // Regular column name + colName: ColName{Name: ColIdent{val: "a"}}, + expectedName: ColName{Name: ColIdent{val: "a"}}, + expectedExplicit: false, + }, + { + // User variable + colName: ColName{Name: ColIdent{val: "@aaa"}}, + expectedName: ColName{Name: ColIdent{val: "aaa"}}, + expectedExplicit: false, + expectedScope: "user", + }, + { + // System variable without an explicit scope (defaults to session) + colName: ColName{Name: ColIdent{val: "@@max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedExplicit: false, + expectedScope: "session", + }, + { + // System variable with an explicit session scope + colName: ColName{Name: ColIdent{val: "@@session.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedExplicit: true, + expectedScope: "session", + }, + { + // System variable with an explicit global scope + colName: ColName{Name: ColIdent{val: "@@global.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedExplicit: true, + expectedScope: "global", + }, + { + // System variable with an explicit persist scope + colName: ColName{Name: ColIdent{val: "@@persist.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedExplicit: true, + expectedScope: "persist", + }, + } + + for _, tt := range testcases { + t.Run(tt.colName.String(), func(t *testing.T) { + name, scope, b, err := VarScopeForColName(&tt.colName) + assert.NoError(t, err) + assert.Equal(t, tt.expectedName, *name) + assert.Equal(t, tt.expectedScope, string(scope)) + assert.Equal(t, tt.expectedExplicit, b) + }) + } +} + func TestWindowErrors(t *testing.T) { testcases := []struct { input string From 33f24ba514a13589be8cf9358ed86ea2dadc5956 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Wed, 17 Jan 2024 14:40:22 -0800 Subject: [PATCH 3/3] Changing the VarScope function to report the originally specified scope, so that we can match the exact specified casing in the field metadata for the result set --- go/vt/sqlparser/ast.go | 88 +++++++++++++++++++------------------ go/vt/sqlparser/ast_test.go | 61 +++++++++++++------------ 2 files changed, 79 insertions(+), 70 deletions(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 52600d905ec..de48e256ab6 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -6498,115 +6498,117 @@ const ( ) // VarScopeForColName returns the SetScope of the given ColName, along with a new ColName without the scope information, -// and a boolean that indicates if the scope was explicitly specified, or inferred. -func VarScopeForColName(colName *ColName) (*ColName, SetScope, bool, error) { +// and a string indicating the exact scope that was specified in the original query or "" if no scope was explicitly +// specified. +func VarScopeForColName(colName *ColName) (*ColName, SetScope, string, error) { if colName.Qualifier.IsEmpty() { // Forms are like `@@x` and `@x` if strings.HasPrefix(colName.Name.val, "@") && strings.Index(colName.Name.val, ".") != -1 { - varName, scope, explicit, err := VarScope(strings.Split(colName.Name.val, ".")...) + varName, scope, specifiedScope, err := VarScope(strings.Split(colName.Name.val, ".")...) if err != nil { - return nil, SetScope_None, false, err + return nil, SetScope_None, "", err } if scope == SetScope_None { - return colName, scope, false, nil + return colName, scope, "", nil } - return &ColName{Name: ColIdent{val: varName}}, scope, explicit, nil + return &ColName{Name: ColIdent{val: varName}}, scope, specifiedScope, nil } else { - varName, scope, explicit, err := VarScope(colName.Name.val) + varName, scope, specifiedScope, err := VarScope(colName.Name.val) if err != nil { - return nil, SetScope_None, false, err + return nil, SetScope_None, "", err } if scope == SetScope_None { - return colName, scope, false, nil + return colName, scope, "", nil } - return &ColName{Name: ColIdent{val: varName}}, scope, explicit, nil + return &ColName{Name: ColIdent{val: varName}}, scope, specifiedScope, nil } } else if colName.Qualifier.Qualifier.IsEmpty() { // Forms are like `@@GLOBAL.x` and `@@SESSION.x` - varName, scope, explicit, err := VarScope(colName.Qualifier.Name.v, colName.Name.val) + varName, scope, specifiedScope, err := VarScope(colName.Qualifier.Name.v, colName.Name.val) if err != nil { - return nil, SetScope_None, false, err + return nil, SetScope_None, "", err } if scope == SetScope_None { - return colName, scope, false, nil + return colName, scope, "", nil } - return &ColName{Name: ColIdent{val: varName}}, scope, explicit, nil + return &ColName{Name: ColIdent{val: varName}}, scope, specifiedScope, nil } else { // Forms are like `@@GLOBAL.validate_password.length`, which is currently unsupported _, _, _, err := VarScope(colName.Qualifier.Qualifier.v, colName.Qualifier.Name.v, colName.Name.val) - return colName, SetScope_None, false, err + return colName, SetScope_None, "", err } } // VarScope returns the SetScope of the given name, broken into parts. For example, `@@GLOBAL.sys_var` would become // `[]string{"@@GLOBAL", "sys_var"}`. Returns the variable name without any scope specifiers, so the aforementioned // variable would simply return "sys_var". `[]string{"@@other_var"}` would return "other_var". If a scope is not -// explicitly specified, then the boolean return parameter will be false. If the name parts do not -// specify a variable (returns SetScope_None), then it is recommended to use the original non-broken string, as this -// will always only return the last part. `[]string{"my_db", "my_tbl", "my_col"}` will return "my_col" with SetScope_None. -func VarScope(nameParts ...string) (string, SetScope, bool, error) { +// explicitly specified, then the requestedScope string will be empty, otherwise it will be the exact +// scope that was explicitly specified, which can differ from the returned scope, when the returned scope is +// inferred. If the name parts do not specify a variable (returns SetScope_None), then it is recommended to use the original non-broken string, as this will always only return the last part. +// `[]string{"my_db", "my_tbl", "my_col"}` will return "my_col" with SetScope_None. +func VarScope(nameParts ...string) (string, SetScope, string, error) { switch len(nameParts) { case 0: - return "", SetScope_None, false, nil + return "", SetScope_None, "", nil case 1: // First case covers `@@@`, `@@@@`, etc. if strings.HasPrefix(nameParts[0], "@@@") { - return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[0]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[0]) } else if strings.HasPrefix(nameParts[0], "@@") { dotIdx := strings.Index(nameParts[0], ".") if dotIdx != -1 { return VarScope(nameParts[0][:dotIdx], nameParts[0][dotIdx+1:]) } // Session scope is inferred here, but not explicitly requested - return nameParts[0][2:], SetScope_Session, false, nil + return nameParts[0][2:], SetScope_Session, "", nil } else if strings.HasPrefix(nameParts[0], "@") { - return nameParts[0][1:], SetScope_User, false, nil + return nameParts[0][1:], SetScope_User, "", nil } else { - return nameParts[0], SetScope_None, false, nil + return nameParts[0], SetScope_None, "", nil } case 2: // `@user.var` is valid, so we check for it here. if len(nameParts[0]) >= 2 && nameParts[0][0] == '@' && nameParts[0][1] != '@' && !strings.HasPrefix(nameParts[1], "@") { // `@user.@var` is invalid though. - return fmt.Sprintf("%s.%s", nameParts[0][1:], nameParts[1]), SetScope_User, true, nil + return fmt.Sprintf("%s.%s", nameParts[0][1:], nameParts[1]), SetScope_User, "", nil } // We don't support variables such as `@@validate_password.length` right now, only `@@GLOBAL.sys_var`, etc. // The `@` symbols are only valid on the first name_part. First case also catches `@@@`, etc. if strings.HasPrefix(nameParts[1], "@@") { - return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } else if strings.HasPrefix(nameParts[1], "@") { - return "", SetScope_None, false, fmt.Errorf("invalid user variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid user variable declaration `%s`", nameParts[1]) } switch strings.ToLower(nameParts[0]) { case "@@global": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Global, true, nil + return nameParts[1], SetScope_Global, nameParts[0][2:], nil case "@@persist": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Persist, true, nil + return nameParts[1], SetScope_Persist, nameParts[0][2:], nil case "@@persist_only": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_PersistOnly, true, nil + return nameParts[1], SetScope_PersistOnly, nameParts[0][2:], nil case "@@session": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Session, true, nil + return nameParts[1], SetScope_Session, nameParts[0][2:], nil case "@@local": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Session, true, nil + return nameParts[1], SetScope_Session, nameParts[0][2:], nil default: // This catches `@@@GLOBAL.sys_var`. Due to the earlier check, this does not error on `@user.var`. if strings.HasPrefix(nameParts[0], "@") { // Last value is column name, so we return that in the error - return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_None, false, nil + return nameParts[1], SetScope_None, "", nil } default: // `@user.var.name` is valid, so we check for it here. @@ -6615,20 +6617,20 @@ func VarScope(nameParts ...string) (string, SetScope, bool, error) { for i := 1; i < len(nameParts); i++ { if strings.HasPrefix(nameParts[i], "@") { // Last value is column name, so we return that in the error - return "", SetScope_None, false, fmt.Errorf("invalid user variable declaration `%s`", nameParts[len(nameParts)-1]) + return "", SetScope_None, "", fmt.Errorf("invalid user variable declaration `%s`", nameParts[len(nameParts)-1]) } } - return strings.Join(append([]string{nameParts[0][1:]}, nameParts[1:]...), "."), SetScope_User, false, nil + return strings.Join(append([]string{nameParts[0][1:]}, nameParts[1:]...), "."), SetScope_User, "", nil } // As we don't support `@@GLOBAL.validate_password.length` or anything potentially longer, we error if any part // starts with either `@@` or `@`. We can just check for `@` though. for _, namePart := range nameParts { if strings.HasPrefix(namePart, "@") { // Last value is column name, so we return that in the error - return "", SetScope_None, false, fmt.Errorf("invalid system variable declaration `%s`", nameParts[len(nameParts)-1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[len(nameParts)-1]) } } - return nameParts[len(nameParts)-1], SetScope_None, false, nil + return nameParts[len(nameParts)-1], SetScope_None, "", nil } } diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index 91acde42202..5edc447a038 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -889,59 +889,66 @@ func TestVarScopeForColName(t *testing.T) { testcases := []struct { colName ColName expectedName ColName - expectedScope string - expectedExplicit bool + expectedScope string + expectedSpecifiedScope string }{ { // Regular column name - colName: ColName{Name: ColIdent{val: "a"}}, - expectedName: ColName{Name: ColIdent{val: "a"}}, - expectedExplicit: false, + colName: ColName{Name: ColIdent{val: "a"}}, + expectedName: ColName{Name: ColIdent{val: "a"}}, + expectedSpecifiedScope: "", }, { // User variable - colName: ColName{Name: ColIdent{val: "@aaa"}}, - expectedName: ColName{Name: ColIdent{val: "aaa"}}, - expectedExplicit: false, - expectedScope: "user", + colName: ColName{Name: ColIdent{val: "@aaa"}}, + expectedName: ColName{Name: ColIdent{val: "aaa"}}, + expectedSpecifiedScope: "", + expectedScope: "user", }, { // System variable without an explicit scope (defaults to session) - colName: ColName{Name: ColIdent{val: "@@max_allowed_packets"}}, - expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, - expectedExplicit: false, - expectedScope: "session", + colName: ColName{Name: ColIdent{val: "@@max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "", + expectedScope: "session", }, { // System variable with an explicit session scope - colName: ColName{Name: ColIdent{val: "@@session.max_allowed_packets"}}, - expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, - expectedExplicit: true, - expectedScope: "session", + colName: ColName{Name: ColIdent{val: "@@session.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "session", + expectedScope: "session", }, { // System variable with an explicit global scope - colName: ColName{Name: ColIdent{val: "@@global.max_allowed_packets"}}, - expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, - expectedExplicit: true, - expectedScope: "global", + colName: ColName{Name: ColIdent{val: "@@global.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "global", + expectedScope: "global", + }, + { + // System variable with an explicit global scope (in all caps) + colName: ColName{Name: ColIdent{val: "@@GLOBAL.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "GLOBAL", + expectedScope: "global", }, { // System variable with an explicit persist scope - colName: ColName{Name: ColIdent{val: "@@persist.max_allowed_packets"}}, - expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, - expectedExplicit: true, - expectedScope: "persist", + colName: ColName{Name: ColIdent{val: "@@persist.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "persist", + expectedScope: "persist", }, } for _, tt := range testcases { t.Run(tt.colName.String(), func(t *testing.T) { - name, scope, b, err := VarScopeForColName(&tt.colName) + name, scope, specifiedScope, err := VarScopeForColName(&tt.colName) assert.NoError(t, err) assert.Equal(t, tt.expectedName, *name) assert.Equal(t, tt.expectedScope, string(scope)) - assert.Equal(t, tt.expectedExplicit, b) + assert.Equal(t, tt.expectedSpecifiedScope, specifiedScope) }) } }