diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java index b9c7b1268f306..d977420d28339 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java @@ -68,6 +68,6 @@ public static Block intersect( @SqlType("array") public static String arrayIntersectArray() { - return "RETURN reduce(input, null, (s, x) -> IF((s IS NULL), x, array_intersect(s, x)), (s) -> s)"; + return "RETURN reduce(input, IF((cardinality(input) = 0), ARRAY[], input[1]), (s, x) -> array_intersect(s, x), (s) -> s)"; } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java index 528eb9521144f..7a77020585434 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java @@ -215,6 +215,10 @@ public void testDuplicates() public void testSqlFunctions() { assertFunction("array_intersect(ARRAY[ARRAY[1, 3, 5], ARRAY[2, 3, 5], ARRAY[3, 3, 3, 6]])", new ArrayType(INTEGER), ImmutableList.of(3)); + assertFunction("array_intersect(ARRAY[null, ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(INTEGER), null); + assertFunction("array_intersect(ARRAY[ARRAY[], null, ARRAY[1, 2, 3]])", new ArrayType(INTEGER), null); + assertFunction("array_intersect(ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of()); + assertFunction("array_intersect(null)", new ArrayType(UNKNOWN), null); assertFunction("array_intersect(ARRAY[ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(INTEGER), ImmutableList.of()); assertFunction("array_intersect(ARRAY[ARRAY[1, 2, 3], null])", new ArrayType(INTEGER), null); assertFunction("array_intersect(ARRAY[ARRAY[DOUBLE'1.1', DOUBLE'2.2', DOUBLE'3.3'], ARRAY[DOUBLE'1.1', DOUBLE'3.4'], ARRAY[DOUBLE'1.0', DOUBLE'1.1', DOUBLE'1.2']])", new ArrayType(DOUBLE), ImmutableList.of(1.1)); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestNanQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestNanQueries.java index 5e02f43295a40..05d82512c4777 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestNanQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestNanQueries.java @@ -46,6 +46,10 @@ public abstract class AbstractTestNanQueries public static final String SIMPLE_DOUBLE_ARRAY_COLUMN = "simple_double_array"; public static final String SIMPLE_REAL_ARRAY_COLUMN = "simple_real_array"; + public static final String ARRAY_TABLE_NAME_NO_NULL = "array_nans_table_no_null"; + public static final String SIMPLE_DOUBLE_ARRAY_COLUMN_NO_NULL = "simple_double_array_no_null"; + public static final String SIMPLE_REAL_ARRAY_COLUMN_NO_NULL = "simple_real_array_no_null"; + public static final String MAP_TABLE_NAME = "map_nans_table"; public static final String DOUBLE_MAP_COLUMN = "double_map"; public static final String REAL_MAP_COLUMN = "real_map"; @@ -96,7 +100,19 @@ public void setup() "(ARRAY[DOUBLE '0', DOUBLE '1', nan(), DOUBLE '-1', nan(), DOUBLE '1', DOUBLE '1', DOUBLE'0'], ARRAY [REAL '0', REAL '1', CAST(nan() AS REAL), REAL '-1', CAST(nan() AS REAL), REAL '1', REAL '1', REAL '0'])) " + "AS t (" + SIMPLE_DOUBLE_ARRAY_COLUMN + ", " + SIMPLE_REAL_ARRAY_COLUMN + ")"; + @Language("SQL") String createArrayTableNoNullQuery = "" + + "CREATE TABLE " + ARRAY_TABLE_NAME_NO_NULL + " AS " + + "SELECT * FROM (VALUES " + + "(ARRAY[nan(), DOUBLE '0', DOUBLE '1', DOUBLE '-1'], ARRAY[cast(nan() AS REAL), REAL '0', REAL '1', REAL '-1']), " + + "(ARRAY[ DOUBLE '0', nan(), DOUBLE '1', DOUBLE '-1'], ARRAY[REAL '0', CAST(nan() AS REAL), REAL '1', REAL '-1']), " + + "(ARRAY[ DOUBLE '0', DOUBLE '1', DOUBLE '-1', nan()], ARRAY[REAL '0', REAL '1', REAL '-1', CAST(nan() AS REAL)]), " + + "(ARRAY[null, nan(), DOUBLE '200'], ARRAY[null, CAST(nan() AS REAL), REAL '200']), " + + "(ARRAY[nan(), nan()], ARRAY[CAST(nan() AS REAL), CAST(nan() AS REAL)]), " + + "(ARRAY[DOUBLE '0', DOUBLE '1', nan(), DOUBLE '-1', nan(), DOUBLE '1', DOUBLE '1', DOUBLE'0'], ARRAY [REAL '0', REAL '1', CAST(nan() AS REAL), REAL '-1', CAST(nan() AS REAL), REAL '1', REAL '1', REAL '0'])) " + + "AS t (" + SIMPLE_DOUBLE_ARRAY_COLUMN_NO_NULL + ", " + SIMPLE_REAL_ARRAY_COLUMN_NO_NULL + ")"; + assertUpdate(createArrayTableQuery, 7); + assertUpdate(createArrayTableNoNullQuery, 6); @Language("SQL") String createMapTableQuery = "" + "CREATE TABLE " + MAP_TABLE_NAME + " AS " + @@ -728,6 +744,9 @@ public void testDoubleArrayIntersect2() // Test the array of arrays function signature assertQueryWithSameQueryRunner( format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_DOUBLE_ARRAY_COLUMN, ARRAY_TABLE_NAME), + "SELECT NULL"); + assertQueryWithSameQueryRunner( + format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_DOUBLE_ARRAY_COLUMN_NO_NULL, ARRAY_TABLE_NAME_NO_NULL), "SELECT * FROM (VALUES (ARRAY[nan()]))"); } @@ -737,6 +756,9 @@ public void testRealArrayIntersect2() // Test the array of arrays function signature assertQueryWithSameQueryRunner( format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_REAL_ARRAY_COLUMN, ARRAY_TABLE_NAME), + "SELECT NULL"); + assertQueryWithSameQueryRunner( + format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_REAL_ARRAY_COLUMN_NO_NULL, ARRAY_TABLE_NAME_NO_NULL), "SELECT * FROM (VALUES (ARRAY[CAST(nan() AS REAL)]))"); }