diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/ShortDecimalType.java b/presto-common/src/main/java/com/facebook/presto/common/type/ShortDecimalType.java index a0606adb2c229..b55c075d682f5 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/ShortDecimalType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/ShortDecimalType.java @@ -26,10 +26,10 @@ import static com.facebook.presto.common.type.Decimals.MAX_SHORT_PRECISION; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; -final class ShortDecimalType +public final class ShortDecimalType extends DecimalType { - ShortDecimalType(int precision, int scale) + public ShortDecimalType(int precision, int scale) { super(precision, scale, long.class); validatePrecisionScale(precision, scale, MAX_SHORT_PRECISION); diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java index 2364575c099f8..e35ff56499263 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java @@ -35,6 +35,7 @@ public BuiltInFunctionHandle(@JsonProperty("signature") Signature signature) checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); } + @Override @JsonProperty public Signature getSignature() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java index 4a47420ee316a..1fd2980a5c53f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java @@ -58,6 +58,7 @@ import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; import static com.facebook.presto.operator.SyntheticAddress.encodeSyntheticAddress; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.units.DataSize.Unit.BYTE; @@ -443,6 +444,12 @@ public PagesHashStrategy createPagesHashStrategy(List joinChannels, Opt groupByUsesEqualTo); } + public PagesIndexComparator createChannelComparator(int leftChannel, int rightChannel, SortOrder sortOrder) + { + checkArgument(types.get(leftChannel).equals(types.get(rightChannel)), "comparing channels of different types: %s and %s", types.get(leftChannel), types.get(rightChannel)); + return new SimpleChannelComparator(leftChannel, rightChannel, types.get(leftChannel), sortOrder); + } + public LookupSourceSupplier createLookupSourceSupplier( Session session, List joinChannels, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/SimpleChannelComparator.java b/presto-main/src/main/java/com/facebook/presto/operator/SimpleChannelComparator.java new file mode 100644 index 0000000000000..1ff46c8423782 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/SimpleChannelComparator.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; + +import static com.facebook.presto.operator.SyntheticAddress.decodePosition; +import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; +import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static java.util.Objects.requireNonNull; + +public class SimpleChannelComparator + implements PagesIndexComparator +{ + private final int leftChannel; + private final int rightChannel; + private final SortOrder sortOrder; + private final Type sortType; + + public SimpleChannelComparator(int leftChannel, int rightChannel, Type sortType, SortOrder sortOrder) + { + this.leftChannel = leftChannel; + this.rightChannel = rightChannel; + this.sortOrder = requireNonNull(sortOrder, "sortOrder is null"); + this.sortType = requireNonNull(sortType, "sortType is null."); + } + + @Override + public int compareTo(PagesIndex pagesIndex, int leftPosition, int rightPosition) + { + long leftPageAddress = pagesIndex.getValueAddresses().get(leftPosition); + int leftBlockIndex = decodeSliceIndex(leftPageAddress); + int leftBlockPosition = decodePosition(leftPageAddress); + + long rightPageAddress = pagesIndex.getValueAddresses().get(rightPosition); + int rightBlockIndex = decodeSliceIndex(rightPageAddress); + int rightBlockPosition = decodePosition(rightPageAddress); + + try { + Block leftBlock = pagesIndex.getChannel(leftChannel).get(leftBlockIndex); + Block rightBlock = pagesIndex.getChannel(rightChannel).get(rightBlockIndex); + int result = sortOrder.compareBlockValue(sortType, leftBlock, leftBlockPosition, rightBlock, rightBlockPosition); + + // sortOrder compares block values and adjusts the result by ASC and DESC. SimpleChannelComparator should + // return the simple comparison, so reverse it if it is DESC. + return sortOrder.isAscending() ? result : -result; + } + catch (Throwable throwable) { + throwIfUnchecked(throwable); + throw new PrestoException(GENERIC_INTERNAL_ERROR, throwable); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java index e3f4b8f685ac1..1d55d5d87231a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java @@ -21,6 +21,7 @@ import com.facebook.presto.operator.WorkProcessor.ProcessState; import com.facebook.presto.operator.WorkProcessor.Transformation; import com.facebook.presto.operator.WorkProcessor.TransformationState; +import com.facebook.presto.operator.window.FrameInfo; import com.facebook.presto.operator.window.FramedWindowFunction; import com.facebook.presto.operator.window.WindowPartition; import com.facebook.presto.spi.plan.PlanNodeId; @@ -29,6 +30,7 @@ import com.facebook.presto.sql.gen.OrderingCompiler; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.PeekingIterator; @@ -38,6 +40,8 @@ import javax.annotation.Nullable; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.atomic.AtomicReference; @@ -47,6 +51,9 @@ import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; import static com.facebook.presto.operator.SpillingUtils.checkSpillSucceeded; import static com.facebook.presto.operator.WorkProcessor.TransformationState.needsMoreData; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.FOLLOWING; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.util.MergeSortedPages.mergeSortedPages; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkPositionIndex; @@ -270,7 +277,9 @@ public WindowOperator( preGroupedChannels, unGroupedPartitionChannels, preSortedChannels, - sortChannels); + sortChannels, + sortOrder, + windowFunctionDefinitions); if (spillEnabled) { PagesIndexWithHashStrategies mergedPagesIndexWithHashStrategies = new PagesIndexWithHashStrategies( @@ -282,7 +291,9 @@ public WindowOperator( ImmutableList.of(), // merged pages are pre sorted on all sort channels sortChannels, - sortChannels); + sortChannels, + sortOrder, + windowFunctionDefinitions); this.spillablePagesToPagesIndexes = Optional.of(new SpillablePagesToPagesIndexes( inMemoryPagesIndexWithHashStrategies, @@ -381,6 +392,7 @@ private static class PagesIndexWithHashStrategies final PagesHashStrategy preSortedPartitionHashStrategy; final PagesHashStrategy peerGroupHashStrategy; final int[] preGroupedPartitionChannels; + final Map frameBoundComparators; PagesIndexWithHashStrategies( PagesIndex.Factory pagesIndexFactory, @@ -389,7 +401,9 @@ private static class PagesIndexWithHashStrategies List preGroupedPartitionChannels, List unGroupedPartitionChannels, List preSortedChannels, - List sortChannels) + List sortChannels, + List sortOrder, + List windowFunctionDefinitions) { this.pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); this.preGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preGroupedPartitionChannels, OptionalInt.empty()); @@ -397,6 +411,77 @@ private static class PagesIndexWithHashStrategies this.preSortedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty()); this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, OptionalInt.empty()); this.preGroupedPartitionChannels = Ints.toArray(preGroupedPartitionChannels); + this.frameBoundComparators = createFrameBoundComparators(pagesIndex, windowFunctionDefinitions, sortOrder); + } + } + + /** + * Create comparators necessary for seeking frame start or frame end for window functions with frame type RANGE. + * Whenever a frame bound is specified as RANGE X PRECEDING or RANGE X FOLLOWING, + * a dedicated comparator is created to compare sort key values with expected frame bound values. + */ + private static Map createFrameBoundComparators(PagesIndex pagesIndex, + List windowFunctionDefinitions, + List sortOrders) + { + ImmutableMap.Builder builder = ImmutableMap.builder(); + + for (int i = 0; i < windowFunctionDefinitions.size(); i++) { + FrameInfo frameInfo = windowFunctionDefinitions.get(i).getFrameInfo(); + if (frameInfo.getType() == RANGE) { + if (frameInfo.getStartType() == PRECEDING || frameInfo.getStartType() == FOLLOWING) { + // Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY + SortOrder sortOrder = sortOrders.get(0); + PagesIndexComparator comparator = pagesIndex.createChannelComparator(frameInfo.getSortKeyChannelForStartComparison(), frameInfo.getStartChannel(), sortOrder); + builder.put(new FrameBoundKey(i, FrameBoundKey.Type.START), comparator); + } + if (frameInfo.getEndType() == PRECEDING || frameInfo.getEndType() == FOLLOWING) { + // Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY + SortOrder sortOrder = sortOrders.get(0); + PagesIndexComparator comparator = pagesIndex.createChannelComparator(frameInfo.getSortKeyChannelForEndComparison(), frameInfo.getEndChannel(), sortOrder); + builder.put(new FrameBoundKey(i, FrameBoundKey.Type.END), comparator); + } + } + } + + return builder.build(); + } + + public static class FrameBoundKey + { + private final int functionIndex; + private final Type type; + + public enum Type + { + START, + END; + } + + public FrameBoundKey(int functionIndex, Type type) + { + this.functionIndex = functionIndex; + this.type = requireNonNull(type, "type is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FrameBoundKey that = (FrameBoundKey) o; + return functionIndex == that.functionIndex && + type == that.type; + } + + @Override + public int hashCode() + { + return Objects.hash(functionIndex, type); } } @@ -501,7 +586,14 @@ public ProcessState process() int partitionEnd = findGroupEnd(pagesIndex, pagesIndexWithHashStrategies.unGroupedPartitionHashStrategy, partitionStart); - WindowPartition partition = new WindowPartition(pagesIndex, partitionStart, partitionEnd, outputChannels, windowFunctions, pagesIndexWithHashStrategies.peerGroupHashStrategy); + WindowPartition partition = new WindowPartition( + pagesIndex, + partitionStart, + partitionEnd, + outputChannels, + windowFunctions, + pagesIndexWithHashStrategies.peerGroupHashStrategy, + pagesIndexWithHashStrategies.frameBoundComparators); windowInfo.addPartition(partition); partitionStart = partitionEnd; return ProcessState.ofResult(partition); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/window/FrameInfo.java b/presto-main/src/main/java/com/facebook/presto/operator/window/FrameInfo.java index 86c4504fa1a5b..97a3eeb7f8883 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/window/FrameInfo.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/window/FrameInfo.java @@ -15,6 +15,7 @@ import com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType; +import com.facebook.presto.sql.tree.SortItem.Ordering; import java.util.Objects; import java.util.Optional; @@ -27,21 +28,33 @@ public class FrameInfo private final WindowType type; private final BoundType startType; private final int startChannel; + private final int sortKeyChannelForStartComparison; private final BoundType endType; private final int endChannel; + private final int sortKeyChannelForEndComparison; + private final int sortKeyChannel; + private final Optional ordering; public FrameInfo( WindowType type, BoundType startType, Optional startChannel, + Optional sortKeyChannelForStartComparison, BoundType endType, - Optional endChannel) + Optional endChannel, + Optional sortKeyChannelForEndComparison, + Optional sortKeyChannel, + Optional ordering) { this.type = requireNonNull(type, "type is null"); this.startType = requireNonNull(startType, "startType is null"); this.startChannel = requireNonNull(startChannel, "startChannel is null").orElse(-1); + this.sortKeyChannelForStartComparison = requireNonNull(sortKeyChannelForStartComparison, "sortKeyChannelForStartComparison is null").orElse(-1); this.endType = requireNonNull(endType, "endType is null"); this.endChannel = requireNonNull(endChannel, "endChannel is null").orElse(-1); + this.sortKeyChannelForEndComparison = requireNonNull(sortKeyChannelForEndComparison, "sortKeyChannelForEndComparison is null").orElse(-1); + this.sortKeyChannel = requireNonNull(sortKeyChannel, "sortKeyChannel is null").orElse(-1); + this.ordering = requireNonNull(ordering, "ordering is null"); } public WindowType getType() @@ -59,6 +72,11 @@ public int getStartChannel() return startChannel; } + public int getSortKeyChannelForStartComparison() + { + return sortKeyChannelForStartComparison; + } + public BoundType getEndType() { return endType; @@ -69,10 +87,25 @@ public int getEndChannel() return endChannel; } + public int getSortKeyChannelForEndComparison() + { + return sortKeyChannelForEndComparison; + } + + public int getSortKeyChannel() + { + return sortKeyChannel; + } + + public Optional getOrdering() + { + return ordering; + } + @Override public int hashCode() { - return Objects.hash(type, startType, startChannel, endType, endChannel); + return Objects.hash(type, startType, startChannel, sortKeyChannelForStartComparison, endType, endChannel, sortKeyChannelForEndComparison, sortKeyChannel, ordering); } @Override @@ -90,9 +123,13 @@ public boolean equals(Object obj) return Objects.equals(this.type, other.type) && Objects.equals(this.startType, other.startType) && + Objects.equals(this.sortKeyChannelForStartComparison, other.sortKeyChannelForStartComparison) && Objects.equals(this.startChannel, other.startChannel) && Objects.equals(this.endType, other.endType) && - Objects.equals(this.endChannel, other.endChannel); + Objects.equals(this.endChannel, other.endChannel) && + Objects.equals(this.sortKeyChannelForEndComparison, other.sortKeyChannelForEndComparison) && + Objects.equals(this.sortKeyChannel, other.sortKeyChannel) && + Objects.equals(this.ordering, other.ordering); } @Override @@ -102,8 +139,12 @@ public String toString() .add("type", type) .add("startType", startType) .add("startChannel", startChannel) + .add("sortKeyChannelForStartComparison", sortKeyChannelForStartComparison) .add("endType", endType) .add("endChannel", endChannel) + .add("sortKeyChannelForEndComparison", sortKeyChannelForEndComparison) + .add("sortKeyChannel", sortKeyChannel) + .add("ordering", ordering) .toString(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java b/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java index 22c72c3f5b720..14910e2a7ad7d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java @@ -16,18 +16,29 @@ import com.facebook.presto.common.PageBuilder; import com.facebook.presto.operator.PagesHashStrategy; import com.facebook.presto.operator.PagesIndex; +import com.facebook.presto.operator.PagesIndexComparator; +import com.facebook.presto.operator.WindowOperator.FrameBoundKey; import com.facebook.presto.spi.function.WindowIndex; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType; +import com.facebook.presto.sql.tree.SortItem.Ordering; import com.google.common.collect.ImmutableList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import static com.facebook.presto.operator.WindowOperator.FrameBoundKey.Type.END; +import static com.facebook.presto.operator.WindowOperator.FrameBoundKey.Type.START; import static com.facebook.presto.spi.StandardErrorCode.INVALID_WINDOW_FRAME; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.CURRENT_ROW; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.FOLLOWING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; +import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkState; import static java.lang.Math.toIntExact; @@ -40,19 +51,31 @@ public final class WindowPartition private final int[] outputChannels; private final List windowFunctions; + + // Recently computed frame bounds for functions with frame type RANGE. + // When computing frame start and frame end for a row, frame bounds for the previous row + // are used as the starting point. Then they are moved backward or forward based on the sort order + // until the matching position for a current row is found. + // This approach is efficient in case when frame offset values are constant. It was chosen + // based on the assumption that in most use cases frame offset is constant rather than + // row-dependent. + private final Map recentRanges; private final PagesHashStrategy peerGroupHashStrategy; + private final Map frameBoundComparators; private int peerGroupStart; private int peerGroupEnd; private int currentPosition; - public WindowPartition(PagesIndex pagesIndex, + public WindowPartition( + PagesIndex pagesIndex, int partitionStart, int partitionEnd, int[] outputChannels, List windowFunctions, - PagesHashStrategy peerGroupHashStrategy) + PagesHashStrategy peerGroupHashStrategy, + Map frameBoundComparators) { this.pagesIndex = pagesIndex; this.partitionStart = partitionStart; @@ -60,6 +83,7 @@ public WindowPartition(PagesIndex pagesIndex, this.outputChannels = outputChannels; this.windowFunctions = ImmutableList.copyOf(windowFunctions); this.peerGroupHashStrategy = peerGroupHashStrategy; + this.frameBoundComparators = frameBoundComparators; // reset functions for new partition WindowIndex windowIndex = new PagesWindowIndex(pagesIndex, partitionStart, partitionEnd); @@ -69,6 +93,28 @@ public WindowPartition(PagesIndex pagesIndex, currentPosition = partitionStart; updatePeerGroup(); + + recentRanges = initializeRangeCache(partitionStart, partitionEnd, peerGroupEnd, windowFunctions); + } + + private static Map initializeRangeCache(int partitionStart, int partitionEnd, int peerGroupEnd, List windowFunctions) + { + Map ranges = new HashMap<>(); + Range initialPeerRange = new Range(0, peerGroupEnd - partitionStart - 1); + Range initialUnboundedRange = new Range(0, partitionEnd - partitionStart - 1); + for (int i = 0; i < windowFunctions.size(); i++) { + FrameInfo frame = windowFunctions.get(i).getFrame(); + if (frame.getType() == RANGE) { + if (frame.getEndType() == UNBOUNDED_FOLLOWING) { + ranges.put(i, initialUnboundedRange); + } + else { + ranges.put(i, initialPeerRange); + } + } + } + + return ranges; } public int getPartitionStart() @@ -103,8 +149,9 @@ public void processNextRow(PageBuilder pageBuilder) updatePeerGroup(); } - for (FramedWindowFunction framedFunction : windowFunctions) { - Range range = getFrameRange(framedFunction.getFrame()); + for (int i = 0; i < windowFunctions.size(); i++) { + FramedWindowFunction framedFunction = windowFunctions.get(i); + Range range = getFrameRange(framedFunction.getFrame(), i); framedFunction.getFunction().processRow( pageBuilder.getBlockBuilder(channel), peerGroupStart - partitionStart, @@ -149,6 +196,29 @@ private void updatePeerGroup() } } + private Range getFrameRange(FrameInfo frameInfo, int functionIndex) + { + switch (frameInfo.getType()) { + case RANGE: + Range range = getFrameRange( + frameInfo, + recentRanges.get(functionIndex), + frameBoundComparators.get(new FrameBoundKey(functionIndex, START)), + frameBoundComparators.get(new FrameBoundKey(functionIndex, END))); + // handle empty frame. If the frame is out of partition bounds, record the nearest valid frame as the 'recentRange' for the next row. + if (emptyFrame(range)) { + recentRanges.put(functionIndex, nearestValidFrame(range)); + return new Range(-1, -1); + } + recentRanges.put(functionIndex, range); + return range; + case ROWS: + return getFrameRange(frameInfo); + default: + throw new IllegalArgumentException("Unsupported frame type: " + frameInfo.getType()); + } + } + private Range getFrameRange(FrameInfo frameInfo) { int rowPosition = currentPosition - partitionStart; @@ -172,9 +242,6 @@ else if (frameInfo.getStartType() == PRECEDING) { else if (frameInfo.getStartType() == FOLLOWING) { frameStart = following(rowPosition, endPosition, getStartValue(frameInfo)); } - else if (frameInfo.getType() == RANGE) { - frameStart = peerGroupStart - partitionStart; - } else { frameStart = rowPosition; } @@ -189,16 +256,229 @@ else if (frameInfo.getEndType() == PRECEDING) { else if (frameInfo.getEndType() == FOLLOWING) { frameEnd = following(rowPosition, endPosition, getEndValue(frameInfo)); } - else if (frameInfo.getType() == RANGE) { + else { + frameEnd = rowPosition; + } + + return new Range(frameStart, frameEnd); + } + + private Range getFrameRange(FrameInfo frameInfo, Range recentRange, PagesIndexComparator startComparator, PagesIndexComparator endComparator) + { + // full partition + if ((frameInfo.getStartType() == UNBOUNDED_PRECEDING && frameInfo.getEndType() == UNBOUNDED_FOLLOWING)) { + return new Range(0, partitionEnd - partitionStart - 1); + } + + // frame defined by peer group + if ((frameInfo.getStartType() == CURRENT_ROW && frameInfo.getEndType() == CURRENT_ROW) || + (frameInfo.getStartType() == CURRENT_ROW && frameInfo.getEndType() == UNBOUNDED_FOLLOWING) || + (frameInfo.getStartType() == UNBOUNDED_PRECEDING && frameInfo.getEndType() == CURRENT_ROW)) { + // same peer group as recent row + if (currentPosition == partitionStart || pagesIndex.positionEqualsPosition(peerGroupHashStrategy, currentPosition - 1, currentPosition)) { + return recentRange; + } + // next peer group + return new Range( + frameInfo.getStartType() == UNBOUNDED_PRECEDING ? 0 : peerGroupStart - partitionStart, + frameInfo.getEndType() == UNBOUNDED_FOLLOWING ? partitionEnd - partitionStart - 1 : peerGroupEnd - partitionStart - 1); + } + + // at this point, frame definition has at least one of: X PRECEDING, Y FOLLOWING + // 1. leading or trailing nulls: frame consists of nulls peer group, possibly extended to partition start / end. + // according to Spec, behavior of "X PRECEDING", "X FOLLOWING" frame boundaries is similar to "CURRENT ROW" for null values. + if (pagesIndex.isNull(frameInfo.getSortKeyChannel(), currentPosition)) { + return new Range( + frameInfo.getStartType() == UNBOUNDED_PRECEDING ? 0 : peerGroupStart - partitionStart, + frameInfo.getEndType() == UNBOUNDED_FOLLOWING ? partitionEnd - partitionStart - 1 : peerGroupEnd - partitionStart - 1); + } + + // 2. non-null value in current row. Find frame boundaries starting from recentRange + int frameStart; + if (frameInfo.getStartType() == UNBOUNDED_PRECEDING) { + frameStart = 0; + } + else if (frameInfo.getStartType() == CURRENT_ROW) { + frameStart = peerGroupStart - partitionStart; + } + else if (frameInfo.getStartType() == PRECEDING) { + frameStart = getFrameStartPreceding(recentRange.getStart(), frameInfo, startComparator); + } + else { + // frameInfo.getStartType() == FOLLOWING + // note: this is the only case where frameStart might get out of partition bound + frameStart = getFrameStartFollowing(recentRange.getStart(), frameInfo, startComparator); + } + + int frameEnd; + if (frameInfo.getEndType() == UNBOUNDED_FOLLOWING) { + frameEnd = partitionEnd - partitionStart - 1; + } + else if (frameInfo.getEndType() == CURRENT_ROW) { frameEnd = peerGroupEnd - partitionStart - 1; } + else if (frameInfo.getEndType() == PRECEDING) { + // note: this is the only case where frameEnd might get out of partition bound + frameEnd = getFrameEndPreceding(recentRange.getEnd(), frameInfo, endComparator); + } else { - frameEnd = rowPosition; + // frameInfo.getEndType() == FOLLOWING + frameEnd = getFrameEndFollowing(recentRange.getEnd(), frameInfo, endComparator); } return new Range(frameStart, frameEnd); } + private int getFrameStartPreceding(int recent, FrameInfo frameInfo, PagesIndexComparator comparator) + { + int sortKeyChannel = frameInfo.getSortKeyChannelForStartComparison(); + Ordering ordering = frameInfo.getOrdering().get(); + + // If the recent frame start points at a null, it means that we are now processing first non-null position. + // For frame start "X PRECEDING", the frame starts at the first null for all null values, and it never includes nulls for non-null values. + if (pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart + recent)) { + return currentPosition - partitionStart; + } + + return seek( + comparator, + sortKeyChannel, + recent, + -1, + ordering == DESCENDING, + 0, + p -> false); + } + + private int getFrameStartFollowing(int recent, FrameInfo frameInfo, PagesIndexComparator comparator) + { + int sortKeyChannel = frameInfo.getSortKeyChannelForStartComparison(); + Ordering ordering = frameInfo.getOrdering().get(); + + int position = recent; + + // If the recent frame start points at the beginning of partition and it is null, it means that we are now processing first non-null position. + // frame start for first non-null position - leave section of leading nulls + if (recent == 0 && pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart)) { + position = currentPosition - partitionStart; + } + // leave section of trailing nulls + while (pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart + position)) { + position--; + } + + return seek( + comparator, + sortKeyChannel, + position, + -1, + ordering == DESCENDING, + 0, + p -> p >= partitionEnd - partitionStart || pagesIndex.isNull(sortKeyChannel, partitionStart + p)); + } + + private int getFrameEndPreceding(int recent, FrameInfo frameInfo, PagesIndexComparator comparator) + { + int sortKeyChannel = frameInfo.getSortKeyChannelForEndComparison(); + Ordering ordering = frameInfo.getOrdering().get(); + + int position = recent; + + // leave section of leading nulls + while (pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart + position)) { + position++; + } + + return seek( + comparator, + sortKeyChannel, + position, + 1, + ordering == ASCENDING, + partitionEnd - 1 - partitionStart, + p -> p < 0 || pagesIndex.isNull(sortKeyChannel, partitionStart + p)); + } + + private int getFrameEndFollowing(int recent, FrameInfo frameInfo, PagesIndexComparator comparator) + { + Ordering ordering = frameInfo.getOrdering().get(); + int sortKeyChannel = frameInfo.getSortKeyChannelForEndComparison(); + + int position = recent; + + // frame end for first non-null position - leave section of leading nulls + if (pagesIndex.isNull(frameInfo.getSortKeyChannel(), partitionStart + recent)) { + position = currentPosition - partitionStart; + } + + return seek( + comparator, + sortKeyChannel, + position, + 1, + ordering == ASCENDING, + partitionEnd - 1 - partitionStart, + p -> false); + } + + private int compare(PagesIndexComparator comparator, int left, int right, boolean reverse) + { + int result = comparator.compareTo(pagesIndex, left, right); + + if (reverse) { + return -result; + } + + return result; + } + + // This method assumes that `sortKeyChannel` is not null at `position` + private int seek(PagesIndexComparator comparator, int sortKeyChannel, int position, int step, boolean reverse, int limit, Predicate bound) + { + int comparison = compare(comparator, partitionStart + position, currentPosition, reverse); + while (comparison < 0) { + position -= step; + + if (bound.test(position)) { + return position; + } + + comparison = compare(comparator, partitionStart + position, currentPosition, reverse); + } + while (true) { + if (position == limit || pagesIndex.isNull(sortKeyChannel, partitionStart + position + step)) { + break; + } + int newComparison = compare(comparator, partitionStart + position + step, currentPosition, reverse); + if (newComparison >= 0) { + position += step; + } + else { + break; + } + } + + return position; + } + + private boolean emptyFrame(Range range) + { + return range.getStart() > range.getEnd() || + range.getStart() >= partitionEnd - partitionStart || + range.getEnd() < 0; + } + + /** + * Return the nearest valid frame. A frame is valid if its start and end are within partition. + * Note: A valid frame might be empty i.e. its end might be before its start. + */ + private Range nearestValidFrame(Range range) + { + return new Range( + Math.min(partitionEnd - partitionStart - 1, range.getStart()), + Math.max(0, range.getEnd())); + } + private boolean emptyFrame(FrameInfo frameInfo, int rowPosition, int endPosition) { BoundType startType = frameInfo.getStartType(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index 4da8b7dbf2cd6..617a60d453667 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -123,6 +123,9 @@ public class Analysis private final Map, Type> types = new LinkedHashMap<>(); private final Map, Type> coercions = new LinkedHashMap<>(); + private final Map, Type> sortKeyCoercionsForFrameBoundCalculation = new LinkedHashMap<>(); + private final Map, Type> sortKeyCoercionsForFrameBoundComparison = new LinkedHashMap<>(); + private final Map, FunctionHandle> frameBoundCalculations = new LinkedHashMap<>(); private final Set> typeOnlyCoercions = new LinkedHashSet<>(); private final Map, List> relationCoercions = new LinkedHashMap<>(); private final Map, FunctionHandle> functionHandles = new LinkedHashMap<>(); @@ -547,10 +550,36 @@ public void addCoercion(Expression expression, Type type, boolean isTypeOnlyCoer } } - public void addCoercions(Map, Type> coercions, Set> typeOnlyCoercions) + public void addCoercions( + Map, Type> coercions, + Set> typeOnlyCoercions, + Map, Type> sortKeyCoercionsForFrameBoundCalculation, + Map, Type> sortKeyCoercionsForFrameBoundComparison) { this.coercions.putAll(coercions); this.typeOnlyCoercions.addAll(typeOnlyCoercions); + this.sortKeyCoercionsForFrameBoundCalculation.putAll(sortKeyCoercionsForFrameBoundCalculation); + this.sortKeyCoercionsForFrameBoundComparison.putAll(sortKeyCoercionsForFrameBoundComparison); + } + + public Type getSortKeyCoercionForFrameBoundCalculation(Expression frameOffset) + { + return sortKeyCoercionsForFrameBoundCalculation.get(NodeRef.of(frameOffset)); + } + + public Type getSortKeyCoercionForFrameBoundComparison(Expression frameOffset) + { + return sortKeyCoercionsForFrameBoundComparison.get(NodeRef.of(frameOffset)); + } + + public void addFrameBoundCalculations(Map, FunctionHandle> frameBoundCalculations) + { + this.frameBoundCalculations.putAll(frameBoundCalculations); + } + + public FunctionHandle getFrameBoundCalculation(Expression frameOffset) + { + return frameBoundCalculations.get(NodeRef.of(frameOffset)); } public Expression getHaving(QuerySpecification query) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index 57292f53ca304..10c61cc1ddf77 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -19,6 +19,7 @@ import com.facebook.presto.common.function.SqlFunctionProperties; import com.facebook.presto.common.type.CharType; import com.facebook.presto.common.type.DecimalParseResult; +import com.facebook.presto.common.type.DecimalType; import com.facebook.presto.common.type.Decimals; import com.facebook.presto.common.type.FunctionType; import com.facebook.presto.common.type.RowType; @@ -32,6 +33,7 @@ import com.facebook.presto.metadata.OperatorNotFoundException; import com.facebook.presto.security.AccessControl; import com.facebook.presto.security.DenyAllAccessControl; +import com.facebook.presto.spi.ErrorCode; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.StandardErrorCode; @@ -40,6 +42,7 @@ import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionImplementationType; import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.sql.parser.SqlParser; @@ -66,6 +69,7 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Extract; import com.facebook.presto.sql.tree.FieldReference; +import com.facebook.presto.sql.tree.FrameBound; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.GroupingOperation; @@ -86,6 +90,7 @@ import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; @@ -102,11 +107,13 @@ import com.facebook.presto.sql.tree.TimestampLiteral; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; +import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; import com.facebook.presto.transaction.TransactionId; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import io.airlift.slice.SliceUtf8; @@ -122,7 +129,9 @@ import java.util.Set; import java.util.function.Function; +import static com.facebook.presto.common.function.OperatorType.ADD; import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT; +import static com.facebook.presto.common.function.OperatorType.SUBTRACT; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DateType.DATE; @@ -142,6 +151,7 @@ import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.CastType.CAST; import static com.facebook.presto.metadata.FunctionAndTypeManager.qualifyObjectName; +import static com.facebook.presto.spi.StandardErrorCode.OPERATOR_NOT_FOUND; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions; import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoExternalFunctions; @@ -149,8 +159,10 @@ import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteralType; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDER_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ORDER_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MULTIPLE_FIELDS_FROM_SUBQUERY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.STANDALONE_LAMBDA; @@ -159,6 +171,12 @@ import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.tree.Extract.Field.TIMEZONE_HOUR; import static com.facebook.presto.sql.tree.Extract.Field.TIMEZONE_MINUTE; +import static com.facebook.presto.sql.tree.FrameBound.Type.FOLLOWING; +import static com.facebook.presto.sql.tree.FrameBound.Type.PRECEDING; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; +import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; +import static com.facebook.presto.sql.tree.WindowFrame.Type.RANGE; +import static com.facebook.presto.sql.tree.WindowFrame.Type.ROWS; import static com.facebook.presto.type.ArrayParametricType.ARRAY; import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; @@ -192,6 +210,17 @@ public class ExpressionAnalyzer private final Set> existsSubqueries = new LinkedHashSet<>(); private final Map, Type> expressionCoercions = new LinkedHashMap<>(); private final Set> typeOnlyCoercions = new LinkedHashSet<>(); + + // Coercions needed for window function frame of type RANGE. + // These are coercions for the sort key, needed for frame bound calculation, identified by frame range offset expression. + // Frame definition might contain two different offset expressions (for start and end), each requiring different coercion of the sort key. + private final Map, Type> sortKeyCoercionsForFrameBoundCalculation = new LinkedHashMap<>(); + // Coercions needed for window function frame of type RANGE. + // These are coercions for the sort key, needed for comparison of the sort key with precomputed frame bound, identified by frame range offset expression. + private final Map, Type> sortKeyCoercionsForFrameBoundComparison = new LinkedHashMap<>(); + // Functions for calculating frame bounds for frame of type RANGE, identified by frame range offset expression. + private final Map, FunctionHandle> frameBoundCalculations = new LinkedHashMap<>(); + private final Set> subqueryInPredicates = new LinkedHashSet<>(); private final Map, FieldId> columnReferences = new LinkedHashMap<>(); private final Map, Type> expressionTypes = new LinkedHashMap<>(); @@ -268,6 +297,21 @@ public Set> getTypeOnlyCoercions() return unmodifiableSet(typeOnlyCoercions); } + public Map, Type> getSortKeyCoercionsForFrameBoundCalculation() + { + return unmodifiableMap(sortKeyCoercionsForFrameBoundCalculation); + } + + public Map, Type> getSortKeyCoercionsForFrameBoundComparison() + { + return unmodifiableMap(sortKeyCoercionsForFrameBoundComparison); + } + + public Map, FunctionHandle> getFrameBoundCalculations() + { + return unmodifiableMap(frameBoundCalculations); + } + public Set> getSubqueryInPredicates() { return unmodifiableSet(subqueryInPredicates); @@ -865,7 +909,8 @@ protected Type visitNullLiteral(NullLiteral node, StackableAstVisitorContext context) { if (node.getWindow().isPresent()) { - for (Expression expression : node.getWindow().get().getPartitionBy()) { + Window window = node.getWindow().get(); + for (Expression expression : window.getPartitionBy()) { process(expression, context); Type type = getExpressionType(expression); if (!type.isComparable()) { @@ -873,7 +918,7 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext context, Window window) + { + if (!window.getOrderBy().isPresent()) { + throw new SemanticException(MISSING_ORDER_BY, window, "Window frame of type RANGE PRECEDING or FOLLOWING requires ORDER BY"); + } + OrderBy orderBy = window.getOrderBy().get(); + if (orderBy.getSortItems().size() != 1) { + throw new SemanticException(INVALID_ORDER_BY, orderBy, "Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY (actual: %s)", orderBy.getSortItems().size()); + } + Expression sortKey = Iterables.getOnlyElement(orderBy.getSortItems()).getSortKey(); + Type sortKeyType = getExpressionType(sortKey); + if (!isNumericType(sortKeyType) && !isDateTimeType(sortKeyType)) { + throw new SemanticException(TYPE_MISMATCH, sortKey, "Window frame of type RANGE PRECEDING or FOLLOWING requires that sort item type be numeric, datetime or interval (actual: %s)", sortKeyType); + } + + Type offsetValueType = process(offsetValue, context); + + if (isNumericType(sortKeyType)) { + if (!isNumericType(offsetValueType)) { + throw new SemanticException(TYPE_MISMATCH, offsetValue, "Window frame RANGE value type (%s) not compatible with sort item type (%s)", offsetValueType, sortKeyType); + } + } + else { // isDateTimeType(sortKeyType) + if (offsetValueType != INTERVAL_DAY_TIME && offsetValueType != INTERVAL_YEAR_MONTH) { + throw new SemanticException(TYPE_MISMATCH, offsetValue, "Window frame RANGE value type (%s) not compatible with sort item type (%s)", offsetValueType, sortKeyType); + } + } + + // resolve function to calculate frame boundary value (add / subtract offset from sortKey) + SortItem.Ordering ordering = Iterables.getOnlyElement(orderBy.getSortItems()).getOrdering(); + OperatorType operatorType; + FunctionHandle function; + if ((boundType == PRECEDING && ordering == ASCENDING) || (boundType == FOLLOWING && ordering == DESCENDING)) { + operatorType = SUBTRACT; + } + else { + operatorType = ADD; + } + try { + function = functionAndTypeManager.resolveOperator(operatorType, TypeSignatureProvider.fromTypes(sortKeyType, offsetValueType)); + } + catch (PrestoException e) { + ErrorCode errorCode = e.getErrorCode(); + if (errorCode.equals(OPERATOR_NOT_FOUND.toErrorCode())) { + throw new SemanticException(TYPE_MISMATCH, offsetValue, "Window frame RANGE value type (%s) not compatible with sort item type (%s)", offsetValueType, sortKeyType); + } + throw e; + } + Signature signature = function.getSignature(); + Type expectedSortKeyType = functionAndTypeManager.getType(signature.getArgumentTypes().get(0)); + if (!expectedSortKeyType.equals(sortKeyType)) { + if (!functionAndTypeManager.canCoerce(sortKeyType, expectedSortKeyType)) { + throw new SemanticException(TYPE_MISMATCH, sortKey, "Sort key must evaluate to a %s (actual: %s)", expectedSortKeyType, sortKeyType); + } + sortKeyCoercionsForFrameBoundCalculation.put(NodeRef.of(offsetValue), expectedSortKeyType); + } + Type expectedOffsetValueType = functionAndTypeManager.getType(signature.getArgumentTypes().get(1)); + if (!expectedOffsetValueType.equals(offsetValueType)) { + coerceType(offsetValue, offsetValueType, expectedOffsetValueType, format("Function %s argument 1", function)); + } + Type expectedFunctionResultType = functionAndTypeManager.getType(signature.getReturnType()); + if (!expectedFunctionResultType.equals(sortKeyType)) { + if (!functionAndTypeManager.canCoerce(sortKeyType, expectedFunctionResultType)) { + throw new SemanticException(TYPE_MISMATCH, sortKey, "Sort key must evaluate to a %s (actual: %s)", expectedFunctionResultType, sortKeyType); + } + sortKeyCoercionsForFrameBoundComparison.put(NodeRef.of(offsetValue), expectedFunctionResultType); + } + + frameBoundCalculations.put(NodeRef.of(offsetValue), function); + } + @Override protected Type visitAtTimeZone(AtTimeZone node, StackableAstVisitorContext context) { @@ -1607,10 +1739,14 @@ public static ExpressionAnalysis analyzeExpression( Map, Type> expressionTypes = analyzer.getExpressionTypes(); Map, Type> expressionCoercions = analyzer.getExpressionCoercions(); Set> typeOnlyCoercions = analyzer.getTypeOnlyCoercions(); + Map, Type> sortKeyCoercionsForFrameBoundCalculation = analyzer.getSortKeyCoercionsForFrameBoundCalculation(); + Map, Type> sortKeyCoercionsForFrameBoundComparison = analyzer.getSortKeyCoercionsForFrameBoundComparison(); + Map, FunctionHandle> frameBoundCalculations = analyzer.getFrameBoundCalculations(); Map, FunctionHandle> resolvedFunctions = analyzer.getResolvedFunctions(); analysis.addTypes(expressionTypes); - analysis.addCoercions(expressionCoercions, typeOnlyCoercions); + analysis.addCoercions(expressionCoercions, typeOnlyCoercions, sortKeyCoercionsForFrameBoundCalculation, sortKeyCoercionsForFrameBoundComparison); + analysis.addFrameBoundCalculations(frameBoundCalculations); analysis.addFunctionHandles(resolvedFunctions); analysis.addColumnReferences(analyzer.getColumnReferences()); analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences()); @@ -1776,4 +1912,15 @@ public static ExpressionAnalyzer createWithoutSubqueries( warningCollector, isDescribe); } + + public static boolean isNumericType(Type type) + { + return type.equals(BIGINT) || + type.equals(INTEGER) || + type.equals(SMALLINT) || + type.equals(TINYINT) || + type.equals(DOUBLE) || + type.equals(REAL) || + type instanceof DecimalType; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java index 66933a527c380..66b6d73302598 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java @@ -50,6 +50,8 @@ public enum SemanticErrorCode MISSING_ATTRIBUTE, INVALID_ORDINAL, INVALID_LITERAL, + MISSING_ORDER_BY, + INVALID_ORDER_BY, FUNCTION_NOT_FOUND, INVALID_FUNCTION_NAME, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 75036e66dc416..0bd84d202eadd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -255,7 +255,6 @@ import static com.facebook.presto.sql.tree.FrameBound.Type.PRECEDING; import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_FOLLOWING; import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; -import static com.facebook.presto.sql.tree.WindowFrame.Type.RANGE; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.throwIfInstanceOf; @@ -1988,12 +1987,6 @@ private void analyzeWindowFrame(WindowFrame frame) if ((startType == FOLLOWING) && (endType == CURRENT_ROW)) { throw new SemanticException(INVALID_WINDOW_FRAME, frame, "Window frame starting from FOLLOWING cannot end with CURRENT ROW"); } - if ((frame.getType() == RANGE) && ((startType == PRECEDING) || (endType == PRECEDING))) { - throw new SemanticException(INVALID_WINDOW_FRAME, frame, "Window frame RANGE PRECEDING is only supported with UNBOUNDED"); - } - if ((frame.getType() == RANGE) && ((startType == FOLLOWING) || (endType == FOLLOWING))) { - throw new SemanticException(INVALID_WINDOW_FRAME, frame, "Window frame RANGE FOLLOWING is only supported with UNBOUNDED"); - } } private void analyzeHaving(QuerySpecification node, Scope scope) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 451a47b213cd8..142e2695c9a9c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -204,6 +204,7 @@ import com.facebook.presto.sql.planner.plan.WindowNode.Frame; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.VariableToChannelTranslator; +import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.SymbolReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; @@ -312,6 +313,8 @@ import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; +import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; import static com.facebook.presto.util.Reflection.constructorMethodHandle; import static com.facebook.presto.util.SpatialJoinUtils.ST_CONTAINS; import static com.facebook.presto.util.SpatialJoinUtils.ST_CROSSES; @@ -1033,17 +1036,40 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext ImmutableList.Builder windowFunctionOutputVariablesBuilder = ImmutableList.builder(); for (Map.Entry entry : node.getWindowFunctions().entrySet()) { Optional frameStartChannel = Optional.empty(); + Optional sortKeyChannelForStartComparison = Optional.empty(); Optional frameEndChannel = Optional.empty(); + Optional sortKeyChannelForEndComparison = Optional.empty(); + Optional sortKeyChannel = Optional.empty(); + Optional ordering = Optional.empty(); Frame frame = entry.getValue().getFrame(); if (frame.getStartValue().isPresent()) { frameStartChannel = Optional.of(source.getLayout().get(frame.getStartValue().get())); } + if (frame.getSortKeyCoercedForFrameStartComparison().isPresent()) { + sortKeyChannelForStartComparison = Optional.of(source.getLayout().get(frame.getSortKeyCoercedForFrameStartComparison().get())); + } if (frame.getEndValue().isPresent()) { frameEndChannel = Optional.of(source.getLayout().get(frame.getEndValue().get())); } + if (frame.getSortKeyCoercedForFrameEndComparison().isPresent()) { + sortKeyChannelForEndComparison = Optional.of(source.getLayout().get(frame.getSortKeyCoercedForFrameEndComparison().get())); + } + if (node.getOrderingScheme().isPresent()) { + sortKeyChannel = Optional.of(sortChannels.get(0)); + ordering = Optional.of(sortOrder.get(0).isAscending() ? ASCENDING : DESCENDING); + } - FrameInfo frameInfo = new FrameInfo(frame.getType(), frame.getStartType(), frameStartChannel, frame.getEndType(), frameEndChannel); + FrameInfo frameInfo = new FrameInfo( + frame.getType(), + frame.getStartType(), + frameStartChannel, + sortKeyChannelForStartComparison, + frame.getEndType(), + frameEndChannel, + sortKeyChannelForEndComparison, + sortKeyChannel, + ordering); WindowNode.Function function = entry.getValue(); CallExpression call = function.getFunctionCall(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java index 9e6df2979acd6..94f66cb6fc799 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java @@ -106,6 +106,7 @@ public PlanBuilder appendProjections(Iterable expressions, PlanVaria } ImmutableMap.Builder newTranslations = ImmutableMap.builder(); + for (Expression expression : expressions) { VariableReferenceExpression variable = variableAllocator.newVariable(expression, getAnalysis().getTypeWithCoercions(expression)); projections.put(variable, castToRowExpression(translations.rewrite(expression))); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 3a705320740b9..2c63058adaf5e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -14,12 +14,14 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.Assignments; @@ -50,22 +52,28 @@ import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FrameBound; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GroupingOperation; +import com.facebook.presto.sql.tree.IfExpression; +import com.facebook.presto.sql.tree.IntervalLiteral; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NodeLocation; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Offset; import com.facebook.presto.sql.tree.OrderBy; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; import com.facebook.presto.sql.tree.SortItem; +import com.facebook.presto.sql.tree.StringLiteral; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; @@ -76,6 +84,7 @@ import com.google.common.collect.Sets; import java.util.ArrayList; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -85,13 +94,17 @@ import static com.facebook.presto.SystemSessionProperties.isSkipRedundantSort; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.plan.AggregationNode.groupingSets; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.LimitNode.Step.FINAL; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.isNumericType; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme; import static com.facebook.presto.sql.planner.PlannerUtils.toSortOrder; import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.toBoundType; @@ -100,10 +113,20 @@ import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static com.facebook.presto.sql.tree.IntervalLiteral.IntervalField.DAY; +import static com.facebook.presto.sql.tree.IntervalLiteral.IntervalField.YEAR; +import static com.facebook.presto.sql.tree.IntervalLiteral.Sign.POSITIVE; +import static com.facebook.presto.sql.tree.WindowFrame.Type.RANGE; +import static com.facebook.presto.sql.tree.WindowFrame.Type.ROWS; +import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; +import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Streams.stream; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; class QueryPlanner @@ -365,6 +388,40 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression projections.build())); } + /** + * Creates a projection with any additional coercions by identity of the provided expressions. + * + * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed + */ + public static PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, PlanVariableAllocator variableAllocator, Metadata metadata) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putIdentities(subPlan.getRoot().getOutputVariables(), variable -> castToRowExpression(asSymbolReference(variable))); + ImmutableMap.Builder, VariableReferenceExpression> mappings = ImmutableMap.builder(); + for (Expression expression : expressions) { + Type coercion = analysis.getCoercion(expression); + if (coercion != null) { + Type type = analysis.getType(expression); + VariableReferenceExpression variable = variableAllocator.newVariable(expression, coercion); + assignments.put(variable, castToRowExpression(new Cast( + subPlan.rewrite(expression), + coercion.getTypeSignature().toString(), + false, + metadata.getFunctionAndTypeManager().isTypeOnlyCoercion(type, coercion)))); + mappings.put(NodeRef.of(expression), variable); + } + else { + mappings.put(NodeRef.of(expression), subPlan.translate(expression)); + } + } + subPlan = subPlan.withNewRoot( + new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + assignments.build())); + return new PlanAndMappings(subPlan, mappings.build()); + } + private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) { ImmutableMap.Builder projections = ImmutableMap.builder(); @@ -732,39 +789,73 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio Window window = windowFunction.getWindow().get(); // Extract frame - WindowFrame.Type frameType = WindowFrame.Type.RANGE; + WindowFrame.Type frameType = RANGE; FrameBound.Type frameStartType = FrameBound.Type.UNBOUNDED_PRECEDING; FrameBound.Type frameEndType = FrameBound.Type.CURRENT_ROW; - Expression frameStart = null; - Expression frameEnd = null; + Optional startValue = Optional.empty(); + Optional endValue = Optional.empty(); if (window.getFrame().isPresent()) { WindowFrame frame = window.getFrame().get(); frameType = frame.getType(); frameStartType = frame.getStart().getType(); - frameStart = frame.getStart().getValue().orElse(null); + startValue = frame.getStart().getValue(); if (frame.getEnd().isPresent()) { frameEndType = frame.getEnd().get().getType(); - frameEnd = frame.getEnd().get().getValue().orElse(null); + endValue = frame.getEnd().get().getValue(); } } // Pre-project inputs - ImmutableList.Builder inputs = ImmutableList.builder() + ImmutableList.Builder inputsBuilder = ImmutableList.builder() .addAll(windowFunction.getArguments()) .addAll(window.getPartitionBy()) .addAll(Iterables.transform(getSortItemsFromOrderBy(window.getOrderBy()), SortItem::getSortKey)); - if (frameStart != null) { - inputs.add(frameStart); + if (startValue.isPresent()) { + inputsBuilder.add(startValue.get()); } - if (frameEnd != null) { - inputs.add(frameEnd); + if (endValue.isPresent()) { + inputsBuilder.add(endValue.get()); } - subPlan = subPlan.appendProjections(inputs.build(), variableAllocator, idAllocator); + ImmutableList inputs = inputsBuilder.build(); + subPlan = subPlan.appendProjections(inputs, variableAllocator, idAllocator); + + PlanAndMappings coercions = coerce(subPlan, inputs, analysis, idAllocator, variableAllocator, metadata); + subPlan = coercions.getSubPlan(); + + // For frame of type RANGE, append casts and functions necessary for frame bound calculations + Optional frameStart = Optional.empty(); + Optional frameEnd = Optional.empty(); + Optional sortKeyCoercedForFrameStartComparison = Optional.empty(); + Optional sortKeyCoercedForFrameEndComparison = Optional.empty(); + + if (window.getFrame().isPresent() && window.getFrame().get().getType() == RANGE) { + // record sortKey coercions for reuse + Map sortKeyCoercions = new HashMap<>(); + + // process frame start + FrameBoundPlanAndSymbols plan = planFrameBound(subPlan, coercions, startValue, window, sortKeyCoercions); + subPlan = plan.getSubPlan(); + frameStart = plan.getFrameBoundSymbol(); + sortKeyCoercedForFrameStartComparison = plan.getSortKeyCoercedForFrameBoundComparison(); + + // process frame end + plan = planFrameBound(subPlan, coercions, endValue, window, sortKeyCoercions); + subPlan = plan.getSubPlan(); + frameEnd = plan.getFrameBoundSymbol(); + sortKeyCoercedForFrameEndComparison = plan.getSortKeyCoercedForFrameBoundComparison(); + } + else if (window.getFrame().isPresent() && window.getFrame().get().getType() == ROWS) { + frameStart = window.getFrame().get().getStart().getValue().map(coercions::get); + frameEnd = window.getFrame().get().getEnd().flatMap(FrameBound::getValue).map(coercions::get); + } + else if (window.getFrame().isPresent()) { + throw new IllegalArgumentException("unexpected window frame type: " + window.getFrame().get().getType()); + } // Rewrite PARTITION BY in terms of pre-projected inputs ImmutableList.Builder partitionByVariables = ImmutableList.builder(); @@ -781,23 +872,17 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio } // Rewrite frame bounds in terms of pre-projected inputs - Optional frameStartVariable = Optional.empty(); - Optional frameEndVariable = Optional.empty(); - if (frameStart != null) { - frameStartVariable = Optional.of(subPlan.translate(frameStart)); - } - if (frameEnd != null) { - frameEndVariable = Optional.of(subPlan.translate(frameEnd)); - } WindowNode.Frame frame = new WindowNode.Frame( toWindowType(frameType), toBoundType(frameStartType), - frameStartVariable, + frameStart, + sortKeyCoercedForFrameStartComparison, toBoundType(frameEndType), - frameEndVariable, - Optional.ofNullable(frameStart).map(Expression::toString), - Optional.ofNullable(frameEnd).map(Expression::toString)); + frameEnd, + sortKeyCoercedForFrameEndComparison, + startValue.map(Expression::toString), + endValue.map(Expression::toString)); TranslationMap outputTranslations = subPlan.copyTranslations(); @@ -865,6 +950,134 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio return subPlan; } + private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMappings coercions, Optional frameOffset, Window window, Map sortKeyCoercions) + { + Optional frameBoundCalculationFunction = frameOffset.map(analysis::getFrameBoundCalculation); + + // Empty frameBoundCalculationFunction indicates that frame bound type is CURRENT ROW or UNBOUNDED. + // Handling it doesn't require any additional symbols. + if (!frameBoundCalculationFunction.isPresent()) { + return new FrameBoundPlanAndSymbols(subPlan, Optional.empty(), Optional.empty()); + } + + // Present frameBoundCalculationFunction indicates that frame bound type is PRECEDING or FOLLOWING. + // It requires adding certain projections to the plan so that the operator can determine frame bounds. + + // First, append filter to validate offset values. They mustn't be negative or null. + VariableReferenceExpression offsetSymbol = coercions.get(frameOffset.get()); + Expression zeroOffset = zeroOfType(variableAllocator.getTypes().get(offsetSymbol)); + FunctionHandle fail = metadata.getFunctionAndTypeManager().resolveFunction(Optional.empty(), Optional.empty(), QualifiedObjectName.valueOf("presto.default.fail"), fromTypes(VARCHAR)); + Expression predicate = new IfExpression( + new ComparisonExpression( + GREATER_THAN_OR_EQUAL, + new SymbolReference(offsetSymbol.getName()), + zeroOffset), + TRUE_LITERAL, + new Cast( + new FunctionCall( + //fail.toQualifiedName(), + QualifiedName.of("presto", "default", "fail"), + ImmutableList.of(new Cast(new StringLiteral("Window frame offset value must not be negative or null"), VARCHAR.getTypeSignature().toString()))), + BOOLEAN.getTypeSignature().toString())); + subPlan = subPlan.withNewRoot(new FilterNode( + idAllocator.getNextId(), + subPlan.getRoot(), + castToRowExpression(predicate))); + + // Then, coerce the sortKey so that we can add / subtract the offset. + // Note: for that we cannot rely on the usual mechanism of using the coerce() method. The coerce() method can only handle one coercion for a node, + // while the sortKey node might require several different coercions, e.g. one for frame start and one for frame end. + Expression sortKey = Iterables.getOnlyElement(window.getOrderBy().get().getSortItems()).getSortKey(); + VariableReferenceExpression sortKeyCoercedForFrameBoundCalculation = coercions.get(sortKey); + Optional coercion = frameOffset.map(analysis::getSortKeyCoercionForFrameBoundCalculation); + if (coercion.isPresent()) { + Type expectedType = coercion.get(); + VariableReferenceExpression alreadyCoerced = sortKeyCoercions.get(expectedType); + if (alreadyCoerced != null) { + sortKeyCoercedForFrameBoundCalculation = alreadyCoerced; + } + else { + Expression cast = new Cast( + new SymbolReference(coercions.get(sortKey).getName()), + expectedType.getTypeSignature().toString(), + false, + metadata.getFunctionAndTypeManager().isTypeOnlyCoercion(analysis.getType(sortKey), expectedType)); + sortKeyCoercedForFrameBoundCalculation = variableAllocator.newVariable(cast, expectedType); + sortKeyCoercions.put(expectedType, sortKeyCoercedForFrameBoundCalculation); + subPlan = subPlan.withNewRoot(new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + Assignments.builder() + .putIdentities(subPlan.getRoot().getOutputVariables(), variable -> castToRowExpression(asSymbolReference(variable))) + .put(sortKeyCoercedForFrameBoundCalculation, castToRowExpression(cast)) + .build())); + } + } + + // Next, pre-project the function which combines sortKey with the offset. + // Note: if frameOffset needs a coercion, it was added before by a call to coerce() method. + FunctionHandle function = frameBoundCalculationFunction.get(); + QualifiedObjectName name = function.getSignature().getName(); + Expression functionCall = new FunctionCall( + //function.toQualifiedName(), + QualifiedName.of(name.getCatalogName(), name.getSchemaName(), name.getObjectName()), + ImmutableList.of( + new SymbolReference(sortKeyCoercedForFrameBoundCalculation.getName()), + new SymbolReference(offsetSymbol.getName()))); + VariableReferenceExpression frameBoundVariable = variableAllocator.newVariable(functionCall, metadata.getFunctionAndTypeManager().getType(function.getSignature().getReturnType())); + subPlan = subPlan.withNewRoot(new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + Assignments.builder() + .putIdentities(subPlan.getRoot().getOutputVariables(), variable -> castToRowExpression(asSymbolReference(variable))) + .put(frameBoundVariable, castToRowExpression(functionCall)) + .build())); + + // Finally, coerce the sortKey to the type of frameBound so that the operator can perform comparisons on them + Optional sortKeyCoercedForFrameBoundComparison = Optional.of(coercions.get(sortKey)); + coercion = frameOffset.map(analysis::getSortKeyCoercionForFrameBoundComparison); + if (coercion.isPresent()) { + Type expectedType = coercion.get(); + VariableReferenceExpression alreadyCoerced = sortKeyCoercions.get(expectedType); + if (alreadyCoerced != null) { + sortKeyCoercedForFrameBoundComparison = Optional.of(alreadyCoerced); + } + else { + Expression cast = new Cast( + new SymbolReference(coercions.get(sortKey).getName()), + expectedType.getTypeSignature().toString(), + false, + metadata.getFunctionAndTypeManager().isTypeOnlyCoercion(analysis.getType(sortKey), expectedType)); + VariableReferenceExpression castSymbol = variableAllocator.newVariable(cast, expectedType); + sortKeyCoercions.put(expectedType, castSymbol); + subPlan = subPlan.withNewRoot(new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + Assignments.builder() + .putIdentities(subPlan.getRoot().getOutputVariables(), variable -> castToRowExpression(asSymbolReference(variable))) + .put(castSymbol, castToRowExpression(cast)) + .build())); + sortKeyCoercedForFrameBoundComparison = Optional.of(castSymbol); + } + } + + return new FrameBoundPlanAndSymbols(subPlan, Optional.of(frameBoundVariable), sortKeyCoercedForFrameBoundComparison); + } + + private Expression zeroOfType(Type type) + { + if (isNumericType(type)) { + return new Cast(new LongLiteral("0"), type.getTypeSignature().toString()); + } + if (type.equals(INTERVAL_DAY_TIME)) { + return new IntervalLiteral("0", POSITIVE, DAY); + } + if (type.equals(INTERVAL_YEAR_MONTH)) { + return new IntervalLiteral("0", POSITIVE, YEAR); + } + throw new IllegalArgumentException("unexpected type: " + type); + } + private PlanBuilder handleSubqueries(PlanBuilder subPlan, Node node, Iterable inputs) { for (Expression input : inputs) { @@ -961,4 +1174,61 @@ private static List toSymbolReferences(List, VariableReferenceExpression> mappings; + public PlanAndMappings(PlanBuilder subPlan, Map, VariableReferenceExpression> mappings) + { + this.subPlan = subPlan; + this.mappings = mappings; + } + public PlanBuilder getSubPlan() + { + return subPlan; + } + public VariableReferenceExpression get(Expression expression) + { + return tryGet(expression) + .orElseThrow(() -> new IllegalArgumentException(format("No mapping for expression: %s (%s)", expression, System.identityHashCode(expression)))); + } + public Optional tryGet(Expression expression) + { + VariableReferenceExpression result = mappings.get(NodeRef.of(expression)); + if (result != null) { + return Optional.of(result); + } + return Optional.empty(); + } + } + + private static class FrameBoundPlanAndSymbols + { + private final PlanBuilder subPlan; + private final Optional frameBoundSymbol; + private final Optional sortKeyCoercedForFrameBoundComparison; + + public FrameBoundPlanAndSymbols(PlanBuilder subPlan, Optional frameBoundSymbol, Optional sortKeyCoercedForFrameBoundComparison) + { + this.subPlan = subPlan; + this.frameBoundSymbol = frameBoundSymbol; + this.sortKeyCoercedForFrameBoundComparison = sortKeyCoercedForFrameBoundComparison; + } + + public PlanBuilder getSubPlan() + { + return subPlan; + } + + public Optional getFrameBoundSymbol() + { + return frameBoundSymbol; + } + + public Optional getSortKeyCoercedForFrameBoundComparison() + { + return sortKeyCoercedForFrameBoundComparison; + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java index c263eb4dd0a67..5dada04d4d26f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java @@ -66,6 +66,15 @@ public Type get(Expression expression) return type; } + public Type get(VariableReferenceExpression expression) + { + requireNonNull(expression, "expression is null"); + Type type = types.get(expression.getName()); + checkArgument(type != null, "no type found found for expression '%s'", expression); + + return type; + } + public Set allVariables() { return types.entrySet().stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java index 12cc3e6b003f7..1f0546d64246a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java @@ -62,6 +62,8 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, referencedInputs.addAll(WindowNodeUtil.extractWindowFunctionUniqueVariables(windowFunction, variableAllocator.getTypes())); windowFunction.getFrame().getStartValue().ifPresent(referencedInputs::add); windowFunction.getFrame().getEndValue().ifPresent(referencedInputs::add); + windowFunction.getFrame().getSortKeyCoercedForFrameStartComparison().ifPresent(referencedInputs::add); + windowFunction.getFrame().getSortKeyCoercedForFrameEndComparison().ifPresent(referencedInputs::add); } PlanNode prunedWindowNode = new WindowNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 6d418aad1c022..b08d790ca5f70 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -416,6 +416,12 @@ public PlanNode visitWindow(WindowNode node, RewriteContext startValue; + private final Optional sortKeyCoercedForFrameStartComparison; private final BoundType endType; private final Optional endValue; + private final Optional sortKeyCoercedForFrameEndComparison; // This information is only used for printing the plan. private final Optional originalStartValue; @@ -235,25 +238,35 @@ public Frame( @JsonProperty("type") WindowType type, @JsonProperty("startType") BoundType startType, @JsonProperty("startValue") Optional startValue, + @JsonProperty("sortKeyCoercedForFrameStartComparison") Optional sortKeyCoercedForFrameStartComparison, @JsonProperty("endType") BoundType endType, @JsonProperty("endValue") Optional endValue, + @JsonProperty("sortKeyCoercedForFrameEndComparison") Optional sortKeyCoercedForFrameEndComparison, @JsonProperty("originalStartValue") Optional originalStartValue, @JsonProperty("originalEndValue") Optional originalEndValue) { this.startType = requireNonNull(startType, "startType is null"); this.startValue = requireNonNull(startValue, "startValue is null"); + this.sortKeyCoercedForFrameStartComparison = requireNonNull(sortKeyCoercedForFrameStartComparison, "sortKeyCoercedForFrameStartComparison is null"); this.endType = requireNonNull(endType, "endType is null"); this.endValue = requireNonNull(endValue, "endValue is null"); + this.sortKeyCoercedForFrameEndComparison = requireNonNull(sortKeyCoercedForFrameEndComparison, "sortKeyCoercedForFrameEndComparison is null"); this.type = requireNonNull(type, "type is null"); this.originalStartValue = requireNonNull(originalStartValue, "originalStartValue is null"); this.originalEndValue = requireNonNull(originalEndValue, "originalEndValue is null"); if (startValue.isPresent()) { checkArgument(originalStartValue.isPresent(), "originalStartValue must be present if startValue is present"); + if (type == RANGE) { + checkArgument(sortKeyCoercedForFrameStartComparison.isPresent(), "for frame of type RANGE, sortKeyCoercedForFrameStartComparison must be present if startValue is present"); + } } if (endValue.isPresent()) { checkArgument(originalEndValue.isPresent(), "originalEndValue must be present if endValue is present"); + if (type == RANGE) { + checkArgument(sortKeyCoercedForFrameEndComparison.isPresent(), "for frame of type RANGE, sortKeyCoercedForFrameEndComparison must be present if endValue is present"); + } } } @@ -275,6 +288,12 @@ public Optional getStartValue() return startValue; } + @JsonProperty + public Optional getSortKeyCoercedForFrameStartComparison() + { + return sortKeyCoercedForFrameStartComparison; + } + @JsonProperty public BoundType getEndType() { @@ -287,6 +306,12 @@ public Optional getEndValue() return endValue; } + @JsonProperty + public Optional getSortKeyCoercedForFrameEndComparison() + { + return sortKeyCoercedForFrameEndComparison; + } + @JsonProperty public Optional getOriginalStartValue() { @@ -312,14 +337,16 @@ public boolean equals(Object o) return type == frame.type && startType == frame.startType && Objects.equals(startValue, frame.startValue) && + Objects.equals(sortKeyCoercedForFrameStartComparison, frame.sortKeyCoercedForFrameStartComparison) && endType == frame.endType && - Objects.equals(endValue, frame.endValue); + Objects.equals(endValue, frame.endValue) && + Objects.equals(sortKeyCoercedForFrameEndComparison, frame.sortKeyCoercedForFrameEndComparison); } @Override public int hashCode() { - return Objects.hash(type, startType, startValue, endType, endValue, originalStartValue, originalEndValue); + return Objects.hash(type, startType, startValue, sortKeyCoercedForFrameStartComparison, endType, endValue, originalStartValue, originalEndValue, sortKeyCoercedForFrameEndComparison); } public enum WindowType diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index 359bcd2530f6f..081a39a75375f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -197,6 +197,17 @@ public Void visitWindow(WindowNode node, Set boundV } checkDependencies(inputs, bounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)", bounds.build(), node.getSource().getOutputVariables()); + ImmutableList.Builder symbolsForFrameBoundsComparison = ImmutableList.builder(); + for (WindowNode.Frame frame : node.getFrames()) { + if (frame.getSortKeyCoercedForFrameStartComparison().isPresent()) { + symbolsForFrameBoundsComparison.add(frame.getSortKeyCoercedForFrameStartComparison().get()); + } + if (frame.getSortKeyCoercedForFrameEndComparison().isPresent()) { + symbolsForFrameBoundsComparison.add(frame.getSortKeyCoercedForFrameEndComparison().get()); + } + } + checkDependencies(inputs, symbolsForFrameBoundsComparison.build(), "Invalid node. Symbols for frame bound comparison (%s) not in source plan output (%s)", symbolsForFrameBoundsComparison.build(), node.getSource().getOutputVariables()); + for (WindowNode.Function function : node.getWindowFunctions().values()) { Set dependencies = WindowNodeUtil.extractWindowFunctionUniqueVariables(function, types); checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputVariables()); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java index 7c24303656894..7a7cf8e2f27d5 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java @@ -73,7 +73,7 @@ @Test(singleThreaded = true) public class TestWindowOperator { - private static final FrameInfo UNBOUNDED_FRAME = new FrameInfo(RANGE, UNBOUNDED_PRECEDING, Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty()); + private static final FrameInfo UNBOUNDED_FRAME = new FrameInfo(RANGE, UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); public static final List ROW_NUMBER = ImmutableList.of( window(new ReflectionWindowFunctionSupplier<>("row_number", BIGINT, ImmutableList.of(), RowNumberFunction.class), BIGINT, UNBOUNDED_FRAME)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index cd1a10970eb8f..1530c6adfca9b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -43,6 +43,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_FUNCTION_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_OFFSET_ROW_COUNT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDER_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; @@ -53,6 +54,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ATTRIBUTE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_CATALOG; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_COLUMN; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ORDER_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_SCHEMA; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_TABLE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MULTIPLE_FIELDS_FROM_SUBQUERY; @@ -656,7 +658,7 @@ public void testWindowFunctionWithoutOverClause() } @Test - public void testInvalidWindowFrame() + public void testInvalidWindowFrameTypeRows() { assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS UNBOUNDED FOLLOWING)"); assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS 2 FOLLOWING)"); @@ -665,10 +667,6 @@ public void testInvalidWindowFrame() assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS BETWEEN CURRENT ROW AND 5 PRECEDING)"); assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS BETWEEN 2 FOLLOWING AND 5 PRECEDING)"); assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS BETWEEN 2 FOLLOWING AND CURRENT ROW)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE 2 PRECEDING)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE BETWEEN 2 PRECEDING AND CURRENT ROW)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE BETWEEN CURRENT ROW AND 5 FOLLOWING)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE BETWEEN 2 PRECEDING AND 5 FOLLOWING)"); assertFails(TYPE_MISMATCH, "SELECT rank() OVER (ROWS 5e-1 PRECEDING)"); assertFails(TYPE_MISMATCH, "SELECT rank() OVER (ROWS 'foo' PRECEDING)"); @@ -676,6 +674,63 @@ public void testInvalidWindowFrame() assertFails(TYPE_MISMATCH, "SELECT rank() OVER (ROWS BETWEEN CURRENT ROW AND 'foo' FOLLOWING)"); } + @Test + public void testWindowFrameTypeRange() + { + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND 2 FOLLOWING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND CURRENT ROW) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND 5 PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED FOLLOWING AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE 2 FOLLOWING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND CURRENT ROW) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 5 PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND 5 PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + assertFails(INVALID_WINDOW_FRAME, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE UNBOUNDED PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND 5 PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE 5 PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND 10 PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND 3 PRECEDING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND CURRENT ROW) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND 2 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 5 PRECEDING AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE CURRENT ROW) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND CURRENT ROW) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 1 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 10 FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 2 FOLLOWING AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + + // this should pass the analysis but fail during execution + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN -x PRECEDING AND 0 * x FOLLOWING) FROM (VALUES 1) T(x)"); + analyze("SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN CAST(null AS BIGINT) PRECEDING AND CAST(null AS BIGINT) FOLLOWING) FROM (VALUES 1) T(x)"); + + assertFails(MISSING_ORDER_BY, "SELECT array_agg(x) OVER (RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM (VALUES 1) T(x)"); + + assertFails(INVALID_ORDER_BY, "SELECT array_agg(x) OVER (ORDER BY x DESC, x ASC RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM (VALUES 1) T(x)"); + + assertFails(TYPE_MISMATCH, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM (VALUES 'a') T(x)"); + + assertFails(TYPE_MISMATCH, "SELECT array_agg(x) OVER (ORDER BY x RANGE BETWEEN 'a' PRECEDING AND 'z' FOLLOWING) FROM (VALUES 1) T(x)"); + + assertFails(TYPE_MISMATCH, "SELECT array_agg(x) OVER (ORDER BY x RANGE INTERVAL '1' day PRECEDING) FROM (VALUES INTERVAL '1' year) T(x)"); + + // window frame other than PRECEDING or FOLLOWING has no requirements regarding window ORDER BY clause + // ORDER BY is not required + analyze("SELECT array_agg(x) OVER (RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES 1) T(x)"); + // multiple sort keys and sort keys of types other than numeric or datetime are allowed + analyze("SELECT array_agg(x) OVER (ORDER BY y, z RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM (VALUES (1, 'text', true)) T(x, y, z)"); + } + @Test public void testDistinctInWindowFunctionParameter() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index a59fd402d8c26..7bc7aea158520 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -155,9 +155,11 @@ public void testValidWindow() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, DOUBLE, variableC), frame, false); @@ -296,9 +298,11 @@ public void testInvalidWindowFunctionCall() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, variableA), frame, false); @@ -327,9 +331,11 @@ public void testInvalidWindowFunctionSignature() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, variableC), frame, false); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestWindowFrameRange.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestWindowFrameRange.java new file mode 100644 index 0000000000000..adb3f67e32543 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestWindowFrameRange.java @@ -0,0 +1,207 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.ShortDecimalType; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.tree.DecimalLiteral; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.type.DecimalType.createDecimalType; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.LogicalPlanner.Stage.CREATED; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.windowFrame; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.CURRENT_ROW; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.FOLLOWING; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING; +import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; + +public class TestWindowFrameRange + extends BasePlanTest +{ + @Test + public void testFramePrecedingWithSortKeyCoercions() + { + @Language("SQL") String sql = "SELECT array_agg(key) OVER(ORDER BY key RANGE x PRECEDING) " + + "FROM (VALUES (1, 1.1), (2, 2.2)) T(key, x)"; + + PlanMatchPattern pattern = + anyTree( + window( + windowMatcherBuilder -> windowMatcherBuilder + .specification(specification( + ImmutableList.of(), + ImmutableList.of("key"), + ImmutableMap.of("key", SortOrder.ASC_NULLS_LAST))) + .addFunction( + "array_agg_result", + functionCall("array_agg", ImmutableList.of("key")), + createTestFunctionAndTypeManager().resolveFunction(Optional.empty(), Optional.empty(), QualifiedObjectName.valueOf("presto.default.array_agg"), fromTypes(INTEGER)), + windowFrame( + RANGE, + PRECEDING, + Optional.of("frame_start_value"), + Optional.of(new ShortDecimalType(12, 1)), + Optional.of("key_for_frame_start_comparison"), + Optional.of(new ShortDecimalType(12, 1)), + CURRENT_ROW, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty())), + project(// coerce sort key to compare sort key values with frame start values + ImmutableMap.of("key_for_frame_start_comparison", expression("CAST(key AS decimal(12, 1))")), + project(// calculate frame start value (sort key - frame offset) + ImmutableMap.of("frame_start_value", expression(new FunctionCall(QualifiedName.of("presto", "default", "$operator$subtract"), ImmutableList.of(new SymbolReference("key_for_frame_start_calculation"), new SymbolReference("x"))))), + project(// coerce sort key to calculate frame start values + ImmutableMap.of("key_for_frame_start_calculation", expression("CAST(key AS decimal(10, 0))")), + filter(// validate offset values + "IF((x >= CAST(0 AS DECIMAL(2,1))), " + + "true, " + + "CAST(presto.default.fail(CAST('Window frame offset value must not be negative or null' AS varchar)) AS boolean))", + anyTree( + values( + ImmutableList.of("key", "x"), + ImmutableList.of( + ImmutableList.of(new LongLiteral("1"), new DecimalLiteral("1.1")), + ImmutableList.of(new LongLiteral("2"), new DecimalLiteral("2.2"))))))))))); + + assertPlan(sql, CREATED, pattern); + } + + @Test + public void testFrameFollowingWithOffsetCoercion() + { + @Language("SQL") String sql = "SELECT array_agg(key) OVER(ORDER BY key RANGE BETWEEN CURRENT ROW AND x FOLLOWING) " + + "FROM (VALUES (1.1, 1), (2.2, 2)) T(key, x)"; + + PlanMatchPattern pattern = + anyTree( + window( + windowMatcherBuilder -> windowMatcherBuilder + .specification(specification( + ImmutableList.of(), + ImmutableList.of("key"), + ImmutableMap.of("key", SortOrder.ASC_NULLS_LAST))) + .addFunction( + "array_agg_result", + functionCall("array_agg", ImmutableList.of("key")), + createTestFunctionAndTypeManager().resolveFunction(Optional.empty(), Optional.empty(), + QualifiedObjectName.valueOf("presto.default.array_agg"), fromTypes(createDecimalType(2, 1))), + windowFrame( + RANGE, + CURRENT_ROW, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + FOLLOWING, + Optional.of("frame_end_value"), + Optional.of(new ShortDecimalType(12, 1)), + Optional.of("key_for_frame_end_comparison"), + Optional.of(new ShortDecimalType(12, 1)))), + project(// coerce sort key to compare sort key values with frame end values + ImmutableMap.of("key_for_frame_end_comparison", expression("CAST(key AS decimal(12, 1))")), + project(// calculate frame end value (sort key + frame offset) + ImmutableMap.of("frame_end_value", expression(new FunctionCall(QualifiedName.of("presto", "default", "$operator$add"), ImmutableList.of(new SymbolReference("key"), new SymbolReference("offset"))))), + filter(// validate offset values + "IF((offset >= CAST(0 AS DECIMAL(10, 0))), " + + "true, " + + "CAST(presto.default.fail(CAST('Window frame offset value must not be negative or null' AS varchar)) AS boolean))", + project(// coerce offset value to calculate frame end values + ImmutableMap.of("offset", expression("CAST(x AS decimal(10, 0))")), + anyTree( + values( + ImmutableList.of("key", "x"), + ImmutableList.of( + ImmutableList.of(new DecimalLiteral("1.1"), new LongLiteral("1")), + ImmutableList.of(new DecimalLiteral("2.2"), new LongLiteral("2"))))))))))); + + assertPlan(sql, CREATED, pattern); + } + + @Test + public void testFramePrecedingFollowingNoCoercions() + { + @Language("SQL") String sql = "SELECT array_agg(key) OVER(ORDER BY key RANGE BETWEEN x PRECEDING AND y FOLLOWING) " + + "FROM (VALUES (1, 1, 1), (2, 2, 2)) T(key, x, y)"; + + PlanMatchPattern pattern = + anyTree( + window( + windowMatcherBuilder -> windowMatcherBuilder + .specification(specification( + ImmutableList.of(), + ImmutableList.of("key"), + ImmutableMap.of("key", SortOrder.ASC_NULLS_LAST))) + .addFunction( + "array_agg_result", + functionCall("array_agg", ImmutableList.of("key")), + createTestFunctionAndTypeManager().resolveFunction(Optional.empty(), Optional.empty(), QualifiedObjectName.valueOf("presto.default.array_agg"), fromTypes(INTEGER)), + windowFrame( + RANGE, + PRECEDING, + Optional.of("frame_start_value"), + Optional.of(INTEGER), + Optional.of("key"), + Optional.of(INTEGER), + FOLLOWING, + Optional.of("frame_end_value"), + Optional.of(INTEGER), + Optional.of("key"), + Optional.of(INTEGER))), + project(// calculate frame end value (sort key + frame end offset) + ImmutableMap.of("frame_end_value", expression(new FunctionCall(QualifiedName.of("presto", "default", "$operator$add"), ImmutableList.of(new SymbolReference("key"), new SymbolReference("y"))))), + filter(// validate frame end offset values + "IF((y >= CAST(0 AS INTEGER)), " + + "true, " + + "CAST(presto.default.fail(CAST('Window frame offset value must not be negative or null' AS varchar)) AS boolean))", + project(// calculate frame start value (sort key - frame start offset)pattern = {PlanMatchPattern@4789} "- anyTree\n - node(WindowNode)\n WindowMatcher{specification=SpecificationProvider{partitionBy=[], orderBy=[key], orderings={key=ASC_NULLS_LAST}}}\n bind array_agg_result -> WindowFunctionMatcher{callMaker=array_agg (key) , functionHandle=presto.default.array_agg(integer):array(integer), frameMaker=WindowFrameProvider{type=RANGE, startType=PRECEDING, startValue=Optional[frame_start_value], endType=FOLLOWING, endValue=Optional[frame_end_value]}}\n - node(ProjectNode)\n bind frame_end_value -> "$operator$add"("key", "y")\n - node(FilterNode)\n FilterMatcher{predicate=IF(("y" >= CAST(0 AS integer)), true, CAST("fail"(CAST('Window frame offset value must not be negative or null' AS varchar)) AS boolean))}\n - node(ProjectNode)\n bind frame_start_value -> "$operator$subtract"("key", "x")\n - node(FilterNode)\n FilterMatcher{predicate=IF(("x" >= CAST(0 AS integer)), t"… View + ImmutableMap.of("frame_start_value", expression(new FunctionCall(QualifiedName.of("presto", "default", "$operator$subtract"), ImmutableList.of(new SymbolReference("key"), new SymbolReference("x"))))), + filter(// validate frame start offset values + "IF((x >= CAST(0 AS INTEGER)), " + + "true, " + + "CAST(presto.default.fail(CAST('Window frame offset value must not be negative or null' AS varchar)) AS boolean))", + anyTree( + values( + ImmutableList.of("key", "x", "y"), + ImmutableList.of( + ImmutableList.of(new LongLiteral("1"), new LongLiteral("1"), new LongLiteral("1")), + ImmutableList.of(new LongLiteral("2"), new LongLiteral("2"), new LongLiteral("2"))))))))))); + + assertPlan(sql, CREATED, pattern); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java index 79c2d33b4ae64..2101921810c29 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java @@ -47,6 +47,12 @@ public ExpressionMatcher(String expression) this.expression = expression(requireNonNull(expression)); } + public ExpressionMatcher(Expression expression) + { + this.expression = requireNonNull(expression, "expression is null"); + this.sql = requireNonNull(expression).toString(); + } + private Expression expression(String sql) { SqlParser parser = new SqlParser(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java index 9d8525891c1ea..66a40a2e78c22 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java @@ -26,6 +26,7 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; +import com.facebook.presto.sql.tree.IfExpression; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.IsNotNullPredicate; @@ -316,7 +317,8 @@ protected Boolean visitSymbolReference(SymbolReference actual, Node expected) if (!(expected instanceof SymbolReference)) { return false; } - return symbolAliases.get(((SymbolReference) expected).getName()).equals(actual); + return symbolAliases.get(((SymbolReference) expected).getName()).equals(actual) || + expected.equals(actual); } @Override @@ -451,6 +453,20 @@ protected Boolean visitSearchedCaseExpression(SearchedCaseExpression actual, Nod return process(actual.getDefaultValue(), expected.getDefaultValue()) && process(actual.getWhenClauses(), expected.getWhenClauses()); } + @Override + protected Boolean visitIfExpression(IfExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof IfExpression)) { + return false; + } + + IfExpression expected = (IfExpression) expectedExpression; + + return process(actual.getCondition(), expected.getCondition()) + && process(actual.getTrueValue(), expected.getTrueValue()) + && process(actual.getFalseValue(), expected.getFalseValue()); + } + private boolean process(List actuals, List expecteds) { if (actuals.size() != expecteds.size()) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index d3f83535833eb..cf979d282b896 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.predicate.Domain; +import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.AggregationNode; @@ -256,14 +257,37 @@ public static ExpectedValueProvider windowFrame( BoundType startType, Optional startValue, BoundType endType, - Optional endValue) + Optional endValue, + Optional sortKey) + { + return windowFrame(type, startType, startValue, Optional.empty(), sortKey, Optional.empty(), endType, endValue, Optional.empty(), sortKey, Optional.empty()); + } + + public static ExpectedValueProvider windowFrame( + WindowType type, + BoundType startType, + Optional startValue, + Optional startValueType, + Optional sortKeyForStartComparison, + Optional sortKeyForStartComparisonType, + BoundType endType, + Optional endValue, + Optional endValueType, + Optional sortKeyForEndComparison, + Optional sortKeyForEndComparisonType) { return new WindowFrameProvider( type, startType, startValue.map(SymbolAlias::new), + startValueType, + sortKeyForStartComparison.map(SymbolAlias::new), + sortKeyForStartComparisonType, endType, - endValue.map(SymbolAlias::new)); + endValue.map(SymbolAlias::new), + endValueType, + sortKeyForEndComparison.map(SymbolAlias::new), + sortKeyForEndComparisonType); } public static PlanMatchPattern window(Consumer windowMatcherBuilderConsumer, PlanMatchPattern source) @@ -692,6 +716,11 @@ public static ExpressionMatcher expression(String expression) return new ExpressionMatcher(expression); } + public static ExpressionMatcher expression(Expression expression) + { + return new ExpressionMatcher(expression); + } + public PlanMatchPattern withOutputs(String... aliases) { return withOutputs(ImmutableList.copyOf(aliases)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java index 66579552a9905..ec8df12490f87 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java @@ -103,4 +103,21 @@ public static boolean matchSpecification(WindowNode.Specification actual, Window .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)))) .orElse(true); } + + public static boolean matchSpecification(WindowNode.Specification actual, SpecificationProvider expected) + { + return actual.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList()) + .equals(expected.partitionBy.stream().map(SymbolAlias::toString).collect(toImmutableList())) && + actual.getOrderingScheme().map(orderingScheme -> orderingScheme.getOrderByVariables().stream() + .map(VariableReferenceExpression::getName) + .collect(toImmutableSet()) + .equals(expected.orderBy.stream() + .map(SymbolAlias::toString) + .collect(toImmutableSet())) && + orderingScheme.getOrderingsMap().entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)) + .equals(expected.orderings.entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().toString(), Map.Entry::getValue)))) + .orElse(true); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java index 281627ec290b3..6c064a8ebf9d6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java @@ -100,19 +100,20 @@ private static String toKey(String alias) private Map getUpdatedAssignments(Assignments assignments) { + Map> newMap = new HashMap<>(); ImmutableMap.Builder mapUpdate = ImmutableMap.builder(); for (Map.Entry assignment : assignments.getMap().entrySet()) { for (Map.Entry existingAlias : map.entrySet()) { RowExpression expression = assignment.getValue(); if (isExpression(expression) && castToExpression(expression).equals(existingAlias.getValue())) { // Simple symbol rename - mapUpdate.put(existingAlias.getKey(), asSymbolReference(assignment.getKey())); + updateMap(existingAlias.getKey(), asSymbolReference(assignment.getKey()), newMap); } else if (!isExpression(expression) && (expression instanceof VariableReferenceExpression) && ((VariableReferenceExpression) expression).getName().equals(existingAlias.getValue().getName())) { // Simple symbol rename - mapUpdate.put(existingAlias.getKey(), createSymbolReference(assignment.getKey())); + updateMap(existingAlias.getKey(), createSymbolReference(assignment.getKey()), newMap); } else if (createSymbolReference(assignment.getKey()).equals(existingAlias.getValue())) { /* @@ -125,13 +126,38 @@ else if (createSymbolReference(assignment.getKey()).equals(existingAlias.getValu * At the beginning for the function, map contains { NEW_ALIAS: SymbolReference("expr_2" } * and the assignments map contains { expr_2 := }. */ - mapUpdate.put(existingAlias.getKey(), existingAlias.getValue()); + updateMap(existingAlias.getKey(), existingAlias.getValue(), newMap); } } } + for (Map.Entry> entry : newMap.entrySet()) { + mapUpdate.put(entry.getKey(), entry.getValue().get()); + } return mapUpdate.build(); } + private void updateMap(String key, SymbolReference symbolRef, Map> newMap) + { + /* + * Assignments: + * field_4: field_4 + * field_5: field_5 + * key: field_4 + * x: field_5 + * key_6: field_4 + * map: + * key: field_4 + * x: field_5 + * The updated map should not contain either or + * result updated map: + * key: key_6 + * x: field_5 + */ + if (!key.equals(symbolRef.getName()) || !newMap.containsKey(key)) { + newMap.put(key, Optional.of(symbolRef)); + } + } + /* * Return a new SymbolAliases that contains a map with the original bindings * updated based on assignments given that assignments is a map of diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java index 325b1f5f9d5bc..66d61211c2eb7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.assertions; +import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType; @@ -31,21 +32,40 @@ public class WindowFrameProvider private final WindowType type; private final BoundType startType; private final Optional startValue; + private final Optional sortKeyForStartComparison; private final BoundType endType; private final Optional endValue; + private final Optional sortKeyForEndComparison; + + private Optional startValueType; + private Optional sortKeyForStartComparisonType; + private Optional endValueType; + private Optional sortKeyForEndComparisonType; WindowFrameProvider( WindowType type, BoundType startType, Optional startValue, + Optional startValueType, + Optional sortKeyForStartComparison, + Optional sortKeyForStartComparisonType, BoundType endType, - Optional endValue) + Optional endValue, + Optional endValueType, + Optional sortKeyForEndComparison, + Optional sortKeyForEndComparisonType) { this.type = requireNonNull(type, "type is null"); this.startType = requireNonNull(startType, "startType is null"); this.startValue = requireNonNull(startValue, "startValue is null"); + this.startValueType = requireNonNull(startValueType, "startValueType is null"); + this.sortKeyForStartComparison = requireNonNull(sortKeyForStartComparison, "sortKeyForStartComparison is null"); + this.sortKeyForStartComparisonType = requireNonNull(sortKeyForStartComparisonType, "sortKeyForStartComparisonType is null"); this.endType = requireNonNull(endType, "endType is null"); this.endValue = requireNonNull(endValue, "endValue is null"); + this.endValueType = requireNonNull(endValueType, "endValueType is null"); + this.sortKeyForEndComparison = requireNonNull(sortKeyForEndComparison, "sortKeyForEndComparison is null"); + this.sortKeyForEndComparisonType = requireNonNull(sortKeyForEndComparisonType, "sortKeyForEndComparisonType is null"); } @Override @@ -59,13 +79,36 @@ public WindowNode.Frame getExpectedValue(SymbolAliases aliases) return new WindowNode.Frame( type, startType, - startValue.map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), BIGINT)), + toVariableReferenceExpression(aliases, startValue, startValueType), + //startValue.map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), startValueType.orElseGet(() -> BIGINT))), + toVariableReferenceExpression(aliases, sortKeyForStartComparison, sortKeyForStartComparisonType), + //sortKeyForStartComparison.map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), sortKeyForStartComparisonType.orElseGet(() -> BIGINT))), endType, - endValue.map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), BIGINT)), + toVariableReferenceExpression(aliases, endValue, endValueType), + //endValue.map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), endValueType.orElseGet(() -> BIGINT))), + toVariableReferenceExpression(aliases, sortKeyForEndComparison, sortKeyForEndComparisonType), + //sortKeyForEndComparison.map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), sortKeyForEndComparisonType.orElseGet(() -> BIGINT))), originalStartValue.map(Expression::toString), originalEndValue.map(Expression::toString)); } + private Optional toVariableReferenceExpression(SymbolAliases aliases, Optional symbolAlias, Optional type) + { + if (!symbolAlias.isPresent()) { + return Optional.empty(); + } + + String alias = symbolAlias.get().toSymbol(aliases).getName(); + Type variableType = type.orElseGet(() -> BIGINT); + // This is really ugly! Because we translate differently between field and other expressions in TranslationMap.get(). + // We should handle them the same way here but we don't have information other than their names. Luckily this is test code. + if (alias.startsWith("field")) { + return Optional.of(new VariableReferenceExpression(symbolAlias.get().toString(), variableType)); + } + + return Optional.of(new VariableReferenceExpression(symbolAlias.get().toSymbol(aliases).getName(), variableType)); + } + @Override public String toString() { @@ -73,8 +116,14 @@ public String toString() .add("type", this.type) .add("startType", this.startType) .add("startValue", this.startValue) + .add("startValueType", this.startValueType) + .add("sortKeyForStartComparison", this.sortKeyForStartComparison) + .add("sortKeyForStartComparisonType", this.sortKeyForStartComparisonType) .add("endType", this.endType) .add("endValue", this.endValue) + .add("endValueType", this.endValueType) + .add("sortKeyForEndComparison", this.sortKeyForEndComparison) + .add("sortKeyForEndComparisonType", this.sortKeyForEndComparisonType) .toString(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java index bb6f3cdb1242d..ac5d0e1036213 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java @@ -24,6 +24,7 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableSet; import java.util.List; @@ -99,9 +100,13 @@ public Optional getAssignedVariable(PlanNode node, } } else { - if (!expectedExpression.equals(castToExpression(actualExpression))) { - return false; + if (expectedExpression.equals(castToExpression(actualExpression))) { + return true; + } + if (castToExpression(actualExpression) instanceof SymbolReference) { + return expectedExpression.equals(symbolAliases.get((((SymbolReference) castToExpression(actualExpression))).getName())); } + return false; } } return true; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java index b2c3bc9363b96..4cff3d3b6b284 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java @@ -85,7 +85,8 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses } if (!specification - .map(expectedSpecification -> matchSpecification(windowNode.getSpecification(), expectedSpecification.getExpectedValue(symbolAliases))) + .map(expectedSpecification -> matchSpecification(windowNode.getSpecification(), expectedSpecification.getExpectedValue(symbolAliases)) || + (expectedSpecification instanceof SpecificationProvider && matchSpecification(windowNode.getSpecification(), (SpecificationProvider) expectedSpecification))) .orElse(true)) { return NO_MATCH; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 579030675b7a3..ad5443e972792 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -53,9 +53,11 @@ public class TestMergeAdjacentWindows RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), CURRENT_ROW, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); private static final FunctionHandle SUM_FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("sum", fromTypes(DOUBLE)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java index 60ab90d96b8ef..0edede7b133d8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -70,14 +70,16 @@ public class TestPruneWindowColumns UNBOUNDED_PRECEDING, Optional.of("startValue1"), CURRENT_ROW, - Optional.of("endValue1")); + Optional.of("endValue1"), + Optional.of("orderKey")); private static final ExpectedValueProvider frameProvider2 = windowFrame( RANGE, UNBOUNDED_PRECEDING, Optional.of("startValue2"), CURRENT_ROW, - Optional.of("endValue2")); + Optional.of("endValue2"), + Optional.of("orderKey")); @Test public void testWindowNotNeeded() @@ -223,8 +225,10 @@ private static PlanNode buildProjectedWindow( RANGE, UNBOUNDED_PRECEDING, Optional.of(startValue1), + Optional.of(orderKey), CURRENT_ROW, Optional.of(endValue1), + Optional.of(orderKey), Optional.of(new SymbolReference(startValue1.getName())).map(Expression::toString), Optional.of(new SymbolReference(endValue2.getName())).map(Expression::toString)), false), @@ -235,8 +239,10 @@ private static PlanNode buildProjectedWindow( RANGE, UNBOUNDED_PRECEDING, Optional.of(startValue2), + Optional.of(orderKey), CURRENT_ROW, Optional.of(endValue2), + Optional.of(orderKey), Optional.of(new SymbolReference(startValue2.getName())).map(Expression::toString), Optional.of(new SymbolReference(endValue2.getName())).map(Expression::toString)), false)), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index 84bc4109db11e..a2d33ab34f4e0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -54,9 +54,11 @@ public TestSwapAdjacentWindowsBySpecifications() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), CURRENT_ROW, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); functionHandle = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("avg", fromTypes(BIGINT)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java index a41eae7d154f1..33e69c9670302 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java @@ -55,9 +55,11 @@ public void windowNodePruning() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); assertRuleApplication() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java index a47fe5ad36c67..526cfbe8b1bb3 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java @@ -225,8 +225,9 @@ public void testReorderAcrossProjectNodes() window(windowMatcherBuilder -> windowMatcherBuilder .specification(windowA) .addFunction(functionCall("lag", commonFrame, ImmutableList.of(QUANTITY_ALIAS, "ONE"))), - project(ImmutableMap.of("ONE", expression("CAST(1 AS bigint)")), - LINEITEM_TABLESCAN_DOQRST))))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly + project(ImmutableMap.of("ONE", expression("expr")), + project(ImmutableMap.of("expr", expression("CAST(1 AS bigint)")), + LINEITEM_TABLESCAN_DOQRST)))))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly } @Test diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java index d8271f87e277e..4a2f24b3bf527 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java @@ -105,9 +105,11 @@ public void testSerializationRoundtrip() RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); PlanNodeId id = newId(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyNoOriginalExpression.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyNoOriginalExpression.java index 61bcb8cbb8fed..40d09db029b18 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyNoOriginalExpression.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyNoOriginalExpression.java @@ -69,6 +69,7 @@ public class TestVerifyNoOriginalExpression ComparisonExpression.Operator.EQUAL, new SymbolReference("count"), new Cast(new LongLiteral("5"), "bigint")); + private static final VariableReferenceExpression SORT_KEY = new VariableReferenceExpression("count", BIGINT); private Metadata metadata; private PlanBuilder builder; @@ -113,8 +114,10 @@ public void testValidateForWindow() WindowNode.Frame.WindowType.RANGE, WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING, startValue, + Optional.of(SORT_KEY), WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING, endValue, + Optional.of(SORT_KEY), originalStartValue, originalEndValue); WindowNode.Function function = new WindowNode.Function( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java b/presto-main/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java index dd99a717ba375..26e5a638d60d5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java @@ -22,16 +22,25 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.testing.QueryRunner; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.AssertProvider; +import org.assertj.core.api.ListAssert; +import org.assertj.core.presentation.Representation; +import org.assertj.core.presentation.StandardRepresentation; import org.intellij.lang.annotations.Language; import java.io.Closeable; import java.util.List; +import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.stream.Collectors; import static com.facebook.airlift.testing.Assertions.assertEqualsIgnoreOrder; +import static com.facebook.presto.sql.query.QueryAssertions.QueryAssert.newQueryAssert; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.google.common.base.Strings.nullToEmpty; import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; @@ -58,6 +67,33 @@ public QueryRunner getQueryRunner() return runner; } + public Session.SessionBuilder sessionBuilder() + { + return Session.builder(runner.getDefaultSession()); + } + + public Session getDefaultSession() + { + return runner.getDefaultSession(); + } + + public AssertProvider query(@Language("SQL") String query) + { + return query(query, runner.getDefaultSession()); + } + + public AssertProvider query(@Language("SQL") String query, Session session) + { + return newQueryAssert(query, runner, session); + } + + /** + * @deprecated use {@link org.assertj.core.api.Assertions#assertThatThrownBy(ThrowableAssert.ThrowingCallable)}: + *
+     * assertThatThrownBy(() -> assertions.execute(sql))
+ * .hasMessage(...) + *
+ */ public void assertFails(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) { try { @@ -83,12 +119,16 @@ public void assertQueryAndPlan( planValidator.accept(plan); } + /** + * @deprecated use {@link org.assertj.core.api.Assertions#assertThat} with {@link #query(String)} + */ + @Deprecated public void assertQuery(@Language("SQL") String actual, @Language("SQL") String expected) { assertQuery(actual, expected, false); } - public void assertQuery(@Language("SQL") String actual, @Language("SQL") String expected, boolean ensureOrdering) + private void assertQuery(@Language("SQL") String actual, @Language("SQL") String expected, boolean ensureOrdering) { MaterializedResult actualResults = null; try { @@ -126,4 +166,90 @@ public void close() { runner.close(); } + + public static class QueryAssert + extends AbstractAssert + { + private static final Representation ROWS_REPRESENTATION = new StandardRepresentation() + { + @Override + public String toStringOf(Object object) + { + if (object instanceof List) { + List list = (List) object; + return list.stream() + .map(this::toStringOf) + .collect(Collectors.joining(", ")); + } + if (object instanceof MaterializedRow) { + MaterializedRow row = (MaterializedRow) object; + + return row.getFields().stream() + .map(Object::toString) + .collect(Collectors.joining(", ", "(", ")")); + } + else { + return super.toStringOf(object); + } + } + }; + + private final QueryRunner runner; + private final Session session; + private boolean ordered; + + static AssertProvider newQueryAssert(String query, QueryRunner runner, Session session) + { + MaterializedResult result = runner.execute(session, query); + return () -> new QueryAssert(runner, session, result); + } + + public QueryAssert(QueryRunner runner, Session session, MaterializedResult actual) + { + super(actual, Object.class); + this.runner = runner; + this.session = session; + } + + public QueryAssert matches(BiFunction evaluator) + { + MaterializedResult expected = evaluator.apply(session, runner); + return isEqualTo(expected); + } + + public QueryAssert ordered() + { + ordered = true; + return this; + } + + public QueryAssert matches(@Language("SQL") String query) + { + MaterializedResult expected = runner.execute(session, query); + + return satisfies(actual -> { + assertThat(actual.getTypes()) + .as("Output types") + .isEqualTo(expected.getTypes()); + + ListAssert assertion = assertThat(actual.getMaterializedRows()) + .as("Rows") + .withRepresentation(ROWS_REPRESENTATION); + + if (ordered) { + assertion.containsExactlyElementsOf(expected.getMaterializedRows()); + } + else { + assertion.containsExactlyInAnyOrder(expected.getMaterializedRows().toArray(new MaterializedRow[0])); + } + }); + } + + public QueryAssert returnsEmptyResult() + { + return satisfies(actual -> { + assertThat(actual.getRowCount()).as("row count").isEqualTo(0); + }); + } + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/query/TestWindowFrameRange.java b/presto-main/src/test/java/com/facebook/presto/sql/query/TestWindowFrameRange.java new file mode 100644 index 0000000000000..22f03b459840a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/query/TestWindowFrameRange.java @@ -0,0 +1,756 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.query; + +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestWindowFrameRange +{ + private QueryAssertions assertions; + + @BeforeClass + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterClass(alwaysRun = true) + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testNullsSortKey() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)")) + .matches("VALUES " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[1, 1, 2, 2], " + + "ARRAY[1, 1, 2, 2], " + + "ARRAY[1, 1, 2, 2, 3], " + + "ARRAY[1, 1, 2, 2, 3], " + + "ARRAY[2, 2, 3]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)")) + .matches("VALUES " + + "ARRAY[1, 1, 2, 2], " + + "ARRAY[1, 1, 2, 2], " + + "ARRAY[1, 1, 2, 2, 3], " + + "ARRAY[1, 1, 2, 2, 3], " + + "ARRAY[2, 2, 3], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)")) + .matches("VALUES " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[3, 2, 2], " + + "ARRAY[3, 2, 2, 1, 1], " + + "ARRAY[3, 2, 2, 1, 1], " + + "ARRAY[2, 2, 1, 1], " + + "ARRAY[2, 2, 1, 1]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)")) + .matches("VALUES " + + "ARRAY[3, 2, 2], " + + "ARRAY[3, 2, 2, 1, 1], " + + "ARRAY[3, 2, 2, 1, 1], " + + "ARRAY[2, 2, 1, 1], " + + "ARRAY[2, 2, 1, 1], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2) T(a)")) + .matches("VALUES " + + "ARRAY[1, 2, null, null], " + + "ARRAY[1, 2, null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2) T(a)")) + .matches("VALUES " + + "ARRAY[1, 2], " + + "ARRAY[1, 2], " + + "ARRAY[1, 2, null, null], " + + "ARRAY[1, 2, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 1, 2], " + + "ARRAY[null, null, 1, 2]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2) T(a)")) + .matches("VALUES " + + "ARRAY[null, null, 1, 2], " + + "ARRAY[null, null, 1, 2], " + + "ARRAY[1, 2], " + + "ARRAY[1, 2]"); + } + + @Test + public void testNoValueFrameBounds() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 1, 1], " + + "ARRAY[null, null, 1, 1], " + + "ARRAY[null, null, 1, 1, 2]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[2]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[2]"); + } + + @Test + public void testMixedTypeFrameBoundsAscendingNullsFirst() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 PRECEDING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 1, 1]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[2]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 1.5 PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1, 2]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 0.5 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[2]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS FIRST RANGE BETWEEN 0.5 FOLLOWING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[null, null, 1, 1, 2], " + + "ARRAY[2], " + + "ARRAY[2], " + + "null"); + } + + @Test + public void testMixedTypeFrameBoundsAscendingNullsLast() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 PRECEDING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "null, " + + "null, " + + "ARRAY[1, 1], " + + "ARRAY[1, 1, 2, null, null], " + + "ARRAY[1, 1, 2, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2, null, null], " + + "ARRAY[1, 1, 2, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[1, 1, 2], " + + "ARRAY[1, 1, 2], " + + "ARRAY[2], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 1.5 PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1, 2], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 0.5 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[1, 1, 2, null, null], " + + "ARRAY[1, 1, 2, null, null], " + + "ARRAY[2, null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 0.5 FOLLOWING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[2, null, null], " + + "ARRAY[2, null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + } + + @Test + public void testMixedTypeFrameBoundsDescendingNullsFirst() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 PRECEDING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 2], " + + "ARRAY[null, null, 2]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null, 2], " + + "ARRAY[null, null, 2, 1, 1], " + + "ARRAY[null, null, 2, 1, 1]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[2, 1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN 1.5 PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[2], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN 1.5 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null, 2, 1, 1], " + + "ARRAY[null, null, 2, 1, 1], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS FIRST RANGE BETWEEN 1.5 FOLLOWING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[null, null, 2, 1, 1], " + + "ARRAY[null, null, 2, 1, 1], " + + "null, " + + "null, " + + "null"); + } + + @Test + public void testMixedTypeFrameBoundsDescendingNullsLast() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 0.5 PRECEDING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "null, " + + "ARRAY[2], " + + "ARRAY[2], " + + "ARRAY[2, 1, 1, null, null], " + + "ARRAY[2, 1, 1, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1], " + + "ARRAY[2, 1, 1, null, null], " + + "ARRAY[2, 1, 1, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[2, 1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 0.5 PRECEDING AND CURRENT ROW) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[2], " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 0.5 PRECEDING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "ARRAY[2, 1, 1, null, null], " + + "ARRAY[1, 1, null, null], " + + "ARRAY[1, 1, null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 1.5 FOLLOWING AND UNBOUNDED FOLLOWING) " + + "FROM (VALUES 1, null, null, 2, 1) T(a)")) + .matches("VALUES " + + "CAST(ARRAY[null, null] AS array(integer)), " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null], " + + "ARRAY[null, null]"); + } + + @Test + public void testEmptyInput() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (SELECT 1 WHERE false) T(a)")) + .returnsEmptyResult(); + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE UNBOUNDED PRECEDING) " + + "FROM (SELECT 1 WHERE false) T(a)")) + .returnsEmptyResult(); + } + + @Test + public void testEmptyFrame() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 10 PRECEDING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)")) + .matches("VALUES " + + "CAST(null AS array(integer)), " + + "null, " + + "null, " + + "null, " + + "null, " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC NULLS LAST RANGE BETWEEN 10 FOLLOWING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 2, 3, null, null, 2, 1, null, null) T(a)")) + .matches("VALUES " + + "CAST(null AS array(integer)), " + + "null, " + + "null, " + + "null, " + + "null, " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null], " + + "ARRAY[null, null, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 0.5 FOLLOWING AND 1.5 FOLLOWING) " + + "FROM (VALUES 1, 2, 4) T(a)")) + .matches("VALUES " + + "ARRAY[2], " + + "null, " + + "null"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 1 FOLLOWING AND 2 FOLLOWING) " + + "FROM (VALUES 1.0, 1.1) T(a)")) + .matches("VALUES " + + "CAST(null AS array(decimal(2, 1))), " + + "null"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a NULLS LAST RANGE BETWEEN 1 FOLLOWING AND 2 FOLLOWING) " + + "FROM (VALUES 1.0, 1.1, null) T(a)")) + .matches("VALUES " + + "CAST(null AS array(decimal(2, 1))), " + + "null, " + + "ARRAY[null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES 1.0, 1.1) T(a)")) + .matches("VALUES " + + "CAST(null AS array(decimal(2, 1))), " + + "null"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a NULLS FIRST RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES null, 1.0, 1.1) T(a)")) + .matches("VALUES " + + "CAST(ARRAY[null] AS array(decimal(2,1))), " + + "null, " + + "null"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES 1, 2) T(a)")) + .matches("VALUES " + + "null, " + + "ARRAY[1]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a NULLS FIRST RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES null, 1, 2) T(a)")) + .matches("VALUES " + + "ARRAY[null], " + + "null, " + + "ARRAY[1]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a NULLS FIRST RANGE BETWEEN 2 PRECEDING AND 1.5 PRECEDING) " + + "FROM (VALUES null, 1, 2) T(a)")) + .matches("VALUES " + + "CAST(ARRAY[null] AS array(integer)), " + + "null, " + + "null"); + } + + @Test + public void testOnlyNulls() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 1 FOLLOWING AND 2 FOLLOWING) " + + "FROM (VALUES CAST(null AS integer), null, null) T(a)")) + .matches("VALUES " + + "CAST(ARRAY[null, null, null] AS array(integer)), " + + "ARRAY[null, null, null], " + + "ARRAY[null, null, null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES CAST(null AS integer), null, null) T(a)")) + .matches("VALUES " + + "CAST(ARRAY[null, null, null] AS array(integer)), " + + "ARRAY[null, null, null], " + + "ARRAY[null, null, null]"); + } + + @Test + public void testAllPartitionSameValues() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 1 FOLLOWING AND 2 FOLLOWING) " + + "FROM (VALUES 1, 1, 1) T(a)")) + .matches("VALUES " + + "CAST(null AS array(integer)), " + + "null, " + + "null"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING) " + + "FROM (VALUES 1, 1, 1) T(a)")) + .matches("VALUES " + + "CAST(null AS array(integer)), " + + "null, " + + "null"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES 1, 1, 1) T(a)")) + .matches("VALUES " + + "ARRAY[1, 1, 1], " + + "ARRAY[1, 1, 1], " + + "ARRAY[1, 1, 1]"); + } + + @Test + public void testZeroOffset() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC NULLS LAST RANGE BETWEEN 0 PRECEDING AND 0 FOLLOWING) " + + "FROM (VALUES 1, 2, 1, null) T(a)")) + .matches("VALUES " + + "ARRAY[1, 1], " + + "ARRAY[1, 1], " + + "ARRAY[2], " + + "ARRAY[null]"); + } + + @Test + public void testNonConstantOffset() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN x * 10 PRECEDING AND y / 10.0 FOLLOWING) " + + "FROM (VALUES (1, 0.1, 10), (2, 0.2, 20), (4, 0.4, 40)) T(a, x, y)")) + .matches("VALUES " + + "ARRAY[1, 2], " + + "ARRAY[1, 2, 4], " + + "ARRAY[1, 2, 4]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN x * 10 PRECEDING AND y / 10.0 FOLLOWING) " + + "FROM (VALUES (1, 0.1, 10), (2, 0.2, 20), (4, 0.4, 40), (null, 0.5, 50)) T(a, x, y)")) + .matches("VALUES " + + "ARRAY[1, 2], " + + "ARRAY[1, 2, 4], " + + "ARRAY[1, 2, 4], " + + "ARRAY[null]"); + } + + @Test + public void testInvalidOffset() + { + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC RANGE x PRECEDING) " + + "FROM (VALUES (1, 0.1), (2, -0.2)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a ASC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (2, -0.2)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE x PRECEDING) " + + "FROM (VALUES (1, 0.1), (2, -0.2)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (2, -0.2)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE x PRECEDING) " + + "FROM (VALUES (1, 0.1), (2, null)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (2, null)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + // fail if offset is invalid for null sort key + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (null, null)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a DESC RANGE BETWEEN 1 PRECEDING AND x FOLLOWING) " + + "FROM (VALUES (1, 0.1), (null, -0.1)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + // test invalid offset of different types + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, BIGINT '-1')) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, INTEGER '-1')) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (SMALLINT '1', SMALLINT '-1')) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (TINYINT '1', TINYINT '-1')) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, -1.1e0)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, REAL '-1.1')) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (1, -1.0001)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' YEAR)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' MONTH)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' DAY)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' HOUR)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' MINUTE)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + + assertThatThrownBy(() -> assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE x PRECEDING) " + + "FROM (VALUES (DATE '2001-01-31', INTERVAL '-1' SECOND)) T(a, x)")) + .hasMessage("Window frame offset value must not be negative or null"); + } + + @Test + public void testWindowPartitioning() + { + assertThat(assertions.query("SELECT a, p, array_agg(a) OVER(PARTITION BY p ORDER BY a ASC NULLS FIRST RANGE BETWEEN 0.5 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES (1, 'x'), (2, 'x'), (null, 'x'), (null, 'y'), (2, 'y')) T(a, p)")) + .matches("VALUES " + + "(null, 'x', ARRAY[null]), " + + "(1, 'x', ARRAY[1, 2]), " + + "(2, 'x', ARRAY[2]), " + + "(null, 'y', ARRAY[null]), " + + "(2, 'y', ARRAY[2])"); + + assertThat(assertions.query("SELECT a, p, array_agg(a) OVER(PARTITION BY p ORDER BY a ASC NULLS FIRST RANGE BETWEEN 0.5 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES (1, 'x'), (2, 'x'), (null, 'x'), (null, 'y'), (2, 'y'), (null, null), (null, null), (1, null)) T(a, p)")) + .matches("VALUES " + + "(null, null, ARRAY[null, null]), " + + "(null, null, ARRAY[null, null]), " + + "(1, null, ARRAY[1]), " + + "(null, 'x', ARRAY[null]), " + + "(1, 'x', ARRAY[1, 2]), " + + "(2, 'x', ARRAY[2]), " + + "(null, 'y', ARRAY[null]), " + + "(2, 'y', ARRAY[2])"); + } + + @Test + public void testTypes() + { + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN DOUBLE '0.5' PRECEDING AND TINYINT '1' FOLLOWING) " + + "FROM (VALUES 1, null, 2) T(a)")) + .matches("VALUES " + + "ARRAY[1, 2], " + + "ARRAY[2], " + + "ARRAY[null]"); + + assertThat(assertions.query("SELECT array_agg(a) OVER(ORDER BY a RANGE BETWEEN 0.5 PRECEDING AND 1.000 FOLLOWING) " + + "FROM (VALUES REAL '1', null, 2) T(a)")) + .matches("VALUES " + + "ARRAY[REAL '1', REAL '2'], " + + "ARRAY[REAL '2'], " + + "ARRAY[null]"); + + assertThat(assertions.query("SELECT x, array_agg(x) OVER(ORDER BY x DESC RANGE BETWEEN interval '1' month PRECEDING AND interval '1' month FOLLOWING) " + + "FROM (VALUES DATE '2001-01-31', DATE '2001-08-25', DATE '2001-09-25', DATE '2001-09-26') T(x)")) + .matches("VALUES " + + "(DATE '2001-09-26', ARRAY[DATE '2001-09-26', DATE '2001-09-25']), " + + "(DATE '2001-09-25', ARRAY[DATE '2001-09-26', DATE '2001-09-25', DATE '2001-08-25']), " + + "(DATE '2001-08-25', ARRAY[DATE '2001-09-25', DATE '2001-08-25']), " + + "(DATE '2001-01-31', ARRAY[DATE '2001-01-31'])"); + + // January 31 + 1 month sets the frame bound to the last day of February. March 1 is out of range. + assertThat(assertions.query("SELECT x, array_agg(x) OVER(ORDER BY x RANGE BETWEEN CURRENT ROW AND interval '1' month FOLLOWING) " + + "FROM (VALUES DATE '2001-01-31', DATE '2001-02-28', DATE '2001-03-01') T(x)")) + .matches("VALUES " + + "(DATE '2001-01-31', ARRAY[DATE '2001-01-31', DATE '2001-02-28']), " + + "(DATE '2001-02-28', ARRAY[DATE '2001-02-28', DATE '2001-03-01']), " + + "(DATE '2001-03-01', ARRAY[DATE '2001-03-01'])"); + + assertThat(assertions.query("SELECT x, array_agg(x) OVER(ORDER BY x RANGE BETWEEN interval '1' year PRECEDING AND interval '1' month FOLLOWING) " + + "FROM (VALUES " + + "INTERVAL '1' month, " + + "INTERVAL '2' month, " + + "INTERVAL '5' year) T(x)")) + .matches("VALUES " + + "(INTERVAL '1' month, ARRAY[INTERVAL '1' month, INTERVAL '2' month]), " + + "(INTERVAL '2' month, ARRAY[INTERVAL '1' month, INTERVAL '2' month]), " + + "(INTERVAL '5' year, ARRAY[INTERVAL '5' year])"); + } + + @Test + public void testMultipleWindowFunctions() + { + assertThat(assertions.query("SELECT x, array_agg(date) OVER(ORDER BY x RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), avg(number) OVER(ORDER BY x RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM (VALUES " + + "(2, DATE '2222-01-01', 4.4), " + + "(1, DATE '1111-01-01', 2.2), " + + "(3, DATE '3333-01-01', 6.6)) T(x, date, number)")) + .matches("VALUES " + + "(1, ARRAY[DATE '1111-01-01', DATE '2222-01-01'], 3.3), " + + "(2, ARRAY[DATE '1111-01-01', DATE '2222-01-01', DATE '3333-01-01'], 4.4), " + + "(3, ARRAY[DATE '2222-01-01', DATE '3333-01-01'], 5.5)"); + + assertThat(assertions.query("SELECT x, array_agg(a) OVER(ORDER BY x RANGE BETWEEN 2 PRECEDING AND CURRENT ROW), array_agg(a) OVER(ORDER BY x RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) " + + "FROM (VALUES " + + "(1.0, 1), " + + "(2.0, 2), " + + "(3.0, 3), " + + "(4.0, 4), " + + "(5.0, 5), " + + "(6.0, 6)) T(x, a)")) + .matches("VALUES " + + "(1.0, ARRAY[1], ARRAY[1, 2, 3]), " + + "(2.0, ARRAY[1, 2], ARRAY[2, 3, 4]), " + + "(3.0, ARRAY[1, 2, 3], ARRAY[3, 4, 5]), " + + "(4.0, ARRAY[2, 3, 4], ARRAY[4, 5, 6]), " + + "(5.0, ARRAY[3, 4, 5], ARRAY[5, 6]), " + + "(6.0, ARRAY[4, 5, 6], ARRAY[6])"); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionHandle.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionHandle.java index 4397676f63fbc..a294686798f94 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionHandle.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionHandle.java @@ -14,6 +14,7 @@ package com.facebook.presto.spi.function; import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.NotSupportedException; /** * FunctionHandle is a unique handle to identify the function implementation from namespaces. @@ -22,4 +23,8 @@ public interface FunctionHandle { CatalogSchemaName getCatalogSchemaName(); + default Signature getSignature() + { + throw new NotSupportedException("Method getSignature is not implemented."); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/Assignments.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/Assignments.java index b421cfceb4c92..1d9758cc6ae27 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/Assignments.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/Assignments.java @@ -26,6 +26,7 @@ import java.util.Map.Entry; import java.util.Set; import java.util.function.BiConsumer; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collector; @@ -198,6 +199,20 @@ public Builder put(VariableReferenceExpression variable, RowExpression expressio return this; } + public Builder putIdentities(Iterable variables, Function toRowExpression) + { + for (VariableReferenceExpression variable : variables) { + putIdentity(variable, toRowExpression); + } + return this; + } + + public Builder putIdentity(VariableReferenceExpression variable, Function toRowExpression) + { + put(variable, toRowExpression.apply(variable)); + return this; + } + public Builder put(Entry assignment) { put(assignment.getKey(), assignment.getValue());