diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index 767d2acae0e58..98e4b6e346a3e 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -48,6 +48,16 @@ Array Functions Returns an array of elements in ``x`` but not in ``y``, without duplicates. +.. function:: array_frequency(array(bigint)) -> map(bigint, int) + + Returns a map: keys are the unique elements in the ``array``, values are how many times the key appears. + Ignores null elements. Empty array returns empty map. + +.. function:: array_frequency(array(varchar)) -> map(varchar, int) + + Returns a map: keys are the unique elements in the ``array``, values are how many times the key appears. + Ignores null elements. Empty array returns empty map. + .. function:: array_intersect(x, y) -> array Returns an array of the elements in the intersection of ``x`` and ``y``, without duplicates. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArrayArithmeticFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArrayArithmeticFunctions.java index aeb6d7daf74fc..be61800bc9bcb 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArrayArithmeticFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArrayArithmeticFunctions.java @@ -52,4 +52,30 @@ public static String arrayAverage() "(s, x) -> IF(x IS NOT NULL, (s[1] + x, s[2] + 1), s), " + "s -> if(s[2] = 0, cast(null as double), s[1] / cast(s[2] as double)))"; } + + @SqlInvokedScalarFunction(value = "array_frequency", deterministic = true, calledOnNullInput = false) + @Description("Returns the frequency of all array elements as a map.") + @SqlParameter(name = "input", type = "array(bigint)") + @SqlType("map(bigint, int)") + public static String arrayFrequencyBigint() + { + return "RETURN reduce(" + + "input," + + "MAP()," + + "(m, x) -> IF (x IS NOT NULL, MAP_CONCAT(m,MAP_FROM_ENTRIES(ARRAY[ROW(x, COALESCE(ELEMENT_AT(m,x) + 1, 1))])), m)," + + "m -> m)"; + } + + @SqlInvokedScalarFunction(value = "array_frequency", deterministic = true, calledOnNullInput = false) + @Description("Returns the frequency of all array elements as a map.") + @SqlParameter(name = "input", type = "array(varchar)") + @SqlType("map(varchar, int)") + public static String arrayFrequencyVarchar() + { + return "RETURN reduce(" + + "input," + + "MAP()," + + "(m, x) -> IF (x IS NOT NULL, MAP_CONCAT(m,MAP_FROM_ENTRIES(ARRAY[ROW(x, COALESCE(ELEMENT_AT(m,x) + 1, 1))])), m)," + + "m -> m)"; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArrayArithmeticFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArrayArithmeticFunctions.java index c017d3c532123..35e60c61e10f9 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArrayArithmeticFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArrayArithmeticFunctions.java @@ -13,11 +13,20 @@ */ package com.facebook.presto.operator.scalar.sql; +import com.facebook.presto.common.type.MapType; +import com.facebook.presto.common.type.TestRowType; +import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; +import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; public class TestArrayArithmeticFunctions extends AbstractTestFunctions @@ -65,4 +74,51 @@ public void testArrayAverage() assertFunction("array_average(array[null, null])", DOUBLE, null); assertFunction("array_average(null)", DOUBLE, null); } + + @Test + public void testArrayFrequencyBigint() + { + FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + MapType type = new MapType(BIGINT, + INTEGER, + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, "throwUnsupportedOperation")); + TypeSignature typeSignature = TypeSignature.parseTypeSignature(type.getDisplayName()); + + assertFunction("array_frequency(cast(null as array(bigint)))", functionAndTypeManager.getType(typeSignature), null); + assertFunction("array_frequency(cast(array[] as array(bigint)))", functionAndTypeManager.getType(typeSignature), ImmutableMap.of()); + assertFunction("array_frequency(array[cast(null as bigint), cast(null as bigint), cast(null as bigint)])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of()); + assertFunction("array_frequency(array[cast(null as bigint), bigint '1'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 1)); + assertFunction("array_frequency(array[cast(null as bigint), bigint '1', bigint '3', cast(null as bigint), bigint '1', bigint '3', cast(null as bigint)])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 2, 3L, 2)); + assertFunction("array_frequency(array[bigint '1', bigint '1', bigint '2', bigint '2', bigint '3', bigint '1', bigint '3', bigint '2'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 3, 2L, 3, 3L, 2)); + assertFunction("array_frequency(array[bigint '45'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(45L, 1)); + assertFunction("array_frequency(array[bigint '-45'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(-45L, 1)); + assertFunction("array_frequency(array[bigint '1', bigint '3', bigint '1', bigint '3'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 2, 3L, 2)); + assertFunction("array_frequency(array[bigint '3', bigint '1', bigint '3',bigint '1'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 2, 3L, 2)); + assertFunction("array_frequency(array[bigint '4',bigint '3',bigint '3',bigint '2',bigint '2',bigint '2',bigint '1',bigint '1',bigint '1',bigint '1'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 4, 2L, 3, 3L, 2, 4L, 1)); + assertFunction("array_frequency(array[bigint '3', bigint '3', bigint '2', bigint '2', bigint '5', bigint '5', bigint '1', bigint '1'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of(1L, 2, 2L, 2, 3L, 2, 5L, 2)); + } + + @Test + public void testArrayFrequencyVarchar() + { + FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + + MapType type = new MapType(VARCHAR, + INTEGER, + methodHandle(TestRowType.class, "throwUnsupportedOperation"), + methodHandle(TestRowType.class, "throwUnsupportedOperation")); + TypeSignature typeSignature = TypeSignature.parseTypeSignature(type.getDisplayName()); + + assertFunction("array_frequency(cast(null as array(varchar)))", functionAndTypeManager.getType(typeSignature), null); + assertFunction("array_frequency(cast(array[] as array(varchar)))", functionAndTypeManager.getType(typeSignature), ImmutableMap.of()); + assertFunction("array_frequency(array[cast(null as varchar), cast(null as varchar), cast(null as varchar)])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of()); + assertFunction("array_frequency(array[varchar 'z', cast(null as varchar)])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("z", 1)); + assertFunction("array_frequency(array[varchar 'a', cast(null as varchar), varchar 'b', cast(null as varchar), cast(null as varchar) ])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("a", 1, "b", 1)); + assertFunction("array_frequency(array[varchar 'a', varchar 'b', varchar 'a', varchar 'a', varchar 'a'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("a", 4, "b", 1)); + assertFunction("array_frequency(array[varchar 'a', varchar 'b', varchar 'a', varchar 'b', varchar 'c'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("a", 2, "b", 2, "c", 1)); + assertFunction("array_frequency(array[varchar 'y', varchar 'p'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("p", 1, "y", 1)); + assertFunction("array_frequency(array[varchar 'a', varchar 'a', varchar 'p'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("p", 1, "a", 2)); + assertFunction("array_frequency(array[varchar 'z'])", functionAndTypeManager.getType(typeSignature), ImmutableMap.of("z", 1)); + } }