Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 49 additions & 44 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -6513,113 +6513,118 @@ const (
SetScope_User SetScope = "user"
)

// VarScopeForColName returns the SetScope of the given ColName, along with a new ColName without the scope information.
func VarScopeForColName(colName *ColName) (*ColName, SetScope, error) {
// VarScopeForColName returns the SetScope of the given ColName, along with a new ColName without the scope information,
// and a string indicating the exact scope that was specified in the original query or "" if no scope was explicitly
// specified.
func VarScopeForColName(colName *ColName) (*ColName, SetScope, string, error) {
if colName.Qualifier.IsEmpty() { // Forms are like `@@x` and `@x`
if strings.HasPrefix(colName.Name.val, "@") && strings.Index(colName.Name.val, ".") != -1 {
varName, scope, err := VarScope(strings.Split(colName.Name.val, ".")...)
varName, scope, specifiedScope, err := VarScope(strings.Split(colName.Name.val, ".")...)
if err != nil {
return nil, SetScope_None, err
return nil, SetScope_None, "", err
}
if scope == SetScope_None {
return colName, scope, nil
return colName, scope, "", nil
}
return &ColName{Name: ColIdent{val: varName}}, scope, nil
return &ColName{Name: ColIdent{val: varName}}, scope, specifiedScope, nil
} else {
varName, scope, err := VarScope(colName.Name.val)
varName, scope, specifiedScope, err := VarScope(colName.Name.val)
if err != nil {
return nil, SetScope_None, err
return nil, SetScope_None, "", err
}
if scope == SetScope_None {
return colName, scope, nil
return colName, scope, "", nil
}
return &ColName{Name: ColIdent{val: varName}}, scope, nil
return &ColName{Name: ColIdent{val: varName}}, scope, specifiedScope, nil
}
} else if colName.Qualifier.Qualifier.IsEmpty() { // Forms are like `@@GLOBAL.x` and `@@SESSION.x`
varName, scope, err := VarScope(colName.Qualifier.Name.v, colName.Name.val)
varName, scope, specifiedScope, err := VarScope(colName.Qualifier.Name.v, colName.Name.val)
if err != nil {
return nil, SetScope_None, err
return nil, SetScope_None, "", err
}
if scope == SetScope_None {
return colName, scope, nil
return colName, scope, "", nil
}
return &ColName{Name: ColIdent{val: varName}}, scope, nil
return &ColName{Name: ColIdent{val: varName}}, scope, specifiedScope, nil
} else { // Forms are like `@@GLOBAL.validate_password.length`, which is currently unsupported
_, _, err := VarScope(colName.Qualifier.Qualifier.v, colName.Qualifier.Name.v, colName.Name.val)
return colName, SetScope_None, err
_, _, _, err := VarScope(colName.Qualifier.Qualifier.v, colName.Qualifier.Name.v, colName.Name.val)
return colName, SetScope_None, "", err
}
}

// VarScope returns the SetScope of the given name, broken into parts. For example, `@@GLOBAL.sys_var` would become
// `[]string{"@@GLOBAL", "sys_var"}`. Returns the variable name without any scope specifiers, so the aforementioned
// variable would simply return "sys_var". `[]string{"@@other_var"}` would return "other_var". If the name parts do not
// specify a variable (returns SetScope_None), then it is recommended to use the original non-broken string, as this
// will always only return the last part. `[]string{"my_db", "my_tbl", "my_col"}` will return "my_col" with SetScope_None.
func VarScope(nameParts ...string) (string, SetScope, error) {
// variable would simply return "sys_var". `[]string{"@@other_var"}` would return "other_var". If a scope is not
// explicitly specified, then the requestedScope string will be empty, otherwise it will be the exact
// scope that was explicitly specified, which can differ from the returned scope, when the returned scope is
// inferred. If the name parts do not specify a variable (returns SetScope_None), then it is recommended to use the original non-broken string, as this will always only return the last part.
// `[]string{"my_db", "my_tbl", "my_col"}` will return "my_col" with SetScope_None.
func VarScope(nameParts ...string) (string, SetScope, string, error) {
switch len(nameParts) {
case 0:
return "", SetScope_None, nil
return "", SetScope_None, "", nil
case 1:
// First case covers `@@@`, `@@@@`, etc.
if strings.HasPrefix(nameParts[0], "@@@") {
return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[0])
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[0])
} else if strings.HasPrefix(nameParts[0], "@@") {
dotIdx := strings.Index(nameParts[0], ".")
if dotIdx != -1 {
return VarScope(nameParts[0][:dotIdx], nameParts[0][dotIdx+1:])
}
return nameParts[0][2:], SetScope_Session, nil
// Session scope is inferred here, but not explicitly requested
return nameParts[0][2:], SetScope_Session, "", nil
} else if strings.HasPrefix(nameParts[0], "@") {
return nameParts[0][1:], SetScope_User, nil
return nameParts[0][1:], SetScope_User, "", nil
} else {
return nameParts[0], SetScope_None, nil
return nameParts[0], SetScope_None, "", nil
}
case 2:
// `@user.var` is valid, so we check for it here.
if len(nameParts[0]) >= 2 && nameParts[0][0] == '@' && nameParts[0][1] != '@' &&
!strings.HasPrefix(nameParts[1], "@") { // `@user.@var` is invalid though.
return fmt.Sprintf("%s.%s", nameParts[0][1:], nameParts[1]), SetScope_User, nil
return fmt.Sprintf("%s.%s", nameParts[0][1:], nameParts[1]), SetScope_User, "", nil
}
// We don't support variables such as `@@validate_password.length` right now, only `@@GLOBAL.sys_var`, etc.
// The `@` symbols are only valid on the first name_part. First case also catches `@@@`, etc.
if strings.HasPrefix(nameParts[1], "@@") {
return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
} else if strings.HasPrefix(nameParts[1], "@") {
return "", SetScope_None, fmt.Errorf("invalid user variable declaration `%s`", nameParts[1])
return "", SetScope_None, "", fmt.Errorf("invalid user variable declaration `%s`", nameParts[1])
}
switch strings.ToLower(nameParts[0]) {
case "@@global":
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_Global, nil
return nameParts[1], SetScope_Global, nameParts[0][2:], nil
case "@@persist":
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_Persist, nil
return nameParts[1], SetScope_Persist, nameParts[0][2:], nil
case "@@persist_only":
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_PersistOnly, nil
return nameParts[1], SetScope_PersistOnly, nameParts[0][2:], nil
case "@@session":
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_Session, nil
return nameParts[1], SetScope_Session, nameParts[0][2:], nil
case "@@local":
if strings.HasPrefix(nameParts[1], `"`) || strings.HasPrefix(nameParts[1], `'`) {
return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_Session, nil
return nameParts[1], SetScope_Session, nameParts[0][2:], nil
default:
// This catches `@@@GLOBAL.sys_var`. Due to the earlier check, this does not error on `@user.var`.
if strings.HasPrefix(nameParts[0], "@") {
// Last value is column name, so we return that in the error
return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[1])
}
return nameParts[1], SetScope_None, nil
return nameParts[1], SetScope_None, "", nil
}
default:
// `@user.var.name` is valid, so we check for it here.
Expand All @@ -6628,20 +6633,20 @@ func VarScope(nameParts ...string) (string, SetScope, error) {
for i := 1; i < len(nameParts); i++ {
if strings.HasPrefix(nameParts[i], "@") {
// Last value is column name, so we return that in the error
return "", SetScope_None, fmt.Errorf("invalid user variable declaration `%s`", nameParts[len(nameParts)-1])
return "", SetScope_None, "", fmt.Errorf("invalid user variable declaration `%s`", nameParts[len(nameParts)-1])
}
}
return strings.Join(append([]string{nameParts[0][1:]}, nameParts[1:]...), "."), SetScope_User, nil
return strings.Join(append([]string{nameParts[0][1:]}, nameParts[1:]...), "."), SetScope_User, "", nil
}
// As we don't support `@@GLOBAL.validate_password.length` or anything potentially longer, we error if any part
// starts with either `@@` or `@`. We can just check for `@` though.
for _, namePart := range nameParts {
if strings.HasPrefix(namePart, "@") {
// Last value is column name, so we return that in the error
return "", SetScope_None, fmt.Errorf("invalid system variable declaration `%s`", nameParts[len(nameParts)-1])
return "", SetScope_None, "", fmt.Errorf("invalid system variable declaration `%s`", nameParts[len(nameParts)-1])
}
}
return nameParts[len(nameParts)-1], SetScope_None, nil
return nameParts[len(nameParts)-1], SetScope_None, "", nil
}
}

Expand Down
69 changes: 69 additions & 0 deletions go/vt/sqlparser/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"testing"
"unsafe"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/dolthub/vitess/go/sqltypes"
Expand Down Expand Up @@ -884,6 +885,74 @@ func TestSplitStatementToPieces(t *testing.T) {
}
}

func TestVarScopeForColName(t *testing.T) {
testcases := []struct {
colName ColName
expectedName ColName
expectedScope string
expectedSpecifiedScope string
}{
{
// Regular column name
colName: ColName{Name: ColIdent{val: "a"}},
expectedName: ColName{Name: ColIdent{val: "a"}},
expectedSpecifiedScope: "",
},
{
// User variable
colName: ColName{Name: ColIdent{val: "@aaa"}},
expectedName: ColName{Name: ColIdent{val: "aaa"}},
expectedSpecifiedScope: "",
expectedScope: "user",
},
{
// System variable without an explicit scope (defaults to session)
colName: ColName{Name: ColIdent{val: "@@max_allowed_packets"}},
expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}},
expectedSpecifiedScope: "",
expectedScope: "session",
},
{
// System variable with an explicit session scope
colName: ColName{Name: ColIdent{val: "@@session.max_allowed_packets"}},
expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}},
expectedSpecifiedScope: "session",
expectedScope: "session",
},
{
// System variable with an explicit global scope
colName: ColName{Name: ColIdent{val: "@@global.max_allowed_packets"}},
expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}},
expectedSpecifiedScope: "global",
expectedScope: "global",
},
{
// System variable with an explicit global scope (in all caps)
colName: ColName{Name: ColIdent{val: "@@GLOBAL.max_allowed_packets"}},
expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}},
expectedSpecifiedScope: "GLOBAL",
expectedScope: "global",
},
{
// System variable with an explicit persist scope
colName: ColName{Name: ColIdent{val: "@@persist.max_allowed_packets"}},
expectedName: ColName{Name: ColIdent{val: "max_allowed_packets"}},
expectedSpecifiedScope: "persist",
expectedScope: "persist",
},
}

for _, tt := range testcases {
t.Run(tt.colName.String(), func(t *testing.T) {
name, scope, specifiedScope, err := VarScopeForColName(&tt.colName)
assert.NoError(t, err)
assert.Equal(t, tt.expectedName, *name)
assert.Equal(t, tt.expectedScope, string(scope))
assert.Equal(t, tt.expectedSpecifiedScope, specifiedScope)
})
}
}

func TestWindowErrors(t *testing.T) {
testcases := []struct {
input string
Expand Down
6 changes: 3 additions & 3 deletions go/vt/sqlparser/sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions go/vt/sqlparser/sql.y
Original file line number Diff line number Diff line change
Expand Up @@ -7907,7 +7907,7 @@ set_list:
set_expression:
set_expression_assignment
{
colName, scope, err := VarScopeForColName($1.Name)
colName, scope, _, err := VarScopeForColName($1.Name)
if err != nil {
yylex.Error(err.Error())
return 1
Expand All @@ -7918,7 +7918,7 @@ set_expression:
}
| set_scope_primary set_expression_assignment
{
_, scope, err := VarScopeForColName($2.Name)
_, scope, _, err := VarScopeForColName($2.Name)
if err != nil {
yylex.Error(err.Error())
return 1
Expand All @@ -7931,7 +7931,7 @@ set_expression:
}
| set_scope_secondary set_expression_assignment
{
_, scope, err := VarScopeForColName($2.Name)
_, scope, _, err := VarScopeForColName($2.Name)
if err != nil {
yylex.Error(err.Error())
return 1
Expand Down