diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 539813c87b622..54645ee2ed94a 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -131,6 +131,7 @@ public final class SystemSessionProperties public static final String SPILL_ENABLED = "spill_enabled"; public static final String JOIN_SPILL_ENABLED = "join_spill_enabled"; public static final String AGGREGATION_SPILL_ENABLED = "aggregation_spill_enabled"; + public static final String TOPN_SPILL_ENABLED = "topn_spill_enabled"; public static final String DISTINCT_AGGREGATION_SPILL_ENABLED = "distinct_aggregation_spill_enabled"; public static final String DEDUP_BASED_DISTINCT_AGGREGATION_SPILL_ENABLED = "dedup_based_distinct_aggregation_spill_enabled"; public static final String DISTINCT_AGGREGATION_LARGE_BLOCK_SPILL_ENABLED = "distinct_aggregation_large_block_spill_enabled"; @@ -139,6 +140,7 @@ public final class SystemSessionProperties public static final String WINDOW_SPILL_ENABLED = "window_spill_enabled"; public static final String ORDER_BY_SPILL_ENABLED = "order_by_spill_enabled"; public static final String AGGREGATION_OPERATOR_UNSPILL_MEMORY_LIMIT = "aggregation_operator_unspill_memory_limit"; + public static final String TOPN_OPERATOR_UNSPILL_MEMORY_LIMIT = "topn_operator_unspill_memory_limit"; public static final String QUERY_MAX_REVOCABLE_MEMORY_PER_NODE = "query_max_revocable_memory_per_node"; public static final String TEMP_STORAGE_SPILLER_BUFFER_SIZE = "temp_storage_spiller_buffer_size"; public static final String OPTIMIZE_DISTINCT_AGGREGATIONS = "optimize_mixed_distinct_aggregations"; @@ -689,6 +691,11 @@ public SystemSessionProperties( "Enable aggregate spilling if spill_enabled", featuresConfig.isAggregationSpillEnabled(), false), + booleanProperty( + TOPN_SPILL_ENABLED, + "Enable topN spilling if spill_enabled", + featuresConfig.isTopNSpillEnabled(), + false), booleanProperty( DISTINCT_AGGREGATION_SPILL_ENABLED, "Enable spill for distinct aggregations if spill_enabled and aggregation_spill_enabled", @@ -737,6 +744,15 @@ public SystemSessionProperties( false, value -> DataSize.valueOf((String) value), DataSize::toString), + new PropertyMetadata<>( + TOPN_OPERATOR_UNSPILL_MEMORY_LIMIT, + "How much memory can should be allocated per topN operator in unspilling process", + VARCHAR, + DataSize.class, + featuresConfig.getTopNOperatorUnspillMemoryLimit(), + false, + value -> DataSize.valueOf((String) value), + DataSize::toString), new PropertyMetadata<>( QUERY_MAX_REVOCABLE_MEMORY_PER_NODE, "Maximum amount of revocable memory a query can use", @@ -1762,6 +1778,11 @@ public static boolean isAggregationSpillEnabled(Session session) return session.getSystemProperty(AGGREGATION_SPILL_ENABLED, Boolean.class) && isSpillEnabled(session); } + public static boolean isTopNSpillEnabled(Session session) + { + return session.getSystemProperty(TOPN_SPILL_ENABLED, Boolean.class) && isSpillEnabled(session); + } + public static boolean isDistinctAggregationSpillEnabled(Session session) { return session.getSystemProperty(DISTINCT_AGGREGATION_SPILL_ENABLED, Boolean.class) && isAggregationSpillEnabled(session); @@ -1804,6 +1825,13 @@ public static DataSize getAggregationOperatorUnspillMemoryLimit(Session session) return memoryLimitForMerge; } + public static DataSize getTopNOperatorUnspillMemoryLimit(Session session) + { + DataSize unspillMemoryLimit = session.getSystemProperty(TOPN_OPERATOR_UNSPILL_MEMORY_LIMIT, DataSize.class); + checkArgument(unspillMemoryLimit.toBytes() >= 0, "%s must be positive", TOPN_OPERATOR_UNSPILL_MEMORY_LIMIT); + return unspillMemoryLimit; + } + public static DataSize getQueryMaxRevocableMemoryPerNode(Session session) { return session.getSystemProperty(QUERY_MAX_REVOCABLE_MEMORY_PER_NODE, DataSize.class); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/GroupByHash.java b/presto-main/src/main/java/com/facebook/presto/operator/GroupByHash.java index aa3eac4131fce..ffa771daa48de 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/GroupByHash.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/GroupByHash.java @@ -16,10 +16,12 @@ import com.facebook.presto.Session; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.array.IntBigArray; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.function.aggregation.GroupByIdBlock; import com.facebook.presto.sql.gen.JoinCompiler; import com.google.common.annotations.VisibleForTesting; +import it.unimi.dsi.fastutil.ints.IntIterator; import java.util.List; import java.util.Optional; @@ -85,4 +87,34 @@ default boolean contains(int position, Page page, int[] hashChannels, long rawHa @VisibleForTesting int getCapacity(); + + default IntIterator getHashSortedGroupIds() + { + IntBigArray groupIds = new IntBigArray(); + groupIds.ensureCapacity(getGroupCount()); + for (int i = 0; i < getGroupCount(); i++) { + groupIds.set(i, i); + } + + groupIds.sort(0, getGroupCount(), (leftGroupId, rightGroupId) -> + Long.compare(getRawHash(leftGroupId), getRawHash(rightGroupId))); + + return new IntIterator() + { + private final int totalPositions = getGroupCount(); + private int position; + + @Override + public boolean hasNext() + { + return position < totalPositions; + } + + @Override + public int nextInt() + { + return groupIds.get(position++); + } + }; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/GroupedTopNBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/GroupedTopNBuilder.java index b4a3056d64f2a..19f6d1e0e4ea4 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/GroupedTopNBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/GroupedTopNBuilder.java @@ -14,485 +14,29 @@ package com.facebook.presto.operator; import com.facebook.presto.common.Page; -import com.facebook.presto.common.PageBuilder; -import com.facebook.presto.common.array.ObjectBigArray; -import com.facebook.presto.common.type.Type; -import com.facebook.presto.spi.function.aggregation.GroupByIdBlock; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.AbstractIterator; -import com.google.common.collect.ImmutableList; -import it.unimi.dsi.fastutil.ints.IntArrayFIFOQueue; -import it.unimi.dsi.fastutil.ints.IntIterator; -import it.unimi.dsi.fastutil.ints.IntOpenHashSet; -import it.unimi.dsi.fastutil.ints.IntSet; -import it.unimi.dsi.fastutil.objects.ObjectHeapPriorityQueue; -import org.openjdk.jol.info.ClassLayout; +import com.google.common.util.concurrent.ListenableFuture; -import java.util.Comparator; import java.util.Iterator; -import java.util.List; -import java.util.stream.IntStream; -import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.slice.SizeOf.sizeOf; -import static java.util.Objects.requireNonNull; - -/** - * This class finds the top N rows defined by {@param comparator} for each group specified by {@param groupByHash}. - */ -public class GroupedTopNBuilder +public interface GroupedTopNBuilder { - private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupedTopNBuilder.class).instanceSize(); - // compact a page when 50% of its positions are unreferenced - private static final int COMPACT_THRESHOLD = 2; - - private final Type[] sourceTypes; - private final int topN; - private final boolean produceRowNumber; - private final GroupByHash groupByHash; - - // a map of heaps, each of which records the top N rows - private final ObjectBigArray groupedRows = new ObjectBigArray<>(); - // a list of input pages, each of which has information of which row in which heap references which position - private final ObjectBigArray pageReferences = new ObjectBigArray<>(); - // for heap element comparison - private final PageWithPositionComparator pageWithPositionComparator; - private final Comparator rowHeapComparator; - // when there is no row referenced in a page, it will be removed instead of compacted; use a list to record those empty slots to reuse them - private final IntFIFOQueue emptyPageReferenceSlots; - - // keeps track sizes of input pages and heaps - private long memorySizeInBytes; - private int currentPageCount; - - public GroupedTopNBuilder( - List sourceTypes, - PageWithPositionComparator comparator, - int topN, - boolean produceRowNumber, - GroupByHash groupByHash) - { - this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]); - checkArgument(topN > 0, "topN must be > 0"); - this.topN = topN; - this.produceRowNumber = produceRowNumber; - this.groupByHash = requireNonNull(groupByHash, "groupByHash is not null"); - - this.pageWithPositionComparator = requireNonNull(comparator, "comparator is null"); - // Note: this is comparator intentionally swaps left and right arguments form a "reverse order" comparator - this.rowHeapComparator = (right, left) -> this.pageWithPositionComparator.compareTo( - pageReferences.get(left.getPageId()).getPage(), - left.getPosition(), - pageReferences.get(right.getPageId()).getPage(), - right.getPosition()); - this.emptyPageReferenceSlots = new IntFIFOQueue(); - } - - public Work processPage(Page page) - { - return new TransformWork<>( - groupByHash.getGroupIds(page), - groupIds -> { - processPage(page, groupIds); - return null; - }); - } - - public Iterator buildResult() - { - return new ResultIterator(); - } - - public long getEstimatedSizeInBytes() - { - return INSTANCE_SIZE + - memorySizeInBytes + - groupByHash.getEstimatedSize() + - groupedRows.sizeOf() + - pageReferences.sizeOf() + - emptyPageReferenceSlots.getEstimatedSizeInBytes(); - } - - @VisibleForTesting - List getBufferedPages() - { - return IntStream.range(0, currentPageCount) - .filter(i -> pageReferences.get(i) != null) - .mapToObj(i -> pageReferences.get(i).getPage()) - .collect(toImmutableList()); - } - - private void processPage(Page newPage, GroupByIdBlock groupIds) - { - checkArgument(newPage != null); - checkArgument(groupIds != null); - - int firstPositionToInsert = findFirstPositionToInsert(newPage, groupIds); - if (firstPositionToInsert < 0) { - // no insertions required - return; - } - - PageReference newPageReference = new PageReference(newPage); - memorySizeInBytes += newPageReference.getEstimatedSizeInBytes(); - int newPageId; - if (emptyPageReferenceSlots.isEmpty()) { - // all the previous slots are full; create a new one - pageReferences.ensureCapacity(currentPageCount + 1); - newPageId = currentPageCount; - currentPageCount++; - } - else { - // reuse a previously removed page's slot - newPageId = emptyPageReferenceSlots.dequeueInt(); - } - verify(pageReferences.setIfNull(newPageId, newPageReference), "should not overwrite a non-empty slot"); - - // ensure sufficient group capacity outside of the loop - groupedRows.ensureCapacity(groupIds.getGroupCount()); - // update the affected heaps and record candidate pages that need compaction - IntSet pagesToCompact = new IntOpenHashSet(); - for (int position = firstPositionToInsert; position < newPage.getPositionCount(); position++) { - long groupId = groupIds.getGroupId(position); - RowHeap rows = groupedRows.get(groupId); - if (rows == null) { - // a new group - rows = new RowHeap(rowHeapComparator); - groupedRows.set(groupId, rows); - } - else { - // update an existing group; - // remove the memory usage for this group for now; add it back after update - memorySizeInBytes -= rows.getEstimatedSizeInBytes(); - } - - if (rows.size() < topN) { - Row row = new Row(newPageId, position); - newPageReference.reference(row); - rows.enqueue(row); - } - else { - // may compare with the topN-th element with in the heap to decide if update is necessary - Row previousRow = rows.first(); - PageReference previousPageReference = pageReferences.get(previousRow.getPageId()); - if (pageWithPositionComparator.compareTo(newPage, position, previousPageReference.getPage(), previousRow.getPosition()) < 0) { - // update reference and the heap - rows.dequeue(); - previousPageReference.dereference(previousRow.getPosition()); - - Row newRow = new Row(newPageId, position); - newPageReference.reference(newRow); - rows.enqueue(newRow); - - // compact a page if it is not the current input page and the reference count is below the threshold - if (previousPageReference.getPage() != newPage && - previousPageReference.getUsedPositionCount() * COMPACT_THRESHOLD < previousPageReference.getPage().getPositionCount()) { - pagesToCompact.add(previousRow.getPageId()); - } - } - } - - memorySizeInBytes += rows.getEstimatedSizeInBytes(); - } - - // may compact the new page as well - if (newPageReference.getUsedPositionCount() * COMPACT_THRESHOLD < newPage.getPositionCount()) { - verify(pagesToCompact.add(newPageId)); - } - - // compact pages - IntIterator iterator = pagesToCompact.iterator(); - while (iterator.hasNext()) { - int pageId = iterator.nextInt(); - PageReference pageReference = pageReferences.get(pageId); - if (pageReference.getUsedPositionCount() == 0) { - pageReferences.set(pageId, null); - emptyPageReferenceSlots.enqueue(pageId); - memorySizeInBytes -= pageReference.getEstimatedSizeInBytes(); - } - else { - memorySizeInBytes -= pageReference.getEstimatedSizeInBytes(); - pageReference.compact(); - memorySizeInBytes += pageReference.getEstimatedSizeInBytes(); - } - } - } - - private int findFirstPositionToInsert(Page newPage, GroupByIdBlock groupIds) - { - for (int position = 0; position < newPage.getPositionCount(); position++) { - long groupId = groupIds.getGroupId(position); - if (groupedRows.getCapacity() <= groupId) { - return position; - } - - RowHeap rows = groupedRows.get(groupId); - if (rows == null || rows.size() < topN) { - return position; - } - // check against current minimum - Row previousRow = rows.first(); - PageReference pageReference = pageReferences.get(previousRow.getPageId()); - if (pageWithPositionComparator.compareTo(newPage, position, pageReference.getPage(), previousRow.getPosition()) < 0) { - return position; - } - } - // no positions to insert - return -1; - } - - /** - * The class is a pointer to a row in a page. - * The actual position in the page is mutable because as pages are compacted, the position will change. - */ - private static class Row - { - private final int pageId; - private int position; - - private Row(int pageId, int position) - { - this.pageId = pageId; - reset(position); - } - - public void reset(int position) - { - this.position = position; - } - - public int getPageId() - { - return pageId; - } - - public int getPosition() - { - return position; - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("pageId", pageId) - .add("position", position) - .toString(); - } - } - - private static class PageReference - { - private static final long INSTANCE_SIZE = ClassLayout.parseClass(PageReference.class).instanceSize(); - - private Page page; - private Row[] reference; - - private int usedPositionCount; - - public PageReference(Page page) - { - this.page = requireNonNull(page, "page is null"); - this.reference = new Row[page.getPositionCount()]; - } - - public void reference(Row row) - { - reference[row.getPosition()] = row; - usedPositionCount++; - } - - public boolean dereference(int position) - { - checkArgument(reference[position] != null && usedPositionCount > 0); - reference[position] = null; - return (--usedPositionCount) == 0; - } - - public int getUsedPositionCount() - { - return usedPositionCount; - } - - public void compact() - { - checkState(usedPositionCount > 0); - - if (usedPositionCount == page.getPositionCount()) { - return; - } - - // re-assign reference - Row[] newReference = new Row[usedPositionCount]; - int[] positions = new int[usedPositionCount]; - int index = 0; - // update all the elements in the heaps that reference the current page - // this does not change the elements in the heap; - // it only updates the value of the elements; while keeping the same order - for (int i = 0; i < reference.length && index < usedPositionCount; i++) { - Row value = reference[i]; - if (value != null) { - value.reset(index); - newReference[index] = value; - positions[index] = i; - index++; - } - } - verify(index == usedPositionCount); - - // compact page - page = page.copyPositions(positions, 0, usedPositionCount); - reference = newReference; - } - - public Page getPage() - { - return page; - } - - public long getEstimatedSizeInBytes() - { - return page.getRetainedSizeInBytes() + sizeOf(reference) + INSTANCE_SIZE; - } - } - - // this class is for precise memory tracking - private static class IntFIFOQueue - extends IntArrayFIFOQueue - { - private static final long INSTANCE_SIZE = ClassLayout.parseClass(IntFIFOQueue.class).instanceSize(); - - private long getEstimatedSizeInBytes() - { - return INSTANCE_SIZE + sizeOf(array); - } - } - - // this class is for precise memory tracking - private static class RowHeap - extends ObjectHeapPriorityQueue - { - private static final long INSTANCE_SIZE = ClassLayout.parseClass(RowHeap.class).instanceSize(); - private static final long ROW_ENTRY_SIZE = ClassLayout.parseClass(Row.class).instanceSize(); - - private RowHeap(Comparator comparator) - { - super(1, comparator); - } - - private long getEstimatedSizeInBytes() - { - return INSTANCE_SIZE + sizeOf(heap) + size() * ROW_ENTRY_SIZE; - } - } - - private class ResultIterator - extends AbstractIterator - { - // ObjectBigArray capacity is always at least 1024, so discarding "small" BigArrays even if you don't need the entire space is wasteful - private static final int UNUSED_CAPACITY_DISPOSAL_THRESHOLD = 4096; - - private final PageBuilder pageBuilder; - // we may have 0 groups if there is no input page processed - private final int groupCount = groupByHash.getGroupCount(); - - private int currentGroupNumber; - private long currentGroupSizeInBytes; - - // the row number of the current position in the group - private int currentGroupPosition; - // number of rows in the group - private int currentGroupSize; + Work processPage(Page page); - private ObjectBigArray currentRows; + WorkProcessor buildResult(); - ResultIterator() - { - if (produceRowNumber) { - pageBuilder = new PageBuilder(new ImmutableList.Builder().add(sourceTypes).add(BIGINT).build()); - } - else { - pageBuilder = new PageBuilder(ImmutableList.copyOf(sourceTypes)); - } - // Populate the first group - currentRows = new ObjectBigArray<>(); - nextGroupedRows(); - } + ListenableFuture startMemoryRevoke(); - @Override - protected Page computeNext() - { - pageBuilder.reset(); - while (!pageBuilder.isFull()) { - if (currentRows == null) { - // no more groups - break; - } - if (currentGroupPosition == currentGroupSize) { - // the current group has produced all its rows - memorySizeInBytes -= currentGroupSizeInBytes; - currentGroupPosition = 0; - nextGroupedRows(); - continue; - } + void finishMemoryRevoke(); - // Clear the reference to the Row after access to make it reclaimable by GC - Row row = currentRows.getAndSet(currentGroupPosition, null); - PageReference pageReference = pageReferences.get(row.getPageId()); - Page page = pageReference.getPage(); - int position = row.getPosition(); - for (int i = 0; i < sourceTypes.length; i++) { - sourceTypes[i].appendTo(page.getBlock(i), position, pageBuilder.getBlockBuilder(i)); - } + long getEstimatedSizeInBytes(); - if (produceRowNumber) { - BIGINT.writeLong(pageBuilder.getBlockBuilder(sourceTypes.length), currentGroupPosition + 1); - } - pageBuilder.declarePosition(); - currentGroupPosition++; + ListenableFuture updateMemoryReservations(); - // deference the row; no need to compact the pages but remove them if completely unused - if (pageReference.dereference(position)) { - pageReferences.set(row.getPageId(), null); - memorySizeInBytes -= pageReference.getEstimatedSizeInBytes(); - } - } + GroupByHash getGroupByHash(); - if (pageBuilder.isEmpty()) { - return endOfData(); - } - return pageBuilder.build(); - } + boolean isEmpty(); - private void nextGroupedRows() - { - if (currentGroupNumber < groupCount) { - RowHeap rows = groupedRows.getAndSet(currentGroupNumber, null); - verify(rows != null && !rows.isEmpty(), "impossible to have inserted a group without a witness row"); - currentGroupSizeInBytes = rows.getEstimatedSizeInBytes(); - currentGroupNumber++; - currentGroupSize = rows.size(); + void close(); - // sort output rows in a big array in case there are too many rows - checkState(currentRows != null, "currentRows already observed the final group"); - if (currentRows.getCapacity() > UNUSED_CAPACITY_DISPOSAL_THRESHOLD && currentRows.getCapacity() > currentGroupSize * 2L) { - // Discard over-sized big array to avoid unnecessary waste - currentRows = new ObjectBigArray<>(); - } - currentRows.ensureCapacity(currentGroupSize); - for (int index = currentGroupSize - 1; index >= 0; index--) { - currentRows.set(index, rows.dequeue()); - } - } - else { - currentRows = null; - currentGroupSize = 0; - } - } - } + Iterator buildHashSortedIntermediateResult(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/InMemoryGroupedTopNBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/InMemoryGroupedTopNBuilder.java new file mode 100644 index 0000000000000..98ad3c031e6bc --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/InMemoryGroupedTopNBuilder.java @@ -0,0 +1,598 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.array.ObjectBigArray; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.context.LocalMemoryContext; +import com.facebook.presto.spi.function.aggregation.GroupByIdBlock; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import it.unimi.dsi.fastutil.ints.IntArrayFIFOQueue; +import it.unimi.dsi.fastutil.ints.IntIterator; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import it.unimi.dsi.fastutil.objects.ObjectHeapPriorityQueue; +import org.openjdk.jol.info.ClassLayout; + +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.PrimitiveIterator; +import java.util.stream.IntStream; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.slice.SizeOf.sizeOf; +import static java.util.Collections.emptyIterator; +import static java.util.Objects.requireNonNull; + +/** + * This class finds the top N rows defined by {@param comparator} for each group specified by {@param groupByHash}. + * + * The 3 main datastructures used are GroupByHash, RowHeap[] and PageReferences. + * GroupByHash - Is HashTable used to compute the Groups each record belongs to and the + * RowHeap[] - Is an array of Heaps/Priority-Queue + * Each heap in the array tracks the TopN for the give group + * PageReferences - List of pointers to Actual Pages buffered so far. The RowHeap contains Rows which + * are wrapper class that points to Pages in the PageReference + * + * As we receive input we populate it into the HashTable and also populate it to the Heap. + */ +public class InMemoryGroupedTopNBuilder + implements GroupedTopNBuilder +{ + private static final long INSTANCE_SIZE = ClassLayout.parseClass(InMemoryGroupedTopNBuilder.class).instanceSize(); + // compact a page when 50% of its positions are unreferenced + private static final int COMPACT_THRESHOLD = 2; + + private final Type[] sourceTypes; + private final int topN; + private final boolean produceRowNumber; + private final GroupByHash groupByHash; + private LocalMemoryContext memoryContext; + + // a map of heaps, each of which records the top N rows + private final ObjectBigArray groupedRows = new ObjectBigArray<>(); + // a list of input pages, each of which has information of which row in which heap references which position + private final ObjectBigArray pageReferences = new ObjectBigArray<>(); + // for heap element comparison + private final PageWithPositionComparator pageWithPositionComparator; + private final Comparator rowHeapComparator; + // when there is no row referenced in a page, it will be removed instead of compacted; use a list to record those empty slots to reuse them + private final IntFIFOQueue emptyPageReferenceSlots; + + // keeps track sizes of input pages and heaps + private long memorySizeInBytes; + private int currentPageCount; + + public InMemoryGroupedTopNBuilder( + List sourceTypes, + PageWithPositionComparator comparator, + int topN, + boolean produceRowNumber, + LocalMemoryContext memoryContext, + GroupByHash groupByHash) + { + this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]); + checkArgument(topN > 0, "topN must be > 0"); + this.topN = topN; + this.produceRowNumber = produceRowNumber; + this.groupByHash = requireNonNull(groupByHash, "groupByHash is null"); + this.memoryContext = requireNonNull(memoryContext, "memoryContext is null"); + + this.pageWithPositionComparator = requireNonNull(comparator, "comparator is null"); + // Note: this is comparator intentionally swaps left and right arguments form a "reverse order" comparator + this.rowHeapComparator = (right, left) -> this.pageWithPositionComparator.compareTo( + pageReferences.get(left.getPageId()).getPage(), + left.getPosition(), + pageReferences.get(right.getPageId()).getPage(), + right.getPosition()); + this.emptyPageReferenceSlots = new IntFIFOQueue(); + } + + @Override + public Work processPage(Page page) + { + return new TransformWork<>( + groupByHash.getGroupIds(page), + groupIds -> { + processPage(page, groupIds); + return null; + }); + } + + @Override + public WorkProcessor buildResult() + { + if (groupByHash.getGroupCount() == 0) { + return WorkProcessor.fromIterator(emptyIterator()); + } + return WorkProcessor.fromIterator(new ResultIterator(IntStream.range(0, groupByHash.getGroupCount()).iterator(), false)); + } + + @Override + public ListenableFuture startMemoryRevoke() + { + throw new UnsupportedOperationException("InMemoryGroupedTopNBuilder does not support startMemoryRevoke"); + } + + @Override + public void finishMemoryRevoke() + { + throw new UnsupportedOperationException("InMemoryGroupedTopNBuilder does not support finishMemoryRevoke"); + } + + public long getEstimatedSizeInBytes() + { + return INSTANCE_SIZE + + memorySizeInBytes + + groupByHash.getEstimatedSize() + + groupedRows.sizeOf() + + pageReferences.sizeOf() + + getGroupIdsSortingSize() + + emptyPageReferenceSlots.getEstimatedSizeInBytes(); + } + + @Override + public ListenableFuture updateMemoryReservations() + { + return memoryContext.setBytes(getEstimatedSizeInBytes()); + } + + @VisibleForTesting + List getBufferedPages() + { + return IntStream.range(0, currentPageCount) + .filter(i -> pageReferences.get(i) != null) + .mapToObj(i -> pageReferences.get(i).getPage()) + .collect(toImmutableList()); + } + + private void processPage(Page newPage, GroupByIdBlock groupIds) + { + checkArgument(newPage != null); + checkArgument(groupIds != null); + + int firstPositionToInsert = findFirstPositionToInsert(newPage, groupIds); + if (firstPositionToInsert < 0) { + // no insertions required + return; + } + + PageReference newPageReference = new PageReference(newPage); + memorySizeInBytes += newPageReference.getEstimatedSizeInBytes(); + int newPageId; + if (emptyPageReferenceSlots.isEmpty()) { + // all the previous slots are full; create a new one + pageReferences.ensureCapacity(currentPageCount + 1); + newPageId = currentPageCount; + currentPageCount++; + } + else { + // reuse a previously removed page's slot + newPageId = emptyPageReferenceSlots.dequeueInt(); + } + verify(pageReferences.setIfNull(newPageId, newPageReference), "should not overwrite a non-empty slot"); + + // ensure sufficient group capacity outside of the loop + groupedRows.ensureCapacity(groupIds.getGroupCount()); + // update the affected heaps and record candidate pages that need compaction + IntSet pagesToCompact = new IntOpenHashSet(); + for (int position = firstPositionToInsert; position < newPage.getPositionCount(); position++) { + long groupId = groupIds.getGroupId(position); + RowHeap rows = groupedRows.get(groupId); + if (rows == null) { + // a new group + rows = new RowHeap(rowHeapComparator); + groupedRows.set(groupId, rows); + } + else { + // update an existing group; + // remove the memory usage for this group for now; add it back after update + memorySizeInBytes -= rows.getEstimatedSizeInBytes(); + } + + if (rows.size() < topN) { + Row row = new Row(newPageId, position); + newPageReference.reference(row); + rows.enqueue(row); + } + else { + // may compare with the topN-th element with in the heap to decide if update is necessary + Row previousRow = rows.first(); + PageReference previousPageReference = pageReferences.get(previousRow.getPageId()); + if (pageWithPositionComparator.compareTo(newPage, position, previousPageReference.getPage(), previousRow.getPosition()) < 0) { + // update reference and the heap + rows.dequeue(); + previousPageReference.dereference(previousRow.getPosition()); + + Row newRow = new Row(newPageId, position); + newPageReference.reference(newRow); + rows.enqueue(newRow); + + // compact a page if it is not the current input page and the reference count is below the threshold + if (previousPageReference.getPage() != newPage && + previousPageReference.getUsedPositionCount() * COMPACT_THRESHOLD < previousPageReference.getPage().getPositionCount()) { + pagesToCompact.add(previousRow.getPageId()); + } + } + } + + memorySizeInBytes += rows.getEstimatedSizeInBytes(); + } + + // may compact the new page as well + if (newPageReference.getUsedPositionCount() * COMPACT_THRESHOLD < newPage.getPositionCount()) { + verify(pagesToCompact.add(newPageId)); + } + + // compact pages + IntIterator iterator = pagesToCompact.iterator(); + while (iterator.hasNext()) { + int pageId = iterator.nextInt(); + PageReference pageReference = pageReferences.get(pageId); + if (pageReference.getUsedPositionCount() == 0) { + pageReferences.set(pageId, null); + emptyPageReferenceSlots.enqueue(pageId); + memorySizeInBytes -= pageReference.getEstimatedSizeInBytes(); + } + else { + memorySizeInBytes -= pageReference.getEstimatedSizeInBytes(); + pageReference.compact(); + memorySizeInBytes += pageReference.getEstimatedSizeInBytes(); + } + } + } + + private int findFirstPositionToInsert(Page newPage, GroupByIdBlock groupIds) + { + for (int position = 0; position < newPage.getPositionCount(); position++) { + long groupId = groupIds.getGroupId(position); + if (groupedRows.getCapacity() <= groupId) { + return position; + } + + RowHeap rows = groupedRows.get(groupId); + if (rows == null || rows.size() < topN) { + return position; + } + // check against current minimum + Row previousRow = rows.first(); + PageReference pageReference = pageReferences.get(previousRow.getPageId()); + if (pageWithPositionComparator.compareTo(newPage, position, pageReference.getPage(), previousRow.getPosition()) < 0) { + return position; + } + } + // no positions to insert + return -1; + } + + /** + * The class is a pointer to a row in a page. + * The actual position in the page is mutable because as pages are compacted, the position will change. + */ + private static class Row + { + private final int pageId; + private int position; + + private Row(int pageId, int position) + { + this.pageId = pageId; + reset(position); + } + + public void reset(int position) + { + this.position = position; + } + + public int getPageId() + { + return pageId; + } + + public int getPosition() + { + return position; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("pageId", pageId) + .add("position", position) + .toString(); + } + } + + private static class PageReference + { + private static final long INSTANCE_SIZE = ClassLayout.parseClass(PageReference.class).instanceSize(); + + private Page page; + private Row[] reference; + + private int usedPositionCount; + + public PageReference(Page page) + { + this.page = requireNonNull(page, "page is null"); + this.reference = new Row[page.getPositionCount()]; + } + + public void reference(Row row) + { + reference[row.getPosition()] = row; + usedPositionCount++; + } + + public boolean dereference(int position) + { + checkArgument(reference[position] != null && usedPositionCount > 0); + reference[position] = null; + return (--usedPositionCount) == 0; + } + + public int getUsedPositionCount() + { + return usedPositionCount; + } + + public void compact() + { + checkState(usedPositionCount > 0); + + if (usedPositionCount == page.getPositionCount()) { + return; + } + + // re-assign reference + Row[] newReference = new Row[usedPositionCount]; + int[] positions = new int[usedPositionCount]; + int index = 0; + // update all the elements in the heaps that reference the current page + // this does not change the elements in the heap; + // it only updates the value of the elements; while keeping the same order + for (int i = 0; i < reference.length && index < usedPositionCount; i++) { + Row value = reference[i]; + if (value != null) { + value.reset(index); + newReference[index] = value; + positions[index] = i; + index++; + } + } + verify(index == usedPositionCount); + + // compact page + page = page.copyPositions(positions, 0, usedPositionCount); + reference = newReference; + } + + public Page getPage() + { + return page; + } + + public long getEstimatedSizeInBytes() + { + return page.getRetainedSizeInBytes() + sizeOf(reference) + INSTANCE_SIZE; + } + } + + // this class is for precise memory tracking + private static class IntFIFOQueue + extends IntArrayFIFOQueue + { + private static final long INSTANCE_SIZE = ClassLayout.parseClass(IntFIFOQueue.class).instanceSize(); + + private long getEstimatedSizeInBytes() + { + return INSTANCE_SIZE + sizeOf(array); + } + } + + // this class is for precise memory tracking + private static class RowHeap + extends ObjectHeapPriorityQueue + { + private static final long INSTANCE_SIZE = ClassLayout.parseClass(RowHeap.class).instanceSize(); + private static final long ROW_ENTRY_SIZE = ClassLayout.parseClass(Row.class).instanceSize(); + + private RowHeap(Comparator comparator) + { + super(1, comparator); + } + + private long getEstimatedSizeInBytes() + { + return INSTANCE_SIZE + sizeOf(heap) + size() * ROW_ENTRY_SIZE; + } + } + + private class ResultIterator + extends AbstractIterator + { + // ObjectBigArray capacity is always at least 1024, so discarding "small" BigArrays even if you don't need the entire space is wasteful + private static final int UNUSED_CAPACITY_DISPOSAL_THRESHOLD = 4096; + + private final PageBuilder pageBuilder; + private final PrimitiveIterator.OfInt groupIds; + + private long currentGroupSizeInBytes; + + // the row number of the current position in the group + private int currentGroupPosition; + // number of rows in the group + private int currentGroupSize; + + private ObjectBigArray currentRows; + boolean intermediate; + + ResultIterator(PrimitiveIterator.OfInt groupIds, boolean intermediate) + { + this.intermediate = intermediate; + + // If intermediate=True, it means that we are extracting the data for + // intermediate (spilling) output. In such cases, we do not want to add the Row Numbers + // as the RowNumbers will be computed and added after un-spilling data to produce output + if (produceRowNumber && !intermediate) { + pageBuilder = new PageBuilder(new ImmutableList.Builder().add(sourceTypes).add(BIGINT).build()); + } + else { + pageBuilder = new PageBuilder(ImmutableList.copyOf(sourceTypes)); + } + // Populate the first group + currentRows = new ObjectBigArray<>(); + this.groupIds = groupIds; + nextGroupedRows(); + } + + @Override + protected Page computeNext() + { + pageBuilder.reset(); + while (!pageBuilder.isFull()) { + if (currentRows == null) { + // no more groups + break; + } + if (currentGroupPosition == currentGroupSize) { + // the current group has produced all its rows + memorySizeInBytes -= currentGroupSizeInBytes; + currentGroupPosition = 0; + nextGroupedRows(); + continue; + } + + // Clear the reference to the Row after access to make it reclaimable by GC + Row row = currentRows.getAndSet(currentGroupPosition, null); + PageReference pageReference = pageReferences.get(row.getPageId()); + Page page = pageReference.getPage(); + int position = row.getPosition(); + for (int i = 0; i < sourceTypes.length; i++) { + sourceTypes[i].appendTo(page.getBlock(i), position, pageBuilder.getBlockBuilder(i)); + } + + if (produceRowNumber && !intermediate) { + BIGINT.writeLong(pageBuilder.getBlockBuilder(sourceTypes.length), currentGroupPosition + 1); + } + pageBuilder.declarePosition(); + currentGroupPosition++; + + // deference the row; no need to compact the pages but remove them if completely unused + if (pageReference.dereference(position)) { + pageReferences.set(row.getPageId(), null); + memorySizeInBytes -= pageReference.getEstimatedSizeInBytes(); + } + } + + if (pageBuilder.isEmpty()) { + return endOfData(); + } + return pageBuilder.build(); + } + + private void nextGroupedRows() + { + if (this.groupIds.hasNext()) { + RowHeap rows = groupedRows.getAndSet(this.groupIds.nextInt(), null); + verify(rows != null && !rows.isEmpty(), "impossible to have inserted a group without a witness row. rows=%s for %s", rows, this); + currentGroupSizeInBytes = rows.getEstimatedSizeInBytes(); + currentGroupSize = rows.size(); + + // sort output rows in a big array in case there are too many rows + checkState(currentRows != null, "currentRows already observed the final group"); + if (currentRows.getCapacity() > UNUSED_CAPACITY_DISPOSAL_THRESHOLD && currentRows.getCapacity() > currentGroupSize * 2L) { + // Discard over-sized big array to avoid unnecessary waste + currentRows = new ObjectBigArray<>(); + } + currentRows.ensureCapacity(currentGroupSize); + for (int index = currentGroupSize - 1; index >= 0; index--) { + currentRows.set(index, rows.dequeue()); + } + } + else { + currentRows = null; + currentGroupSize = 0; + } + } + } + + // Below code has been borrowed from SpillableHashAggregationBuilder. + // Fix this when SpillableHashAggregationBuilder is fixed + // + // TODO: we could skip memory reservation for inMemoryGroupedTopNBuilder.getGroupIdsSortingSize() + // if before building result from inMemoryGroupedTopNBuilder we would convert it to "read only" version. + // Read only version of GroupByHash from inMemoryGroupedTopNBuilder could be compacted by dropping + // most of it's field, freeing up some memory that could be used for sorting. + public long getGroupIdsSortingSize() + { + return (long) groupByHash.getGroupCount() * Integer.BYTES; + } + + /** This function is used in the spill flow, where we want the contents of + * the builder extracted in a hashSorted manner to write to persistent storage + * @return Iterator of hash sorted + */ + @Override + public Iterator buildHashSortedIntermediateResult() + { + return new ResultIterator(groupByHash.getHashSortedGroupIds(), true); + } + + @Override + public void close() {} + + @Override + public GroupByHash getGroupByHash() + { + return groupByHash; + } + + @Override + public boolean isEmpty() + { + return groupByHash.getGroupCount() == 0; + } + + /** + * This function is used when we want to migrate the memory accounting to a new memory context + * @param newMemoryContext + * @return + */ + public boolean migrateMemoryContext(LocalMemoryContext newMemoryContext) + { + long currentBytes = memoryContext.getBytes(); + memoryContext.setBytes(0); + boolean successFullyMigrated = newMemoryContext.trySetBytes(newMemoryContext.getBytes() + currentBytes); + if (!successFullyMigrated) { + memoryContext.setBytes(currentBytes); + return false; + } + memoryContext = newMemoryContext; + return true; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/MergeHashSort.java b/presto-main/src/main/java/com/facebook/presto/operator/MergeHashSort.java index aaacc66f2ca51..36c3c15738b14 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/MergeHashSort.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/MergeHashSort.java @@ -19,6 +19,8 @@ import com.facebook.presto.memory.context.AggregatedMemoryContext; import com.facebook.presto.util.MergeSortedPages.PageWithPosition; +import javax.annotation.Nullable; + import java.util.List; import java.util.function.BiPredicate; import java.util.stream.IntStream; @@ -45,11 +47,28 @@ public MergeHashSort(AggregatedMemoryContext memoryContext) /** * Rows with same hash value are guaranteed to be in the same result page. */ - public WorkProcessor merge(List keyTypes, List allTypes, List> channels, DriverYieldSignal driverYieldSignal) + public WorkProcessor merge(List keyTypes, List allTypes, List> pages, DriverYieldSignal driverYieldSignal) + { + return merge(keyTypes, null, allTypes, pages, driverYieldSignal); + } + + public WorkProcessor merge(List keyTypes, @Nullable List keyChannels, List allTypes, List> pages, DriverYieldSignal driverYieldSignal) { - InterpretedHashGenerator hashGenerator = InterpretedHashGenerator.createPositionalWithTypes(keyTypes); + InterpretedHashGenerator hashGenerator; + + // keyChannels=null indicates that the keyChannels are implicitly the first N channels, N being keyTypes.size() + // SpillableHashAggregationBuilder invokes this function in this manner. + // For other invocations of this function (like in SpillableGroupedTopNBuilder), the keyChannels need not be the first N channels + // and are hence explicitly specified + if (keyChannels == null) { + hashGenerator = InterpretedHashGenerator.createPositionalWithTypes(keyTypes); + } + else { + hashGenerator = new InterpretedHashGenerator(keyTypes, keyChannels); + } + return mergeSortedPages( - channels, + pages, createHashPageWithPositionComparator(hashGenerator), IntStream.range(0, allTypes.size()).boxed().collect(toImmutableList()), allTypes, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/SpillableGroupedTopNBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/SpillableGroupedTopNBuilder.java new file mode 100644 index 0000000000000..37ff85c813ff3 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/SpillableGroupedTopNBuilder.java @@ -0,0 +1,330 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.context.AggregatedMemoryContext; +import com.facebook.presto.memory.context.LocalMemoryContext; +import com.facebook.presto.spiller.Spiller; +import com.facebook.presto.spiller.SpillerFactory; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.io.Closer; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.openjdk.jol.info.ClassLayout; + +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.facebook.presto.operator.Operator.NOT_BLOCKED; +import static com.facebook.presto.operator.SpillingUtils.checkSpillSucceeded; +import static com.facebook.presto.operator.WorkProcessor.TransformationState.blocked; +import static com.facebook.presto.operator.WorkProcessor.TransformationState.finished; +import static com.facebook.presto.operator.WorkProcessor.TransformationState.needsMoreData; +import static com.facebook.presto.operator.WorkProcessor.TransformationState.ofResult; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static java.util.Objects.requireNonNull; + +public class SpillableGroupedTopNBuilder + implements GroupedTopNBuilder +{ + private static final long INSTANCE_SIZE = ClassLayout.parseClass(SpillableGroupedTopNBuilder.class).instanceSize(); + + private final Supplier inputInMemoryGroupedTopNBuilderSupplier; + private final Supplier outputInMemoryGroupedTopNBuilderSupplier; + private final Supplier> memoryWaitingFutureSupplier; + private final SpillerFactory spillerFactory; + private final List sourceTypes; + private final List partitionTypes; + private final List partitionChannels; + + private InMemoryGroupedTopNBuilder inputInMemoryGroupedTopNBuilder; + private InMemoryGroupedTopNBuilder outputInMemoryGroupedTopNBuilder; + + private final LocalMemoryContext localUserMemoryContext; + private final LocalMemoryContext localRevocableMemoryContext; + private final AggregatedMemoryContext aggregatedMemoryContextForMerge; + private final AggregatedMemoryContext aggregatedMemoryContextForSpill; + private final DriverYieldSignal driverYieldSignal; + private final SpillContext spillContext; + + private final long unspillMemoryLimit; + + private Optional spiller = Optional.empty(); + private ListenableFuture spillInProgress = immediateFuture(null); + + public SpillableGroupedTopNBuilder( + List sourceTypes, + List partitionTypes, + List partitionChannels, + Supplier inputInMemoryGroupedTopNBuilderSupplier, + Supplier outputInMemoryGroupedTopNBuilderSupplier, + Supplier> memoryWaitingFutureSupplier, + long unspillMemoryLimit, + LocalMemoryContext localUserMemoryContext, + LocalMemoryContext localRevocableMemoryContext, + AggregatedMemoryContext aggregatedMemoryContextForMerge, + AggregatedMemoryContext aggregatedMemoryContextForSpill, + SpillContext spillContext, + DriverYieldSignal driverYieldSignal, + SpillerFactory spillerFactory) + { + this.inputInMemoryGroupedTopNBuilderSupplier = requireNonNull(inputInMemoryGroupedTopNBuilderSupplier, "inputInMemoryGroupedTopNBuilderSupplier cannot be null"); + this.outputInMemoryGroupedTopNBuilderSupplier = requireNonNull(outputInMemoryGroupedTopNBuilderSupplier, "outputInMemoryGroupedTopNBuilderSupplier cannot be null"); + this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory cannot be null"); + this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes cannot be null"); + this.partitionTypes = requireNonNull(partitionTypes, "partitionTypes cannot be null"); + this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels cannot be null"); + + initializeInputInMemoryGroupedTopNBuilder(); + + this.localUserMemoryContext = requireNonNull(localUserMemoryContext, "localUserMemoryContext cannot be null"); + this.localRevocableMemoryContext = requireNonNull(localRevocableMemoryContext, "localRevocableMemoryContext cannot be null"); + this.aggregatedMemoryContextForMerge = requireNonNull(aggregatedMemoryContextForMerge, "aggregatedMemoryContextForMerge cannot be null"); + this.aggregatedMemoryContextForSpill = requireNonNull(aggregatedMemoryContextForSpill, "aggregatedMemoryContextForSpill cannot be null"); + this.driverYieldSignal = requireNonNull(driverYieldSignal, "driverYieldSignal cannot be null"); + this.spillContext = requireNonNull(spillContext, "spillContext cannot be null"); + + this.unspillMemoryLimit = requireNonNull(unspillMemoryLimit, "unspillMemoryLimit cannot be null"); + this.memoryWaitingFutureSupplier = memoryWaitingFutureSupplier; + } + + public Work processPage(Page page) + { + checkState(hasPreviousSpillCompletedSuccessfully(), "Previous spill hasn't yet finished"); + return inputInMemoryGroupedTopNBuilder.processPage(page); + } + + private boolean hasPreviousSpillCompletedSuccessfully() + { + if (spillInProgress.isDone()) { + // check for exception from previous spill for early failure + checkSpillSucceeded(spillInProgress); + return true; + } + return false; + } + + @Override + public WorkProcessor buildResult() + { + // spill could be in progress. + checkSpillSucceeded(spillInProgress); + + // Convert revocable memory to user memory as returned Iterator holds on to memory so we no longer can revoke. + if (!spiller.isPresent()) { + if (inputInMemoryGroupedTopNBuilder.isEmpty() || inputInMemoryGroupedTopNBuilder.migrateMemoryContext(localUserMemoryContext)) { + // we were able to successfully move to userMemory, so we can now safely return the result + return inputInMemoryGroupedTopNBuilder.buildResult(); + } + } + + // Spill the remaining collected input + // TODO: Possible Optimization here is to not spill the last remaining buffered input + // and instead do a memory+disk sort merge. SpillableHashAggregationBuilder does this + checkSpillSucceeded(spillToDisk()); + verify(inputInMemoryGroupedTopNBuilder.isEmpty()); + updateMemoryReservations(); + + // Collect all spill streams to merge-sort + List> sortedPageStreams = ImmutableList.>builder() + .addAll(spiller.get().getSpills().stream() + .map(WorkProcessor::fromIterator) + .collect(toImmutableList())) + .build(); + + // Sort-Merge the rows and produce group-by-group output + return getFinalResult(sortedPageStreams); + } + + @Override + public GroupByHash getGroupByHash() + { + return inputInMemoryGroupedTopNBuilder.getGroupByHash(); + } + + @Override + public boolean isEmpty() + { + return inputInMemoryGroupedTopNBuilder.isEmpty() && outputInMemoryGroupedTopNBuilder.isEmpty(); + } + + @Override + public long getEstimatedSizeInBytes() + { + return INSTANCE_SIZE + inputInMemoryGroupedTopNBuilder.getEstimatedSizeInBytes(); + } + + @Override + public ListenableFuture updateMemoryReservations() + { + ListenableFuture inputBuilderFuture = inputInMemoryGroupedTopNBuilder.updateMemoryReservations(); + + ListenableFuture outputBuilderFuture = null; + if (outputInMemoryGroupedTopNBuilder != null) { + outputBuilderFuture = outputInMemoryGroupedTopNBuilder.updateMemoryReservations(); + } + + if (!inputBuilderFuture.isDone()) { + return inputBuilderFuture; + } + if (outputBuilderFuture != null && !outputBuilderFuture.isDone()) { + return outputBuilderFuture; + } + return Futures.immediateFuture(null); + } + + @Override + public void close() + { + try (Closer closer = Closer.create()) { + if (inputInMemoryGroupedTopNBuilder != null) { + closer.register(inputInMemoryGroupedTopNBuilder::close); + } + + if (outputInMemoryGroupedTopNBuilder != null) { + closer.register(outputInMemoryGroupedTopNBuilder::close); + } + spiller.ifPresent(closer::register); + closer.register(() -> localUserMemoryContext.setBytes(0)); + closer.register(() -> localRevocableMemoryContext.setBytes(0)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + public ListenableFuture startMemoryRevoke() + { + checkState(spillInProgress.isDone()); + if (inputInMemoryGroupedTopNBuilder.isEmpty() || localRevocableMemoryContext.getBytes() == 0) { + // All revocable memory has been released in buildResult method. + // At this point, InMemoryGroupedTopNBuilder is no longer accepting any input so no point in spilling. + return NOT_BLOCKED; + } + spillToDisk(); + return spillInProgress; + } + + public void finishMemoryRevoke() + { + if (spiller.isPresent()) { + checkState(spillInProgress.isDone()); + verify(inputInMemoryGroupedTopNBuilder.isEmpty()); + spiller.get().commit(); + } + updateMemoryReservations(); + } + + @VisibleForTesting + private WorkProcessor getFinalResult(List> sortedPageStreams) + { + MergeHashSort mergeHashSort = new MergeHashSort(aggregatedMemoryContextForMerge); + WorkProcessor mergedSortedPages = mergeHashSort.merge( + partitionTypes, + partitionChannels, + sourceTypes, + sortedPageStreams, + driverYieldSignal); + + initializeOutputInMemoryGroupedTopNBuilder(); + + // Create final result by re-processing the sorted stream page-at-a-time through a GroupedTopNBuilder + return mergedSortedPages.flatTransform(new WorkProcessor.Transformation>() + { + public WorkProcessor.TransformationState> process(Optional inputPageOptional) + { + boolean inputIsPresent = inputPageOptional.isPresent(); + if (!inputIsPresent && outputInMemoryGroupedTopNBuilder.isEmpty()) { + // no more pages and builder is empty + return finished(); + } + + if (inputIsPresent) { + Page inputPage = inputPageOptional.get(); + boolean done = outputInMemoryGroupedTopNBuilder.processPage(inputPage).process(); + if (!done) { + return blocked(memoryWaitingFutureSupplier.get()); + } + if (outputInMemoryGroupedTopNBuilder.getEstimatedSizeInBytes() < unspillMemoryLimit) { + return needsMoreData(); + } + } + + // We can produce output after every input page, because input pages do not have + // hash values that span multiple pages (guaranteed by MergeHashSort) + // + // iterator to extract existing context out of builder + WorkProcessor result = outputInMemoryGroupedTopNBuilder.buildResult(); + // initialize new builder + initializeOutputInMemoryGroupedTopNBuilder(); + return ofResult(result, inputIsPresent); + } + }); + } + + private ListenableFuture spillToDisk() + { + if (!spiller.isPresent()) { + spiller = Optional.of(spillerFactory.create( + sourceTypes, + spillContext, + aggregatedMemoryContextForSpill)); + } + + // start spilling process with current content of the inMemoryGroupedTopNBuilder builder... + spillInProgress = spiller.get().spill(inputInMemoryGroupedTopNBuilder.buildHashSortedIntermediateResult()); + // ... and immediately create new inMemoryGroupedTopNBuilder so effectively memory ownership + // over inMemoryGroupedTopNBuilder is transferred from this thread to a spilling thread + initializeInputInMemoryGroupedTopNBuilder(); + + return spillInProgress; + } + + private void initializeInputInMemoryGroupedTopNBuilder() + { + if (inputInMemoryGroupedTopNBuilder != null) { + inputInMemoryGroupedTopNBuilder.close(); + } + inputInMemoryGroupedTopNBuilder = inputInMemoryGroupedTopNBuilderSupplier.get(); + } + + private void initializeOutputInMemoryGroupedTopNBuilder() + { + if (outputInMemoryGroupedTopNBuilder != null) { + outputInMemoryGroupedTopNBuilder.close(); + } + outputInMemoryGroupedTopNBuilder = outputInMemoryGroupedTopNBuilderSupplier.get(); + } + + @Override + public Iterator buildHashSortedIntermediateResult() + { + throw new UnsupportedOperationException("SpillableGroupedTopNBuilder does not support buildHashSortedIntermediateResult"); + } + + @VisibleForTesting + protected InMemoryGroupedTopNBuilder getInputInMemoryGroupedTopNBuilder() + { + return inputInMemoryGroupedTopNBuilder; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TopNOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/TopNOperator.java index 29ca992bfc59d..038719a644997 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/TopNOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/TopNOperator.java @@ -16,17 +16,14 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.type.Type; -import com.facebook.presto.memory.context.LocalMemoryContext; import com.facebook.presto.spi.plan.PlanNodeId; import com.google.common.collect.ImmutableList; -import java.util.Iterator; import java.util.List; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; -import static java.util.Collections.emptyIterator; import static java.util.Objects.requireNonNull; /** @@ -89,12 +86,11 @@ public OperatorFactory duplicate() } private final OperatorContext operatorContext; - private final LocalMemoryContext localUserMemoryContext; private GroupedTopNBuilder topNBuilder; private boolean finishing; - private Iterator outputIterator; + private WorkProcessor outputPages; public TopNOperator( OperatorContext operatorContext, @@ -104,19 +100,20 @@ public TopNOperator( List sortOrders) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); - this.localUserMemoryContext = operatorContext.localUserMemoryContext(); checkArgument(n >= 0, "n must be positive"); - if (n == 0) { finishing = true; - outputIterator = emptyIterator(); + // We create an empty WorkProcessor and finish it + outputPages = WorkProcessor.of(); + outputPages.process(); } else { - topNBuilder = new GroupedTopNBuilder( + topNBuilder = new InMemoryGroupedTopNBuilder( types, new SimplePageWithPositionComparator(types, sortChannels, sortOrders), n, false, + operatorContext.localUserMemoryContext(), new NoChannelGroupByHash()); } } @@ -152,7 +149,7 @@ public void addInput(Page page) boolean done = topNBuilder.processPage(requireNonNull(page, "page is null")).process(); // there is no grouping so work will always be done verify(done); - updateMemoryReservation(); + topNBuilder.updateMemoryReservations(); } @Override @@ -162,29 +159,28 @@ public Page getOutput() return null; } - if (outputIterator == null) { + if (outputPages == null) { // start flushing - outputIterator = topNBuilder.buildResult(); + outputPages = topNBuilder.buildResult(); } - Page output = null; - if (outputIterator.hasNext()) { - output = outputIterator.next(); + if (!outputPages.process()) { + return null; } - else { - outputIterator = emptyIterator(); + + if (outputPages.isFinished()) { + topNBuilder.close(); + return null; } - updateMemoryReservation(); - return output; - } - private void updateMemoryReservation() - { - localUserMemoryContext.setBytes(topNBuilder.getEstimatedSizeInBytes()); + Page outputPage = outputPages.getResult(); + topNBuilder.updateMemoryReservations(); + + return outputPage; } private boolean noMoreOutput() { - return outputIterator != null && !outputIterator.hasNext(); + return outputPages != null && outputPages.isFinished(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TopNRowNumberOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/TopNRowNumberOperator.java index b0b23aa666806..e7b82cded9f69 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/TopNRowNumberOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/TopNRowNumberOperator.java @@ -16,20 +16,20 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.type.Type; -import com.facebook.presto.memory.context.LocalMemoryContext; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spiller.SpillerFactory; import com.facebook.presto.sql.gen.JoinCompiler; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; -import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.function.Supplier; import static com.facebook.presto.SystemSessionProperties.isDictionaryAggregationEnabled; import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.operator.GroupByHash.createGroupByHash; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; @@ -57,7 +57,10 @@ public static class TopNRowNumberOperatorFactory private final boolean generateRowNumber; private boolean closed; + private final long unspillMemoryLimit; private final JoinCompiler joinCompiler; + private final SpillerFactory spillerFactory; + private final boolean spillEnabled; public TopNRowNumberOperatorFactory( int operatorId, @@ -72,7 +75,10 @@ public TopNRowNumberOperatorFactory( boolean partial, Optional hashChannel, int expectedPositions, - JoinCompiler joinCompiler) + long unspillMemoryLimit, + JoinCompiler joinCompiler, + SpillerFactory spillerFactory, + boolean spillEnabled) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); @@ -89,7 +95,10 @@ public TopNRowNumberOperatorFactory( checkArgument(expectedPositions > 0, "expectedPositions must be > 0"); this.generateRowNumber = !partial; this.expectedPositions = expectedPositions; + this.unspillMemoryLimit = unspillMemoryLimit; this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null"); + this.spillerFactory = spillerFactory; + this.spillEnabled = spillEnabled; } @Override @@ -109,7 +118,10 @@ public Operator createOperator(DriverContext driverContext) generateRowNumber, hashChannel, expectedPositions, - joinCompiler); + unspillMemoryLimit, + joinCompiler, + spillerFactory, + spillEnabled); } @Override @@ -121,21 +133,20 @@ public void noMoreOperators() @Override public OperatorFactory duplicate() { - return new TopNRowNumberOperatorFactory(operatorId, planNodeId, sourceTypes, outputChannels, partitionChannels, partitionTypes, sortChannels, sortOrder, maxRowCountPerPartition, partial, hashChannel, expectedPositions, joinCompiler); + return new TopNRowNumberOperatorFactory(operatorId, planNodeId, sourceTypes, outputChannels, partitionChannels, partitionTypes, sortChannels, sortOrder, maxRowCountPerPartition, partial, hashChannel, expectedPositions, unspillMemoryLimit, joinCompiler, spillerFactory, spillEnabled); } } private final OperatorContext operatorContext; - private final LocalMemoryContext localUserMemoryContext; private final int[] outputChannels; - private final GroupByHash groupByHash; - private final GroupedTopNBuilder groupedTopNBuilder; + private GroupedTopNBuilder groupedTopNBuilder; private boolean finishing; + private boolean finished; private Work unfinishedWork; - private Iterator outputIterator; + private WorkProcessor outputPages; public TopNRowNumberOperator( OperatorContext operatorContext, @@ -149,10 +160,12 @@ public TopNRowNumberOperator( boolean generateRowNumber, Optional hashChannel, int expectedPositions, - JoinCompiler joinCompiler) + long unspillMemoryLimit, + JoinCompiler joinCompiler, + SpillerFactory spillerFactory, + boolean spillEnabled) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); - this.localUserMemoryContext = operatorContext.localUserMemoryContext(); ImmutableList.Builder outputChannelsBuilder = ImmutableList.builder(); for (int channel : requireNonNull(outputChannels, "outputChannels is null")) { @@ -165,28 +178,79 @@ public TopNRowNumberOperator( checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0"); + List types = toTypes(sourceTypes, outputChannels, generateRowNumber); + Supplier groupByHashSupplier = () -> createGroupByHash( + partitionTypes, + partitionChannels, + hashChannel, + expectedPositions, + joinCompiler, + isDictionaryAggregationEnabled(operatorContext.getSession()), + this::updateMemoryReservation); + + if (spillEnabled) { + this.groupedTopNBuilder = new SpillableGroupedTopNBuilder( + ImmutableList.copyOf(sourceTypes), + partitionTypes, + partitionChannels, + () -> new InMemoryGroupedTopNBuilder( + ImmutableList.copyOf(sourceTypes), + new SimplePageWithPositionComparator(types, sortChannels, sortOrders), + maxRowCountPerPartition, + generateRowNumber, + operatorContext.localRevocableMemoryContext(), + groupByHashSupplier.get()), + () -> new InMemoryGroupedTopNBuilder( + ImmutableList.copyOf(sourceTypes), + new SimplePageWithPositionComparator(types, sortChannels, sortOrders), + maxRowCountPerPartition, + generateRowNumber, + operatorContext.localUserMemoryContext(), + groupByHashSupplier.get()), + operatorContext::isWaitingForMemory, + unspillMemoryLimit, + operatorContext.localUserMemoryContext(), + operatorContext.localRevocableMemoryContext(), + operatorContext.aggregateSystemMemoryContext(), + operatorContext.aggregateSystemMemoryContext(), + operatorContext.getSpillContext(), + operatorContext.getDriverContext().getYieldSignal(), + spillerFactory); + } + else { + this.groupedTopNBuilder = new InMemoryGroupedTopNBuilder( + ImmutableList.copyOf(sourceTypes), + new SimplePageWithPositionComparator(types, sortChannels, sortOrders), + maxRowCountPerPartition, + generateRowNumber, + operatorContext.localUserMemoryContext(), + groupByHashSupplier.get()); + } + } + + private GroupByHash createGroupByHash( + List partitionTypes, + List partitionChannels, + Optional inputHashChannel, + int expectedPositions, + JoinCompiler joinCompiler, + boolean isDictionaryAggregationEnabled, + UpdateMemory updateMemory) + { if (!partitionChannels.isEmpty()) { checkArgument(expectedPositions > 0, "expectedPositions must be > 0"); - groupByHash = createGroupByHash( + return GroupByHash.createGroupByHash( partitionTypes, Ints.toArray(partitionChannels), - hashChannel, + inputHashChannel, expectedPositions, - isDictionaryAggregationEnabled(operatorContext.getSession()), + isDictionaryAggregationEnabled, joinCompiler, - this::updateMemoryReservation); + updateMemory); } else { - groupByHash = new NoChannelGroupByHash(); + return new NoChannelGroupByHash(); } - - List types = toTypes(sourceTypes, outputChannels, generateRowNumber); - this.groupedTopNBuilder = new GroupedTopNBuilder( - ImmutableList.copyOf(sourceTypes), - new SimplePageWithPositionComparator(types, sortChannels, sortOrders), - maxRowCountPerPartition, - generateRowNumber, - groupByHash); } @Override @@ -205,14 +269,14 @@ public void finish() public boolean isFinished() { // has no more input, has finished flushing, and has no unfinished work - return finishing && outputIterator != null && !outputIterator.hasNext() && unfinishedWork == null; + return finished; } @Override public boolean needsInput() { // still has more input, has not started flushing yet, and has no unfinished work - return !finishing && outputIterator == null && unfinishedWork == null; + return !finishing && outputPages == null && unfinishedWork == null; } @Override @@ -220,7 +284,7 @@ public void addInput(Page page) { checkState(!finishing, "Operator is already finishing"); checkState(unfinishedWork == null, "Cannot add input with the operator when unfinished work is not empty"); - checkState(outputIterator == null, "Cannot add input with the operator when flushing"); + checkState(outputPages == null, "Cannot add input with the operator when flushing"); requireNonNull(page, "page is null"); unfinishedWork = groupedTopNBuilder.processPage(page); if (unfinishedWork.process()) { @@ -229,9 +293,28 @@ public void addInput(Page page) updateMemoryReservation(); } + @Override + public ListenableFuture startMemoryRevoke() + { + if (finishing) { + return NOT_BLOCKED; + } + return groupedTopNBuilder.startMemoryRevoke(); + } + + @Override + public void finishMemoryRevoke() + { + groupedTopNBuilder.finishMemoryRevoke(); + } + @Override public Page getOutput() { + if (finished) { + return null; + } + if (unfinishedWork != null) { boolean finished = unfinishedWork.process(); updateMemoryReservation(); @@ -245,30 +328,46 @@ public Page getOutput() return null; } - if (outputIterator == null) { + if (outputPages == null) { + if (groupedTopNBuilder == null) { + finished = true; + return null; + } // start flushing - outputIterator = groupedTopNBuilder.buildResult(); + outputPages = groupedTopNBuilder.buildResult(); } - Page output = null; - if (outputIterator.hasNext()) { - output = outputIterator.next().extractChannels(outputChannels); + if (!outputPages.process()) { + return null; } + + if (outputPages.isFinished()) { + if (groupedTopNBuilder != null) { + groupedTopNBuilder.close(); + groupedTopNBuilder = null; + } + finished = true; + return null; + } + + Page outputPage = outputPages.getResult() + .extractChannels(outputChannels); + updateMemoryReservation(); - return output; + return outputPage; } @VisibleForTesting public int getCapacity() { + GroupByHash groupByHash = groupedTopNBuilder.getGroupByHash(); checkState(groupByHash != null); return groupByHash.getCapacity(); } private boolean updateMemoryReservation() { - // TODO: may need to use trySetMemoryReservation with a compaction to free memory (but that may cause GC pressure) - localUserMemoryContext.setBytes(groupedTopNBuilder.getEstimatedSizeInBytes()); + groupedTopNBuilder.updateMemoryReservations(); return operatorContext.isWaitingForMemory().isDone(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index 4a39094852cad..6f5d59843ecd1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -15,7 +15,6 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; -import com.facebook.presto.common.array.IntBigArray; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.type.Type; import com.facebook.presto.memory.context.LocalMemoryContext; @@ -38,7 +37,6 @@ import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; -import it.unimi.dsi.fastutil.ints.AbstractIntIterator; import it.unimi.dsi.fastutil.ints.IntIterator; import it.unimi.dsi.fastutil.ints.IntIterators; @@ -261,7 +259,7 @@ public WorkProcessor buildResult() public WorkProcessor buildHashSortedResult() { - return buildResult(hashSortedGroupIds()); + return buildResult(groupByHash.getHashSortedGroupIds()); } public List buildIntermediateTypes() @@ -354,36 +352,6 @@ private IntIterator consecutiveGroupIds() return IntIterators.fromTo(0, groupByHash.getGroupCount()); } - private IntIterator hashSortedGroupIds() - { - IntBigArray groupIds = new IntBigArray(); - groupIds.ensureCapacity(groupByHash.getGroupCount()); - for (int i = 0; i < groupByHash.getGroupCount(); i++) { - groupIds.set(i, i); - } - - groupIds.sort(0, groupByHash.getGroupCount(), (leftGroupId, rightGroupId) -> - Long.compare(groupByHash.getRawHash(leftGroupId), groupByHash.getRawHash(rightGroupId))); - - return new AbstractIntIterator() - { - private final int totalPositions = groupByHash.getGroupCount(); - private int position; - - @Override - public boolean hasNext() - { - return position < totalPositions; - } - - @Override - public int nextInt() - { - return groupIds.get(position++); - } - }; - } - private static class Aggregator { private final GroupedAccumulator aggregation; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 65aecff8e79ef..596e20a5a04eb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -130,6 +130,7 @@ public class FeaturesConfig private boolean spillEnabled; private boolean joinSpillingEnabled = true; private boolean aggregationSpillEnabled = true; + private boolean topNSpillEnabled = true; private boolean distinctAggregationSpillEnabled = true; private boolean dedupBasedDistinctAggregationSpillEnabled; private boolean distinctAggregationLargeBlockSpillEnabled; @@ -138,6 +139,7 @@ public class FeaturesConfig private boolean windowSpillEnabled = true; private boolean orderBySpillEnabled = true; private DataSize aggregationOperatorUnspillMemoryLimit = new DataSize(4, MEGABYTE); + private DataSize topNOperatorUnspillMemoryLimit = new DataSize(4, MEGABYTE); private List spillerSpillPaths = ImmutableList.of(); private int spillerThreads = 4; private double spillMaxUsedSpaceThreshold = 0.9; @@ -1061,6 +1063,19 @@ public boolean isAggregationSpillEnabled() return aggregationSpillEnabled; } + @Config("experimental.topn-spill-enabled") + @ConfigDescription("Spill TopN if spill is enabled") + public FeaturesConfig setTopNSpillEnabled(boolean topNSpillEnabled) + { + this.topNSpillEnabled = topNSpillEnabled; + return this; + } + + public boolean isTopNSpillEnabled() + { + return topNSpillEnabled; + } + @Config("experimental.distinct-aggregation-spill-enabled") @ConfigDescription("Spill distinct aggregations if aggregation spill is enabled") public FeaturesConfig setDistinctAggregationSpillEnabled(boolean distinctAggregationSpillEnabled) @@ -1262,6 +1277,18 @@ public boolean isDefaultFilterFactorEnabled() return defaultFilterFactorEnabled; } + public DataSize getTopNOperatorUnspillMemoryLimit() + { + return topNOperatorUnspillMemoryLimit; + } + + @Config("experimental.topn-operator-unspill-memory-limit") + public FeaturesConfig setTopNOperatorUnspillMemoryLimit(DataSize aggregationOperatorUnspillMemoryLimit) + { + this.topNOperatorUnspillMemoryLimit = aggregationOperatorUnspillMemoryLimit; + return this; + } + public DataSize getAggregationOperatorUnspillMemoryLimit() { return aggregationOperatorUnspillMemoryLimit; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 748daa490efce..45efb218e33c5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -255,6 +255,7 @@ import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency; import static com.facebook.presto.SystemSessionProperties.getTaskPartitionedWriterCount; import static com.facebook.presto.SystemSessionProperties.getTaskWriterCount; +import static com.facebook.presto.SystemSessionProperties.getTopNOperatorUnspillMemoryLimit; import static com.facebook.presto.SystemSessionProperties.isAggregationSpillEnabled; import static com.facebook.presto.SystemSessionProperties.isDistinctAggregationSpillEnabled; import static com.facebook.presto.SystemSessionProperties.isEnableDynamicFiltering; @@ -267,6 +268,7 @@ import static com.facebook.presto.SystemSessionProperties.isOrderBySpillEnabled; import static com.facebook.presto.SystemSessionProperties.isQuickDistinctLimitEnabled; import static com.facebook.presto.SystemSessionProperties.isSpillEnabled; +import static com.facebook.presto.SystemSessionProperties.isTopNSpillEnabled; import static com.facebook.presto.SystemSessionProperties.isWindowSpillEnabled; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; @@ -1064,6 +1066,8 @@ public PhysicalOperation visitTopNRowNumber(TopNRowNumberNode node, LocalExecuti outputMappings.put(node.getRowNumberVariable(), channel); } + DataSize unspillMemoryLimit = getTopNOperatorUnspillMemoryLimit(context.getSession()); + Optional hashChannel = node.getHashVariable().map(variableChannelGetter(source)); OperatorFactory operatorFactory = new TopNRowNumberOperator.TopNRowNumberOperatorFactory( context.getNextOperatorId(), @@ -1078,7 +1082,10 @@ public PhysicalOperation visitTopNRowNumber(TopNRowNumberNode node, LocalExecuti node.isPartial(), hashChannel, 1000, - joinCompiler); + unspillMemoryLimit.toBytes(), + joinCompiler, + spillerFactory, + isTopNSpillEnabled(session)); return new PhysicalOperation(operatorFactory, makeLayout(node), context, source); } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingTaskContext.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingTaskContext.java index b6e2a6c8a1b46..a061451a2c72c 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/TestingTaskContext.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingTaskContext.java @@ -134,6 +134,12 @@ public Builder setQueryMaxTotalMemory(DataSize queryMaxTotalMemory) return this; } + public Builder setMaxRevocableMemory(DataSize maxRevocableMemory) + { + this.maxRevocableMemory = maxRevocableMemory; + return this; + } + public Builder setMemoryPoolSize(DataSize memoryPoolSize) { this.memoryPoolSize = memoryPoolSize; diff --git a/presto-main/src/test/java/com/facebook/presto/memory/TestingMemoryContext.java b/presto-main/src/test/java/com/facebook/presto/memory/TestingMemoryContext.java new file mode 100644 index 0000000000000..1d498180f9a01 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/memory/TestingMemoryContext.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.memory; + +import com.facebook.presto.memory.context.LocalMemoryContext; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; + +public class TestingMemoryContext + implements LocalMemoryContext +{ + private long usedBytes; + private final long maxBytes; + + public TestingMemoryContext(long maxBytes) + { + this.maxBytes = maxBytes; + this.usedBytes = 0; + } + + @Override + public long getBytes() + { + return usedBytes; + } + + @Override + public ListenableFuture setBytes(long bytes) + { + usedBytes = bytes; + return Futures.immediateFuture(null); + } + + @Override + public ListenableFuture setBytes(long bytes, boolean enforceBroadcastMemoryLimit) + { + return setBytes(bytes); + } + + @Override + public boolean trySetBytes(long bytes) + { + if (usedBytes + bytes > maxBytes) { + return false; + } + usedBytes += bytes; + return true; + } + + @Override + public boolean trySetBytes(long bytes, boolean enforceBroadcastMemoryLimit) + { + return trySetBytes(bytes); + } + + @Override + public void close() + { + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkGroupedTopNBuilder.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkInMemoryGroupedTopNBuilder.java similarity index 75% rename from presto-main/src/test/java/com/facebook/presto/operator/BenchmarkGroupedTopNBuilder.java rename to presto-main/src/test/java/com/facebook/presto/operator/BenchmarkInMemoryGroupedTopNBuilder.java index bccfd7a71efcd..4014241c0fbac 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkGroupedTopNBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkInMemoryGroupedTopNBuilder.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.TestingMemoryContext; import com.google.common.collect.ImmutableList; import io.airlift.tpch.LineItem; import io.airlift.tpch.LineItemGenerator; @@ -56,7 +57,7 @@ @Fork(4) @Warmup(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) @Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) -public class BenchmarkGroupedTopNBuilder +public class BenchmarkInMemoryGroupedTopNBuilder { private static final int HASH_GROUP = 0; private static final int EXTENDED_PRICE = 1; @@ -88,7 +89,7 @@ public static class BenchmarkData private int groupCount = 10; private List page; - private GroupedTopNBuilder topNBuilder; + private InMemoryGroupedTopNBuilder topNBuilder; @Setup public void setup() @@ -101,10 +102,10 @@ public void setup() else { groupByHash = new NoChannelGroupByHash(); } - topNBuilder = new GroupedTopNBuilder(types, comparator, topN, false, groupByHash); + topNBuilder = new InMemoryGroupedTopNBuilder(types, comparator, topN, false, new TestingMemoryContext(0L), groupByHash); } - public GroupedTopNBuilder getTopNBuilder() + public InMemoryGroupedTopNBuilder getTopNBuilder() { return topNBuilder; } @@ -118,7 +119,7 @@ public List getPages() @Benchmark public void topN(BenchmarkData data, Blackhole blackhole) { - GroupedTopNBuilder topNBuilder = data.getTopNBuilder(); + InMemoryGroupedTopNBuilder topNBuilder = data.getTopNBuilder(); for (Page page : data.getPages()) { Work work = topNBuilder.processPage(page); boolean finished; @@ -127,7 +128,7 @@ public void topN(BenchmarkData data, Blackhole blackhole) } while (!finished); } - Iterator results = topNBuilder.buildResult(); + Iterator results = topNBuilder.buildResult().iterator(); while (results.hasNext()) { blackhole.consume(results.next()); } @@ -135,7 +136,7 @@ public void topN(BenchmarkData data, Blackhole blackhole) public List topNToList(BenchmarkData data) { - GroupedTopNBuilder topNBuilder = data.getTopNBuilder(); + InMemoryGroupedTopNBuilder topNBuilder = data.getTopNBuilder(); for (Page page : data.getPages()) { Work work = topNBuilder.processPage(page); boolean finished; @@ -143,7 +144,7 @@ public List topNToList(BenchmarkData data) finished = work.process(); } while (!finished); } - return ImmutableList.copyOf(topNBuilder.buildResult()); + return ImmutableList.copyOf(topNBuilder.buildResult().iterator()); } @Test @@ -159,7 +160,7 @@ public static void main(String[] args) { Options options = new OptionsBuilder() .parent(new CommandLineOptions(args)) - .include(".*" + BenchmarkGroupedTopNBuilder.class.getSimpleName() + ".*") + .include(".*" + BenchmarkInMemoryGroupedTopNBuilder.class.getSimpleName() + ".*") .build(); new Runner(options).run(); @@ -194,4 +195,39 @@ private static List createInputPages(int positions, List types, int return pages; } + + public static List createSequentialInputPages(int positions, List types, int positionsPerPage, int groupCount, int seed) + { + List pages = new ArrayList<>(); + PageBuilder pageBuilder = new PageBuilder(types); + LineItemGenerator lineItemGenerator = new LineItemGenerator(1, 1, 1); + Iterator iterator = lineItemGenerator.iterator(); + + long mod = positions / groupCount; + long groupNumber = 0; + + for (int i = 0; i < positions; i++) { + pageBuilder.declarePosition(); + if (i % mod == 0) { + groupNumber++; + } + LineItem lineItem = iterator.next(); + BIGINT.writeLong(pageBuilder.getBlockBuilder(HASH_GROUP), groupNumber); + DOUBLE.writeDouble(pageBuilder.getBlockBuilder(EXTENDED_PRICE), i % mod * 2.0); + DOUBLE.writeDouble(pageBuilder.getBlockBuilder(DISCOUNT), lineItem.getDiscount()); + DATE.writeLong(pageBuilder.getBlockBuilder(SHIP_DATE), lineItem.getShipDate()); + DOUBLE.writeDouble(pageBuilder.getBlockBuilder(QUANTITY), lineItem.getQuantity()); + + if (pageBuilder.getPositionCount() >= positionsPerPage) { + pages.add(pageBuilder.build()); + pageBuilder.reset(); + } + } + + if (!pageBuilder.isEmpty()) { + pages.add(pageBuilder.build()); + } + + return pages; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestGroupedTopNBuilder.java b/presto-main/src/test/java/com/facebook/presto/operator/TestInMemoryGroupedTopNBuilder.java similarity index 90% rename from presto-main/src/test/java/com/facebook/presto/operator/TestGroupedTopNBuilder.java rename to presto-main/src/test/java/com/facebook/presto/operator/TestInMemoryGroupedTopNBuilder.java index 945e32ab361ec..2484a8a298921 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestGroupedTopNBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestInMemoryGroupedTopNBuilder.java @@ -17,6 +17,7 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.array.ObjectBigArray; import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.TestingMemoryContext; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.gen.JoinCompiler; import com.google.common.collect.ImmutableList; @@ -47,9 +48,9 @@ import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertTrue; -public class TestGroupedTopNBuilder +public class TestInMemoryGroupedTopNBuilder { - private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupedTopNBuilder.class).instanceSize(); + private static final long INSTANCE_SIZE = ClassLayout.parseClass(InMemoryGroupedTopNBuilder.class).instanceSize(); private static final long INT_FIFO_QUEUE_SIZE = ClassLayout.parseClass(IntArrayFIFOQueue.class).instanceSize(); private static final long OBJECT_OVERHEAD = ClassLayout.parseClass(Object.class).instanceSize(); private static final long PAGE_REFERENCE_INSTANCE_SIZE = ClassLayout.parseClass(TestPageReference.class).instanceSize(); @@ -70,15 +71,17 @@ public static Object[][] pageRowCounts() @Test public void testEmptyInput() { - GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + InMemoryGroupedTopNBuilder groupedTopNBuilder = new InMemoryGroupedTopNBuilder( ImmutableList.of(BIGINT), (left, leftPosition, right, rightPosition) -> { throw new UnsupportedOperationException(); }, 5, false, + new TestingMemoryContext(100L), new NoChannelGroupByHash()); - assertFalse(groupedTopNBuilder.buildResult().hasNext()); + + assertFalse(groupedTopNBuilder.buildResult().iterator().hasNext()); } @Test(dataProvider = "produceRowNumbers") @@ -107,11 +110,12 @@ public void testMultiGroupTopN(boolean produceRowNumbers) } GroupByHash groupByHash = createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0), NOOP); - GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + InMemoryGroupedTopNBuilder groupedTopNBuilder = new InMemoryGroupedTopNBuilder( types, new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), 2, produceRowNumbers, + new TestingMemoryContext(100L), groupByHash); assertBuilderSize(groupByHash, types, ImmutableList.of(), ImmutableList.of(), groupedTopNBuilder.getEstimatedSizeInBytes()); @@ -131,7 +135,7 @@ public void testMultiGroupTopN(boolean produceRowNumbers) assertTrue(groupedTopNBuilder.processPage(input.get(3)).process()); assertBuilderSize(groupByHash, types, ImmutableList.of(4, 1, 2, 0), ImmutableList.of(2, 2, 2, 1), groupedTopNBuilder.getEstimatedSizeInBytes()); - List output = ImmutableList.copyOf(groupedTopNBuilder.buildResult()); + List output = ImmutableList.copyOf(groupedTopNBuilder.buildResult().iterator()); assertEquals(output.size(), 1); Page expected = rowPagesBuilder(BIGINT, DOUBLE, BIGINT) @@ -179,31 +183,34 @@ public void testSingleGroupTopN(boolean produceRowNumbers) page.compact(); } - GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + InMemoryGroupedTopNBuilder groupedTopNBuilder = new InMemoryGroupedTopNBuilder( types, new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), 5, produceRowNumbers, + new TestingMemoryContext(100L), new NoChannelGroupByHash()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(), ImmutableList.of(), groupedTopNBuilder.getEstimatedSizeInBytes()); + + GroupByHash groupByHash = groupedTopNBuilder.getGroupByHash(); + assertBuilderSize(groupByHash, types, ImmutableList.of(), ImmutableList.of(), groupedTopNBuilder.getEstimatedSizeInBytes()); // add 4 rows for the first page and created a single heap with 4 rows assertTrue(groupedTopNBuilder.processPage(input.get(0)).process()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(4), ImmutableList.of(4), groupedTopNBuilder.getEstimatedSizeInBytes()); + assertBuilderSize(groupByHash, types, ImmutableList.of(4), ImmutableList.of(4), groupedTopNBuilder.getEstimatedSizeInBytes()); // add 1 row for the second page and the heap is with 5 rows assertTrue(groupedTopNBuilder.processPage(input.get(1)).process()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(4, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); + assertBuilderSize(groupByHash, types, ImmutableList.of(4, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); // update 1 new row from the third page (which will be compacted into a single row only) assertTrue(groupedTopNBuilder.processPage(input.get(2)).process()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(4, 1, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); + assertBuilderSize(groupByHash, types, ImmutableList.of(4, 1, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); // the last page will be discarded assertTrue(groupedTopNBuilder.processPage(input.get(3)).process()); - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(4, 1, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); + assertBuilderSize(groupByHash, types, ImmutableList.of(4, 1, 1), ImmutableList.of(5), groupedTopNBuilder.getEstimatedSizeInBytes()); - List output = ImmutableList.copyOf(groupedTopNBuilder.buildResult()); + List output = ImmutableList.copyOf(groupedTopNBuilder.buildResult().iterator()); assertEquals(output.size(), 1); Page expected = rowPagesBuilder(BIGINT, DOUBLE, BIGINT) @@ -221,7 +228,7 @@ public void testSingleGroupTopN(boolean produceRowNumbers) assertPageEquals(types, output.get(0), new Page(expected.getBlock(0), expected.getBlock(1))); } - assertBuilderSize(new NoChannelGroupByHash(), types, ImmutableList.of(0, 0, 0), ImmutableList.of(0), groupedTopNBuilder.getEstimatedSizeInBytes()); + assertBuilderSize(groupedTopNBuilder.getGroupByHash(), types, ImmutableList.of(0, 0, 0), ImmutableList.of(0), groupedTopNBuilder.getEstimatedSizeInBytes()); } @Test @@ -239,11 +246,12 @@ public void testYield() AtomicBoolean unblock = new AtomicBoolean(); GroupByHash groupByHash = createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0), unblock::get); - GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + InMemoryGroupedTopNBuilder groupedTopNBuilder = new InMemoryGroupedTopNBuilder( types, new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), 5, false, + new TestingMemoryContext(100L), groupByHash); assertBuilderSize(groupByHash, types, ImmutableList.of(), ImmutableList.of(), groupedTopNBuilder.getEstimatedSizeInBytes()); @@ -252,7 +260,7 @@ public void testYield() assertFalse(work.process()); unblock.set(true); assertTrue(work.process()); - List output = ImmutableList.copyOf(groupedTopNBuilder.buildResult()); + List output = ImmutableList.copyOf(groupedTopNBuilder.buildResult().iterator()); assertEquals(output.size(), 1); Page expected = rowPagesBuilder(types) @@ -289,11 +297,12 @@ public void testAutoCompact() .row(1L, 0.6) .build(); - GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + InMemoryGroupedTopNBuilder groupedTopNBuilder = new InMemoryGroupedTopNBuilder( types, new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), 1, false, + new TestingMemoryContext(100L), createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0), NOOP)); // page 1: @@ -383,11 +392,12 @@ public void testLargePagesMemoryTracking(int pageCount, int rowCount) List input = rowPagesBuilder.build(); GroupByHash groupByHash = createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0), NOOP); - GroupedTopNBuilder groupedTopNBuilder = new GroupedTopNBuilder( + InMemoryGroupedTopNBuilder groupedTopNBuilder = new InMemoryGroupedTopNBuilder( types, new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), pageCount * rowCount, false, + new TestingMemoryContext(100L), groupByHash); // Assert memory usage gradually goes up @@ -399,7 +409,7 @@ public void testLargePagesMemoryTracking(int pageCount, int rowCount) // Assert memory usage gradually goes down (i.e., proportional to the number of rows/pages we have produced) int outputPageCount = 0; int remainingRows = pageCount * rowCount; - Iterator output = groupedTopNBuilder.buildResult(); + Iterator output = groupedTopNBuilder.buildResult().iterator(); while (output.hasNext()) { remainingRows -= output.next().getPositionCount(); assertBuilderSize( @@ -485,6 +495,7 @@ private static void assertBuilderSize( rowHeapsSizeInBytes + pageReferencesSizeInBytes + groupedRowsSizeInBytes + + (long) groupByHash.getGroupCount() * Integer.BYTES + emptyPageReferenceSlotsSizeInBytes; assertEquals(actualSizeInBytes, expectedSizeInBytes); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestSpillableGroupedTopNBuilder.java b/presto-main/src/test/java/com/facebook/presto/operator/TestSpillableGroupedTopNBuilder.java new file mode 100644 index 0000000000000..09bb63cb9b9e4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestSpillableGroupedTopNBuilder.java @@ -0,0 +1,361 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.RowPagesBuilder; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.TestingMemoryContext; +import com.facebook.presto.memory.context.AggregatedMemoryContext; +import com.facebook.presto.memory.context.LocalMemoryContext; +import com.facebook.presto.spiller.TestingSpillContext; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.gen.JoinCompiler; +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.operator.UpdateMemory.NOOP; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestSpillableGroupedTopNBuilder +{ + @DataProvider + public static Object[][] produceRowNumbers() + { + return new Object[][] {{true}, {false}}; + } + + @Test(dataProvider = "produceRowNumbers") + public void testThatRevokeSpillsDuringAddInput(boolean produceRowNumbers) + { + DummySpillerFactory spillerFactory = new DummySpillerFactory(); + List types = ImmutableList.of(BIGINT, DOUBLE); + Supplier groupByHashSupplier = () -> createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0)); + + LocalMemoryContext userMemoryContext = new TestingMemoryContext(200L); + LocalMemoryContext revocableMemoryContext = new TestingMemoryContext(1000L); + DriverYieldSignal driverYieldSignal = new DriverYieldSignal(); + AggregatedMemoryContext aggregatedMemoryContextForMerge = AggregatedMemoryContext.newSimpleAggregatedMemoryContext(); + AggregatedMemoryContext aggregatedMemoryContextForSpill = AggregatedMemoryContext.newSimpleAggregatedMemoryContext(); + TestingSpillContext spillContext = new TestingSpillContext(); + + SpillableGroupedTopNBuilder spillableGroupedTopNBuilder = new SpillableGroupedTopNBuilder( + types, + ImmutableList.of(BIGINT), + ImmutableList.of(0), + () -> new InMemoryGroupedTopNBuilder( + types, + new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), + 4, + produceRowNumbers, + revocableMemoryContext, + groupByHashSupplier.get()), + () -> new InMemoryGroupedTopNBuilder( + types, + new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), + 4, + produceRowNumbers, + revocableMemoryContext, + groupByHashSupplier.get()), + () -> immediateFuture(null), + 100_000, + userMemoryContext, + revocableMemoryContext, + aggregatedMemoryContextForMerge, + aggregatedMemoryContextForSpill, + spillContext, + driverYieldSignal, + spillerFactory); + + List inputPages = generatePages(1000, 10, 100); + + long emptyBuilderSize = spillableGroupedTopNBuilder.getInputInMemoryGroupedTopNBuilder().getEstimatedSizeInBytes(); + + // add input + for (int i = 0; i < 20; i++) { + spillableGroupedTopNBuilder.processPage(inputPages.get(i)).process(); + spillableGroupedTopNBuilder.updateMemoryReservations(); + } + + // revoke + spillableGroupedTopNBuilder.startMemoryRevoke(); + spillableGroupedTopNBuilder.finishMemoryRevoke(); + // assert that spill files were created + assertEquals(spillerFactory.getSpillsCount(), 1); + // assert that the memory was emptied + assertEquals(spillableGroupedTopNBuilder.getInputInMemoryGroupedTopNBuilder().getEstimatedSizeInBytes(), emptyBuilderSize); + assertEquals(userMemoryContext.getBytes(), 0); + // assert that input uses revocable memory and that spillable builder ensures revocable memory is updated with input builder memory + assertEquals(revocableMemoryContext.getBytes(), spillableGroupedTopNBuilder.getInputInMemoryGroupedTopNBuilder().getEstimatedSizeInBytes()); + assertEquals(userMemoryContext.getBytes(), 0); + + // add input + for (int i = 21; i < 40; i++) { + spillableGroupedTopNBuilder.processPage(inputPages.get(i)).process(); + spillableGroupedTopNBuilder.updateMemoryReservations(); + } + // revoke + spillableGroupedTopNBuilder.startMemoryRevoke(); + spillableGroupedTopNBuilder.finishMemoryRevoke(); + // assert that spill files were created + assertEquals(spillerFactory.getSpillsCount(), 2); + // assert that the revocable memory was emptied + assertEquals(spillableGroupedTopNBuilder.getInputInMemoryGroupedTopNBuilder().getEstimatedSizeInBytes(), emptyBuilderSize); + + // add input + for (int i = 41; i < 100; i++) { + spillableGroupedTopNBuilder.processPage(inputPages.get(i)).process(); + spillableGroupedTopNBuilder.updateMemoryReservations(); + } + + WorkProcessor result = spillableGroupedTopNBuilder.buildResult(); + // when we call buildResult, we should have either moved the last chunk of input + // from revocable memory to user memory, if it doesn't fit, we should have spilled it + + while (!result.isFinished()) { + boolean res = result.process(); + if (res && !result.isFinished()) { + Page resPage = result.getResult(); + } + } + assertEquals(spillableGroupedTopNBuilder.getInputInMemoryGroupedTopNBuilder().getEstimatedSizeInBytes(), emptyBuilderSize); + + // assert that builder.close clears memory accounts + spillableGroupedTopNBuilder.close(); + assertEquals(userMemoryContext.getBytes(), 0); + assertEquals(revocableMemoryContext.getBytes(), 0); + } + + @Test(dataProvider = "produceRowNumbers") + public void testNoSpilling(boolean produceRowNumbers) + { + DummySpillerFactory spillerFactory = new DummySpillerFactory(); + List types = ImmutableList.of(BIGINT, DOUBLE); + Supplier groupByHashSupplier = () -> createGroupByHash(ImmutableList.of(types.get(0)), ImmutableList.of(0)); + + // set userMemory high enough that no spilling is needed + LocalMemoryContext userMemoryContext = new TestingMemoryContext(1000000L); + LocalMemoryContext revocableMemoryContext = new TestingMemoryContext(1000000L); + DriverYieldSignal driverYieldSignal = new DriverYieldSignal(); + AggregatedMemoryContext aggregatedMemoryContextForMerge = AggregatedMemoryContext.newSimpleAggregatedMemoryContext(); + AggregatedMemoryContext aggregatedMemoryContextForSpill = AggregatedMemoryContext.newSimpleAggregatedMemoryContext(); + TestingSpillContext spillContext = new TestingSpillContext(); + + SpillableGroupedTopNBuilder spillableGroupedTopNBuilder = new SpillableGroupedTopNBuilder( + types, + ImmutableList.of(BIGINT), + ImmutableList.of(0), + () -> new InMemoryGroupedTopNBuilder( + types, + new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), + 4, + produceRowNumbers, + revocableMemoryContext, + groupByHashSupplier.get()), + () -> new InMemoryGroupedTopNBuilder( + types, + new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), + 4, + produceRowNumbers, + revocableMemoryContext, + groupByHashSupplier.get()), + () -> immediateFuture(null), + 100_000, + userMemoryContext, + revocableMemoryContext, + aggregatedMemoryContextForMerge, + aggregatedMemoryContextForSpill, + spillContext, + driverYieldSignal, + spillerFactory); + + List inputPages = generatePages(100, 2, 100); + + // add input + for (int i = 0; i < 3; i++) { + spillableGroupedTopNBuilder.processPage(inputPages.get(i)).process(); + } + spillableGroupedTopNBuilder.updateMemoryReservations(); + + // get output + WorkProcessor outputPages = spillableGroupedTopNBuilder.buildResult(); + + // assert that revocable memory was moved to user memory + assertEquals(revocableMemoryContext.getBytes(), 0); + assertEquals(userMemoryContext.getBytes(), spillableGroupedTopNBuilder.getInputInMemoryGroupedTopNBuilder().getEstimatedSizeInBytes()); + + // get output page (only 1 page in this test case) + boolean isResAvailable = outputPages.process(); + assertTrue(isResAvailable); + Page resPage = outputPages.getResult(); + assertEquals(resPage.getPositionCount(), 200); + } + + @Test(dataProvider = "produceRowNumbers") + public void testThatBuilderYieldsDuringBuildResultAndResumesWhenUnblocked(boolean produceRowNumbers) + { + class MemoryFuture + { + ListenableFuture future; + + public void setFuture(ListenableFuture future) + { + this.future = future; + } + + public ListenableFuture getFuture() + { + return future; + } + } + + DummySpillerFactory spillerFactory = new DummySpillerFactory(); + List types = ImmutableList.of(BIGINT, DOUBLE); + final MemoryFuture memoryWaitingFuture = new MemoryFuture(); + memoryWaitingFuture.setFuture(immediateFuture(null)); + Supplier groupByHashSupplier = () -> GroupByHash.createGroupByHash( + ImmutableList.of(types.get(0)), + Ints.toArray(ImmutableList.of(0)), + Optional.empty(), + 1, + false, + new JoinCompiler(createTestMetadataManager(), new FeaturesConfig()), + () -> memoryWaitingFuture.getFuture().isDone()); + + LocalMemoryContext userMemoryContext = new TestingMemoryContext(200L); + LocalMemoryContext revocableMemoryContext = new TestingMemoryContext(1000L); + DriverYieldSignal driverYieldSignal = new DriverYieldSignal(); + AggregatedMemoryContext aggregatedMemoryContextForMerge = AggregatedMemoryContext.newSimpleAggregatedMemoryContext(); + AggregatedMemoryContext aggregatedMemoryContextForSpill = AggregatedMemoryContext.newSimpleAggregatedMemoryContext(); + TestingSpillContext spillContext = new TestingSpillContext(); + SpillableGroupedTopNBuilder spillableGroupedTopNBuilder = new SpillableGroupedTopNBuilder( + types, + ImmutableList.of(BIGINT), + ImmutableList.of(0), + () -> new InMemoryGroupedTopNBuilder( + types, + new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), + 4, + produceRowNumbers, + revocableMemoryContext, + groupByHashSupplier.get()), + () -> new InMemoryGroupedTopNBuilder( + types, + new SimplePageWithPositionComparator(types, ImmutableList.of(1), ImmutableList.of(ASC_NULLS_LAST)), + 4, + produceRowNumbers, + revocableMemoryContext, + groupByHashSupplier.get()), + memoryWaitingFuture::getFuture, + 100_000, + userMemoryContext, + revocableMemoryContext, + aggregatedMemoryContextForMerge, + aggregatedMemoryContextForSpill, + spillContext, + driverYieldSignal, + spillerFactory); + + List inputPages = generatePages(1000, 10, 100); + + long emptyBuilderSize = spillableGroupedTopNBuilder.getInputInMemoryGroupedTopNBuilder().getEstimatedSizeInBytes(); + + // add input + for (int i = 0; i < 20; i++) { + spillableGroupedTopNBuilder.processPage(inputPages.get(i)).process(); + spillableGroupedTopNBuilder.updateMemoryReservations(); + } + + // revoke + spillableGroupedTopNBuilder.startMemoryRevoke(); + spillableGroupedTopNBuilder.finishMemoryRevoke(); + + // assert that spill files were created + assertEquals(spillerFactory.getSpillsCount(), 1); + // assert that the memory was emptied + assertEquals(spillableGroupedTopNBuilder.getInputInMemoryGroupedTopNBuilder().getEstimatedSizeInBytes(), emptyBuilderSize); + + assertEquals(userMemoryContext.getBytes(), 0); + // assert that input uses revocable memory and that spillable builder ensures revocable memory is updated with input builder memory + assertEquals(revocableMemoryContext.getBytes(), spillableGroupedTopNBuilder.getInputInMemoryGroupedTopNBuilder().getEstimatedSizeInBytes()); + assertEquals(userMemoryContext.getBytes(), 0); + + WorkProcessor result = spillableGroupedTopNBuilder.buildResult(); + + // Yield after producing first output Page + SettableFuture currentWaitingFuture = SettableFuture.create(); + memoryWaitingFuture.setFuture(currentWaitingFuture); + assertTrue(!memoryWaitingFuture.getFuture().isDone()); + + // try to get output and assert that none is available + boolean isResAvailble = result.process(); + assertFalse(isResAvailble); + + // unblock + currentWaitingFuture.set(null); + + // output should be available + isResAvailble = result.process(); + assertTrue(isResAvailble); + } + + private static GroupByHash createGroupByHash(List partitionTypes, List partitionChannels) + { + return GroupByHash.createGroupByHash( + partitionTypes, + Ints.toArray(partitionChannels), + Optional.empty(), + 1, + false, + new JoinCompiler(createTestMetadataManager(), new FeaturesConfig()), + NOOP); + } + + private static List generatePages(int groupCount, int rowsPerGroup, int rowsPerPage) + { + //create input + List types = ImmutableList.of(BIGINT, DOUBLE, VARCHAR); + RowPagesBuilder pagesBuilder = rowPagesBuilder(types); + int nextVal = 0; + int nextGroup = 0; + int totalRows = 0; + for (int i = 0; i < groupCount; i++) { + for (int j = 0; j < rowsPerGroup; j++) { + pagesBuilder.row(nextGroup++, nextVal++, "Unit test written during times of increased intensity"); + + if (totalRows++ % rowsPerPage == 0) { + pagesBuilder.pageBreak(); + } + } + } + return pagesBuilder.build(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestTopNRowNumberOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestTopNRowNumberOperator.java index b79b63fa66770..d2699c6aac624 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestTopNRowNumberOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestTopNRowNumberOperator.java @@ -22,8 +22,10 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; +import io.airlift.units.DataSize; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; @@ -40,9 +42,13 @@ import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.operator.BenchmarkInMemoryGroupedTopNBuilder.createSequentialInputPages; import static com.facebook.presto.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; import static com.facebook.presto.operator.GroupByHashYieldAssertion.finishOperatorWithYieldingGroupByHash; import static com.facebook.presto.operator.OperatorAssertion.assertOperatorEquals; +import static com.facebook.presto.operator.OperatorAssertion.assertOperatorEqualsIgnoreOrder; +import static com.facebook.presto.operator.OperatorAssertion.toMaterializedResult; import static com.facebook.presto.operator.TopNRowNumberOperator.TopNRowNumberOperatorFactory; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; @@ -121,7 +127,10 @@ public void testPartitioned(boolean hashEnabled) false, Optional.empty(), 10, - joinCompiler); + 0, + joinCompiler, + null, + false); MaterializedResult expected = resultBuilder(driverContext.getSession(), DOUBLE, BIGINT, BIGINT) .row(0.3, 1L, 1L) @@ -169,7 +178,10 @@ public void testUnPartitioned(boolean partial) partial, Optional.empty(), 10, - joinCompiler); + 0, + joinCompiler, + null, + false); MaterializedResult expected; if (partial) { @@ -190,6 +202,7 @@ public void testUnPartitioned(boolean partial) assertOperatorEquals(operatorFactory, driverContext, input, expected); } + @Test public void testMemoryReservationYield() { Type type = BIGINT; @@ -208,7 +221,10 @@ public void testMemoryReservationYield() false, Optional.empty(), 10, - joinCompiler); + 0, + joinCompiler, + null, + false); // get result with yield; pick a relatively small buffer for heaps GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash( @@ -230,4 +246,78 @@ public void testMemoryReservationYield() } assertEquals(count, 6_000 * 600); } + + @Test + public void testSpillableTopNRowNumberOperatorProducesCorrectOutputIFSpilledDuringAddInput() + { + List input = createSequentialInputPages(1000, ImmutableList.of(BIGINT, DOUBLE, DOUBLE, VARCHAR, DOUBLE), 200, 300, 42); + + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + driverContext = TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) + .setQueryMaxMemory(new DataSize(200, DataSize.Unit.KILOBYTE)) + .setQueryMaxTotalMemory(new DataSize(200, DataSize.Unit.KILOBYTE)) + .setMaxRevocableMemory(new DataSize(200, DataSize.Unit.KILOBYTE)) + .build() + .addPipelineContext(0, true, true, false) + .addDriverContext(); + joinCompiler = new JoinCompiler(MetadataManager.createTestMetadataManager(), new FeaturesConfig()); + + TopNRowNumberOperatorFactory operatorFactory = new TopNRowNumberOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(BIGINT, DOUBLE, DOUBLE, VARCHAR, DOUBLE), + ImmutableList.of(0, 1, 2, 3, 4), + ImmutableList.of(0), + ImmutableList.of(BIGINT), + ImmutableList.of(1), + ImmutableList.of(SortOrder.ASC_NULLS_LAST), + 3, + true, + Optional.empty(), + 10, + 10000, + joinCompiler, + new DummySpillerFactory(), + true); + + assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, toMaterializedResult(driverContext.getSession(), ImmutableList.of(BIGINT, DOUBLE, DOUBLE, VARCHAR, DOUBLE), input), true); + } + + @Test + public void testSpillableTopNRowNumberOperatorProducesCorrectOutputIfNOSPILLDuringAddInput() + { + List input = createSequentialInputPages(1000, ImmutableList.of(BIGINT, DOUBLE, DOUBLE, VARCHAR, DOUBLE), 200, 300, 42); + + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + driverContext = TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) + .setQueryMaxMemory(new DataSize(200, DataSize.Unit.KILOBYTE)) + .setQueryMaxTotalMemory(new DataSize(200, DataSize.Unit.KILOBYTE)) + .setMaxRevocableMemory(new DataSize(200, DataSize.Unit.KILOBYTE)) + .build() + .addPipelineContext(0, true, true, false) + .addDriverContext(); + joinCompiler = new JoinCompiler(MetadataManager.createTestMetadataManager(), new FeaturesConfig()); + + TopNRowNumberOperatorFactory operatorFactory = new TopNRowNumberOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(BIGINT, DOUBLE, DOUBLE, VARCHAR, DOUBLE), + ImmutableList.of(0, 1, 2, 3, 4), + ImmutableList.of(0), + ImmutableList.of(BIGINT), + ImmutableList.of(1), + ImmutableList.of(SortOrder.ASC_NULLS_LAST), + 3, + true, + Optional.empty(), + 10, + 10000, + joinCompiler, + new DummySpillerFactory(), + true); + + assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, toMaterializedResult(driverContext.getSession(), ImmutableList.of(BIGINT, DOUBLE, DOUBLE, VARCHAR, DOUBLE), input)); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 9570f87240605..5369d2bafc626 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -105,7 +105,9 @@ public void testDefaults() .setOrderByAggregationSpillEnabled(true) .setWindowSpillEnabled(true) .setOrderBySpillEnabled(true) + .setTopNSpillEnabled(true) .setAggregationOperatorUnspillMemoryLimit(DataSize.valueOf("4MB")) + .setTopNOperatorUnspillMemoryLimit(DataSize.valueOf("4MB")) .setSpillerSpillPaths("") .setSpillerThreads(4) .setSpillMaxUsedSpaceThreshold(0.9) @@ -286,7 +288,9 @@ public void testExplicitPropertyMappings() .put("experimental.order-by-aggregation-spill-enabled", "false") .put("experimental.window-spill-enabled", "false") .put("experimental.order-by-spill-enabled", "false") + .put("experimental.topn-spill-enabled", "false") .put("experimental.aggregation-operator-unspill-memory-limit", "100MB") + .put("experimental.topn-operator-unspill-memory-limit", "100MB") .put("experimental.spiller-spill-path", "/tmp/custom/spill/path1,/tmp/custom/spill/path2") .put("experimental.spiller-threads", "42") .put("experimental.spiller-max-used-space-threshold", "0.8") @@ -440,7 +444,9 @@ public void testExplicitPropertyMappings() .setOrderByAggregationSpillEnabled(false) .setWindowSpillEnabled(false) .setOrderBySpillEnabled(false) + .setTopNSpillEnabled(false) .setAggregationOperatorUnspillMemoryLimit(DataSize.valueOf("100MB")) + .setTopNOperatorUnspillMemoryLimit(DataSize.valueOf("100MB")) .setSpillerSpillPaths("/tmp/custom/spill/path1,/tmp/custom/spill/path2") .setSpillerThreads(42) .setSpillMaxUsedSpaceThreshold(0.8) diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkSpilledTopNQueries.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkSpilledTopNQueries.java new file mode 100644 index 0000000000000..24aee81571e0c --- /dev/null +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkSpilledTopNQueries.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark; + +import com.facebook.presto.testing.QueryRunner; + +import java.util.HashMap; + +import static com.facebook.presto.spark.PrestoSparkQueryRunner.createSpilledHivePrestoSparkQueryRunner; +import static io.airlift.tpch.TpchTable.getTables; + +public class TestPrestoSparkSpilledTopNQueries + extends TestPrestoSparkTopNQueries +{ + @Override + protected QueryRunner createQueryRunner() + { + HashMap additionalProperties = new HashMap<>(); + additionalProperties.put("experimental.topn-spill-enabled", "true"); + additionalProperties.put("experimental.spiller.single-stream-spiller-choice", "TEMP_STORAGE"); + additionalProperties.put("experimental.memory-revoking-threshold", "0.0"); // revoke always + additionalProperties.put("experimental.memory-revoking-target", "0.0"); + additionalProperties.put("query.max-memory", "110kB"); + + return createSpilledHivePrestoSparkQueryRunner(getTables(), additionalProperties); + } +} diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkTopNQueries.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkTopNQueries.java new file mode 100644 index 0000000000000..0d70307eb9dd4 --- /dev/null +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkTopNQueries.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark; + +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestTopNQueries; + +import static com.facebook.presto.spark.PrestoSparkQueryRunner.createHivePrestoSparkQueryRunner; + +public class TestPrestoSparkTopNQueries + extends AbstractTestTopNQueries +{ + @Override + protected QueryRunner createQueryRunner() + { + return createHivePrestoSparkQueryRunner(); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestTopNQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestTopNQueries.java new file mode 100644 index 0000000000000..c07f09a55bff2 --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestTopNQueries.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import org.testng.annotations.Test; + +public abstract class AbstractTestTopNQueries + extends AbstractTestQueryFramework +{ + @Test + public void testUngroupedTopN() + { + assertQuery("SELECT custkey, totalprice from orders ORDER BY totalprice limit 3"); + } + + @Test + public void testGroupedTopN() + { + assertQuery( + "SELECT * FROM (SELECT " + + "custkey, " + + "totalprice, " + + "ROW_NUMBER() OVER (PARTITION BY custkey order by totalprice) rn " + + "from orders) " + + "where rn < 3"); + } + + @Test + public void testGroupedTopNRowNumber() + { + assertQuery( + "SELECT * FROM (SELECT " + + "custkey, " + + "totalprice, " + + "ROW_NUMBER() OVER (PARTITION BY custkey order by totalprice) rn " + + "from orders) " + + "where rn < 3"); + } + + @Test + public void testGroupedTopWithAggregationAndMultiChannelGrouping() + { + assertQuery( + "SELECT * FROM " + + "( SELECT " + + " regionkey, RANK() OVER (PARTITION BY regionkey ORDER BY nation_count) r FROM" + + " ( SELECT R.regionkey, count(distinct nationkey) nation_count FROM " + + " region R " + + " JOIN nation N ON R.regionkey=N.regionkey " + + " GROUP BY R.regionkey" + + " )" + + ") " + + " WHERE " + + "r <= 2"); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledTopNQueries.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledTopNQueries.java new file mode 100644 index 0000000000000..79b04b50d0e6f --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledTopNQueries.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.nio.file.Paths; + +import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_TOTAL_MEMORY_PER_NODE; +import static com.facebook.presto.SystemSessionProperties.TOPN_SPILL_ENABLED; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; + +public class TestSpilledTopNQueries + extends AbstractTestTopNQueries + +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session defaultSession = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty(SystemSessionProperties.TASK_CONCURRENCY, "2") + .setSystemProperty(SystemSessionProperties.TOPN_OPERATOR_UNSPILL_MEMORY_LIMIT, "120kB") + .setSystemProperty(SystemSessionProperties.QUERY_MAX_MEMORY_PER_NODE, "1500kB") + .build(); + + ImmutableMap extraProperties = ImmutableMap.builder() + .put("experimental.spill-enabled", "true") + .put("experimental.topn-spill-enabled", "true") + .put("experimental.spiller-spill-path", Paths.get(System.getProperty("java.io.tmpdir"), "presto", "spills").toString()) + .put("experimental.spiller-max-used-space-threshold", "1.0") + .put("experimental.memory-revoking-threshold", "0.001") // revoke always + .put("experimental.memory-revoking-target", "0.0") + .build(); + + DistributedQueryRunner queryRunner = new DistributedQueryRunner(defaultSession, 2, extraProperties); + + try { + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + return queryRunner; + } + catch (Exception e) { + queryRunner.close(); + throw e; + } + } + + @Test + public void testDoesNotSpillTopNWhenDisabled() + { + Session session = Session.builder(getSession()) + .setSystemProperty(TOPN_SPILL_ENABLED, "false") + // set this low so that if we ran without spill the query would fail + .setSystemProperty(QUERY_MAX_TOTAL_MEMORY_PER_NODE, "50kB") + .build(); + assertQueryFails(session, + "SELECT orderpriority, custkey FROM orders ORDER BY orderpriority LIMIT 1000", "Query exceeded.*memory limit.*"); + } +}