Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ protected List<? extends OperatorFactory> createOperatorFactories()
getColumnTypes("lineitem", "returnflag", "linestatus"),
Ints.asList(0, 1),
ImmutableList.of(),
ImmutableList.of(),
Step.SINGLE,
ImmutableList.of(
doubleSum.bind(ImmutableList.of(2), Optional.empty()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ protected List<? extends OperatorFactory> createOperatorFactories()
ImmutableList.of(tableTypes.get(0)),
Ints.asList(0),
ImmutableList.of(),
ImmutableList.of(),
Step.SINGLE,
ImmutableList.of(doubleSum.bind(ImmutableList.of(1), Optional.empty())),
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ public SystemSessionProperties(
SEGMENTED_AGGREGATION_ENABLED,
"Enable segmented aggregation.",
featuresConfig.isSegmentedAggregationEnabled(),
true),
false),
new PropertyMetadata<>(
AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY,
format("Set the strategy used to rewrite AGG IF to AGG FILTER. Options are %s",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.aggregation.Accumulator;
Expand All @@ -29,18 +30,22 @@
import com.facebook.presto.sql.gen.JoinCompiler;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;

import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.stream.Collectors;

import static com.facebook.presto.operator.aggregation.builder.InMemoryHashAggregationBuilder.toTypes;
import static com.facebook.presto.sql.planner.optimizations.HashGenerationOptimizer.INITIAL_HASH_VALUE;
import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.util.Objects.requireNonNull;

Expand All @@ -56,6 +61,8 @@ public static class HashAggregationOperatorFactory
private final PlanNodeId planNodeId;
private final List<Type> groupByTypes;
private final List<Integer> groupByChannels;
// A subset of groupByChannels, containing channels that are already sorted.
private final List<Integer> preGroupedChannels;
private final List<Integer> globalAggregationGroupIds;
private final Step step;
private final boolean produceDefaultOutput;
Expand All @@ -80,6 +87,7 @@ public HashAggregationOperatorFactory(
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
List<Integer> preGroupedChannels,
List<Integer> globalAggregationGroupIds,
Step step,
List<AccumulatorFactory> accumulatorFactories,
Expand All @@ -94,6 +102,7 @@ public HashAggregationOperatorFactory(
planNodeId,
groupByTypes,
groupByChannels,
preGroupedChannels,
globalAggregationGroupIds,
step,
false,
Expand All @@ -117,6 +126,7 @@ public HashAggregationOperatorFactory(
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
List<Integer> preGroupedChannels,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
Expand All @@ -135,6 +145,7 @@ public HashAggregationOperatorFactory(
planNodeId,
groupByTypes,
groupByChannels,
preGroupedChannels,
globalAggregationGroupIds,
step,
produceDefaultOutput,
Expand All @@ -157,6 +168,7 @@ public HashAggregationOperatorFactory(
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
List<Integer> preGroupedChannels,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
Expand All @@ -178,6 +190,7 @@ public HashAggregationOperatorFactory(
this.groupIdChannel = requireNonNull(groupIdChannel, "groupIdChannel is null");
this.groupByTypes = ImmutableList.copyOf(groupByTypes);
this.groupByChannels = ImmutableList.copyOf(groupByChannels);
this.preGroupedChannels = ImmutableList.copyOf(preGroupedChannels);
this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds);
this.step = step;
this.produceDefaultOutput = produceDefaultOutput;
Expand All @@ -202,6 +215,7 @@ public Operator createOperator(DriverContext driverContext)
operatorContext,
groupByTypes,
groupByChannels,
preGroupedChannels,
globalAggregationGroupIds,
step,
produceDefaultOutput,
Expand Down Expand Up @@ -233,6 +247,7 @@ public OperatorFactory duplicate()
planNodeId,
groupByTypes,
groupByChannels,
preGroupedChannels,
globalAggregationGroupIds,
step,
produceDefaultOutput,
Expand All @@ -253,6 +268,7 @@ public OperatorFactory duplicate()
private final OperatorContext operatorContext;
private final List<Type> groupByTypes;
private final List<Integer> groupByChannels;
private final int[] preGroupedChannels;
private final List<Integer> globalAggregationGroupIds;
private final Step step;
private final boolean produceDefaultOutput;
Expand All @@ -267,6 +283,7 @@ public OperatorFactory duplicate()
private final SpillerFactory spillerFactory;
private final JoinCompiler joinCompiler;
private final boolean useSystemMemory;
private final Optional<PagesHashStrategy> preGroupedHashStrategy;

private final List<Type> types;
private final HashCollisionsCounter hashCollisionsCounter;
Expand All @@ -276,6 +293,8 @@ public OperatorFactory duplicate()
private boolean inputProcessed;
private boolean finishing;
private boolean finished;
private Page firstUnfinishedSegment;
private Page remainingPageForSegmentedAggregation;

// for yield when memory is not available
private Work<?> unfinishedWork;
Expand All @@ -284,6 +303,7 @@ public HashAggregationOperator(
OperatorContext operatorContext,
List<Type> groupByTypes,
List<Integer> groupByChannels,
List<Integer> preGroupedChannels,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
Expand All @@ -306,6 +326,7 @@ public HashAggregationOperator(

this.groupByTypes = ImmutableList.copyOf(groupByTypes);
this.groupByChannels = ImmutableList.copyOf(groupByChannels);
this.preGroupedChannels = Ints.toArray(requireNonNull(preGroupedChannels, "preGroupedChannels is null"));
this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds);
this.accumulatorFactories = ImmutableList.copyOf(accumulatorFactories);
this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
Expand All @@ -323,6 +344,13 @@ public HashAggregationOperator(
this.hashCollisionsCounter = new HashCollisionsCounter(operatorContext);
operatorContext.setInfoSupplier(hashCollisionsCounter);
this.useSystemMemory = useSystemMemory;

checkState(ImmutableSet.copyOf(groupByChannels).containsAll(preGroupedChannels), "groupByChannels must include all channels in preGroupedChannels");
this.preGroupedHashStrategy = preGroupedChannels.isEmpty()
? Optional.empty()
: Optional.of(joinCompiler.compilePagesHashStrategyFactory(
preGroupedChannels.stream().map(groupByTypes::get).collect(toImmutableList()), preGroupedChannels, Optional.empty())
.createPagesHashStrategy(groupByTypes.stream().map(type -> ImmutableList.<Block>of()).collect(toImmutableList()), OptionalInt.empty()));
}

@Override
Expand All @@ -348,13 +376,15 @@ public boolean isFinished()
// - 2. Current page has been processed.
// - 3. Aggregation builder has not been triggered or has finished processing.
// - 4. If this is partial aggregation then it must have not reached the memory limit.
// - 5. If running in segmented aggregation mode, there must be no remaining page to process.
@Override
public boolean needsInput()
{
return !finishing
&& unfinishedWork == null
&& outputPages == null
&& !partialAggregationReachedMemoryLimit();
&& !partialAggregationReachedMemoryLimit()
&& remainingPageForSegmentedAggregation == null;
}

@Override
Expand All @@ -366,10 +396,10 @@ public void addInput(Page page)
inputProcessed = true;

initializeAggregationBuilderIfNeeded();
processInputPage(page);

// process the current page; save the unfinished work if we are waiting for memory
unfinishedWork = aggregationBuilder.processPage(page);
if (unfinishedWork.process()) {
if (unfinishedWork != null && unfinishedWork.process()) {
unfinishedWork = null;
}
aggregationBuilder.updateMemory();
Expand Down Expand Up @@ -436,6 +466,7 @@ public Page getOutput()

if (outputPages.isFinished()) {
closeAggregationBuilder();
processRemainingPageForSegmentedAggregation();
return null;
}

Expand All @@ -454,6 +485,53 @@ public HashAggregationBuilder getAggregationBuilder()
return aggregationBuilder;
}

private void processInputPage(Page page)
{
// 1. normal aggregation
if (!preGroupedHashStrategy.isPresent()) {
unfinishedWork = aggregationBuilder.processPage(page);
return;
}

// 2. segmented aggregation
if (firstUnfinishedSegment == null) {
Comment on lines 491 to 497
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Whether to use pre-grouped channel or not is actually determined at the beginning of the execution or planning phase. There is no need to check if/else for every page input. Instead, it would be good to abstract the design a bit. For example, can we have a base abstract hash operator with two implementations: Hash and Segmented. Then we can leave the branches within the different implementations

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether to use pre-grouped channel or not is actually determined at the beginning of the execution or planning phase. There is no need to check if/else for every page input.

Great point.

can we have a base abstract hash operator with two implementations: Hash and Segmented.

I thought about this idea as well before implementing the PR. Not sure if it's worth it to restructure the whole operator given that Segmented aggregation is essentially still Hash aggregation just with a few tricks. The difference is not that big. WDYT? @kewang1024

// If this is the first page, treat the first segment in this page as the current segment.
firstUnfinishedSegment = page.getRegion(0, 1);
}

Page pageOnPreGroupedChannels = page.extractChannels(preGroupedChannels);
int lastRowInPage = page.getPositionCount() - 1;
int lastSegmentStart = findLastSegmentStart(preGroupedHashStrategy.get(), pageOnPreGroupedChannels);
if (lastSegmentStart == 0) {
// The whole page is in one segment.
if (preGroupedHashStrategy.get().rowEqualsRow(0, firstUnfinishedSegment.extractChannels(preGroupedChannels), 0, pageOnPreGroupedChannels)) {
// All rows in this page belong to the previous unfinished segment, process the whole page.
unfinishedWork = aggregationBuilder.processPage(page);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assume we don't have to care about the pages for a segment once it's done right? Do we need to close the aggregationBuilder to clear the memory after having processed a segment? It would be good to reflect the memory usage in your benchmark as well.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to close the aggregationBuilder to clear the memory after having processed a segment

Yes, that how it is implemented currently. Once it has fully processed at least one segments, the aggregationBuilder will be closed and rebuilt if there are more segments to process.

to reflect the memory usage in your benchmark as well.

Regarding the memory usage, thought the memory comparison has been covered from the manual test attached in the PR description. Could you please elaborate a bit how to run benchmark against the memory usage? Any pointers I can refer to? Thanks!

}
else {
// If the current page starts with a new segment, flush before processing it.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we are here if the first row in the page is same as the last unfinished segment. In that case, the current page doesn't not start with a new segment, right? Did I miss anything?

remainingPageForSegmentedAggregation = page;
}
}
else {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle the smaller branch first and return to avoid indents for the larger branch. Reducing indents can improve the readability of the code.

// If the current segment ends in the current page, flush it with all the segments (if exist) except the last segment of the current page.
unfinishedWork = aggregationBuilder.processPage(page.getRegion(0, lastSegmentStart));
remainingPageForSegmentedAggregation = page.getRegion(lastSegmentStart, lastRowInPage - lastSegmentStart + 1);
}
// Record the last segment.
firstUnfinishedSegment = page.getRegion(lastRowInPage, 1);
}

private int findLastSegmentStart(PagesHashStrategy pagesHashStrategy, Page page)
{
for (int i = page.getPositionCount() - 1; i > 0; i--) {
Comment thread
zacw7 marked this conversation as resolved.
Outdated
if (!pagesHashStrategy.rowEqualsRow(i - 1, page, i, page)) {
return i;
}
}
return 0;
}

private void closeAggregationBuilder()
{
outputPages = null;
Expand All @@ -468,6 +546,16 @@ private void closeAggregationBuilder()
operatorContext.localRevocableMemoryContext().setBytes(0);
}

private void processRemainingPageForSegmentedAggregation()
{
// Running in segmented aggregation mode, reopen the aggregation builder and process the remaining page.
if (remainingPageForSegmentedAggregation != null) {
initializeAggregationBuilderIfNeeded();
unfinishedWork = aggregationBuilder.processPage(remainingPageForSegmentedAggregation);
remainingPageForSegmentedAggregation = null;
}
}

private void initializeAggregationBuilderIfNeeded()
{
if (aggregationBuilder != null) {
Expand Down Expand Up @@ -509,9 +597,10 @@ private void initializeAggregationBuilderIfNeeded()
// Flush if one of the following is true:
// - received finish() signal (no more input to come).
// - it is a partial aggregation and has reached memory limit
// - running in segmented aggregation mode and at least one segment has been fully processed
private boolean shouldFlush()
{
return finishing || partialAggregationReachedMemoryLimit();
return finishing || partialAggregationReachedMemoryLimit() || remainingPageForSegmentedAggregation != null;
}

private boolean partialAggregationReachedMemoryLimit()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2541,6 +2541,7 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl
aggregation.getAggregations(),
ImmutableSet.of(),
groupingVariables,
ImmutableList.of(),
PARTIAL,
Optional.empty(),
Optional.empty(),
Expand Down Expand Up @@ -2646,6 +2647,7 @@ public PhysicalOperation visitTableWriteMerge(TableWriterMergeNode node, LocalEx
aggregation.getAggregations(),
ImmutableSet.of(),
groupingVariables,
ImmutableList.of(),
INTERMEDIATE,
Optional.empty(),
Optional.empty(),
Expand Down Expand Up @@ -2700,6 +2702,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl
aggregation.getAggregations(),
ImmutableSet.of(),
groupingVariables,
ImmutableList.of(),
FINAL,
Optional.empty(),
Optional.empty(),
Expand Down Expand Up @@ -3075,6 +3078,7 @@ private PhysicalOperation planGroupByAggregation(
node.getAggregations(),
node.getGlobalGroupingSets(),
node.getGroupingKeys(),
node.getPreGroupedVariables(),
node.getStep(),
node.getHashVariable(),
node.getGroupIdVariable(),
Expand All @@ -3099,6 +3103,7 @@ private OperatorFactory createHashAggregationOperatorFactory(
Map<VariableReferenceExpression, Aggregation> aggregations,
Set<Integer> globalGroupingSets,
List<VariableReferenceExpression> groupbyVariables,
List<VariableReferenceExpression> preGroupedVariables,
Step step,
Optional<VariableReferenceExpression> hashVariable,
Optional<VariableReferenceExpression> groupIdVariable,
Expand Down Expand Up @@ -3167,11 +3172,13 @@ private OperatorFactory createHashAggregationOperatorFactory(
}
else {
Optional<Integer> hashChannel = hashVariable.map(variableChannelGetter(source));
List<Integer> preGroupedChannels = getChannelsForVariables(preGroupedVariables, source.getLayout());
return new HashAggregationOperatorFactory(
context.getNextOperatorId(),
planNodeId,
groupByTypes,
groupByChannels,
preGroupedChannels,
ImmutableList.copyOf(globalGroupingSets),
step,
hasDefaultOutput,
Expand Down
Loading