diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index 05dee6a99447b..5aeccd90ba83a 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -207,6 +207,16 @@ Array Functions SELECT array_sort_desc(ARRAY [null, 100, null, 1, 10, 50]); -- [100, 50, 10, 1, null, null] SELECT array_sort_desc(ARRAY [ARRAY ["a", null], null, ARRAY ["a"]); -- [["a", null], ["a"], null] +.. function:: array_split_into_chunks(array(T), int) -> array(array(T)) + + Returns an ``array`` of arrays splitting the input ``array`` into chunks of given length. + If the ``array`` is not evenly divisible it will split into as many possible chunks and return + the left over elements for the last ``array``. Ignores null inputs, but not elements. + + SELECT array_split_into_chunks(ARRAY [1, 2, 3, 4], 3); -- [[1, 2, 3], [4]] + SELECT array_split_into_chunks(null, null); -- null + SELECT array_split_into_chunks(array[1, 2, 3, cast(null as int)], 2]); -- [[1, 2], [3, null]] + .. function:: array_sum(array(T)) -> bigint/double Returns the sum of all non-null elements of the ``array``. If there is no non-null elements, returns ``0``. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java index fa275eca5362f..b51855cc9a2ee 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java @@ -37,6 +37,24 @@ public static String arrayAverage() "s -> if(s[2] = 0, cast(null as double), s[1] / cast(s[2] as double)))"; } + @SqlInvokedScalarFunction(value = "array_split_into_chunks", deterministic = true, calledOnNullInput = false) + @Description("Returns an array of arrays splitting input array into chunks of given length. " + + "If array is not evenly divisible it will split into as many possible chunks and " + + "return the left over elements for the last array. Returns null for null inputs, but not elements.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "sz", type = "int")}) + @SqlType("array(array(T))") + public static String arraySplitIntoChunks() + { + return "RETURN IF(sz <= 0, " + + "fail('Invalid slice size: ' || cast(sz as varchar) || '. Size must be greater than zero.'), " + + "IF(cardinality(input) / sz > 10000, " + + "fail('Cannot split array of size: ' || cast(cardinality(input) as varchar) || ' into more than 10000 parts.'), " + + "transform(" + + "sequence(1, cardinality(input), sz), " + + "x -> slice(input, x, sz))))"; + } + @SqlInvokedScalarFunction(value = "array_frequency", deterministic = true, calledOnNullInput = false) @Description("Returns the frequency of all array elements as a map.") @TypeParameter("T") diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java index a0da487dae0a4..216f27b8d44cd 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java @@ -22,6 +22,9 @@ import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + import static com.facebook.presto.block.BlockAssertions.createMapType; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; @@ -52,6 +55,82 @@ public void testArrayAverage() assertFunction("array_average(null)", DOUBLE, null); } + @Test + public void testArraySplitIntoChunksBigint() + { + assertFunction("array_split_into_chunks(array[bigint '1', bigint '2', bigint '3'], 2)", new ArrayType(new ArrayType(BIGINT)), ImmutableList.of(ImmutableList.of(1L, 2L), ImmutableList.of(3L))); + assertFunction("array_split_into_chunks(array[bigint '1', bigint '2', bigint '3', bigint '4' , bigint '5'], 2)", new ArrayType(new ArrayType(BIGINT)), ImmutableList.of(ImmutableList.of(1L, 2L), ImmutableList.of(3L, 4L), ImmutableList.of(5L))); + assertFunction("array_split_into_chunks(array[bigint '2', bigint '3', bigint '4' , bigint '5'], 4)", new ArrayType(new ArrayType(BIGINT)), ImmutableList.of(ImmutableList.of(2L, 3L, 4L, 5L))); + assertFunction("array_split_into_chunks(array[bigint '-66', bigint '3', bigint '-66' , bigint '5'], 1)", new ArrayType(new ArrayType(BIGINT)), ImmutableList.of(ImmutableList.of(-66L), ImmutableList.of(3L), ImmutableList.of(-66L), ImmutableList.of(5L))); + assertFunction("array_split_into_chunks(array[bigint '-1', bigint '2', bigint '3' , bigint '-11'], 6)", new ArrayType(new ArrayType(BIGINT)), ImmutableList.of(ImmutableList.of(-1L, 2L, 3L, -11L))); + assertFunction("array_split_into_chunks(array[bigint '1', bigint '2', bigint '3', bigint '4', bigint '5', bigint '6', bigint '7'], 3)", new ArrayType(new ArrayType(BIGINT)), ImmutableList.of(ImmutableList.of(1L, 2L, 3L), ImmutableList.of(4L, 5L, 6L), ImmutableList.of(7L))); + assertInvalidFunction("array_split_into_chunks(array[bigint '-1', bigint '2', bigint '3' , bigint '-11'], 0)", StandardErrorCode.GENERIC_USER_ERROR, "Invalid slice size: 0. Size must be greater than zero."); + assertInvalidFunction( + "array_split_into_chunks(array[" + IntStream.rangeClosed(1, 12001).mapToObj(Long::toString).collect(Collectors.joining(", ")) + "], 1)", + StandardErrorCode.GENERIC_USER_ERROR, + "Cannot split array of size: 12001 into more than 10000 parts."); + } + + @Test + public void testArraySplitIntoChunksVarchar() + { + assertFunction("array_split_into_chunks(array[varchar 'a', varchar 'b', varchar 'c'], 2)", new ArrayType(new ArrayType(VARCHAR)), ImmutableList.of(ImmutableList.of("a", "b"), ImmutableList.of("c"))); + assertFunction("array_split_into_chunks(array[varchar 'a', varchar 'b', varchar 'c', varchar 'd', varchar 'e'], 2)", new ArrayType(new ArrayType(VARCHAR)), ImmutableList.of(ImmutableList.of("a", "b"), ImmutableList.of("c", "d"), ImmutableList.of("e"))); + assertFunction("array_split_into_chunks(array[varchar 'z', varchar 'y', varchar 'x', varchar 'w', varchar 'v', varchar 'u'], 6)", new ArrayType(new ArrayType(VARCHAR)), ImmutableList.of(ImmutableList.of("z", "y", "x", "w", "v", "u"))); + assertFunction("array_split_into_chunks(array[varchar 'k', varchar 'l', varchar 'm'], 1)", new ArrayType(new ArrayType(VARCHAR)), ImmutableList.of(ImmutableList.of("k"), ImmutableList.of("l"), ImmutableList.of("m"))); + assertFunction("array_split_into_chunks(array[varchar 'k', varchar 'l', varchar 'm'], 8)", new ArrayType(new ArrayType(VARCHAR)), ImmutableList.of(ImmutableList.of("k", "l", "m"))); + assertFunction("array_split_into_chunks(array[varchar 'k', varchar 'l', varchar 'm', varchar 'n', varchar 'o', varchar 'p', varchar 'q', varchar 'r'], 3)", new ArrayType(new ArrayType(VARCHAR)), ImmutableList.of(ImmutableList.of("k", "l", "m"), ImmutableList.of("n", "o", "p"), ImmutableList.of("q", "r"))); + assertInvalidFunction("array_split_into_chunks(array[varchar 'a', varchar 'b', varchar 'c', varchar 'd'], 0)", StandardErrorCode.GENERIC_USER_ERROR, "Invalid slice size: 0. Size must be greater than zero."); + assertInvalidFunction( + "array_split_into_chunks(array[" + IntStream.rangeClosed(1, 10002).mapToObj(s -> "'" + s + "'").collect(Collectors.joining(", ")) + "], 1)", + StandardErrorCode.GENERIC_USER_ERROR, + "Cannot split array of size: 10002 into more than 10000 parts."); + } + + @Test + public void testArraySplitIntoChunksInteger() + { + assertFunction("array_split_into_chunks(array[1, 2, 3], 2)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1, 2), ImmutableList.of(3))); + assertFunction("array_split_into_chunks(array[1, 2, 3], 2)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1, 2), ImmutableList.of(3))); + assertFunction("array_split_into_chunks(array[1, 2, 5, 7, 9, 11], 4)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1, 2, 5, 7), ImmutableList.of(9, 11))); + assertFunction("array_split_into_chunks(array[1, 2, 5, 7, 9, 11, 22], 7)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1, 2, 5, 7, 9, 11, 22))); + assertFunction("array_split_into_chunks(array[9, 10, 11, 12], 1)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(9), ImmutableList.of(10), ImmutableList.of(11), ImmutableList.of(12))); + assertFunction("array_split_into_chunks(array[9, 10, 11, 12], 20)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(9, 10, 11, 12))); + assertFunction("array_split_into_chunks(array[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 4)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(9, 10, 11, 12), ImmutableList.of(13, 14, 15, 16), ImmutableList.of(17, 18, 19, 20))); + assertInvalidFunction("array_split_into_chunks(array[5, 6, 7, 8], 0)", StandardErrorCode.GENERIC_USER_ERROR, "Invalid slice size: 0. Size must be greater than zero."); + assertInvalidFunction( + "array_split_into_chunks(array[" + IntStream.rangeClosed(1, 10001).mapToObj(Integer::toString).collect(Collectors.joining(", ")) + "], 1)", + StandardErrorCode.GENERIC_USER_ERROR, + "Cannot split array of size: 10001 into more than 10000 parts."); + } + + @Test + public void testArraySplitIntoChunksDouble() + { + assertFunction("array_split_into_chunks(array[cast(1.0 as double), cast(2.0 as double), cast(3.0 as double)], 2)", new ArrayType(new ArrayType(DOUBLE)), ImmutableList.of(ImmutableList.of(1.0, 2.0), ImmutableList.of(3.0))); + assertFunction("array_split_into_chunks(array[cast(1.2 as double), cast(2.3 as double), cast(3.4 as double)], 1)", new ArrayType(new ArrayType(DOUBLE)), ImmutableList.of(ImmutableList.of(1.2), ImmutableList.of(2.3), ImmutableList.of(3.4))); + assertFunction("array_split_into_chunks(array[cast(1.2 as double), cast(2.3 as double), cast(3.4 as double), cast(4.5 as double), cast(5.6 as double)], 5)", new ArrayType(new ArrayType(DOUBLE)), ImmutableList.of(ImmutableList.of(1.2, 2.3, 3.4, 4.5, 5.6))); + assertFunction("array_split_into_chunks(array[cast(1.2 as double), cast(2.3 as double), cast(3.4 as double), cast(4.5 as double), cast(5.6 as double)], 100)", new ArrayType(new ArrayType(DOUBLE)), ImmutableList.of(ImmutableList.of(1.2, 2.3, 3.4, 4.5, 5.6))); + assertFunction("array_split_into_chunks(array[cast(1.2 as double), cast(2.3 as double), cast(3.4 as double), cast(4.5 as double), cast(5.6 as double), cast(6.7 as double), cast(7.8 as double), cast(8.9 as double), cast(9.1 as double)], 3)", new ArrayType(new ArrayType(DOUBLE)), ImmutableList.of(ImmutableList.of(1.2, 2.3, 3.4), ImmutableList.of(4.5, 5.6, 6.7), ImmutableList.of(7.8, 8.9, 9.1))); + assertInvalidFunction("array_split_into_chunks(array[cast(1.2 as double), cast(2.3 as double), cast(3.4 as double)], 0)", StandardErrorCode.GENERIC_USER_ERROR, "Invalid slice size: 0. Size must be greater than zero."); + assertInvalidFunction( + "array_split_into_chunks(array[" + IntStream.rangeClosed(1, 10001).mapToObj(Double::toString).collect(Collectors.joining(", ")) + "], 1)", + StandardErrorCode.GENERIC_USER_ERROR, + "Cannot split array of size: 10001 into more than 10000 parts."); + } + + @Test + public void testArraySplitIntoChunksNulls() + { + assertFunction("array_split_into_chunks(array[cast(null as bigint), bigint '1', cast(null as bigint), bigint '2'], 2)", new ArrayType(new ArrayType(BIGINT)), ImmutableList.of(asList(null, 1L), asList(null, 2L))); + assertFunction("array_split_into_chunks(array[cast(null as varchar), cast(null as varchar)], 2)", new ArrayType(new ArrayType(VARCHAR)), ImmutableList.of(asList(null, null))); + assertFunction("array_split_into_chunks(array[cast(null as double), 1.1, 2.1, 3.1], 2)", new ArrayType(new ArrayType(DOUBLE)), ImmutableList.of(asList(null, 1.1), asList(2.1, 3.1))); + assertFunction("array_split_into_chunks(array[1, 2, 3, cast(null as int)], 2)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1, 2), asList(3, null))); + assertFunction("array_split_into_chunks(null, null)", new ArrayType(new ArrayType(UNKNOWN)), null); + assertFunction("array_split_into_chunks(null, 1)", new ArrayType(new ArrayType(UNKNOWN)), null); + assertFunction("array_split_into_chunks(array[1], null)", new ArrayType(new ArrayType(INTEGER)), null); + } + @Test public void testArrayFrequencyBigint() {