diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 07fd70b06d7..00530116537 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -6513,113 +6513,118 @@ const ( SetScope_User SetScope = "user" ) -// VarScopeForColName returns the SetScope of the given ColName, along with a new ColName without the scope information. -func VarScopeForColName(colName *ColName) (*ColName, SetScope, error) { +// VarScopeForColName returns the SetScope of the given ColName, along with a new ColName without the scope information, +// and a string indicating the exact scope that was specified in the original query or "" if no scope was explicitly +// specified. +func VarScopeForColName(colName *ColName) (*ColName, SetScope, string, error) { if colName.Qualifier.IsEmpty() { // Forms are like `@@x` and `@x` if strings.HasPrefix(colName.Name.val, "@") && strings.Index(colName.Name.val, ".") != -1 { - varName, scope, err := VarScope(strings.Split(colName.Name.val, ".")...) + varName, scope, specifiedScope, err := VarScope(strings.Split(colName.Name.val, ".")...) if err != nil { - return nil, SetScope_None, err + return nil, SetScope_None, "", err } if scope == SetScope_None { - return colName, scope, nil + return colName, scope, "", nil } - return &ColName{Name: ColIdent{val: varName}}, scope, nil + return &ColName{Name: ColIdent{val: varName}}, scope, specifiedScope, nil } else { - varName, scope, err := VarScope(colName.Name.val) + varName, scope, specifiedScope, err := VarScope(colName.Name.val) if err != nil { - return nil, SetScope_None, err + return nil, SetScope_None, "", err } if scope == SetScope_None { - return colName, scope, nil + return colName, scope, "", nil } - return &ColName{Name: ColIdent{val: varName}}, scope, nil + return &ColName{Name: ColIdent{val: varName}}, scope, specifiedScope, nil } } else if colName.Qualifier.Qualifier.IsEmpty() { // Forms are like `@@GLOBAL.x` and `@@SESSION.x` - varName, scope, err := VarScope(colName.Qualifier.Name.v, colName.Name.val) + varName, scope, specifiedScope, err := VarScope(colName.Qualifier.Name.v, colName.Name.val) if err != nil { - return nil, SetScope_None, err + return nil, SetScope_None, "", err } if scope == SetScope_None { - return colName, scope, nil + return colName, scope, "", nil } - return &ColName{Name: ColIdent{val: varName}}, scope, nil + return &ColName{Name: ColIdent{val: varName}}, scope, specifiedScope, nil } else { // Forms are like `@@GLOBAL.validate_password.length`, which is currently unsupported - _, _, err := VarScope(colName.Qualifier.Qualifier.v, colName.Qualifier.Name.v, colName.Name.val) - return colName, SetScope_None, err + _, _, _, err := VarScope(colName.Qualifier.Qualifier.v, colName.Qualifier.Name.v, colName.Name.val) + return colName, SetScope_None, "", err } } // VarScope returns the SetScope of the given name, broken into parts. For example, `@@GLOBAL.sys_var` would become // `[]string{"@@GLOBAL", "sys_var"}`. Returns the variable name without any scope specifiers, so the aforementioned -// variable would simply return "sys_var". `[]string{"@@other_var"}` would return "other_var". If the name parts do not -// specify a variable (returns SetScope_None), then it is recommended to use the original non-broken string, as this -// will always only return the last part. `[]string{"my_db", "my_tbl", "my_col"}` will return "my_col" with SetScope_None. -func VarScope(nameParts ...string) (string, SetScope, error) { +// variable would simply return "sys_var". `[]string{"@@other_var"}` would return "other_var". If a scope is not +// explicitly specified, then the requestedScope string will be empty, otherwise it will be the exact +// scope that was explicitly specified, which can differ from the returned scope, when the returned scope is +// inferred. If the name parts do not specify a variable (returns SetScope_None), then it is recommended to use the original non-broken string, as this will always only return the last part. +// `[]string{"my_db", "my_tbl", "my_col"}` will return "my_col" with SetScope_None. +func VarScope(nameParts ...string) (string, SetScope, string, error) { switch len(nameParts) { case 0: - return "", SetScope_None, nil + return "", SetScope_None, "", nil case 1: // First case covers `@@@`, `@@@@`, etc. if strings.HasPrefix(nameParts[0], "@@@") { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[0]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[0]) } else if strings.HasPrefix(nameParts[0], "@@") { dotIdx := strings.Index(nameParts[0], ".") if dotIdx != -1 { return VarScope(nameParts[0][:dotIdx], nameParts[0][dotIdx+1:]) } - return nameParts[0][2:], SetScope_Session, nil + // Session scope is inferred here, but not explicitly requested + return nameParts[0][2:], SetScope_Session, "", nil } else if strings.HasPrefix(nameParts[0], "@") { - return nameParts[0][1:], SetScope_User, nil + return nameParts[0][1:], SetScope_User, "", nil } else { - return nameParts[0], SetScope_None, nil + return nameParts[0], SetScope_None, "", nil } case 2: // `@user.var` is valid, so we check for it here. if len(nameParts[0]) >= 2 && nameParts[0][0] == '@' && nameParts[0][1] != '@' && !strings.HasPrefix(nameParts[1], "@") { // `@user.@var` is invalid though. - return fmt.Sprintf("%s.%s", nameParts[0][1:], nameParts[1]), SetScope_User, nil + return fmt.Sprintf("%s.%s", nameParts[0][1:], nameParts[1]), SetScope_User, "", nil } // We don't support variables such as `@@validate_password.length` right now, only `@@GLOBAL.sys_var`, etc. // The `@` symbols are only valid on the first name_part. First case also catches `@@@`, etc. if strings.HasPrefix(nameParts[1], "@@") { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } else if strings.HasPrefix(nameParts[1], "@") { - return "", SetScope_None, fmt.Errorf("invalid user variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid user variable declaration `%s`", nameParts[1]) } switch strings.ToLower(nameParts[0]) { case "@@global": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Global, nil + return nameParts[1], SetScope_Global, nameParts[0][2:], nil case "@@persist": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Persist, nil + return nameParts[1], SetScope_Persist, nameParts[0][2:], nil case "@@persist_only": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_PersistOnly, nil + return nameParts[1], SetScope_PersistOnly, nameParts[0][2:], nil case "@@session": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Session, nil + return nameParts[1], SetScope_Session, nameParts[0][2:], nil case "@@local": if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) { - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_Session, nil + return nameParts[1], SetScope_Session, nameParts[0][2:], nil default: // This catches `@@@GLOBAL.sys_var`. Due to the earlier check, this does not error on `@user.var`. if strings.HasPrefix(nameParts[0], "@") { // Last value is column name, so we return that in the error - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1]) } - return nameParts[1], SetScope_None, nil + return nameParts[1], SetScope_None, "", nil } default: // `@user.var.name` is valid, so we check for it here. @@ -6628,20 +6633,20 @@ func VarScope(nameParts ...string) (string, SetScope, error) { for i := 1; i < len(nameParts); i++ { if strings.HasPrefix(nameParts[i], "@") { // Last value is column name, so we return that in the error - return "", SetScope_None, fmt.Errorf("invalid user variable declaration `%s`", nameParts[len(nameParts)-1]) + return "", SetScope_None, "", fmt.Errorf("invalid user variable declaration `%s`", nameParts[len(nameParts)-1]) } } - return strings.Join(append([]string{nameParts[0][1:]}, nameParts[1:]...), "."), SetScope_User, nil + return strings.Join(append([]string{nameParts[0][1:]}, nameParts[1:]...), "."), SetScope_User, "", nil } // As we don't support `@@GLOBAL.validate_password.length` or anything potentially longer, we error if any part // starts with either `@@` or `@`. We can just check for `@` though. for _, namePart := range nameParts { if strings.HasPrefix(namePart, "@") { // Last value is column name, so we return that in the error - return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[len(nameParts)-1]) + return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[len(nameParts)-1]) } } - return nameParts[len(nameParts)-1], SetScope_None, nil + return nameParts[len(nameParts)-1], SetScope_None, "", nil } } diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index 459c26d265b..5edc447a038 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -24,6 +24,7 @@ import ( "testing" "unsafe" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dolthub/vitess/go/sqltypes" @@ -884,6 +885,74 @@ func TestSplitStatementToPieces(t *testing.T) { } } +func TestVarScopeForColName(t *testing.T) { + testcases := []struct { + colName ColName + expectedName ColName + expectedScope string + expectedSpecifiedScope string + }{ + { + // Regular column name + colName: ColName{Name: ColIdent{val: "a"}}, + expectedName: ColName{Name: ColIdent{val: "a"}}, + expectedSpecifiedScope: "", + }, + { + // User variable + colName: ColName{Name: ColIdent{val: "@aaa"}}, + expectedName: ColName{Name: ColIdent{val: "aaa"}}, + expectedSpecifiedScope: "", + expectedScope: "user", + }, + { + // System variable without an explicit scope (defaults to session) + colName: ColName{Name: ColIdent{val: "@@max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "", + expectedScope: "session", + }, + { + // System variable with an explicit session scope + colName: ColName{Name: ColIdent{val: "@@session.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "session", + expectedScope: "session", + }, + { + // System variable with an explicit global scope + colName: ColName{Name: ColIdent{val: "@@global.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "global", + expectedScope: "global", + }, + { + // System variable with an explicit global scope (in all caps) + colName: ColName{Name: ColIdent{val: "@@GLOBAL.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "GLOBAL", + expectedScope: "global", + }, + { + // System variable with an explicit persist scope + colName: ColName{Name: ColIdent{val: "@@persist.max_allowed_packets"}}, + expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}}, + expectedSpecifiedScope: "persist", + expectedScope: "persist", + }, + } + + for _, tt := range testcases { + t.Run(tt.colName.String(), func(t *testing.T) { + name, scope, specifiedScope, err := VarScopeForColName(&tt.colName) + assert.NoError(t, err) + assert.Equal(t, tt.expectedName, *name) + assert.Equal(t, tt.expectedScope, string(scope)) + assert.Equal(t, tt.expectedSpecifiedScope, specifiedScope) + }) + } +} + func TestWindowErrors(t *testing.T) { testcases := []struct { input string diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index aebdfd5cc36..87ade8c9531 100755 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -20871,7 +20871,7 @@ yydefault: yyDollar = yyS[yypt-1 : yypt+1] //line sql.y:7909 { - colName, scope, err := VarScopeForColName(yyDollar[1].setVarExpr.Name) + colName, scope, _, err := VarScopeForColName(yyDollar[1].setVarExpr.Name) if err != nil { yylex.Error(err.Error()) return 1 @@ -20884,7 +20884,7 @@ yydefault: yyDollar = yyS[yypt-2 : yypt+1] //line sql.y:7920 { - _, scope, err := VarScopeForColName(yyDollar[2].setVarExpr.Name) + _, scope, _, err := VarScopeForColName(yyDollar[2].setVarExpr.Name) if err != nil { yylex.Error(err.Error()) return 1 @@ -20899,7 +20899,7 @@ yydefault: yyDollar = yyS[yypt-2 : yypt+1] //line sql.y:7933 { - _, scope, err := VarScopeForColName(yyDollar[2].setVarExpr.Name) + _, scope, _, err := VarScopeForColName(yyDollar[2].setVarExpr.Name) if err != nil { yylex.Error(err.Error()) return 1 diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 7c8dddbcc69..5cbaac88139 100755 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -7907,7 +7907,7 @@ set_list: set_expression: set_expression_assignment { - colName, scope, err := VarScopeForColName($1.Name) + colName, scope, _, err := VarScopeForColName($1.Name) if err != nil { yylex.Error(err.Error()) return 1 @@ -7918,7 +7918,7 @@ set_expression: } | set_scope_primary set_expression_assignment { - _, scope, err := VarScopeForColName($2.Name) + _, scope, _, err := VarScopeForColName($2.Name) if err != nil { yylex.Error(err.Error()) return 1 @@ -7931,7 +7931,7 @@ set_expression: } | set_scope_secondary set_expression_assignment { - _, scope, err := VarScopeForColName($2.Name) + _, scope, _, err := VarScopeForColName($2.Name) if err != nil { yylex.Error(err.Error()) return 1