Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion go/vt/proto/vtgate/vtgate.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

77 changes: 59 additions & 18 deletions go/vt/sqlparser/expression_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,33 @@ func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix
return RewriteAST(in)
}

// BindVarNeeds represents the bind vars that need to be provided as the result of expression rewriting.
type BindVarNeeds struct {
NeedLastInsertID bool
NeedDatabase bool
NeedFoundRows bool
}

// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries
func RewriteAST(in Statement) (*RewriteASTResult, error) {
er := new(expressionRewriter)
er := newExpressionRewriter()
er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in)
Rewrite(in, er.goingDown, nil)

return &RewriteASTResult{
AST: in,
NeedLastInsertID: er.lastInsertID,
NeedDatabase: er.database,
}, nil
r := &RewriteASTResult{
AST: in,
}
if _, ok := er.bindVars[LastInsertIDName]; ok {
r.NeedLastInsertID = true
}
if _, ok := er.bindVars[DBVarName]; ok {
r.NeedDatabase = true
}
if _, ok := er.bindVars[FoundRowsName]; ok {
r.NeedFoundRows = true
}

return r, nil
}

func shouldRewriteDatabaseFunc(in Statement) bool {
Expand All @@ -63,22 +79,29 @@ func shouldRewriteDatabaseFunc(in Statement) bool {

// RewriteASTResult contains the rewritten ast and meta information about it
type RewriteASTResult struct {
AST Statement
NeedLastInsertID bool
NeedDatabase bool
BindVarNeeds
AST Statement // The rewritten AST
}

type expressionRewriter struct {
lastInsertID, database bool
bindVars map[string]struct{}
shouldRewriteDatabaseFunc bool
err error
}

func newExpressionRewriter() *expressionRewriter {
return &expressionRewriter{bindVars: make(map[string]struct{})}
}

const (
//LastInsertIDName is a reserved bind var name for last_insert_id()
LastInsertIDName = "__lastInsertId"

//DBVarName is a reserved bind var name for database()
DBVarName = "__vtdbname"

//FoundRowsName is a reserved bind var name for found_rows()
FoundRowsName = "__vtfrows"
)

func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
Expand All @@ -87,7 +110,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
if node.As.IsEmpty() {
buf := NewTrackedBuffer(nil)
node.Expr.Format(buf)
inner := new(expressionRewriter)
inner := newExpressionRewriter()
inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc
tmp := Rewrite(node.Expr, inner.goingDown, nil)
newExpr, ok := tmp.(Expr)
Expand All @@ -96,11 +119,12 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
return false
}
node.Expr = newExpr
er.database = er.database || inner.database
er.lastInsertID = er.lastInsertID || inner.lastInsertID
if inner.didAnythingChange() {
node.As = NewColIdent(buf.String())
}
for k := range inner.bindVars {
er.needBindVarFor(k)
}
return false
}

Expand All @@ -111,22 +135,39 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
er.err = vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported")
} else {
cursor.Replace(bindVarExpression(LastInsertIDName))
er.lastInsertID = true
er.needBindVarFor(LastInsertIDName)
}
case node.Name.EqualString("database") && er.shouldRewriteDatabaseFunc:
case er.shouldRewriteDatabaseFunc &&
(node.Name.EqualString("database") ||
node.Name.EqualString("schema")):
if len(node.Exprs) > 0 {
er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. DATABASE() takes no arguments")
er.err = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. %s() takes no arguments", node.Name.String())
} else {
cursor.Replace(bindVarExpression(DBVarName))
er.database = true
er.needBindVarFor(DBVarName)
}
case node.Name.EqualString("found_rows"):
if len(node.Exprs) > 0 {
er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Arguments to FOUND_ROWS() not supported")
} else {
cursor.Replace(bindVarExpression(FoundRowsName))
er.needBindVarFor(FoundRowsName)
}
}

}
return true
}

// instead of creating new objects, we'll reuse this one
var token = struct{}{}

func (er *expressionRewriter) needBindVarFor(name string) {
er.bindVars[name] = token
}

func (er *expressionRewriter) didAnythingChange() bool {
return er.database || er.lastInsertID
return len(er.bindVars) > 0
}

func bindVarExpression(name string) *SQLVal {
Expand Down
27 changes: 19 additions & 8 deletions go/vt/sqlparser/expression_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ import (
)

type myTestCase struct {
in, expected string
liid, db bool
in, expected string
liid, db, foundRows bool
}

func TestRewrites(in *testing.T) {
tests := []myTestCase{
{
in: "SELECT 42",
expected: "SELECT 42",
db: false, liid: false,
db: false, liid: false, foundRows: false,
},
{
in: "SELECT last_insert_id()",
Expand All @@ -42,7 +42,7 @@ func TestRewrites(in *testing.T) {
{
in: "SELECT database()",
expected: "SELECT :__vtdbname as `database()`",
db: true, liid: false,
db: true, liid: false, foundRows: false,
},
{
in: "SELECT database() from test",
Expand All @@ -52,12 +52,12 @@ func TestRewrites(in *testing.T) {
{
in: "SELECT last_insert_id() as test",
expected: "SELECT :__lastInsertId as test",
db: false, liid: true,
db: false, liid: true, foundRows: false,
},
{
in: "SELECT last_insert_id() + database()",
expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`",
db: true, liid: true,
db: true, liid: true, foundRows: false,
},
{
in: "select (select database()) from test",
Expand All @@ -72,7 +72,7 @@ func TestRewrites(in *testing.T) {
{
in: "select (select database() from dual) from dual",
expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual",
db: true, liid: false,
db: true, liid: false, foundRows: false,
},
{
in: "select id from user where database()",
Expand All @@ -82,7 +82,17 @@ func TestRewrites(in *testing.T) {
{
in: "select table_name from information_schema.tables where table_schema = database()",
expected: "select table_name from information_schema.tables where table_schema = database()",
db: false, liid: false,
db: false, liid: false, foundRows: false,
},
{
in: "select schema()",
expected: "select :__vtdbname as 'schema()'",
db: true, liid: false, foundRows: false,
},
{
in: "select found_rows()",
expected: "select :__vtfrows as 'found_rows()'",
db: false, liid: false, foundRows: true,
},
}

Expand All @@ -101,6 +111,7 @@ func TestRewrites(in *testing.T) {
require.Equal(t, s, String(result.AST))
require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id")
require.Equal(t, tc.db, result.NeedDatabase, "should need database name")
require.Equal(t, tc.foundRows, result.NeedFoundRows, "should need found rows")
})
}
}
46 changes: 27 additions & 19 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,12 @@ var (
input: "select match(a) against ('foo') from t",
}, {
input: "select match(a1, a2) against ('foo' in natural language mode with query expansion) from t",
}, {
input: "select database()",
output: "select database() from dual",
}, {
input: "select schema()",
output: "select schema() from dual",
}, {
input: "select title from video as v where match(v.title, v.tag) against ('DEMO' in boolean mode)",
}, {
Expand Down Expand Up @@ -1522,25 +1528,27 @@ var (

func TestValid(t *testing.T) {
for _, tcase := range validSQL {
if tcase.output == "" {
tcase.output = tcase.input
}
tree, err := Parse(tcase.input)
if err != nil {
t.Errorf("Parse(%q) err: %v, want nil", tcase.input, err)
continue
}
out := String(tree)
if out != tcase.output {
t.Errorf("Parse(%q) = %q, want: %q", tcase.input, out, tcase.output)
}
// This test just exercises the tree walking functionality.
// There's no way automated way to verify that a node calls
// all its children. But we can examine code coverage and
// ensure that all walkSubtree functions were called.
Walk(func(node SQLNode) (bool, error) {
return true, nil
}, tree)
t.Run(tcase.input, func(t *testing.T) {
if tcase.output == "" {
tcase.output = tcase.input
}
tree, err := Parse(tcase.input)
if err != nil {
t.Errorf("Parse(%q) err: %v, want nil", tcase.input, err)
return
}
out := String(tree)
if out != tcase.output {
t.Errorf("Parse(%q) = %q, want: %q", tcase.input, out, tcase.output)
}
// This test just exercises the tree walking functionality.
// There's no way automated way to verify that a node calls
// all its children. But we can examine code coverage and
// ensure that all walkSubtree functions were called.
Walk(func(node SQLNode) (bool, error) {
return true, nil
}, tree)
})
}
}

Expand Down
Loading