Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.type.Type;
import jakarta.annotation.Nullable;

import java.util.Iterator;
import java.util.List;
Expand All @@ -42,7 +43,8 @@ public class GroupedTopNRankBuilder
private final List<Type> sourceTypes;
private final boolean produceRanking;
private final int[] groupByChannels;
private final GroupByHash groupByHash;
@Nullable
private GroupByHash groupByHash; // null after output starts
private final PageWithPositionComparator comparator;
private final RowReferencePageManager pageManager = new RowReferencePageManager();
private final GroupedTopNRankAccumulator groupedTopNRankAccumulator;
Expand Down Expand Up @@ -102,6 +104,9 @@ public long hashCode(long rowId)
@Override
public Work<?> processPage(Page page)
{
if (groupByHash == null) {
throw new IllegalStateException("already producing results");
}
return new TransformWork<>(
groupByHash.getGroupIds(page.getColumns(groupByChannels)),
groupIds -> {
Expand All @@ -113,14 +118,19 @@ public Work<?> processPage(Page page)
@Override
public Iterator<Page> buildResult()
{
return new ResultIterator();
if (groupByHash == null) {
throw new IllegalStateException("already producing results");
}
int groupIdCount = groupByHash.getGroupCount();
groupByHash = null;
return new ResultIterator(groupIdCount);
}

@Override
public long getEstimatedSizeInBytes()
{
return INSTANCE_SIZE
+ groupByHash.getEstimatedSize()
+ (groupByHash == null ? 0L : groupByHash.getEstimatedSize())
+ pageManager.sizeOf()
+ groupedTopNRankAccumulator.sizeOf();
}
Expand Down Expand Up @@ -148,15 +158,16 @@ private class ResultIterator
extends AbstractIterator<Page>
{
private final PageBuilder pageBuilder;
private final int groupIdCount = groupByHash.getGroupCount();
private final int groupIdCount;
private int currentGroupId = -1;
private final LongBigArray rowIdOutput = new LongBigArray();
private final LongBigArray rankingOutput = new LongBigArray();
private long currentGroupSize;
private int currentIndexInGroup;

ResultIterator()
private ResultIterator(int groupIdCount)
{
this.groupIdCount = groupIdCount;
ImmutableList.Builder<Type> sourceTypesBuilders = ImmutableList.<Type>builder().addAll(sourceTypes);
if (produceRanking) {
sourceTypesBuilders.add(BIGINT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.type.Type;
import jakarta.annotation.Nullable;

import java.util.Iterator;
import java.util.List;
Expand All @@ -42,7 +43,8 @@ public class GroupedTopNRowNumberBuilder
private final List<Type> sourceTypes;
private final boolean produceRowNumber;
private final int[] groupByChannels;
private final GroupByHash groupByHash;
@Nullable
private GroupByHash groupByHash; // null after output starts
private final RowReferencePageManager pageManager = new RowReferencePageManager();
private final GroupedTopNRowNumberAccumulator groupedTopNRowNumberAccumulator;
private final PageWithPositionComparator comparator;
Expand Down Expand Up @@ -77,6 +79,9 @@ public GroupedTopNRowNumberBuilder(
@Override
public Work<?> processPage(Page page)
{
if (groupByHash == null) {
throw new IllegalStateException("already producing output");
}
return new TransformWork<>(
groupByHash.getGroupIds(page.getColumns(groupByChannels)),
groupIds -> {
Expand All @@ -88,14 +93,19 @@ public Work<?> processPage(Page page)
@Override
public Iterator<Page> buildResult()
{
return new ResultIterator();
if (groupByHash == null) {
throw new IllegalStateException("already producing output");
}
int groupIdCount = groupByHash.getGroupCount();
groupByHash = null;
return new ResultIterator(groupIdCount);
}

@Override
public long getEstimatedSizeInBytes()
{
return INSTANCE_SIZE
+ groupByHash.getEstimatedSize()
+ (groupByHash == null ? 0L : groupByHash.getEstimatedSize())
+ pageManager.sizeOf()
+ groupedTopNRowNumberAccumulator.sizeOf();
}
Expand All @@ -119,6 +129,7 @@ private void processPage(Page newPage, int groupCount, int[] groupIds)
pageManager.compactIfNeeded();
}

@Nullable
@VisibleForTesting
GroupByHash getGroupByHash()
{
Expand All @@ -129,14 +140,15 @@ private class ResultIterator
extends AbstractIterator<Page>
{
private final PageBuilder pageBuilder;
private final int groupIdCount = groupByHash.getGroupCount();
private final int groupIdCount;
private int currentGroupId = -1;
private final LongBigArray rowIdOutput = new LongBigArray();
private long currentGroupSize;
private int currentIndexInGroup;

ResultIterator()
private ResultIterator(int groupIdCount)
{
this.groupIdCount = groupIdCount;
ImmutableList.Builder<Type> sourceTypesBuilders = ImmutableList.<Type>builder().addAll(sourceTypes);
if (produceRowNumber) {
sourceTypesBuilders.add(BIGINT);
Expand Down