Skip to content

Commit 604587b

Browse files
pdabre12aditi-pandit
authored andcommitted
Extract SQL invoked inlined functions tests into AbstractTestEngineOnlyQueries
1 parent b1c7bc4 commit 604587b

File tree

3 files changed

+300
-264
lines changed

3 files changed

+300
-264
lines changed

presto-native-tests/src/test/java/com/facebook/presto/nativetests/TestDistributedEngineOnlyQueries.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,35 @@ public void testLocallyUnrepresentableTimeLiterals()
9292
@Language("SQL") String sql = DateTimeFormatter.ofPattern("'SELECT TIME '''HH:mm:ss''").format(localTimeThatDidNotOccurOn20120401);
9393
assertQueryFails(sql, timeTypeUnsupportedError);
9494
}
95+
96+
// todo: turn on these test cases when the sql invoked functions are extracted into a plugin module.
97+
@Override
98+
@Test(enabled = false)
99+
public void testArraySplitIntoChunks()
100+
{
101+
}
102+
103+
@Override
104+
@Test(enabled = false)
105+
public void testCrossJoinWithArrayNotContainsCondition()
106+
{
107+
}
108+
109+
@Override
110+
@Test(enabled = false)
111+
public void testSamplingJoinChain()
112+
{
113+
}
114+
115+
@Override
116+
@Test(enabled = false)
117+
public void testKeyBasedSampling()
118+
{
119+
}
120+
121+
@Override
122+
@Test(enabled = false)
123+
public void testDefaultSamplingPercent()
124+
{
125+
}
95126
}

presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestEngineOnlyQueries.java

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
import java.time.ZonedDateTime;
2828
import java.time.format.DateTimeFormatter;
2929

30+
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED;
31+
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_FUNCTION;
32+
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_PERCENTAGE;
33+
import static com.facebook.presto.SystemSessionProperties.PUSH_DOWN_FILTER_EXPRESSION_EVALUATION_THROUGH_CROSS_JOIN;
34+
import static com.facebook.presto.SystemSessionProperties.REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN;
35+
import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN;
3036
import static com.google.common.base.Preconditions.checkState;
3137
import static org.testng.Assert.assertEquals;
3238

@@ -107,4 +113,267 @@ public void testLocallyUnrepresentableTimestampLiterals()
107113
assertEquals(computeScalar(sql), localTimeThatDidNotExist); // this tests Presto and the QueryRunner
108114
assertQuery(sql); // this tests H2QueryRunner
109115
}
116+
117+
@Test
118+
public void testArraySplitIntoChunks()
119+
{
120+
@Language("SQL") String sql = "select array_split_into_chunks(array[1, 2, 3, 4, 5, 6], 2)";
121+
assertQuery(sql, "values array[array[1, 2], array[3, 4], array[5, 6]]");
122+
123+
sql = "select array_split_into_chunks(array[1, 2, 3, 4, 5], 3)";
124+
assertQuery(sql, "values array[array[1, 2, 3], array[4, 5]]");
125+
126+
sql = "select array_split_into_chunks(array[1, 2, 3], 5)";
127+
assertQuery(sql, "values array[array[1, 2, 3]]");
128+
129+
sql = "select array_split_into_chunks(null, 2)";
130+
assertQuery(sql, "values null");
131+
132+
sql = "select array_split_into_chunks(array[1, 2, 3], 0)";
133+
assertQueryFails(sql, "Invalid slice size: 0. Size must be greater than zero.");
134+
135+
sql = "select array_split_into_chunks(array[1, 2, 3], -1)";
136+
assertQueryFails(sql, "Invalid slice size: -1. Size must be greater than zero.");
137+
138+
sql = "select array_split_into_chunks(array[1, null, 3, null, 5], 2)";
139+
assertQuery(sql, "values array[array[1, null], array[3, null], array[5]]");
140+
141+
sql = "select array_split_into_chunks(array['a', 'b', 'c', 'd'], 2)";
142+
assertQuery(sql, "values array[array['a', 'b'], array['c', 'd']]");
143+
144+
sql = "select array_split_into_chunks(array[1.1, 2.2, 3.3, 4.4, 5.5], 2)";
145+
assertQuery(sql, "values array[array[1.1, 2.2], array[3.3, 4.4], array[5.5]]");
146+
147+
sql = "select array_split_into_chunks(array[null, null, null], 0)";
148+
assertQueryFails(sql, "Invalid slice size: 0. Size must be greater than zero.");
149+
150+
sql = "select array_split_into_chunks(array[null, null, null], 2)";
151+
assertQuery(sql, "values array[array[null, null], array[null]]");
152+
153+
sql = "select array_split_into_chunks(array[null, 1, 2], 5)";
154+
assertQuery(sql, "values array[array[null, 1, 2]]");
155+
156+
sql = "select array_split_into_chunks(array[], 0)";
157+
assertQueryFails(sql, "Invalid slice size: 0. Size must be greater than zero.");
158+
}
159+
160+
@Test
161+
public void testCrossJoinWithArrayNotContainsCondition()
162+
{
163+
Session enableOptimization = Session.builder(getSession())
164+
.setSystemProperty(PUSH_DOWN_FILTER_EXPRESSION_EVALUATION_THROUGH_CROSS_JOIN, "REWRITTEN_TO_INNER_JOIN")
165+
.setSystemProperty(REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN, "true")
166+
.build();
167+
168+
@Language("SQL") String sql = "with t1 as (select * from (values (array[1, 2, 3])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
169+
"select t2.k, t2.v from t2 where not contains((select t1.arr from t1), t2.k)";
170+
assertQuery(enableOptimization, sql, "values (4, 'b')");
171+
172+
sql = "with t1 as (select * from (values (array[1, 2, 3, 3, null])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
173+
"select t2.k, t2.v from t2 where not contains((select t1.arr from t1), t2.k)";
174+
assertQuery(enableOptimization, sql, "values (4, 'b')");
175+
176+
sql = "with t1 as (select * from (values (1, 'JAPAN'), (2, 'invalid_nation')) t(k, nation)) " +
177+
"select t1.k, t1.nation from t1 where not contains((select array_agg(name) from nation), t1.nation)";
178+
assertQuery(enableOptimization, sql, "values (2, 'invalid_nation')");
179+
180+
// array is an expression that needs to be pushed down
181+
sql = "with t1 as (select * from (values (1, 'JAPAN'), (2, 'invalid_nation')) t(k, nation)) " +
182+
"select t1.k, t1.nation from t1 where not contains(array_distinct((select array_agg(name) from nation)), t1.nation)";
183+
assertQuery(enableOptimization, sql, "values (2, 'invalid_nation')");
184+
185+
// check not applicable cases for optimization
186+
187+
// optimization doesn't apply when there are additional columns on array side
188+
sql = "with t1 as (select * from (values (array[1, 1, 3], 10)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
189+
"select t1.k, t2.k, t2.v from t1 join t2 on not contains(t1.arr, t2.k)";
190+
assertQuery(enableOptimization, sql, "values (10, 4, 'b')");
191+
192+
// optimization doesn't apply for multi-row array tables
193+
sql = "with t1 as (select * from (values (array[1, 2, 3]), (array[4, 5, 6])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
194+
"select t1.arr, t2.k, t2.v from t1 join t2 on not contains(t1.arr, t2.k)";
195+
assertQuery(enableOptimization, sql, "values (array[1,2,3], 4, 'b'), (array[4,5,6], 1, 'a')");
196+
197+
// we currently don't support the optimization for cases that didn't come from a subquery
198+
sql = "with t1 as (select * from (values (array[1, 2, 3])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
199+
"select t2.k, t2.v from t1 join t2 on not contains(t1.arr, t2.k)";
200+
assertQuery(enableOptimization, sql, "values (4, 'b')");
201+
202+
sql = "with t1 as (select * from (values (array[1, 2, 3])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
203+
"select t1.arr, t2.k, t2.v from t1 join t2 on not contains(t1.arr, t2.k)";
204+
assertQuery(enableOptimization, sql, "values (array[1,2,3], 4, 'b')");
205+
206+
// transform function considered non-deterministic and doesn't get pushed down
207+
sql = "with t1 as (select * from (values (1, 'JAPAN'), (2, 'invalid_nation')) t(k, nation)) " +
208+
"select t1.k, t1.nation from t1 where not contains(transform((select array_agg(name) from nation), (x) ->lower(x)), lower(t1.nation))";
209+
assertQuery(enableOptimization, sql, "values (2, 'invalid_nation')");
210+
}
211+
212+
@Test
213+
public void testDefaultSamplingPercent()
214+
{
215+
assertQuery("select key_sampling_percent('abc')", "select 0.56");
216+
}
217+
218+
@Test
219+
public void testKeyBasedSampling()
220+
{
221+
String[] queries = {
222+
"select count(1) from orders join lineitem using(orderkey)",
223+
"select count(1) from (select custkey, max(orderkey) from orders group by custkey)",
224+
"select count_if(m >= 1) from (select max(orderkey) over(partition by custkey) m from orders)",
225+
"select cast(m as bigint) from (select sum(totalprice) over(partition by custkey order by comment) m from orders order by 1 desc limit 1)",
226+
"select count(1) from lineitem where orderkey in (select orderkey from orders where length(comment) > 7)",
227+
"select count(1) from lineitem where orderkey not in (select orderkey from orders where length(comment) > 27)",
228+
"select count(1) from (select distinct orderkey, custkey from orders)",
229+
};
230+
231+
int[] unsampledResults = {60175, 1000, 15000, 5408941, 60175, 9256, 15000};
232+
for (int i = 0; i < queries.length; i++) {
233+
assertQuery(queries[i], "select " + unsampledResults[i]);
234+
}
235+
236+
Session sessionWithKeyBasedSampling = Session.builder(getSession())
237+
.setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true")
238+
.setSystemProperty(KEY_BASED_SAMPLING_PERCENTAGE, "0.2")
239+
.build();
240+
241+
int[] sampled20PercentResults = {37170, 616, 9189, 5408941, 37170, 5721, 9278};
242+
for (int i = 0; i < queries.length; i++) {
243+
assertQuery(sessionWithKeyBasedSampling, queries[i], "select " + sampled20PercentResults[i]);
244+
}
245+
246+
sessionWithKeyBasedSampling = Session.builder(getSession())
247+
.setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true")
248+
.setSystemProperty(KEY_BASED_SAMPLING_PERCENTAGE, "0.1")
249+
.build();
250+
251+
int[] sampled10PercentResults = {33649, 557, 8377, 4644937, 33649, 5098, 8397};
252+
for (int i = 0; i < queries.length; i++) {
253+
assertQuery(sessionWithKeyBasedSampling, queries[i], "select " + sampled10PercentResults[i]);
254+
}
255+
}
256+
257+
@Test
258+
public void testLeftJoinWithArrayContainsCondition()
259+
{
260+
Session enableOptimization = Session.builder(getSession())
261+
.setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
262+
.build();
263+
264+
@Language("SQL") String sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
265+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
266+
assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')");
267+
268+
sql = "with t1 as (select * from (values (array[1, 2, 3, null], 10), (array[4, 5, 6, null, null], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
269+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
270+
assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')");
271+
272+
sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11), (array[null, 9], 12)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b'), (null, 'c'), (9, 'd'), (8, 'd')) t(k, v)) " +
273+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
274+
assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b'), (null, null, 'c'), (12, 9, 'd'), (null, 8, 'd')");
275+
276+
sql = "with t1 as (select * from (values (array[1, 2, 3, null, null], 10), (array[4, 5, 6, null, null], 11), (array[null, 9], 12)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b'), (null, 'c'), (9, 'd'), (8, 'd')) t(k, v)) " +
277+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
278+
assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b'), (null, null, 'c'), (12, 9, 'd'), (null, 8, 'd')");
279+
280+
sql = "with t1 as (select * from (values (array[1, 1, 3], 10), (array[4, 4, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
281+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
282+
assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')");
283+
284+
sql = "with t1 as (select * from (values (array[1, 1, 3, null, null], 10), (array[4, 4, 6, null, null], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
285+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
286+
assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')");
287+
288+
sql = "with t1 as (select * from (values (array[1, null, 3], 10), (array[4, null, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (null, 'b')) t(k, v)) " +
289+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
290+
assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (NULL, NULL, 'b')");
291+
292+
sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
293+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k) and t1.k > 10";
294+
assertQuery(enableOptimization, sql, "values (NULL, 1, 'a'), (11, 4, 'b')");
295+
296+
sql = "with t1 as (select * from (values (array[1, 2, 3], 1), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
297+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k) or t1.k = t2.k";
298+
assertQuery(enableOptimization, sql, "values (1, 1, 'a'), (11, 4, 'b')");
299+
300+
sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " +
301+
"select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on contains(t1.orderkey, o.orderkey) where o.totalprice < 2000";
302+
// Because the UDF has different names in H2, which is `array_contains`
303+
String h2Sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " +
304+
"select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on array_contains(t1.orderkey, o.orderkey) where o.totalprice < 2000";
305+
assertQuery(enableOptimization, sql, h2Sql);
306+
307+
sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " +
308+
"select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on contains(t1.orderkey, o.orderkey) and t1.partkey < o.orderkey where o.totalprice < 2000";
309+
h2Sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " +
310+
"select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on array_contains(t1.orderkey, o.orderkey) and t1.partkey < o.orderkey where o.totalprice < 2000";
311+
assertQuery(enableOptimization, sql, h2Sql);
312+
313+
// Element type and array type does not match
314+
sql = "with t1 as (select * from (values (array[cast(1 as bigint), 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (cast(1 as integer), 'a'), (4, 'b')) t(k, v)) " +
315+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
316+
assertQuery(enableOptimization, sql, "values (11, 4, 'b'), (10, 1, 'a')");
317+
318+
sql = "with t1 as (select * from (values (array[cast(1 as integer), 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (cast(1 as bigint), 'a'), (4, 'b')) t(k, v)) " +
319+
"select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
320+
assertQuery(enableOptimization, sql, "values (11, 4, 'b'), (10, 1, 'a')");
321+
}
322+
323+
@Test
324+
public void testKeyBasedSamplingFunctionError()
325+
{
326+
Session sessionWithKeyBasedSampling = Session.builder(getSession())
327+
.setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true")
328+
.setSystemProperty(KEY_BASED_SAMPLING_FUNCTION, "blah")
329+
.build();
330+
331+
assertQueryFails(sessionWithKeyBasedSampling, "select count(1) from orders join lineitem using(orderkey)", "Sampling function: blah not cannot be resolved");
332+
}
333+
334+
@Test
335+
public void testSamplingJoinChain()
336+
{
337+
Session sessionWithKeyBasedSampling = Session.builder(getSession())
338+
.setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true")
339+
.build();
340+
@Language("SQL") String sql = "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey";
341+
342+
assertQuery(sql, "select 60175");
343+
assertQuery(sessionWithKeyBasedSampling, sql, "select 16185");
344+
}
345+
346+
@Test
347+
public void testTry()
348+
{
349+
// Test try with map method and value parameter is optional and argument is an array with null,
350+
// the error should be suppressed and just return null.
351+
assertQuery("SELECT\n" +
352+
" TRY(map_keys_by_top_n_values(c0, BIGINT '6455219767830808341'))\n" +
353+
"FROM (\n" +
354+
" VALUES\n" +
355+
" MAP(\n" +
356+
" ARRAY[1, 2], ARRAY[\n" +
357+
" ARRAY[1, null],\n" +
358+
" ARRAY[1, null]\n" +
359+
" ]\n" +
360+
" )\n" +
361+
") t(c0)", "SELECT NULL");
362+
363+
assertQuery("SELECT\n" +
364+
" TRY(map_keys_by_top_n_values(c0, BIGINT '6455219767830808341'))\n" +
365+
"FROM (\n" +
366+
" VALUES\n" +
367+
" MAP(\n" +
368+
" ARRAY[1, 2], ARRAY[\n" +
369+
" ARRAY[null, null],\n" +
370+
" ARRAY[1, 2]\n" +
371+
" ]\n" +
372+
" )\n" +
373+
") t(c0)", "SELECT NULL");
374+
375+
// Test try with array method with an input array containing null values.
376+
// the error should be suppressed and just return null.
377+
assertQuery("SELECT TRY(ARRAY_MAX(ARRAY [ARRAY[1, NULL], ARRAY[1, 2]]))", "SELECT NULL");
378+
}
110379
}

0 commit comments

Comments
 (0)