Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions presto-docs/src/main/sphinx/functions/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ Map Functions

.. function:: map_top_n(x(K,V), n) -> map(K, V)

Truncates map items. Keeps only the top N elements by value.
``n`` must be a non-negative integer.::
Truncates map items. Keeps only the top ``n`` elements by value. Keys are used to break ties with the max key being chosen. Both keys and values should be orderable.
``n`` must be a non-negative integer. ::

SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[2, 3, 1]), 2) --- {'b' -> 3, 'a' -> 2}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public static String mapKeysExists()
@SqlType("map(K, V)")
public static String mapTopN()
{
return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), map_from_entries(slice(array_sort(map_entries(map_filter(input, (k, v) -> v is not null)), (x, y) -> IF(x[2] < y[2], 1, IF(x[2] = y[2], IF(x[1] < y[1], 1, -1), -1))) || map_entries(map_filter(input, (k, v) -> v is null)), 1, n)))";
return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), map_from_entries(slice(array_sort(map_entries(map_filter(input, (k, v) -> v is not null)), (x, y) -> IF(x[2] < y[2], 1, IF(x[2] = y[2], IF(x[1] < y[1], 1, -1), -1))) "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required but if there's any plausible way to use indentation and line breaks here to make this code easier to read, that would help a lot. Without that, it's very hard to follow the logic.

+ "|| ARRAY_SORT(MAP_ENTRIES(MAP_FILTER(input, (k, v) -> v IS NULL)), (x, y) -> IF( x[1] < y[1], 1, -1)), 1, n)))";
}

@SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.HashMap;
import java.util.Map;

import static com.facebook.presto.common.type.DecimalType.createDecimalType;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
Expand Down Expand Up @@ -95,6 +98,25 @@ public void testEmpty()
public void testNull()
{
assertFunction("MAP_TOP_N(NULL, 1)", mapType(UNKNOWN, UNKNOWN), null);

// If values are null, then use keys to break ties.
Map<Integer, Integer> expected = new HashMap<Integer, Integer>() {{
put(4, 4);
put(3, 1);
put(5, null);
}};

assertFunction("MAP_TOP_N(MAP(ARRAY[1, 2, 3, 4, 5], ARRAY[NULL, NULL, 1, 4, NULL]), 3)", mapType(INTEGER, INTEGER),
expected);

Map<String, Integer> expectedStringKey = new HashMap<String, Integer>() {{
put("ef", 6);
put("cd", 4);
put("ab", -1);
put("hi", null);
}};
assertFunction("MAP_TOP_N(MAP(ARRAY['ab', 'bc', 'ef', 'cd', 'hi'], ARRAY[-1, NULL, 6, 4, NULL]), 4)", mapType(createVarcharType(2), INTEGER),
expectedStringKey);
}

@Test
Expand Down