diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index b45e14ac1151a..b0ffb482be7bf 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -128,6 +128,15 @@ public class Analysis private final Map, Type> types = new LinkedHashMap<>(); private final Map, Type> coercions = new LinkedHashMap<>(); + // Coercions needed for window function frame of type RANGE. + // These are coercions for the sort key, needed for frame bound calculation, identified by frame range offset expression. + // Frame definition might contain two different offset expressions (for start and end), each requiring different coercion of the sort key. + private final Map, Type> sortKeyCoercionsForFrameBoundCalculation = new LinkedHashMap<>(); + // Coercions needed for window function frame of type RANGE. + // These are coercions for the sort key, needed for comparison of the sort key with precomputed frame bound, identified by frame range offset expression. + private final Map, Type> sortKeyCoercionsForFrameBoundComparison = new LinkedHashMap<>(); + // Functions for calculating frame bounds for frame of type RANGE, identified by frame range offset expression. + private final Map, FunctionHandle> frameBoundCalculations = new LinkedHashMap<>(); private final Set> typeOnlyCoercions = new LinkedHashSet<>(); private final Map, List> relationCoercions = new LinkedHashMap<>(); private final Map, FunctionHandle> functionHandles = new LinkedHashMap<>(); @@ -553,10 +562,36 @@ public void addCoercion(Expression expression, Type type, boolean isTypeOnlyCoer } } - public void addCoercions(Map, Type> coercions, Set> typeOnlyCoercions) + public void addCoercions( + Map, Type> coercions, + Set> typeOnlyCoercions, + Map, Type> sortKeyCoercionsForFrameBoundCalculation, + Map, Type> sortKeyCoercionsForFrameBoundComparison) { this.coercions.putAll(coercions); this.typeOnlyCoercions.addAll(typeOnlyCoercions); + this.sortKeyCoercionsForFrameBoundCalculation.putAll(sortKeyCoercionsForFrameBoundCalculation); + this.sortKeyCoercionsForFrameBoundComparison.putAll(sortKeyCoercionsForFrameBoundComparison); + } + + public Type getSortKeyCoercionForFrameBoundCalculation(Expression frameOffset) + { + return sortKeyCoercionsForFrameBoundCalculation.get(NodeRef.of(frameOffset)); + } + + public Type getSortKeyCoercionForFrameBoundComparison(Expression frameOffset) + { + return sortKeyCoercionsForFrameBoundComparison.get(NodeRef.of(frameOffset)); + } + + public void addFrameBoundCalculations(Map, FunctionHandle> frameBoundCalculations) + { + this.frameBoundCalculations.putAll(frameBoundCalculations); + } + + public FunctionHandle getFrameBoundCalculation(Expression frameOffset) + { + return frameBoundCalculations.get(NodeRef.of(frameOffset)); } public Expression getHaving(QuerySpecification query) diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java index 66933a527c380..66b6d73302598 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java @@ -50,6 +50,8 @@ public enum SemanticErrorCode MISSING_ATTRIBUTE, INVALID_ORDINAL, INVALID_LITERAL, + MISSING_ORDER_BY, + INVALID_ORDER_BY, FUNCTION_NOT_FOUND, INVALID_FUNCTION_NAME, diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/ShortDecimalType.java b/presto-common/src/main/java/com/facebook/presto/common/type/ShortDecimalType.java index a0606adb2c229..b55c075d682f5 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/ShortDecimalType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/ShortDecimalType.java @@ -26,10 +26,10 @@ import static com.facebook.presto.common.type.Decimals.MAX_SHORT_PRECISION; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; -final class ShortDecimalType +public final class ShortDecimalType extends DecimalType { - ShortDecimalType(int precision, int scale) + public ShortDecimalType(int precision, int scale) { super(precision, scale, long.class); validatePrecisionScale(precision, scale, MAX_SHORT_PRECISION); diff --git a/presto-docs/src/main/sphinx/functions/window.rst b/presto-docs/src/main/sphinx/functions/window.rst index 1dab8acced88a..be06f9a072dd5 100644 --- a/presto-docs/src/main/sphinx/functions/window.rst +++ b/presto-docs/src/main/sphinx/functions/window.rst @@ -21,9 +21,9 @@ A ``frame`` is one of:: ``frame_start`` and ``frame_end`` can be any of:: UNBOUNDED PRECEDING - expression PRECEDING -- only allowed in ROWS mode + expression PRECEDING CURRENT ROW - expression FOLLOWING -- only allowed in ROWS mode + expression FOLLOWING UNBOUNDED FOLLOWING @@ -49,10 +49,15 @@ The window definition has 3 components: the first peer row of the current row, while a frame end of ``CURRENT ROW`` refers to the last peer row of the current row. - Frame starts and ends of ``expression PRECEDING`` or ``expression FOLLOWING`` are currently - only allowed in ``ROWS`` mode. They define the start or end of the frame as the specified number + In ``ROWS`` mode, frame starts and ends of ``expression PRECEDING`` or ``expression FOLLOWING`` + define the start or end of the frame as the specified number of rows before or after the current row. The ``expression`` must be of type ``INTEGER``. + In ``RANGE`` mode, frame starts and ends of ``expression PRECEDING`` or ``expression FOLLOWING`` + define the start or end of the frame as the value difference of the sort key from + the current row. The sort key must either be the same type of ``expression`` or can be coerced to the + same type as ``expression``. + If no frame is specified, a default frame of ``RANGE UNBOUNDED PRECEDING`` is used. Examples diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java index 4a47420ee316a..1fd2980a5c53f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java @@ -58,6 +58,7 @@ import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; import static com.facebook.presto.operator.SyntheticAddress.encodeSyntheticAddress; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.units.DataSize.Unit.BYTE; @@ -443,6 +444,12 @@ public PagesHashStrategy createPagesHashStrategy(List joinChannels, Opt groupByUsesEqualTo); } + public PagesIndexComparator createChannelComparator(int leftChannel, int rightChannel, SortOrder sortOrder) + { + checkArgument(types.get(leftChannel).equals(types.get(rightChannel)), "comparing channels of different types: %s and %s", types.get(leftChannel), types.get(rightChannel)); + return new SimpleChannelComparator(leftChannel, rightChannel, types.get(leftChannel), sortOrder); + } + public LookupSourceSupplier createLookupSourceSupplier( Session session, List joinChannels, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/SimpleChannelComparator.java b/presto-main/src/main/java/com/facebook/presto/operator/SimpleChannelComparator.java new file mode 100644 index 0000000000000..1ff46c8423782 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/SimpleChannelComparator.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; + +import static com.facebook.presto.operator.SyntheticAddress.decodePosition; +import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; +import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static java.util.Objects.requireNonNull; + +public class SimpleChannelComparator + implements PagesIndexComparator +{ + private final int leftChannel; + private final int rightChannel; + private final SortOrder sortOrder; + private final Type sortType; + + public SimpleChannelComparator(int leftChannel, int rightChannel, Type sortType, SortOrder sortOrder) + { + this.leftChannel = leftChannel; + this.rightChannel = rightChannel; + this.sortOrder = requireNonNull(sortOrder, "sortOrder is null"); + this.sortType = requireNonNull(sortType, "sortType is null."); + } + + @Override + public int compareTo(PagesIndex pagesIndex, int leftPosition, int rightPosition) + { + long leftPageAddress = pagesIndex.getValueAddresses().get(leftPosition); + int leftBlockIndex = decodeSliceIndex(leftPageAddress); + int leftBlockPosition = decodePosition(leftPageAddress); + + long rightPageAddress = pagesIndex.getValueAddresses().get(rightPosition); + int rightBlockIndex = decodeSliceIndex(rightPageAddress); + int rightBlockPosition = decodePosition(rightPageAddress); + + try { + Block leftBlock = pagesIndex.getChannel(leftChannel).get(leftBlockIndex); + Block rightBlock = pagesIndex.getChannel(rightChannel).get(rightBlockIndex); + int result = sortOrder.compareBlockValue(sortType, leftBlock, leftBlockPosition, rightBlock, rightBlockPosition); + + // sortOrder compares block values and adjusts the result by ASC and DESC. SimpleChannelComparator should + // return the simple comparison, so reverse it if it is DESC. + return sortOrder.isAscending() ? result : -result; + } + catch (Throwable throwable) { + throwIfUnchecked(throwable); + throw new PrestoException(GENERIC_INTERNAL_ERROR, throwable); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java index d1199becdf6e5..6813ba04fafa5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java @@ -21,6 +21,7 @@ import com.facebook.presto.operator.WorkProcessor.ProcessState; import com.facebook.presto.operator.WorkProcessor.Transformation; import com.facebook.presto.operator.WorkProcessor.TransformationState; +import com.facebook.presto.operator.window.FrameInfo; import com.facebook.presto.operator.window.FramedWindowFunction; import com.facebook.presto.operator.window.WindowPartition; import com.facebook.presto.spi.plan.PlanNodeId; @@ -29,6 +30,7 @@ import com.facebook.presto.sql.gen.OrderingCompiler; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.PeekingIterator; @@ -38,6 +40,8 @@ import javax.annotation.Nullable; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.atomic.AtomicReference; @@ -47,6 +51,9 @@ import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; import static com.facebook.presto.operator.SpillingUtils.checkSpillSucceeded; import static com.facebook.presto.operator.WorkProcessor.TransformationState.needsMoreData; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.FOLLOWING; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.util.MergeSortedPages.mergeSortedPages; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkPositionIndex; @@ -270,7 +277,9 @@ public WindowOperator( preGroupedChannels, unGroupedPartitionChannels, preSortedChannels, - sortChannels); + sortChannels, + sortOrder, + windowFunctionDefinitions); if (spillEnabled) { PagesIndexWithHashStrategies mergedPagesIndexWithHashStrategies = new PagesIndexWithHashStrategies( @@ -282,7 +291,9 @@ public WindowOperator( ImmutableList.of(), // merged pages are pre sorted on all sort channels sortChannels, - sortChannels); + sortChannels, + sortOrder, + windowFunctionDefinitions); this.spillablePagesToPagesIndexes = Optional.of(new SpillablePagesToPagesIndexes( inMemoryPagesIndexWithHashStrategies, @@ -381,6 +392,7 @@ private static class PagesIndexWithHashStrategies final PagesHashStrategy preSortedPartitionHashStrategy; final PagesHashStrategy peerGroupHashStrategy; final int[] preGroupedPartitionChannels; + final Map frameBoundComparators; PagesIndexWithHashStrategies( PagesIndex.Factory pagesIndexFactory, @@ -389,7 +401,9 @@ private static class PagesIndexWithHashStrategies List preGroupedPartitionChannels, List unGroupedPartitionChannels, List preSortedChannels, - List sortChannels) + List sortChannels, + List sortOrder, + List windowFunctionDefinitions) { this.pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); this.preGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preGroupedPartitionChannels, OptionalInt.empty()); @@ -397,6 +411,79 @@ private static class PagesIndexWithHashStrategies this.preSortedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty()); this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, OptionalInt.empty()); this.preGroupedPartitionChannels = Ints.toArray(preGroupedPartitionChannels); + this.frameBoundComparators = createFrameBoundComparators(pagesIndex, windowFunctionDefinitions, sortOrder); + } + } + + /** + * Create comparators necessary for seeking frame start or frame end for window functions with frame type RANGE. + * Whenever a frame bound is specified as RANGE X PRECEDING or RANGE X FOLLOWING, + * a dedicated comparator is created to compare sort key values with expected frame bound values. + */ + private static Map createFrameBoundComparators(PagesIndex pagesIndex, + List windowFunctionDefinitions, + List sortOrders) + { + ImmutableMap.Builder builder = ImmutableMap.builder(); + + for (int i = 0; i < windowFunctionDefinitions.size(); i++) { + FrameInfo frameInfo = windowFunctionDefinitions.get(i).getFrameInfo(); + if (frameInfo.getType() == RANGE) { + if (frameInfo.getStartType() == PRECEDING || frameInfo.getStartType() == FOLLOWING) { + // Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY + checkState(sortOrders != null && sortOrders.size() == 1, "Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY."); + SortOrder sortOrder = sortOrders.get(0); + PagesIndexComparator comparator = pagesIndex.createChannelComparator(frameInfo.getSortKeyChannelForStartComparison(), frameInfo.getStartChannel(), sortOrder); + builder.put(new FrameBoundKey(i, FrameBoundKey.Type.START), comparator); + } + if (frameInfo.getEndType() == PRECEDING || frameInfo.getEndType() == FOLLOWING) { + // Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY + checkState(sortOrders != null && sortOrders.size() == 1, "Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY."); + SortOrder sortOrder = sortOrders.get(0); + PagesIndexComparator comparator = pagesIndex.createChannelComparator(frameInfo.getSortKeyChannelForEndComparison(), frameInfo.getEndChannel(), sortOrder); + builder.put(new FrameBoundKey(i, FrameBoundKey.Type.END), comparator); + } + } + } + + return builder.build(); + } + + public static class FrameBoundKey + { + private final int functionIndex; + private final Type type; + + public enum Type + { + START, + END; + } + + public FrameBoundKey(int functionIndex, Type type) + { + this.functionIndex = functionIndex; + this.type = requireNonNull(type, "type is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FrameBoundKey that = (FrameBoundKey) o; + return functionIndex == that.functionIndex && + type == that.type; + } + + @Override + public int hashCode() + { + return Objects.hash(functionIndex, type); } } @@ -501,7 +588,14 @@ public ProcessState process() int partitionEnd = findGroupEnd(pagesIndex, pagesIndexWithHashStrategies.unGroupedPartitionHashStrategy, partitionStart); - WindowPartition partition = new WindowPartition(pagesIndex, partitionStart, partitionEnd, outputChannels, windowFunctions, pagesIndexWithHashStrategies.peerGroupHashStrategy); + WindowPartition partition = new WindowPartition( + pagesIndex, + partitionStart, + partitionEnd, + outputChannels, + windowFunctions, + pagesIndexWithHashStrategies.peerGroupHashStrategy, + pagesIndexWithHashStrategies.frameBoundComparators); windowInfo.addPartition(partition); partitionStart = partitionEnd; return ProcessState.ofResult(partition); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/window/FrameInfo.java b/presto-main/src/main/java/com/facebook/presto/operator/window/FrameInfo.java index 86c4504fa1a5b..97a3eeb7f8883 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/window/FrameInfo.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/window/FrameInfo.java @@ -15,6 +15,7 @@ import com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType; +import com.facebook.presto.sql.tree.SortItem.Ordering; import java.util.Objects; import java.util.Optional; @@ -27,21 +28,33 @@ public class FrameInfo private final WindowType type; private final BoundType startType; private final int startChannel; + private final int sortKeyChannelForStartComparison; private final BoundType endType; private final int endChannel; + private final int sortKeyChannelForEndComparison; + private final int sortKeyChannel; + private final Optional ordering; public FrameInfo( WindowType type, BoundType startType, Optional startChannel, + Optional sortKeyChannelForStartComparison, BoundType endType, - Optional endChannel) + Optional endChannel, + Optional sortKeyChannelForEndComparison, + Optional sortKeyChannel, + Optional ordering) { this.type = requireNonNull(type, "type is null"); this.startType = requireNonNull(startType, "startType is null"); this.startChannel = requireNonNull(startChannel, "startChannel is null").orElse(-1); + this.sortKeyChannelForStartComparison = requireNonNull(sortKeyChannelForStartComparison, "sortKeyChannelForStartComparison is null").orElse(-1); this.endType = requireNonNull(endType, "endType is null"); this.endChannel = requireNonNull(endChannel, "endChannel is null").orElse(-1); + this.sortKeyChannelForEndComparison = requireNonNull(sortKeyChannelForEndComparison, "sortKeyChannelForEndComparison is null").orElse(-1); + this.sortKeyChannel = requireNonNull(sortKeyChannel, "sortKeyChannel is null").orElse(-1); + this.ordering = requireNonNull(ordering, "ordering is null"); } public WindowType getType() @@ -59,6 +72,11 @@ public int getStartChannel() return startChannel; } + public int getSortKeyChannelForStartComparison() + { + return sortKeyChannelForStartComparison; + } + public BoundType getEndType() { return endType; @@ -69,10 +87,25 @@ public int getEndChannel() return endChannel; } + public int getSortKeyChannelForEndComparison() + { + return sortKeyChannelForEndComparison; + } + + public int getSortKeyChannel() + { + return sortKeyChannel; + } + + public Optional getOrdering() + { + return ordering; + } + @Override public int hashCode() { - return Objects.hash(type, startType, startChannel, endType, endChannel); + return Objects.hash(type, startType, startChannel, sortKeyChannelForStartComparison, endType, endChannel, sortKeyChannelForEndComparison, sortKeyChannel, ordering); } @Override @@ -90,9 +123,13 @@ public boolean equals(Object obj) return Objects.equals(this.type, other.type) && Objects.equals(this.startType, other.startType) && + Objects.equals(this.sortKeyChannelForStartComparison, other.sortKeyChannelForStartComparison) && Objects.equals(this.startChannel, other.startChannel) && Objects.equals(this.endType, other.endType) && - Objects.equals(this.endChannel, other.endChannel); + Objects.equals(this.endChannel, other.endChannel) && + Objects.equals(this.sortKeyChannelForEndComparison, other.sortKeyChannelForEndComparison) && + Objects.equals(this.sortKeyChannel, other.sortKeyChannel) && + Objects.equals(this.ordering, other.ordering); } @Override @@ -102,8 +139,12 @@ public String toString() .add("type", type) .add("startType", startType) .add("startChannel", startChannel) + .add("sortKeyChannelForStartComparison", sortKeyChannelForStartComparison) .add("endType", endType) .add("endChannel", endChannel) + .add("sortKeyChannelForEndComparison", sortKeyChannelForEndComparison) + .add("sortKeyChannel", sortKeyChannel) + .add("ordering", ordering) .toString(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java b/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java index 22c72c3f5b720..fd211bd5c3689 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java @@ -16,18 +16,29 @@ import com.facebook.presto.common.PageBuilder; import com.facebook.presto.operator.PagesHashStrategy; import com.facebook.presto.operator.PagesIndex; +import com.facebook.presto.operator.PagesIndexComparator; +import com.facebook.presto.operator.WindowOperator.FrameBoundKey; import com.facebook.presto.spi.function.WindowIndex; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType; +import com.facebook.presto.sql.tree.SortItem.Ordering; import com.google.common.collect.ImmutableList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import static com.facebook.presto.operator.WindowOperator.FrameBoundKey.Type.END; +import static com.facebook.presto.operator.WindowOperator.FrameBoundKey.Type.START; import static com.facebook.presto.spi.StandardErrorCode.INVALID_WINDOW_FRAME; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.CURRENT_ROW; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.FOLLOWING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; +import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkState; import static java.lang.Math.toIntExact; @@ -40,19 +51,31 @@ public final class WindowPartition private final int[] outputChannels; private final List windowFunctions; + + // Recently computed frame bounds for functions with frame type RANGE. + // When computing frame start and frame end for a row, frame bounds for the previous row + // are used as the starting point. Then they are moved backward or forward based on the sort order + // until the matching position for a current row is found. + // This approach is efficient in case when frame offset values are constant. It was chosen + // based on the assumption that in most use cases frame offset is constant rather than + // row-dependent. + private final Map recentRanges; private final PagesHashStrategy peerGroupHashStrategy; + private final Map frameBoundComparators; private int peerGroupStart; private int peerGroupEnd; private int currentPosition; - public WindowPartition(PagesIndex pagesIndex, + public WindowPartition( + PagesIndex pagesIndex, int partitionStart, int partitionEnd, int[] outputChannels, List windowFunctions, - PagesHashStrategy peerGroupHashStrategy) + PagesHashStrategy peerGroupHashStrategy, + Map frameBoundComparators) { this.pagesIndex = pagesIndex; this.partitionStart = partitionStart; @@ -60,6 +83,7 @@ public WindowPartition(PagesIndex pagesIndex, this.outputChannels = outputChannels; this.windowFunctions = ImmutableList.copyOf(windowFunctions); this.peerGroupHashStrategy = peerGroupHashStrategy; + this.frameBoundComparators = frameBoundComparators; // reset functions for new partition WindowIndex windowIndex = new PagesWindowIndex(pagesIndex, partitionStart, partitionEnd); @@ -69,6 +93,28 @@ public WindowPartition(PagesIndex pagesIndex, currentPosition = partitionStart; updatePeerGroup(); + + recentRanges = initializeRangeCache(partitionStart, partitionEnd, peerGroupEnd, windowFunctions); + } + + private static Map initializeRangeCache(int partitionStart, int partitionEnd, int peerGroupEnd, List windowFunctions) + { + Map ranges = new HashMap<>(); + Range initialPeerRange = new Range(0, peerGroupEnd - partitionStart - 1); + Range initialUnboundedRange = new Range(0, partitionEnd - partitionStart - 1); + for (int i = 0; i < windowFunctions.size(); i++) { + FrameInfo frame = windowFunctions.get(i).getFrame(); + if (frame.getType() == RANGE) { + if (frame.getEndType() == UNBOUNDED_FOLLOWING) { + ranges.put(i, initialUnboundedRange); + } + else { + ranges.put(i, initialPeerRange); + } + } + } + + return ranges; } public int getPartitionStart() @@ -103,8 +149,9 @@ public void processNextRow(PageBuilder pageBuilder) updatePeerGroup(); } - for (FramedWindowFunction framedFunction : windowFunctions) { - Range range = getFrameRange(framedFunction.getFrame()); + for (int i = 0; i < windowFunctions.size(); i++) { + FramedWindowFunction framedFunction = windowFunctions.get(i); + Range range = getFrameRange(framedFunction.getFrame(), i); framedFunction.getFunction().processRow( pageBuilder.getBlockBuilder(channel), peerGroupStart - partitionStart, @@ -149,6 +196,29 @@ private void updatePeerGroup() } } + private Range getFrameRange(FrameInfo frameInfo, int functionIndex) + { + switch (frameInfo.getType()) { + case RANGE: + Range range = getFrameRange( + frameInfo, + recentRanges.get(functionIndex), + frameBoundComparators.get(new FrameBoundKey(functionIndex, START)), + frameBoundComparators.get(new FrameBoundKey(functionIndex, END))); + // handle empty frame. If the frame is out of partition bounds, record the nearest valid frame as the 'recentRange' for the next row. + if (emptyFrame(range)) { + recentRanges.put(functionIndex, nearestValidFrame(range)); + return new Range(-1, -1); + } + recentRanges.put(functionIndex, range); + return range; + case ROWS: + return getFrameRange(frameInfo); + default: + throw new IllegalArgumentException("Unsupported frame type: " + frameInfo.getType()); + } + } + private Range getFrameRange(FrameInfo frameInfo) { int rowPosition = currentPosition - partitionStart; @@ -172,9 +242,6 @@ else if (frameInfo.getStartType() == PRECEDING) { else if (frameInfo.getStartType() == FOLLOWING) { frameStart = following(rowPosition, endPosition, getStartValue(frameInfo)); } - else if (frameInfo.getType() == RANGE) { - frameStart = peerGroupStart - partitionStart; - } else { frameStart = rowPosition; } @@ -189,9 +256,6 @@ else if (frameInfo.getEndType() == PRECEDING) { else if (frameInfo.getEndType() == FOLLOWING) { frameEnd = following(rowPosition, endPosition, getEndValue(frameInfo)); } - else if (frameInfo.getType() == RANGE) { - frameEnd = peerGroupEnd - partitionStart - 1; - } else { frameEnd = rowPosition; } @@ -199,6 +263,230 @@ else if (frameInfo.getType() == RANGE) { return new Range(frameStart, frameEnd); } + private Range getFrameRange(FrameInfo frameInfo, Range recentRange, PagesIndexComparator startComparator, PagesIndexComparator endComparator) + { + // full partition + if ((frameInfo.getStartType() == UNBOUNDED_PRECEDING && frameInfo.getEndType() == UNBOUNDED_FOLLOWING)) { + return new Range(0, partitionEnd - partitionStart - 1); + } + + // frame defined by peer group + if ((frameInfo.getStartType() == CURRENT_ROW && frameInfo.getEndType() == CURRENT_ROW) || + (frameInfo.getStartType() == CURRENT_ROW && frameInfo.getEndType() == UNBOUNDED_FOLLOWING) || + (frameInfo.getStartType() == UNBOUNDED_PRECEDING && frameInfo.getEndType() == CURRENT_ROW)) { + // same peer group as recent row + if (currentPosition == partitionStart || pagesIndex.positionEqualsPosition(peerGroupHashStrategy, currentPosition - 1, currentPosition)) { + return recentRange; + } + // next peer group + return new Range( + frameInfo.getStartType() == UNBOUNDED_PRECEDING ? 0 : peerGroupStart - partitionStart, + frameInfo.getEndType() == UNBOUNDED_FOLLOWING ? partitionEnd - partitionStart - 1 : peerGroupEnd - partitionStart - 1); + } + + // at this point, frame definition has at least one of: X PRECEDING, Y FOLLOWING + // 1. leading or trailing nulls: frame consists of nulls peer group, possibly extended to partition start / end. + // according to Spec, behavior of "X PRECEDING", "X FOLLOWING" frame boundaries is similar to "CURRENT ROW" for null values. + if (pagesIndex.isNull(frameInfo.getSortKeyChannel(), currentPosition)) { + return new Range( + frameInfo.getStartType() == UNBOUNDED_PRECEDING ? 0 : peerGroupStart - partitionStart, + frameInfo.getEndType() == UNBOUNDED_FOLLOWING ? partitionEnd - partitionStart - 1 : peerGroupEnd - partitionStart - 1); + } + + // 2. non-null value in current row. Find frame boundaries starting from recentRange + int frameStart; + switch (frameInfo.getStartType()) { + case UNBOUNDED_PRECEDING: + frameStart = 0; + break; + case CURRENT_ROW: + frameStart = peerGroupStart - partitionStart; + break; + case PRECEDING: + frameStart = getFrameStartPreceding(recentRange.getStart(), frameInfo, startComparator); + break; + case FOLLOWING: + // note: this is the only case where frameStart might get out of partition bound + frameStart = getFrameStartFollowing(recentRange.getStart(), frameInfo, startComparator); + break; + default: + // start type cannot be UNBOUNDED_FOLLOWING + throw new IllegalArgumentException("Unsupported frame start type: " + frameInfo.getStartType()); + } + + int frameEnd; + switch (frameInfo.getEndType()) { + case UNBOUNDED_FOLLOWING: + frameEnd = partitionEnd - partitionStart - 1; + break; + case CURRENT_ROW: + frameEnd = peerGroupEnd - partitionStart - 1; + break; + case PRECEDING: + // note: this is the only case where frameEnd might get out of partition bound + frameEnd = getFrameEndPreceding(recentRange.getEnd(), frameInfo, endComparator); + break; + case FOLLOWING: + frameEnd = getFrameEndFollowing(recentRange.getEnd(), frameInfo, endComparator); + break; + default: + // end type cannot be UNBOUNDED_PRECEDING + throw new IllegalArgumentException("Unsupported frame end type: " + frameInfo.getStartType()); + } + + return new Range(frameStart, frameEnd); + } + + private int getFrameStartPreceding(int recent, FrameInfo frameInfo, PagesIndexComparator comparator) + { + int sortKeyChannel = frameInfo.getSortKeyChannelForStartComparison(); + Ordering ordering = frameInfo.getOrdering().get(); + + // If the recent frame start points at a null, it means that we are now processing first non-null position. + // For frame start "X PRECEDING", the frame starts at the first null for all null values, and it never includes nulls for non-null values. + if (pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart + recent)) { + return currentPosition - partitionStart; + } + + return seek( + comparator, + sortKeyChannel, + recent, + -1, + ordering == DESCENDING, + 0, + p -> false); + } + + private int getFrameStartFollowing(int recent, FrameInfo frameInfo, PagesIndexComparator comparator) + { + int sortKeyChannel = frameInfo.getSortKeyChannelForStartComparison(); + Ordering ordering = frameInfo.getOrdering().get(); + + int position = recent; + + // If the recent frame start points at the beginning of partition and it is null, it means that we are now processing first non-null position. + // frame start for first non-null position - leave section of leading nulls + if (recent == 0 && pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart)) { + position = currentPosition - partitionStart; + } + // leave section of trailing nulls + while (pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart + position)) { + position--; + } + + return seek( + comparator, + sortKeyChannel, + position, + -1, + ordering == DESCENDING, + 0, + p -> p >= partitionEnd - partitionStart || pagesIndex.isNull(sortKeyChannel, partitionStart + p)); + } + + private int getFrameEndPreceding(int recent, FrameInfo frameInfo, PagesIndexComparator comparator) + { + int sortKeyChannel = frameInfo.getSortKeyChannelForEndComparison(); + Ordering ordering = frameInfo.getOrdering().get(); + + int position = recent; + + // leave section of leading nulls + while (pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart + position)) { + position++; + } + + return seek( + comparator, + sortKeyChannel, + position, + 1, + ordering == ASCENDING, + partitionEnd - 1 - partitionStart, + p -> p < 0 || pagesIndex.isNull(sortKeyChannel, partitionStart + p)); + } + + private int getFrameEndFollowing(int recent, FrameInfo frameInfo, PagesIndexComparator comparator) + { + Ordering ordering = frameInfo.getOrdering().get(); + int sortKeyChannel = frameInfo.getSortKeyChannelForEndComparison(); + + int position = recent; + + // frame end for first non-null position - leave section of leading nulls + if (pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart + recent)) { + position = currentPosition - partitionStart; + } + + return seek( + comparator, + sortKeyChannel, + position, + 1, + ordering == ASCENDING, + partitionEnd - 1 - partitionStart, + p -> false); + } + + private int compare(PagesIndexComparator comparator, int left, int right, boolean reverse) + { + int result = comparator.compareTo(pagesIndex, left, right); + + if (reverse) { + return -result; + } + + return result; + } + + // This method assumes that `sortKeyChannel` is not null at `position` + private int seek(PagesIndexComparator comparator, int sortKeyChannel, int position, int step, boolean reverse, int limit, Predicate bound) + { + int comparison = compare(comparator, partitionStart + position, currentPosition, reverse); + while (comparison < 0) { + position -= step; + + if (bound.test(position)) { + return position; + } + + comparison = compare(comparator, partitionStart + position, currentPosition, reverse); + } + while (true) { + if (position == limit || pagesIndex.isNull(sortKeyChannel, partitionStart + position + step)) { + break; + } + int newComparison = compare(comparator, partitionStart + position + step, currentPosition, reverse); + if (newComparison >= 0) { + position += step; + } + else { + break; + } + } + + return position; + } + + private boolean emptyFrame(Range range) + { + return range.getStart() > range.getEnd() || + range.getStart() >= partitionEnd - partitionStart || + range.getEnd() < 0; + } + + /** + * Return the nearest valid frame. A frame is valid if its start and end are within partition. + * Note: A valid frame might be empty i.e. its end might be before its start. + */ + private Range nearestValidFrame(Range range) + { + return new Range( + Math.min(partitionEnd - partitionStart - 1, range.getStart()), + Math.max(0, range.getEnd())); + } + private boolean emptyFrame(FrameInfo frameInfo, int rowPosition, int endPosition) { BoundType startType = frameInfo.getStartType(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index ae18f8014ba84..434c3dd136628 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.Session; +import com.facebook.presto.common.ErrorCode; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.function.OperatorType; @@ -21,6 +22,7 @@ import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.common.type.CharType; import com.facebook.presto.common.type.DecimalParseResult; +import com.facebook.presto.common.type.DecimalType; import com.facebook.presto.common.type.Decimals; import com.facebook.presto.common.type.DistinctType; import com.facebook.presto.common.type.FunctionType; @@ -68,6 +70,7 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Extract; import com.facebook.presto.sql.tree.FieldReference; +import com.facebook.presto.sql.tree.FrameBound; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.GroupingOperation; @@ -88,6 +91,7 @@ import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; @@ -104,11 +108,13 @@ import com.facebook.presto.sql.tree.TimestampLiteral; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; +import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; import io.airlift.slice.SliceUtf8; @@ -125,7 +131,9 @@ import java.util.Set; import java.util.function.Function; +import static com.facebook.presto.common.function.OperatorType.ADD; import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT; +import static com.facebook.presto.common.function.OperatorType.SUBTRACT; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DateType.DATE; @@ -145,6 +153,7 @@ import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE; import static com.facebook.presto.metadata.FunctionAndTypeManager.qualifyObjectName; +import static com.facebook.presto.spi.StandardErrorCode.OPERATOR_NOT_FOUND; import static com.facebook.presto.spi.StandardWarningCode.SEMANTIC_WARNING; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions; @@ -157,8 +166,10 @@ import static com.facebook.presto.sql.analyzer.FunctionArgumentCheckerForAccessControlUtils.resolveSubfield; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDER_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ORDER_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MULTIPLE_FIELDS_FROM_SUBQUERY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.STANDALONE_LAMBDA; @@ -167,6 +178,12 @@ import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.tree.Extract.Field.TIMEZONE_HOUR; import static com.facebook.presto.sql.tree.Extract.Field.TIMEZONE_MINUTE; +import static com.facebook.presto.sql.tree.FrameBound.Type.FOLLOWING; +import static com.facebook.presto.sql.tree.FrameBound.Type.PRECEDING; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; +import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; +import static com.facebook.presto.sql.tree.WindowFrame.Type.RANGE; +import static com.facebook.presto.sql.tree.WindowFrame.Type.ROWS; import static com.facebook.presto.type.ArrayParametricType.ARRAY; import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; @@ -200,6 +217,17 @@ public class ExpressionAnalyzer private final Set> existsSubqueries = new LinkedHashSet<>(); private final Map, Type> expressionCoercions = new LinkedHashMap<>(); private final Set> typeOnlyCoercions = new LinkedHashSet<>(); + + // Coercions needed for window function frame of type RANGE. + // These are coercions for the sort key, needed for frame bound calculation, identified by frame range offset expression. + // Frame definition might contain two different offset expressions (for start and end), each requiring different coercion of the sort key. + private final Map, Type> sortKeyCoercionsForFrameBoundCalculation = new LinkedHashMap<>(); + // Coercions needed for window function frame of type RANGE. + // These are coercions for the sort key, needed for comparison of the sort key with precomputed frame bound, identified by frame range offset expression. + private final Map, Type> sortKeyCoercionsForFrameBoundComparison = new LinkedHashMap<>(); + // Functions for calculating frame bounds for frame of type RANGE, identified by frame range offset expression. + private final Map, FunctionHandle> frameBoundCalculations = new LinkedHashMap<>(); + private final Set> subqueryInPredicates = new LinkedHashSet<>(); private final Map, FieldId> columnReferences = new LinkedHashMap<>(); private final Map, Type> expressionTypes = new LinkedHashMap<>(); @@ -277,6 +305,21 @@ public Set> getTypeOnlyCoercions() return unmodifiableSet(typeOnlyCoercions); } + public Map, Type> getSortKeyCoercionsForFrameBoundCalculation() + { + return unmodifiableMap(sortKeyCoercionsForFrameBoundCalculation); + } + + public Map, Type> getSortKeyCoercionsForFrameBoundComparison() + { + return unmodifiableMap(sortKeyCoercionsForFrameBoundComparison); + } + + public Map, FunctionHandle> getFrameBoundCalculations() + { + return unmodifiableMap(frameBoundCalculations); + } + public Set> getSubqueryInPredicates() { return unmodifiableSet(subqueryInPredicates); @@ -901,7 +944,8 @@ protected Type visitNullLiteral(NullLiteral node, StackableAstVisitorContext context) { if (node.getWindow().isPresent()) { - for (Expression expression : node.getWindow().get().getPartitionBy()) { + Window window = node.getWindow().get(); + for (Expression expression : window.getPartitionBy()) { process(expression, context); Type type = getExpressionType(expression); if (!type.isComparable()) { @@ -909,7 +953,7 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext context, Window window) + { + if (!window.getOrderBy().isPresent()) { + throw new SemanticException(MISSING_ORDER_BY, window, "Window frame of type RANGE PRECEDING or FOLLOWING requires ORDER BY"); + } + OrderBy orderBy = window.getOrderBy().get(); + if (orderBy.getSortItems().size() != 1) { + throw new SemanticException(INVALID_ORDER_BY, orderBy, "Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY (actual: %s)", orderBy.getSortItems().size()); + } + Expression sortKey = Iterables.getOnlyElement(orderBy.getSortItems()).getSortKey(); + Type sortKeyType = getExpressionType(sortKey); + if (!isNumericType(sortKeyType) && !isDateTimeType(sortKeyType)) { + throw new SemanticException(TYPE_MISMATCH, sortKey, "Window frame of type RANGE PRECEDING or FOLLOWING requires that sort item type be numeric, datetime or interval (actual: %s)", sortKeyType); + } + + Type offsetValueType = process(offsetValue, context); + + if (isNumericType(sortKeyType)) { + if (!isNumericType(offsetValueType)) { + throw new SemanticException(TYPE_MISMATCH, offsetValue, "Window frame RANGE value type (%s) not compatible with sort item type (%s)", offsetValueType, sortKeyType); + } + } + else { // isDateTimeType(sortKeyType) + if (offsetValueType != INTERVAL_DAY_TIME && offsetValueType != INTERVAL_YEAR_MONTH) { + throw new SemanticException(TYPE_MISMATCH, offsetValue, "Window frame RANGE value type (%s) not compatible with sort item type (%s)", offsetValueType, sortKeyType); + } + } + + // resolve function to calculate frame boundary value (add / subtract offset from sortKey) + SortItem.Ordering ordering = Iterables.getOnlyElement(orderBy.getSortItems()).getOrdering(); + OperatorType operatorType; + FunctionHandle function; + if ((boundType == PRECEDING && ordering == ASCENDING) || (boundType == FOLLOWING && ordering == DESCENDING)) { + operatorType = SUBTRACT; + } + else { + operatorType = ADD; + } + try { + function = functionAndTypeResolver.resolveOperator(operatorType, TypeSignatureProvider.fromTypes(sortKeyType, offsetValueType)); + } + catch (PrestoException e) { + ErrorCode errorCode = e.getErrorCode(); + if (errorCode.equals(OPERATOR_NOT_FOUND.toErrorCode())) { + throw new SemanticException(TYPE_MISMATCH, offsetValue, "Window frame RANGE value type (%s) not compatible with sort item type (%s)", offsetValueType, sortKeyType); + } + throw e; + } + + FunctionMetadata functionMetadata = functionAndTypeResolver.getFunctionMetadata(function); + Type expectedSortKeyType = functionAndTypeResolver.getType(functionMetadata.getArgumentTypes().get(0)); + + if (!expectedSortKeyType.equals(sortKeyType)) { + if (!functionAndTypeResolver.canCoerce(sortKeyType, expectedSortKeyType)) { + throw new SemanticException(TYPE_MISMATCH, sortKey, "Sort key must evaluate to a %s (actual: %s)", expectedSortKeyType, sortKeyType); + } + sortKeyCoercionsForFrameBoundCalculation.put(NodeRef.of(offsetValue), expectedSortKeyType); + } + + Type expectedOffsetValueType = functionAndTypeResolver.getType(functionMetadata.getArgumentTypes().get(1)); + if (!expectedOffsetValueType.equals(offsetValueType)) { + coerceType(offsetValue, offsetValueType, expectedOffsetValueType, format("Function %s argument 1", function)); + } + Type expectedFunctionResultType = functionAndTypeResolver.getType(functionMetadata.getReturnType()); + if (!expectedFunctionResultType.equals(sortKeyType)) { + if (!functionAndTypeResolver.canCoerce(sortKeyType, expectedFunctionResultType)) { + throw new SemanticException(TYPE_MISMATCH, sortKey, "Sort key must evaluate to a %s (actual: %s)", expectedFunctionResultType, sortKeyType); + } + sortKeyCoercionsForFrameBoundComparison.put(NodeRef.of(offsetValue), expectedFunctionResultType); + } + + frameBoundCalculations.put(NodeRef.of(offsetValue), function); + } + @Override protected Type visitAtTimeZone(AtTimeZone node, StackableAstVisitorContext context) { @@ -1727,10 +1861,14 @@ public static ExpressionAnalysis analyzeExpression( Map, Type> expressionTypes = analyzer.getExpressionTypes(); Map, Type> expressionCoercions = analyzer.getExpressionCoercions(); Set> typeOnlyCoercions = analyzer.getTypeOnlyCoercions(); + Map, Type> sortKeyCoercionsForFrameBoundCalculation = analyzer.getSortKeyCoercionsForFrameBoundCalculation(); + Map, Type> sortKeyCoercionsForFrameBoundComparison = analyzer.getSortKeyCoercionsForFrameBoundComparison(); + Map, FunctionHandle> frameBoundCalculations = analyzer.getFrameBoundCalculations(); Map, FunctionHandle> resolvedFunctions = analyzer.getResolvedFunctions(); analysis.addTypes(expressionTypes); - analysis.addCoercions(expressionCoercions, typeOnlyCoercions); + analysis.addCoercions(expressionCoercions, typeOnlyCoercions, sortKeyCoercionsForFrameBoundCalculation, sortKeyCoercionsForFrameBoundComparison); + analysis.addFrameBoundCalculations(frameBoundCalculations); analysis.addFunctionHandles(resolvedFunctions); analysis.addColumnReferences(analyzer.getColumnReferences()); analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences()); @@ -1896,4 +2034,15 @@ public static ExpressionAnalyzer createWithoutSubqueries( warningCollector, isDescribe); } + + public static boolean isNumericType(Type type) + { + return type.equals(BIGINT) || + type.equals(INTEGER) || + type.equals(SMALLINT) || + type.equals(TINYINT) || + type.equals(DOUBLE) || + type.equals(REAL) || + type instanceof DecimalType; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 024cf1ab4ad72..ca0f5cb42a1ae 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -267,7 +267,6 @@ import static com.facebook.presto.sql.tree.FrameBound.Type.PRECEDING; import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_FOLLOWING; import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; -import static com.facebook.presto.sql.tree.WindowFrame.Type.RANGE; import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -2180,12 +2179,6 @@ private void analyzeWindowFrame(WindowFrame frame) if ((startType == FOLLOWING) && (endType == CURRENT_ROW)) { throw new SemanticException(INVALID_WINDOW_FRAME, frame, "Window frame starting from FOLLOWING cannot end with CURRENT ROW"); } - if ((frame.getType() == RANGE) && ((startType == PRECEDING) || (endType == PRECEDING))) { - throw new SemanticException(INVALID_WINDOW_FRAME, frame, "Window frame RANGE PRECEDING is only supported with UNBOUNDED"); - } - if ((frame.getType() == RANGE) && ((startType == FOLLOWING) || (endType == FOLLOWING))) { - throw new SemanticException(INVALID_WINDOW_FRAME, frame, "Window frame RANGE FOLLOWING is only supported with UNBOUNDED"); - } } private void analyzeHaving(QuerySpecification node, Scope scope) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index b9705dbe82431..d146fd6dd7147 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -463,12 +463,18 @@ public Optional visitWindow(WindowNode node, Context context) .map(expression -> inlineAndCanonicalize(context.getExpressions(), expression)); Optional endValue = function.getFrame().getEndValue() .map(expression -> inlineAndCanonicalize(context.getExpressions(), expression)); + Optional sortKeyCoercedForFrameStartComparison = function.getFrame().getSortKeyCoercedForFrameStartComparison() + .map(expression -> inlineAndCanonicalize(context.getExpressions(), expression)); + Optional sortKeyCoercedForFrameEndComparison = function.getFrame().getSortKeyCoercedForFrameEndComparison() + .map(expression -> inlineAndCanonicalize(context.getExpressions(), expression)); WindowNode.Frame frame = new WindowNode.Frame( function.getFrame().getType(), function.getFrame().getStartType(), startValue, + sortKeyCoercedForFrameStartComparison, function.getFrame().getEndType(), endValue, + sortKeyCoercedForFrameEndComparison, startValue.map(ignored -> ""), endValue.map(ignored -> "")); WindowNode.Function newFunction = new WindowNode.Function( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index cd9254e58c0b3..79e677bed507e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -206,6 +206,7 @@ import com.facebook.presto.sql.planner.plan.WindowNode.Frame; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.VariableToChannelTranslator; +import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.SymbolReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; @@ -318,6 +319,8 @@ import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; +import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; import static com.facebook.presto.util.Reflection.constructorMethodHandle; import static com.facebook.presto.util.SpatialJoinUtils.ST_CONTAINS; import static com.facebook.presto.util.SpatialJoinUtils.ST_CROSSES; @@ -1116,17 +1119,40 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext ImmutableList.Builder windowFunctionOutputVariablesBuilder = ImmutableList.builder(); for (Map.Entry entry : node.getWindowFunctions().entrySet()) { Optional frameStartChannel = Optional.empty(); + Optional sortKeyChannelForStartComparison = Optional.empty(); Optional frameEndChannel = Optional.empty(); + Optional sortKeyChannelForEndComparison = Optional.empty(); + Optional sortKeyChannel = Optional.empty(); + Optional ordering = Optional.empty(); Frame frame = entry.getValue().getFrame(); if (frame.getStartValue().isPresent()) { frameStartChannel = Optional.of(source.getLayout().get(frame.getStartValue().get())); } + if (frame.getSortKeyCoercedForFrameStartComparison().isPresent()) { + sortKeyChannelForStartComparison = Optional.of(source.getLayout().get(frame.getSortKeyCoercedForFrameStartComparison().get())); + } if (frame.getEndValue().isPresent()) { frameEndChannel = Optional.of(source.getLayout().get(frame.getEndValue().get())); } + if (frame.getSortKeyCoercedForFrameEndComparison().isPresent()) { + sortKeyChannelForEndComparison = Optional.of(source.getLayout().get(frame.getSortKeyCoercedForFrameEndComparison().get())); + } + if (node.getOrderingScheme().isPresent()) { + sortKeyChannel = Optional.of(sortChannels.get(0)); + ordering = Optional.of(sortOrder.get(0).isAscending() ? ASCENDING : DESCENDING); + } - FrameInfo frameInfo = new FrameInfo(frame.getType(), frame.getStartType(), frameStartChannel, frame.getEndType(), frameEndChannel); + FrameInfo frameInfo = new FrameInfo( + frame.getType(), + frame.getStartType(), + frameStartChannel, + sortKeyChannelForStartComparison, + frame.getEndType(), + frameEndChannel, + sortKeyChannelForEndComparison, + sortKeyChannel, + ordering); WindowNode.Function function = entry.getValue(); CallExpression call = function.getFunctionCall(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 1997fdc5ee693..6e56a24b4c9fe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -14,12 +14,15 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.Assignments; @@ -50,22 +53,28 @@ import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FrameBound; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GroupingOperation; +import com.facebook.presto.sql.tree.IfExpression; +import com.facebook.presto.sql.tree.IntervalLiteral; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NodeLocation; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Offset; import com.facebook.presto.sql.tree.OrderBy; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; import com.facebook.presto.sql.tree.SortItem; +import com.facebook.presto.sql.tree.StringLiteral; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; @@ -76,22 +85,28 @@ import com.google.common.collect.Sets; import java.util.ArrayList; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.stream.IntStream; import static com.facebook.presto.SystemSessionProperties.isSkipRedundantSort; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.plan.AggregationNode.groupingSets; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.LimitNode.Step.FINAL; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.isNumericType; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme; import static com.facebook.presto.sql.planner.PlannerUtils.toSortOrder; import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.toBoundType; @@ -100,10 +115,21 @@ import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static com.facebook.presto.sql.tree.IntervalLiteral.IntervalField.DAY; +import static com.facebook.presto.sql.tree.IntervalLiteral.IntervalField.YEAR; +import static com.facebook.presto.sql.tree.IntervalLiteral.Sign.POSITIVE; +import static com.facebook.presto.sql.tree.WindowFrame.Type.RANGE; +import static com.facebook.presto.sql.tree.WindowFrame.Type.ROWS; +import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; +import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Streams.stream; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; class QueryPlanner @@ -368,6 +394,40 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression projections.build())); } + /** + * Creates a projection with any additional coercions by identity of the provided expressions. + * + * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed + */ + public static PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, PlanVariableAllocator variableAllocator, Metadata metadata) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), x -> castToRowExpression(asSymbolReference(x))))); + ImmutableMap.Builder, VariableReferenceExpression> mappings = ImmutableMap.builder(); + for (Expression expression : expressions) { + Type coercion = analysis.getCoercion(expression); + if (coercion != null) { + Type type = analysis.getType(expression); + VariableReferenceExpression variable = variableAllocator.newVariable(expression, coercion); + assignments.put(variable, castToRowExpression(new Cast( + subPlan.rewrite(expression), + coercion.getTypeSignature().toString(), + false, + metadata.getFunctionAndTypeManager().isTypeOnlyCoercion(type, coercion)))); + mappings.put(NodeRef.of(expression), variable); + } + else { + mappings.put(NodeRef.of(expression), subPlan.translate(expression)); + } + } + subPlan = subPlan.withNewRoot( + new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + assignments.build())); + return new PlanAndMappings(subPlan, mappings.build()); + } + private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) { ImmutableMap.Builder projections = ImmutableMap.builder(); @@ -738,39 +798,80 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio Window window = windowFunction.getWindow().get(); // Extract frame - WindowFrame.Type frameType = WindowFrame.Type.RANGE; + WindowFrame.Type frameType = RANGE; FrameBound.Type frameStartType = FrameBound.Type.UNBOUNDED_PRECEDING; FrameBound.Type frameEndType = FrameBound.Type.CURRENT_ROW; - Expression frameStart = null; - Expression frameEnd = null; + Optional startValue = Optional.empty(); + Optional endValue = Optional.empty(); if (window.getFrame().isPresent()) { WindowFrame frame = window.getFrame().get(); frameType = frame.getType(); frameStartType = frame.getStart().getType(); - frameStart = frame.getStart().getValue().orElse(null); + startValue = frame.getStart().getValue(); if (frame.getEnd().isPresent()) { frameEndType = frame.getEnd().get().getType(); - frameEnd = frame.getEnd().get().getValue().orElse(null); + endValue = frame.getEnd().get().getValue(); } } // Pre-project inputs - ImmutableList.Builder inputs = ImmutableList.builder() + ImmutableList.Builder inputsBuilder = ImmutableList.builder() .addAll(windowFunction.getArguments()) .addAll(window.getPartitionBy()) .addAll(Iterables.transform(getSortItemsFromOrderBy(window.getOrderBy()), SortItem::getSortKey)); - if (frameStart != null) { - inputs.add(frameStart); + if (startValue.isPresent()) { + inputsBuilder.add(startValue.get()); } - if (frameEnd != null) { - inputs.add(frameEnd); + if (endValue.isPresent()) { + inputsBuilder.add(endValue.get()); } - subPlan = subPlan.appendProjections(inputs.build(), variableAllocator, idAllocator); + ImmutableList inputs = inputsBuilder.build(); + subPlan = subPlan.appendProjections(inputs, variableAllocator, idAllocator); + + // Add projection to coerce inputs to their site-specific types. + // This is important because the same lexical expression may need to be coerced + // in different ways if it's referenced by multiple arguments to the window function. + // For example, given v::integer, + // avg(v) OVER (ORDER BY v) + // Needs to be rewritten as + // avg(CAST(v AS double)) OVER (ORDER BY v) + PlanAndMappings coercions = coerce(subPlan, inputs, analysis, idAllocator, variableAllocator, metadata); + subPlan = coercions.getSubPlan(); + + // For frame of type RANGE, append casts and functions necessary for frame bound calculations + Optional frameStart = Optional.empty(); + Optional frameEnd = Optional.empty(); + Optional sortKeyCoercedForFrameStartComparison = Optional.empty(); + Optional sortKeyCoercedForFrameEndComparison = Optional.empty(); + + if (window.getFrame().isPresent() && window.getFrame().get().getType() == RANGE) { + // record sortKey coercions for reuse + Map sortKeyCoercions = new HashMap<>(); + + // process frame start + FrameBoundPlanAndSymbols plan = planFrameBound(subPlan, coercions, startValue, window, sortKeyCoercions); + subPlan = plan.getSubPlan(); + frameStart = plan.getFrameBoundSymbol(); + sortKeyCoercedForFrameStartComparison = plan.getSortKeyCoercedForFrameBoundComparison(); + + // process frame end + plan = planFrameBound(subPlan, coercions, endValue, window, sortKeyCoercions); + subPlan = plan.getSubPlan(); + frameEnd = plan.getFrameBoundSymbol(); + sortKeyCoercedForFrameEndComparison = plan.getSortKeyCoercedForFrameBoundComparison(); + } + else if (window.getFrame().isPresent() && window.getFrame().get().getType() == ROWS) { + frameStart = window.getFrame().get().getStart().getValue().map(coercions::get); + frameEnd = window.getFrame().get().getEnd().flatMap(FrameBound::getValue).map(coercions::get); + } + else if (window.getFrame().isPresent()) { + throw new IllegalArgumentException("unexpected window frame type: " + window.getFrame().get().getType()); + } // Rewrite PARTITION BY in terms of pre-projected inputs ImmutableList.Builder partitionByVariables = ImmutableList.builder(); @@ -787,23 +888,17 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio } // Rewrite frame bounds in terms of pre-projected inputs - Optional frameStartVariable = Optional.empty(); - Optional frameEndVariable = Optional.empty(); - if (frameStart != null) { - frameStartVariable = Optional.of(subPlan.translate(frameStart)); - } - if (frameEnd != null) { - frameEndVariable = Optional.of(subPlan.translate(frameEnd)); - } WindowNode.Frame frame = new WindowNode.Frame( toWindowType(frameType), toBoundType(frameStartType), - frameStartVariable, + frameStart, + sortKeyCoercedForFrameStartComparison, toBoundType(frameEndType), - frameEndVariable, - Optional.ofNullable(frameStart).map(Expression::toString), - Optional.ofNullable(frameEnd).map(Expression::toString)); + frameEnd, + sortKeyCoercedForFrameEndComparison, + startValue.map(Expression::toString), + endValue.map(Expression::toString)); TranslationMap outputTranslations = subPlan.copyTranslations(); @@ -872,6 +967,134 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio return subPlan; } + private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMappings coercions, Optional frameOffset, Window window, Map sortKeyCoercions) + { + Optional frameBoundCalculationFunction = frameOffset.map(analysis::getFrameBoundCalculation); + + // Empty frameBoundCalculationFunction indicates that frame bound type is CURRENT ROW or UNBOUNDED. + // Handling it doesn't require any additional symbols. + if (!frameBoundCalculationFunction.isPresent()) { + return new FrameBoundPlanAndSymbols(subPlan, Optional.empty(), Optional.empty()); + } + + // Present frameBoundCalculationFunction indicates that frame bound type is PRECEDING or FOLLOWING. + // It requires adding certain projections to the plan so that the operator can determine frame bounds. + + // First, append filter to validate offset values. They mustn't be negative or null. + VariableReferenceExpression offsetSymbol = coercions.get(frameOffset.get()); + Expression zeroOffset = zeroOfType(variableAllocator.getTypes().get(offsetSymbol)); + FunctionHandle fail = metadata.getFunctionAndTypeManager().resolveFunction(Optional.empty(), Optional.empty(), QualifiedObjectName.valueOf("presto.default.fail"), fromTypes(VARCHAR)); + Expression predicate = new IfExpression( + new ComparisonExpression( + GREATER_THAN_OR_EQUAL, + new SymbolReference(offsetSymbol.getName()), + zeroOffset), + TRUE_LITERAL, + new Cast( + new FunctionCall( + QualifiedName.of("presto", "default", "fail"), + ImmutableList.of(new Cast(new StringLiteral("Window frame offset value must not be negative or null"), VARCHAR.getTypeSignature().toString()))), + BOOLEAN.getTypeSignature().toString())); + subPlan = subPlan.withNewRoot(new FilterNode( + getSourceLocation(window), + idAllocator.getNextId(), + subPlan.getRoot(), + castToRowExpression(predicate))); + + // Then, coerce the sortKey so that we can add / subtract the offset. + // Note: for that we cannot rely on the usual mechanism of using the coerce() method. The coerce() method can only handle one coercion for a node, + // while the sortKey node might require several different coercions, e.g. one for frame start and one for frame end. + Expression sortKey = Iterables.getOnlyElement(window.getOrderBy().get().getSortItems()).getSortKey(); + VariableReferenceExpression sortKeyCoercedForFrameBoundCalculation = coercions.get(sortKey); + Optional coercion = frameOffset.map(analysis::getSortKeyCoercionForFrameBoundCalculation); + if (coercion.isPresent()) { + Type expectedType = coercion.get(); + VariableReferenceExpression alreadyCoerced = sortKeyCoercions.get(expectedType); + if (alreadyCoerced != null) { + sortKeyCoercedForFrameBoundCalculation = alreadyCoerced; + } + else { + Expression cast = new Cast( + new SymbolReference(coercions.get(sortKey).getName()), + expectedType.getTypeSignature().toString(), + false, + metadata.getFunctionAndTypeManager().isTypeOnlyCoercion(analysis.getType(sortKey), expectedType)); + sortKeyCoercedForFrameBoundCalculation = variableAllocator.newVariable(cast, expectedType); + sortKeyCoercions.put(expectedType, sortKeyCoercedForFrameBoundCalculation); + subPlan = subPlan.withNewRoot(new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + Assignments.builder() + .putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), x -> castToRowExpression(asSymbolReference(x))))) + .put(sortKeyCoercedForFrameBoundCalculation, castToRowExpression(cast)) + .build())); + } + } + + // Next, pre-project the function which combines sortKey with the offset. + // Note: if frameOffset needs a coercion, it was added before by a call to coerce() method. + FunctionHandle function = frameBoundCalculationFunction.get(); + FunctionMetadata functionMetadata = metadata.getFunctionAndTypeManager().getFunctionMetadata(function); + QualifiedObjectName name = functionMetadata.getName(); + Expression functionCall = new FunctionCall( + QualifiedName.of(name.getCatalogName(), name.getSchemaName(), name.getObjectName()), + ImmutableList.of( + new SymbolReference(sortKeyCoercedForFrameBoundCalculation.getName()), + new SymbolReference(offsetSymbol.getName()))); + VariableReferenceExpression frameBoundVariable = variableAllocator.newVariable(functionCall, metadata.getFunctionAndTypeManager().getType(functionMetadata.getReturnType())); + subPlan = subPlan.withNewRoot(new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + Assignments.builder() + .putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), x -> castToRowExpression(asSymbolReference(x))))) + .put(frameBoundVariable, castToRowExpression(functionCall)) + .build())); + + // Finally, coerce the sortKey to the type of frameBound so that the operator can perform comparisons on them + Optional sortKeyCoercedForFrameBoundComparison = Optional.of(coercions.get(sortKey)); + coercion = frameOffset.map(analysis::getSortKeyCoercionForFrameBoundComparison); + if (coercion.isPresent()) { + Type expectedType = coercion.get(); + VariableReferenceExpression alreadyCoerced = sortKeyCoercions.get(expectedType); + if (alreadyCoerced != null) { + sortKeyCoercedForFrameBoundComparison = Optional.of(alreadyCoerced); + } + else { + Expression cast = new Cast( + new SymbolReference(coercions.get(sortKey).getName()), + expectedType.getTypeSignature().toString(), + false, + metadata.getFunctionAndTypeManager().isTypeOnlyCoercion(analysis.getType(sortKey), expectedType)); + VariableReferenceExpression castSymbol = variableAllocator.newVariable(cast, expectedType); + sortKeyCoercions.put(expectedType, castSymbol); + subPlan = subPlan.withNewRoot(new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + Assignments.builder() + .putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), x -> castToRowExpression(asSymbolReference(x))))) + .put(castSymbol, castToRowExpression(cast)) + .build())); + sortKeyCoercedForFrameBoundComparison = Optional.of(castSymbol); + } + } + + return new FrameBoundPlanAndSymbols(subPlan, Optional.of(frameBoundVariable), sortKeyCoercedForFrameBoundComparison); + } + + private Expression zeroOfType(Type type) + { + if (isNumericType(type)) { + return new Cast(new LongLiteral("0"), type.getTypeSignature().toString()); + } + if (type.equals(INTERVAL_DAY_TIME)) { + return new IntervalLiteral("0", POSITIVE, DAY); + } + if (type.equals(INTERVAL_YEAR_MONTH)) { + return new IntervalLiteral("0", POSITIVE, YEAR); + } + throw new IllegalArgumentException("unexpected type: " + type); + } + private PlanBuilder handleSubqueries(PlanBuilder subPlan, Node node, Iterable inputs) { for (Expression input : inputs) { @@ -970,4 +1193,65 @@ private static List toSymbolReferences(List, VariableReferenceExpression> mappings; + + public PlanAndMappings(PlanBuilder subPlan, Map, VariableReferenceExpression> mappings) + { + this.subPlan = subPlan; + this.mappings = mappings; + } + + public PlanBuilder getSubPlan() + { + return subPlan; + } + + public VariableReferenceExpression get(Expression expression) + { + return tryGet(expression) + .orElseThrow(() -> new IllegalArgumentException(format("No mapping for expression: %s (%s)", expression, System.identityHashCode(expression)))); + } + + public Optional tryGet(Expression expression) + { + VariableReferenceExpression result = mappings.get(NodeRef.of(expression)); + if (result != null) { + return Optional.of(result); + } + return Optional.empty(); + } + } + + private static class FrameBoundPlanAndSymbols + { + private final PlanBuilder subPlan; + private final Optional frameBoundSymbol; + private final Optional sortKeyCoercedForFrameBoundComparison; + + public FrameBoundPlanAndSymbols(PlanBuilder subPlan, Optional frameBoundSymbol, Optional sortKeyCoercedForFrameBoundComparison) + { + this.subPlan = subPlan; + this.frameBoundSymbol = frameBoundSymbol; + this.sortKeyCoercedForFrameBoundComparison = sortKeyCoercedForFrameBoundComparison; + } + + public PlanBuilder getSubPlan() + { + return subPlan; + } + + public Optional getFrameBoundSymbol() + { + return frameBoundSymbol; + } + + public Optional getSortKeyCoercedForFrameBoundComparison() + { + return sortKeyCoercedForFrameBoundComparison; + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java index aaca444f57e88..bca9d34b1539a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java @@ -67,6 +67,15 @@ public Type get(Expression expression) return type; } + public Type get(VariableReferenceExpression expression) + { + requireNonNull(expression, "expression is null"); + Type type = types.get(expression.getName()); + checkArgument(type != null, "no type found found for expression '%s'", expression); + + return type; + } + public Set allVariables() { return types.entrySet().stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java index 5b6760c615933..1fa44c0624a45 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java @@ -187,7 +187,7 @@ private static boolean dependsOn(WindowNode parent, WindowNode child) .anyMatch(child.getCreatedVariable()::contains) || parent.getWindowFunctions().values().stream() .map(function -> function.getFrame()) - .map(frame -> ImmutableList.of(frame.getStartValue(), frame.getEndValue())) + .map(frame -> ImmutableList.of(frame.getStartValue(), frame.getEndValue(), frame.getSortKeyCoercedForFrameStartComparison(), frame.getSortKeyCoercedForFrameEndComparison())) .flatMap(Collection::stream) .anyMatch(x -> x.isPresent() && child.getCreatedVariable().contains(x.get())); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java index 7695b19f06c22..ae71d603a58c8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java @@ -62,6 +62,8 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, referencedInputs.addAll(WindowNodeUtil.extractWindowFunctionUniqueVariables(windowFunction, variableAllocator.getTypes())); windowFunction.getFrame().getStartValue().ifPresent(referencedInputs::add); windowFunction.getFrame().getEndValue().ifPresent(referencedInputs::add); + windowFunction.getFrame().getSortKeyCoercedForFrameStartComparison().ifPresent(referencedInputs::add); + windowFunction.getFrame().getSortKeyCoercedForFrameEndComparison().ifPresent(referencedInputs::add); } PlanNode prunedWindowNode = new WindowNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 6d8bd4986add0..3c66ce231b55c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -422,6 +422,12 @@ public PlanNode visitWindow(WindowNode node, RewriteContext startValue; + // Sort key coerced to the same type of range start expression, for comparing and deciding frame start for range expression + private final Optional sortKeyCoercedForFrameStartComparison; private final BoundType endType; private final Optional endValue; + // Sort key coerced to the same type of range end expression, for comparing and deciding frame end for range expression + private final Optional sortKeyCoercedForFrameEndComparison; // This information is only used for printing the plan. private final Optional originalStartValue; @@ -257,25 +262,35 @@ public Frame( @JsonProperty("type") WindowType type, @JsonProperty("startType") BoundType startType, @JsonProperty("startValue") Optional startValue, + @JsonProperty("sortKeyCoercedForFrameStartComparison") Optional sortKeyCoercedForFrameStartComparison, @JsonProperty("endType") BoundType endType, @JsonProperty("endValue") Optional endValue, + @JsonProperty("sortKeyCoercedForFrameEndComparison") Optional sortKeyCoercedForFrameEndComparison, @JsonProperty("originalStartValue") Optional originalStartValue, @JsonProperty("originalEndValue") Optional originalEndValue) { this.startType = requireNonNull(startType, "startType is null"); this.startValue = requireNonNull(startValue, "startValue is null"); + this.sortKeyCoercedForFrameStartComparison = requireNonNull(sortKeyCoercedForFrameStartComparison, "sortKeyCoercedForFrameStartComparison is null"); this.endType = requireNonNull(endType, "endType is null"); this.endValue = requireNonNull(endValue, "endValue is null"); + this.sortKeyCoercedForFrameEndComparison = requireNonNull(sortKeyCoercedForFrameEndComparison, "sortKeyCoercedForFrameEndComparison is null"); this.type = requireNonNull(type, "type is null"); this.originalStartValue = requireNonNull(originalStartValue, "originalStartValue is null"); this.originalEndValue = requireNonNull(originalEndValue, "originalEndValue is null"); if (startValue.isPresent()) { checkArgument(originalStartValue.isPresent(), "originalStartValue must be present if startValue is present"); + if (type == RANGE) { + checkArgument(sortKeyCoercedForFrameStartComparison.isPresent(), "for frame of type RANGE, sortKeyCoercedForFrameStartComparison must be present if startValue is present"); + } } if (endValue.isPresent()) { checkArgument(originalEndValue.isPresent(), "originalEndValue must be present if endValue is present"); + if (type == RANGE) { + checkArgument(sortKeyCoercedForFrameEndComparison.isPresent(), "for frame of type RANGE, sortKeyCoercedForFrameEndComparison must be present if endValue is present"); + } } } @@ -297,6 +312,12 @@ public Optional getStartValue() return startValue; } + @JsonProperty + public Optional getSortKeyCoercedForFrameStartComparison() + { + return sortKeyCoercedForFrameStartComparison; + } + @JsonProperty public BoundType getEndType() { @@ -309,6 +330,12 @@ public Optional getEndValue() return endValue; } + @JsonProperty + public Optional getSortKeyCoercedForFrameEndComparison() + { + return sortKeyCoercedForFrameEndComparison; + } + @JsonProperty public Optional getOriginalStartValue() { @@ -334,14 +361,16 @@ public boolean equals(Object o) return type == frame.type && startType == frame.startType && Objects.equals(startValue, frame.startValue) && + Objects.equals(sortKeyCoercedForFrameStartComparison, frame.sortKeyCoercedForFrameStartComparison) && endType == frame.endType && - Objects.equals(endValue, frame.endValue); + Objects.equals(endValue, frame.endValue) && + Objects.equals(sortKeyCoercedForFrameEndComparison, frame.sortKeyCoercedForFrameEndComparison); } @Override public int hashCode() { - return Objects.hash(type, startType, startValue, endType, endValue, originalStartValue, originalEndValue); + return Objects.hash(type, startType, startValue, sortKeyCoercedForFrameStartComparison, endType, endValue, originalStartValue, originalEndValue, sortKeyCoercedForFrameEndComparison); } public enum WindowType diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index da3b8e5a8cc49..7529374e35345 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -198,6 +198,17 @@ public Void visitWindow(WindowNode node, Set boundV } checkDependencies(inputs, bounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)", bounds.build(), node.getSource().getOutputVariables()); + ImmutableList.Builder symbolsForFrameBoundsComparison = ImmutableList.builder(); + for (WindowNode.Frame frame : node.getFrames()) { + if (frame.getSortKeyCoercedForFrameStartComparison().isPresent()) { + symbolsForFrameBoundsComparison.add(frame.getSortKeyCoercedForFrameStartComparison().get()); + } + if (frame.getSortKeyCoercedForFrameEndComparison().isPresent()) { + symbolsForFrameBoundsComparison.add(frame.getSortKeyCoercedForFrameEndComparison().get()); + } + } + checkDependencies(inputs, symbolsForFrameBoundsComparison.build(), "Invalid node. Symbols for frame bound comparison (%s) not in source plan output (%s)", symbolsForFrameBoundsComparison.build(), node.getSource().getOutputVariables()); + for (WindowNode.Function function : node.getWindowFunctions().values()) { Set dependencies = WindowNodeUtil.extractWindowFunctionUniqueVariables(function, types); checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputVariables()); diff --git a/presto-main/src/main/java/com/facebook/presto/testing/MaterializedResult.java b/presto-main/src/main/java/com/facebook/presto/testing/MaterializedResult.java index 8c679953d3668..a8fb59a473396 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/MaterializedResult.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/MaterializedResult.java @@ -81,6 +81,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Lists.newArrayList; import static java.lang.Float.floatToRawIntBits; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -366,39 +367,47 @@ private static MaterializedRow convertToTestTypes(MaterializedRow prestoRow) List convertedValues = new ArrayList<>(); for (int field = 0; field < prestoRow.getFieldCount(); field++) { Object prestoValue = prestoRow.getField(field); - Object convertedValue; - if (prestoValue instanceof SqlDate) { - convertedValue = LocalDate.ofEpochDay(((SqlDate) prestoValue).getDays()); - } - else if (prestoValue instanceof SqlTime) { - convertedValue = DateTimeFormatter.ISO_LOCAL_TIME.parse(prestoValue.toString(), LocalTime::from); - } - else if (prestoValue instanceof SqlTimeWithTimeZone) { - // Political timezone cannot be represented in OffsetTime and there isn't any better representation. - long millisUtc = ((SqlTimeWithTimeZone) prestoValue).getMillisUtc(); - ZoneOffset zone = toZoneOffset(((SqlTimeWithTimeZone) prestoValue).getTimeZoneKey()); - convertedValue = OffsetTime.of( - LocalTime.ofNanoOfDay(MILLISECONDS.toNanos(millisUtc) + SECONDS.toNanos(zone.getTotalSeconds())), - zone); - } - else if (prestoValue instanceof SqlTimestamp) { - convertedValue = SqlTimestamp.JSON_MILLIS_FORMATTER.parse(prestoValue.toString(), LocalDateTime::from); - } - else if (prestoValue instanceof SqlTimestampWithTimeZone) { - convertedValue = Instant.ofEpochMilli(((SqlTimestampWithTimeZone) prestoValue).getMillisUtc()) - .atZone(ZoneId.of(((SqlTimestampWithTimeZone) prestoValue).getTimeZoneKey().getId())); - } - else if (prestoValue instanceof SqlDecimal) { - convertedValue = ((SqlDecimal) prestoValue).toBigDecimal(); - } - else { - convertedValue = prestoValue; - } - convertedValues.add(convertedValue); + convertedValues.add(convertPrestoValueToTestType(prestoValue)); } return new MaterializedRow(prestoRow.getPrecision(), convertedValues); } + private static Object convertPrestoValueToTestType(Object prestoValue) + { + Object convertedValue; + if (prestoValue instanceof SqlDate) { + convertedValue = LocalDate.ofEpochDay(((SqlDate) prestoValue).getDays()); + } + else if (prestoValue instanceof SqlTime) { + convertedValue = DateTimeFormatter.ISO_LOCAL_TIME.parse(prestoValue.toString(), LocalTime::from); + } + else if (prestoValue instanceof SqlTimeWithTimeZone) { + // Political timezone cannot be represented in OffsetTime and there isn't any better representation. + long millisUtc = ((SqlTimeWithTimeZone) prestoValue).getMillisUtc(); + ZoneOffset zone = toZoneOffset(((SqlTimeWithTimeZone) prestoValue).getTimeZoneKey()); + convertedValue = OffsetTime.of( + LocalTime.ofNanoOfDay(MILLISECONDS.toNanos(millisUtc) + SECONDS.toNanos(zone.getTotalSeconds())), + zone); + } + else if (prestoValue instanceof SqlTimestamp) { + convertedValue = SqlTimestamp.JSON_MILLIS_FORMATTER.parse(prestoValue.toString(), LocalDateTime::from); + } + else if (prestoValue instanceof SqlTimestampWithTimeZone) { + convertedValue = Instant.ofEpochMilli(((SqlTimestampWithTimeZone) prestoValue).getMillisUtc()) + .atZone(ZoneId.of(((SqlTimestampWithTimeZone) prestoValue).getTimeZoneKey().getId())); + } + else if (prestoValue instanceof SqlDecimal) { + convertedValue = ((SqlDecimal) prestoValue).toBigDecimal(); + } + else if (prestoValue instanceof ArrayList) { + convertedValue = newArrayList(((ArrayList) prestoValue).stream().map(x -> convertPrestoValueToTestType(x)).toArray()); + } + else { + convertedValue = prestoValue; + } + return convertedValue; + } + private static ZoneOffset toZoneOffset(TimeZoneKey timeZoneKey) { requireNonNull(timeZoneKey, "timeZoneKey is null"); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java index 7c24303656894..7a7cf8e2f27d5 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java @@ -73,7 +73,7 @@ @Test(singleThreaded = true) public class TestWindowOperator { - private static final FrameInfo UNBOUNDED_FRAME = new FrameInfo(RANGE, UNBOUNDED_PRECEDING, Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty()); + private static final FrameInfo UNBOUNDED_FRAME = new FrameInfo(RANGE, UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); public static final List ROW_NUMBER = ImmutableList.of( window(new ReflectionWindowFunctionSupplier<>("row_number", BIGINT, ImmutableList.of(), RowNumberFunction.class), BIGINT, UNBOUNDED_FRAME)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index 8c4dfd0e74ec7..9a96d9aeb0175 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -45,6 +45,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_FUNCTION_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_OFFSET_ROW_COUNT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDER_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; @@ -55,6 +56,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ATTRIBUTE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_CATALOG; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_COLUMN; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ORDER_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_SCHEMA; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_TABLE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MULTIPLE_FIELDS_FROM_SUBQUERY; @@ -662,7 +664,7 @@ public void testWindowFunctionWithoutOverClause() } @Test - public void testInvalidWindowFrame() + public void testInvalidWindowFrameTypeRows() { assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS UNBOUNDED FOLLOWING)"); assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS 2 FOLLOWING)"); @@ -671,10 +673,6 @@ public void testInvalidWindowFrame() assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS BETWEEN CURRENT ROW AND 5 PRECEDING)"); assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS BETWEEN 2 FOLLOWING AND 5 PRECEDING)"); assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS BETWEEN 2 FOLLOWING AND CURRENT ROW)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE 2 PRECEDING)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE BETWEEN 2 PRECEDING AND CURRENT ROW)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE BETWEEN CURRENT ROW AND 5 FOLLOWING)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE BETWEEN 2 PRECEDING AND 5 FOLLOWING)"); assertFails(TYPE_MISMATCH, "SELECT rank() OVER (ROWS 5e-1 PRECEDING)"); assertFails(TYPE_MISMATCH, "SELECT rank() OVER (ROWS 'foo' PRECEDING)"); @@ -682,6 +680,63 @@ public void testInvalidWindowFrame() assertFails(TYPE_MISMATCH, "SELECT rank() OVER (ROWS BETWEEN CURRENT ROW AND 'foo' FOLLOWING)"); } + @Test + public void testWindowFrameTypeRange() + { + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND 2 FOLLOWING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND CURRENT ROW) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND 5 PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE 2 FOLLOWING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND CURRENT ROW) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 5 PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND 5 PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND 5 PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE 5 PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND 10 PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND 3 PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND CURRENT ROW) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND 2 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE CURRENT ROW) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND CURRENT ROW) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 1 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 10 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + + // this should pass the analysis but fail during execution + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN -x PRECEDING AND 0 * x FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CAST(null AS BIGINT) PRECEDING AND CAST(null AS BIGINT) FOLLOWING) FROM (VALUES 1) T(x)"); + + assertFails(MISSING_ORDER_BY, "SELECT array_agg(x) OVER (RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM (VALUES 1) T(x)"); + + assertFails(INVALID_ORDER_BY, "SELECT array_agg(x) OVER (ORDER BY x DESC, x ASC RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM (VALUES 1) T(x)"); + + assertFails(TYPE_MISMATCH, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM (VALUES 'a') T(x)"); + + assertFails(TYPE_MISMATCH, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 'a' PRECEDING AND 'z' FOLLOWING) FROM (VALUES 1) T(x)"); + + assertFails(TYPE_MISMATCH, "SELECT array_agg(x) OVER (ORDER BY x RANGE INTERVAL '1' day PRECEDING) FROM (VALUES INTERVAL '1' year) T(x)"); + + // window frame other than PRECEDING or FOLLOWING has no requirements regarding window ORDER BY clause + // ORDER BY is not required + analyze("SELECT array_agg(x) OVER (RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + // multiple sort keys and sort keys of types other than numeric or datetime are allowed + analyze("SELECT array_agg(x) OVER (ORDER BY y, z RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES (1, 'text', true)) T(x, y, z)"); + } + @Test public void testDistinctInWindowFunctionParameter() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index d1b48adff54c7..312cb48cf706c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -163,9 +163,11 @@ public void testValidWindow() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, DOUBLE, variableC), frame, false); @@ -416,9 +418,11 @@ public void testInvalidWindowFunctionCall() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, variableA), frame, false); @@ -448,9 +452,11 @@ public void testInvalidWindowFunctionSignature() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, variableC), frame, false); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestWindowFrameRange.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestWindowFrameRange.java new file mode 100644 index 0000000000000..acf046783cf23 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestWindowFrameRange.java @@ -0,0 +1,207 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.ShortDecimalType; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.tree.DecimalLiteral; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.type.DecimalType.createDecimalType; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.facebook.presto.sql.Optimizer.PlanStage.CREATED; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.windowFrame; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.CURRENT_ROW; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.FOLLOWING; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; + +public class TestWindowFrameRange + extends BasePlanTest +{ + @Test + public void testFramePrecedingWithSortKeyCoercions() + { + @Language("SQL") String sql = "SELECT array_agg(key) OVER(ORDER BY key RANGE x PRECEDING) " + + "FROM (VALUES (1, 1.1), (2, 2.2)) T(key, x)"; + + PlanMatchPattern pattern = + anyTree( + window( + windowMatcherBuilder -> windowMatcherBuilder + .specification(specification( + ImmutableList.of(), + ImmutableList.of("key"), + ImmutableMap.of("key", SortOrder.ASC_NULLS_LAST))) + .addFunction( + "array_agg_result", + functionCall("array_agg", ImmutableList.of("key")), + createTestFunctionAndTypeManager().resolveFunction(Optional.empty(), Optional.empty(), QualifiedObjectName.valueOf("presto.default.array_agg"), fromTypes(INTEGER)), + windowFrame( + RANGE, + PRECEDING, + Optional.of("frame_start_value"), + Optional.of(new ShortDecimalType(12, 1)), + Optional.of("key_for_frame_start_comparison"), + Optional.of(new ShortDecimalType(12, 1)), + CURRENT_ROW, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty())), + project(// coerce sort key to compare sort key values with frame start values + ImmutableMap.of("key_for_frame_start_comparison", expression("CAST(key AS decimal(12, 1))")), + project(// calculate frame start value (sort key - frame offset) + ImmutableMap.of("frame_start_value", expression(new FunctionCall(QualifiedName.of("presto", "default", "$operator$subtract"), ImmutableList.of(new SymbolReference("key_for_frame_start_calculation"), new SymbolReference("x"))))), + project(// coerce sort key to calculate frame start values + ImmutableMap.of("key_for_frame_start_calculation", expression("CAST(key AS decimal(10, 0))")), + filter(// validate offset values + "IF((x >= CAST(0 AS DECIMAL(2,1))), " + + "true, " + + "CAST(presto.default.fail(CAST('Window frame offset value must not be negative or null' AS varchar)) AS boolean))", + anyTree( + values( + ImmutableList.of("key", "x"), + ImmutableList.of( + ImmutableList.of(new LongLiteral("1"), new DecimalLiteral("1.1")), + ImmutableList.of(new LongLiteral("2"), new DecimalLiteral("2.2"))))))))))); + + assertPlan(sql, CREATED, pattern); + } + + @Test + public void testFrameFollowingWithOffsetCoercion() + { + @Language("SQL") String sql = "SELECT array_agg(key) OVER(ORDER BY key RANGE BETWEEN CURRENT ROW AND x FOLLOWING) " + + "FROM (VALUES (1.1, 1), (2.2, 2)) T(key, x)"; + + PlanMatchPattern pattern = + anyTree( + window( + windowMatcherBuilder -> windowMatcherBuilder + .specification(specification( + ImmutableList.of(), + ImmutableList.of("key"), + ImmutableMap.of("key", SortOrder.ASC_NULLS_LAST))) + .addFunction( + "array_agg_result", + functionCall("array_agg", ImmutableList.of("key")), + createTestFunctionAndTypeManager().resolveFunction(Optional.empty(), Optional.empty(), + QualifiedObjectName.valueOf("presto.default.array_agg"), fromTypes(createDecimalType(2, 1))), + windowFrame( + RANGE, + CURRENT_ROW, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + FOLLOWING, + Optional.of("frame_end_value"), + Optional.of(new ShortDecimalType(12, 1)), + Optional.of("key_for_frame_end_comparison"), + Optional.of(new ShortDecimalType(12, 1)))), + project(// coerce sort key to compare sort key values with frame end values + ImmutableMap.of("key_for_frame_end_comparison", expression("CAST(key AS decimal(12, 1))")), + project(// calculate frame end value (sort key + frame offset) + ImmutableMap.of("frame_end_value", expression(new FunctionCall(QualifiedName.of("presto", "default", "$operator$add"), ImmutableList.of(new SymbolReference("key"), new SymbolReference("offset"))))), + filter(// validate offset values + "IF((offset >= CAST(0 AS DECIMAL(10, 0))), " + + "true, " + + "CAST(presto.default.fail(CAST('Window frame offset value must not be negative or null' AS varchar)) AS boolean))", + project(// coerce offset value to calculate frame end values + ImmutableMap.of("offset", expression("CAST(x AS decimal(10, 0))")), + anyTree( + values( + ImmutableList.of("key", "x"), + ImmutableList.of( + ImmutableList.of(new DecimalLiteral("1.1"), new LongLiteral("1")), + ImmutableList.of(new DecimalLiteral("2.2"), new LongLiteral("2"))))))))))); + + assertPlan(sql, CREATED, pattern); + } + + @Test + public void testFramePrecedingFollowingNoCoercions() + { + @Language("SQL") String sql = "SELECT array_agg(key) OVER(ORDER BY key RANGE BETWEEN x PRECEDING AND y FOLLOWING) " + + "FROM (VALUES (1, 1, 1), (2, 2, 2)) T(key, x, y)"; + + PlanMatchPattern pattern = + anyTree( + window( + windowMatcherBuilder -> windowMatcherBuilder + .specification(specification( + ImmutableList.of(), + ImmutableList.of("key"), + ImmutableMap.of("key", SortOrder.ASC_NULLS_LAST))) + .addFunction( + "array_agg_result", + functionCall("array_agg", ImmutableList.of("key")), + createTestFunctionAndTypeManager().resolveFunction(Optional.empty(), Optional.empty(), QualifiedObjectName.valueOf("presto.default.array_agg"), fromTypes(INTEGER)), + windowFrame( + RANGE, + PRECEDING, + Optional.of("frame_start_value"), + Optional.of(INTEGER), + Optional.of("key"), + Optional.of(INTEGER), + FOLLOWING, + Optional.of("frame_end_value"), + Optional.of(INTEGER), + Optional.of("key"), + Optional.of(INTEGER))), + project(// calculate frame end value (sort key + frame end offset) + ImmutableMap.of("frame_end_value", expression(new FunctionCall(QualifiedName.of("presto", "default", "$operator$add"), ImmutableList.of(new SymbolReference("key"), new SymbolReference("y"))))), + filter(// validate frame end offset values + "IF((y >= CAST(0 AS INTEGER)), " + + "true, " + + "CAST(presto.default.fail(CAST('Window frame offset value must not be negative or null' AS varchar)) AS boolean))", + project(// calculate frame start value (sort key - frame start offset) + ImmutableMap.of("frame_start_value", expression(new FunctionCall(QualifiedName.of("presto", "default", "$operator$subtract"), ImmutableList.of(new SymbolReference("key"), new SymbolReference("x"))))), + filter(// validate frame start offset values + "IF((x >= CAST(0 AS INTEGER)), " + + "true, " + + "CAST(presto.default.fail(CAST('Window frame offset value must not be negative or null' AS varchar)) AS boolean))", + anyTree( + values( + ImmutableList.of("key", "x", "y"), + ImmutableList.of( + ImmutableList.of(new LongLiteral("1"), new LongLiteral("1"), new LongLiteral("1")), + ImmutableList.of(new LongLiteral("2"), new LongLiteral("2"), new LongLiteral("2"))))))))))); + + assertPlan(sql, CREATED, pattern); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java index c4581700ada85..0f9d04333d2f1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java @@ -57,6 +57,13 @@ public ExpressionMatcher(String expression, ParsingOptions.DecimalLiteralTreatme this.expression = expression(requireNonNull(expression)); } + public ExpressionMatcher(Expression expression) + { + this.expression = requireNonNull(expression, "expression is null"); + this.sql = requireNonNull(expression).toString(); + this.decimalLiteralTreatment = ParsingOptions.DecimalLiteralTreatment.REJECT; + } + private Expression expression(String sql) { SqlParser parser = new SqlParser(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java index 9d8525891c1ea..66a40a2e78c22 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java @@ -26,6 +26,7 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; +import com.facebook.presto.sql.tree.IfExpression; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.IsNotNullPredicate; @@ -316,7 +317,8 @@ protected Boolean visitSymbolReference(SymbolReference actual, Node expected) if (!(expected instanceof SymbolReference)) { return false; } - return symbolAliases.get(((SymbolReference) expected).getName()).equals(actual); + return symbolAliases.get(((SymbolReference) expected).getName()).equals(actual) || + expected.equals(actual); } @Override @@ -451,6 +453,20 @@ protected Boolean visitSearchedCaseExpression(SearchedCaseExpression actual, Nod return process(actual.getDefaultValue(), expected.getDefaultValue()) && process(actual.getWhenClauses(), expected.getWhenClauses()); } + @Override + protected Boolean visitIfExpression(IfExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof IfExpression)) { + return false; + } + + IfExpression expected = (IfExpression) expectedExpression; + + return process(actual.getCondition(), expected.getCondition()) + && process(actual.getTrueValue(), expected.getTrueValue()) + && process(actual.getFalseValue(), expected.getFalseValue()); + } + private boolean process(List actuals, List expecteds) { if (actuals.size() != expecteds.size()) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index f47b5ef7dfa4b..66d9631601f3b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.predicate.Domain; +import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.AggregationNode; @@ -255,14 +256,37 @@ public static ExpectedValueProvider windowFrame( BoundType startType, Optional startValue, BoundType endType, - Optional endValue) + Optional endValue, + Optional sortKey) + { + return windowFrame(type, startType, startValue, Optional.empty(), sortKey, Optional.empty(), endType, endValue, Optional.empty(), sortKey, Optional.empty()); + } + + public static ExpectedValueProvider windowFrame( + WindowType type, + BoundType startType, + Optional startValue, + Optional startValueType, + Optional sortKeyForStartComparison, + Optional sortKeyForStartComparisonType, + BoundType endType, + Optional endValue, + Optional endValueType, + Optional sortKeyForEndComparison, + Optional sortKeyForEndComparisonType) { return new WindowFrameProvider( type, startType, startValue.map(SymbolAlias::new), + startValueType, + sortKeyForStartComparison.map(SymbolAlias::new), + sortKeyForStartComparisonType, endType, - endValue.map(SymbolAlias::new)); + endValue.map(SymbolAlias::new), + endValueType, + sortKeyForEndComparison.map(SymbolAlias::new), + sortKeyForEndComparisonType); } public static PlanMatchPattern window(Consumer windowMatcherBuilderConsumer, PlanMatchPattern source) @@ -711,6 +735,11 @@ public static ExpressionMatcher expression(String expression, ParsingOptions.Dec return new ExpressionMatcher(expression, decimalLiteralTreatment); } + public static ExpressionMatcher expression(Expression expression) + { + return new ExpressionMatcher(expression); + } + public PlanMatchPattern withOutputs(String... aliases) { return withOutputs(ImmutableList.copyOf(aliases)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java index 91724b27024e2..65275157980d1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java @@ -103,4 +103,21 @@ public static boolean matchSpecification(WindowNode.Specification actual, Window .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)))) .orElse(true); } + + public static boolean matchSpecification(WindowNode.Specification actual, SpecificationProvider expected) + { + return actual.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList()) + .equals(expected.partitionBy.stream().map(SymbolAlias::toString).collect(toImmutableList())) && + actual.getOrderingScheme().map(orderingScheme -> orderingScheme.getOrderByVariables().stream() + .map(VariableReferenceExpression::getName) + .collect(toImmutableSet()) + .equals(expected.orderBy.stream() + .map(SymbolAlias::toString) + .collect(toImmutableSet())) && + orderingScheme.getOrderingsMap().entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)) + .equals(expected.orderings.entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().toString(), Map.Entry::getValue)))) + .orElse(true); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java index ad44824c2215e..e0aa40f4f422d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java @@ -100,19 +100,19 @@ private static String toKey(String alias) private Map getUpdatedAssignments(Assignments assignments) { - ImmutableMap.Builder mapUpdate = ImmutableMap.builder(); + Map updatedMap = new HashMap<>(); for (Map.Entry assignment : assignments.getMap().entrySet()) { for (Map.Entry existingAlias : map.entrySet()) { RowExpression expression = assignment.getValue(); if (isExpression(expression) && castToExpression(expression).equals(existingAlias.getValue())) { // Simple symbol rename - mapUpdate.put(existingAlias.getKey(), asSymbolReference(assignment.getKey())); + updatedMap.put(existingAlias.getKey(), asSymbolReference(assignment.getKey())); } else if (!isExpression(expression) && (expression instanceof VariableReferenceExpression) && ((VariableReferenceExpression) expression).getName().equals(existingAlias.getValue().getName())) { // Simple symbol rename - mapUpdate.put(existingAlias.getKey(), createSymbolReference(assignment.getKey())); + updatedMap.put(existingAlias.getKey(), createSymbolReference(assignment.getKey())); } else if (createSymbolReference(assignment.getKey()).equals(existingAlias.getValue())) { /* @@ -125,11 +125,11 @@ else if (createSymbolReference(assignment.getKey()).equals(existingAlias.getValu * At the beginning for the function, map contains { NEW_ALIAS: SymbolReference("expr_2" } * and the assignments map contains { expr_2 := }. */ - mapUpdate.put(existingAlias.getKey(), existingAlias.getValue()); + updatedMap.put(existingAlias.getKey(), existingAlias.getValue()); } } } - return mapUpdate.build(); + return ImmutableMap.copyOf(updatedMap); } /* diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java index a6a816ae0dae1..f9c551513e0ab 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.assertions; +import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType; @@ -31,21 +32,40 @@ public class WindowFrameProvider private final WindowType type; private final BoundType startType; private final Optional startValue; + private final Optional sortKeyForStartComparison; private final BoundType endType; private final Optional endValue; + private final Optional sortKeyForEndComparison; + + private Optional startValueType; + private Optional sortKeyForStartComparisonType; + private Optional endValueType; + private Optional sortKeyForEndComparisonType; WindowFrameProvider( WindowType type, BoundType startType, Optional startValue, + Optional startValueType, + Optional sortKeyForStartComparison, + Optional sortKeyForStartComparisonType, BoundType endType, - Optional endValue) + Optional endValue, + Optional endValueType, + Optional sortKeyForEndComparison, + Optional sortKeyForEndComparisonType) { this.type = requireNonNull(type, "type is null"); this.startType = requireNonNull(startType, "startType is null"); this.startValue = requireNonNull(startValue, "startValue is null"); + this.startValueType = requireNonNull(startValueType, "startValueType is null"); + this.sortKeyForStartComparison = requireNonNull(sortKeyForStartComparison, "sortKeyForStartComparison is null"); + this.sortKeyForStartComparisonType = requireNonNull(sortKeyForStartComparisonType, "sortKeyForStartComparisonType is null"); this.endType = requireNonNull(endType, "endType is null"); this.endValue = requireNonNull(endValue, "endValue is null"); + this.endValueType = requireNonNull(endValueType, "endValueType is null"); + this.sortKeyForEndComparison = requireNonNull(sortKeyForEndComparison, "sortKeyForEndComparison is null"); + this.sortKeyForEndComparisonType = requireNonNull(sortKeyForEndComparisonType, "sortKeyForEndComparisonType is null"); } @Override @@ -59,13 +79,30 @@ public WindowNode.Frame getExpectedValue(SymbolAliases aliases) return new WindowNode.Frame( type, startType, - startValue.map(alias -> new VariableReferenceExpression(Optional.empty(), alias.toSymbol(aliases).getName(), BIGINT)), + toVariableReferenceExpression(aliases, startValue, startValueType), + toVariableReferenceExpression(aliases, sortKeyForStartComparison, sortKeyForStartComparisonType), endType, - endValue.map(alias -> new VariableReferenceExpression(Optional.empty(), alias.toSymbol(aliases).getName(), BIGINT)), + toVariableReferenceExpression(aliases, endValue, endValueType), + toVariableReferenceExpression(aliases, sortKeyForEndComparison, sortKeyForEndComparisonType), originalStartValue.map(Expression::toString), originalEndValue.map(Expression::toString)); } + private Optional toVariableReferenceExpression(SymbolAliases aliases, Optional symbolAlias, Optional type) + { + if (!symbolAlias.isPresent()) { + return Optional.empty(); + } + + String alias = symbolAlias.get().toSymbol(aliases).getName(); + Type variableType = type.orElseGet(() -> BIGINT); + if (alias.startsWith("field")) { + return Optional.of(new VariableReferenceExpression(Optional.empty(), symbolAlias.get().toString(), variableType)); + } + + return Optional.of(new VariableReferenceExpression(Optional.empty(), symbolAlias.get().toSymbol(aliases).getName(), variableType)); + } + @Override public String toString() { @@ -73,8 +110,14 @@ public String toString() .add("type", this.type) .add("startType", this.startType) .add("startValue", this.startValue) + .add("startValueType", this.startValueType) + .add("sortKeyForStartComparison", this.sortKeyForStartComparison) + .add("sortKeyForStartComparisonType", this.sortKeyForStartComparisonType) .add("endType", this.endType) .add("endValue", this.endValue) + .add("endValueType", this.endValueType) + .add("sortKeyForEndComparison", this.sortKeyForEndComparison) + .add("sortKeyForEndComparisonType", this.sortKeyForEndComparisonType) .toString(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java index bb6f3cdb1242d..ac5d0e1036213 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java @@ -24,6 +24,7 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableSet; import java.util.List; @@ -99,9 +100,13 @@ public Optional getAssignedVariable(PlanNode node, } } else { - if (!expectedExpression.equals(castToExpression(actualExpression))) { - return false; + if (expectedExpression.equals(castToExpression(actualExpression))) { + return true; + } + if (castToExpression(actualExpression) instanceof SymbolReference) { + return expectedExpression.equals(symbolAliases.get((((SymbolReference) castToExpression(actualExpression))).getName())); } + return false; } } return true; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java index b2c3bc9363b96..4cff3d3b6b284 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java @@ -85,7 +85,8 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses } if (!specification - .map(expectedSpecification -> matchSpecification(windowNode.getSpecification(), expectedSpecification.getExpectedValue(symbolAliases))) + .map(expectedSpecification -> matchSpecification(windowNode.getSpecification(), expectedSpecification.getExpectedValue(symbolAliases)) || + (expectedSpecification instanceof SpecificationProvider && matchSpecification(windowNode.getSpecification(), (SpecificationProvider) expectedSpecification))) .orElse(true)) { return NO_MATCH; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 1003023e202c1..1305a19f1d00e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -44,6 +44,7 @@ import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.CURRENT_ROW; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.FOLLOWING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; @@ -59,17 +60,32 @@ public class TestMergeAdjacentWindows RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), CURRENT_ROW, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); + private static final WindowNode.Frame frameWithRangeOffset = new WindowNode.Frame( + RANGE, + PRECEDING, + Optional.of(new VariableReferenceExpression(Optional.empty(), "startValue", BIGINT)), + Optional.of(new VariableReferenceExpression(Optional.empty(), "sortKeyCoercedForFrameStartComparison", BIGINT)), + FOLLOWING, + Optional.of(new VariableReferenceExpression(Optional.empty(), "endValue", BIGINT)), + Optional.of(new VariableReferenceExpression(Optional.empty(), "sortKeyCoercedForFrameEndComparison", BIGINT)), + Optional.of("originalStartValue"), + Optional.of("originalEndValue")); + private static final WindowNode.Frame frameWithRowOffset = new WindowNode.Frame( ROWS, PRECEDING, Optional.of(new VariableReferenceExpression(Optional.empty(), "startValue", BIGINT)), + Optional.empty(), CURRENT_ROW, Optional.empty(), + Optional.empty(), Optional.of("startValue"), Optional.empty()); @@ -148,6 +164,21 @@ public void testDependentAdjacentWindowsIdenticalSpecifications() .doesNotFire(); } + @Test + public void testDependentAdjacentWindowsIdenticalSpecificationsWithRangeOffset() + { + tester().assertThat(new GatherAndMergeWindows.MergeAdjacentWindowsOverProjects(0)) + .on(p -> + p.window( + newWindowNodeSpecification(p, "a", "sortkey"), + ImmutableMap.of(p.variable("avg_1"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, frameWithRangeOffset, "a")), + p.window( + newWindowNodeSpecification(p, "a", "sortkey"), + ImmutableMap.of(p.variable("startValue"), newWindowNodeFunction("rank", RANK_FUNCTION_HANDLE)), + p.values(p.variable("a"), p.variable("sortkey"))))) + .doesNotFire(); + } + @Test public void testDependentAdjacentWindowsIdenticalSpecificationsWithOffset() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java index 60ab90d96b8ef..00c06d7dfb528 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -49,6 +49,8 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.windowFrame; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.CURRENT_ROW; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.FOLLOWING; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.sql.relational.Expressions.call; @@ -62,7 +64,8 @@ public class TestPruneWindowColumns private static final FunctionHandle FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction(FUNCTION_NAME, fromTypes(BIGINT)); private static final List inputSymbolNameList = - ImmutableList.of("orderKey", "partitionKey", "hash", "startValue1", "startValue2", "endValue1", "endValue2", "input1", "input2", "unused"); + ImmutableList.of("orderKey", "partitionKey", "hash", "startValue1", "startValue2", "startValue3", "endValue1", "endValue2", "endValue3", "sortKeyForStartComparison3", + "sortKeyForEndComparison3", "input1", "input2", "input3", "unused"); private static final Set inputSymbolNameSet = ImmutableSet.copyOf(inputSymbolNameList); private static final ExpectedValueProvider frameProvider1 = windowFrame( @@ -70,14 +73,29 @@ public class TestPruneWindowColumns UNBOUNDED_PRECEDING, Optional.of("startValue1"), CURRENT_ROW, - Optional.of("endValue1")); + Optional.of("endValue1"), + Optional.of("orderKey")); private static final ExpectedValueProvider frameProvider2 = windowFrame( RANGE, UNBOUNDED_PRECEDING, Optional.of("startValue2"), CURRENT_ROW, - Optional.of("endValue2")); + Optional.of("endValue2"), + Optional.of("orderKey")); + + private static final ExpectedValueProvider frameProvider3 = windowFrame( + RANGE, + PRECEDING, + Optional.of("startValue3"), + Optional.of(BIGINT), + Optional.of("sortKeyForStartComparison3"), + Optional.of(BIGINT), + FOLLOWING, + Optional.of("endValue3"), + Optional.of(BIGINT), + Optional.of("sortKeyForEndComparison3"), + Optional.of(BIGINT)); @Test public void testWindowNotNeeded() @@ -95,12 +113,13 @@ public void testOneFunctionNotNeeded() { tester().assertThat(new PruneWindowColumns()) .on(p -> buildProjectedWindow(p, - symbol -> symbol.getName().equals("output2") || symbol.getName().equals("unused"), + symbol -> symbol.getName().equals("output2") || symbol.getName().equals("output3") || symbol.getName().equals("unused"), alwaysTrue())) .matches( strictProject( ImmutableMap.of( "output2", expression("output2"), + "output3", expression("output3"), "unused", expression("unused")), window(windowBuilder -> windowBuilder .prePartitionedInputs(ImmutableSet.of()) @@ -114,6 +133,11 @@ public void testOneFunctionNotNeeded() functionCall("min", ImmutableList.of("input2")), FUNCTION_HANDLE, frameProvider2) + .addFunction( + "output3", + functionCall("min", ImmutableList.of("input3")), + FUNCTION_HANDLE, + frameProvider3) .hashSymbol("hash"), strictProject( Maps.asMap( @@ -122,6 +146,38 @@ public void testOneFunctionNotNeeded() values(inputSymbolNameList))))); } + @Test + public void testTwoFunctionsNotNeeded() + { + tester().assertThat(new PruneWindowColumns()) + .on(p -> buildProjectedWindow(p, + symbol -> symbol.getName().equals("output3") || symbol.getName().equals("unused"), + alwaysTrue())) + .matches( + strictProject( + ImmutableMap.of( + "output3", expression("output3"), + "unused", expression("unused")), + window(windowBuilder -> windowBuilder + .prePartitionedInputs(ImmutableSet.of()) + .specification( + ImmutableList.of("partitionKey"), + ImmutableList.of("orderKey"), + ImmutableMap.of("orderKey", SortOrder.ASC_NULLS_FIRST)) + .preSortedOrderPrefix(0) + .addFunction( + "output3", + functionCall("min", ImmutableList.of("input3")), + FUNCTION_HANDLE, + frameProvider3) + .hashSymbol("hash"), + strictProject( + Maps.asMap( + Sets.difference(inputSymbolNameSet, ImmutableSet.of("input1", "startValue1", "endValue1", "input2", "startValue2", "endValue2")), + PlanMatchPattern::expression), + values(inputSymbolNameList))))); + } + @Test public void testAllColumnsNeeded() { @@ -157,7 +213,8 @@ public void testUnusedInputNotNeeded() strictProject( ImmutableMap.of( "output1", expression("output1"), - "output2", expression("output2")), + "output2", expression("output2"), + "output3", expression("output3")), window(windowBuilder -> windowBuilder .prePartitionedInputs(ImmutableSet.of()) .specification( @@ -175,6 +232,11 @@ public void testUnusedInputNotNeeded() functionCall("min", ImmutableList.of("input2")), FUNCTION_HANDLE, frameProvider2) + .addFunction( + "output3", + functionCall("min", ImmutableList.of("input3")), + FUNCTION_HANDLE, + frameProvider3) .hashSymbol("hash"), strictProject( Maps.asMap( @@ -193,15 +255,22 @@ private static PlanNode buildProjectedWindow( VariableReferenceExpression hash = p.variable("hash"); VariableReferenceExpression startValue1 = p.variable("startValue1"); VariableReferenceExpression startValue2 = p.variable("startValue2"); + VariableReferenceExpression startValue3 = p.variable("startValue3"); + VariableReferenceExpression sortKeyForStartComparison3 = p.variable("sortKeyForStartComparison3"); VariableReferenceExpression endValue1 = p.variable("endValue1"); VariableReferenceExpression endValue2 = p.variable("endValue2"); + VariableReferenceExpression endValue3 = p.variable("endValue3"); + VariableReferenceExpression sortKeyForEndComparison3 = p.variable("sortKeyForEndComparison3"); VariableReferenceExpression input1 = p.variable("input1"); VariableReferenceExpression input2 = p.variable("input2"); + VariableReferenceExpression input3 = p.variable("input3"); VariableReferenceExpression unused = p.variable("unused"); VariableReferenceExpression output1 = p.variable("output1"); VariableReferenceExpression output2 = p.variable("output2"); - List inputs = ImmutableList.of(orderKey, partitionKey, hash, startValue1, startValue2, endValue1, endValue2, input1, input2, unused); - List outputs = ImmutableList.builder().addAll(inputs).add(output1, output2).build(); + VariableReferenceExpression output3 = p.variable("output3"); + List inputs = ImmutableList.of(orderKey, partitionKey, hash, startValue1, startValue2, startValue3, endValue1, endValue2, endValue3, + sortKeyForStartComparison3, sortKeyForEndComparison3, input1, input2, input3, unused); + List outputs = ImmutableList.builder().addAll(inputs).add(output1, output2, output3).build(); List filteredInputs = inputs.stream().filter(sourceFilter).collect(toImmutableList()); @@ -223,8 +292,10 @@ private static PlanNode buildProjectedWindow( RANGE, UNBOUNDED_PRECEDING, Optional.of(startValue1), + Optional.of(orderKey), CURRENT_ROW, Optional.of(endValue1), + Optional.of(orderKey), Optional.of(new SymbolReference(startValue1.getName())).map(Expression::toString), Optional.of(new SymbolReference(endValue2.getName())).map(Expression::toString)), false), @@ -235,10 +306,26 @@ private static PlanNode buildProjectedWindow( RANGE, UNBOUNDED_PRECEDING, Optional.of(startValue2), + Optional.of(orderKey), CURRENT_ROW, Optional.of(endValue2), + Optional.of(orderKey), Optional.of(new SymbolReference(startValue2.getName())).map(Expression::toString), Optional.of(new SymbolReference(endValue2.getName())).map(Expression::toString)), + false), + output3, + new WindowNode.Function( + call(FUNCTION_NAME, FUNCTION_HANDLE, BIGINT, input3), + new WindowNode.Frame( + RANGE, + PRECEDING, + Optional.of(startValue3), + Optional.of(sortKeyForStartComparison3), + FOLLOWING, + Optional.of(endValue3), + Optional.of(sortKeyForEndComparison3), + Optional.of(new SymbolReference(startValue3.getName())).map(Expression::toString), + Optional.of(new SymbolReference(endValue3.getName())).map(Expression::toString)), false)), hash, p.values( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index 28413cb1ac36b..2e1e094bf8ea9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -59,9 +59,11 @@ public TestSwapAdjacentWindowsBySpecifications() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), CURRENT_ROW, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); functionHandle = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("avg", fromTypes(BIGINT)); @@ -156,8 +158,55 @@ public void dependentWindowsAreNotReorderedWithOffset() ROWS, PRECEDING, Optional.of(new VariableReferenceExpression(Optional.empty(), "startValue", BIGINT)), + Optional.empty(), CURRENT_ROW, Optional.empty(), + Optional.empty(), + Optional.of("startValue"), + Optional.empty()); + WindowNode.Function functionWithOffset = new WindowNode.Function( + call( + "avg", + functionHandle, + BIGINT, + ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "a", BIGINT))), + frameWithRowOffset, + false); + + tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) + .on(p -> + p.window(new WindowNode.Specification( + ImmutableList.of(p.variable("a")), + Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), + ImmutableMap.of(p.variable("avg_1"), functionWithOffset), + p.window(new WindowNode.Specification( + ImmutableList.of(p.variable("a"), p.variable("b")), + Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), + ImmutableMap.of(p.variable("startValue"), windowFunction), + p.values(p.variable("a"), p.variable("b"), p.variable("sortkey"))))) + .doesNotFire(); + } + + @Test + public void dependentWindowsWithRangeAreNotReordered() + { + FunctionHandle rankFunction = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("rank", ImmutableList.of()); + WindowNode.Function windowFunction = new WindowNode.Function( + call( + "rank", + rankFunction, + BIGINT, + ImmutableList.of()), + frame, + false); + WindowNode.Frame frameWithRowOffset = new WindowNode.Frame( + RANGE, + PRECEDING, + Optional.of(new VariableReferenceExpression(Optional.empty(), "startValue", BIGINT)), + Optional.of(new VariableReferenceExpression(Optional.empty(), "sortKeyCoercedForFrameStartComparison", BIGINT)), + CURRENT_ROW, + Optional.empty(), + Optional.empty(), Optional.of("startValue"), Optional.empty()); WindowNode.Function functionWithOffset = new WindowNode.Function( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java index 1ed909a145443..b9748bbce07ba 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java @@ -203,8 +203,9 @@ public void testIdenticalWindowSpecificationsABcpA() window(windowMatcherBuilder -> windowMatcherBuilder .specification(specificationB) .addFunction(functionCall("lag", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS, "ONE", "ZERO"))), - project(ImmutableMap.of("ONE", expression("CAST(1 AS bigint)"), "ZERO", expression("0.0E0")), - LINEITEM_TABLESCAN_DOQSS))))); + anyTree( + project(ImmutableMap.of("ONE", expression("CAST(1 AS bigint)"), "ZERO", expression("0.0E0")), + LINEITEM_TABLESCAN_DOQSS)))))); } @Test @@ -256,8 +257,9 @@ public void testIdenticalWindowSpecificationsAAcpA() .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))) .addFunction(functionCall("lag", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS, "ONE", "ZERO"))) .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), - project(ImmutableMap.of("ONE", expression("CAST(1 AS bigint)"), "ZERO", expression("0.0E0")), - LINEITEM_TABLESCAN_DOQS)))); + anyTree( + project(ImmutableMap.of("ONE", expression("CAST(1 AS bigint)"), "ZERO", expression("0.0E0")), + LINEITEM_TABLESCAN_DOQS))))); } @Test @@ -383,6 +385,35 @@ public void testMergeDifferentFramesWithDefault() LINEITEM_TABLESCAN_DOQS))); } + @Test + public void testMergeRangeFramesWithDefault() + { + Optional frameD = Optional.of(new WindowFrame( + WindowFrame.Type.RANGE, + new FrameBound(FrameBound.Type.CURRENT_ROW), + Optional.of(new FrameBound(FrameBound.Type.UNBOUNDED_FOLLOWING)))); + + ExpectedValueProvider specificationD = specification( + ImmutableList.of(SUPPKEY_ALIAS), + ImmutableList.of(ORDERKEY_ALIAS), + ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); + + @Language("SQL") String sql = "SELECT " + + "SUM(quantity) OVER (PARTITION BY suppkey ORDER BY orderkey) sum_quantity_C, " + + "AVG(quantity) OVER (PARTITION BY suppkey ORDER BY orderkey RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) avg_quantity_D, " + + "SUM(discount) OVER (PARTITION BY suppkey ORDER BY orderkey) sum_discount_C " + + "FROM lineitem"; + + assertUnitPlan(sql, + anyTree( + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationD) + .addFunction(functionCall("avg", frameD, ImmutableList.of(QUANTITY_ALIAS))) + .addFunction(functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(DISCOUNT_ALIAS))) + .addFunction(functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + LINEITEM_TABLESCAN_DOQS))); + } + @Test public void testNotMergeAcrossJoinBranches() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java index a41eae7d154f1..33e69c9670302 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java @@ -55,9 +55,11 @@ public void windowNodePruning() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); assertRuleApplication() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java index a47fe5ad36c67..b96b3863fa74c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java @@ -339,7 +339,13 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0), new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(1), new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(2))), - new PruneUnreferencedOutputs()); + new PruneUnreferencedOutputs(), + new IterativeOptimizer( + new RuleStatsRecorder(), + getQueryRunner().getStatsCalculator(), + getQueryRunner().getEstimatedExchangesCostCalculator(), + ImmutableSet.of( + new RemoveRedundantIdentityProjections()))); assertPlan(sql, pattern, optimizers); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java index b854d40d4e6af..91a799d615ffb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java @@ -107,9 +107,11 @@ public void testSerializationRoundtrip() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); PlanNodeId id = newId(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyNoOriginalExpression.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyNoOriginalExpression.java index 6a15fc2d09279..1e47043973e05 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyNoOriginalExpression.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyNoOriginalExpression.java @@ -69,6 +69,7 @@ public class TestVerifyNoOriginalExpression ComparisonExpression.Operator.EQUAL, new SymbolReference("count"), new Cast(new LongLiteral("5"), "bigint")); + private static final VariableReferenceExpression SORT_KEY = new VariableReferenceExpression(Optional.empty(), "count", BIGINT); private Metadata metadata; private PlanBuilder builder; @@ -113,8 +114,10 @@ public void testValidateForWindow() WindowNode.Frame.WindowType.RANGE, WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING, startValue, + Optional.of(SORT_KEY), WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING, endValue, + Optional.of(SORT_KEY), originalStartValue, originalEndValue); WindowNode.Function function = new WindowNode.Function( diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 40063e267374c..a45975b13356a 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -154,6 +154,12 @@ protected void assertQuery(@Language("SQL") String actual, @Language("SQL") Stri QueryAssertions.assertQuery(queryRunner, getSession(), actual, expectedQueryRunner, expected, false, false); } + protected void assertQueryWithSameQueryRunner(@Language("SQL") String actual, @Language("SQL") String expected) + { + checkArgument(!actual.equals(expected)); + QueryAssertions.assertQuery(queryRunner, getSession(), actual, queryRunner, expected, false, false); + } + protected void assertQuery(Session session, @Language("SQL") String actual, @Language("SQL") String expected) { QueryAssertions.assertQuery(queryRunner, session, actual, expectedQueryRunner, expected, false, false); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestWindowQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestWindowQueries.java index d76377e86f79c..739d344c51918 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestWindowQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestWindowQueries.java @@ -644,4 +644,722 @@ public void testMultipleInstancesOfWindowFunction() "(5, 'A', 'e', 'c', null), " + "(6, 'A', null, 'e', 'e')"); } + + @Test + public void testNullsSortKey() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)", + "VALUES " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[1, 1, 2, 2], " + + "ARRAY[1, 1, 2, 2], " + + "ARRAY[1, 1, 2, 2, 3], " + + "ARRAY[1, 1, 2, 2, 3], " + + "ARRAY[2, 2, 3]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)", + "VALUES " + + "ARRAY[1, 1, 2, 2], " + + "ARRAY[1, 1, 2, 2], " + + "ARRAY[1, 1, 2, 2, 3], " + + "ARRAY[1, 1, 2, 2, 3], " + + "ARRAY[2, 2, 3], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)", + "VALUES " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[3, 2, 2], " + + "ARRAY[3, 2, 2, 1, 1], " + + "ARRAY[3, 2, 2, 1, 1], " + + "ARRAY[2, 2, 1, 1], " + + "ARRAY[2, 2, 1, 1]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)", + "VALUES " + + "ARRAY[3, 2, 2], " + + "ARRAY[3, 2, 2, 1, 1], " + + "ARRAY[3, 2, 2, 1, 1], " + + "ARRAY[2, 2, 1, 1], " + + "ARRAY[2, 2, 1, 1], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2) T(a)", + "VALUES " + + "ARRAY[1, 2, null, null], " + + "ARRAY[1, 2, null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2) T(a)", + "VALUES " + + "ARRAY[1, 2], " + + "ARRAY[1, 2], " + + "ARRAY[1, 2, null, null], " + + "ARRAY[1, 2, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 1, 2], " + + "ARRAY[null, null, 1, 2]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2) T(a)", + "VALUES " + + "ARRAY[null, null, 1, 2], " + + "ARRAY[null, null, 1, 2], " + + "ARRAY[1, 2], " + + "ARRAY[1, 2]"); + } + + @Test + public void testNoValueFrameBounds() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 1, 1], " + + "ARRAY[null, null, 1, 1], " + + "ARRAY[null, null, 1, 1, 2]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[2]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[2]"); + } + + @Test + public void testMixedTypeFrameBoundsAscendingNullsFirst() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 PRECEDING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 1, 1]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[2]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 1.5 PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1, 2]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 0.5 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[2]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 0.5 FOLLOWING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[2], " + + "ARRAY[2], " + + "null"); + } + + @Test + public void testMixedTypeFrameBoundsAscendingNullsLast() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 PRECEDING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "null, " + + "null, " + + "ARRAY[1, 1], " + + "ARRAY[1, 1, 2, null, null], " + + "ARRAY[1, 1, 2, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2, null, null], " + + "ARRAY[1, 1, 2, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[2], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 1.5 PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1, 2], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 0.5 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[1, 1, 2, null, null], " + + "ARRAY[1, 1, 2, null, null], " + + "ARRAY[2, null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 0.5 FOLLOWING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[2, null, null], " + + "ARRAY[2, null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + } + + @Test + public void testMixedTypeFrameBoundsDescendingNullsFirst() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 PRECEDING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 2], " + + "ARRAY[null, null, 2]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 2], " + + "ARRAY[null, null, 2, 1, 1], " + + "ARRAY[null, null, 2, 1, 1]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[2, 1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN 1.5 PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[2], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN 1.5 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null, 2, 1, 1], " + + "ARRAY[null, null, 2, 1, 1], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN 1.5 FOLLOWING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[null, null, 2, 1, 1], " + + "ARRAY[null, null, 2, 1, 1], " + + "null, " + + "null, " + + "null"); + } + + @Test + public void testMixedTypeFrameBoundsDescendingNullsLast() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 PRECEDING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "null, " + + "ARRAY[2], " + + "ARRAY[2], " + + "ARRAY[2, 1, 1, null, null], " + + "ARRAY[2, 1, 1, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1, null, null], " + + "ARRAY[2, 1, 1, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[2, 1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 0.5 PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[2], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 0.5 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[2, 1, 1, null, null], " + + "ARRAY[1, 1, null, null], " + + "ARRAY[1, 1, null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 1.5 FOLLOWING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)", + "VALUES " + + "ARRAY[cast(null as integer), cast(null as integer)], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + } + + @Test + public void testEmptyInput() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (SELECT 1 WHERE false) T(a)", + "SELECT ARRAY[1] WHERE false"); + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE UNBOUNDED PRECEDING) " + + "FROM (SELECT 1 WHERE false) T(a)", + "SELECT ARRAY[1] WHERE false"); + } + + @Test + public void testEmptyFrame() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 10 PRECEDING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)", + "VALUES " + + "CAST(null AS array), " + + "null, " + + "null, " + + "null, " + + "null, " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 10 FOLLOWING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)", + "VALUES " + + "CAST(null AS array), " + + "null, " + + "null, " + + "null, " + + "null, " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 0.5 FOLLOWING AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, 2, 4) T(a)", + "VALUES " + + "ARRAY[2], " + + "null, " + + "null"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 1 FOLLOWING AND 2 FOLLOWING) " + + "FROM (VALUES 1.0, 1.1) T(a)", + "VALUES " + + "CAST(null AS array), " + + "null"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a NULLS LAST RANGE BETWEEN 1 FOLLOWING AND 2 FOLLOWING) " + + "FROM (VALUES 1.0, 1.1, null) T(a)", + "VALUES " + + "CAST(null AS array), " + + "null, " + + "ARRAY[null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES 1.0, 1.1) T(a)", + "VALUES " + + "CAST(null AS array), " + + "null"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a NULLS FIRST RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES null, 1.0, 1.1) T(a)", + "VALUES " + + "ARRAY[cast(null as decimal(2,1))], " + + "null, " + + "null"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES 1, 2) T(a)", + "VALUES " + + "null, " + + "ARRAY[1]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a NULLS FIRST RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES null, 1, 2) T(a)", + "VALUES " + + "ARRAY[null], " + + "null, " + + "ARRAY[1]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a NULLS FIRST RANGE BETWEEN 2 PRECEDING AND 1.5 PRECEDING) " + + "FROM (VALUES null, 1, 2) T(a)", + "VALUES " + + "ARRAY[cast(null as integer)], " + + "null, " + + "null"); + } + + @Test + public void testOnlyNulls() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 1 FOLLOWING AND 2 FOLLOWING) " + + "FROM (VALUES CAST(null AS integer), null, null) T(a)", + "VALUES " + + "ARRAY[cast(null as integer), cast(null as integer), cast(null as integer)], " + + "ARRAY[null, null, null], " + + "ARRAY[null, null, null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES CAST(null AS integer), null, null) T(a)", + "VALUES " + + "ARRAY[cast(null as integer), cast(null as integer), cast(null as integer)], " + + "ARRAY[null, null, null], " + + "ARRAY[null, null, null]"); + } + + @Test + public void testAllPartitionSameValues() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 1 FOLLOWING AND 2 FOLLOWING) " + + "FROM (VALUES 1, 1, 1) T(a)", + "VALUES " + + "CAST(null AS array), " + + "null, " + + "null"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES 1, 1, 1) T(a)", + "VALUES " + + "CAST(null AS array), " + + "null, " + + "null"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 1, 1) T(a)", + "VALUES " + + "ARRAY[1, 1, 1], " + + "ARRAY[1, 1, 1], " + + "ARRAY[1, 1, 1]"); + } + + @Test + public void testZeroOffset() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 0 PRECEDING AND 0 FOLLOWING) " + + "FROM (VALUES 1, 2, 1, null) T(a)", + "VALUES " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[2], " + + "ARRAY[null]"); + } + + @Test + public void testNonConstantOffset() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN x * 10 PRECEDING AND y / 10.0 FOLLOWING) " + + "FROM (VALUES (1, 0.1, 10), (2, 0.2, 20), (4, 0.4, 40)) T(a, x, y)", + "VALUES " + + "ARRAY[1, 2], " + + "ARRAY[1, 2, 4], " + + "ARRAY[1, 2, 4]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN x * 10 PRECEDING AND y / 10.0 FOLLOWING) " + + "FROM (VALUES (1, 0.1, 10), (2, 0.2, 20), (4, 0.4, 40), (null, 0.5, 50)) T(a, x, y)", + "VALUES " + + "ARRAY[1, 2], " + + "ARRAY[1, 2, 4], " + + "ARRAY[1, 2, 4], " + + "ARRAY[null]"); + } + + @Test + public void testInvalidOffset() + { + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a ASC RANGE x PRECEDING) " + + "FROM (VALUES (1, 0.1), (2, -0.2)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a ASC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (2, -0.2)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE x PRECEDING) " + + "FROM (VALUES (1, 0.1), (2, -0.2)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (2, -0.2)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE x PRECEDING) " + + "FROM (VALUES (1, 0.1), (2, null)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (2, null)) T(a, x)", + "Window frame offset value must not be negative or null"); + + // fail if offset is invalid for null sort key + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (null, null)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (null, -0.1)) T(a, x)", + "Window frame offset value must not be negative or null"); + + // test invalid offset of different types + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, BIGINT '-1')) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, INTEGER '-1')) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (SMALLINT '1', SMALLINT '-1')) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (TINYINT '1', TINYINT '-1')) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, -1.1e0)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, REAL '-1.1')) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, -1.0001)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' YEAR)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' MONTH)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' DAY)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' HOUR)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' MINUTE)) T(a, x)", + "Window frame offset value must not be negative or null"); + + assertQueryFails("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' SECOND)) T(a, x)", + "Window frame offset value must not be negative or null"); + } + + @Test + public void testWindowPartitioning() + { + assertQuery("SELECT a, p, array_agg(a) OVER(PARTITION BY p ORDER BY a ASC NULLS FIRST RANGE BETWEEN 0.5 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES (1, 'x'), (2, 'x'), (null, 'x'), (null, 'y'), (2, 'y')) T(a, p)", + "VALUES " + + "(null, 'x', ARRAY[null]), " + + "(1, 'x', ARRAY[1, 2]), " + + "(2, 'x', ARRAY[2]), " + + "(null, 'y', ARRAY[null]), " + + "(2, 'y', ARRAY[2])"); + + assertQuery("SELECT a, p, array_agg(a) OVER(PARTITION BY p ORDER BY a ASC NULLS FIRST RANGE BETWEEN 0.5 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES (1, 'x'), (2, 'x'), (null, 'x'), (null, 'y'), (2, 'y'), (null, null), (null, null), (1, null)) T(a, p)", + "VALUES " + + "(null, null, ARRAY[null, null]), " + + "(null, null, ARRAY[null, null]), " + + "(1, null, ARRAY[1]), " + + "(null, 'x', ARRAY[null]), " + + "(1, 'x', ARRAY[1, 2]), " + + "(2, 'x', ARRAY[2]), " + + "(null, 'y', ARRAY[null]), " + + "(2, 'y', ARRAY[2])"); + } + + @Test + public void testTypes() + { + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN DOUBLE '0.5' PRECEDING AND TINYINT '1' FOLLOWING) " + + "FROM (VALUES 1, null, 2) T(a)", + "VALUES " + + "ARRAY[1, 2], " + + "ARRAY[2], " + + "ARRAY[null]"); + + assertQuery("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 0.5 PRECEDING AND 1.000 FOLLOWING) " + + "FROM (VALUES REAL '1', null, 2) T(a)", + "VALUES " + + "ARRAY[CAST('1' AS REAL), CAST('2' AS REAL)], " + + "ARRAY[CAST('2' AS REAL)], " + + "ARRAY[null]"); + + assertQuery("SELECT x, array_agg(x) OVER(ORDER BY x DESC RANGE BETWEEN interval '1' month PRECEDING AND interval '1' month FOLLOWING) " + + "FROM (VALUES DATE '2001-01-31', DATE '2001-08-25', DATE '2001-09-25', DATE '2001-09-26') T(x)", + "VALUES " + + "(DATE '2001-09-26', ARRAY[DATE '2001-09-26', DATE '2001-09-25']), " + + "(DATE '2001-09-25', ARRAY[DATE '2001-09-26', DATE '2001-09-25', DATE '2001-08-25']), " + + "(DATE '2001-08-25', ARRAY[DATE '2001-09-25', DATE '2001-08-25']), " + + "(DATE '2001-01-31', ARRAY[DATE '2001-01-31'])"); + + // January 31 + 1 month sets the frame bound to the last day of February. March 1 is out of range. + assertQuery("SELECT x, array_agg(x) OVER(ORDER BY x RANGE BETWEEN CURRENT ROW AND interval '1' month FOLLOWING) " + + "FROM (VALUES DATE '2001-01-31', DATE '2001-02-28', DATE '2001-03-01') T(x)", + "VALUES " + + "(DATE '2001-01-31', ARRAY[DATE '2001-01-31', DATE '2001-02-28']), " + + "(DATE '2001-02-28', ARRAY[DATE '2001-02-28', DATE '2001-03-01']), " + + "(DATE '2001-03-01', ARRAY[DATE '2001-03-01'])"); + + // H2 and Presto has some type conversion problem for Interval type, hence use the same query runner for this query + assertQueryWithSameQueryRunner("SELECT x, array_agg(x) OVER(ORDER BY x RANGE BETWEEN interval '1' year PRECEDING AND interval '1' month FOLLOWING) " + + "FROM (VALUES " + + "INTERVAL '1' month, " + + "INTERVAL '2' month, " + + "INTERVAL '5' year) T(x)", + "VALUES " + + "(INTERVAL '1' month, ARRAY[INTERVAL '1' month, INTERVAL '2' month]), " + + "(INTERVAL '2' month, ARRAY[INTERVAL '1' month, INTERVAL '2' month]), " + + "(INTERVAL '5' year, ARRAY[INTERVAL '5' year])"); + } + + @Test + public void testMultipleWindowFunctions() + { + assertQuery("SELECT x, array_agg(date) OVER(ORDER BY x RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), avg(number) OVER(ORDER BY x RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES " + + "(2, DATE '2222-01-01', 4.4), " + + "(1, DATE '1111-01-01', 2.2), " + + "(3, DATE '3333-01-01', 6.6)) T(x, date, number)", + "VALUES " + + "(1, ARRAY[DATE '1111-01-01', DATE '2222-01-01'], 3.3), " + + "(2, ARRAY[DATE '1111-01-01', DATE '2222-01-01', DATE '3333-01-01'], 4.4), " + + "(3, ARRAY[DATE '2222-01-01', DATE '3333-01-01'], 5.5)"); + + assertQuery("SELECT x, array_agg(a) OVER(ORDER BY x RANGE BETWEEN 2 PRECEDING AND CURRENT ROW), array_agg(a) OVER(ORDER BY x RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) " + + "FROM (VALUES " + + "(1.0, 1), " + + "(2.0, 2), " + + "(3.0, 3), " + + "(4.0, 4), " + + "(5.0, 5), " + + "(6.0, 6)) T(x, a)", + "VALUES " + + "(1.0, ARRAY[1], ARRAY[1, 2, 3]), " + + "(2.0, ARRAY[1, 2], ARRAY[2, 3, 4]), " + + "(3.0, ARRAY[1, 2, 3], ARRAY[3, 4, 5]), " + + "(4.0, ARRAY[2, 3, 4], ARRAY[4, 5, 6]), " + + "(5.0, ARRAY[3, 4, 5], ARRAY[5, 6]), " + + "(6.0, ARRAY[4, 5, 6], ARRAY[6])"); + } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/H2QueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/H2QueryRunner.java index 01d03f257e2db..e6c74183322c7 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/H2QueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/H2QueryRunner.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; import com.facebook.presto.common.type.DecimalType; import com.facebook.presto.common.type.DistinctType; import com.facebook.presto.common.type.RowType; @@ -374,6 +375,10 @@ private static Object[] mapArrayValues(ArrayType arrayType, Object[] values) .toArray(); } + if (elementType instanceof DateType) { + return Arrays.stream(values).map(v -> v == null ? null : ((Date) v).toLocalDate()).toArray(); + } + return values; }