diff --git a/presto-main/src/main/java/com/facebook/presto/operator/GroupedTopNBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/GroupedTopNBuilder.java index b4a3056d64f2a..228cadeced798 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/GroupedTopNBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/GroupedTopNBuilder.java @@ -17,10 +17,14 @@ import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.array.ObjectBigArray; import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.context.LocalMemoryContext; import com.facebook.presto.spi.function.aggregation.GroupByIdBlock; +import com.facebook.presto.sql.gen.JoinCompiler; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.AbstractIterator; import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; import it.unimi.dsi.fastutil.ints.IntArrayFIFOQueue; import it.unimi.dsi.fastutil.ints.IntIterator; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; @@ -31,9 +35,12 @@ import java.util.Comparator; import java.util.Iterator; import java.util.List; +import java.util.Optional; +import java.util.PrimitiveIterator; import java.util.stream.IntStream; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.operator.GroupByHash.createGroupByHash; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -55,7 +62,7 @@ public class GroupedTopNBuilder private final int topN; private final boolean produceRowNumber; private final GroupByHash groupByHash; - + private final OperatorContext operatorContext; // a map of heaps, each of which records the top N rows private final ObjectBigArray groupedRows = new ObjectBigArray<>(); // a list of input pages, each of which has information of which row in which heap references which position @@ -69,19 +76,55 @@ public class GroupedTopNBuilder // keeps track sizes of input pages and heaps private long memorySizeInBytes; private int currentPageCount; + private LocalMemoryContext localUserMemoryContext; + + public GroupedTopNBuilder( + OperatorContext operatorContext, + List sourceTypes, + List partitionTypes, + List partitionChannels, + Optional hashChannel, + int expectedPositions, + boolean isDictionaryAggregationEnabled, + JoinCompiler joinCompiler, + PageWithPositionComparator comparator, + int topN, + boolean produceRowNumber) + { + this( + operatorContext, + sourceTypes, + partitionTypes, + partitionChannels, + hashChannel, + expectedPositions, + isDictionaryAggregationEnabled, + joinCompiler, + comparator, + topN, + produceRowNumber, + UpdateMemory.NOOP); + } public GroupedTopNBuilder( + OperatorContext operatorContext, List sourceTypes, + List partitionTypes, + List partitionChannels, + Optional hashChannel, + int expectedPositions, + boolean isDictionaryAggregationEnabled, + JoinCompiler joinCompiler, PageWithPositionComparator comparator, int topN, boolean produceRowNumber, - GroupByHash groupByHash) + UpdateMemory updateMemory) { + this.operatorContext = operatorContext; this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]); checkArgument(topN > 0, "topN must be > 0"); this.topN = topN; this.produceRowNumber = produceRowNumber; - this.groupByHash = requireNonNull(groupByHash, "groupByHash is not null"); this.pageWithPositionComparator = requireNonNull(comparator, "comparator is null"); // Note: this is comparator intentionally swaps left and right arguments form a "reverse order" comparator @@ -91,6 +134,21 @@ public GroupedTopNBuilder( pageReferences.get(right.getPageId()).getPage(), right.getPosition()); this.emptyPageReferenceSlots = new IntFIFOQueue(); + + if (!partitionChannels.isEmpty()) { + checkArgument(expectedPositions > 0, "expectedPositions must be > 0"); + this.groupByHash = createGroupByHash( + partitionTypes, + Ints.toArray(partitionChannels), + hashChannel, + expectedPositions, + isDictionaryAggregationEnabled, + joinCompiler, + updateMemory); + } + else { + this.groupByHash = new NoChannelGroupByHash(); + } } public Work processPage(Page page) @@ -105,7 +163,7 @@ public Work processPage(Page page) public Iterator buildResult() { - return new ResultIterator(); + return new ResultIterator(IntStream.range(0, groupByHash.getGroupCount()).iterator()); } public long getEstimatedSizeInBytes() @@ -398,10 +456,8 @@ private class ResultIterator private static final int UNUSED_CAPACITY_DISPOSAL_THRESHOLD = 4096; private final PageBuilder pageBuilder; - // we may have 0 groups if there is no input page processed - private final int groupCount = groupByHash.getGroupCount(); + private final PrimitiveIterator.OfInt groupIds; - private int currentGroupNumber; private long currentGroupSizeInBytes; // the row number of the current position in the group @@ -411,7 +467,7 @@ private class ResultIterator private ObjectBigArray currentRows; - ResultIterator() + ResultIterator(PrimitiveIterator.OfInt groupIds) { if (produceRowNumber) { pageBuilder = new PageBuilder(new ImmutableList.Builder().add(sourceTypes).add(BIGINT).build()); @@ -421,6 +477,7 @@ private class ResultIterator } // Populate the first group currentRows = new ObjectBigArray<>(); + this.groupIds = groupIds; nextGroupedRows(); } @@ -471,11 +528,10 @@ protected Page computeNext() private void nextGroupedRows() { - if (currentGroupNumber < groupCount) { - RowHeap rows = groupedRows.getAndSet(currentGroupNumber, null); - verify(rows != null && !rows.isEmpty(), "impossible to have inserted a group without a witness row"); + if (this.groupIds.hasNext()) { + RowHeap rows = groupedRows.getAndSet(this.groupIds.nextInt(), null); + verify(rows != null && !rows.isEmpty(), "impossible to have inserted a group without a witness row. rows=%s for %s", rows, this); currentGroupSizeInBytes = rows.getEstimatedSizeInBytes(); - currentGroupNumber++; currentGroupSize = rows.size(); // sort output rows in a big array in case there are too many rows @@ -495,4 +551,9 @@ private void nextGroupedRows() } } } + + public GroupByHash getGroupByHash() + { + return groupByHash; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TopNOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/TopNOperator.java index 29ca992bfc59d..9fd8018ee4180 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/TopNOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/TopNOperator.java @@ -27,6 +27,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static java.util.Collections.emptyIterator; +import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -106,18 +107,23 @@ public TopNOperator( this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.localUserMemoryContext = operatorContext.localUserMemoryContext(); checkArgument(n >= 0, "n must be positive"); - if (n == 0) { finishing = true; outputIterator = emptyIterator(); } else { topNBuilder = new GroupedTopNBuilder( + operatorContext, types, + emptyList(), + emptyList(), + null, + 0, + false, + null, new SimplePageWithPositionComparator(types, sortChannels, sortOrders), n, - false, - new NoChannelGroupByHash()); + false); } } @@ -152,7 +158,6 @@ public void addInput(Page page) boolean done = topNBuilder.processPage(requireNonNull(page, "page is null")).process(); // there is no grouping so work will always be done verify(done); - updateMemoryReservation(); } @Override @@ -174,15 +179,9 @@ public Page getOutput() else { outputIterator = emptyIterator(); } - updateMemoryReservation(); return output; } - private void updateMemoryReservation() - { - localUserMemoryContext.setBytes(topNBuilder.getEstimatedSizeInBytes()); - } - private boolean noMoreOutput() { return outputIterator != null && !outputIterator.hasNext(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TopNRowNumberOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/TopNRowNumberOperator.java index b0b23aa666806..3955b3bc1c272 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/TopNRowNumberOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/TopNRowNumberOperator.java @@ -29,7 +29,6 @@ import static com.facebook.presto.SystemSessionProperties.isDictionaryAggregationEnabled; import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.operator.GroupByHash.createGroupByHash; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; @@ -130,7 +129,6 @@ public OperatorFactory duplicate() private final int[] outputChannels; - private final GroupByHash groupByHash; private final GroupedTopNBuilder groupedTopNBuilder; private boolean finishing; @@ -165,28 +163,21 @@ public TopNRowNumberOperator( checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0"); - if (!partitionChannels.isEmpty()) { - checkArgument(expectedPositions > 0, "expectedPositions must be > 0"); - groupByHash = createGroupByHash( - partitionTypes, - Ints.toArray(partitionChannels), - hashChannel, - expectedPositions, - isDictionaryAggregationEnabled(operatorContext.getSession()), - joinCompiler, - this::updateMemoryReservation); - } - else { - groupByHash = new NoChannelGroupByHash(); - } - List types = toTypes(sourceTypes, outputChannels, generateRowNumber); + this.groupedTopNBuilder = new GroupedTopNBuilder( + operatorContext, ImmutableList.copyOf(sourceTypes), + partitionTypes, + partitionChannels, + hashChannel, + expectedPositions, + isDictionaryAggregationEnabled(operatorContext.getSession()), + joinCompiler, new SimplePageWithPositionComparator(types, sortChannels, sortOrders), maxRowCountPerPartition, generateRowNumber, - groupByHash); + this::updateMemoryReservation); } @Override @@ -226,7 +217,6 @@ public void addInput(Page page) if (unfinishedWork.process()) { unfinishedWork = null; } - updateMemoryReservation(); } @Override @@ -234,7 +224,6 @@ public Page getOutput() { if (unfinishedWork != null) { boolean finished = unfinishedWork.process(); - updateMemoryReservation(); if (!finished) { return null; } @@ -254,13 +243,13 @@ public Page getOutput() if (outputIterator.hasNext()) { output = outputIterator.next().extractChannels(outputChannels); } - updateMemoryReservation(); return output; } @VisibleForTesting public int getCapacity() { + GroupByHash groupByHash = groupedTopNBuilder.getGroupByHash(); checkState(groupByHash != null); return groupByHash.getCapacity(); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkGroupedTopNBuilder.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkGroupedTopNBuilder.java index bccfd7a71efcd..f08adf72a9dac 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkGroupedTopNBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkGroupedTopNBuilder.java @@ -38,8 +38,10 @@ import org.testng.annotations.Test; import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.Random; import java.util.concurrent.TimeUnit; @@ -97,11 +99,35 @@ public void setup() GroupByHash groupByHash; if (groupCount > 1) { groupByHash = new BigintGroupByHash(HASH_GROUP, true, groupCount, UpdateMemory.NOOP); + topNBuilder = new GroupedTopNBuilder( + null, + types, + ImmutableList.of(types.get(HASH_GROUP)), + ImmutableList.of(HASH_GROUP), + Optional.of(HASH_GROUP), + groupCount, + false, + null, + comparator, + topN, + false, + UpdateMemory.NOOP); } else { - groupByHash = new NoChannelGroupByHash(); + topNBuilder = new GroupedTopNBuilder( + null, + types, + Collections.emptyList(), + Collections.emptyList(), + null, + groupCount, + false, + null, + comparator, + topN, + false, + UpdateMemory.NOOP); } - topNBuilder = new GroupedTopNBuilder(types, comparator, topN, false, groupByHash); } public GroupedTopNBuilder getTopNBuilder() diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestGroupedTopNBuilder.java b/presto-main/src/test/java/com/facebook/presto/operator/TestGroupedTopNBuilder.java index 945e32ab361ec..dde2a16b91756 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestGroupedTopNBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestGroupedTopNBuilder.java @@ -40,7 +40,6 @@ import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.operator.PageAssertions.assertPageEquals; -import static com.facebook.presto.operator.UpdateMemory.NOOP; import static io.airlift.slice.SizeOf.sizeOf; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -71,13 +70,20 @@ public static Object[][] pageRowCounts() public void testEmptyInput() { GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + null, ImmutableList.of(BIGINT), + Collections.emptyList(), + Collections.emptyList(), + null, + 0, + false, + null, (left, leftPosition, right, rightPosition) -> { throw new UnsupportedOperationException(); }, 5, - false, - new NoChannelGroupByHash()); + false); + assertFalse(groupedTopNBuilder.buildResult().hasNext()); } @@ -106,13 +112,20 @@ public void testMultiGroupTopN(boolean produceRowNumbers) page.compact(); } - GroupByHash groupByHash = createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0), NOOP); GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + null, types, + ImmutableList.of(types.get(0)), + ImmutableList.of(0), + Optional.empty(), + 2, + false, + null, new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), 2, - produceRowNumbers, - groupByHash); + produceRowNumbers); + + GroupByHash groupByHash = groupedTopNBuilder.getGroupByHash(); assertBuilderSize(groupByHash, types, ImmutableList.of(), ImmutableList.of(), groupedTopNBuilder.getEstimatedSizeInBytes()); // add 4 rows for the first page and created three heaps with 1, 1, 2 rows respectively @@ -180,28 +193,36 @@ public void testSingleGroupTopN(boolean produceRowNumbers) } GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + null, types, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + 0, + false, + null, new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), 5, - produceRowNumbers, - new NoChannelGroupByHash()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(), ImmutableList.of(), groupedTopNBuilder.getEstimatedSizeInBytes()); + produceRowNumbers); + + GroupByHash groupByHash = groupedTopNBuilder.getGroupByHash(); + assertBuilderSize(groupByHash, types, ImmutableList.of(), ImmutableList.of(), groupedTopNBuilder.getEstimatedSizeInBytes()); // add 4 rows for the first page and created a single heap with 4 rows assertTrue(groupedTopNBuilder.processPage(input.get(0)).process()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(4), ImmutableList.of(4), groupedTopNBuilder.getEstimatedSizeInBytes()); + assertBuilderSize(groupByHash, types, ImmutableList.of(4), ImmutableList.of(4), groupedTopNBuilder.getEstimatedSizeInBytes()); // add 1 row for the second page and the heap is with 5 rows assertTrue(groupedTopNBuilder.processPage(input.get(1)).process()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(4, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); + assertBuilderSize(groupByHash, types, ImmutableList.of(4, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); // update 1 new row from the third page (which will be compacted into a single row only) assertTrue(groupedTopNBuilder.processPage(input.get(2)).process()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(4, 1, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); + assertBuilderSize(groupByHash, types, ImmutableList.of(4, 1, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); // the last page will be discarded assertTrue(groupedTopNBuilder.processPage(input.get(3)).process()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(4, 1, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); + assertBuilderSize(groupByHash, types, ImmutableList.of(4, 1, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); List output = ImmutableList.copyOf(groupedTopNBuilder.buildResult()); assertEquals(output.size(), 1); @@ -231,20 +252,30 @@ public void testYield() Page input = rowPagesBuilder(types) .row(1L, 0.3) .row(1L, 0.2) - .row(1L, 0.9) - .row(1L, 0.1) + .row(2L, 0.9) + .row(3L, 0.1) .build() .get(0); input.compact(); AtomicBoolean unblock = new AtomicBoolean(); - GroupByHash groupByHash = createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0), unblock::get); GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + null, types, + ImmutableList.of(types.get(0)), + ImmutableList.of(0), + Optional.empty(), + 1, + false, + null, new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), 5, false, - groupByHash); + () -> { + return unblock.get(); + }); + + GroupByHash groupByHash = groupedTopNBuilder.getGroupByHash(); assertBuilderSize(groupByHash, types, ImmutableList.of(), ImmutableList.of(), groupedTopNBuilder.getEstimatedSizeInBytes()); Work work = groupedTopNBuilder.processPage(input); @@ -256,10 +287,10 @@ public void testYield() assertEquals(output.size(), 1); Page expected = rowPagesBuilder(types) - .row(1L, 0.1) .row(1L, 0.2) .row(1L, 0.3) - .row(1L, 0.9) + .row(2L, 0.9) + .row(3L, 0.1) .build() .get(0); assertPageEquals(types, output.get(0), expected); @@ -290,11 +321,17 @@ public void testAutoCompact() .build(); GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + null, types, - new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), + ImmutableList.of(types.get(0)), + ImmutableList.of(0), + Optional.empty(), 1, false, - createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0), NOOP)); + null, + new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), + 1, + false); // page 1: // the first page will be compacted @@ -382,15 +419,21 @@ public void testLargePagesMemoryTracking(int pageCount, int rowCount) } List input = rowPagesBuilder.build(); - GroupByHash groupByHash = createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0), NOOP); GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + null, types, + ImmutableList.of(types.get(0)), + ImmutableList.of(0), + Optional.empty(), + 1, + false, + new JoinCompiler(createTestMetadataManager(), new FeaturesConfig()), new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), pageCount * rowCount, - false, - groupByHash); + false); // Assert memory usage gradually goes up + GroupByHash groupByHash = groupedTopNBuilder.getGroupByHash(); for (int i = 0; i < pageCount; i++) { assertTrue(groupedTopNBuilder.processPage(input.get(i)).process()); assertBuilderSize(groupByHash, types, Collections.nCopies(i + 1, rowCount), Collections.nCopies(rowCount, i + 1), groupedTopNBuilder.getEstimatedSizeInBytes());