diff --git a/presto-docs/src/main/sphinx/functions/map.rst b/presto-docs/src/main/sphinx/functions/map.rst index 7ba9c46a6a201..6e9cd04f3ba4d 100644 --- a/presto-docs/src/main/sphinx/functions/map.rst +++ b/presto-docs/src/main/sphinx/functions/map.rst @@ -107,6 +107,23 @@ Map Functions Returns all the values in the map ``x``. +.. function:: map_top_n_values(x(K,V), n) -> array(K) + + Returns top n values in the map ``x``. + ``n`` must be a positive integer + For bottom ``n`` values, use the function with lambda operator to perform custom sorting + + SELECT map_top_n_values(map(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), 2) --- [3, 2] + +.. function:: map_top_n_values(x(K,V), n, function(V,V,int)) -> array(V) + + Returns top n values in the map ``x`` based on the given comparator ``function``. The comparator will take + two nullable arguments representing two values of the ``map``. It returns -1, 0, or 1 + as the first value is less than, equal to, or greater than the second value. + If the comparator function returns other values (including ``NULL``), the query will fail and raise an error :: + + SELECT map_top_n_values(map(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), 2, (x, y) -> IF(x < y, -1, IF(x = y, 0, 1))) --- [3, 2] + .. function:: map_zip_with(map(K,V1), map(K,V2), function(K,V1,V2,V3)) -> map(K,V3) Merges the two given maps into a single map by applying ``function`` to the pair of values with the same key. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java index 737a8212750e8..2c19dc7d762aa 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java @@ -45,4 +45,26 @@ public static String mapTopNKeysComparator() { return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(map_keys(input), f)), 1, n))"; } + + @SqlInvokedScalarFunction(value = "map_top_n_values", deterministic = true, calledOnNullInput = false) + @Description("Returns the top N values of the given map in descending order according to the natural ordering of its values.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint")}) + @SqlType("array") + public static String mapTopNValues() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(array_sort_desc(map_values(input)), 1, n))"; + } + + @SqlInvokedScalarFunction(value = "map_top_n_values", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N values of the given map sorted using the provided lambda comparator.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(V, V, int)")}) + @SqlType("array") + public static String mapTopNValuesComparator() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(remove_nulls(map_values(input)), f)) || filter(map_values(input), x -> x is null), 1, n))"; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNValuesComparatorFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNValuesComparatorFunction.java new file mode 100644 index 0000000000000..8e55b6bd9b9f7 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNValuesComparatorFunction.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar.sql; + +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.facebook.presto.spi.StandardErrorCode; +import com.facebook.presto.sql.analyzer.SemanticErrorCode; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; + +public class TestMapTopNValuesComparatorFunction + extends AbstractTestFunctions +{ + @Test + public void testBasic() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 2, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of(1, 2)); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 2, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of(3, 2)); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY['x', 'y', 'z'], ARRAY['a', 'b', 'c']), 3, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(createVarcharType(1)), ImmutableList.of("c", "b", "a")); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY['x', 'y', 'z'], ARRAY['a1', 'b2', 'c3']), 1, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(createVarcharType(2)), ImmutableList.of("c3")); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY['a', 'b', 'c', 'd'], ARRAY['xyz', 'xy', 'yz', 'z']), 4, (x, y) -> CASE " + + "WHEN LENGTH(x) > LENGTH(y) THEN 1 " + + "WHEN LENGTH(x) < LENGTH(y) THEN -1 " + + "WHEN x > y THEN 1 " + + "WHEN x < y THEN -1 " + + "ELSE -1 END)", + new ArrayType(createVarcharType(3)), ImmutableList.of("xyz", "yz", "xy", "z")); + } + + @Test + public void testNLargerThanMapSize() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 8, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 9, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of(3, 2, 1)); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY['x', 'y', 'z'], ARRAY['abc', 'bcd', 'cde']), 10, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(createVarcharType(3)), ImmutableList.of("cde", "bcd", "abc")); + } + + @Test + public void testNegativeN() + { + assertInvalidFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), -1, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0"); + assertInvalidFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1.99, -2.12, 3.01]), -2, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0"); + assertInvalidFunction("MAP_TOP_N_VALUES(MAP(ARRAY[1, 2, 3], ARRAY['x', 'y', 'z']), -3, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0"); + assertInvalidFunction("MAP_TOP_N_VALUES(MAP(ARRAY['a', 'b', 'c', 'd'], ARRAY['xyz', 'xy', 'yz', 'z']), -2, (x, y) -> CASE " + + "WHEN LENGTH(x) > LENGTH(y) THEN 1 " + + "WHEN LENGTH(x) < LENGTH(y) THEN -1 " + + "WHEN x > y THEN 1 " + + "WHEN x < y THEN -1 " + + "ELSE -1 END)", + StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0"); + } + + @Test + public void testZeroN() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 0, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 0, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[1, 2, 3], ARRAY['x', 'y', 'z']), 0, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(createVarcharType(1)), ImmutableList.of()); + } + + @Test + public void testEmpty() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[], ARRAY[]), 1, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(UNKNOWN), ImmutableList.of()); + } + + @Test + public void testNull() + { + assertFunction("MAP_TOP_N_VALUES(NULL, 1, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(UNKNOWN), null); + } + + @Test + public void testComplexValues() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY['a', 'b', 'c'], ARRAY[ROW('x', 1), ROW('y', 2), ROW('z', 3)]), 3," + + "(x, y) -> IF(x[1] < y[1], 1, IF(x[1] = y[1], 0, -1)))", + new ArrayType(RowType.from(ImmutableList.of(RowType.field(createVarcharType(1)), RowType.field(INTEGER)))), + ImmutableList.of(ImmutableList.of("x", 1), ImmutableList.of("y", 2), ImmutableList.of("z", 3))); + } + + @Test + public void testBadLambda() + { + assertInvalidFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 1, (x, y) -> 10)", StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Lambda comparator must return either -1, 0, or 1"); + assertInvalidFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 2, null)", SemanticErrorCode.FUNCTION_NOT_FOUND); + assertInvalidFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 3, (x, y) -> IF(x = 'test', 1, -1))", SemanticErrorCode.TYPE_MISMATCH); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNValuesFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNValuesFunction.java new file mode 100644 index 0000000000000..c5bcf0b3dc2c7 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNValuesFunction.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar.sql; + +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.facebook.presto.spi.StandardErrorCode; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.DecimalType.createDecimalType; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static java.util.Arrays.asList; + +public class TestMapTopNValuesFunction + extends AbstractTestFunctions +{ + @Test + public void testBasic() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 2)", new ArrayType(INTEGER), ImmutableList.of(3, 2)); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[-1, -2, -3]), 2)", new ArrayType(INTEGER), ImmutableList.of(-1, -2)); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY['x', 'y', 'z'], ARRAY['ab', 'bc', 'cd']), 1)", new ArrayType(createVarcharType(2)), ImmutableList.of("cd")); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY['x', 'y', 'z'], ARRAY[123.0, 99.5, 1000.99]), 3)", new ArrayType(createDecimalType(6, 2)), ImmutableList.of(decimal("1000.99"), decimal("123.00"), decimal("99.50"))); + } + + @Test + public void tesMayHaveNull() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, null, 3]), 3)", new ArrayType(INTEGER), asList(3, 1, null)); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[-1, -2, null]), 2)", new ArrayType(INTEGER), asList(-1, -2)); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY['x', 'y', 'z'], ARRAY[null, 'bc', 'cd']), 3)", new ArrayType(createVarcharType(2)), asList("cd", "bc", null)); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY['x', 'y', 'z'], ARRAY[123.0, 99.5, null]), 3)", new ArrayType(createDecimalType(4, 1)), asList(decimal("123.0"), decimal("99.5"), null)); + } + + @Test + public void testNegativeN() + { + assertInvalidFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, null, 3]), -1)", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0"); + assertInvalidFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY['a', 'b', 'c']), -2)", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0"); + } + + @Test + public void testZeroN() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY[1, null, 3]), 0)", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[4, 5, 6], ARRAY['a', 'b', 'c']), 0)", new ArrayType(createVarcharType(1)), ImmutableList.of()); + } + + @Test + public void testEmpty() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[], ARRAY[]), 5)", new ArrayType(UNKNOWN), ImmutableList.of()); + } + + @Test + public void testNull() + { + assertFunction("MAP_TOP_N_VALUES(NULL, 1)", new ArrayType(UNKNOWN), null); + } + + @Test + public void testComplexValues() + { + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[1, 2], ARRAY[ROW('x', 1), ROW('y', 2)]), 1)", new ArrayType(RowType.from(ImmutableList.of(RowType.field(createVarcharType(1)), RowType.field(INTEGER)))), ImmutableList.of(ImmutableList.of("y", 2))); + assertFunction("MAP_TOP_N_VALUES(MAP(ARRAY[1, 2], ARRAY[ROW('x', 1), ROW('x', -2)]), 1)", new ArrayType(RowType.from(ImmutableList.of(RowType.field(createVarcharType(1)), RowType.field(INTEGER)))), ImmutableList.of(ImmutableList.of("x", 1))); + } +}