diff --git a/presto-docs/src/main/sphinx/functions/map.rst b/presto-docs/src/main/sphinx/functions/map.rst index 5079c3c3bc1f1..f04165403b0cb 100644 --- a/presto-docs/src/main/sphinx/functions/map.rst +++ b/presto-docs/src/main/sphinx/functions/map.rst @@ -67,6 +67,10 @@ Map Functions SELECT map_filter(MAP(ARRAY[10, 20, 30], ARRAY['a', NULL, 'c']), (k, v) -> v IS NOT NULL); -- {10 -> a, 30 -> c} SELECT map_filter(MAP(ARRAY['k1', 'k2', 'k3'], ARRAY[20, 3, 15]), (k, v) -> v > 10); -- {k1 -> 20, k3 -> 15} +.. function:: map_remove_null_values(x(K,V)) -> map(K, V) + + Removes all the entries where the value is null from the map ``x``. + .. function:: map_subset(map(K,V), array(k)) -> map(K,V) Constructs a map from those entries of ``map`` for which the key is in the array given:: @@ -98,9 +102,12 @@ Map Functions SELECT map_top_n_keys(map(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), 2, (x, y) -> IF(x < y, -1, IF(x = y, 0, 1))) --- ['c', 'b'] -.. function:: map_remove_null_values(x(K,V)) -> map(K, V) +.. function:: map_top_n(x(K,V), n) -> map(K, V) - Removes all the entries where the value is null from the map ``x``. + Truncates map items. Keeps only the top N elements by value. + ``n`` must be a non-negative integer + + SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[2, 3, 1]), 2) --- {'b' -> 3, 'a' -> 2} .. function:: map_normalize(x(varchar,double)) -> map(varchar,double) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java index bae51afa0b141..1a2bfbd577463 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java @@ -24,6 +24,17 @@ public class MapSqlFunctions { private MapSqlFunctions() {} + @SqlInvokedScalarFunction(value = "map_top_n", deterministic = true, calledOnNullInput = true) + @Description("Truncates map items. Keeps only the top N elements by value.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint")}) + @SqlType("map(K, V)") + public static String mapTopN() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), map_from_entries(slice(array_sort(map_entries(map_filter(input, (k, v) -> v is not null)), (x, y) -> IF(x[2] < y[2], 1, IF(x[2] = y[2], 0, -1))) || map_entries(map_filter(input, (k, v) -> v is null)), 1, n)))"; + } + @SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = false) @Description("Returns the top N keys of the given map in descending order according to the natural ordering of its values.") @TypeParameter("K") diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNFunction.java new file mode 100644 index 0000000000000..99ba8020b983d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNFunction.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar.sql; + +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.facebook.presto.spi.StandardErrorCode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.DecimalType.createDecimalType; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.util.StructuralTestUtil.mapType; +import static java.util.Arrays.asList; + +public class TestMapTopNFunction + extends AbstractTestFunctions +{ + @Test + public void testBasic() + { + assertFunction( + "MAP_TOP_N(MAP(ARRAY[1, 2, 3], ARRAY[4, 5, 6]), 2)", + mapType(INTEGER, INTEGER), + ImmutableMap.of(3, 6, 2, 5)); + assertFunction( + "MAP_TOP_N(MAP(ARRAY[-1, -2, -3], ARRAY[4, 5, 6]), 2)", + mapType(INTEGER, INTEGER), + ImmutableMap.of(-3, 6, -2, 5)); + assertFunction( + "MAP_TOP_N(MAP(ARRAY['ab', 'bc', 'cd'], ARRAY['x', 'y', 'z']), 1)", + mapType(createVarcharType(2), createVarcharType(1)), + ImmutableMap.of("cd", "z")); + assertFunction( + "MAP_TOP_N(MAP(ARRAY[123.0, 99.5, 1000.99], ARRAY['x', 'y', 'z']), 3)", + mapType(createDecimalType(6, 2), createVarcharType(1)), + ImmutableMap.of(decimal("1000.99"), "z", decimal("99.50"), "y", decimal("123.00"), "x")); + } + + @Test + public void testNegativeN() + { + assertInvalidFunction( + "MAP_TOP_N(MAP(ARRAY[100, 200, 300], ARRAY[4, 5, 6]), -3)", + StandardErrorCode.GENERIC_USER_ERROR, + "n must be greater than or equal to 0"); + assertInvalidFunction( + "MAP_TOP_N(MAP(ARRAY[1, 2, 3], ARRAY[4, 5, 6]), -1)", + StandardErrorCode.GENERIC_USER_ERROR, + "n must be greater than or equal to 0"); + assertInvalidFunction( + "MAP_TOP_N(MAP(ARRAY['a', 'b', 'c'], ARRAY[4, 5, 6]), -2)", + StandardErrorCode.GENERIC_USER_ERROR, + "n must be greater than or equal to 0"); + } + + @Test + public void testZeroN() + { + assertFunction( + "MAP_TOP_N(MAP(ARRAY[-1, -2, -3], ARRAY[4, 5, 6]), 0)", + mapType(INTEGER, INTEGER), + ImmutableMap.of()); + assertFunction( + "MAP_TOP_N(MAP(ARRAY['ab', 'bc', 'cd'], ARRAY['x', 'y', 'z']), 0)", + mapType(createVarcharType(2), createVarcharType(1)), + ImmutableMap.of()); + assertFunction( + "MAP_TOP_N(MAP(ARRAY[123.0, 99.5, 1000.99], ARRAY['x', 'y', 'z']), 0)", + mapType(createDecimalType(6, 2), createVarcharType(1)), + ImmutableMap.of()); + } + + @Test + public void testEmpty() + { + assertFunction("MAP_TOP_N(MAP(ARRAY[], ARRAY[]), 5)", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + } + + @Test + public void testNull() + { + assertFunction("MAP_TOP_N(NULL, 1)", mapType(UNKNOWN, UNKNOWN), null); + } + + @Test + public void testComplexKeys() + { + assertFunction( + "MAP_TOP_N(MAP(ARRAY[ROW('x', 1), ROW('y', 2)], ARRAY[1, 2]), 1)", + mapType(RowType.from(ImmutableList.of(RowType.field(createVarcharType(1)), RowType.field(INTEGER))), INTEGER), + ImmutableMap.of(ImmutableList.of("y", 2), 2)); + assertFunction( + "MAP_TOP_N(MAP(ARRAY[ROW('x', 1), ROW('x', -2)], ARRAY[2, 1]), 1)", + mapType(RowType.from(ImmutableList.of(RowType.field(createVarcharType(1)), RowType.field(INTEGER))), INTEGER), + ImmutableMap.of(ImmutableList.of("x", 1), 2)); + assertFunction( + "MAP_TOP_N(MAP(ARRAY[ROW('x', 1), ROW('x', -2), ROW('y', 1)], ARRAY[100, 200, null]), 3)", + mapType(RowType.from(ImmutableList.of(RowType.field(createVarcharType(1)), RowType.field(INTEGER))), INTEGER), + asMap(asList(ImmutableList.of("x", -2), ImmutableList.of("x", 1), ImmutableList.of("y", 1)), asList(200, 100, null))); + } +}