diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index a3e85481c05f4..5244a04c7772d 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -48,18 +48,11 @@ Array Functions Returns a set of elements that occur more than once in ``array``. - ``T`` must be coercible to ``bigint`` or ``varchar``. - .. function:: array_except(x, y) -> array Returns an array of elements in ``x`` but not in ``y``, without duplicates. -.. function:: array_frequency(array(bigint)) -> map(bigint, int) - - Returns a map: keys are the unique elements in the ``array``, values are how many times the key appears. - Ignores null elements. Empty array returns empty map. - -.. function:: array_frequency(array(varchar)) -> map(varchar, int) +.. function:: array_frequency(array(E)) -> map(E, int) Returns a map: keys are the unique elements in the ``array``, values are how many times the key appears. Ignores null elements. Empty array returns empty map. @@ -68,16 +61,13 @@ Array Functions Returns a boolean: whether ``array`` has any elements that occur more than once. - ``T`` must be coercible to ``bigint`` or ``varchar``. - .. function:: array_intersect(x, y) -> array Returns an array of the elements in the intersection of ``x`` and ``y``, without duplicates. -.. function:: array_intersect(array(array(E))) -> array(bigint/double) +.. function:: array_intersect(array(array(E))) -> array(E) Returns an array of the elements in the intersection of all arrays in the given array, without duplicates. - E must be coercible to ``double``. Returns ``bigint`` if T is coercible to ``bigint``. Otherwise, returns ``double``. .. function:: array_join(x, delimiter, null_replacement) -> varchar diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java index 9e776732fb217..ceb42383e07cd 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java @@ -57,18 +57,10 @@ public static Block intersect( @SqlInvokedScalarFunction(value = "array_intersect", deterministic = true, calledOnNullInput = false) @Description("Intersects elements of all arrays in the given array") - @SqlParameter(name = "input", type = "array>") - @SqlType("array") - public static String arrayIntersectBigint() - { - return "RETURN reduce(input, null, (s, x) -> IF((s IS NULL), x, array_intersect(s, x)), (s) -> s)"; - } - - @SqlInvokedScalarFunction(value = "array_intersect", deterministic = true, calledOnNullInput = false) - @Description("Intersects elements of all arrays in the given array") - @SqlParameter(name = "input", type = "array>") - @SqlType("array") - public static String arrayIntersectDouble() + @TypeParameter("T") + @SqlParameter(name = "input", type = "array>") + @SqlType("array") + public static String arrayIntersectArray() { return "RETURN reduce(input, null, (s, x) -> IF((s IS NULL), x, array_intersect(s, x)), (s) -> s)"; } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java index 08ba7319e26f7..ee81451076fe6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.function.SqlInvokedScalarFunction; import com.facebook.presto.spi.function.SqlParameter; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; public class ArraySqlFunctions { @@ -55,8 +56,9 @@ public static String arrayAverage() @SqlInvokedScalarFunction(value = "array_frequency", deterministic = true, calledOnNullInput = false) @Description("Returns the frequency of all array elements as a map.") - @SqlParameter(name = "input", type = "array(bigint)") - @SqlType("map(bigint, int)") + @TypeParameter("T") + @SqlParameter(name = "input", type = "array(T)") + @SqlType("map(T, int)") public static String arrayFrequencyBigint() { return "RETURN reduce(" + @@ -66,56 +68,25 @@ public static String arrayFrequencyBigint() "m -> m)"; } - @SqlInvokedScalarFunction(value = "array_frequency", deterministic = true, calledOnNullInput = false) - @Description("Returns the frequency of all array elements as a map.") - @SqlParameter(name = "input", type = "array(varchar)") - @SqlType("map(varchar, int)") - public static String arrayFrequencyVarchar() - { - return "RETURN reduce(" + - "input," + - "MAP()," + - "(m, x) -> IF (x IS NOT NULL, MAP_CONCAT(m,MAP_FROM_ENTRIES(ARRAY[ROW(x, COALESCE(ELEMENT_AT(m,x) + 1, 1))])), m)," + - "m -> m)"; - } - - @SqlInvokedScalarFunction(value = "array_duplicates", alias = {"array_dupes"}, deterministic = true, calledOnNullInput = false) - @Description("Returns set of elements that have duplicates") - @SqlParameter(name = "input", type = "array(varchar)") - @SqlType("array(varchar)") - public static String arrayDuplicatesVarchar() - { - return "RETURN CONCAT(" + - "CAST(IF (cardinality(filter(input, x -> x is NULL)) > 1, ARRAY[NULL], ARRAY[]) AS ARRAY(VARCHAR))," + - "map_keys(map_filter(array_frequency(input), (k, v) -> v > 1)))"; - } - @SqlInvokedScalarFunction(value = "array_duplicates", alias = {"array_dupes"}, deterministic = true, calledOnNullInput = false) @Description("Returns set of elements that have duplicates") - @SqlParameter(name = "input", type = "array(bigint)") - @SqlType("array(bigint)") - public static String arrayDuplicatesBigint() + @SqlParameter(name = "input", type = "array(T)") + @TypeParameter("T") + @SqlType("array(T)") + public static String arrayDuplicates() { return "RETURN CONCAT(" + - "CAST(IF (cardinality(filter(input, x -> x is NULL)) > 1, ARRAY[NULL], ARRAY[]) AS ARRAY(BIGINT))," + + "IF (cardinality(filter(input, x -> x is NULL)) > 1, array[find_first(input, x -> x IS NULL)], array[])," + "map_keys(map_filter(array_frequency(input), (k, v) -> v > 1)))"; } @SqlInvokedScalarFunction(value = "array_has_duplicates", alias = {"array_has_dupes"}, deterministic = true, calledOnNullInput = false) @Description("Returns whether array has any duplicate element") - @SqlParameter(name = "input", type = "array(varchar)") + @TypeParameter("T") + @SqlParameter(name = "input", type = "array(T)") @SqlType("boolean") public static String arrayHasDuplicatesVarchar() { return "RETURN cardinality(array_duplicates(input)) > 0"; } - - @SqlInvokedScalarFunction(value = "array_has_duplicates", alias = {"array_has_dupes"}, deterministic = true, calledOnNullInput = false) - @Description("Returns whether array has any duplicate element") - @SqlParameter(name = "input", type = "array(bigint)") - @SqlType("boolean") - public static String arrayHasDuplicatesBigint() - { - return "RETURN cardinality(array_duplicates(input)) > 0"; - } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java index cee3fecf320cf..60f09dd353ed8 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.RowType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -69,9 +70,15 @@ public void testDuplicates() @Test public void testSQLFunctions() { - assertFunction("array_intersect(ARRAY[ARRAY[1, 3, 5], ARRAY[2, 3, 5], ARRAY[3, 3, 3, 6]])", new ArrayType(BIGINT), ImmutableList.of(3L)); - assertFunction("array_intersect(ARRAY[ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(BIGINT), ImmutableList.of()); - assertFunction("array_intersect(ARRAY[ARRAY[1, 2, 3], null])", new ArrayType(BIGINT), null); - assertFunction("array_intersect(ARRAY[ARRAY[1.1, 2.2, 3.3], ARRAY[1.1, 3.4], ARRAY[1.0, 1.1, 1.2]])", new ArrayType(DOUBLE), ImmutableList.of(1.1)); + assertFunction("array_intersect(ARRAY[ARRAY[1, 3, 5], ARRAY[2, 3, 5], ARRAY[3, 3, 3, 6]])", new ArrayType(INTEGER), ImmutableList.of(3)); + assertFunction("array_intersect(ARRAY[ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("array_intersect(ARRAY[ARRAY[1, 2, 3], null])", new ArrayType(INTEGER), null); + assertFunction("array_intersect(ARRAY[ARRAY[DOUBLE'1.1', DOUBLE'2.2', DOUBLE'3.3'], ARRAY[DOUBLE'1.1', DOUBLE'3.4'], ARRAY[DOUBLE'1.0', DOUBLE'1.1', DOUBLE'1.2']])", new ArrayType(DOUBLE), ImmutableList.of(1.1)); + + assertFunction("array_intersect(ARRAY[ARRAY[ARRAY[1], ARRAY[2]], ARRAY[ARRAY[2], ARRAY[3]]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(2))); + + RowType rowType = RowType.from(ImmutableList.of(RowType.field("x", DOUBLE), RowType.field("y", DOUBLE))); + String t = rowType.toString(); + assertFunction("array_intersect(ARRAY[ARRAY[CAST((1.0, 2.0) AS " + t + "), CAST((2.0, 3.0) AS " + t + ")], ARRAY[CAST((0.0, 1.0) AS " + t + "), CAST((1.0, 2.0) AS " + t + ")]])", new ArrayType(rowType), ImmutableList.of(ImmutableList.of(1.0, 2.0))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java index 90b5f5e809674..32fd1ac4d5ef6 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java @@ -14,24 +14,20 @@ package com.facebook.presto.operator.scalar.sql; import com.facebook.presto.common.type.ArrayType; -import com.facebook.presto.common.type.MapType; -import com.facebook.presto.common.type.TestRowType; -import com.facebook.presto.common.type.TypeSignature; -import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.common.type.RowType; import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.facebook.presto.spi.StandardErrorCode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; -import java.util.Collections; - -import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle; +import static com.facebook.presto.block.BlockAssertions.createMapType; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static java.util.Collections.singletonList; public class TestArraySqlFunctions extends AbstractTestFunctions @@ -83,48 +79,49 @@ public void testArrayAverage() @Test public void testArrayFrequencyBigint() { - FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); - MapType type = new MapType(BIGINT, - INTEGER, - methodHandle(TestRowType.class, "throwUnsupportedOperation"), - methodHandle(TestRowType.class, "throwUnsupportedOperation")); - TypeSignature typeSignature = TypeSignature.parseTypeSignature(type.getDisplayName()); - - assertFunction("array_frequency(cast(null as array(bigint)))", functionAndTypeManager.getType(typeSignature), null); - assertFunction("array_frequency(cast(array[] as array(bigint)))", functionAndTypeManager.getType(typeSignature), ImmutableMap.of()); - assertFunction("array_frequency(array[cast(null as bigint), cast(null as bigint), cast(null as bigint)])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of()); - assertFunction("array_frequency(array[cast(null as bigint), bigint '1'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 1)); - assertFunction("array_frequency(array[cast(null as bigint), bigint '1', bigint '3', cast(null as bigint), bigint '1', bigint '3', cast(null as bigint)])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 2, 3L, 2)); - assertFunction("array_frequency(array[bigint '1', bigint '1', bigint '2', bigint '2', bigint '3', bigint '1', bigint '3', bigint '2'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 3, 2L, 3, 3L, 2)); - assertFunction("array_frequency(array[bigint '45'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(45L, 1)); - assertFunction("array_frequency(array[bigint '-45'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(-45L, 1)); - assertFunction("array_frequency(array[bigint '1', bigint '3', bigint '1', bigint '3'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 2, 3L, 2)); - assertFunction("array_frequency(array[bigint '3', bigint '1', bigint '3',bigint '1'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 2, 3L, 2)); - assertFunction("array_frequency(array[bigint '4',bigint '3',bigint '3',bigint '2',bigint '2',bigint '2',bigint '1',bigint '1',bigint '1',bigint '1'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 4, 2L, 3, 3L, 2, 4L, 1)); - assertFunction("array_frequency(array[bigint '3', bigint '3', bigint '2', bigint '2', bigint '5', bigint '5', bigint '1', bigint '1'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 2, 2L, 2, 3L, 2, 5L, 2)); + assertFunction("array_frequency(cast(null as array(bigint)))", createMapType(BIGINT, INTEGER), null); + assertFunction("array_frequency(cast(array[] as array(bigint)))", createMapType(BIGINT, INTEGER), ImmutableMap.of()); + assertFunction("array_frequency(array[cast(null as bigint), cast(null as bigint), cast(null as bigint)])", createMapType(BIGINT, INTEGER), ImmutableMap.of()); + assertFunction("array_frequency(array[cast(null as bigint), bigint '1'])", createMapType(BIGINT, INTEGER), ImmutableMap.of(1L, 1)); + assertFunction("array_frequency(array[cast(null as bigint), bigint '1', bigint '3', cast(null as bigint), bigint '1', bigint '3', cast(null as bigint)])", createMapType(BIGINT, INTEGER), ImmutableMap.of(1L, 2, 3L, 2)); + assertFunction("array_frequency(array[bigint '1', bigint '1', bigint '2', bigint '2', bigint '3', bigint '1', bigint '3', bigint '2'])", createMapType(BIGINT, INTEGER), ImmutableMap.of(1L, 3, 2L, 3, 3L, 2)); + assertFunction("array_frequency(array[bigint '45'])", createMapType(BIGINT, INTEGER), ImmutableMap.of(45L, 1)); + assertFunction("array_frequency(array[bigint '-45'])", createMapType(BIGINT, INTEGER), ImmutableMap.of(-45L, 1)); + assertFunction("array_frequency(array[bigint '1', bigint '3', bigint '1', bigint '3'])", createMapType(BIGINT, INTEGER), ImmutableMap.of(1L, 2, 3L, 2)); + assertFunction("array_frequency(array[bigint '3', bigint '1', bigint '3',bigint '1'])", createMapType(BIGINT, INTEGER), ImmutableMap.of(1L, 2, 3L, 2)); + assertFunction("array_frequency(array[bigint '4',bigint '3',bigint '3',bigint '2',bigint '2',bigint '2',bigint '1',bigint '1',bigint '1',bigint '1'])", createMapType(BIGINT, INTEGER), ImmutableMap.of(1L, 4, 2L, 3, 3L, 2, 4L, 1)); + assertFunction("array_frequency(array[bigint '3', bigint '3', bigint '2', bigint '2', bigint '5', bigint '5', bigint '1', bigint '1'])", createMapType(BIGINT, INTEGER), ImmutableMap.of(1L, 2, 2L, 2, 3L, 2, 5L, 2)); } @Test public void testArrayFrequencyVarchar() { - FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); - - MapType type = new MapType(VARCHAR, - INTEGER, - methodHandle(TestRowType.class, "throwUnsupportedOperation"), - methodHandle(TestRowType.class, "throwUnsupportedOperation")); - TypeSignature typeSignature = TypeSignature.parseTypeSignature(type.getDisplayName()); - - assertFunction("array_frequency(cast(null as array(varchar)))", functionAndTypeManager.getType(typeSignature), null); - assertFunction("array_frequency(cast(array[] as array(varchar)))", functionAndTypeManager.getType(typeSignature), ImmutableMap.of()); - assertFunction("array_frequency(array[cast(null as varchar), cast(null as varchar), cast(null as varchar)])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of()); - assertFunction("array_frequency(array[varchar 'z', cast(null as varchar)])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("z", 1)); - assertFunction("array_frequency(array[varchar 'a', cast(null as varchar), varchar 'b', cast(null as varchar), cast(null as varchar) ])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("a", 1, "b", 1)); - assertFunction("array_frequency(array[varchar 'a', varchar 'b', varchar 'a', varchar 'a', varchar 'a'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("a", 4, "b", 1)); - assertFunction("array_frequency(array[varchar 'a', varchar 'b', varchar 'a', varchar 'b', varchar 'c'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("a", 2, "b", 2, "c", 1)); - assertFunction("array_frequency(array[varchar 'y', varchar 'p'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("p", 1, "y", 1)); - assertFunction("array_frequency(array[varchar 'a', varchar 'a', varchar 'p'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("p", 1, "a", 2)); - assertFunction("array_frequency(array[varchar 'z'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("z", 1)); + assertFunction("array_frequency(cast(null as array(varchar)))", createMapType(VARCHAR, INTEGER), null); + assertFunction("array_frequency(cast(array[] as array(varchar)))", createMapType(VARCHAR, INTEGER), ImmutableMap.of()); + assertFunction("array_frequency(array[cast(null as varchar), cast(null as varchar), cast(null as varchar)])", createMapType(VARCHAR, INTEGER), ImmutableMap.of()); + assertFunction("array_frequency(array[varchar 'z', cast(null as varchar)])", createMapType(VARCHAR, INTEGER), ImmutableMap.of("z", 1)); + assertFunction("array_frequency(array[varchar 'a', cast(null as varchar), varchar 'b', cast(null as varchar), cast(null as varchar) ])", createMapType(VARCHAR, INTEGER), ImmutableMap.of("a", 1, "b", 1)); + assertFunction("array_frequency(array[varchar 'a', varchar 'b', varchar 'a', varchar 'a', varchar 'a'])", createMapType(VARCHAR, INTEGER), ImmutableMap.of("a", 4, "b", 1)); + assertFunction("array_frequency(array[varchar 'a', varchar 'b', varchar 'a', varchar 'b', varchar 'c'])", createMapType(VARCHAR, INTEGER), ImmutableMap.of("a", 2, "b", 2, "c", 1)); + assertFunction("array_frequency(array[varchar 'y', varchar 'p'])", createMapType(VARCHAR, INTEGER), ImmutableMap.of("p", 1, "y", 1)); + assertFunction("array_frequency(array[varchar 'a', varchar 'a', varchar 'p'])", createMapType(VARCHAR, INTEGER), ImmutableMap.of("p", 1, "a", 2)); + assertFunction("array_frequency(array[varchar 'z'])", createMapType(VARCHAR, INTEGER), ImmutableMap.of("z", 1)); + } + + @Test + public void testArrayFrequencyComplexTypes() + { + assertFunction("array_frequency(cast(null as array(array(varchar))))", createMapType(new ArrayType(VARCHAR), INTEGER), null); + assertFunction("array_frequency(cast(array[] as array(array(varchar))))", createMapType(new ArrayType(VARCHAR), INTEGER), ImmutableMap.of()); + assertFunction("array_frequency(array[cast(null as array(varchar)), cast(null as array(varchar)), cast(null as array(varchar))])", createMapType(new ArrayType(VARCHAR), INTEGER), ImmutableMap.of()); + assertFunction("array_frequency(array[array[varchar 'z'], array[varchar 'z']])", createMapType(new ArrayType(VARCHAR), INTEGER), ImmutableMap.of(singletonList("z"), 2)); + assertFunction("array_frequency(array[array[varchar 'z'], array[varchar 't']])", createMapType(new ArrayType(VARCHAR), INTEGER), ImmutableMap.of(singletonList("z"), 1, singletonList("t"), 1)); + + RowType rowType = RowType.from(ImmutableList.of(RowType.field(INTEGER), RowType.field(INTEGER))); + String t = rowType.toString(); + assertFunction("array_frequency(array[(1, 2), (1, 3), (1, 2)])", createMapType(rowType, INTEGER), ImmutableMap.of(ImmutableList.of(1, 2), 2, ImmutableList.of(1, 3), 1)); + assertInvalidFunction("array_frequency(array[(1, null), (null, 2), (null, 1)])", StandardErrorCode.NOT_SUPPORTED, "ROW comparison not supported for fields with null elements"); + assertInvalidFunction("array_frequency(array[(null, 1), (1, null), (null, null)])", StandardErrorCode.NOT_SUPPORTED, "map key cannot be null or contain nulls"); } @Test @@ -146,6 +143,13 @@ public void testArrayHasDuplicates() // Test legacy name. assertFunction("array_has_dupes(array[varchar 'a', varchar 'b', varchar 'a'])", BOOLEAN, true); + + assertFunction("array_has_duplicates(array[array[1], array[2], array[]])", BOOLEAN, false); + assertFunction("array_has_duplicates(array[array[1], array[2], array[2]])", BOOLEAN, true); + assertFunction("array_has_duplicates(array[(1, 2), (1, 2)])", BOOLEAN, true); + assertFunction("array_has_duplicates(array[(1, 2), (2, 2)])", BOOLEAN, false); + assertInvalidFunction("array_has_duplicates(array[(1, null), (null, 2), (null, 1)])", StandardErrorCode.NOT_SUPPORTED, "ROW comparison not supported for fields with null elements"); + assertInvalidFunction("array_has_duplicates(array[(1, null), (null, 2), (null, null)])", StandardErrorCode.NOT_SUPPORTED, "map key cannot be null or contain nulls"); } @Test @@ -158,14 +162,23 @@ public void testArrayDuplicates() assertFunction("array_duplicates(array[varchar 'a', varchar 'b'])", new ArrayType(VARCHAR), ImmutableList.of()); assertFunction("array_duplicates(array[varchar 'a', varchar 'a'])", new ArrayType(VARCHAR), ImmutableList.of("a")); - assertFunction("array_duplicates(array[1, 2, 1])", new ArrayType(BIGINT), ImmutableList.of(1L)); - assertFunction("array_duplicates(array[1, 2])", new ArrayType(BIGINT), ImmutableList.of()); - assertFunction("array_duplicates(array[1, 1, 1])", new ArrayType(BIGINT), ImmutableList.of(1L)); + assertFunction("array_duplicates(array[1, 2, 1])", new ArrayType(INTEGER), ImmutableList.of(1)); + assertFunction("array_duplicates(array[1, 2])", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("array_duplicates(array[1, 1, 1])", new ArrayType(INTEGER), ImmutableList.of(1)); - assertFunction("array_duplicates(array[0, null])", new ArrayType(BIGINT), ImmutableList.of()); - assertFunction("array_duplicates(array[0, null, null])", new ArrayType(BIGINT), Collections.singletonList(null)); + assertFunction("array_duplicates(array[0, null])", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("array_duplicates(array[0, null, null])", new ArrayType(INTEGER), singletonList(null)); // Test legacy name. - assertFunction("array_dupes(array[1, 2, 1])", new ArrayType(BIGINT), ImmutableList.of(1L)); + assertFunction("array_dupes(array[1, 2, 1])", new ArrayType(INTEGER), ImmutableList.of(1)); + + RowType rowType = RowType.from(ImmutableList.of(RowType.field(INTEGER), RowType.field(INTEGER))); + String t = rowType.toString(); + assertFunction("array_duplicates(array[array[1], array[2], array[]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of()); + assertFunction("array_duplicates(array[array[1], array[2], array[2]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(2))); + assertFunction("array_duplicates(array[(1, 2), (1, 2)])", new ArrayType(rowType), ImmutableList.of(ImmutableList.of(1, 2))); + assertFunction("array_duplicates(array[(1, 2), (2, 2)])", new ArrayType(rowType), ImmutableList.of()); + assertInvalidFunction("array_duplicates(array[(1, null), (null, 2), (null, 1)])", StandardErrorCode.NOT_SUPPORTED, "ROW comparison not supported for fields with null elements"); + assertInvalidFunction("array_duplicates(array[(1, null), (null, 2), (null, null)])", StandardErrorCode.NOT_SUPPORTED, "map key cannot be null or contain nulls"); } }