-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Make partial aggregation adaptive #11011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,13 +13,21 @@ | |
| */ | ||
| package io.trino.operator; | ||
|
|
||
| import javax.annotation.Nullable; | ||
|
|
||
| import static java.util.Objects.requireNonNull; | ||
|
|
||
| public final class CompletedWork<T> | ||
| implements Work<T> | ||
| { | ||
| @Nullable | ||
| private final T result; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mark |
||
|
|
||
| public CompletedWork() | ||
| { | ||
| this.result = null; | ||
| } | ||
|
|
||
| public CompletedWork(T value) | ||
| { | ||
| this.result = requireNonNull(value); | ||
|
|
@@ -31,6 +39,7 @@ public boolean process() | |
| return true; | ||
| } | ||
|
|
||
| @Nullable | ||
| @Override | ||
| public T getResult() | ||
| { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,8 @@ | |
| import io.trino.operator.aggregation.builder.HashAggregationBuilder; | ||
| import io.trino.operator.aggregation.builder.InMemoryHashAggregationBuilder; | ||
| import io.trino.operator.aggregation.builder.SpillableHashAggregationBuilder; | ||
| import io.trino.operator.aggregation.partial.PartialAggregationController; | ||
| import io.trino.operator.aggregation.partial.SkipAggregationBuilder; | ||
| import io.trino.operator.scalar.CombineHashFunction; | ||
| import io.trino.spi.Page; | ||
| import io.trino.spi.PageBuilder; | ||
|
|
@@ -36,6 +38,7 @@ | |
| import java.util.List; | ||
| import java.util.Optional; | ||
|
|
||
| import static com.google.common.base.Preconditions.checkArgument; | ||
| import static com.google.common.base.Preconditions.checkState; | ||
| import static io.airlift.units.DataSize.Unit.MEGABYTE; | ||
| import static io.trino.operator.aggregation.builder.InMemoryHashAggregationBuilder.toTypes; | ||
|
|
@@ -70,6 +73,7 @@ public static class HashAggregationOperatorFactory | |
| private final SpillerFactory spillerFactory; | ||
| private final JoinCompiler joinCompiler; | ||
| private final BlockTypeOperators blockTypeOperators; | ||
| private final Optional<PartialAggregationController> partialAggregationController; | ||
|
|
||
| private boolean closed; | ||
|
|
||
|
|
@@ -87,7 +91,8 @@ public HashAggregationOperatorFactory( | |
| int expectedGroups, | ||
| Optional<DataSize> maxPartialMemory, | ||
| JoinCompiler joinCompiler, | ||
| BlockTypeOperators blockTypeOperators) | ||
| BlockTypeOperators blockTypeOperators, | ||
| Optional<PartialAggregationController> partialAggregationController) | ||
| { | ||
| this(operatorId, | ||
| planNodeId, | ||
|
|
@@ -108,7 +113,8 @@ public HashAggregationOperatorFactory( | |
| throw new UnsupportedOperationException(); | ||
| }, | ||
| joinCompiler, | ||
| blockTypeOperators); | ||
| blockTypeOperators, | ||
| partialAggregationController); | ||
| } | ||
|
|
||
| public HashAggregationOperatorFactory( | ||
|
|
@@ -128,7 +134,8 @@ public HashAggregationOperatorFactory( | |
| DataSize unspillMemoryLimit, | ||
| SpillerFactory spillerFactory, | ||
| JoinCompiler joinCompiler, | ||
| BlockTypeOperators blockTypeOperators) | ||
| BlockTypeOperators blockTypeOperators, | ||
| Optional<PartialAggregationController> partialAggregationController) | ||
| { | ||
| this(operatorId, | ||
| planNodeId, | ||
|
|
@@ -147,7 +154,8 @@ public HashAggregationOperatorFactory( | |
| DataSize.succinctBytes((long) (unspillMemoryLimit.toBytes() * MERGE_WITH_MEMORY_RATIO)), | ||
| spillerFactory, | ||
| joinCompiler, | ||
| blockTypeOperators); | ||
| blockTypeOperators, | ||
| partialAggregationController); | ||
| } | ||
|
|
||
| @VisibleForTesting | ||
|
|
@@ -169,7 +177,8 @@ public HashAggregationOperatorFactory( | |
| DataSize memoryLimitForMergeWithMemory, | ||
| SpillerFactory spillerFactory, | ||
| JoinCompiler joinCompiler, | ||
| BlockTypeOperators blockTypeOperators) | ||
| BlockTypeOperators blockTypeOperators, | ||
| Optional<PartialAggregationController> partialAggregationController) | ||
| { | ||
| this.operatorId = operatorId; | ||
| this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); | ||
|
|
@@ -189,6 +198,7 @@ public HashAggregationOperatorFactory( | |
| this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory is null"); | ||
| this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null"); | ||
| this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null"); | ||
| this.partialAggregationController = requireNonNull(partialAggregationController, "partialAggregationController is null"); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -214,7 +224,8 @@ public Operator createOperator(DriverContext driverContext) | |
| memoryLimitForMergeWithMemory, | ||
| spillerFactory, | ||
| joinCompiler, | ||
| blockTypeOperators); | ||
| blockTypeOperators, | ||
| partialAggregationController); | ||
| return hashAggregationOperator; | ||
| } | ||
|
|
||
|
|
@@ -245,11 +256,13 @@ public OperatorFactory duplicate() | |
| memoryLimitForMergeWithMemory, | ||
| spillerFactory, | ||
| joinCompiler, | ||
| blockTypeOperators); | ||
| blockTypeOperators, | ||
| partialAggregationController.map(PartialAggregationController::duplicate)); | ||
| } | ||
| } | ||
|
|
||
| private final OperatorContext operatorContext; | ||
| private final Optional<PartialAggregationController> partialAggregationController; | ||
| private final List<Type> groupByTypes; | ||
| private final List<Integer> groupByChannels; | ||
| private final List<Integer> globalAggregationGroupIds; | ||
|
|
@@ -279,6 +292,8 @@ public OperatorFactory duplicate() | |
|
|
||
| // for yield when memory is not available | ||
| private Work<?> unfinishedWork; | ||
| private long numberOfInputRowsProcessed; | ||
| private long numberOfUniqueRowsProduced; | ||
|
|
||
| private HashAggregationOperator( | ||
| OperatorContext operatorContext, | ||
|
|
@@ -297,12 +312,15 @@ private HashAggregationOperator( | |
| DataSize memoryLimitForMergeWithMemory, | ||
| SpillerFactory spillerFactory, | ||
| JoinCompiler joinCompiler, | ||
| BlockTypeOperators blockTypeOperators) | ||
| BlockTypeOperators blockTypeOperators, | ||
| Optional<PartialAggregationController> partialAggregationController) | ||
| { | ||
| this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); | ||
| this.partialAggregationController = requireNonNull(partialAggregationController, "partialAggregationControl is null"); | ||
| requireNonNull(step, "step is null"); | ||
| requireNonNull(aggregatorFactories, "aggregatorFactories is null"); | ||
| requireNonNull(operatorContext, "operatorContext is null"); | ||
| checkArgument(partialAggregationController.isEmpty() || step.isOutputPartial(), "partialAggregationController should be present only for partial aggregation"); | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: restore newline
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. restored after the |
||
| this.groupByTypes = ImmutableList.copyOf(groupByTypes); | ||
| this.groupByChannels = ImmutableList.copyOf(groupByChannels); | ||
|
|
@@ -368,8 +386,14 @@ public void addInput(Page page) | |
| inputProcessed = true; | ||
|
|
||
| if (aggregationBuilder == null) { | ||
| // TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling. | ||
| if (step.isOutputPartial() || !spillEnabled || !isSpillable()) { | ||
| boolean partialAggregationDisabled = partialAggregationController | ||
| .map(PartialAggregationController::isPartialAggregationDisabled) | ||
| .orElse(false); | ||
| if (step.isOutputPartial() && partialAggregationDisabled) { | ||
| aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. future improvement. It would be great to actually collect metrics
This can be returned via @lukasz-stec Maybe create an issue for that?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. This would allow easier monitoring of the adaptation. |
||
| } | ||
| else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) { | ||
|
sopel39 marked this conversation as resolved.
Outdated
|
||
| // TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling. | ||
| aggregationBuilder = new InMemoryHashAggregationBuilder( | ||
| aggregatorFactories, | ||
| step, | ||
|
|
@@ -418,6 +442,7 @@ public void addInput(Page page) | |
| unfinishedWork = null; | ||
| } | ||
| aggregationBuilder.updateMemory(); | ||
| numberOfInputRowsProcessed += page.getPositionCount(); | ||
| } | ||
|
|
||
| private boolean isSpillable() | ||
|
|
@@ -490,7 +515,9 @@ public Page getOutput() | |
| return null; | ||
| } | ||
|
|
||
| return outputPages.getResult(); | ||
| Page result = outputPages.getResult(); | ||
| numberOfUniqueRowsProduced += result.getPositionCount(); | ||
| return result; | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -516,6 +543,10 @@ private void closeAggregationBuilder() | |
| aggregationBuilder = null; | ||
| } | ||
| memoryContext.setBytes(0); | ||
| partialAggregationController.ifPresent( | ||
| controller -> controller.onFlush(numberOfInputRowsProcessed, numberOfUniqueRowsProduced)); | ||
| numberOfInputRowsProcessed = 0; | ||
| numberOfUniqueRowsProduced = 0; | ||
| } | ||
|
|
||
| private Page getGlobalAggregationOutput() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| /* | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package io.trino.operator.aggregation.partial; | ||
|
|
||
| import io.trino.operator.HashAggregationOperator; | ||
|
|
||
| /** | ||
| * Controls whenever partial aggregation is enabled across all {@link HashAggregationOperator}s | ||
| * for a particular plan node on a single node. | ||
| * Partial aggregation is disabled once enough rows has been processed ({@link #minNumberOfRowsProcessed}) | ||
| * and the ratio between output(unique) and input rows is too high (> {@link #uniqueRowsRatioThreshold}). | ||
| * TODO https://github.com/trinodb/trino/issues/11361 add support to adaptively re-enable partial aggregation. | ||
| * <p> | ||
| * The class is thread safe and objects of this class are used potentially by multiple threads/drivers simultaneously. | ||
| * Different threads either: | ||
| * - modify fields via synchronized {@link #onFlush}. | ||
| * - read volatile {@link #partialAggregationDisabled} (volatile here gives visibility). | ||
| */ | ||
| public class PartialAggregationController | ||
| { | ||
| private final long minNumberOfRowsProcessed; | ||
| private final double uniqueRowsRatioThreshold; | ||
|
|
||
| private volatile boolean partialAggregationDisabled; | ||
| private long totalRowProcessed; | ||
| private long totalUniqueRowsProduced; | ||
|
|
||
| public PartialAggregationController(long minNumberOfRowsProcessedToDisable, double uniqueRowsRatioThreshold) | ||
| { | ||
| this.minNumberOfRowsProcessed = minNumberOfRowsProcessedToDisable; | ||
| this.uniqueRowsRatioThreshold = uniqueRowsRatioThreshold; | ||
| } | ||
|
|
||
| public boolean isPartialAggregationDisabled() | ||
| { | ||
| return partialAggregationDisabled; | ||
| } | ||
|
|
||
| public synchronized void onFlush(long rowsProcessed, long uniqueRowsProduced) | ||
| { | ||
| if (partialAggregationDisabled) { | ||
| return; | ||
| } | ||
|
|
||
| totalRowProcessed += rowsProcessed; | ||
| totalUniqueRowsProduced += uniqueRowsProduced; | ||
| if (shouldDisablePartialAggregation()) { | ||
| partialAggregationDisabled = true; | ||
| } | ||
| } | ||
|
|
||
| private boolean shouldDisablePartialAggregation() | ||
| { | ||
| return totalRowProcessed >= minNumberOfRowsProcessed | ||
| && ((double) totalUniqueRowsProduced / totalRowProcessed) > uniqueRowsRatioThreshold; | ||
| } | ||
|
|
||
| public PartialAggregationController duplicate() | ||
| { | ||
| return new PartialAggregationController(minNumberOfRowsProcessed, uniqueRowsRatioThreshold); | ||
| } | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.