From eb01da2bed2356a71f448102e4845ac872104d77 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 9 Jun 2020 21:43:30 +0200 Subject: [PATCH] Make it possible to rewrite function calls inside function calls Signed-off-by: Andres Taylor --- go/vt/sqlparser/expression_rewriting.go | 40 ++++++++++---------- go/vt/sqlparser/expression_rewriting_test.go | 5 +++ 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go index ece90c80fac..79df8d1ae23 100644 --- a/go/vt/sqlparser/expression_rewriting.go +++ b/go/vt/sqlparser/expression_rewriting.go @@ -132,26 +132,28 @@ const ( func (er *expressionRewriter) goingDown(cursor *Cursor) bool { switch node := cursor.Node().(type) { // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` - case *AliasedExpr: - if node.As.IsEmpty() { - buf := NewTrackedBuffer(nil) - node.Expr.Format(buf) - inner := newExpressionRewriter() - inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc - tmp := Rewrite(node.Expr, inner.goingDown, nil) - newExpr, ok := tmp.(Expr) - if !ok { - log.Errorf("failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) - return false + case *Select: + for _, col := range node.SelectExprs { + aliasedExpr, ok := col.(*AliasedExpr) + if ok && aliasedExpr.As.IsEmpty() { + buf := NewTrackedBuffer(nil) + aliasedExpr.Expr.Format(buf) + inner := newExpressionRewriter() + inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc + tmp := Rewrite(aliasedExpr.Expr, inner.goingDown, nil) + newExpr, ok := tmp.(Expr) + if !ok { + log.Errorf("failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) + return false + } + aliasedExpr.Expr = newExpr + if inner.didAnythingChange() { + aliasedExpr.As = NewColIdent(buf.String()) + } + for k := range inner.bindVars { + er.needBindVarFor(k) + } } - node.Expr = newExpr - if inner.didAnythingChange() { - node.As = NewColIdent(buf.String()) - } - for k := range inner.bindVars { - er.needBindVarFor(k) - } - return false } case *FuncExpr: er.funcRewrite(cursor, node) diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/expression_rewriting_test.go index 65433767afd..4e10c0418fd 100644 --- a/go/vt/sqlparser/expression_rewriting_test.go +++ b/go/vt/sqlparser/expression_rewriting_test.go @@ -114,6 +114,11 @@ func TestRewrites(in *testing.T) { expected: "select :__vtrcount as `row_count()`", rowCount: true, }, + { + in: "SELECT lower(database())", + expected: "SELECT lower(:__vtdbname) as `lower(database())`", + db: true, + }, } for _, tc := range tests {