|
27 | 27 | import java.time.ZonedDateTime; |
28 | 28 | import java.time.format.DateTimeFormatter; |
29 | 29 |
|
| 30 | +import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED; |
| 31 | +import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_FUNCTION; |
| 32 | +import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_PERCENTAGE; |
| 33 | +import static com.facebook.presto.SystemSessionProperties.PUSH_DOWN_FILTER_EXPRESSION_EVALUATION_THROUGH_CROSS_JOIN; |
| 34 | +import static com.facebook.presto.SystemSessionProperties.REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN; |
| 35 | +import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN; |
30 | 36 | import static com.google.common.base.Preconditions.checkState; |
31 | 37 | import static org.testng.Assert.assertEquals; |
32 | 38 |
|
@@ -107,4 +113,267 @@ public void testLocallyUnrepresentableTimestampLiterals() |
107 | 113 | assertEquals(computeScalar(sql), localTimeThatDidNotExist); // this tests Presto and the QueryRunner |
108 | 114 | assertQuery(sql); // this tests H2QueryRunner |
109 | 115 | } |
| 116 | + |
| 117 | + @Test |
| 118 | + public void testArraySplitIntoChunks() |
| 119 | + { |
| 120 | + @Language("SQL") String sql = "select array_split_into_chunks(array[1, 2, 3, 4, 5, 6], 2)"; |
| 121 | + assertQuery(sql, "values array[array[1, 2], array[3, 4], array[5, 6]]"); |
| 122 | + |
| 123 | + sql = "select array_split_into_chunks(array[1, 2, 3, 4, 5], 3)"; |
| 124 | + assertQuery(sql, "values array[array[1, 2, 3], array[4, 5]]"); |
| 125 | + |
| 126 | + sql = "select array_split_into_chunks(array[1, 2, 3], 5)"; |
| 127 | + assertQuery(sql, "values array[array[1, 2, 3]]"); |
| 128 | + |
| 129 | + sql = "select array_split_into_chunks(null, 2)"; |
| 130 | + assertQuery(sql, "values null"); |
| 131 | + |
| 132 | + sql = "select array_split_into_chunks(array[1, 2, 3], 0)"; |
| 133 | + assertQueryFails(sql, "Invalid slice size: 0. Size must be greater than zero."); |
| 134 | + |
| 135 | + sql = "select array_split_into_chunks(array[1, 2, 3], -1)"; |
| 136 | + assertQueryFails(sql, "Invalid slice size: -1. Size must be greater than zero."); |
| 137 | + |
| 138 | + sql = "select array_split_into_chunks(array[1, null, 3, null, 5], 2)"; |
| 139 | + assertQuery(sql, "values array[array[1, null], array[3, null], array[5]]"); |
| 140 | + |
| 141 | + sql = "select array_split_into_chunks(array['a', 'b', 'c', 'd'], 2)"; |
| 142 | + assertQuery(sql, "values array[array['a', 'b'], array['c', 'd']]"); |
| 143 | + |
| 144 | + sql = "select array_split_into_chunks(array[1.1, 2.2, 3.3, 4.4, 5.5], 2)"; |
| 145 | + assertQuery(sql, "values array[array[1.1, 2.2], array[3.3, 4.4], array[5.5]]"); |
| 146 | + |
| 147 | + sql = "select array_split_into_chunks(array[null, null, null], 0)"; |
| 148 | + assertQueryFails(sql, "Invalid slice size: 0. Size must be greater than zero."); |
| 149 | + |
| 150 | + sql = "select array_split_into_chunks(array[null, null, null], 2)"; |
| 151 | + assertQuery(sql, "values array[array[null, null], array[null]]"); |
| 152 | + |
| 153 | + sql = "select array_split_into_chunks(array[null, 1, 2], 5)"; |
| 154 | + assertQuery(sql, "values array[array[null, 1, 2]]"); |
| 155 | + |
| 156 | + sql = "select array_split_into_chunks(array[], 0)"; |
| 157 | + assertQueryFails(sql, "Invalid slice size: 0. Size must be greater than zero."); |
| 158 | + } |
| 159 | + |
| 160 | + @Test |
| 161 | + public void testCrossJoinWithArrayNotContainsCondition() |
| 162 | + { |
| 163 | + Session enableOptimization = Session.builder(getSession()) |
| 164 | + .setSystemProperty(PUSH_DOWN_FILTER_EXPRESSION_EVALUATION_THROUGH_CROSS_JOIN, "REWRITTEN_TO_INNER_JOIN") |
| 165 | + .setSystemProperty(REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN, "true") |
| 166 | + .build(); |
| 167 | + |
| 168 | + @Language("SQL") String sql = "with t1 as (select * from (values (array[1, 2, 3])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 169 | + "select t2.k, t2.v from t2 where not contains((select t1.arr from t1), t2.k)"; |
| 170 | + assertQuery(enableOptimization, sql, "values (4, 'b')"); |
| 171 | + |
| 172 | + sql = "with t1 as (select * from (values (array[1, 2, 3, 3, null])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 173 | + "select t2.k, t2.v from t2 where not contains((select t1.arr from t1), t2.k)"; |
| 174 | + assertQuery(enableOptimization, sql, "values (4, 'b')"); |
| 175 | + |
| 176 | + sql = "with t1 as (select * from (values (1, 'JAPAN'), (2, 'invalid_nation')) t(k, nation)) " + |
| 177 | + "select t1.k, t1.nation from t1 where not contains((select array_agg(name) from nation), t1.nation)"; |
| 178 | + assertQuery(enableOptimization, sql, "values (2, 'invalid_nation')"); |
| 179 | + |
| 180 | + // array is an expression that needs to be pushed down |
| 181 | + sql = "with t1 as (select * from (values (1, 'JAPAN'), (2, 'invalid_nation')) t(k, nation)) " + |
| 182 | + "select t1.k, t1.nation from t1 where not contains(array_distinct((select array_agg(name) from nation)), t1.nation)"; |
| 183 | + assertQuery(enableOptimization, sql, "values (2, 'invalid_nation')"); |
| 184 | + |
| 185 | + // check not applicable cases for optimization |
| 186 | + |
| 187 | + // optimization doesn't apply when there are additional columns on array side |
| 188 | + sql = "with t1 as (select * from (values (array[1, 1, 3], 10)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 189 | + "select t1.k, t2.k, t2.v from t1 join t2 on not contains(t1.arr, t2.k)"; |
| 190 | + assertQuery(enableOptimization, sql, "values (10, 4, 'b')"); |
| 191 | + |
| 192 | + // optimization doesn't apply for multi-row array tables |
| 193 | + sql = "with t1 as (select * from (values (array[1, 2, 3]), (array[4, 5, 6])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 194 | + "select t1.arr, t2.k, t2.v from t1 join t2 on not contains(t1.arr, t2.k)"; |
| 195 | + assertQuery(enableOptimization, sql, "values (array[1,2,3], 4, 'b'), (array[4,5,6], 1, 'a')"); |
| 196 | + |
| 197 | + // we currently don't support the optimization for cases that didn't come from a subquery |
| 198 | + sql = "with t1 as (select * from (values (array[1, 2, 3])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 199 | + "select t2.k, t2.v from t1 join t2 on not contains(t1.arr, t2.k)"; |
| 200 | + assertQuery(enableOptimization, sql, "values (4, 'b')"); |
| 201 | + |
| 202 | + sql = "with t1 as (select * from (values (array[1, 2, 3])) t(arr)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 203 | + "select t1.arr, t2.k, t2.v from t1 join t2 on not contains(t1.arr, t2.k)"; |
| 204 | + assertQuery(enableOptimization, sql, "values (array[1,2,3], 4, 'b')"); |
| 205 | + |
| 206 | + // transform function considered non-deterministic and doesn't get pushed down |
| 207 | + sql = "with t1 as (select * from (values (1, 'JAPAN'), (2, 'invalid_nation')) t(k, nation)) " + |
| 208 | + "select t1.k, t1.nation from t1 where not contains(transform((select array_agg(name) from nation), (x) ->lower(x)), lower(t1.nation))"; |
| 209 | + assertQuery(enableOptimization, sql, "values (2, 'invalid_nation')"); |
| 210 | + } |
| 211 | + |
| 212 | + @Test |
| 213 | + public void testDefaultSamplingPercent() |
| 214 | + { |
| 215 | + assertQuery("select key_sampling_percent('abc')", "select 0.56"); |
| 216 | + } |
| 217 | + |
| 218 | + @Test |
| 219 | + public void testKeyBasedSampling() |
| 220 | + { |
| 221 | + String[] queries = { |
| 222 | + "select count(1) from orders join lineitem using(orderkey)", |
| 223 | + "select count(1) from (select custkey, max(orderkey) from orders group by custkey)", |
| 224 | + "select count_if(m >= 1) from (select max(orderkey) over(partition by custkey) m from orders)", |
| 225 | + "select cast(m as bigint) from (select sum(totalprice) over(partition by custkey order by comment) m from orders order by 1 desc limit 1)", |
| 226 | + "select count(1) from lineitem where orderkey in (select orderkey from orders where length(comment) > 7)", |
| 227 | + "select count(1) from lineitem where orderkey not in (select orderkey from orders where length(comment) > 27)", |
| 228 | + "select count(1) from (select distinct orderkey, custkey from orders)", |
| 229 | + }; |
| 230 | + |
| 231 | + int[] unsampledResults = {60175, 1000, 15000, 5408941, 60175, 9256, 15000}; |
| 232 | + for (int i = 0; i < queries.length; i++) { |
| 233 | + assertQuery(queries[i], "select " + unsampledResults[i]); |
| 234 | + } |
| 235 | + |
| 236 | + Session sessionWithKeyBasedSampling = Session.builder(getSession()) |
| 237 | + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") |
| 238 | + .setSystemProperty(KEY_BASED_SAMPLING_PERCENTAGE, "0.2") |
| 239 | + .build(); |
| 240 | + |
| 241 | + int[] sampled20PercentResults = {37170, 616, 9189, 5408941, 37170, 5721, 9278}; |
| 242 | + for (int i = 0; i < queries.length; i++) { |
| 243 | + assertQuery(sessionWithKeyBasedSampling, queries[i], "select " + sampled20PercentResults[i]); |
| 244 | + } |
| 245 | + |
| 246 | + sessionWithKeyBasedSampling = Session.builder(getSession()) |
| 247 | + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") |
| 248 | + .setSystemProperty(KEY_BASED_SAMPLING_PERCENTAGE, "0.1") |
| 249 | + .build(); |
| 250 | + |
| 251 | + int[] sampled10PercentResults = {33649, 557, 8377, 4644937, 33649, 5098, 8397}; |
| 252 | + for (int i = 0; i < queries.length; i++) { |
| 253 | + assertQuery(sessionWithKeyBasedSampling, queries[i], "select " + sampled10PercentResults[i]); |
| 254 | + } |
| 255 | + } |
| 256 | + |
| 257 | + @Test |
| 258 | + public void testLeftJoinWithArrayContainsCondition() |
| 259 | + { |
| 260 | + Session enableOptimization = Session.builder(getSession()) |
| 261 | + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") |
| 262 | + .build(); |
| 263 | + |
| 264 | + @Language("SQL") String sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 265 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; |
| 266 | + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')"); |
| 267 | + |
| 268 | + sql = "with t1 as (select * from (values (array[1, 2, 3, null], 10), (array[4, 5, 6, null, null], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 269 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; |
| 270 | + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')"); |
| 271 | + |
| 272 | + sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11), (array[null, 9], 12)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b'), (null, 'c'), (9, 'd'), (8, 'd')) t(k, v)) " + |
| 273 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; |
| 274 | + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b'), (null, null, 'c'), (12, 9, 'd'), (null, 8, 'd')"); |
| 275 | + |
| 276 | + sql = "with t1 as (select * from (values (array[1, 2, 3, null, null], 10), (array[4, 5, 6, null, null], 11), (array[null, 9], 12)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b'), (null, 'c'), (9, 'd'), (8, 'd')) t(k, v)) " + |
| 277 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; |
| 278 | + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b'), (null, null, 'c'), (12, 9, 'd'), (null, 8, 'd')"); |
| 279 | + |
| 280 | + sql = "with t1 as (select * from (values (array[1, 1, 3], 10), (array[4, 4, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 281 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; |
| 282 | + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')"); |
| 283 | + |
| 284 | + sql = "with t1 as (select * from (values (array[1, 1, 3, null, null], 10), (array[4, 4, 6, null, null], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 285 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; |
| 286 | + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')"); |
| 287 | + |
| 288 | + sql = "with t1 as (select * from (values (array[1, null, 3], 10), (array[4, null, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (null, 'b')) t(k, v)) " + |
| 289 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; |
| 290 | + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (NULL, NULL, 'b')"); |
| 291 | + |
| 292 | + sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 293 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k) and t1.k > 10"; |
| 294 | + assertQuery(enableOptimization, sql, "values (NULL, 1, 'a'), (11, 4, 'b')"); |
| 295 | + |
| 296 | + sql = "with t1 as (select * from (values (array[1, 2, 3], 1), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + |
| 297 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k) or t1.k = t2.k"; |
| 298 | + assertQuery(enableOptimization, sql, "values (1, 1, 'a'), (11, 4, 'b')"); |
| 299 | + |
| 300 | + sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " + |
| 301 | + "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on contains(t1.orderkey, o.orderkey) where o.totalprice < 2000"; |
| 302 | + // Because the UDF has different names in H2, which is `array_contains` |
| 303 | + String h2Sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " + |
| 304 | + "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on array_contains(t1.orderkey, o.orderkey) where o.totalprice < 2000"; |
| 305 | + assertQuery(enableOptimization, sql, h2Sql); |
| 306 | + |
| 307 | + sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " + |
| 308 | + "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on contains(t1.orderkey, o.orderkey) and t1.partkey < o.orderkey where o.totalprice < 2000"; |
| 309 | + h2Sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " + |
| 310 | + "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on array_contains(t1.orderkey, o.orderkey) and t1.partkey < o.orderkey where o.totalprice < 2000"; |
| 311 | + assertQuery(enableOptimization, sql, h2Sql); |
| 312 | + |
| 313 | + // Element type and array type does not match |
| 314 | + sql = "with t1 as (select * from (values (array[cast(1 as bigint), 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (cast(1 as integer), 'a'), (4, 'b')) t(k, v)) " + |
| 315 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; |
| 316 | + assertQuery(enableOptimization, sql, "values (11, 4, 'b'), (10, 1, 'a')"); |
| 317 | + |
| 318 | + sql = "with t1 as (select * from (values (array[cast(1 as integer), 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (cast(1 as bigint), 'a'), (4, 'b')) t(k, v)) " + |
| 319 | + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; |
| 320 | + assertQuery(enableOptimization, sql, "values (11, 4, 'b'), (10, 1, 'a')"); |
| 321 | + } |
| 322 | + |
| 323 | + @Test |
| 324 | + public void testKeyBasedSamplingFunctionError() |
| 325 | + { |
| 326 | + Session sessionWithKeyBasedSampling = Session.builder(getSession()) |
| 327 | + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") |
| 328 | + .setSystemProperty(KEY_BASED_SAMPLING_FUNCTION, "blah") |
| 329 | + .build(); |
| 330 | + |
| 331 | + assertQueryFails(sessionWithKeyBasedSampling, "select count(1) from orders join lineitem using(orderkey)", "Sampling function: blah not cannot be resolved"); |
| 332 | + } |
| 333 | + |
| 334 | + @Test |
| 335 | + public void testSamplingJoinChain() |
| 336 | + { |
| 337 | + Session sessionWithKeyBasedSampling = Session.builder(getSession()) |
| 338 | + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") |
| 339 | + .build(); |
| 340 | + @Language("SQL") String sql = "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey"; |
| 341 | + |
| 342 | + assertQuery(sql, "select 60175"); |
| 343 | + assertQuery(sessionWithKeyBasedSampling, sql, "select 16185"); |
| 344 | + } |
| 345 | + |
| 346 | + @Test |
| 347 | + public void testTry() |
| 348 | + { |
| 349 | + // Test try with map method and value parameter is optional and argument is an array with null, |
| 350 | + // the error should be suppressed and just return null. |
| 351 | + assertQuery("SELECT\n" + |
| 352 | + " TRY(map_keys_by_top_n_values(c0, BIGINT '6455219767830808341'))\n" + |
| 353 | + "FROM (\n" + |
| 354 | + " VALUES\n" + |
| 355 | + " MAP(\n" + |
| 356 | + " ARRAY[1, 2], ARRAY[\n" + |
| 357 | + " ARRAY[1, null],\n" + |
| 358 | + " ARRAY[1, null]\n" + |
| 359 | + " ]\n" + |
| 360 | + " )\n" + |
| 361 | + ") t(c0)", "SELECT NULL"); |
| 362 | + |
| 363 | + assertQuery("SELECT\n" + |
| 364 | + " TRY(map_keys_by_top_n_values(c0, BIGINT '6455219767830808341'))\n" + |
| 365 | + "FROM (\n" + |
| 366 | + " VALUES\n" + |
| 367 | + " MAP(\n" + |
| 368 | + " ARRAY[1, 2], ARRAY[\n" + |
| 369 | + " ARRAY[null, null],\n" + |
| 370 | + " ARRAY[1, 2]\n" + |
| 371 | + " ]\n" + |
| 372 | + " )\n" + |
| 373 | + ") t(c0)", "SELECT NULL"); |
| 374 | + |
| 375 | + // Test try with array method with an input array containing null values. |
| 376 | + // the error should be suppressed and just return null. |
| 377 | + assertQuery("SELECT TRY(ARRAY_MAX(ARRAY [ARRAY[1, NULL], ARRAY[1, 2]]))", "SELECT NULL"); |
| 378 | + } |
110 | 379 | } |
0 commit comments