Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import java.util.stream.Stream;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.prestosql.operator.SyntheticAddress.decodePosition;
Expand Down Expand Up @@ -453,6 +454,12 @@ public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, Opt
blockTypeOperators);
}

public PagesIndexComparator createChannelComparator(int leftChannel, int rightChannel)
{
checkArgument(types.get(leftChannel).equals(types.get(rightChannel)), "comparing channels of different types: %s and %s", types.get(leftChannel), types.get(rightChannel));
return new SimpleChannelComparator(leftChannel, rightChannel, blockTypeOperators.getComparisonOperator(types.get(leftChannel)));
}

public LookupSourceSupplier createLookupSourceSupplier(
Session session,
List<Integer> joinChannels,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.operator;

import io.prestosql.spi.PrestoException;
import io.prestosql.spi.block.Block;
import io.prestosql.type.BlockTypeOperators.BlockPositionComparison;

import static com.google.common.base.Throwables.throwIfUnchecked;
import static io.prestosql.operator.SyntheticAddress.decodePosition;
import static io.prestosql.operator.SyntheticAddress.decodeSliceIndex;
import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static java.util.Objects.requireNonNull;

public class SimpleChannelComparator
implements PagesIndexComparator
{
private final int leftChannel;
private final int rightChannel;
private final BlockPositionComparison comparator;

public SimpleChannelComparator(int leftChannel, int rightChannel, BlockPositionComparison comparator)
{
this.leftChannel = leftChannel;
this.rightChannel = rightChannel;
this.comparator = requireNonNull(comparator, "comparator is null");
}

@Override
public int compareTo(PagesIndex pagesIndex, int leftPosition, int rightPosition)
{
long leftPageAddress = pagesIndex.getValueAddresses().getLong(leftPosition);
int leftBlockIndex = decodeSliceIndex(leftPageAddress);
int leftBlockPosition = decodePosition(leftPageAddress);

long rightPageAddress = pagesIndex.getValueAddresses().getLong(rightPosition);
int rightBlockIndex = decodeSliceIndex(rightPageAddress);
int rightBlockPosition = decodePosition(rightPageAddress);

try {
Block leftBlock = pagesIndex.getChannel(leftChannel).get(leftBlockIndex);
Block rightBlock = pagesIndex.getChannel(rightChannel).get(rightBlockIndex);
return (int) comparator.compare(leftBlock, leftBlockPosition, rightBlock, rightBlockPosition);
}
catch (Throwable throwable) {
throwIfUnchecked(throwable);
throw new PrestoException(GENERIC_INTERNAL_ERROR, throwable);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.PeekingIterator;
Expand All @@ -25,6 +26,7 @@
import io.prestosql.operator.WorkProcessor.ProcessState;
import io.prestosql.operator.WorkProcessor.Transformation;
import io.prestosql.operator.WorkProcessor.TransformationState;
import io.prestosql.operator.window.FrameInfo;
import io.prestosql.operator.window.FramedWindowFunction;
import io.prestosql.operator.window.WindowPartition;
import io.prestosql.spi.Page;
Expand All @@ -37,6 +39,8 @@
import io.prestosql.sql.planner.plan.PlanNodeId;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -53,6 +57,9 @@
import static io.airlift.concurrent.MoreFutures.checkSuccess;
import static io.prestosql.operator.WorkProcessor.TransformationState.needsMoreData;
import static io.prestosql.spi.connector.SortOrder.ASC_NULLS_LAST;
import static io.prestosql.sql.tree.FrameBound.Type.FOLLOWING;
import static io.prestosql.sql.tree.FrameBound.Type.PRECEDING;
import static io.prestosql.sql.tree.WindowFrame.Type.RANGE;
import static io.prestosql.util.MergeSortedPages.mergeSortedPages;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -266,7 +273,8 @@ public WindowOperator(
preGroupedChannels,
unGroupedPartitionChannels,
preSortedChannels,
sortChannels);
sortChannels,
windowFunctionDefinitions);

if (spillEnabled) {
PagesIndexWithHashStrategies mergedPagesIndexWithHashStrategies = new PagesIndexWithHashStrategies(
Expand All @@ -278,7 +286,8 @@ public WindowOperator(
ImmutableList.of(),
// merged pages are pre sorted on all sort channels
sortChannels,
sortChannels);
sortChannels,
windowFunctionDefinitions);

this.spillablePagesToPagesIndexes = Optional.of(new SpillablePagesToPagesIndexes(
inMemoryPagesIndexWithHashStrategies,
Expand Down Expand Up @@ -386,6 +395,7 @@ private static class PagesIndexWithHashStrategies
final PagesHashStrategy preSortedPartitionHashStrategy;
final PagesHashStrategy peerGroupHashStrategy;
final int[] preGroupedPartitionChannels;
final Map<FrameBoundKey, PagesIndexComparator> frameBoundComparators;

PagesIndexWithHashStrategies(
PagesIndex.Factory pagesIndexFactory,
Expand All @@ -394,14 +404,80 @@ private static class PagesIndexWithHashStrategies
List<Integer> preGroupedPartitionChannels,
List<Integer> unGroupedPartitionChannels,
List<Integer> preSortedChannels,
List<Integer> sortChannels)
List<Integer> sortChannels,
List<WindowFunctionDefinition> windowFunctionDefinitions)
{
this.pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions);
this.preGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preGroupedPartitionChannels, OptionalInt.empty());
this.unGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(unGroupedPartitionChannels, OptionalInt.empty());
this.preSortedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty());
this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, OptionalInt.empty());
this.preGroupedPartitionChannels = Ints.toArray(preGroupedPartitionChannels);
this.frameBoundComparators = createFrameBoundComparators(pagesIndex, windowFunctionDefinitions);
}
}

/**
* Create comparators necessary for seeking frame start or frame end for window functions with frame type RANGE.
* Whenever a frame bound is specified as RANGE X PRECEDING or RANGE X FOLLOWING,
* a dedicated comparator is created to compare sort key values with expected frame bound values.
*/
private static Map<FrameBoundKey, PagesIndexComparator> createFrameBoundComparators(PagesIndex pagesIndex, List<WindowFunctionDefinition> windowFunctionDefinitions)
{
ImmutableMap.Builder<FrameBoundKey, PagesIndexComparator> builder = ImmutableMap.builder();

for (int i = 0; i < windowFunctionDefinitions.size(); i++) {
FrameInfo frameInfo = windowFunctionDefinitions.get(i).getFrameInfo();
if (frameInfo.getType() == RANGE) {
if (frameInfo.getStartType() == PRECEDING || frameInfo.getStartType() == FOLLOWING) {
PagesIndexComparator comparator = pagesIndex.createChannelComparator(frameInfo.getSortKeyChannelForStartComparison(), frameInfo.getStartChannel());
builder.put(new FrameBoundKey(i, FrameBoundKey.Type.START), comparator);
}
if (frameInfo.getEndType() == PRECEDING || frameInfo.getEndType() == FOLLOWING) {
PagesIndexComparator comparator = pagesIndex.createChannelComparator(frameInfo.getSortKeyChannelForEndComparison(), frameInfo.getEndChannel());
builder.put(new FrameBoundKey(i, FrameBoundKey.Type.END), comparator);
}
}
}

return builder.build();
}

public static class FrameBoundKey
{
private final int functionIndex;
private final Type type;

public enum Type
{
START,
END;
}

public FrameBoundKey(int functionIndex, Type type)
{
this.functionIndex = functionIndex;
this.type = requireNonNull(type, "type is null");
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FrameBoundKey that = (FrameBoundKey) o;
return functionIndex == that.functionIndex &&
type == that.type;
}

@Override
public int hashCode()
{
return Objects.hash(functionIndex, type);
}
}

Expand Down Expand Up @@ -485,7 +561,7 @@ public ProcessState<WindowPartition> process()

int partitionEnd = findGroupEnd(pagesIndex, pagesIndexWithHashStrategies.unGroupedPartitionHashStrategy, partitionStart);

WindowPartition partition = new WindowPartition(pagesIndex, partitionStart, partitionEnd, outputChannels, windowFunctions, pagesIndexWithHashStrategies.peerGroupHashStrategy);
WindowPartition partition = new WindowPartition(pagesIndex, partitionStart, partitionEnd, outputChannels, windowFunctions, pagesIndexWithHashStrategies.peerGroupHashStrategy, pagesIndexWithHashStrategies.frameBoundComparators);
windowInfo.addPartition(partition);
partitionStart = partitionEnd;
return ProcessState.ofResult(partition);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.prestosql.operator.window;

import io.prestosql.sql.tree.FrameBound;
import io.prestosql.sql.tree.SortItem.Ordering;
import io.prestosql.sql.tree.WindowFrame;

import java.util.Objects;
Expand All @@ -27,21 +28,33 @@ public class FrameInfo
private final WindowFrame.Type type;
private final FrameBound.Type startType;
private final int startChannel;
private final int sortKeyChannelForStartComparison;
private final FrameBound.Type endType;
private final int endChannel;
private final int sortKeyChannelForEndComparison;
private final int sortKeyChannel;
private final Optional<Ordering> ordering;

public FrameInfo(
WindowFrame.Type type,
FrameBound.Type startType,
Optional<Integer> startChannel,
Optional<Integer> sortKeyChannelForStartComparison,
FrameBound.Type endType,
Optional<Integer> endChannel)
Optional<Integer> endChannel,
Optional<Integer> sortKeyChannelForEndComparison,
Optional<Integer> sortKeyChannel,
Optional<Ordering> ordering)
{
this.type = requireNonNull(type, "type is null");
this.startType = requireNonNull(startType, "startType is null");
this.startChannel = requireNonNull(startChannel, "startChannel is null").orElse(-1);
this.sortKeyChannelForStartComparison = requireNonNull(sortKeyChannelForStartComparison, "sortKeyChannelForStartComparison is null").orElse(-1);
this.endType = requireNonNull(endType, "endType is null");
this.endChannel = requireNonNull(endChannel, "endChannel is null").orElse(-1);
this.sortKeyChannelForEndComparison = requireNonNull(sortKeyChannelForEndComparison, "sortKeyChannelForEndComparison is null").orElse(-1);
this.sortKeyChannel = requireNonNull(sortKeyChannel, "sortKeyChannel is null").orElse(-1);
this.ordering = requireNonNull(ordering, "ordering is null");
}

public WindowFrame.Type getType()
Expand All @@ -59,6 +72,11 @@ public int getStartChannel()
return startChannel;
}

public int getSortKeyChannelForStartComparison()
{
return sortKeyChannelForStartComparison;
}

public FrameBound.Type getEndType()
{
return endType;
Expand All @@ -69,10 +87,25 @@ public int getEndChannel()
return endChannel;
}

public int getSortKeyChannelForEndComparison()
{
return sortKeyChannelForEndComparison;
}

public int getSortKeyChannel()
{
return sortKeyChannel;
}

public Optional<Ordering> getOrdering()
{
return ordering;
}

@Override
public int hashCode()
{
return Objects.hash(type, startType, startChannel, endType, endChannel);
return Objects.hash(type, startType, startChannel, sortKeyChannelForStartComparison, endType, endChannel, sortKeyChannelForEndComparison, sortKeyChannel, ordering);
}

@Override
Expand All @@ -91,8 +124,12 @@ public boolean equals(Object obj)
return this.type == other.type &&
this.startType == other.startType &&
Objects.equals(this.startChannel, other.startChannel) &&
Objects.equals(this.sortKeyChannelForStartComparison, other.sortKeyChannelForStartComparison) &&
this.endType == other.endType &&
Objects.equals(this.endChannel, other.endChannel);
Objects.equals(this.endChannel, other.endChannel) &&
Objects.equals(this.sortKeyChannelForEndComparison, other.sortKeyChannelForEndComparison) &&
Objects.equals(this.sortKeyChannel, other.sortKeyChannel) &&
Objects.equals(this.ordering, other.ordering);
}

@Override
Expand All @@ -102,8 +139,12 @@ public String toString()
.add("type", type)
.add("startType", startType)
.add("startChannel", startChannel)
.add("sortKeyChannelForStartComparison", sortKeyChannelForStartComparison)
.add("endType", endType)
.add("endChannel", endChannel)
.add("sortKeyChannelForEndComparison", sortKeyChannelForEndComparison)
.add("sortKeyChannel", sortKeyChannel)
.add("ordering", ordering)
.toString();
}
}
Loading