diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedWindowQueries.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedWindowQueries.java new file mode 100644 index 0000000000000..7c028e1c80b9a --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedWindowQueries.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.presto.tests.AbstractTestWindowQueries; + +import static com.facebook.presto.hive.HiveQueryRunner.createQueryRunner; +import static io.airlift.tpch.TpchTable.getTables; + +public class TestHiveDistributedWindowQueries + extends AbstractTestWindowQueries +{ + public TestHiveDistributedWindowQueries() + { + super(() -> createQueryRunner(getTables())); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java b/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java index 4990ae12b0da3..f4f2d47e81393 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java @@ -264,6 +264,12 @@ public AggregatedMemoryContext aggregateUserMemoryContext() return new InternalAggregatedMemoryContext(operatorMemoryContext.aggregateUserMemoryContext(), memoryFuture, this::updatePeakMemoryReservations, false); } + // caller shouldn't close this context as it's managed by the OperatorContext + public AggregatedMemoryContext aggregateRevocableMemoryContext() + { + return new InternalAggregatedMemoryContext(operatorMemoryContext.aggregateRevocableMemoryContext(), memoryFuture, () -> {}, false); + } + // caller should close this context as it's a new context public AggregatedMemoryContext newAggregateSystemMemoryContext() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java index 548739e13ff4a..c71b1c8cb439b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java @@ -19,14 +19,24 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.type.Type; import com.facebook.presto.memory.context.LocalMemoryContext; +import com.facebook.presto.operator.WorkProcessor.ProcessState; +import com.facebook.presto.operator.WorkProcessor.Transformation; +import com.facebook.presto.operator.WorkProcessor.TransformationState; import com.facebook.presto.operator.window.FramedWindowFunction; import com.facebook.presto.operator.window.WindowPartition; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spiller.Spiller; +import com.facebook.presto.spiller.SpillerFactory; +import com.facebook.presto.sql.gen.OrderingCompiler; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import com.google.common.collect.PeekingIterator; import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; + +import javax.annotation.Nullable; import java.util.List; import java.util.Optional; @@ -35,12 +45,18 @@ import java.util.function.BiPredicate; import java.util.stream.Stream; +import static com.facebook.airlift.concurrent.MoreFutures.checkSuccess; import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.operator.WorkProcessor.TransformationState.needsMoreData; +import static com.facebook.presto.util.MergeSortedPages.mergeSortedPages; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkPositionIndex; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.concat; +import static com.google.common.collect.Iterators.peekingIterator; +import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; @@ -63,6 +79,9 @@ public static class WindowOperatorFactory private final int expectedPositions; private boolean closed; private final PagesIndex.Factory pagesIndexFactory; + private final boolean spillEnabled; + private final SpillerFactory spillerFactory; + private final OrderingCompiler orderingCompiler; public WindowOperatorFactory( int operatorId, @@ -76,7 +95,10 @@ public WindowOperatorFactory( List sortOrder, int preSortedChannelPrefix, int expectedPositions, - PagesIndex.Factory pagesIndexFactory) + PagesIndex.Factory pagesIndexFactory, + boolean spillEnabled, + SpillerFactory spillerFactory, + OrderingCompiler orderingCompiler) { requireNonNull(sourceTypes, "sourceTypes is null"); requireNonNull(planNodeId, "planNodeId is null"); @@ -88,6 +110,8 @@ public WindowOperatorFactory( requireNonNull(sortChannels, "sortChannels is null"); requireNonNull(sortOrder, "sortOrder is null"); requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + requireNonNull(spillerFactory, "spillerFactory is null"); + requireNonNull(orderingCompiler, "orderingCompiler is null"); checkArgument(sortChannels.size() == sortOrder.size(), "Must have same number of sort channels as sort orders"); checkArgument(preSortedChannelPrefix <= sortChannels.size(), "Cannot have more pre-sorted channels than specified sorted channels"); checkArgument(preSortedChannelPrefix == 0 || ImmutableSet.copyOf(preGroupedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedChannelPrefix can only be greater than zero if all partition channels are pre-grouped"); @@ -104,6 +128,9 @@ public WindowOperatorFactory( this.sortOrder = ImmutableList.copyOf(sortOrder); this.preSortedChannelPrefix = preSortedChannelPrefix; this.expectedPositions = expectedPositions; + this.spillEnabled = spillEnabled; + this.spillerFactory = spillerFactory; + this.orderingCompiler = orderingCompiler; } @Override @@ -123,7 +150,10 @@ public Operator createOperator(DriverContext driverContext) sortOrder, preSortedChannelPrefix, expectedPositions, - pagesIndexFactory); + pagesIndexFactory, + spillEnabled, + spillerFactory, + orderingCompiler); } @Override @@ -147,44 +177,26 @@ public OperatorFactory duplicate() sortOrder, preSortedChannelPrefix, expectedPositions, - pagesIndexFactory); + pagesIndexFactory, + spillEnabled, + spillerFactory, + orderingCompiler); } } - private enum State - { - NEEDS_INPUT, - HAS_OUTPUT, - FINISHING, - FINISHED - } - private final OperatorContext operatorContext; + private final List outputTypes; private final int[] outputChannels; private final List windowFunctions; - private final List orderChannels; - private final List ordering; - private final LocalMemoryContext localUserMemoryContext; - - private final int[] preGroupedChannels; - - private final PagesHashStrategy preGroupedPartitionHashStrategy; - private final PagesHashStrategy unGroupedPartitionHashStrategy; - private final PagesHashStrategy preSortedPartitionHashStrategy; - private final PagesHashStrategy peerGroupHashStrategy; - - private final PagesIndex pagesIndex; - - private final PageBuilder pageBuilder; - private final WindowInfo.DriverWindowInfoBuilder windowInfo; private final AtomicReference> driverWindowInfo = new AtomicReference<>(Optional.empty()); - private State state = State.NEEDS_INPUT; - - private WindowPartition partition; + private final Optional spillablePagesToPagesIndexes; + private final WorkProcessor outputPages; + @Nullable private Page pendingInput; + private boolean operatorFinishing; public WindowOperator( OperatorContext operatorContext, @@ -197,7 +209,10 @@ public WindowOperator( List sortOrder, int preSortedChannelPrefix, int expectedPositions, - PagesIndex.Factory pagesIndexFactory) + PagesIndex.Factory pagesIndexFactory, + boolean spillEnabled, + SpillerFactory spillerFactory, + OrderingCompiler orderingCompiler) { requireNonNull(operatorContext, "operatorContext is null"); requireNonNull(outputChannels, "outputChannels is null"); @@ -208,48 +223,88 @@ public WindowOperator( requireNonNull(sortChannels, "sortChannels is null"); requireNonNull(sortOrder, "sortOrder is null"); requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + requireNonNull(spillerFactory, "spillerFactory is null"); checkArgument(sortChannels.size() == sortOrder.size(), "Must have same number of sort channels as sort orders"); checkArgument(preSortedChannelPrefix <= sortChannels.size(), "Cannot have more pre-sorted channels than specified sorted channels"); checkArgument(preSortedChannelPrefix == 0 || ImmutableSet.copyOf(preGroupedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedChannelPrefix can only be greater than zero if all partition channels are pre-grouped"); this.operatorContext = operatorContext; - this.localUserMemoryContext = operatorContext.localUserMemoryContext(); this.outputChannels = Ints.toArray(outputChannels); this.windowFunctions = windowFunctionDefinitions.stream() .map(functionDefinition -> new FramedWindowFunction(functionDefinition.createWindowFunction(), functionDefinition.getFrameInfo())) .collect(toImmutableList()); - List types = Stream.concat( + this.outputTypes = Stream.concat( outputChannels.stream() .map(sourceTypes::get), windowFunctionDefinitions.stream() .map(WindowFunctionDefinition::getType)) .collect(toImmutableList()); - this.pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); - this.preGroupedChannels = Ints.toArray(preGroupedChannels); - this.preGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preGroupedChannels, OptionalInt.empty()); List unGroupedPartitionChannels = partitionChannels.stream() .filter(channel -> !preGroupedChannels.contains(channel)) .collect(toImmutableList()); - this.unGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(unGroupedPartitionChannels, OptionalInt.empty()); List preSortedChannels = sortChannels.stream() .limit(preSortedChannelPrefix) .collect(toImmutableList()); - this.preSortedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty()); - this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, OptionalInt.empty()); - this.pageBuilder = new PageBuilder(types); + List unGroupedOrderChannels = ImmutableList.copyOf(concat(unGroupedPartitionChannels, sortChannels)); + List unGroupedOrdering = ImmutableList.copyOf(concat(nCopies(unGroupedPartitionChannels.size(), ASC_NULLS_LAST), sortOrder)); + List orderChannels; + List ordering; if (preSortedChannelPrefix > 0) { // This already implies that set(preGroupedChannels) == set(partitionChannels) (enforced with checkArgument) - this.orderChannels = ImmutableList.copyOf(Iterables.skip(sortChannels, preSortedChannelPrefix)); - this.ordering = ImmutableList.copyOf(Iterables.skip(sortOrder, preSortedChannelPrefix)); + orderChannels = ImmutableList.copyOf(Iterables.skip(sortChannels, preSortedChannelPrefix)); + ordering = ImmutableList.copyOf(Iterables.skip(sortOrder, preSortedChannelPrefix)); } else { // Otherwise, we need to sort by the unGroupedPartitionChannels and all original sort channels - this.orderChannels = ImmutableList.copyOf(concat(unGroupedPartitionChannels, sortChannels)); - this.ordering = ImmutableList.copyOf(concat(nCopies(unGroupedPartitionChannels.size(), ASC_NULLS_LAST), sortOrder)); + orderChannels = unGroupedOrderChannels; + ordering = unGroupedOrdering; + } + + PagesIndexWithHashStrategies inMemoryPagesIndexWithHashStrategies = new PagesIndexWithHashStrategies( + pagesIndexFactory, + sourceTypes, + expectedPositions, + preGroupedChannels, + unGroupedPartitionChannels, + preSortedChannels, + sortChannels); + + if (spillEnabled) { + PagesIndexWithHashStrategies mergedPagesIndexWithHashStrategies = new PagesIndexWithHashStrategies( + pagesIndexFactory, + sourceTypes, + expectedPositions, + // merged pages are grouped on all partition channels + partitionChannels, + ImmutableList.of(), + // merged pages are pre sorted on all sort channels + sortChannels, + sortChannels); + + this.spillablePagesToPagesIndexes = Optional.of(new SpillablePagesToPagesIndexes( + inMemoryPagesIndexWithHashStrategies, + mergedPagesIndexWithHashStrategies, + sourceTypes, + orderChannels, + ordering, + spillerFactory, + orderingCompiler.compilePageWithPositionComparator(sourceTypes, unGroupedOrderChannels, unGroupedOrdering))); + + this.outputPages = WorkProcessor.create(new PagesSource()) + .flatTransform(spillablePagesToPagesIndexes.get()) + .flatMap(this::pagesIndexToWindowPartitions) + .transform(new WindowPartitionsToOutputPages()); + } + else { + this.spillablePagesToPagesIndexes = Optional.empty(); + this.outputPages = WorkProcessor.create(new PagesSource()) + .transform(new PagesToPagesIndexes(inMemoryPagesIndexWithHashStrategies, orderChannels, ordering)) + .flatMap(this::pagesIndexToWindowPartitions) + .transform(new WindowPartitionsToOutputPages()); } windowInfo = new WindowInfo.DriverWindowInfoBuilder(); @@ -270,32 +325,24 @@ public OperatorContext getOperatorContext() @Override public void finish() { - if (state == State.FINISHING || state == State.FINISHED) { - return; - } - if (state == State.NEEDS_INPUT) { - // Since was waiting for more input, prepare what we have for output since we will not be getting any more input - finishPagesIndex(); - } - state = State.FINISHING; + operatorFinishing = true; } @Override public boolean isFinished() { - return state == State.FINISHED; + return outputPages.isFinished(); } @Override public boolean needsInput() { - return state == State.NEEDS_INPUT; + return pendingInput == null && !operatorFinishing; } @Override public void addInput(Page page) { - checkState(state == State.NEEDS_INPUT, "Operator can not take input at this time"); requireNonNull(page, "page is null"); checkState(pendingInput == null, "Operator already has pending input"); @@ -304,146 +351,467 @@ public void addInput(Page page) } pendingInput = page; - if (processPendingInput()) { - state = State.HAS_OUTPUT; + } + + @Override + public Page getOutput() + { + if (!outputPages.process()) { + return null; + } + + if (outputPages.isFinished()) { + return null; } - localUserMemoryContext.setBytes(pagesIndex.getEstimatedSize().toBytes()); + + return outputPages.getResult(); } - /** - * @return true if a full group has been buffered after processing the pendingInput, false otherwise - */ - private boolean processPendingInput() + @Override + public ListenableFuture startMemoryRevoke() + { + return spillablePagesToPagesIndexes.get().spill(); + } + + @Override + public void finishMemoryRevoke() { - checkState(pendingInput != null); - pendingInput = updatePagesIndex(pendingInput); + spillablePagesToPagesIndexes.get().finishRevokeMemory(); + } - // If we have unused input or are finishing, then we have buffered a full group - if (pendingInput != null || state == State.FINISHING) { - finishPagesIndex(); - return true; + private static class PagesIndexWithHashStrategies + { + final PagesIndex pagesIndex; + final PagesHashStrategy preGroupedPartitionHashStrategy; + final PagesHashStrategy unGroupedPartitionHashStrategy; + final PagesHashStrategy preSortedPartitionHashStrategy; + final PagesHashStrategy peerGroupHashStrategy; + final int[] preGroupedPartitionChannels; + + PagesIndexWithHashStrategies( + PagesIndex.Factory pagesIndexFactory, + List sourceTypes, + int expectedPositions, + List preGroupedPartitionChannels, + List unGroupedPartitionChannels, + List preSortedChannels, + List sortChannels) + { + this.pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); + this.preGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preGroupedPartitionChannels, OptionalInt.empty()); + this.unGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(unGroupedPartitionChannels, OptionalInt.empty()); + this.preSortedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty()); + this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, OptionalInt.empty()); + this.preGroupedPartitionChannels = Ints.toArray(preGroupedPartitionChannels); } - else { - return false; + } + + private class PagesSource + implements WorkProcessor.Process + { + @Override + public ProcessState process() + { + if (operatorFinishing && pendingInput == null) { + return ProcessState.finished(); + } + + if (pendingInput != null) { + Page result = pendingInput; + pendingInput = null; + return ProcessState.ofResult(result); + } + + return ProcessState.yield(); } } - /** - * @return the unused section of the page, or null if fully applied. - * pagesIndex guaranteed to have at least one row after this method returns - */ - private Page updatePagesIndex(Page page) + private class PagesToPagesIndexes + implements Transformation { - checkArgument(page.getPositionCount() > 0); + final PagesIndexWithHashStrategies pagesIndexWithHashStrategies; + final List orderChannels; + final List ordering; + final LocalMemoryContext memoryContext; + + boolean resetPagesIndex; + int pendingInputPosition; + + PagesToPagesIndexes( + PagesIndexWithHashStrategies pagesIndexWithHashStrategies, + List orderChannels, + List ordering) + { + this.pagesIndexWithHashStrategies = pagesIndexWithHashStrategies; + this.orderChannels = orderChannels; + this.ordering = ordering; + this.memoryContext = operatorContext.aggregateUserMemoryContext().newLocalMemoryContext(PagesToPagesIndexes.class.getSimpleName()); + } - // TODO: Fix pagesHashStrategy to allow specifying channels for comparison, it currently requires us to rearrange the right side blocks in consecutive channel order - Page preGroupedPage = rearrangePage(page, preGroupedChannels); - if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionEqualsRow(preGroupedPartitionHashStrategy, 0, 0, preGroupedPage)) { - // Find the position where the pre-grouped columns change - int groupEnd = findGroupEnd(preGroupedPage, preGroupedPartitionHashStrategy, 0); + @Override + public TransformationState process(Optional pendingInputOptional) + { + if (resetPagesIndex) { + pagesIndexWithHashStrategies.pagesIndex.clear(); + updateMemoryUsage(); + resetPagesIndex = false; + } - // Add the section of the page that contains values for the current group - pagesIndex.addPage(page.getRegion(0, groupEnd)); + boolean finishing = !pendingInputOptional.isPresent(); + if (finishing && pagesIndexWithHashStrategies.pagesIndex.getPositionCount() == 0) { + memoryContext.close(); + return TransformationState.finished(); + } - if (page.getPositionCount() - groupEnd > 0) { - // Save the remaining page, which may contain multiple partitions - return page.getRegion(groupEnd, page.getPositionCount() - groupEnd); + if (!finishing) { + Page pendingInput = pendingInputOptional.get(); + pendingInputPosition = updatePagesIndex(pagesIndexWithHashStrategies, pendingInput, pendingInputPosition, Optional.empty()); + updateMemoryUsage(); } - else { - // Page fully consumed - return null; + + // If we have unused input or are finishing, then we have buffered a full group + if (finishing || pendingInputPosition < pendingInputOptional.get().getPositionCount()) { + sortPagesIndexIfNecessary(pagesIndexWithHashStrategies, orderChannels, ordering); + resetPagesIndex = true; + return TransformationState.ofResult(pagesIndexWithHashStrategies, false); } + + pendingInputPosition = 0; + return TransformationState.needsMoreData(); } - else { - // We had previous results buffered, but the new page starts with new group values - return page; + + void updateMemoryUsage() + { + memoryContext.setBytes(pagesIndexWithHashStrategies.pagesIndex.getEstimatedSize().toBytes()); } } - private static Page rearrangePage(Page page, int[] channels) + private WorkProcessor pagesIndexToWindowPartitions(PagesIndexWithHashStrategies pagesIndexWithHashStrategies) { - Block[] newBlocks = new Block[channels.length]; - for (int i = 0; i < channels.length; i++) { - newBlocks[i] = page.getBlock(channels[i]); - } - return new Page(page.getPositionCount(), newBlocks); + PagesIndex pagesIndex = pagesIndexWithHashStrategies.pagesIndex; + + // pagesIndex contains the full grouped & sorted data for one or more partitions + + windowInfo.addIndex(pagesIndex); + + return WorkProcessor.create(new WorkProcessor.Process() + { + int partitionStart; + + @Override + public ProcessState process() + { + if (partitionStart == pagesIndex.getPositionCount()) { + return ProcessState.finished(); + } + + int partitionEnd = findGroupEnd(pagesIndex, pagesIndexWithHashStrategies.unGroupedPartitionHashStrategy, partitionStart); + + WindowPartition partition = new WindowPartition(pagesIndex, partitionStart, partitionEnd, outputChannels, windowFunctions, pagesIndexWithHashStrategies.peerGroupHashStrategy); + windowInfo.addPartition(partition); + partitionStart = partitionEnd; + return ProcessState.ofResult(partition); + } + }); } - @Override - public Page getOutput() + private class WindowPartitionsToOutputPages + implements Transformation { - if (state == State.NEEDS_INPUT || state == State.FINISHED) { - return null; + final PageBuilder pageBuilder; + + WindowPartitionsToOutputPages() + { + pageBuilder = new PageBuilder(outputTypes); } - Page page = extractOutput(); - localUserMemoryContext.setBytes(pagesIndex.getEstimatedSize().toBytes()); - return page; + @Override + public TransformationState process(Optional partitionOptional) + { + boolean finishing = !partitionOptional.isPresent(); + if (finishing) { + if (pageBuilder.isEmpty()) { + return TransformationState.finished(); + } + + // Output the remaining page if we have anything buffered + Page page = pageBuilder.build(); + pageBuilder.reset(); + return TransformationState.ofResult(page, false); + } + + WindowPartition partition = partitionOptional.get(); + while (!pageBuilder.isFull() && partition.hasNext()) { + partition.processNextRow(pageBuilder); + } + if (!pageBuilder.isFull()) { + return needsMoreData(); + } + + Page page = pageBuilder.build(); + pageBuilder.reset(); + return TransformationState.ofResult(page, !partition.hasNext()); + } } - private Page extractOutput() + private class SpillablePagesToPagesIndexes + implements Transformation> { - // INVARIANT: pagesIndex contains the full grouped & sorted data for one or more partitions - - // Iterate through the positions sequentially until we have one full page - while (!pageBuilder.isFull()) { - if (partition == null || !partition.hasNext()) { - int partitionStart = partition == null ? 0 : partition.getPartitionEnd(); - - if (partitionStart >= pagesIndex.getPositionCount()) { - // Finished all of the partitions in the current pagesIndex - partition = null; - pagesIndex.clear(); - - // Try to extract more partitions from the pendingInput - if (pendingInput != null && processPendingInput()) { - partitionStart = 0; - } - else if (state == State.FINISHING) { - state = State.FINISHED; - // Output the remaining page if we have anything buffered - if (!pageBuilder.isEmpty()) { - Page page = pageBuilder.build(); - pageBuilder.reset(); - return page; - } - return null; - } - else { - state = State.NEEDS_INPUT; - return null; - } + final PagesIndexWithHashStrategies inMemoryPagesIndexWithHashStrategies; + final PagesIndexWithHashStrategies mergedPagesIndexWithHashStrategies; + final List sourceTypes; + final List orderChannels; + final List ordering; + final LocalMemoryContext localRevocableMemoryContext; + final LocalMemoryContext localUserMemoryContext; + final SpillerFactory spillerFactory; + final PageWithPositionComparator pageWithPositionComparator; + + boolean spillingWhenConvertingRevocableMemory; + boolean resetPagesIndex; + int pendingInputPosition; + + Optional currentSpillGroupRowPage; + Optional spiller; + // Spill can be trigger by Driver, by us or both. `spillInProgress` is not empty when spill was triggered but not `finishMemoryRevoke()` yet + Optional> spillInProgress = Optional.empty(); + + SpillablePagesToPagesIndexes( + PagesIndexWithHashStrategies inMemoryPagesIndexWithHashStrategies, + PagesIndexWithHashStrategies mergedPagesIndexWithHashStrategies, + List sourceTypes, + List orderChannels, + List ordering, + SpillerFactory spillerFactory, + PageWithPositionComparator pageWithPositionComparator) + { + this.inMemoryPagesIndexWithHashStrategies = inMemoryPagesIndexWithHashStrategies; + this.mergedPagesIndexWithHashStrategies = mergedPagesIndexWithHashStrategies; + this.sourceTypes = sourceTypes; + this.orderChannels = orderChannels; + this.ordering = ordering; + this.localUserMemoryContext = operatorContext.aggregateUserMemoryContext().newLocalMemoryContext(SpillablePagesToPagesIndexes.class.getSimpleName()); + this.localRevocableMemoryContext = operatorContext.aggregateRevocableMemoryContext().newLocalMemoryContext(SpillablePagesToPagesIndexes.class.getSimpleName()); + this.spillerFactory = spillerFactory; + this.pageWithPositionComparator = pageWithPositionComparator; + + this.currentSpillGroupRowPage = Optional.empty(); + this.spiller = Optional.empty(); + } + + @Override + public TransformationState> process(Optional pendingInputOptional) + { + if (spillingWhenConvertingRevocableMemory) { + // Spill could already be finished by Driver (via WindowOperator#finishMemoryRevoke), but finishRevokeMemory will take care of that + finishRevokeMemory(); + spillingWhenConvertingRevocableMemory = false; + return fullGroupBuffered(); + } + + if (resetPagesIndex) { + inMemoryPagesIndexWithHashStrategies.pagesIndex.clear(); + currentSpillGroupRowPage = Optional.empty(); + + closeSpiller(); + + updateMemoryUsage(false); + resetPagesIndex = false; + } + + boolean finishing = !pendingInputOptional.isPresent(); + if (finishing && inMemoryPagesIndexWithHashStrategies.pagesIndex.getPositionCount() == 0 && !spiller.isPresent()) { + localRevocableMemoryContext.close(); + localUserMemoryContext.close(); + closeSpiller(); + return TransformationState.finished(); + } + + if (!finishing) { + Page pendingInput = pendingInputOptional.get(); + pendingInputPosition = updatePagesIndex(inMemoryPagesIndexWithHashStrategies, pendingInput, pendingInputPosition, currentSpillGroupRowPage); + } + + // If we have unused input or are finishing, then we have buffered a full group + if (finishing || pendingInputPosition < pendingInputOptional.get().getPositionCount()) { + return fullGroupBuffered(); + } + + updateMemoryUsage(true); + pendingInputPosition = 0; + return needsMoreData(); + } + + void closeSpiller() + { + spiller.ifPresent(Spiller::close); + spiller = Optional.empty(); + } + + TransformationState> fullGroupBuffered() + { + // Convert revocable memory to user memory as inMemoryPagesIndexWithHashStrategies holds on to memory so we no longer can revoke + if (localRevocableMemoryContext.getBytes() > 0) { + long currentRevocableBytes = localRevocableMemoryContext.getBytes(); + localRevocableMemoryContext.setBytes(0); + if (!localUserMemoryContext.trySetBytes(localUserMemoryContext.getBytes() + currentRevocableBytes)) { + // TODO: this might fail (even though we have just released memory), but we don't + // have a proper way to atomically convert memory reservations + localRevocableMemoryContext.setBytes(currentRevocableBytes); + spillingWhenConvertingRevocableMemory = true; + return TransformationState.blocked(spill()); } + } - int partitionEnd = findGroupEnd(pagesIndex, unGroupedPartitionHashStrategy, partitionStart); - partition = new WindowPartition(pagesIndex, partitionStart, partitionEnd, outputChannels, windowFunctions, peerGroupHashStrategy); - windowInfo.addPartition(partition); + sortPagesIndexIfNecessary(inMemoryPagesIndexWithHashStrategies, orderChannels, ordering); + resetPagesIndex = true; + return TransformationState.ofResult(unspill(), false); + } + + ListenableFuture spill() + { + if (spillInProgress.isPresent()) { + // Spill can be triggered first in SpillablePagesToPagesIndexes#process(..) and then by Driver (via WindowOperator#startMemoryRevoke) + return spillInProgress.get(); + } + + if (localRevocableMemoryContext.getBytes() == 0) { + // This must be stale revoke request + spillInProgress = Optional.of(immediateFuture(null)); + return spillInProgress.get(); + } + + if (!spiller.isPresent()) { + spiller = Optional.of(spillerFactory.create( + sourceTypes, + operatorContext.getSpillContext(), + operatorContext.newAggregateSystemMemoryContext())); + } + + verify(inMemoryPagesIndexWithHashStrategies.pagesIndex.getPositionCount() > 0); + sortPagesIndexIfNecessary(inMemoryPagesIndexWithHashStrategies, orderChannels, ordering); + PeekingIterator sortedPages = peekingIterator(inMemoryPagesIndexWithHashStrategies.pagesIndex.getSortedPages()); + Page anyPage = sortedPages.peek(); + verify(anyPage.getPositionCount() != 0, "PagesIndex.getSortedPages returned an empty page"); + currentSpillGroupRowPage = Optional.of(anyPage.getSingleValuePage(/* any */0)); + spillInProgress = Optional.of(spiller.get().spill(sortedPages)); + + return spillInProgress.get(); + } + + void finishRevokeMemory() + { + if (!spillInProgress.isPresent()) { + // Same spill iteration can be finished first by Driver (via WindowOperator#finishMemoryRevoke) and then by SpillablePagesToPagesIndexes#process(..) + return; + } + + checkSuccess(spillInProgress.get(), "spilling failed"); + spillInProgress = Optional.empty(); + + // No memory to reclaim + if (localRevocableMemoryContext.getBytes() == 0) { + return; + } + + inMemoryPagesIndexWithHashStrategies.pagesIndex.clear(); + updateMemoryUsage(false); + } + + WorkProcessor unspill() + { + if (!spiller.isPresent()) { + return WorkProcessor.fromIterable(ImmutableList.of(inMemoryPagesIndexWithHashStrategies)); } - partition.processNextRow(pageBuilder); + List> sortedStreams = ImmutableList.>builder() + .addAll(spiller.get().getSpills().stream() + .map(WorkProcessor::fromIterator) + .collect(toImmutableList())) + .add(WorkProcessor.fromIterator(inMemoryPagesIndexWithHashStrategies.pagesIndex.getSortedPages())) + .build(); + + WorkProcessor mergedPages = mergeSortedPages( + sortedStreams, + pageWithPositionComparator, + sourceTypes, + operatorContext.aggregateUserMemoryContext(), + operatorContext.getDriverContext().getYieldSignal()); + + return mergedPages.transform(new PagesToPagesIndexes(mergedPagesIndexWithHashStrategies, ImmutableList.of(), ImmutableList.of())); } - Page page = pageBuilder.build(); - pageBuilder.reset(); - return page; + void updateMemoryUsage(boolean revocablePagesIndex) + { + long pagesIndexBytes = inMemoryPagesIndexWithHashStrategies.pagesIndex.getEstimatedSize().toBytes(); + if (revocablePagesIndex) { + verify(inMemoryPagesIndexWithHashStrategies.pagesIndex.getPositionCount() > 0); + localUserMemoryContext.setBytes(0); + localRevocableMemoryContext.setBytes(pagesIndexBytes); + } + else { + localRevocableMemoryContext.setBytes(0L); + localUserMemoryContext.setBytes(pagesIndexBytes); + } + } } - private void sortPagesIndexIfNecessary() + private int updatePagesIndex(PagesIndexWithHashStrategies pagesIndexWithHashStrategies, Page page, int startPosition, Optional currentSpillGroupRowPage) { - if (pagesIndex.getPositionCount() > 1 && !orderChannels.isEmpty()) { - int startPosition = 0; - while (startPosition < pagesIndex.getPositionCount()) { - int endPosition = findGroupEnd(pagesIndex, preSortedPartitionHashStrategy, startPosition); - pagesIndex.sort(orderChannels, ordering, startPosition, endPosition); - startPosition = endPosition; + checkArgument(page.getPositionCount() > startPosition); + + // TODO: Fix pagesHashStrategy to allow specifying channels for comparison, it currently requires us to rearrange the right side blocks in consecutive channel order + Page preGroupedPage = rearrangePage(page, pagesIndexWithHashStrategies.preGroupedPartitionChannels); + + PagesIndex pagesIndex = pagesIndexWithHashStrategies.pagesIndex; + PagesHashStrategy preGroupedPartitionHashStrategy = pagesIndexWithHashStrategies.preGroupedPartitionHashStrategy; + if (currentSpillGroupRowPage.isPresent()) { + if (!preGroupedPartitionHashStrategy.rowEqualsRow(0, rearrangePage(currentSpillGroupRowPage.get(), pagesIndexWithHashStrategies.preGroupedPartitionChannels), startPosition, preGroupedPage)) { + return startPosition; + } + } + + if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionEqualsRow(preGroupedPartitionHashStrategy, 0, startPosition, preGroupedPage)) { + // Find the position where the pre-grouped columns change + int groupEnd = findGroupEnd(preGroupedPage, preGroupedPartitionHashStrategy, startPosition); + + // Add the section of the page that contains values for the current group + pagesIndex.addPage(page.getRegion(startPosition, groupEnd - startPosition)); + + if (page.getPositionCount() - groupEnd > 0) { + // Save the remaining page, which may contain multiple partitions + return groupEnd; + } + else { + // Page fully consumed + return page.getPositionCount(); } } + else { + // We had previous results buffered, but the remaining page starts with new group values + return startPosition; + } } - private void finishPagesIndex() + private static Page rearrangePage(Page page, int[] channels) { - sortPagesIndexIfNecessary(); - windowInfo.addIndex(pagesIndex); + Block[] newBlocks = new Block[channels.length]; + for (int i = 0; i < channels.length; i++) { + newBlocks[i] = page.getBlock(channels[i]); + } + return new Page(page.getPositionCount(), newBlocks); + } + + private void sortPagesIndexIfNecessary(PagesIndexWithHashStrategies pagesIndexWithHashStrategies, List orderChannels, List ordering) + { + if (pagesIndexWithHashStrategies.pagesIndex.getPositionCount() > 1 && !orderChannels.isEmpty()) { + int startPosition = 0; + while (startPosition < pagesIndexWithHashStrategies.pagesIndex.getPositionCount()) { + int endPosition = findGroupEnd(pagesIndexWithHashStrategies.pagesIndex, pagesIndexWithHashStrategies.preSortedPartitionHashStrategy, startPosition); + pagesIndexWithHashStrategies.pagesIndex.sort(orderChannels, ordering, startPosition, endPosition); + startPosition = endPosition; + } + } } // Assumes input grouped on relevant pagesHashStrategy columns @@ -497,5 +865,6 @@ static int findEndPosition(int startPosition, int endPosition, BiPredicate> getSpills() @Override public void close() { + spills.clear(); } }; } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/OperatorAssertion.java b/presto-main/src/test/java/com/facebook/presto/operator/OperatorAssertion.java index 133ac6a4ee7c6..a6e9f55b269f0 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/OperatorAssertion.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/OperatorAssertion.java @@ -253,7 +253,17 @@ public static void assertOperatorEqualsIgnoreOrder( List input, MaterializedResult expected) { - assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, false, Optional.empty()); + assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, false); + } + + public static void assertOperatorEqualsIgnoreOrder( + OperatorFactory operatorFactory, + DriverContext driverContext, + List input, + MaterializedResult expected, + boolean revokeMemoryWhenAddingPages) + { + assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, false, Optional.empty(), revokeMemoryWhenAddingPages); } public static void assertOperatorEqualsIgnoreOrder( diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java index fa817befc22d4..7c24303656894 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java @@ -27,13 +27,17 @@ import com.facebook.presto.operator.window.ReflectionWindowFunctionSupplier; import com.facebook.presto.operator.window.RowNumberFunction; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spiller.SpillerFactory; +import com.facebook.presto.sql.gen.OrderingCompiler; import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import io.airlift.units.DataSize; import io.airlift.units.DataSize.Unit; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; @@ -42,6 +46,7 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.testing.Assertions.assertGreaterThan; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -50,6 +55,7 @@ import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.operator.OperatorAssertion.assertOperatorEquals; import static com.facebook.presto.operator.OperatorAssertion.assertOperatorEqualsIgnoreOrder; +import static com.facebook.presto.operator.OperatorAssertion.toMaterializedResult; import static com.facebook.presto.operator.OperatorAssertion.toPages; import static com.facebook.presto.operator.WindowFunctionDefinition.window; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING; @@ -57,9 +63,12 @@ import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; +import static io.airlift.units.DataSize.succinctBytes; +import static java.lang.String.format; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; @Test(singleThreaded = true) public class TestWindowOperator @@ -86,16 +95,14 @@ public class TestWindowOperator private ExecutorService executor; private ScheduledExecutorService scheduledExecutor; - private DriverContext driverContext; + private DummySpillerFactory spillerFactory; @BeforeMethod public void setUp() { executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); - driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) - .addPipelineContext(0, true, true, false) - .addDriverContext(); + spillerFactory = new DummySpillerFactory(); } @AfterMethod @@ -103,10 +110,56 @@ public void tearDown() { executor.shutdownNow(); scheduledExecutor.shutdownNow(); + spillerFactory = null; } - @Test - public void testRowNumber() + @DataProvider + public static Object[][] spillEnabled() + { + return new Object[][] { + {false, false, 0}, + {true, false, 8}, + {true, true, 8}, + {true, false, 0}, + {true, true, 0}}; + } + + @Test(dataProvider = "spillEnabled") + public void testMultipleOutputPages(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + { + // make operator produce multiple pages during finish phase + int numberOfRows = 80_000; + List input = rowPagesBuilder(BIGINT, DOUBLE) + .addSequencePage(numberOfRows, 0, 0) + .build(); + + WindowOperatorFactory operatorFactory = createFactoryUnbounded( + ImmutableList.of(BIGINT, DOUBLE), + Ints.asList(1, 0), + ROW_NUMBER, + Ints.asList(), + Ints.asList(0), + ImmutableList.copyOf(new SortOrder[] {SortOrder.DESC_NULLS_FIRST}), + spillEnabled); + + DriverContext driverContext = createDriverContext(memoryLimit); + MaterializedResult.Builder expectedBuilder = resultBuilder(driverContext.getSession(), DOUBLE, BIGINT, BIGINT); + for (int i = 0; i < numberOfRows; ++i) { + expectedBuilder.row((double) numberOfRows - i - 1, (long) numberOfRows - i - 1, (long) i + 1); + } + MaterializedResult expected = expectedBuilder.build(); + + List pages = toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages); + assertGreaterThan(pages.size(), 1, "Expected more than one output page"); + + MaterializedResult actual = toMaterializedResult(driverContext.getSession(), expected.getTypes(), pages); + assertEquals(actual.getMaterializedRows(), expected.getMaterializedRows()); + + assertTrue(spillEnabled == (spillerFactory.getSpillsCount() > 0), format("Spill state mismatch. Expected spill: %s, spill count: %s", spillEnabled, spillerFactory.getSpillsCount())); + } + + @Test(dataProvider = "spillEnabled") + public void testRowNumber(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, DOUBLE) .row(2L, 0.3) @@ -123,8 +176,10 @@ public void testRowNumber() ROW_NUMBER, Ints.asList(), Ints.asList(0), - ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST})); + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), DOUBLE, BIGINT, BIGINT) .row(-0.1, -1L, 1L) .row(0.3, 2L, 2L) @@ -133,11 +188,11 @@ public void testRowNumber() .row(0.1, 6L, 5L) .build(); - assertOperatorEquals(operatorFactory, driverContext, input, expected); + assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testRowNumberPartition() + @Test(dataProvider = "spillEnabled") + public void testRowNumberPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, BIGINT, DOUBLE, BOOLEAN) .row("b", -1L, -0.1, true) @@ -154,8 +209,10 @@ public void testRowNumberPartition() ROW_NUMBER, Ints.asList(0), Ints.asList(1), - ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST})); + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT, DOUBLE, BOOLEAN, BIGINT) .row("a", 2L, 0.3, false, 1L) .row("a", 4L, 0.2, true, 2L) @@ -164,7 +221,7 @@ public void testRowNumberPartition() .row("b", 5L, 0.4, false, 2L) .build(); - assertOperatorEquals(operatorFactory, driverContext, input, expected); + assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } @Test @@ -188,8 +245,10 @@ public void testRowNumberArbitrary() ROW_NUMBER, Ints.asList(), Ints.asList(), - ImmutableList.copyOf(new SortOrder[] {})); + ImmutableList.copyOf(new SortOrder[] {}), + false); + DriverContext driverContext = createDriverContext(); MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, BIGINT) .row(1L, 1L) .row(3L, 2L) @@ -204,6 +263,45 @@ public void testRowNumberArbitrary() assertOperatorEquals(operatorFactory, driverContext, input, expected); } + @Test + public void testRowNumberArbitraryWithSpill() + { + List input = rowPagesBuilder(BIGINT) + .row(1L) + .row(3L) + .row(5L) + .row(7L) + .pageBreak() + .row(2L) + .row(4L) + .row(6L) + .row(8L) + .build(); + + WindowOperatorFactory operatorFactory = createFactoryUnbounded( + ImmutableList.of(BIGINT), + Ints.asList(0), + ROW_NUMBER, + Ints.asList(), + Ints.asList(), + ImmutableList.copyOf(new SortOrder[] {}), + true); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, BIGINT) + .row(1L, 1L) + .row(2L, 2L) + .row(3L, 3L) + .row(4L, 4L) + .row(5L, 5L) + .row(6L, 6L) + .row(7L, 7L) + .row(8L, 8L) + .build(); + + assertOperatorEquals(operatorFactory, driverContext, input, expected); + } + @Test(expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded per-node user memory limit of 10B.*") public void testMemoryLimit() { @@ -225,13 +323,14 @@ public void testMemoryLimit() ROW_NUMBER, Ints.asList(), Ints.asList(0), - ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST})); + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + false); toPages(operatorFactory, driverContext, input); } - @Test - public void testFirstValuePartition() + @Test(dataProvider = "spillEnabled") + public void testFirstValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) .row("b", "A1", 1L, true, "") @@ -249,8 +348,10 @@ public void testFirstValuePartition() FIRST_VALUE, Ints.asList(0), Ints.asList(2), - ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST})); + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) .row("a", "A2", 1L, false, "A2") .row("a", "B1", 2L, true, "A2") @@ -260,11 +361,11 @@ public void testFirstValuePartition() .row("c", "A3", 1L, true, "A3") .build(); - assertOperatorEquals(operatorFactory, driverContext, input, expected); + assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testLastValuePartition() + @Test(dataProvider = "spillEnabled") + public void testLastValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) .row("b", "A1", 1L, true, "") @@ -276,13 +377,15 @@ public void testLastValuePartition() .row("c", "A3", 1L, true, "") .build(); + DriverContext driverContext = createDriverContext(memoryLimit); WindowOperatorFactory operatorFactory = createFactoryUnbounded( ImmutableList.of(VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR), Ints.asList(0, 1, 2, 3), LAST_VALUE, Ints.asList(0), Ints.asList(2), - ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST})); + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + spillEnabled); MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) .row("a", "A2", 1L, false, "C2") @@ -292,11 +395,11 @@ public void testLastValuePartition() .row("b", "C1", 2L, false, "C1") .row("c", "A3", 1L, true, "A3") .build(); - assertOperatorEquals(operatorFactory, driverContext, input, expected); + assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testNthValuePartition() + @Test(dataProvider = "spillEnabled") + public void testNthValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BIGINT, BOOLEAN, VARCHAR) .row("b", "A1", 1L, 2L, true, "") @@ -314,8 +417,10 @@ public void testNthValuePartition() NTH_VALUE, Ints.asList(0), Ints.asList(2), - ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST})); + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) .row("a", "A2", 1L, false, "C2") .row("a", "B1", 2L, true, "B1") @@ -325,11 +430,11 @@ public void testNthValuePartition() .row("c", "A3", 1L, true, null) .build(); - assertOperatorEquals(operatorFactory, driverContext, input, expected); + assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testLagPartition() + @Test(dataProvider = "spillEnabled") + public void testLagPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BIGINT, VARCHAR, BOOLEAN, VARCHAR) .row("b", "A1", 1L, 1L, "D", true, "") @@ -347,8 +452,10 @@ public void testLagPartition() LAG, Ints.asList(0), Ints.asList(2), - ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST})); + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) .row("a", "A2", 1L, false, "D") .row("a", "B1", 2L, true, "D") @@ -358,11 +465,11 @@ public void testLagPartition() .row("c", "A3", 1L, true, "D") .build(); - assertOperatorEquals(operatorFactory, driverContext, input, expected); + assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testLeadPartition() + @Test(dataProvider = "spillEnabled") + public void testLeadPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BIGINT, VARCHAR, BOOLEAN, VARCHAR) .row("b", "A1", 1L, 1L, "D", true, "") @@ -380,8 +487,10 @@ public void testLeadPartition() LEAD, Ints.asList(0), Ints.asList(2), - ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST})); + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) .row("a", "A2", 1L, false, "C2") .row("a", "B1", 2L, true, "D") @@ -391,11 +500,11 @@ public void testLeadPartition() .row("c", "A3", 1L, true, "D") .build(); - assertOperatorEquals(operatorFactory, driverContext, input, expected); + assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testPartiallyPreGroupedPartitionWithEmptyInput() + @Test(dataProvider = "spillEnabled") + public void testPartiallyPreGroupedPartitionWithEmptyInput(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -410,16 +519,18 @@ public void testPartiallyPreGroupedPartitionWithEmptyInput() Ints.asList(1), Ints.asList(3), ImmutableList.of(SortOrder.ASC_NULLS_LAST), - 0); + 0, + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, VARCHAR, BIGINT, VARCHAR, BIGINT) .build(); - assertOperatorEquals(operatorFactory, driverContext, input, expected); + assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testPartiallyPreGroupedPartition() + @Test(dataProvider = "spillEnabled") + public void testPartiallyPreGroupedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -442,8 +553,10 @@ public void testPartiallyPreGroupedPartition() Ints.asList(1), Ints.asList(3), ImmutableList.of(SortOrder.ASC_NULLS_LAST), - 0); + 0, + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, VARCHAR, BIGINT, VARCHAR, BIGINT) .row(1L, "a", 100L, "A", 1L) .row(2L, "a", 101L, "B", 1L) @@ -453,11 +566,11 @@ public void testPartiallyPreGroupedPartition() .row(1L, "c", 105L, "F", 1L) .build(); - assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected); + assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testFullyPreGroupedPartition() + @Test(dataProvider = "spillEnabled") + public void testFullyPreGroupedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -481,8 +594,10 @@ public void testFullyPreGroupedPartition() Ints.asList(0, 1), Ints.asList(3), ImmutableList.of(SortOrder.ASC_NULLS_LAST), - 0); + 0, + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, VARCHAR, BIGINT, VARCHAR, BIGINT) .row(1L, "a", 100L, "A", 1L) .row(2L, "a", 101L, "B", 1L) @@ -493,11 +608,11 @@ public void testFullyPreGroupedPartition() .row(3L, "c", 106L, "G", 1L) .build(); - assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected); + assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testFullyPreGroupedAndPartiallySortedPartition() + @Test(dataProvider = "spillEnabled") + public void testFullyPreGroupedAndPartiallySortedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -522,8 +637,10 @@ public void testFullyPreGroupedAndPartiallySortedPartition() Ints.asList(0, 1), Ints.asList(3, 2), ImmutableList.of(SortOrder.ASC_NULLS_LAST, SortOrder.ASC_NULLS_LAST), - 1); + 1, + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, VARCHAR, BIGINT, VARCHAR, BIGINT) .row(1L, "a", 100L, "A", 1L) .row(2L, "a", 100L, "A", 1L) @@ -535,11 +652,11 @@ public void testFullyPreGroupedAndPartiallySortedPartition() .row(3L, "c", 100L, "A", 1L) .build(); - assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected); + assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test - public void testFullyPreGroupedAndFullySortedPartition() + @Test(dataProvider = "spillEnabled") + public void testFullyPreGroupedAndFullySortedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -564,8 +681,10 @@ public void testFullyPreGroupedAndFullySortedPartition() Ints.asList(0, 1), Ints.asList(3), ImmutableList.of(SortOrder.ASC_NULLS_LAST), - 1); + 1, + spillEnabled); + DriverContext driverContext = createDriverContext(memoryLimit); MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, VARCHAR, BIGINT, VARCHAR, BIGINT) .row(1L, "a", 100L, "A", 1L) .row(2L, "a", 101L, "A", 1L) @@ -578,7 +697,7 @@ public void testFullyPreGroupedAndFullySortedPartition() .build(); // Since fully grouped and sorted already, should respect original input order - assertOperatorEquals(operatorFactory, driverContext, input, expected); + assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } @Test @@ -613,13 +732,14 @@ private static void assertFindEndPosition(String values, int expected) assertEquals(WindowOperator.findEndPosition(0, array.length, (first, second) -> array[first] == array[second]), expected); } - private static WindowOperatorFactory createFactoryUnbounded( + private WindowOperatorFactory createFactoryUnbounded( List sourceTypes, List outputChannels, List functions, List partitionChannels, List sortChannels, - List sortOrder) + List sortOrder, + boolean spillEnabled) { return createFactoryUnbounded( sourceTypes, @@ -629,7 +749,37 @@ private static WindowOperatorFactory createFactoryUnbounded( ImmutableList.of(), sortChannels, sortOrder, - 0); + 0, + spillEnabled); + } + + public WindowOperatorFactory createFactoryUnbounded( + List sourceTypes, + List outputChannels, + List functions, + List partitionChannels, + List preGroupedChannels, + List sortChannels, + List sortOrder, + int preSortedChannelPrefix, + boolean spillEnabled) + { + return new WindowOperatorFactory( + 0, + new PlanNodeId("test"), + sourceTypes, + outputChannels, + functions, + partitionChannels, + preGroupedChannels, + sortChannels, + sortOrder, + preSortedChannelPrefix, + 10, + new PagesIndex.TestingFactory(false), + spillEnabled, + spillerFactory, + new OrderingCompiler()); } public static WindowOperatorFactory createFactoryUnbounded( @@ -640,7 +790,9 @@ public static WindowOperatorFactory createFactoryUnbounded( List preGroupedChannels, List sortChannels, List sortOrder, - int preSortedChannelPrefix) + int preSortedChannelPrefix, + SpillerFactory spillerFactory, + boolean spillEnabled) { return new WindowOperatorFactory( 0, @@ -654,6 +806,23 @@ public static WindowOperatorFactory createFactoryUnbounded( sortOrder, preSortedChannelPrefix, 10, - new PagesIndex.TestingFactory(false)); + new PagesIndex.TestingFactory(false), + spillEnabled, + spillerFactory, + new OrderingCompiler()); + } + + private DriverContext createDriverContext() + { + return createDriverContext(Long.MAX_VALUE); + } + + private DriverContext createDriverContext(long memoryLimit) + { + return TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) + .setMemoryPoolSize(succinctBytes(memoryLimit)) + .build() + .addPipelineContext(0, true, true, false) + .addDriverContext(); } } diff --git a/presto-memory-context/src/main/java/com/facebook/presto/memory/context/MemoryTrackingContext.java b/presto-memory-context/src/main/java/com/facebook/presto/memory/context/MemoryTrackingContext.java index 80189a364e436..254771cd27925 100644 --- a/presto-memory-context/src/main/java/com/facebook/presto/memory/context/MemoryTrackingContext.java +++ b/presto-memory-context/src/main/java/com/facebook/presto/memory/context/MemoryTrackingContext.java @@ -112,6 +112,11 @@ public AggregatedMemoryContext aggregateUserMemoryContext() return userAggregateMemoryContext; } + public AggregatedMemoryContext aggregateRevocableMemoryContext() + { + return revocableAggregateMemoryContext; + } + public AggregatedMemoryContext newAggregateSystemMemoryContext() { return systemAggregateMemoryContext.newAggregatedMemoryContext(); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 4889637f591f8..068f701e23281 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -17,7 +17,6 @@ import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.type.Decimals; import com.facebook.presto.common.type.SqlTimestampWithTimeZone; -import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.metadata.FunctionListBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.SqlFunction; @@ -62,7 +61,6 @@ import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.INFORMATION_SCHEMA; import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; import static com.facebook.presto.operator.scalar.InvokeFunction.INVOKE_FUNCTION; @@ -468,32 +466,6 @@ public void testTryMapTransformValueFunction() assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows()); } - @Test - public void testRowFieldAccessorInWindowFunction() - { - assertQuery("SELECT a.col0, " + - "SUM(a.col1[1].col1) OVER(PARTITION BY a.col2.col0), " + - "SUM(a.col2.col1) OVER(PARTITION BY a.col2.col0) FROM " + - "(VALUES " + - "ROW(CAST(ROW(1.0, ARRAY[row(31, 14.5E0), row(12, 4.2E0)], row(3, 4.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + - "ROW(CAST(ROW(2.2, ARRAY[row(41, 13.1E0), row(32, 4.2E0)], row(6, 6.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + - "ROW(CAST(ROW(2.2, ARRAY[row(41, 17.1E0), row(45, 4.2E0)], row(7, 16.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + - "ROW(CAST(ROW(2.2, ARRAY[row(41, 13.1E0), row(32, 4.2E0)], row(6, 6.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + - "ROW(CAST(ROW(3.1, ARRAY[row(41, 13.1E0), row(32, 4.2E0)], row(6, 6.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double))))) t(a) ", - "SELECT * FROM VALUES (1.0, 14.5, 4.0), (2.2, 39.3, 18.0), (2.2, 39.3, 18.0), (2.2, 17.1, 16.0), (3.1, 39.3, 18.0)"); - - assertQuery("SELECT a.col1[1].col0, " + - "SUM(a.col0) OVER(PARTITION BY a.col1[1].col0), " + - "SUM(a.col1[1].col1) OVER(PARTITION BY a.col1[1].col0), " + - "SUM(a.col2.col1) OVER(PARTITION BY a.col1[1].col0) FROM " + - "(VALUES " + - "ROW(CAST(ROW(1.0, ARRAY[row(31, 14.5E0), row(12, 4.2E0)], row(3, 4.0E0)) AS ROW(col0 double, col1 array(row(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + - "ROW(CAST(ROW(3.1, ARRAY[row(41, 13.1E0), row(32, 4.2E0)], row(6, 6.0E0)) AS ROW(col0 double, col1 array(row(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + - "ROW(CAST(ROW(2.2, ARRAY[row(31, 14.2E0), row(22, 5.2E0)], row(5, 4.0E0)) AS ROW(col0 double, col1 array(row(col0 integer, col1 double)), col2 row(col0 integer, col1 double))))) t(a) " + - "WHERE a.col1[2].col1 > a.col2.col0", - "SELECT * FROM VALUES (31, 3.2, 28.7, 8.0), (31, 3.2, 28.7, 8.0)"); - } - @Test public void testRowFieldAccessorInJoin() { @@ -905,19 +877,6 @@ public void testDistinctHaving() "HAVING COUNT(DISTINCT clerk) > 1"); } - @Test - public void testDistinctWindow() - { - MaterializedResult actual = computeActual( - "SELECT RANK() OVER (PARTITION BY orderdate ORDER BY COUNT(DISTINCT clerk)) rnk " + - "FROM orders " + - "GROUP BY orderdate, custkey " + - "ORDER BY rnk " + - "LIMIT 1"); - MaterializedResult expected = resultBuilder(getSession(), BIGINT).row(1L).build(); - assertEquals(actual, expected); - } - @Test public void testDistinctLimit() { @@ -1264,29 +1223,6 @@ public void testGroupingWithFortyArguments() assertQuery(query, "VALUES (0), (822283861886), (995358664191)"); } - @Test - public void testGroupingInWindowFunction() - { - assertQuery( - "SELECT orderkey, custkey, sum(totalprice), grouping(orderkey)+grouping(custkey) AS g, " + - " rank() OVER (PARTITION BY grouping(orderkey)+grouping(custkey), " + - " CASE WHEN grouping(orderkey) = 0 THEN custkey END ORDER BY orderkey ASC) AS r " + - "FROM orders " + - "GROUP BY ROLLUP (orderkey, custkey) " + - "ORDER BY orderkey, custkey " + - "LIMIT 10", - "VALUES (1, 370, 172799.49, 0, 1), " + - " (1, NULL, 172799.49, 1, 1), " + - " (2, 781, 38426.09, 0, 1), " + - " (2, NULL, 38426.09, 1, 2), " + - " (3, 1234, 205654.30, 0, 1), " + - " (3, NULL, 205654.30, 1, 3), " + - " (4, 1369, 56000.91, 0, 1), " + - " (4, NULL, 56000.91, 1, 4), " + - " (5, 445, 105367.67, 0, 1), " + - " (5, NULL, 105367.67, 1, 5)"); - } - @Test public void testGroupingInTableSubquery() { @@ -2940,163 +2876,6 @@ public void testMinByN() "SELECT orderkey FROM orders ORDER BY totalprice ASC LIMIT 2"); } - @Test - public void testWindowImplicitCoercion() - { - assertQueryOrdered( - "SELECT orderkey, 1e0 / row_number() OVER (ORDER BY orderkey) FROM orders LIMIT 2", - "VALUES (1, 1.0), (2, 0.5)"); - } - - @Test - public void testWindowsSameOrdering() - { - MaterializedResult actual = computeActual("SELECT " + - "sum(quantity) OVER(PARTITION BY suppkey ORDER BY orderkey)," + - "min(tax) OVER(PARTITION BY suppkey ORDER BY shipdate)" + - "FROM lineitem " + - "ORDER BY 1 " + - "LIMIT 10"); - - MaterializedResult expected = resultBuilder(getSession(), DOUBLE, DOUBLE) - .row(1.0, 0.0) - .row(2.0, 0.0) - .row(2.0, 0.0) - .row(3.0, 0.0) - .row(3.0, 0.0) - .row(4.0, 0.0) - .row(4.0, 0.0) - .row(5.0, 0.0) - .row(5.0, 0.0) - .row(5.0, 0.0) - .build(); - - assertEquals(actual, expected); - } - - @Test - public void testWindowsPrefixPartitioning() - { - MaterializedResult actual = computeActual("SELECT " + - "max(tax) OVER(PARTITION BY suppkey, tax ORDER BY receiptdate)," + - "sum(quantity) OVER(PARTITION BY suppkey ORDER BY orderkey)" + - "FROM lineitem " + - "ORDER BY 2, 1 " + - "LIMIT 10"); - - MaterializedResult expected = resultBuilder(getSession(), DOUBLE, DOUBLE) - .row(0.06, 1.0) - .row(0.02, 2.0) - .row(0.06, 2.0) - .row(0.02, 3.0) - .row(0.08, 3.0) - .row(0.03, 4.0) - .row(0.03, 4.0) - .row(0.02, 5.0) - .row(0.03, 5.0) - .row(0.07, 5.0) - .build(); - - assertEquals(actual, expected); - } - - @Test - public void testWindowsDifferentPartitions() - { - MaterializedResult actual = computeActual("SELECT " + - "sum(quantity) OVER(PARTITION BY suppkey ORDER BY orderkey)," + - "count(discount) OVER(PARTITION BY partkey ORDER BY receiptdate)," + - "min(tax) OVER(PARTITION BY suppkey, tax ORDER BY receiptdate)" + - "FROM lineitem " + - "ORDER BY 1, 2 " + - "LIMIT 10"); - - MaterializedResult expected = resultBuilder(getSession(), DOUBLE, BIGINT, DOUBLE) - .row(1.0, 10L, 0.06) - .row(2.0, 4L, 0.06) - .row(2.0, 16L, 0.02) - .row(3.0, 3L, 0.08) - .row(3.0, 38L, 0.02) - .row(4.0, 10L, 0.03) - .row(4.0, 10L, 0.03) - .row(5.0, 9L, 0.03) - .row(5.0, 13L, 0.07) - .row(5.0, 15L, 0.02) - .build(); - - assertEquals(actual, expected); - } - - @Test - public void testWindowsConstantExpression() - { - assertQueryOrdered( - "SELECT " + - "sum(size) OVER(PARTITION BY type ORDER BY brand)," + - "lag(partkey, 1) OVER(PARTITION BY type ORDER BY name)" + - "FROM part " + - "ORDER BY 1, 2 " + - "LIMIT 10", - "VALUES " + - "(1, 315), " + - "(1, 881), " + - "(1, 1009), " + - "(3, 1087), " + - "(3, 1187), " + - "(3, 1529), " + - "(4, 969), " + - "(5, 151), " + - "(5, 505), " + - "(5, 872)"); - } - - @Test - public void testDependentWindows() - { - // For such query as below generated plan has two adjacent window nodes where second depends on output of first. - - String sql = "WITH " + - "t1 AS (" + - "SELECT extendedprice FROM lineitem ORDER BY orderkey, partkey LIMIT 2)," + - "t2 AS (" + - "SELECT extendedprice, sum(extendedprice) OVER() AS x FROM t1)," + - "t3 AS (" + - "SELECT max(x) OVER() FROM t2) " + - "SELECT * FROM t3"; - - assertQuery(sql, "VALUES 59645.36, 59645.36"); - } - - @Test - public void testWindowFunctionWithoutParameters() - { - MaterializedResult actual = computeActual("SELECT count() over(partition by custkey) FROM orders WHERE custkey < 3 ORDER BY custkey"); - - MaterializedResult expected = resultBuilder(getSession(), BIGINT) - .row(9L) - .row(9L) - .row(9L) - .row(9L) - .row(9L) - .row(9L) - .row(9L) - .row(9L) - .row(9L) - .row(10L) - .row(10L) - .row(10L) - .row(10L) - .row(10L) - .row(10L) - .row(10L) - .row(10L) - .row(10L) - .row(10L) - .build(); - - assertEquals(actual, expected); - } - @Test public void testHaving() { @@ -3159,107 +2938,6 @@ public void testColumnAliases() "SELECT custkey, orderstatus, totalprice + 1 FROM orders"); } - @Test - public void testWindowFunctionWithImplicitCoercion() - { - assertQuery("SELECT *, 1.0 * sum(x) OVER () FROM (VALUES 1) t(x)", "SELECT 1, 1.0"); - } - - @SuppressWarnings("PointlessArithmeticExpression") - @Test - public void testWindowFunctionsExpressions() - { - assertQueryOrdered( - "SELECT orderkey, orderstatus " + - ", row_number() OVER (ORDER BY orderkey * 2) * " + - " row_number() OVER (ORDER BY orderkey DESC) + 100 " + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) x " + - "ORDER BY orderkey LIMIT 5", - "VALUES " + - "(1, 'O', 110), " + - "(2, 'O', 118), " + - "(3, 'F', 124), " + - "(4, 'O', 128), " + - "(5, 'F', 130)"); - } - - @Test - public void testWindowFunctionsFromAggregate() - { - MaterializedResult actual = computeActual("" + - "SELECT * FROM (\n" + - " SELECT orderstatus, clerk, sales\n" + - " , rank() OVER (PARTITION BY x.orderstatus ORDER BY sales DESC) rnk\n" + - " FROM (\n" + - " SELECT orderstatus, clerk, sum(totalprice) sales\n" + - " FROM orders\n" + - " GROUP BY orderstatus, clerk\n" + - " ) x\n" + - ") x\n" + - "WHERE rnk <= 2\n" + - "ORDER BY orderstatus, rnk"); - - MaterializedResult expected = resultBuilder(getSession(), VARCHAR, VARCHAR, DOUBLE, BIGINT) - .row("F", "Clerk#000000090", 2784836.61, 1L) - .row("F", "Clerk#000000084", 2674447.15, 2L) - .row("O", "Clerk#000000500", 2569878.29, 1L) - .row("O", "Clerk#000000050", 2500162.92, 2L) - .row("P", "Clerk#000000071", 841820.99, 1L) - .row("P", "Clerk#000001000", 643679.49, 2L) - .build(); - - assertEquals(actual.getMaterializedRows(), expected.getMaterializedRows()); - } - - @Test - public void testOrderByWindowFunction() - { - assertQueryOrdered( - "SELECT orderkey, row_number() OVER (ORDER BY orderkey) " + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + - "ORDER BY 2 DESC " + - "LIMIT 5", - "VALUES (34, 10), " + - "(33, 9), " + - "(32, 8), " + - "(7, 7), " + - "(6, 6)"); - } - - @Test - public void testSameWindowFunctionsTwoCoerces() - { - MaterializedResult actual = computeActual("" + - "SELECT 12.0E0 * row_number() OVER ()/row_number() OVER(),\n" + - "row_number() OVER()\n" + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10)\n" + - "ORDER BY 2 DESC\n" + - "LIMIT 5"); - - MaterializedResult expected = resultBuilder(getSession(), DOUBLE, BIGINT) - .row(12.0, 10L) - .row(12.0, 9L) - .row(12.0, 8L) - .row(12.0, 7L) - .row(12.0, 6L) - .build(); - - assertEquals(actual, expected); - - actual = computeActual("" + - "SELECT (MAX(x.a) OVER () - x.a) * 100.0E0 / MAX(x.a) OVER ()\n" + - "FROM (VALUES 1, 2, 3, 4) x(a)"); - - expected = resultBuilder(getSession(), DOUBLE) - .row(75.0) - .row(50.0) - .row(25.0) - .row(0.0) - .build(); - - assertEquals(actual, expected); - } - @Test public void testRowNumberNoOptimization() { @@ -3524,50 +3202,6 @@ public void testRowNumberPropertyDerivation() "(34, 'O', 21)"); } - @Test - public void testWindowMapAgg() - { - MaterializedResult actual = computeActual("" + - "SELECT map_agg(orderkey, orderpriority) OVER(PARTITION BY orderstatus) FROM\n" + - "(SELECT * FROM orders ORDER BY orderkey LIMIT 5) t"); - MaterializedResult expected = resultBuilder(getSession(), mapType(BIGINT, VarcharType.createVarcharType(1))) - .row(ImmutableMap.of(1L, "5-LOW", 2L, "1-URGENT", 4L, "5-LOW")) - .row(ImmutableMap.of(1L, "5-LOW", 2L, "1-URGENT", 4L, "5-LOW")) - .row(ImmutableMap.of(1L, "5-LOW", 2L, "1-URGENT", 4L, "5-LOW")) - .row(ImmutableMap.of(3L, "5-LOW", 5L, "5-LOW")) - .row(ImmutableMap.of(3L, "5-LOW", 5L, "5-LOW")) - .build(); - assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows()); - } - - @Test - public void testWindowPropertyDerivation() - { - assertQuery( - "SELECT orderstatus, orderkey, " + - "SUM(s) OVER (PARTITION BY orderstatus), " + - "SUM(s) OVER (PARTITION BY orderstatus, orderkey), " + - "SUM(s) OVER (PARTITION BY orderstatus ORDER BY orderkey), " + - "SUM(s) OVER (ORDER BY orderstatus, orderkey) " + - "FROM ( " + - " SELECT orderkey, orderstatus, SUM(orderkey) OVER (ORDER BY orderstatus, orderkey) s " + - " FROM ( " + - " SELECT * FROM orders ORDER BY orderkey LIMIT 10 " + - " ) " + - ")", - "VALUES " + - "('F', 3, 72, 3, 3, 3), " + - "('F', 5, 72, 8, 11, 11), " + - "('F', 6, 72, 14, 25, 25), " + - "('F', 33, 72, 47, 72, 72), " + - "('O', 1, 433, 48, 48, 120), " + - "('O', 2, 433, 50, 98, 170), " + - "('O', 4, 433, 54, 152, 224), " + - "('O', 7, 433, 61, 213, 285), " + - "('O', 32, 433, 93, 306, 378), " + - "('O', 34, 433, 127, 433, 505)"); - } - @Test public void testTopNUnpartitionedWindow() { @@ -3694,200 +3328,6 @@ public void testTopNPartitionedWindowWithEqualityFilter() "VALUES (2, 'O'), (2, 'F'), (2, 'P')"); } - @Test - public void testWindowFunctionWithGroupBy() - { - MaterializedResult actual = computeActual("" + - "SELECT *, rank() OVER (PARTITION BY x)\n" + - "FROM (SELECT 'foo' x)\n" + - "GROUP BY 1"); - - MaterializedResult expected = resultBuilder(getSession(), createVarcharType(3), BIGINT) - .row("foo", 1L) - .build(); - - assertEquals(actual, expected); - } - - @Test - public void testPartialPrePartitionedWindowFunction() - { - assertQueryOrdered("" + - "SELECT orderkey, COUNT(*) OVER (PARTITION BY orderkey, custkey) " + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + - "ORDER BY orderkey LIMIT 5", - "VALUES (1, 1), " + - "(2, 1), " + - "(3, 1), " + - "(4, 1), " + - "(5, 1)"); - } - - @Test - public void testFullPrePartitionedWindowFunction() - { - assertQueryOrdered( - "SELECT orderkey, COUNT(*) OVER (PARTITION BY orderkey) " + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + - "ORDER BY orderkey LIMIT 5", - "VALUES (1, 1), (2, 1), (3, 1), (4, 1), (5, 1)"); - } - - @Test - public void testPartialPreSortedWindowFunction() - { - assertQueryOrdered( - "SELECT orderkey, COUNT(*) OVER (ORDER BY orderkey, custkey) " + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + - "ORDER BY orderkey LIMIT 5", - "VALUES (1, 1), " + - "(2, 2), " + - "(3, 3), " + - "(4, 4), " + - "(5, 5)"); - } - - @Test - public void testFullPreSortedWindowFunction() - { - assertQueryOrdered( - "SELECT orderkey, COUNT(*) OVER (ORDER BY orderkey) " + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + - "ORDER BY orderkey LIMIT 5", - "VALUES (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)"); - } - - @Test - public void testFullyPartitionedAndPartiallySortedWindowFunction() - { - assertQueryOrdered( - "SELECT orderkey, custkey, orderPriority, COUNT(*) OVER (PARTITION BY orderkey ORDER BY custkey, orderPriority) " + - "FROM (SELECT * FROM orders ORDER BY orderkey, custkey LIMIT 10) " + - "ORDER BY orderkey LIMIT 5", - "VALUES (1, 370, '5-LOW', 1), " + - "(2, 781, '1-URGENT', 1), " + - "(3, 1234, '5-LOW', 1), " + - "(4, 1369, '5-LOW', 1), " + - "(5, 445, '5-LOW', 1)"); - } - - @Test - public void testFullyPartitionedAndFullySortedWindowFunction() - { - assertQueryOrdered( - "SELECT orderkey, custkey, COUNT(*) OVER (PARTITION BY orderkey ORDER BY custkey) " + - "FROM (SELECT * FROM orders ORDER BY orderkey, custkey LIMIT 10) " + - "ORDER BY orderkey LIMIT 5", - "VALUES (1, 370, 1), " + - "(2, 781, 1), " + - "(3, 1234, 1), " + - "(4, 1369, 1), " + - "(5, 445, 1)"); - } - - @Test - public void testOrderByWindowFunctionWithNulls() - { - // Nulls first - assertQueryOrdered( - "SELECT orderkey, row_number() OVER (ORDER BY nullif(orderkey, 3) NULLS FIRST) " + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + - "ORDER BY 2 ASC " + - "LIMIT 5", - "VALUES (3, 1), " + - "(1, 2), " + - "(2, 3), " + - "(4, 4)," + - "(5, 5)"); - - // Nulls last - String nullsLastExpected = "VALUES (3, 10), " + - "(34, 9), " + - "(33, 8), " + - "(32, 7), " + - "(7, 6)"; - assertQueryOrdered( - "SELECT orderkey, row_number() OVER (ORDER BY nullif(orderkey, 3) NULLS LAST) " + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + - "ORDER BY 2 DESC " + - "LIMIT 5", - nullsLastExpected); - - // and nulls last should be the default - assertQueryOrdered( - "SELECT orderkey, row_number() OVER (ORDER BY nullif(orderkey, 3)) " + - "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + - "ORDER BY 2 DESC " + - "LIMIT 5", - nullsLastExpected); - } - - @Test - public void testValueWindowFunctions() - { - assertQueryOrdered( - "SELECT * FROM ( " + - " SELECT orderkey, orderstatus " + - " , first_value(orderkey + 1000) OVER (PARTITION BY orderstatus ORDER BY orderkey) fvalue " + - " , nth_value(orderkey + 1000, 2) OVER (PARTITION BY orderstatus ORDER BY orderkey " + - " ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) nvalue " + - " FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) x " + - " ) x " + - "ORDER BY orderkey LIMIT 5", - "VALUES " + - "(1, 'O', 1001, 1002), " + - "(2, 'O', 1001, 1002), " + - "(3, 'F', 1003, 1005), " + - "(4, 'O', 1001, 1002), " + - "(5, 'F', 1003, 1005)"); - } - - @Test - public void testWindowFrames() - { - MaterializedResult actual = computeActual("SELECT * FROM (\n" + - " SELECT orderkey, orderstatus\n" + - " , sum(orderkey + 1000) OVER (PARTITION BY orderstatus ORDER BY orderkey\n" + - " ROWS BETWEEN mod(custkey, 2) PRECEDING AND custkey / 500 FOLLOWING)\n" + - " FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) x\n" + - " ) x\n" + - "ORDER BY orderkey LIMIT 5"); - - MaterializedResult expected = resultBuilder(getSession(), BIGINT, VARCHAR, BIGINT) - .row(1L, "O", 1001L) - .row(2L, "O", 3007L) - .row(3L, "F", 3014L) - .row(4L, "O", 4045L) - .row(5L, "F", 2008L) - .build(); - - assertEquals(actual.getMaterializedRows(), expected.getMaterializedRows()); - } - - @Test - public void testWindowNoChannels() - { - MaterializedResult actual = computeActual("SELECT rank() OVER ()\n" + - "FROM (SELECT * FROM orders LIMIT 10)\n" + - "LIMIT 3"); - - MaterializedResult expected = resultBuilder(getSession(), BIGINT) - .row(1L) - .row(1L) - .row(1L) - .build(); - - assertEquals(actual, expected); - } - - @Test - public void testInvalidWindowFunction() - { - assertQueryFails("SELECT abs(x) OVER ()\n" + - "FROM (VALUES (1), (2), (3)) t(x)", - "line 1:1: Not a window function: abs"); - } - @Test public void testScalarFunction() { @@ -4235,20 +3675,6 @@ public void testDuplicateFields() "SELECT orderkey, orderkey FROM orders"); } - @Test - public void testDuplicateColumnsInWindowOrderByClause() - { - MaterializedResult actual = computeActual("SELECT a, row_number() OVER (ORDER BY a ASC, a DESC) FROM (VALUES 3, 2, 1) t(a)"); - - MaterializedResult expected = resultBuilder(getSession(), BIGINT, BIGINT) - .row(1, 1L) - .row(2, 2L) - .row(3, 3L) - .build(); - - assertEqualsIgnoreOrder(actual, expected); - } - @Test public void testWildcardFromSubquery() { diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestWindowQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestWindowQueries.java new file mode 100644 index 0000000000000..ceb910a06d5d6 --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestWindowQueries.java @@ -0,0 +1,652 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.testing.MaterializedResult; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static com.facebook.presto.tests.QueryAssertions.assertEqualsIgnoreOrder; +import static com.facebook.presto.tests.StructuralTestUtil.mapType; + +public class AbstractTestWindowQueries + extends AbstractTestQueryFramework +{ + public AbstractTestWindowQueries(QueryRunnerSupplier supplier) + { + super(supplier); + } + + @Test + public void testRowFieldAccessorInWindowFunction() + { + assertQuery("SELECT a.col0, " + + "SUM(a.col1[1].col1) OVER(PARTITION BY a.col2.col0), " + + "SUM(a.col2.col1) OVER(PARTITION BY a.col2.col0) FROM " + + "(VALUES " + + "ROW(CAST(ROW(1.0, ARRAY[row(31, 14.5E0), row(12, 4.2E0)], row(3, 4.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + + "ROW(CAST(ROW(2.2, ARRAY[row(41, 13.1E0), row(32, 4.2E0)], row(6, 6.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + + "ROW(CAST(ROW(2.2, ARRAY[row(41, 17.1E0), row(45, 4.2E0)], row(7, 16.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + + "ROW(CAST(ROW(2.2, ARRAY[row(41, 13.1E0), row(32, 4.2E0)], row(6, 6.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + + "ROW(CAST(ROW(3.1, ARRAY[row(41, 13.1E0), row(32, 4.2E0)], row(6, 6.0E0)) AS ROW(col0 double, col1 array(ROW(col0 integer, col1 double)), col2 row(col0 integer, col1 double))))) t(a) ", + "SELECT * FROM VALUES (1.0, 14.5, 4.0), (2.2, 39.3, 18.0), (2.2, 39.3, 18.0), (2.2, 17.1, 16.0), (3.1, 39.3, 18.0)"); + + assertQuery("SELECT a.col1[1].col0, " + + "SUM(a.col0) OVER(PARTITION BY a.col1[1].col0), " + + "SUM(a.col1[1].col1) OVER(PARTITION BY a.col1[1].col0), " + + "SUM(a.col2.col1) OVER(PARTITION BY a.col1[1].col0) FROM " + + "(VALUES " + + "ROW(CAST(ROW(1.0, ARRAY[row(31, 14.5E0), row(12, 4.2E0)], row(3, 4.0E0)) AS ROW(col0 double, col1 array(row(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + + "ROW(CAST(ROW(3.1, ARRAY[row(41, 13.1E0), row(32, 4.2E0)], row(6, 6.0E0)) AS ROW(col0 double, col1 array(row(col0 integer, col1 double)), col2 row(col0 integer, col1 double)))), " + + "ROW(CAST(ROW(2.2, ARRAY[row(31, 14.2E0), row(22, 5.2E0)], row(5, 4.0E0)) AS ROW(col0 double, col1 array(row(col0 integer, col1 double)), col2 row(col0 integer, col1 double))))) t(a) " + + "WHERE a.col1[2].col1 > a.col2.col0", + "SELECT * FROM VALUES (31, 3.2, 28.7, 8.0), (31, 3.2, 28.7, 8.0)"); + } + + @Test + public void testDistinctWindow() + { + MaterializedResult actual = computeActual( + "SELECT RANK() OVER (PARTITION BY orderdate ORDER BY COUNT(DISTINCT clerk)) rnk " + + "FROM orders " + + "GROUP BY orderdate, custkey " + + "ORDER BY rnk " + + "LIMIT 1"); + MaterializedResult expected = resultBuilder(getSession(), BIGINT).row(1L).build(); + assertEquals(actual, expected); + } + + @Test + public void testGroupingInWindowFunction() + { + assertQuery( + "SELECT orderkey, custkey, sum(totalprice), grouping(orderkey)+grouping(custkey) AS g, " + + " rank() OVER (PARTITION BY grouping(orderkey)+grouping(custkey), " + + " CASE WHEN grouping(orderkey) = 0 THEN custkey END ORDER BY orderkey ASC) AS r " + + "FROM orders " + + "GROUP BY ROLLUP (orderkey, custkey) " + + "ORDER BY orderkey, custkey " + + "LIMIT 10", + "VALUES (1, 370, 172799.49, 0, 1), " + + " (1, NULL, 172799.49, 1, 1), " + + " (2, 781, 38426.09, 0, 1), " + + " (2, NULL, 38426.09, 1, 2), " + + " (3, 1234, 205654.30, 0, 1), " + + " (3, NULL, 205654.30, 1, 3), " + + " (4, 1369, 56000.91, 0, 1), " + + " (4, NULL, 56000.91, 1, 4), " + + " (5, 445, 105367.67, 0, 1), " + + " (5, NULL, 105367.67, 1, 5)"); + } + + @Test + public void testWindowImplicitCoercion() + { + assertQueryOrdered( + "SELECT orderkey, 1e0 / row_number() OVER (ORDER BY orderkey) FROM orders LIMIT 2", + "VALUES (1, 1.0), (2, 0.5)"); + } + + @Test + public void testWindowsSameOrdering() + { + MaterializedResult actual = computeActual("SELECT " + + "sum(quantity) OVER(PARTITION BY suppkey ORDER BY orderkey)," + + "min(tax) OVER(PARTITION BY suppkey ORDER BY shipdate)" + + "FROM lineitem " + + "ORDER BY 1 " + + "LIMIT 10"); + + MaterializedResult expected = resultBuilder(getSession(), DOUBLE, DOUBLE) + .row(1.0, 0.0) + .row(2.0, 0.0) + .row(2.0, 0.0) + .row(3.0, 0.0) + .row(3.0, 0.0) + .row(4.0, 0.0) + .row(4.0, 0.0) + .row(5.0, 0.0) + .row(5.0, 0.0) + .row(5.0, 0.0) + .build(); + + assertEquals(actual, expected); + } + + @Test + public void testWindowsPrefixPartitioning() + { + MaterializedResult actual = computeActual("SELECT " + + "max(tax) OVER(PARTITION BY suppkey, tax ORDER BY receiptdate)," + + "sum(quantity) OVER(PARTITION BY suppkey ORDER BY orderkey)" + + "FROM lineitem " + + "ORDER BY 2, 1 " + + "LIMIT 10"); + + MaterializedResult expected = resultBuilder(getSession(), DOUBLE, DOUBLE) + .row(0.06, 1.0) + .row(0.02, 2.0) + .row(0.06, 2.0) + .row(0.02, 3.0) + .row(0.08, 3.0) + .row(0.03, 4.0) + .row(0.03, 4.0) + .row(0.02, 5.0) + .row(0.03, 5.0) + .row(0.07, 5.0) + .build(); + + assertEquals(actual, expected); + } + + @Test + public void testWindowsDifferentPartitions() + { + MaterializedResult actual = computeActual("SELECT " + + "sum(quantity) OVER(PARTITION BY suppkey ORDER BY orderkey)," + + "count(discount) OVER(PARTITION BY partkey ORDER BY receiptdate)," + + "min(tax) OVER(PARTITION BY suppkey, tax ORDER BY receiptdate)" + + "FROM lineitem " + + "ORDER BY 1, 2 " + + "LIMIT 10"); + + MaterializedResult expected = resultBuilder(getSession(), DOUBLE, BIGINT, DOUBLE) + .row(1.0, 10L, 0.06) + .row(2.0, 4L, 0.06) + .row(2.0, 16L, 0.02) + .row(3.0, 3L, 0.08) + .row(3.0, 38L, 0.02) + .row(4.0, 10L, 0.03) + .row(4.0, 10L, 0.03) + .row(5.0, 9L, 0.03) + .row(5.0, 13L, 0.07) + .row(5.0, 15L, 0.02) + .build(); + + assertEquals(actual, expected); + } + + @Test + public void testWindowsConstantExpression() + { + assertQueryOrdered( + "SELECT " + + "sum(size) OVER(PARTITION BY type ORDER BY brand)," + + "lag(partkey, 1) OVER(PARTITION BY type ORDER BY name)" + + "FROM part " + + "ORDER BY 1, 2 " + + "LIMIT 10", + "VALUES " + + "(1, 315), " + + "(1, 881), " + + "(1, 1009), " + + "(3, 1087), " + + "(3, 1187), " + + "(3, 1529), " + + "(4, 969), " + + "(5, 151), " + + "(5, 505), " + + "(5, 872)"); + } + + @Test + public void testDependentWindows() + { + // For such query as below generated plan has two adjacent window nodes where second depends on output of first. + + String sql = "WITH " + + "t1 AS (" + + "SELECT extendedprice FROM lineitem ORDER BY orderkey, partkey LIMIT 2)," + + "t2 AS (" + + "SELECT extendedprice, sum(extendedprice) OVER() AS x FROM t1)," + + "t3 AS (" + + "SELECT max(x) OVER() FROM t2) " + + "SELECT * FROM t3"; + + assertQuery(sql, "VALUES 59645.36, 59645.36"); + } + + @Test + public void testWindowFunctionWithoutParameters() + { + MaterializedResult actual = computeActual("SELECT count() over(partition by custkey) FROM orders WHERE custkey < 3 ORDER BY custkey"); + + MaterializedResult expected = resultBuilder(getSession(), BIGINT) + .row(9L) + .row(9L) + .row(9L) + .row(9L) + .row(9L) + .row(9L) + .row(9L) + .row(9L) + .row(9L) + .row(10L) + .row(10L) + .row(10L) + .row(10L) + .row(10L) + .row(10L) + .row(10L) + .row(10L) + .row(10L) + .row(10L) + .build(); + + assertEquals(actual, expected); + } + + @Test + public void testWindowFunctionWithImplicitCoercion() + { + assertQuery("SELECT *, 1.0 * sum(x) OVER () FROM (VALUES 1) t(x)", "SELECT 1, 1.0"); + } + + @SuppressWarnings("PointlessArithmeticExpression") + @Test + public void testWindowFunctionsExpressions() + { + assertQueryOrdered( + "SELECT orderkey, orderstatus " + + ", row_number() OVER (ORDER BY orderkey * 2) * " + + " row_number() OVER (ORDER BY orderkey DESC) + 100 " + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) x " + + "ORDER BY orderkey LIMIT 5", + "VALUES " + + "(1, 'O', 110), " + + "(2, 'O', 118), " + + "(3, 'F', 124), " + + "(4, 'O', 128), " + + "(5, 'F', 130)"); + } + + @Test + public void testWindowFunctionsFromAggregate() + { + MaterializedResult actual = computeActual("" + + "SELECT * FROM (\n" + + " SELECT orderstatus, clerk, sales\n" + + " , rank() OVER (PARTITION BY x.orderstatus ORDER BY sales DESC) rnk\n" + + " FROM (\n" + + " SELECT orderstatus, clerk, sum(totalprice) sales\n" + + " FROM orders\n" + + " GROUP BY orderstatus, clerk\n" + + " ) x\n" + + ") x\n" + + "WHERE rnk <= 2\n" + + "ORDER BY orderstatus, rnk"); + + MaterializedResult expected = resultBuilder(getSession(), VARCHAR, VARCHAR, DOUBLE, BIGINT) + .row("F", "Clerk#000000090", 2784836.61, 1L) + .row("F", "Clerk#000000084", 2674447.15, 2L) + .row("O", "Clerk#000000500", 2569878.29, 1L) + .row("O", "Clerk#000000050", 2500162.92, 2L) + .row("P", "Clerk#000000071", 841820.99, 1L) + .row("P", "Clerk#000001000", 643679.49, 2L) + .build(); + + assertEquals(actual.getMaterializedRows(), expected.getMaterializedRows()); + } + + @Test + public void testOrderByWindowFunction() + { + assertQueryOrdered( + "SELECT orderkey, row_number() OVER (ORDER BY orderkey) " + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + + "ORDER BY 2 DESC " + + "LIMIT 5", + "VALUES (34, 10), " + + "(33, 9), " + + "(32, 8), " + + "(7, 7), " + + "(6, 6)"); + } + + @Test + public void testSameWindowFunctionsTwoCoerces() + { + MaterializedResult actual = computeActual("" + + "SELECT 12.0E0 * row_number() OVER ()/row_number() OVER(),\n" + + "row_number() OVER()\n" + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10)\n" + + "ORDER BY 2 DESC\n" + + "LIMIT 5"); + + MaterializedResult expected = resultBuilder(getSession(), DOUBLE, BIGINT) + .row(12.0, 10L) + .row(12.0, 9L) + .row(12.0, 8L) + .row(12.0, 7L) + .row(12.0, 6L) + .build(); + + assertEquals(actual, expected); + + actual = computeActual("" + + "SELECT (MAX(x.a) OVER () - x.a) * 100.0E0 / MAX(x.a) OVER ()\n" + + "FROM (VALUES 1, 2, 3, 4) x(a)"); + + expected = resultBuilder(getSession(), DOUBLE) + .row(75.0) + .row(50.0) + .row(25.0) + .row(0.0) + .build(); + + assertEquals(actual, expected); + } + + @Test + public void testWindowMapAgg() + { + MaterializedResult actual = computeActual("" + + "SELECT map_agg(orderkey, orderpriority) OVER(PARTITION BY orderstatus) FROM\n" + + "(SELECT * FROM orders ORDER BY orderkey LIMIT 5) t"); + MaterializedResult expected = resultBuilder(getSession(), mapType(BIGINT, createVarcharType(1))) + .row(ImmutableMap.of(1L, "5-LOW", 2L, "1-URGENT", 4L, "5-LOW")) + .row(ImmutableMap.of(1L, "5-LOW", 2L, "1-URGENT", 4L, "5-LOW")) + .row(ImmutableMap.of(1L, "5-LOW", 2L, "1-URGENT", 4L, "5-LOW")) + .row(ImmutableMap.of(3L, "5-LOW", 5L, "5-LOW")) + .row(ImmutableMap.of(3L, "5-LOW", 5L, "5-LOW")) + .build(); + assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows()); + } + + @Test + public void testWindowPropertyDerivation() + { + assertQuery( + "SELECT orderstatus, orderkey, " + + "SUM(s) OVER (PARTITION BY orderstatus), " + + "SUM(s) OVER (PARTITION BY orderstatus, orderkey), " + + "SUM(s) OVER (PARTITION BY orderstatus ORDER BY orderkey), " + + "SUM(s) OVER (ORDER BY orderstatus, orderkey) " + + "FROM ( " + + " SELECT orderkey, orderstatus, SUM(orderkey) OVER (ORDER BY orderstatus, orderkey) s " + + " FROM ( " + + " SELECT * FROM orders ORDER BY orderkey LIMIT 10 " + + " ) " + + ")", + "VALUES " + + "('F', 3, 72, 3, 3, 3), " + + "('F', 5, 72, 8, 11, 11), " + + "('F', 6, 72, 14, 25, 25), " + + "('F', 33, 72, 47, 72, 72), " + + "('O', 1, 433, 48, 48, 120), " + + "('O', 2, 433, 50, 98, 170), " + + "('O', 4, 433, 54, 152, 224), " + + "('O', 7, 433, 61, 213, 285), " + + "('O', 32, 433, 93, 306, 378), " + + "('O', 34, 433, 127, 433, 505)"); + } + + @Test + public void testWindowFunctionWithGroupBy() + { + MaterializedResult actual = computeActual("" + + "SELECT *, rank() OVER (PARTITION BY x)\n" + + "FROM (SELECT 'foo' x)\n" + + "GROUP BY 1"); + + MaterializedResult expected = resultBuilder(getSession(), createVarcharType(3), BIGINT) + .row("foo", 1L) + .build(); + + assertEquals(actual, expected); + } + + @Test + public void testPartialPrePartitionedWindowFunction() + { + assertQueryOrdered("" + + "SELECT orderkey, COUNT(*) OVER (PARTITION BY orderkey, custkey) " + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + + "ORDER BY orderkey LIMIT 5", + "VALUES (1, 1), " + + "(2, 1), " + + "(3, 1), " + + "(4, 1), " + + "(5, 1)"); + } + + @Test + public void testFullPrePartitionedWindowFunction() + { + assertQueryOrdered( + "SELECT orderkey, COUNT(*) OVER (PARTITION BY orderkey) " + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + + "ORDER BY orderkey LIMIT 5", + "VALUES (1, 1), (2, 1), (3, 1), (4, 1), (5, 1)"); + } + + @Test + public void testPartialPreSortedWindowFunction() + { + assertQueryOrdered( + "SELECT orderkey, COUNT(*) OVER (ORDER BY orderkey, custkey) " + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + + "ORDER BY orderkey LIMIT 5", + "VALUES (1, 1), " + + "(2, 2), " + + "(3, 3), " + + "(4, 4), " + + "(5, 5)"); + } + + @Test + public void testFullPreSortedWindowFunction() + { + assertQueryOrdered( + "SELECT orderkey, COUNT(*) OVER (ORDER BY orderkey) " + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + + "ORDER BY orderkey LIMIT 5", + "VALUES (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)"); + } + + @Test + public void testFullyPartitionedAndPartiallySortedWindowFunction() + { + assertQueryOrdered( + "SELECT orderkey, custkey, orderPriority, COUNT(*) OVER (PARTITION BY orderkey ORDER BY custkey, orderPriority) " + + "FROM (SELECT * FROM orders ORDER BY orderkey, custkey LIMIT 10) " + + "ORDER BY orderkey LIMIT 5", + "VALUES (1, 370, '5-LOW', 1), " + + "(2, 781, '1-URGENT', 1), " + + "(3, 1234, '5-LOW', 1), " + + "(4, 1369, '5-LOW', 1), " + + "(5, 445, '5-LOW', 1)"); + } + + @Test + public void testFullyPartitionedAndFullySortedWindowFunction() + { + assertQueryOrdered( + "SELECT orderkey, custkey, COUNT(*) OVER (PARTITION BY orderkey ORDER BY custkey) " + + "FROM (SELECT * FROM orders ORDER BY orderkey, custkey LIMIT 10) " + + "ORDER BY orderkey LIMIT 5", + "VALUES (1, 370, 1), " + + "(2, 781, 1), " + + "(3, 1234, 1), " + + "(4, 1369, 1), " + + "(5, 445, 1)"); + } + + @Test + public void testOrderByWindowFunctionWithNulls() + { + // Nulls first + assertQueryOrdered( + "SELECT orderkey, row_number() OVER (ORDER BY nullif(orderkey, 3) NULLS FIRST) " + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + + "ORDER BY 2 ASC " + + "LIMIT 5", + "VALUES (3, 1), " + + "(1, 2), " + + "(2, 3), " + + "(4, 4)," + + "(5, 5)"); + + // Nulls last + String nullsLastExpected = "VALUES (3, 10), " + + "(34, 9), " + + "(33, 8), " + + "(32, 7), " + + "(7, 6)"; + assertQueryOrdered( + "SELECT orderkey, row_number() OVER (ORDER BY nullif(orderkey, 3) NULLS LAST) " + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + + "ORDER BY 2 DESC " + + "LIMIT 5", + nullsLastExpected); + + // and nulls last should be the default + assertQueryOrdered( + "SELECT orderkey, row_number() OVER (ORDER BY nullif(orderkey, 3)) " + + "FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) " + + "ORDER BY 2 DESC " + + "LIMIT 5", + nullsLastExpected); + } + + @Test + public void testValueWindowFunctions() + { + assertQueryOrdered( + "SELECT * FROM ( " + + " SELECT orderkey, orderstatus " + + " , first_value(orderkey + 1000) OVER (PARTITION BY orderstatus ORDER BY orderkey) fvalue " + + " , nth_value(orderkey + 1000, 2) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + " ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) nvalue " + + " FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) x " + + " ) x " + + "ORDER BY orderkey LIMIT 5", + "VALUES " + + "(1, 'O', 1001, 1002), " + + "(2, 'O', 1001, 1002), " + + "(3, 'F', 1003, 1005), " + + "(4, 'O', 1001, 1002), " + + "(5, 'F', 1003, 1005)"); + } + + @Test + public void testWindowFrames() + { + MaterializedResult actual = computeActual("SELECT * FROM (\n" + + " SELECT orderkey, orderstatus\n" + + " , sum(orderkey + 1000) OVER (PARTITION BY orderstatus ORDER BY orderkey\n" + + " ROWS BETWEEN mod(custkey, 2) PRECEDING AND custkey / 500 FOLLOWING)\n" + + " FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 10) x\n" + + " ) x\n" + + "ORDER BY orderkey LIMIT 5"); + + MaterializedResult expected = resultBuilder(getSession(), BIGINT, VARCHAR, BIGINT) + .row(1L, "O", 1001L) + .row(2L, "O", 3007L) + .row(3L, "F", 3014L) + .row(4L, "O", 4045L) + .row(5L, "F", 2008L) + .build(); + + assertEquals(actual.getMaterializedRows(), expected.getMaterializedRows()); + } + + @Test + public void testWindowNoChannels() + { + MaterializedResult actual = computeActual("SELECT rank() OVER ()\n" + + "FROM (SELECT * FROM orders LIMIT 10)\n" + + "LIMIT 3"); + + MaterializedResult expected = resultBuilder(getSession(), BIGINT) + .row(1L) + .row(1L) + .row(1L) + .build(); + + assertEquals(actual, expected); + } + + @Test + public void testInvalidWindowFunction() + { + assertQueryFails("SELECT abs(x) OVER ()\n" + + "FROM (VALUES (1), (2), (3)) t(x)", + "line 1:1: Not a window function: abs"); + } + + @Test + public void testDuplicateColumnsInWindowOrderByClause() + { + MaterializedResult actual = computeActual("SELECT a, row_number() OVER (ORDER BY a ASC, a DESC) FROM (VALUES 3, 2, 1) t(a)"); + + MaterializedResult expected = resultBuilder(getSession(), BIGINT, BIGINT) + .row(1, 1L) + .row(2, 2L) + .row(3, 3L) + .build(); + + assertEqualsIgnoreOrder(actual, expected); + } + + @Test + public void testMultipleInstancesOfWindowFunction() + { + assertQueryOrdered( + "SELECT a, b, c, " + + "lag(c, 1) RESPECT NULLS OVER (PARTITION BY b ORDER BY a), " + + "lag(c, 1) IGNORE NULLS OVER (PARTITION BY b ORDER BY a) " + + "FROM ( VALUES " + + "(1, 'A', 'a'), " + + "(2, 'A', NULL), " + + "(3, 'A', 'c'), " + + "(4, 'A', NULL), " + + "(5, 'A', 'e'), " + + "(6, 'A', NULL)" + + ") t(a, b, c)", + "VALUES " + + "(1, 'A', 'a', null, null), " + + "(2, 'A', null, 'a', 'a'), " + + "(3, 'A', 'c', null, 'a'), " + + "(4, 'A', null, 'c', 'c'), " + + "(5, 'A', 'e', null, 'c'), " + + "(6, 'A', null, 'e', 'e')"); + + assertQueryOrdered( + "SELECT a, b, c, " + + "lag(c, 1) IGNORE NULLS OVER (PARTITION BY b ORDER BY a), " + + "lag(c, 1) RESPECT NULLS OVER (PARTITION BY b ORDER BY a) " + + "FROM ( VALUES " + + "(1, 'A', 'a'), " + + "(2, 'A', NULL), " + + "(3, 'A', 'c'), " + + "(4, 'A', NULL), " + + "(5, 'A', 'e'), " + + "(6, 'A', NULL)" + + ") t(a, b, c)", + "VALUES " + + "(1, 'A', 'a', null, null), " + + "(2, 'A', null, 'a', 'a'), " + + "(3, 'A', 'c', 'a', null), " + + "(4, 'A', null, 'c', 'c'), " + + "(5, 'A', 'e', 'c', null), " + + "(6, 'A', null, 'e', 'e')"); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledWindowQueries.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledWindowQueries.java new file mode 100644 index 0000000000000..e645abd83c9c9 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledWindowQueries.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +public class TestSpilledWindowQueries + extends AbstractTestWindowQueries +{ + public TestSpilledWindowQueries() + { + super(TestDistributedSpilledQueries::createQueryRunner); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestWindowQueries.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestWindowQueries.java new file mode 100644 index 0000000000000..3199b43c34a88 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestWindowQueries.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.tests.tpch.TpchQueryRunnerBuilder; + +public class TestWindowQueries + extends AbstractTestWindowQueries +{ + public TestWindowQueries() + { + super(() -> TpchQueryRunnerBuilder.builder().build()); + } +}