Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 9 additions & 99 deletions go/vt/sqlparser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package sqlparser

import (
"fmt"
"strconv"
"strings"
"unicode"

Expand Down Expand Up @@ -96,6 +95,15 @@ func CanNormalize(stmt Statement) bool {
return false
}

//IsSetStatement takes Statement and returns if the statement is set statement.
func IsSetStatement(stmt Statement) bool {
switch stmt.(type) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_, ok := stmt.(*Set)
return ok

case *Set:
return true
}
return false
}

// Preview analyzes the beginning of the query using a simpler and faster
// textual comparison to identify the statement type.
func Preview(sql string) StatementType {
Expand Down Expand Up @@ -365,101 +373,3 @@ func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
}
return sqltypes.PlanValue{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "expression is too complex '%v'", String(node))
}

// SetKey is the extracted key from one SetExpr
type SetKey struct {
Key string
Scope string
}

// ExtractSetValues returns a map of key-value pairs
// if the query is a SET statement. Values can be bool, int64 or string.
// Since set variable names are case insensitive, all keys are returned
// as lower case.
func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope string, err error) {
stmt, err := Parse(sql)
if err != nil {
return nil, "", err
}
setStmt, ok := stmt.(*Set)
if !ok {
return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
}
result := make(map[SetKey]interface{})
for _, expr := range setStmt.Exprs {
var scope string
key := expr.Name.Lowered()

switch expr.Name.at {
case NoAt:
scope = ImplicitStr
case SingleAt:
scope = VariableStr
case DoubleAt:
switch {
case strings.HasPrefix(key, "global."):
scope = GlobalStr
key = strings.TrimPrefix(key, "global.")
case strings.HasPrefix(key, "session."):
scope = SessionStr
key = strings.TrimPrefix(key, "session.")
case strings.HasPrefix(key, "vitess_metadata."):
scope = VitessMetadataStr
key = strings.TrimPrefix(key, "vitess_metadata.")
default:
scope = SessionStr
}

// This is what correctly allows us to handle queries such as "set @@session.`autocommit`=1"
// it will remove backticks and double quotes that might surround the part after the first period
_, out := NewStringTokenizer(key).Scan()
key = string(out)
}

if setStmt.Scope != "" && scope != "" {
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
}

setKey := SetKey{
Key: key,
Scope: scope,
}

switch expr := expr.Expr.(type) {
case *SQLVal:
switch expr.Type {
case StrVal:
result[setKey] = strings.ToLower(string(expr.Val))
case IntVal:
num, err := strconv.ParseInt(string(expr.Val), 0, 64)
if err != nil {
return nil, "", err
}
result[setKey] = num
case FloatVal:
num, err := strconv.ParseFloat(string(expr.Val), 64)
if err != nil {
return nil, "", err
}
result[setKey] = num
default:
return nil, "", fmt.Errorf("invalid value type: %v", String(expr))
}
case BoolVal:
var val int64
if expr {
val = 1
}
result[setKey] = val
case *ColName:
result[setKey] = expr.Name.String()
case *NullVal:
result[setKey] = nil
case *Default:
result[setKey] = "default"
default:
return nil, "", fmt.Errorf("invalid syntax: %s", String(expr))
}
}
return result, strings.ToLower(setStmt.Scope), nil
}
185 changes: 0 additions & 185 deletions go/vt/sqlparser/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ import (
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"

"github.com/stretchr/testify/assert"
"vitess.io/vitess/go/sqltypes"
)
Expand Down Expand Up @@ -435,188 +432,6 @@ func TestNewPlanValue(t *testing.T) {
}
}

func TestExtractSetValues(t *testing.T) {
testcases := []struct {
sql string
out map[SetKey]interface{}
scope string
err string
}{{
sql: "invalid",
err: "syntax error at position 8 near 'invalid'",
}, {
sql: "select * from t",
err: "ast did not yield *sqlparser.Set: *sqlparser.Select",
}, {
sql: "set autocommit=1+1",
err: "invalid syntax: 1 + 1",
}, {
sql: "set transaction_mode='single'",
out: map[SetKey]interface{}{{Key: "transaction_mode", Scope: ImplicitStr}: "single"},
}, {
sql: "set autocommit=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: int64(1)},
}, {
sql: "set autocommit=true",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: int64(1)},
}, {
sql: "set autocommit=false",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: int64(0)},
}, {
sql: "set autocommit=on",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: "on"},
}, {
sql: "set autocommit=off",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: "off"},
}, {
sql: "set @@global.autocommit=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: GlobalStr}: int64(1)},
}, {
sql: "set @@global.autocommit=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: GlobalStr}: int64(1)},
}, {
sql: "set @@session.autocommit=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.`autocommit`=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.'autocommit'=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.\"autocommit\"=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.'\"autocommit'=1",
out: map[SetKey]interface{}{{Key: "\"autocommit", Scope: SessionStr}: int64(1)},
}, {
sql: "set @@session.`autocommit'`=1",
out: map[SetKey]interface{}{{Key: "autocommit'", Scope: SessionStr}: int64(1)},
}, {
sql: "set AUTOCOMMIT=1",
out: map[SetKey]interface{}{{Key: "autocommit", Scope: ImplicitStr}: int64(1)},
}, {
sql: "SET character_set_results = NULL",
out: map[SetKey]interface{}{{Key: "character_set_results", Scope: ImplicitStr}: nil},
}, {
sql: "SET foo = 0x1234",
err: "invalid value type: 0x1234",
}, {
sql: "SET names utf8",
out: map[SetKey]interface{}{{Key: "names", Scope: ImplicitStr}: "utf8"},
}, {
sql: "SET names ascii collate ascii_bin",
out: map[SetKey]interface{}{{Key: "names", Scope: ImplicitStr}: "ascii"},
}, {
sql: "SET charset default",
out: map[SetKey]interface{}{{Key: "charset", Scope: ImplicitStr}: "default"},
}, {
sql: "SET character set ascii",
out: map[SetKey]interface{}{{Key: "charset", Scope: ImplicitStr}: "ascii"},
}, {
sql: "SET SESSION wait_timeout = 3600",
out: map[SetKey]interface{}{{Key: "wait_timeout", Scope: ImplicitStr}: int64(3600)},
scope: SessionStr,
}, {
sql: "SET GLOBAL wait_timeout = 3600",
out: map[SetKey]interface{}{{Key: "wait_timeout", Scope: ImplicitStr}: int64(3600)},
scope: GlobalStr,
}, {
sql: "set session transaction isolation level repeatable read",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelRepeatableRead},
scope: SessionStr,
}, {
sql: "set session transaction isolation level read committed",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelReadCommitted},
scope: SessionStr,
}, {
sql: "set session transaction isolation level read uncommitted",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelReadUncommitted},
scope: SessionStr,
}, {
sql: "set session transaction isolation level serializable",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelSerializable},
scope: SessionStr,
}, {
sql: "set transaction isolation level serializable",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: IsolationLevelSerializable},
}, {
sql: "set transaction read only",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: TxReadOnly},
}, {
sql: "set transaction read write",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: TxReadWrite},
}, {
sql: "set session transaction read write",
out: map[SetKey]interface{}{{Key: TransactionStr, Scope: ImplicitStr}: TxReadWrite},
scope: SessionStr,
}, {
sql: "set session tx_read_only = 0",
out: map[SetKey]interface{}{{Key: "tx_read_only", Scope: ImplicitStr}: int64(0)},
scope: SessionStr,
}, {
sql: "set session tx_read_only = 1",
out: map[SetKey]interface{}{{Key: "tx_read_only", Scope: ImplicitStr}: int64(1)},
scope: SessionStr,
}, {
sql: "set session sql_safe_updates = 0",
out: map[SetKey]interface{}{{Key: "sql_safe_updates", Scope: ImplicitStr}: int64(0)},
scope: SessionStr,
}, {
sql: "set session transaction_read_only = 0",
out: map[SetKey]interface{}{{Key: "transaction_read_only", Scope: ImplicitStr}: int64(0)},
scope: SessionStr,
}, {
sql: "set session transaction_read_only = 1",
out: map[SetKey]interface{}{{Key: "transaction_read_only", Scope: ImplicitStr}: int64(1)},
scope: SessionStr,
}, {
sql: "set session sql_safe_updates = 1",
out: map[SetKey]interface{}{{Key: "sql_safe_updates", Scope: ImplicitStr}: int64(1)},
scope: SessionStr,
}, {
sql: "set @foo = 42",
out: map[SetKey]interface{}{
{Key: "foo", Scope: VariableStr}: int64(42),
},
scope: ImplicitStr,
}, {
sql: "set @foo.bar.baz = 42",
out: map[SetKey]interface{}{
{Key: "foo.bar.baz", Scope: VariableStr}: int64(42),
},
scope: ImplicitStr,
}, {
sql: "set @`string` = 'abc', @`float` = 4.2, @`int` = 42",
out: map[SetKey]interface{}{
{Key: "string", Scope: VariableStr}: "abc",
{Key: "float", Scope: VariableStr}: 4.2,
{Key: "int", Scope: VariableStr}: int64(42),
},
scope: ImplicitStr,
}, {
sql: "set session @foo = 42",
err: "unsupported in set: scope and user defined variables",
}, {
sql: "set global @foo = 42",
err: "unsupported in set: scope and user defined variables",
}}
for _, tcase := range testcases {
t.Run(tcase.sql, func(t *testing.T) {
out, _, err := ExtractSetValues(tcase.sql)
if tcase.err != "" {
require.Error(t, err, tcase.err)
} else if err != nil {
require.NoError(t, err)
}

if diff := cmp.Diff(tcase.out, out); diff != "" {
t.Error(diff)
}
})
}
}

func newStrVal(in string) *SQLVal {
return NewStrVal([]byte(in))
}
Expand Down
Loading