Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ public static Block intersect(
@SqlType("array<T>")
public static String arrayIntersectArray()
{
return "RETURN reduce(input, null, (s, x) -> IF((s IS NULL), x, array_intersect(s, x)), (s) -> s)";
return "RETURN reduce(input, IF((cardinality(input) = 0), ARRAY[], input[1]), (s, x) -> array_intersect(s, x), (s) -> s)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ public void testDuplicates()
public void testSqlFunctions()
{
assertFunction("array_intersect(ARRAY[ARRAY[1, 3, 5], ARRAY[2, 3, 5], ARRAY[3, 3, 3, 6]])", new ArrayType(INTEGER), ImmutableList.of(3));
assertFunction("array_intersect(ARRAY[null, ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(INTEGER), null);
assertFunction("array_intersect(ARRAY[ARRAY[], null, ARRAY[1, 2, 3]])", new ArrayType(INTEGER), null);
assertFunction("array_intersect(ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of());
Comment thread
kewang1024 marked this conversation as resolved.
assertFunction("array_intersect(null)", new ArrayType(UNKNOWN), null);
assertFunction("array_intersect(ARRAY[ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(INTEGER), ImmutableList.of());
assertFunction("array_intersect(ARRAY[ARRAY[1, 2, 3], null])", new ArrayType(INTEGER), null);
assertFunction("array_intersect(ARRAY[ARRAY[DOUBLE'1.1', DOUBLE'2.2', DOUBLE'3.3'], ARRAY[DOUBLE'1.1', DOUBLE'3.4'], ARRAY[DOUBLE'1.0', DOUBLE'1.1', DOUBLE'1.2']])", new ArrayType(DOUBLE), ImmutableList.of(1.1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ public abstract class AbstractTestNanQueries
public static final String SIMPLE_DOUBLE_ARRAY_COLUMN = "simple_double_array";
public static final String SIMPLE_REAL_ARRAY_COLUMN = "simple_real_array";

public static final String ARRAY_TABLE_NAME_NO_NULL = "array_nans_table_no_null";
public static final String SIMPLE_DOUBLE_ARRAY_COLUMN_NO_NULL = "simple_double_array_no_null";
public static final String SIMPLE_REAL_ARRAY_COLUMN_NO_NULL = "simple_real_array_no_null";

public static final String MAP_TABLE_NAME = "map_nans_table";
public static final String DOUBLE_MAP_COLUMN = "double_map";
public static final String REAL_MAP_COLUMN = "real_map";
Expand Down Expand Up @@ -96,7 +100,19 @@ public void setup()
"(ARRAY[DOUBLE '0', DOUBLE '1', nan(), DOUBLE '-1', nan(), DOUBLE '1', DOUBLE '1', DOUBLE'0'], ARRAY [REAL '0', REAL '1', CAST(nan() AS REAL), REAL '-1', CAST(nan() AS REAL), REAL '1', REAL '1', REAL '0'])) " +
"AS t (" + SIMPLE_DOUBLE_ARRAY_COLUMN + ", " + SIMPLE_REAL_ARRAY_COLUMN + ")";

@Language("SQL") String createArrayTableNoNullQuery = "" +
"CREATE TABLE " + ARRAY_TABLE_NAME_NO_NULL + " AS " +
"SELECT * FROM (VALUES " +
"(ARRAY[nan(), DOUBLE '0', DOUBLE '1', DOUBLE '-1'], ARRAY[cast(nan() AS REAL), REAL '0', REAL '1', REAL '-1']), " +
"(ARRAY[ DOUBLE '0', nan(), DOUBLE '1', DOUBLE '-1'], ARRAY[REAL '0', CAST(nan() AS REAL), REAL '1', REAL '-1']), " +
"(ARRAY[ DOUBLE '0', DOUBLE '1', DOUBLE '-1', nan()], ARRAY[REAL '0', REAL '1', REAL '-1', CAST(nan() AS REAL)]), " +
"(ARRAY[null, nan(), DOUBLE '200'], ARRAY[null, CAST(nan() AS REAL), REAL '200']), " +
"(ARRAY[nan(), nan()], ARRAY[CAST(nan() AS REAL), CAST(nan() AS REAL)]), " +
"(ARRAY[DOUBLE '0', DOUBLE '1', nan(), DOUBLE '-1', nan(), DOUBLE '1', DOUBLE '1', DOUBLE'0'], ARRAY [REAL '0', REAL '1', CAST(nan() AS REAL), REAL '-1', CAST(nan() AS REAL), REAL '1', REAL '1', REAL '0'])) " +
"AS t (" + SIMPLE_DOUBLE_ARRAY_COLUMN_NO_NULL + ", " + SIMPLE_REAL_ARRAY_COLUMN_NO_NULL + ")";

assertUpdate(createArrayTableQuery, 7);
assertUpdate(createArrayTableNoNullQuery, 6);

@Language("SQL") String createMapTableQuery = "" +
"CREATE TABLE " + MAP_TABLE_NAME + " AS " +
Expand Down Expand Up @@ -728,6 +744,9 @@ public void testDoubleArrayIntersect2()
// Test the array of arrays function signature
assertQueryWithSameQueryRunner(
format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_DOUBLE_ARRAY_COLUMN, ARRAY_TABLE_NAME),
"SELECT NULL");
assertQueryWithSameQueryRunner(
format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_DOUBLE_ARRAY_COLUMN_NO_NULL, ARRAY_TABLE_NAME_NO_NULL),
"SELECT * FROM (VALUES (ARRAY[nan()]))");
}

Expand All @@ -737,6 +756,9 @@ public void testRealArrayIntersect2()
// Test the array of arrays function signature
assertQueryWithSameQueryRunner(
format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_REAL_ARRAY_COLUMN, ARRAY_TABLE_NAME),
"SELECT NULL");
assertQueryWithSameQueryRunner(
format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_REAL_ARRAY_COLUMN_NO_NULL, ARRAY_TABLE_NAME_NO_NULL),
"SELECT * FROM (VALUES (ARRAY[CAST(nan() AS REAL)]))");
}

Expand Down