Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 64 additions & 20 deletions core/trino-main/src/main/java/io/trino/operator/TopNOperator.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.operator;

import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import io.trino.memory.context.MemoryTrackingContext;
import io.trino.operator.BasicWorkProcessorOperatorAdapter.BasicAdapterWorkProcessorOperatorFactory;
import io.trino.operator.WorkProcessor.TransformationState;
Expand All @@ -24,6 +25,7 @@
import io.trino.sql.planner.plan.PlanNodeId;

import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static io.trino.operator.BasicWorkProcessorOperatorAdapter.createAdapterOperatorFactory;
Expand All @@ -42,9 +44,10 @@ public static OperatorFactory createOperatorFactory(
int n,
List<Integer> sortChannels,
List<SortOrder> sortOrders,
TypeOperators typeOperators)
TypeOperators typeOperators,
Optional<DataSize> maxPartialMemory)
{
return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, types, n, sortChannels, sortOrders, typeOperators));
return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, types, n, sortChannels, sortOrders, typeOperators, maxPartialMemory));
}

private static class Factory
Expand All @@ -57,6 +60,7 @@ private static class Factory
private final List<Integer> sortChannels;
private final List<SortOrder> sortOrders;
private final TypeOperators typeOperators;
private final Optional<DataSize> maxPartialMemory;
private boolean closed;

private Factory(
Expand All @@ -66,7 +70,8 @@ private Factory(
int n,
List<Integer> sortChannels,
List<SortOrder> sortOrders,
TypeOperators typeOperators)
TypeOperators typeOperators,
Optional<DataSize> maxPartialMemory)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
Expand All @@ -75,6 +80,7 @@ private Factory(
this.sortChannels = ImmutableList.copyOf(requireNonNull(sortChannels, "sortChannels is null"));
this.sortOrders = ImmutableList.copyOf(requireNonNull(sortOrders, "sortOrders is null"));
this.typeOperators = typeOperators;
this.maxPartialMemory = requireNonNull(maxPartialMemory, "maxPartialMemory is null");
}

@Override
Expand All @@ -90,7 +96,8 @@ public WorkProcessorOperator create(
n,
sortChannels,
sortOrders,
typeOperators);
typeOperators,
maxPartialMemory);
}

@Override
Expand Down Expand Up @@ -120,11 +127,10 @@ public void close()
@Override
public Factory duplicate()
{
return new Factory(operatorId, planNodeId, sourceTypes, n, sortChannels, sortOrders, typeOperators);
return new Factory(operatorId, planNodeId, sourceTypes, n, sortChannels, sortOrders, typeOperators, maxPartialMemory);
}
}

private final TopNProcessor topNProcessor;
private final WorkProcessor<Page> pages;

private TopNOperator(
Expand All @@ -134,21 +140,25 @@ private TopNOperator(
int n,
List<Integer> sortChannels,
List<SortOrder> sortOrders,
TypeOperators typeOperators)
TypeOperators typeOperators,
Optional<DataSize> maxPartialMemory)
{
this.topNProcessor = new TopNProcessor(
requireNonNull(memoryTrackingContext, "memoryTrackingContext is null").aggregateUserMemoryContext(),
types,
n,
sortChannels,
sortOrders,
typeOperators);
requireNonNull(memoryTrackingContext, "memoryTrackingContext is null");

if (n == 0) {
pages = WorkProcessor.of();
}
else {
pages = sourcePages.transform(new TopNPages());
TopNProcessor topNProcessor = new TopNProcessor(
memoryTrackingContext.aggregateUserMemoryContext(),
types,
n,
sortChannels,
sortOrders,
typeOperators);
long maxPartialMemoryWithDefaultValueIfAbsent = requireNonNull(maxPartialMemory, "maxPartialMemory is null")
.map(DataSize::toBytes).orElse(Long.MAX_VALUE);
pages = sourcePages.transform(new TopNPages(topNProcessor, maxPartialMemoryWithDefaultValueIfAbsent));
}
}

Expand All @@ -158,18 +168,44 @@ public WorkProcessor<Page> getOutputPages()
return pages;
}

private class TopNPages
private static class TopNPages
implements WorkProcessor.Transformation<Page, Page>
{
private final TopNProcessor topNProcessor;
private final long maxPartialMemory;

private boolean isPartialFlushing;

private TopNPages(TopNProcessor topNProcessor, long maxPartialMemory)
{
this.topNProcessor = topNProcessor;
this.maxPartialMemory = maxPartialMemory;
}

private boolean isBuilderFull()
{
return topNProcessor.getEstimatedSizeInBytes() >= maxPartialMemory;
}

private void addPage(Page page)
{
checkState(!isPartialFlushing, "TopN buffer is already full");
topNProcessor.addInput(page);
if (isBuilderFull()) {
isPartialFlushing = true;
}
}

@Override
public TransformationState<Page> process(Page inputPage)
Copy link
Copy Markdown
Member

@sopel39 sopel39 Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can manage flushing entirely within process method, e.g:

if (!isFlushing && inputPage != null) {
  addPage(inputPage);
  if (partial && isBuilderFull()) {
    isFlushing = true;
  } else {
    // accumulate more data
    return needsMoreData(); 
  }
}

// flushing or finishing (inputPage == null)
Page page = null;
while (page == null && !topNProcessor.noMoreOutput()) {
  page = topNProcessor.getOutput();
}

if (page != null) {
  return TransformationState.ofResult(page, false)
}

// all accumulated data have been outputted (topNProcessor.noMoreOutput() == true)
// and there will be no more input data
if (inputPage == null) {
   return finished();
}

// all accumulated data have been outputted, resume consuming pages
isFlushing = false;
return needsMoreData();

Copy link
Copy Markdown
Member Author

@JunhyungSong JunhyungSong Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is process method will be never called until finish method is called in non late materialization mode. Even in late materialization mode, it will keep sending input pages even if TopNOperator is in partial flushing mode.

Copy link
Copy Markdown
Member

@sopel39 sopel39 Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is process method will be never called until finish

We should just make TopNOperator a WorkProcessorOperator (as for example FilterAndProjectOperator).
Please run BenchmarkTopNOperator afterwards

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented new Workprocessor.Process for TopNOperator(similar to FilterAndProjectOperator). BenchmarkTopNOperator showed almost no discrepancy.

{
if (inputPage != null) {
topNProcessor.addInput(inputPage);
return TransformationState.needsMoreData();
if (!isPartialFlushing && inputPage != null) {
addPage(inputPage);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can inline addPage. It's only used once

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to have it as a separate method.

if (!isPartialFlushing) {
return TransformationState.needsMoreData();
}
}
Comment thread
JunhyungSong marked this conversation as resolved.
Outdated

// no more input, return results
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment: // flushing or finishing

Page page = null;
while (page == null && !topNProcessor.noMoreOutput()) {
page = topNProcessor.getOutput();
Expand All @@ -179,6 +215,14 @@ public TransformationState<Page> process(Page inputPage)
return TransformationState.ofResult(page, false);
}

if (isPartialFlushing) {
checkState(inputPage != null, "inputPage that triggered partial flushing is null");
isPartialFlushing = false;
// resume receiving pages
return TransformationState.needsMoreData();
}

// all input pages are consumed
return TransformationState.finished();
}
}
Expand Down
40 changes: 27 additions & 13 deletions core/trino-main/src/main/java/io/trino/operator/TopNProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
Expand All @@ -35,39 +36,47 @@
*/
public class TopNProcessor
{
private final LocalMemoryContext localUserMemoryContext;
private final LocalMemoryContext localMemoryContext;
@Nullable
private final Supplier<GroupedTopNBuilder> topNBuilderSupplier;
Comment thread
JunhyungSong marked this conversation as resolved.
Outdated

@Nullable
private GroupedTopNBuilder topNBuilder;
@Nullable
private Iterator<Page> outputIterator;

public TopNProcessor(
AggregatedMemoryContext aggregatedMemoryContext,
List<Type> types,
int n,
List<Integer> sortChannels,
List<SortOrder> sortOrders, TypeOperators typeOperators)
List<SortOrder> sortOrders,
TypeOperators typeOperators)
{
requireNonNull(aggregatedMemoryContext, "aggregatedMemoryContext is null");
this.localUserMemoryContext = aggregatedMemoryContext.newLocalMemoryContext(TopNProcessor.class.getSimpleName());
this.localMemoryContext = aggregatedMemoryContext.newLocalMemoryContext(TopNProcessor.class.getSimpleName());
checkArgument(n >= 0, "n must be positive");

if (n == 0) {
outputIterator = emptyIterator();
topNBuilderSupplier = null;
}
else {
topNBuilder = new GroupedTopNRowNumberBuilder(
GroupByHash noChannelGroupByHash = new NoChannelGroupByHash();
PageWithPositionComparator comparator = new SimplePageWithPositionComparator(types, sortChannels, sortOrders, typeOperators);
topNBuilderSupplier = () -> new GroupedTopNRowNumberBuilder(
types,
new SimplePageWithPositionComparator(types, sortChannels, sortOrders, typeOperators),
comparator,
n,
false,
new NoChannelGroupByHash());
noChannelGroupByHash);
}
}

public void addInput(Page page)
{
requireNonNull(topNBuilder, "topNBuilder is null");
if (topNBuilder == null) {
topNBuilder = requireNonNull(topNBuilderSupplier.get(), "topNBuilderSupplier is null");
}
boolean done = topNBuilder.processPage(requireNonNull(page, "page is null")).process();
// there is no grouping so work will always be done
verify(done);
Expand All @@ -78,28 +87,33 @@ public Page getOutput()
{
if (outputIterator == null) {
// start flushing
outputIterator = topNBuilder.buildResult();
outputIterator = topNBuilder == null ? emptyIterator() : topNBuilder.buildResult();
}

Page output = null;
if (outputIterator.hasNext()) {
output = outputIterator.next();
}
else {
outputIterator = emptyIterator();
outputIterator = null;
topNBuilder = null;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After you build the result on line 90, it's not clear that any of the other side effects happening to topNBuilder will be reflected in the output (even though it may not occur in practice), it just looks more correct to an average programmer to clear the topNBuilder immediately after calling buildResult, so we should do that (unless it makes this code incorrect somehow).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though buildResult() is called, topNBuilder is still maintained in outputIterator. So, it needs to be memory-accounted until outputIterator is nullified. This is the reason why other operators like HashAggregationOperator and TopNRankingOperator maintain their builder until flushing is completed.

}
updateMemoryReservation();
return output;
}

public boolean noMoreOutput()
{
return outputIterator != null && !outputIterator.hasNext();
return topNBuilder == null;
}

public long getEstimatedSizeInBytes()
{
return topNBuilder == null ? 0 : topNBuilder.getEstimatedSizeInBytes();
}

private void updateMemoryReservation()
{
requireNonNull(topNBuilder, "topNBuilder is null");
localUserMemoryContext.setBytes(topNBuilder.getEstimatedSizeInBytes());
localMemoryContext.setBytes(getEstimatedSizeInBytes());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1633,11 +1633,18 @@ public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext cont
(int) node.getCount(),
sortChannels,
sortOrders,
plannerContext.getTypeOperators());
plannerContext.getTypeOperators(),
getMaxPartialTopNMemorySize(node.getStep()));

return new PhysicalOperation(operator, source.getLayout(), context, source);
}

private Optional<DataSize> getMaxPartialTopNMemorySize(TopNNode.Step step)
{
DataSize maxPartialTopNMemorySize = SystemSessionProperties.getMaxPartialTopNMemory(session);
return step == TopNNode.Step.PARTIAL && maxPartialTopNMemorySize.compareTo(DataSize.ofBytes(0)) > 0 ? Optional.of(maxPartialTopNMemorySize) : Optional.empty();
}

@Override
public PhysicalOperation visitSort(SortNode node, LocalExecutionPlanContext context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -97,7 +98,8 @@ public void setup()
Integer.valueOf(topN),
ImmutableList.of(0, 2),
ImmutableList.of(DESC_NULLS_LAST, ASC_NULLS_FIRST),
new TypeOperators());
new TypeOperators(),
Optional.of(DataSize.of(16, DataSize.Unit.MEGABYTE)));
}

@TearDown
Expand Down
Loading