diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index 5244a04c7772d..edfafb19f3876 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -81,6 +81,20 @@ Array Functions Returns the minimum value of input array. +.. function:: array_max_by(array(T), function(T, U)) -> T + + Applies the provided function to each element, and returns the element that gives the maximum value. + ``U`` can be any orderable type. :: + + SELECT array_max_by(ARRAY ['a', 'bbb', 'cc'], x -> LENGTH(x)) -- 'bbb' + +.. function:: array_min_by(array(T), function(T, U)) -> T + + Applies the provided function to each element, and returns the element that gives the minimum value. + ``U`` can be any orderable type. :: + + SELECT array_min_by(ARRAY ['a', 'bbb', 'cc'], x -> LENGTH(x)) -- 'a' + .. function:: array_normalize(x, p) -> array Normalizes array ``x`` by dividing each element by the p-norm of the array. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java index ee81451076fe6..b9057b6cd8e37 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.SqlInvokedScalarFunction; import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlParameters; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; @@ -59,7 +60,7 @@ public static String arrayAverage() @TypeParameter("T") @SqlParameter(name = "input", type = "array(T)") @SqlType("map(T, int)") - public static String arrayFrequencyBigint() + public static String arrayFrequency() { return "RETURN reduce(" + "input," + @@ -89,4 +90,30 @@ public static String arrayHasDuplicatesVarchar() { return "RETURN cardinality(array_duplicates(input)) > 0"; } + + @SqlInvokedScalarFunction(value = "array_max_by", deterministic = true, calledOnNullInput = true) + @Description("Get the maximum value of array, by using a specific transformation function") + @TypeParameter("T") + @TypeParameter("U") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "f", type = "function(T, U)")}) + @SqlType("T") + public static String arrayMaxBy() + { + return "RETURN input[" + + "array_max(zip_with(transform(input, f), sequence(1, cardinality(input)), (x, y)->IF(x IS NULL, NULL, (x, y))))[2]" + + "]"; + } + + @SqlInvokedScalarFunction(value = "array_min_by", deterministic = true, calledOnNullInput = true) + @Description("Get the minimum value of array, by using a specific transformation function") + @TypeParameter("T") + @TypeParameter("U") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "f", type = "function(T, U)")}) + @SqlType("T") + public static String arrayMinBy() + { + return "RETURN input[" + + "array_min(zip_with(transform(input, f), sequence(1, cardinality(input)), (x, y)->IF(x IS NULL, NULL, (x, y))))[2]" + + "]"; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java index 32fd1ac4d5ef6..42ed89c10df37 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java @@ -27,6 +27,9 @@ import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.util.StructuralTestUtil.mapType; +import static java.util.Arrays.asList; import static java.util.Collections.singletonList; public class TestArraySqlFunctions @@ -181,4 +184,34 @@ public void testArrayDuplicates() assertInvalidFunction("array_duplicates(array[(1, null), (null, 2), (null, 1)])", StandardErrorCode.NOT_SUPPORTED, "ROW comparison not supported for fields with null elements"); assertInvalidFunction("array_duplicates(array[(1, null), (null, 2), (null, null)])", StandardErrorCode.NOT_SUPPORTED, "map key cannot be null or contain nulls"); } + + @Test + public void testArrayMaxBy() + { + assertFunction("ARRAY_MAX_BY(ARRAY [double'1.0', double'2.0'], i -> i)", DOUBLE, 2.0d); + assertFunction("ARRAY_MAX_BY(ARRAY [double'-3.0', double'2.0'], i -> i*i)", DOUBLE, -3.0d); + assertFunction("ARRAY_MAX_BY(ARRAY ['a', 'bb', 'c'], x -> LENGTH(x))", createVarcharType(2), "bb"); + assertFunction("ARRAY_MAX_BY(ARRAY [1, 2, 3], x -> 1-x)", INTEGER, 1); + assertFunction("ARRAY_MAX_BY(ARRAY [ARRAY['a'], ARRAY['b', 'b'], ARRAY['c']], x -> CARDINALITY(x))", new ArrayType(createVarcharType(1)), asList("b", "b")); + assertFunction("ARRAY_MAX_BY(ARRAY [MAP(ARRAY['foo', 'bar'], ARRAY[1, 2]), MAP(ARRAY['foo', 'bar'], ARRAY[0, 3])], x -> x['foo'])", mapType(createVarcharType(3), INTEGER), ImmutableMap.of("foo", 1, "bar", 2)); + assertFunction("ARRAY_MAX_BY(ARRAY [CAST(ROW(0, 2.0) AS ROW(x BIGINT, y DOUBLE)), CAST(ROW(1, 3.0) AS ROW(x BIGINT, y DOUBLE))], r -> r.y).x", BIGINT, 1L); + assertFunction("ARRAY_MAX_BY(ARRAY [null, double'1.0', double'2.0'], i -> i)", DOUBLE, null); + assertFunction("ARRAY_MAX_BY(ARRAY [cast(null as double), cast(null as double)], i -> i)", DOUBLE, null); + assertFunction("ARRAY_MAX_BY(cast(null as array(double)), i -> i)", DOUBLE, null); + } + + @Test + public void testArrayMinBy() + { + assertFunction("ARRAY_MIN_BY(ARRAY [double'1.0', double'2.0'], i -> i)", DOUBLE, 1.0d); + assertFunction("ARRAY_MIN_BY(ARRAY [double'-3.0', double'2.0'], i -> i*i)", DOUBLE, 2.0d); + assertFunction("ARRAY_MIN_BY(ARRAY ['a', 'bb', 'c'], x -> LENGTH(x))", createVarcharType(2), "a"); + assertFunction("ARRAY_MIN_BY(ARRAY [1, 2, 3], x -> 1-x)", INTEGER, 3); + assertFunction("ARRAY_MIN_BY(ARRAY [ARRAY['a'], ARRAY['b', 'b'], ARRAY['c']], x -> CARDINALITY(x))", new ArrayType(createVarcharType(1)), singletonList("a")); + assertFunction("ARRAY_MIN_BY(ARRAY [MAP(ARRAY['foo', 'bar'], ARRAY[1, 2]), MAP(ARRAY['foo', 'bar'], ARRAY[0, 3])], x -> x['foo'])", mapType(createVarcharType(3), INTEGER), ImmutableMap.of("foo", 0, "bar", 3)); + assertFunction("ARRAY_MIN_BY(ARRAY [CAST(ROW(0, 2.0) AS ROW(x BIGINT, y DOUBLE)), CAST(ROW(1, 3.0) AS ROW(x BIGINT, y DOUBLE))], r -> r.y).x", BIGINT, 0L); + assertFunction("ARRAY_MIN_BY(ARRAY [null, double'1.0', double'2.0'], i -> i)", DOUBLE, null); + assertFunction("ARRAY_MIN_BY(ARRAY [cast(null as double), cast(null as double)], i -> i)", DOUBLE, null); + assertFunction("ARRAY_MIN_BY(cast(null as array(double)), i -> i)", DOUBLE, null); + } }