diff --git a/go/vt/sqlparser/impossible_query.go b/go/vt/sqlparser/impossible_query.go index a6bf1ea8736..91377fcc868 100644 --- a/go/vt/sqlparser/impossible_query.go +++ b/go/vt/sqlparser/impossible_query.go @@ -41,10 +41,15 @@ func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) { node.GroupBy.Format(buf) } case *Union: + if node.With != nil { + node.With.Format(buf) + } if requiresParen(node.Left) { - buf.astPrintf(node, "(%v)", node.Left) + buf.WriteString("(") + FormatImpossibleQuery(buf, node.Left) + buf.WriteString(")") } else { - buf.astPrintf(node, "%v", node.Left) + FormatImpossibleQuery(buf, node.Left) } buf.WriteString(" ") @@ -56,9 +61,11 @@ func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) { buf.WriteString(" ") if requiresParen(node.Right) { - buf.astPrintf(node, "(%v)", node.Right) + buf.WriteString("(") + FormatImpossibleQuery(buf, node.Right) + buf.WriteString(")") } else { - buf.astPrintf(node, "%v", node.Right) + FormatImpossibleQuery(buf, node.Right) } default: node.Format(buf) diff --git a/go/vt/sqlparser/impossible_query_test.go b/go/vt/sqlparser/impossible_query_test.go new file mode 100644 index 00000000000..48e35fff0b9 --- /dev/null +++ b/go/vt/sqlparser/impossible_query_test.go @@ -0,0 +1,223 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFormatImpossibleQuery_Select(t *testing.T) { + parser := NewTestParser() + + testcases := []struct { + name string + input string + expected string + }{ + { + name: "basic select", + input: "select * from t", + expected: "select * from t where 1 != 1", + }, + { + name: "select with existing where", + input: "select * from t where id = 1", + expected: "select * from t where 1 != 1", + }, + { + name: "select with group by", + input: "select col from t group by col", + expected: "select col from t where 1 != 1 group by col", + }, + { + name: "select with multiple columns", + input: "select id, name from users", + expected: "select id, `name` from users where 1 != 1", + }, + { + name: "select with multiple tables", + input: "select * from t1, t2", + expected: "select * from t1, t2 where 1 != 1", + }, + { + name: "select with join", + input: "select * from t1 join t2 on t1.id = t2.id", + expected: "select * from t1 join t2 on t1.id = t2.id where 1 != 1", + }, + { + name: "select with where and group by", + input: "select col from t where id > 5 group by col", + expected: "select col from t where 1 != 1 group by col", + }, + { + name: "select with with clause", + input: "with cte as (select * from t) select * from cte", + expected: "with cte as (select * from t) select * from cte where 1 != 1", + }, + { + name: "select with multiple CTEs", + input: "with cte1 as (select * from t1), cte2 as (select * from t2) select * from cte1 join cte2 on cte1.id = cte2.id", + expected: "with cte1 as (select * from t1) , cte2 as (select * from t2) select * from cte1 join cte2 on cte1.id = cte2.id where 1 != 1", + }, + { + name: "select with recursive CTE", + input: "with recursive cte as (select id from t1 union select id + 1 from cte where id < 10) select * from cte", + expected: "with recursive cte as (select id from t1 union select id + 1 from cte where id < 10) select * from cte where 1 != 1", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + stmt, err := parser.Parse(tc.input) + require.NoError(t, err) + + buf := NewTrackedBuffer(nil) + FormatImpossibleQuery(buf, stmt) + + assert.Equal(t, tc.expected, buf.String()) + }) + } +} + +func TestFormatImpossibleQuery_Union(t *testing.T) { + parser := NewTestParser() + + testcases := []struct { + name string + input string + expected string + }{ + { + name: "simple union", + input: "select * from t1 union select * from t2", + expected: "select * from t1 where 1 != 1 union select * from t2 where 1 != 1", + }, + { + name: "union all", + input: "select * from t1 union all select * from t2", + expected: "select * from t1 where 1 != 1 union all select * from t2 where 1 != 1", + }, + { + name: "union with parentheses", + input: "(select * from t1) union (select * from t2)", + expected: "select * from t1 where 1 != 1 union select * from t2 where 1 != 1", + }, + { + name: "nested union", + input: "select * from t1 union select * from t2 union select * from t3", + expected: "select * from t1 where 1 != 1 union select * from t2 where 1 != 1 union select * from t3 where 1 != 1", + }, + { + name: "union with where clauses", + input: "select * from t1 where id > 1 union select * from t2 where name = 'test'", + expected: "select * from t1 where 1 != 1 union select * from t2 where 1 != 1", + }, + { + name: "union with complex select", + input: "select id, name from users group by id union select id, title from posts", + expected: "select id, `name` from users where 1 != 1 group by id union select id, title from posts where 1 != 1", + }, + { + name: "union with with clause", + input: "with cte as (select * from t1) select * from cte union select * from t2", + expected: "with cte as (select * from t1) select * from cte where 1 != 1 union select * from t2 where 1 != 1", + }, + { + name: "union with multiple CTEs", + input: "with cte1 as (select * from t1), cte2 as (select * from t2) select * from cte1 union select * from cte2", + expected: "with cte1 as (select * from t1) , cte2 as (select * from t2) select * from cte1 where 1 != 1 union select * from cte2 where 1 != 1", + }, + { + name: "union with recursive CTE", + input: "with recursive cte as (select id from t1 union select id + 1 from cte where id < 10) select * from cte union select * from t2", + expected: "with recursive cte as (select id from t1 union select id + 1 from cte where id < 10) select * from cte where 1 != 1 union select * from t2 where 1 != 1", + }, + { + name: "union with CTE containing where clause", + input: "with cte as (select * from t1 where id > 5) select * from cte union select * from t2 where name = 'test'", + expected: "with cte as (select * from t1 where id > 5) select * from cte where 1 != 1 union select * from t2 where 1 != 1", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + stmt, err := parser.Parse(tc.input) + require.NoError(t, err) + + buf := NewTrackedBuffer(nil) + FormatImpossibleQuery(buf, stmt) + + assert.Equal(t, tc.expected, buf.String()) + }) + } +} + +func TestFormatImpossibleQuery_Other(t *testing.T) { + parser := NewTestParser() + + testcases := []struct { + name string + input string + expected string + }{ + { + name: "insert statement", + input: "insert into t values (1, 'test')", + expected: "insert into t values (1, 'test')", + }, + { + name: "update statement", + input: "update t set name = 'test' where id = 1", + expected: "update t set `name` = 'test' where id = 1", + }, + { + name: "delete statement", + input: "delete from t where id = 1", + expected: "delete from t where id = 1", + }, + { + name: "create table statement", + input: "create table t (id int, name varchar(50))", + expected: "create table t (\n\tid int,\n\t`name` varchar(50)\n)", + }, + { + name: "show statement", + input: "show tables", + expected: "show tables", + }, + { + name: "describe statement", + input: "describe t", + expected: "explain t", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + stmt, err := parser.Parse(tc.input) + require.NoError(t, err) + + buf := NewTrackedBuffer(nil) + FormatImpossibleQuery(buf, stmt) + + assert.Equal(t, tc.expected, buf.String()) + }) + } +} diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index a92f476cc25..12a38e7a9c5 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -1862,7 +1862,7 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select 'count_a' as tab, num from count_a where 1 != 1 union select 'count_b' as tab, num from count_b where 1 != 1", + "FieldQuery": "with count_a as (select count(id) as num from unsharded_a where 1 != 1) , count_b as (select count(id) as num from unsharded_b where 1 != 1) select 'count_a' as tab, num from count_a where 1 != 1 union select 'count_b' as tab, num from count_b where 1 != 1", "Query": "with count_a as (select count(id) as num from unsharded_a) , count_b as (select count(id) as num from unsharded_b) select 'count_a' as tab, num from count_a union select 'count_b' as tab, num from count_b" }, "TablesUsed": [