Skip to content
Merged
14 changes: 14 additions & 0 deletions go/test/endtoend/vtgate/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,20 @@ func TestCastConvert(t *testing.T) {
assertMatches(t, conn, `SELECT CAST("test" AS CHAR(60))`, `[[VARCHAR("test")]]`)
}

func TestUnion(t *testing.T) {
conn, err := mysql.Connect(context.Background(), &vtParams)
require.NoError(t, err)
defer conn.Close()

assertMatches(t, conn, `SELECT 1 UNION SELECT 1 UNION SELECT 1`, `[[INT64(1)]]`)
assertMatches(t, conn, `SELECT 1,'a' UNION SELECT 1,'a' UNION SELECT 1,'a' ORDER BY 1`, `[[INT64(1) VARCHAR("a")]]`)
assertMatches(t, conn, `SELECT 1,'z' UNION SELECT 2,'q' UNION SELECT 3,'b' ORDER BY 2`, `[[INT64(3) VARCHAR("b")] [INT64(2) VARCHAR("q")] [INT64(1) VARCHAR("z")]]`)
assertMatches(t, conn, `SELECT 1,'a' UNION ALL SELECT 1,'a' UNION ALL SELECT 1,'a' ORDER BY 1`, `[[INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")]]`)
assertMatches(t, conn, `(SELECT 1,'a') UNION ALL (SELECT 1,'a') UNION ALL (SELECT 1,'a') ORDER BY 1`, `[[INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")]]`)
assertMatches(t, conn, `(SELECT 1,'a') ORDER BY 1`, `[[INT64(1) VARCHAR("a")]]`)
assertMatches(t, conn, `(SELECT 1,'a' order by 1) union SELECT 1,'a' ORDER BY 1`, `[[INT64(1) VARCHAR("a")]]`)
}

func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) {
t.Helper()
qr := exec(t, conn, query)
Expand Down
28 changes: 21 additions & 7 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ type (
Lock string
}

// UnionSelect represents union type and select statement after first select statement.
UnionSelect struct {
Type string
Statement SelectStatement
}

// Union represents a UNION statement.
Union struct {
Type string
Left, Right SelectStatement
OrderBy OrderBy
Limit *Limit
Lock string
FirstStatement SelectStatement
UnionSelects []*UnionSelect
OrderBy OrderBy
Limit *Limit
Lock string
}

// Stream represents a SELECT statement.
Expand Down Expand Up @@ -874,8 +880,16 @@ func (node *ParenSelect) Format(buf *TrackedBuffer) {

// Format formats the node.
func (node *Union) Format(buf *TrackedBuffer) {
buf.astPrintf(node, "%v %s %v%v%v%s", node.Left, node.Type, node.Right,
node.OrderBy, node.Limit, node.Lock)
buf.astPrintf(node, "%v", node.FirstStatement)
for _, us := range node.UnionSelects {
buf.astPrintf(node, "%v", us)
}
buf.astPrintf(node, "%v%v%s", node.OrderBy, node.Limit, node.Lock)
}

// Format formats the node.
func (node *UnionSelect) Format(buf *TrackedBuffer) {
buf.astPrintf(node, " %s %v", node.Type, node.Statement)
}

// Format formats the node.
Expand Down
14 changes: 14 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,20 @@ func (node *Union) SetLimit(limit *Limit) {
node.Limit = limit
}

//Unionize returns a UNION, either creating one or adding SELECT to an existing one
func Unionize(lhs, rhs SelectStatement, typ string, by OrderBy, limit *Limit, lock string) *Union {
union, isUnion := lhs.(*Union)
if isUnion {
union.UnionSelects = append(union.UnionSelects, &UnionSelect{Type: typ, Statement: rhs})
union.OrderBy = by
union.Limit = limit
union.Lock = lock
return union
}

return &Union{FirstStatement: lhs, UnionSelects: []*UnionSelect{{Type: typ, Statement: rhs}}, OrderBy: by, Limit: limit, Lock: lock}
}

// AtCount represents the '@' count in ColIdent
type AtCount int

Expand Down
5 changes: 4 additions & 1 deletion go/vt/sqlparser/impossible_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) {
node.GroupBy.Format(buf)
}
case *Union:
buf.Myprintf("%v %s %v", node.Left, node.Type, node.Right)
buf.astPrintf(node, "%v", node.FirstStatement)
for _, us := range node.UnionSelects {
buf.astPrintf(node, "%v", us)
}
default:
node.Format(buf)
}
Expand Down
44 changes: 27 additions & 17 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"sync"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -141,6 +140,22 @@ var (
input: "select * from t1 where col in (select 1 from dual union select 2 from dual)",
}, {
input: "select * from t1 where exists (select a from t2 union select b from t3)",
}, {
input: "select 1 from dual union select 2 from dual union all select 3 from dual union select 4 from dual union all select 5 from dual",
}, {
input: "(select 1 from dual) order by 1 asc limit 2",
}, {
input: "(select 1 from dual order by 1 desc) order by 1 asc limit 2",
}, {
input: "(select 1 from dual)",
}, {
input: "((select 1 from dual))",
}, {
input: "select 1 from (select 1 from dual) as t",
}, {
input: "select 1 from (select 1 from dual union select 2 from dual) as t",
}, {
input: "select 1 from ((select 1 from dual) union select 2 from dual) as t",
}, {
input: "select /* distinct */ distinct 1 from t",
}, {
Expand Down Expand Up @@ -681,11 +696,9 @@ var (
input: "insert /* it accepts columns with keyword action */ into a(action, b) values (1, 2)",
output: "insert /* it accepts columns with keyword action */ into a(`action`, b) values (1, 2)",
}, {
input: "insert /* no cols & paren select */ into a(select * from t)",
output: "insert /* no cols & paren select */ into a select * from t",
input: "insert /* no cols & paren select */ into a (select * from t)",
}, {
input: "insert /* cols & paren select */ into a(a,b,c) (select * from t)",
output: "insert /* cols & paren select */ into a(a, b, c) select * from t",
input: "insert /* cols & paren select */ into a(a, b, c) (select * from t)",
}, {
input: "insert /* cols & union with paren select */ into a(b, c) (select d, e from f) union (select g from h)",
}, {
Expand Down Expand Up @@ -1653,8 +1666,8 @@ func TestValid(t *testing.T) {
tree, err := Parse(tcase.input)
require.NoError(t, err, tcase.input)
out := String(tree)
if diff := cmp.Diff(tcase.output, out); diff != "" {
t.Errorf("Parse(%q):\n%s", tcase.input, diff)
if tcase.output != out {
t.Errorf("Parsing failed. \nExpected/Got:\n%s\n%s", tcase.output, out)
}
// This test just exercises the tree walking functionality.
// There's no way automated way to verify that a node calls
Expand Down Expand Up @@ -1702,9 +1715,6 @@ func TestInvalid(t *testing.T) {
input string
err string
}{{
input: "select a from (select * from tbl)",
err: "Every derived table must have its own alias",
}, {
input: "select a, b from (select * from tbl) sort by a",
err: "syntax error",
}, {
Expand Down Expand Up @@ -2057,6 +2067,9 @@ func TestPositionedErr(t *testing.T) {
}, {
input: "select * from a left join b",
output: PositionedErr{"syntax error", 28, nil},
}, {
input: "select a from (select * from tbl)",
output: PositionedErr{"syntax error", 34, nil},
}}

for _, tcase := range invalidSQL {
Expand Down Expand Up @@ -2694,9 +2707,6 @@ var (
}, {
input: "select /* vitess-reserved keyword as unqualified column */ * from t where escape = 'test'",
output: "syntax error at position 81 near 'escape'",
}, {
input: "(select /* parenthesized select */ * from t)",
output: "syntax error at position 45",
}, {
input: "select * from t where id = ((select a from t1 union select b from t2) order by a limit 1)",
output: "syntax error at position 76 near 'order'",
Expand All @@ -2720,10 +2730,10 @@ var (

func TestErrors(t *testing.T) {
for _, tcase := range invalidSQL {
_, err := Parse(tcase.input)
if err == nil || err.Error() != tcase.output {
t.Errorf("%s: %v, want %s", tcase.input, err, tcase.output)
}
t.Run(tcase.input, func(t *testing.T) {
_, err := Parse(tcase.input)
require.Error(t, err, tcase.output)
})
}
}

Expand Down
30 changes: 24 additions & 6 deletions go/vt/sqlparser/rewriter.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading