diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java index 5555cff238a17..bcbd29d78dc91 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java @@ -74,6 +74,11 @@ public boolean isEmpty() return positionCount == 0; } + public Type getKeyType() + { + return keyType; + } + public void serialize(BlockBuilder out) { BlockBuilder blockBuilder = out.beginBlockEntry(); @@ -111,12 +116,26 @@ public void popAll(BlockBuilder resultBlockBuilder) } } + public void popAll(BlockBuilder valueResultBlockBuilder, BlockBuilder keyResultBlockBuilder) + { + while (positionCount > 0) { + pop(valueResultBlockBuilder, keyResultBlockBuilder); + } + } + public void pop(BlockBuilder resultBlockBuilder) { valueType.appendTo(valueBlockBuilder, heapIndex[0], resultBlockBuilder); remove(); } + public void pop(BlockBuilder valueResultBlockBuilder, BlockBuilder keyResultBlockBuilder) + { + valueType.appendTo(valueBlockBuilder, heapIndex[0], valueResultBlockBuilder); + keyType.appendTo(keyBlockBuilder, heapIndex[0], keyResultBlockBuilder); + remove(); + } + private void remove() { positionCount--; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java index e5778b53e9d5b..4d5d20a12e40a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java @@ -140,16 +140,17 @@ public static void output(ArrayType outputType, MinMaxByNState state, BlockBuild } Type elementType = outputType.getElementType(); + Type keyType = heap.getKeyType(); BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); - BlockBuilder reversedBlockBuilder = elementType.createBlockBuilder(null, heap.getCapacity()); - long startSize = heap.getEstimatedSize(); - heap.popAll(reversedBlockBuilder); - state.addMemoryUsage(heap.getEstimatedSize() - startSize); + BlockBuilder reversedValueBlockBuilder = elementType.createBlockBuilder(null, heap.getCapacity()); + BlockBuilder reversedKeyBlockBuilder = keyType.createBlockBuilder(null, heap.getCapacity()); + heap.popAll(reversedValueBlockBuilder, reversedKeyBlockBuilder); - for (int i = reversedBlockBuilder.getPositionCount() - 1; i >= 0; i--) { - elementType.appendTo(reversedBlockBuilder, i, arrayBlockBuilder); + for (int i = reversedValueBlockBuilder.getPositionCount() - 1; i >= 0; i--) { + elementType.appendTo(reversedValueBlockBuilder, i, arrayBlockBuilder); } + heap.addAll(reversedKeyBlockBuilder, reversedValueBlockBuilder); out.closeEntry(); } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 27442bcbab82e..dd0c5b3a114a3 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -1816,6 +1816,21 @@ public void testMinMaxN() "SELECT orderkey, ARRAY[cast(1 as bigint), cast(2 as bigint), cast(3 as bigint)] t FROM orders"); } + @Test + public void testMinMaxByN() + { + assertQuery("SELECT MIN_BY(c0, c0, c1) OVER ( PARTITION BY c2 ORDER BY c3 ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW ) FROM " + + "( VALUES (1, 10, FALSE, 0), (2, 10, FALSE, 1) ) AS t(c0, c1, c2, c3)", "values array[1], array[1, 2]"); + assertQuery("SELECT MIN_BY(c0, c3, c1) OVER ( PARTITION BY c2 ORDER BY c3 ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW ) FROM " + + "( VALUES (1, 10, FALSE, 2), (2, 10, FALSE, 1) ) AS t(c0, c1, c2, c3)", "values array[2], array[2, 1]"); + assertQuery("select max_by(k1, k2, 2) over (partition by k3 order by k4) from (values (1, 'a', 1, 3), (2, 'b', 1, 2), (5, 'd', 2, 1), (8, 'c', 2, 0)) " + + "t(k1, k2, k3, k4)", "values array[8], array[5, 8], array[2], array[2, 1]"); + assertQuery("select max_by(k1, k2, 2) over (partition by k3 order by k4) from (values (1, 'a', 1, 3), (2, 'b', 1, 2), (5, 'd', 2, 1), (8, 'c', 2, 0), (7, 'e', 1, 8), " + + "(9, 'f', 2, 10), (0, 'g', 1, 9)) t(k1, k2, k3, k4)", "values array[8], array[5, 8], array[9, 5], array[2], array[2, 1], array[7, 2], array[0, 7]"); + assertQuery("select min_by(k1, k2, 2) over (partition by k3 order by k4) from (values (1, 'a', 1, 3), (2, 'b', 1, 2), (5, 'd', 2, 1), (8, 'c', 2, 0), (7, 'e', 1, 8), " + + "(9, 'f', 2, 10), (0, 'g', 1, 9)) t(k1, k2, k3, k4)", "values array[2], array[1, 2], array[1, 2], array[1, 2], array[8], array[8, 5], array[8, 5]"); + } + @Test public void testRowNumberFilterAndLimit() {