Skip to content

Commit ff042d8

Browse files
committed
Fix output variable alignment in ExtractSystemTableFilterRuleSet
When FilterScanRule transforms Exchange -> Filter -> TableScan to Filter -> Exchange -> TableScan, add a ProjectNode if the Filter's output variables don't match the original Exchange's expected output.
1 parent 89ab915 commit ff042d8

File tree

3 files changed

+327
-0
lines changed

3 files changed

+327
-0
lines changed

presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSystemTableFilterRuleSet.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
import com.facebook.presto.matching.Captures;
1818
import com.facebook.presto.matching.Pattern;
1919
import com.facebook.presto.metadata.FunctionAndTypeManager;
20+
import com.facebook.presto.spi.plan.Assignments;
2021
import com.facebook.presto.spi.plan.FilterNode;
2122
import com.facebook.presto.spi.plan.PartitioningScheme;
2223
import com.facebook.presto.spi.plan.PlanNode;
2324
import com.facebook.presto.spi.plan.ProjectNode;
2425
import com.facebook.presto.spi.plan.TableScanNode;
26+
import com.facebook.presto.spi.relation.VariableReferenceExpression;
2527
import com.facebook.presto.sql.planner.PlannerUtils;
2628
import com.facebook.presto.sql.planner.iterative.Rule;
2729
import com.facebook.presto.sql.planner.plan.ExchangeNode;
@@ -290,6 +292,24 @@ public Result apply(ExchangeNode node, Captures captures, Context context)
290292
newExchange,
291293
filterNode.getPredicate());
292294

295+
// Check if the original exchange's output variables match the filter's output
296+
// If not, add a project node to align them
297+
if (!exchangeNode.getOutputVariables().equals(newFilter.getOutputVariables())) {
298+
Assignments.Builder assignments = Assignments.builder();
299+
for (VariableReferenceExpression variable : exchangeNode.getOutputVariables()) {
300+
assignments.put(variable, variable);
301+
}
302+
303+
ProjectNode projectNode = new ProjectNode(
304+
exchangeNode.getSourceLocation(),
305+
context.getIdAllocator().getNextId(),
306+
newFilter,
307+
assignments.build(),
308+
ProjectNode.Locality.LOCAL);
309+
310+
return Result.ofPlanNode(projectNode);
311+
}
312+
293313
return Result.ofPlanNode(newFilter);
294314
}
295315
}

presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,4 +633,35 @@ public void testComplexJoinWithMultipleCppFunctions()
633633
"nationkey", "nationkey",
634634
"name", "name"))))))));
635635
}
636+
637+
@Test
638+
public void testSystemTableFilterWithOutputVariableMismatch()
639+
{
640+
assertNativeDistributedPlan(
641+
"SELECT table_name FROM information_schema.columns WHERE cpp_foo(ordinal_position) > 5",
642+
output(
643+
project(ImmutableMap.of("table_name", expression("table_name")),
644+
filter("cpp_foo(ordinal_position) > BIGINT'5'",
645+
exchange(REMOTE_STREAMING, GATHER,
646+
tableScan("columns", ImmutableMap.of(
647+
"ordinal_position", "ordinal_position",
648+
"table_name", "table_name")))))));
649+
}
650+
651+
@Test
652+
public void testSystemTableFilterWithMultipleColumnsAndPartialSelection()
653+
{
654+
assertNativeDistributedPlan(
655+
"SELECT table_schema, table_name FROM information_schema.columns " +
656+
"WHERE cpp_foo(ordinal_position) > 0 AND cpp_baz(ordinal_position) < 100",
657+
output(
658+
project(ImmutableMap.of("table_schema", expression("table_schema"),
659+
"table_name", expression("table_name")),
660+
filter("cpp_foo(ordinal_position) > BIGINT'0' AND cpp_baz(ordinal_position) < BIGINT'100'",
661+
exchange(REMOTE_STREAMING, GATHER,
662+
tableScan("columns", ImmutableMap.of(
663+
"ordinal_position", "ordinal_position",
664+
"table_schema", "table_schema",
665+
"table_name", "table_name")))))));
666+
}
636667
}
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.sidecar;
15+
16+
import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils;
17+
import com.facebook.presto.testing.QueryRunner;
18+
import com.facebook.presto.tests.AbstractTestQueryFramework;
19+
import com.facebook.presto.tests.DistributedQueryRunner;
20+
import org.testng.annotations.Test;
21+
22+
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem;
23+
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation;
24+
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders;
25+
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrdersEx;
26+
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createRegion;
27+
28+
public class TestNativeSidecarQueriesOnSystemTables
29+
extends AbstractTestQueryFramework
30+
{
31+
@Override
32+
protected void createTables()
33+
{
34+
QueryRunner queryRunner = (QueryRunner) getExpectedQueryRunner();
35+
createLineitem(queryRunner);
36+
createNation(queryRunner);
37+
createOrders(queryRunner);
38+
createOrdersEx(queryRunner);
39+
createRegion(queryRunner);
40+
}
41+
42+
@Override
43+
protected QueryRunner createQueryRunner()
44+
throws Exception
45+
{
46+
DistributedQueryRunner queryRunner = (DistributedQueryRunner) PrestoNativeQueryRunnerUtils.nativeHiveQueryRunnerBuilder()
47+
.setAddStorageFormatToPath(true)
48+
.setCoordinatorSidecarEnabled(true)
49+
.build();
50+
TestNativeSidecarPlugin.setupNativeSidecarPlugin(queryRunner);
51+
return queryRunner;
52+
}
53+
54+
@Override
55+
protected QueryRunner createExpectedQueryRunner()
56+
throws Exception
57+
{
58+
return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder()
59+
.setAddStorageFormatToPath(true)
60+
.build();
61+
}
62+
63+
@Test
64+
public void testExtractSystemTableFilterCorrectness()
65+
{
66+
// FilterScanRule - Basic filter with CPP function on system table
67+
assertQuery("SELECT table_name, ordinal_position FROM information_schema.columns " +
68+
"WHERE abs(ordinal_position) > 0 AND table_catalog = 'hive' AND table_name != 'roles' " +
69+
"ORDER BY table_name, ordinal_position");
70+
71+
// FilterScanRule - Complex predicate with multiple CPP functions
72+
assertQuery("SELECT table_name, ordinal_position FROM information_schema.columns " +
73+
"WHERE (abs(ordinal_position) > 1 AND ordinal_position < 5) " +
74+
"OR (abs(ordinal_position) + abs(ordinal_position) = 2 * ordinal_position) " +
75+
"AND table_catalog = 'hive' AND table_name != 'roles'" +
76+
"ORDER BY table_name, ordinal_position");
77+
78+
// ProjectScanRule - CPP function in projection
79+
assertQuery("SELECT table_name, abs(ordinal_position) as abs_pos FROM information_schema.columns " +
80+
"WHERE table_catalog = 'hive' AND table_name IN ('nation', 'region', 'lineitem', 'orders') " +
81+
"ORDER BY table_name, abs_pos");
82+
83+
// FilterScanRule with output variable mismatch
84+
assertQuery("SELECT table_name " +
85+
"FROM information_schema.columns " +
86+
"WHERE abs(ordinal_position) > 2 " +
87+
"AND table_catalog = 'hive' AND table_name IN ('nation', 'region', 'lineitem', 'orders') " +
88+
"ORDER BY table_name");
89+
90+
// ProjectFilterScanRule - Project with CPP and Filter on system table
91+
assertQuery("SELECT table_name, abs(ordinal_position) as abs_pos " +
92+
"FROM information_schema.columns " +
93+
"WHERE ordinal_position > 0 AND abs(ordinal_position) < 10 " +
94+
"AND table_catalog = 'hive' AND table_name IN ('nation', 'region', 'lineitem', 'orders') " +
95+
"ORDER BY table_name, abs_pos");
96+
97+
// Join system table with regular table using CPP function
98+
assertQuery("SELECT c.table_name, c.ordinal_position, n.name " +
99+
"FROM information_schema.columns c " +
100+
"JOIN nation n ON abs(c.ordinal_position) = n.nationkey " +
101+
"WHERE c.table_catalog = 'hive' AND c.table_name IN ('nation', 'region', 'lineitem', 'orders') " +
102+
"ORDER BY c.table_name, c.ordinal_position");
103+
104+
// Aggregation with CPP function on system table
105+
assertQuery("SELECT table_name, COUNT(*), SUM(abs(ordinal_position)) " +
106+
"FROM information_schema.columns " +
107+
"WHERE table_catalog = 'hive' AND table_name IN ('nation', 'region', 'lineitem', 'orders') " +
108+
"GROUP BY table_name " +
109+
"ORDER BY table_name");
110+
111+
// Nested CPP functions
112+
assertQuery("SELECT table_name, ordinal_position " +
113+
"FROM information_schema.columns " +
114+
"WHERE abs(abs(ordinal_position)) = ordinal_position " +
115+
"AND table_catalog = 'hive' AND table_name IN ('nation', 'region') " +
116+
"ORDER BY table_name, ordinal_position");
117+
118+
// CPP function in IN predicate
119+
assertQuery("SELECT table_name, ordinal_position " +
120+
"FROM information_schema.columns " +
121+
"WHERE abs(ordinal_position) IN (1, 2, 3) " +
122+
"AND table_catalog = 'hive' AND table_name != 'roles' " +
123+
"ORDER BY table_name, ordinal_position");
124+
125+
// CPP function with NULL handling
126+
assertQuery("SELECT table_name, " +
127+
"COALESCE(abs(ordinal_position), 0) as abs_pos " +
128+
"FROM information_schema.columns " +
129+
"WHERE table_catalog = 'hive' AND table_name IN ('nation', 'region') " +
130+
"ORDER BY table_name, ordinal_position");
131+
}
132+
133+
@Test
134+
public void testExtractSystemTableFilterWithJoins()
135+
{
136+
// Self-join on system table with CPP function
137+
assertQuery("SELECT c1.table_name, c1.ordinal_position, c2.ordinal_position " +
138+
"FROM information_schema.columns c1 " +
139+
"JOIN information_schema.columns c2 " +
140+
"ON c1.table_name = c2.table_name " +
141+
"AND abs(c1.ordinal_position) = abs(c2.ordinal_position) " +
142+
"WHERE c1.table_catalog = 'hive' AND c2.table_catalog = 'hive' " +
143+
"AND c1.table_name = 'nation' " +
144+
"ORDER BY c1.table_name, c1.ordinal_position, c2.ordinal_position");
145+
146+
// Join with CPP function in join condition
147+
assertQuery("SELECT c.table_name, c.column_name, t.table_type " +
148+
"FROM information_schema.columns c " +
149+
"JOIN information_schema.tables t " +
150+
"ON c.table_schema = t.table_schema " +
151+
"AND c.table_name = t.table_name " +
152+
"WHERE abs(c.ordinal_position) <= 3 " +
153+
"AND c.table_catalog = 'hive' " +
154+
"AND t.table_catalog = 'hive' " +
155+
"AND c.table_name IN ('nation', 'region') " +
156+
"ORDER BY c.table_name, c.column_name");
157+
158+
// Join system table with aggregation using CPP function
159+
assertQuery("SELECT t.table_name, COUNT(c.column_name), MAX(abs(c.ordinal_position)) " +
160+
"FROM information_schema.tables t " +
161+
"JOIN information_schema.columns c " +
162+
"ON t.table_schema = c.table_schema AND t.table_name = c.table_name " +
163+
"WHERE t.table_catalog = 'hive' AND c.table_catalog = 'hive' " +
164+
"AND t.table_name IN ('nation', 'region') " +
165+
"GROUP BY t.table_name " +
166+
"ORDER BY t.table_name");
167+
168+
// Complex join with multiple CPP functions
169+
assertQuery("SELECT c1.table_name, COUNT(DISTINCT c2.column_name) " +
170+
"FROM information_schema.columns c1 " +
171+
"JOIN information_schema.columns c2 " +
172+
"ON c1.table_schema = c2.table_schema " +
173+
"WHERE abs(c1.ordinal_position) + abs(c2.ordinal_position) > 3 " +
174+
"AND c1.table_catalog = 'hive' AND c2.table_catalog = 'hive' " +
175+
"AND c1.table_name IN ('nation', 'region') " +
176+
"GROUP BY c1.table_name " +
177+
"ORDER BY c1.table_name");
178+
179+
// Left join with CPP function
180+
assertQuery("SELECT t.table_name, c.column_name, abs(c.ordinal_position) " +
181+
"FROM information_schema.tables t " +
182+
"LEFT JOIN information_schema.columns c " +
183+
"ON t.table_schema = c.table_schema " +
184+
"AND t.table_name = c.table_name " +
185+
"AND abs(c.ordinal_position) <= 2 " +
186+
"WHERE t.table_catalog = 'hive' " +
187+
"AND t.table_name IN ('nation', 'region') " +
188+
"ORDER BY t.table_name, c.ordinal_position");
189+
}
190+
191+
@Test
192+
public void testExtractSystemTableFilterWithSubqueries()
193+
{
194+
// CPP function in subquery
195+
assertQuery("SELECT table_name FROM information_schema.tables " +
196+
"WHERE table_catalog = 'hive' " +
197+
"AND table_name IN (" +
198+
" SELECT table_name FROM information_schema.columns " +
199+
" WHERE abs(ordinal_position) = 1 " +
200+
" AND table_catalog = 'hive' AND table_name != 'roles'" +
201+
") " +
202+
"ORDER BY table_name");
203+
204+
// Correlated subquery with CPP function
205+
assertQuery("SELECT DISTINCT t.table_name " +
206+
"FROM information_schema.tables t " +
207+
"WHERE t.table_catalog = 'hive' " +
208+
"AND EXISTS (" +
209+
" SELECT 1 FROM information_schema.columns c " +
210+
" WHERE c.table_name = t.table_name " +
211+
" AND c.table_catalog = t.table_catalog " +
212+
" AND abs(c.ordinal_position) > 2" +
213+
") " +
214+
"AND t.table_name IN ('nation', 'region', 'lineitem', 'orders') " +
215+
"ORDER BY t.table_name");
216+
217+
// Scalar subquery with CPP function
218+
assertQuery("SELECT table_name, " +
219+
"(SELECT COUNT(*) FROM information_schema.columns c2 " +
220+
" WHERE c2.table_name = c1.table_name " +
221+
" AND c2.table_catalog = c1.table_catalog " +
222+
" AND abs(c2.ordinal_position) <= 3) as col_count " +
223+
"FROM information_schema.columns c1 " +
224+
"WHERE c1.table_catalog = 'hive' " +
225+
"AND c1.table_name IN ('nation', 'region') " +
226+
"AND c1.ordinal_position = 1 " +
227+
"ORDER BY c1.table_name");
228+
}
229+
230+
@Test
231+
public void testExtractSystemTableFilterWithWindowFunctions()
232+
{
233+
// Window function with CPP function in partition
234+
assertQuery("SELECT table_name, ordinal_position, " +
235+
"row_number() OVER (PARTITION BY table_name ORDER BY abs(ordinal_position)) as rn " +
236+
"FROM information_schema.columns " +
237+
"WHERE table_catalog = 'hive' " +
238+
"AND table_name IN ('nation', 'region') " +
239+
"ORDER BY table_name, ordinal_position");
240+
241+
// Window function with CPP function filter
242+
assertQuery("SELECT * FROM (" +
243+
" SELECT table_name, ordinal_position, " +
244+
" row_number() OVER (PARTITION BY table_name ORDER BY ordinal_position) as rn " +
245+
" FROM information_schema.columns " +
246+
" WHERE table_catalog = 'hive' " +
247+
" AND table_name IN ('nation', 'region')" +
248+
") " +
249+
"WHERE abs(rn) <= 2 " +
250+
"ORDER BY table_name, ordinal_position");
251+
}
252+
253+
@Test
254+
public void testExtractSystemTableFilterWithSetOperations()
255+
{
256+
// UNION with CPP functions
257+
assertQuery("SELECT table_name, abs(ordinal_position) as pos " +
258+
"FROM information_schema.columns " +
259+
"WHERE table_catalog = 'hive' AND table_name = 'nation' " +
260+
"UNION ALL " +
261+
"SELECT table_name, abs(ordinal_position) as pos " +
262+
"FROM information_schema.columns " +
263+
"WHERE table_catalog = 'hive' AND table_name = 'region' " +
264+
"ORDER BY table_name, pos");
265+
266+
// INTERSECT with CPP functions
267+
assertQuery("SELECT abs(ordinal_position) as pos " +
268+
"FROM information_schema.columns " +
269+
"WHERE table_catalog = 'hive' AND table_name = 'nation' " +
270+
"INTERSECT " +
271+
"SELECT abs(ordinal_position) as pos " +
272+
"FROM information_schema.columns " +
273+
"WHERE table_catalog = 'hive' AND table_name = 'region' " +
274+
"ORDER BY pos");
275+
}
276+
}

0 commit comments

Comments
 (0)