Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.array.ObjectBigArray;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.spi.function.aggregation.GroupByIdBlock;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import it.unimi.dsi.fastutil.ints.IntArrayFIFOQueue;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
Expand All @@ -31,9 +35,12 @@
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.PrimitiveIterator;
import java.util.stream.IntStream;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.operator.GroupByHash.createGroupByHash;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -55,7 +62,7 @@ public class GroupedTopNBuilder
private final int topN;
private final boolean produceRowNumber;
private final GroupByHash groupByHash;

private final OperatorContext operatorContext;
// a map of heaps, each of which records the top N rows
private final ObjectBigArray<RowHeap> groupedRows = new ObjectBigArray<>();
// a list of input pages, each of which has information of which row in which heap references which position
Expand All @@ -69,19 +76,55 @@ public class GroupedTopNBuilder
// keeps track sizes of input pages and heaps
private long memorySizeInBytes;
private int currentPageCount;
private LocalMemoryContext localUserMemoryContext;

public GroupedTopNBuilder(
OperatorContext operatorContext,
List<Type> sourceTypes,
List<Type> partitionTypes,
List<Integer> partitionChannels,
Optional<Integer> hashChannel,
int expectedPositions,
boolean isDictionaryAggregationEnabled,
JoinCompiler joinCompiler,
PageWithPositionComparator comparator,
int topN,
boolean produceRowNumber)
{
this(
operatorContext,
sourceTypes,
partitionTypes,
partitionChannels,
hashChannel,
expectedPositions,
isDictionaryAggregationEnabled,
joinCompiler,
comparator,
topN,
produceRowNumber,
UpdateMemory.NOOP);
}

public GroupedTopNBuilder(
OperatorContext operatorContext,
List<Type> sourceTypes,
List<Type> partitionTypes,
List<Integer> partitionChannels,
Optional<Integer> hashChannel,
int expectedPositions,
boolean isDictionaryAggregationEnabled,
JoinCompiler joinCompiler,
Comment on lines +110 to +117
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's anti pattern to pass operator-related fields into a non-operator class.... We probably don't need to change anything in this class. It's up to the calling operators to spill. OrderByOperator has the closest logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's a good point. I borrowed this pattern from SpillableHashAggregationBuilder as I was implementing for TopNRowNumberOperator as well, which does grouped TopN and in that aspect SpillableHashAggregationBuilder is closer to this problem. I can reference OrderByOperator for code patterns.

Copy link
Collaborator Author

@shrinidhijoshi shrinidhijoshi Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update from offline discussions:
Instead of moving GroupByHash into the GroupedTopNBuilder, we can create a GroupByHash Supplier/Factory and in the Operator pass it to the TopNBuilder.
This will avoid

  1. leaking operator related fields (operatorContext, user/revocableMemoryContext, etc) and logic into TopNBuilder
  2. Avoid functions with long list of argument fields

PageWithPositionComparator comparator,
int topN,
boolean produceRowNumber,
GroupByHash groupByHash)
UpdateMemory updateMemory)
{
this.operatorContext = operatorContext;
this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]);
checkArgument(topN > 0, "topN must be > 0");
this.topN = topN;
this.produceRowNumber = produceRowNumber;
this.groupByHash = requireNonNull(groupByHash, "groupByHash is not null");

this.pageWithPositionComparator = requireNonNull(comparator, "comparator is null");
// Note: this is comparator intentionally swaps left and right arguments form a "reverse order" comparator
Expand All @@ -91,6 +134,21 @@ public GroupedTopNBuilder(
pageReferences.get(right.getPageId()).getPage(),
right.getPosition());
this.emptyPageReferenceSlots = new IntFIFOQueue();

if (!partitionChannels.isEmpty()) {
checkArgument(expectedPositions > 0, "expectedPositions must be > 0");
this.groupByHash = createGroupByHash(
partitionTypes,
Ints.toArray(partitionChannels),
hashChannel,
expectedPositions,
isDictionaryAggregationEnabled,
joinCompiler,
updateMemory);
}
else {
this.groupByHash = new NoChannelGroupByHash();
}
}

public Work<?> processPage(Page page)
Expand All @@ -105,7 +163,7 @@ public Work<?> processPage(Page page)

public Iterator<Page> buildResult()
{
return new ResultIterator();
return new ResultIterator(IntStream.range(0, groupByHash.getGroupCount()).iterator());
}

public long getEstimatedSizeInBytes()
Expand Down Expand Up @@ -398,10 +456,8 @@ private class ResultIterator
private static final int UNUSED_CAPACITY_DISPOSAL_THRESHOLD = 4096;

private final PageBuilder pageBuilder;
// we may have 0 groups if there is no input page processed
private final int groupCount = groupByHash.getGroupCount();
private final PrimitiveIterator.OfInt groupIds;

private int currentGroupNumber;
private long currentGroupSizeInBytes;

// the row number of the current position in the group
Expand All @@ -411,7 +467,7 @@ private class ResultIterator

private ObjectBigArray<Row> currentRows;

ResultIterator()
ResultIterator(PrimitiveIterator.OfInt groupIds)
{
if (produceRowNumber) {
pageBuilder = new PageBuilder(new ImmutableList.Builder<Type>().add(sourceTypes).add(BIGINT).build());
Expand All @@ -421,6 +477,7 @@ private class ResultIterator
}
// Populate the first group
currentRows = new ObjectBigArray<>();
this.groupIds = groupIds;
nextGroupedRows();
}

Expand Down Expand Up @@ -471,11 +528,10 @@ protected Page computeNext()

private void nextGroupedRows()
{
if (currentGroupNumber < groupCount) {
RowHeap rows = groupedRows.getAndSet(currentGroupNumber, null);
verify(rows != null && !rows.isEmpty(), "impossible to have inserted a group without a witness row");
if (this.groupIds.hasNext()) {
RowHeap rows = groupedRows.getAndSet(this.groupIds.nextInt(), null);
verify(rows != null && !rows.isEmpty(), "impossible to have inserted a group without a witness row. rows=%s for %s", rows, this);
currentGroupSizeInBytes = rows.getEstimatedSizeInBytes();
currentGroupNumber++;
currentGroupSize = rows.size();

// sort output rows in a big array in case there are too many rows
Expand All @@ -495,4 +551,9 @@ private void nextGroupedRows()
}
}
}

public GroupByHash getGroupByHash()
{
return groupByHash;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static java.util.Collections.emptyIterator;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -106,18 +107,23 @@ public TopNOperator(
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.localUserMemoryContext = operatorContext.localUserMemoryContext();
checkArgument(n >= 0, "n must be positive");

if (n == 0) {
finishing = true;
outputIterator = emptyIterator();
}
else {
topNBuilder = new GroupedTopNBuilder(
operatorContext,
types,
emptyList(),
emptyList(),
null,
0,
false,
null,
new SimplePageWithPositionComparator(types, sortChannels, sortOrders),
n,
false,
new NoChannelGroupByHash());
false);
}
}

Expand Down Expand Up @@ -152,7 +158,6 @@ public void addInput(Page page)
boolean done = topNBuilder.processPage(requireNonNull(page, "page is null")).process();
// there is no grouping so work will always be done
verify(done);
updateMemoryReservation();
}

@Override
Expand All @@ -174,15 +179,9 @@ public Page getOutput()
else {
outputIterator = emptyIterator();
}
updateMemoryReservation();
return output;
}

private void updateMemoryReservation()
{
localUserMemoryContext.setBytes(topNBuilder.getEstimatedSizeInBytes());
}

private boolean noMoreOutput()
{
return outputIterator != null && !outputIterator.hasNext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import static com.facebook.presto.SystemSessionProperties.isDictionaryAggregationEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.operator.GroupByHash.createGroupByHash;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -130,7 +129,6 @@ public OperatorFactory duplicate()

private final int[] outputChannels;

private final GroupByHash groupByHash;
private final GroupedTopNBuilder groupedTopNBuilder;

private boolean finishing;
Expand Down Expand Up @@ -165,28 +163,21 @@ public TopNRowNumberOperator(

checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0");

if (!partitionChannels.isEmpty()) {
checkArgument(expectedPositions > 0, "expectedPositions must be > 0");
groupByHash = createGroupByHash(
partitionTypes,
Ints.toArray(partitionChannels),
hashChannel,
expectedPositions,
isDictionaryAggregationEnabled(operatorContext.getSession()),
joinCompiler,
this::updateMemoryReservation);
}
else {
groupByHash = new NoChannelGroupByHash();
}

List<Type> types = toTypes(sourceTypes, outputChannels, generateRowNumber);

this.groupedTopNBuilder = new GroupedTopNBuilder(
operatorContext,
ImmutableList.copyOf(sourceTypes),
partitionTypes,
partitionChannels,
hashChannel,
expectedPositions,
isDictionaryAggregationEnabled(operatorContext.getSession()),
joinCompiler,
new SimplePageWithPositionComparator(types, sortChannels, sortOrders),
maxRowCountPerPartition,
generateRowNumber,
groupByHash);
this::updateMemoryReservation);
}

@Override
Expand Down Expand Up @@ -226,15 +217,13 @@ public void addInput(Page page)
if (unfinishedWork.process()) {
unfinishedWork = null;
}
updateMemoryReservation();
}

@Override
public Page getOutput()
{
if (unfinishedWork != null) {
boolean finished = unfinishedWork.process();
updateMemoryReservation();
if (!finished) {
return null;
}
Expand All @@ -254,13 +243,13 @@ public Page getOutput()
if (outputIterator.hasNext()) {
output = outputIterator.next().extractChannels(outputChannels);
}
updateMemoryReservation();
return output;
}

@VisibleForTesting
public int getCapacity()
{
GroupByHash groupByHash = groupedTopNBuilder.getGroupByHash();
checkState(groupByHash != null);
return groupByHash.getCapacity();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -97,11 +99,35 @@ public void setup()
GroupByHash groupByHash;
if (groupCount > 1) {
groupByHash = new BigintGroupByHash(HASH_GROUP, true, groupCount, UpdateMemory.NOOP);
topNBuilder = new GroupedTopNBuilder(
null,
types,
ImmutableList.of(types.get(HASH_GROUP)),
ImmutableList.of(HASH_GROUP),
Optional.of(HASH_GROUP),
groupCount,
false,
null,
comparator,
topN,
false,
UpdateMemory.NOOP);
}
else {
groupByHash = new NoChannelGroupByHash();
topNBuilder = new GroupedTopNBuilder(
null,
types,
Collections.emptyList(),
Collections.emptyList(),
null,
groupCount,
false,
null,
comparator,
topN,
false,
UpdateMemory.NOOP);
}
topNBuilder = new GroupedTopNBuilder(types, comparator, topN, false, groupByHash);
}

public GroupedTopNBuilder getTopNBuilder()
Expand Down
Loading