-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Implement segmented aggregation execution #17886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
|
|
||
| import com.facebook.presto.common.Page; | ||
| import com.facebook.presto.common.PageBuilder; | ||
| import com.facebook.presto.common.block.Block; | ||
| import com.facebook.presto.common.type.BigintType; | ||
| import com.facebook.presto.common.type.Type; | ||
| import com.facebook.presto.operator.aggregation.Accumulator; | ||
|
|
@@ -29,18 +30,22 @@ | |
| import com.facebook.presto.sql.gen.JoinCompiler; | ||
| import com.google.common.annotations.VisibleForTesting; | ||
| import com.google.common.collect.ImmutableList; | ||
| import com.google.common.collect.ImmutableSet; | ||
| import com.google.common.primitives.Ints; | ||
| import com.google.common.util.concurrent.ListenableFuture; | ||
| import io.airlift.units.DataSize; | ||
|
|
||
| import java.util.List; | ||
| import java.util.Optional; | ||
| import java.util.OptionalInt; | ||
| import java.util.stream.Collectors; | ||
|
|
||
| import static com.facebook.presto.operator.aggregation.builder.InMemoryHashAggregationBuilder.toTypes; | ||
| import static com.facebook.presto.sql.planner.optimizations.HashGenerationOptimizer.INITIAL_HASH_VALUE; | ||
| import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE; | ||
| import static com.google.common.base.Preconditions.checkState; | ||
| import static com.google.common.base.Verify.verify; | ||
| import static com.google.common.collect.ImmutableList.toImmutableList; | ||
| import static io.airlift.units.DataSize.Unit.MEGABYTE; | ||
| import static java.util.Objects.requireNonNull; | ||
|
|
||
|
|
@@ -56,6 +61,8 @@ public static class HashAggregationOperatorFactory | |
| private final PlanNodeId planNodeId; | ||
| private final List<Type> groupByTypes; | ||
| private final List<Integer> groupByChannels; | ||
| // A subset of groupByChannels, containing channels that are already sorted. | ||
| private final List<Integer> preGroupedChannels; | ||
| private final List<Integer> globalAggregationGroupIds; | ||
| private final Step step; | ||
| private final boolean produceDefaultOutput; | ||
|
|
@@ -80,6 +87,7 @@ public HashAggregationOperatorFactory( | |
| PlanNodeId planNodeId, | ||
| List<? extends Type> groupByTypes, | ||
| List<Integer> groupByChannels, | ||
| List<Integer> preGroupedChannels, | ||
| List<Integer> globalAggregationGroupIds, | ||
| Step step, | ||
| List<AccumulatorFactory> accumulatorFactories, | ||
|
|
@@ -94,6 +102,7 @@ public HashAggregationOperatorFactory( | |
| planNodeId, | ||
| groupByTypes, | ||
| groupByChannels, | ||
| preGroupedChannels, | ||
| globalAggregationGroupIds, | ||
| step, | ||
| false, | ||
|
|
@@ -117,6 +126,7 @@ public HashAggregationOperatorFactory( | |
| PlanNodeId planNodeId, | ||
| List<? extends Type> groupByTypes, | ||
| List<Integer> groupByChannels, | ||
| List<Integer> preGroupedChannels, | ||
| List<Integer> globalAggregationGroupIds, | ||
| Step step, | ||
| boolean produceDefaultOutput, | ||
|
|
@@ -135,6 +145,7 @@ public HashAggregationOperatorFactory( | |
| planNodeId, | ||
| groupByTypes, | ||
| groupByChannels, | ||
| preGroupedChannels, | ||
| globalAggregationGroupIds, | ||
| step, | ||
| produceDefaultOutput, | ||
|
|
@@ -157,6 +168,7 @@ public HashAggregationOperatorFactory( | |
| PlanNodeId planNodeId, | ||
| List<? extends Type> groupByTypes, | ||
| List<Integer> groupByChannels, | ||
| List<Integer> preGroupedChannels, | ||
| List<Integer> globalAggregationGroupIds, | ||
| Step step, | ||
| boolean produceDefaultOutput, | ||
|
|
@@ -178,6 +190,7 @@ public HashAggregationOperatorFactory( | |
| this.groupIdChannel = requireNonNull(groupIdChannel, "groupIdChannel is null"); | ||
| this.groupByTypes = ImmutableList.copyOf(groupByTypes); | ||
| this.groupByChannels = ImmutableList.copyOf(groupByChannels); | ||
| this.preGroupedChannels = ImmutableList.copyOf(preGroupedChannels); | ||
| this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds); | ||
| this.step = step; | ||
| this.produceDefaultOutput = produceDefaultOutput; | ||
|
|
@@ -202,6 +215,7 @@ public Operator createOperator(DriverContext driverContext) | |
| operatorContext, | ||
| groupByTypes, | ||
| groupByChannels, | ||
| preGroupedChannels, | ||
| globalAggregationGroupIds, | ||
| step, | ||
| produceDefaultOutput, | ||
|
|
@@ -233,6 +247,7 @@ public OperatorFactory duplicate() | |
| planNodeId, | ||
| groupByTypes, | ||
| groupByChannels, | ||
| preGroupedChannels, | ||
| globalAggregationGroupIds, | ||
| step, | ||
| produceDefaultOutput, | ||
|
|
@@ -253,6 +268,7 @@ public OperatorFactory duplicate() | |
| private final OperatorContext operatorContext; | ||
| private final List<Type> groupByTypes; | ||
| private final List<Integer> groupByChannels; | ||
| private final int[] preGroupedChannels; | ||
| private final List<Integer> globalAggregationGroupIds; | ||
| private final Step step; | ||
| private final boolean produceDefaultOutput; | ||
|
|
@@ -267,6 +283,7 @@ public OperatorFactory duplicate() | |
| private final SpillerFactory spillerFactory; | ||
| private final JoinCompiler joinCompiler; | ||
| private final boolean useSystemMemory; | ||
| private final Optional<PagesHashStrategy> preGroupedHashStrategy; | ||
|
|
||
| private final List<Type> types; | ||
| private final HashCollisionsCounter hashCollisionsCounter; | ||
|
|
@@ -276,6 +293,8 @@ public OperatorFactory duplicate() | |
| private boolean inputProcessed; | ||
| private boolean finishing; | ||
| private boolean finished; | ||
| private Page firstUnfinishedSegment; | ||
| private Page remainingPageForSegmentedAggregation; | ||
|
|
||
| // for yield when memory is not available | ||
| private Work<?> unfinishedWork; | ||
|
|
@@ -284,6 +303,7 @@ public HashAggregationOperator( | |
| OperatorContext operatorContext, | ||
| List<Type> groupByTypes, | ||
| List<Integer> groupByChannels, | ||
| List<Integer> preGroupedChannels, | ||
| List<Integer> globalAggregationGroupIds, | ||
| Step step, | ||
| boolean produceDefaultOutput, | ||
|
|
@@ -306,6 +326,7 @@ public HashAggregationOperator( | |
|
|
||
| this.groupByTypes = ImmutableList.copyOf(groupByTypes); | ||
| this.groupByChannels = ImmutableList.copyOf(groupByChannels); | ||
| this.preGroupedChannels = Ints.toArray(requireNonNull(preGroupedChannels, "preGroupedChannels is null")); | ||
| this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds); | ||
| this.accumulatorFactories = ImmutableList.copyOf(accumulatorFactories); | ||
| this.hashChannel = requireNonNull(hashChannel, "hashChannel is null"); | ||
|
|
@@ -323,6 +344,13 @@ public HashAggregationOperator( | |
| this.hashCollisionsCounter = new HashCollisionsCounter(operatorContext); | ||
| operatorContext.setInfoSupplier(hashCollisionsCounter); | ||
| this.useSystemMemory = useSystemMemory; | ||
|
|
||
| checkState(ImmutableSet.copyOf(groupByChannels).containsAll(preGroupedChannels), "groupByChannels must include all channels in preGroupedChannels"); | ||
| this.preGroupedHashStrategy = preGroupedChannels.isEmpty() | ||
| ? Optional.empty() | ||
| : Optional.of(joinCompiler.compilePagesHashStrategyFactory( | ||
| preGroupedChannels.stream().map(groupByTypes::get).collect(toImmutableList()), preGroupedChannels, Optional.empty()) | ||
| .createPagesHashStrategy(groupByTypes.stream().map(type -> ImmutableList.<Block>of()).collect(toImmutableList()), OptionalInt.empty())); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -348,13 +376,15 @@ public boolean isFinished() | |
| // - 2. Current page has been processed. | ||
| // - 3. Aggregation builder has not been triggered or has finished processing. | ||
| // - 4. If this is partial aggregation then it must have not reached the memory limit. | ||
| // - 5. If running in segmented aggregation mode, there must be no remaining page to process. | ||
| @Override | ||
| public boolean needsInput() | ||
| { | ||
| return !finishing | ||
| && unfinishedWork == null | ||
| && outputPages == null | ||
| && !partialAggregationReachedMemoryLimit(); | ||
| && !partialAggregationReachedMemoryLimit() | ||
| && remainingPageForSegmentedAggregation == null; | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -366,10 +396,10 @@ public void addInput(Page page) | |
| inputProcessed = true; | ||
|
|
||
| initializeAggregationBuilderIfNeeded(); | ||
| processInputPage(page); | ||
|
|
||
| // process the current page; save the unfinished work if we are waiting for memory | ||
| unfinishedWork = aggregationBuilder.processPage(page); | ||
| if (unfinishedWork.process()) { | ||
| if (unfinishedWork != null && unfinishedWork.process()) { | ||
| unfinishedWork = null; | ||
| } | ||
| aggregationBuilder.updateMemory(); | ||
|
|
@@ -436,6 +466,7 @@ public Page getOutput() | |
|
|
||
| if (outputPages.isFinished()) { | ||
| closeAggregationBuilder(); | ||
| processRemainingPageForSegmentedAggregation(); | ||
| return null; | ||
| } | ||
|
|
||
|
|
@@ -454,6 +485,53 @@ public HashAggregationBuilder getAggregationBuilder() | |
| return aggregationBuilder; | ||
| } | ||
|
|
||
| private void processInputPage(Page page) | ||
| { | ||
| // 1. normal aggregation | ||
| if (!preGroupedHashStrategy.isPresent()) { | ||
| unfinishedWork = aggregationBuilder.processPage(page); | ||
| return; | ||
| } | ||
|
|
||
| // 2. segmented aggregation | ||
| if (firstUnfinishedSegment == null) { | ||
| // If this is the first page, treat the first segment in this page as the current segment. | ||
| firstUnfinishedSegment = page.getRegion(0, 1); | ||
| } | ||
|
|
||
| Page pageOnPreGroupedChannels = page.extractChannels(preGroupedChannels); | ||
| int lastRowInPage = page.getPositionCount() - 1; | ||
| int lastSegmentStart = findLastSegmentStart(preGroupedHashStrategy.get(), pageOnPreGroupedChannels); | ||
| if (lastSegmentStart == 0) { | ||
| // The whole page is in one segment. | ||
| if (preGroupedHashStrategy.get().rowEqualsRow(0, firstUnfinishedSegment.extractChannels(preGroupedChannels), 0, pageOnPreGroupedChannels)) { | ||
| // All rows in this page belong to the previous unfinished segment, process the whole page. | ||
| unfinishedWork = aggregationBuilder.processPage(page); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assume we don't have to care about the pages for a segment once it's done right? Do we need to close the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, that how it is implemented currently. Once it has fully processed at least one segments, the
Regarding the memory usage, thought the memory comparison has been covered from the manual test attached in the PR description. Could you please elaborate a bit how to run benchmark against the memory usage? Any pointers I can refer to? Thanks! |
||
| } | ||
| else { | ||
| // If the current page starts with a new segment, flush before processing it. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, we are here if the first row in the page is same as the last unfinished segment. In that case, the current page doesn't not start with a new segment, right? Did I miss anything? |
||
| remainingPageForSegmentedAggregation = page; | ||
| } | ||
| } | ||
| else { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle the smaller branch first and return to avoid indents for the larger branch. Reducing indents can improve the readability of the code. |
||
| // If the current segment ends in the current page, flush it with all the segments (if exist) except the last segment of the current page. | ||
| unfinishedWork = aggregationBuilder.processPage(page.getRegion(0, lastSegmentStart)); | ||
| remainingPageForSegmentedAggregation = page.getRegion(lastSegmentStart, lastRowInPage - lastSegmentStart + 1); | ||
| } | ||
| // Record the last segment. | ||
| firstUnfinishedSegment = page.getRegion(lastRowInPage, 1); | ||
| } | ||
|
|
||
| private int findLastSegmentStart(PagesHashStrategy pagesHashStrategy, Page page) | ||
| { | ||
| for (int i = page.getPositionCount() - 1; i > 0; i--) { | ||
|
zacw7 marked this conversation as resolved.
Outdated
|
||
| if (!pagesHashStrategy.rowEqualsRow(i - 1, page, i, page)) { | ||
| return i; | ||
| } | ||
| } | ||
| return 0; | ||
| } | ||
|
|
||
| private void closeAggregationBuilder() | ||
| { | ||
| outputPages = null; | ||
|
|
@@ -468,6 +546,16 @@ private void closeAggregationBuilder() | |
| operatorContext.localRevocableMemoryContext().setBytes(0); | ||
| } | ||
|
|
||
| private void processRemainingPageForSegmentedAggregation() | ||
| { | ||
| // Running in segmented aggregation mode, reopen the aggregation builder and process the remaining page. | ||
| if (remainingPageForSegmentedAggregation != null) { | ||
| initializeAggregationBuilderIfNeeded(); | ||
| unfinishedWork = aggregationBuilder.processPage(remainingPageForSegmentedAggregation); | ||
| remainingPageForSegmentedAggregation = null; | ||
| } | ||
| } | ||
|
|
||
| private void initializeAggregationBuilderIfNeeded() | ||
| { | ||
| if (aggregationBuilder != null) { | ||
|
|
@@ -509,9 +597,10 @@ private void initializeAggregationBuilderIfNeeded() | |
| // Flush if one of the following is true: | ||
| // - received finish() signal (no more input to come). | ||
| // - it is a partial aggregation and has reached memory limit | ||
| // - running in segmented aggregation mode and at least one segment has been fully processed | ||
| private boolean shouldFlush() | ||
| { | ||
| return finishing || partialAggregationReachedMemoryLimit(); | ||
| return finishing || partialAggregationReachedMemoryLimit() || remainingPageForSegmentedAggregation != null; | ||
| } | ||
|
|
||
| private boolean partialAggregationReachedMemoryLimit() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Whether to use pre-grouped channel or not is actually determined at the beginning of the execution or planning phase. There is no need to check if/else for every page input. Instead, it would be good to abstract the design a bit. For example, can we have a base abstract hash operator with two implementations: Hash and Segmented. Then we can leave the branches within the different implementations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point.
I thought about this idea as well before implementing the PR. Not sure if it's worth it to restructure the whole operator given that Segmented aggregation is essentially still Hash aggregation just with a few tricks. The difference is not that big. WDYT? @kewang1024