diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 6d1511df299c9..1211b67a2b5a2 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -172,6 +172,8 @@ public final class SystemSessionProperties public static final String PREFER_PARTIAL_AGGREGATION = "prefer_partial_aggregation"; public static final String PARTIAL_AGGREGATION_STRATEGY = "partial_aggregation_strategy"; public static final String PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD = "partial_aggregation_byte_reduction_threshold"; + public static final String ADAPTIVE_PARTIAL_AGGREGATION = "adaptive_partial_aggregation"; + public static final String ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold"; public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number"; public static final String OPTIMIZE_CASE_EXPRESSION_PREDICATE = "optimize_case_expression_predicate"; public static final String MAX_GROUPING_SETS = "max_grouping_sets"; @@ -960,6 +962,16 @@ public SystemSessionProperties( "Byte reduction ratio threshold at which to disable partial aggregation", featuresConfig.getPartialAggregationByteReductionThreshold(), false), + booleanProperty( + ADAPTIVE_PARTIAL_AGGREGATION, + "Enable adaptive partial aggregation", + featuresConfig.isAdaptivePartialAggregationEnabled(), + false), + doubleProperty( + ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD, + "Rows reduction ratio threshold at which to adaptively disable partial aggregation", + featuresConfig.getAdaptivePartialAggregationRowsReductionRatioThreshold(), + false), booleanProperty( OPTIMIZE_TOP_N_ROW_NUMBER, "Use top N row number optimization", @@ -2318,6 +2330,16 @@ public static double getPartialAggregationByteReductionThreshold(Session session return session.getSystemProperty(PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD, Double.class); } + public static boolean isAdaptivePartialAggregationEnabled(Session session) + { + return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION, Boolean.class); + } + + public static double getAdaptivePartialAggregationRowsReductionRatioThreshold(Session session) + { + return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD, Double.class); + } + public static boolean isOptimizeTopNRowNumber(Session session) { return session.getSystemProperty(OPTIMIZE_TOP_N_ROW_NUMBER, Boolean.class); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/CompletedWork.java b/presto-main/src/main/java/com/facebook/presto/operator/CompletedWork.java index c2ef8c7b353fd..8b636096bd8e9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/CompletedWork.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/CompletedWork.java @@ -13,11 +13,14 @@ */ package com.facebook.presto.operator; +import javax.annotation.Nullable; + import static java.util.Objects.requireNonNull; public final class CompletedWork implements Work { + @Nullable private final T result; public CompletedWork(T value) @@ -25,12 +28,21 @@ public CompletedWork(T value) this.result = requireNonNull(value); } + /** + * This constructor can be used when the result is computed immediately and we do not need the yield machinery + */ + public CompletedWork() + { + this.result = null; + } + @Override public boolean process() { return true; } + @Nullable @Override public T getResult() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java index 33ece49e03702..5211267d16c2f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java @@ -22,6 +22,8 @@ import com.facebook.presto.operator.aggregation.builder.HashAggregationBuilder; import com.facebook.presto.operator.aggregation.builder.InMemoryHashAggregationBuilder; import com.facebook.presto.operator.aggregation.builder.SpillableHashAggregationBuilder; +import com.facebook.presto.operator.aggregation.partial.PartialAggregationController; +import com.facebook.presto.operator.aggregation.partial.SkipAggregationBuilder; import com.facebook.presto.operator.scalar.CombineHashFunction; import com.facebook.presto.spi.function.aggregation.Accumulator; import com.facebook.presto.spi.plan.AggregationNode.Step; @@ -38,11 +40,13 @@ import java.util.List; import java.util.Optional; import java.util.OptionalInt; +import java.util.OptionalLong; import java.util.stream.Collectors; import static com.facebook.presto.operator.aggregation.builder.InMemoryHashAggregationBuilder.toTypes; import static com.facebook.presto.sql.planner.PlannerUtils.INITIAL_HASH_VALUE; import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -78,6 +82,7 @@ public static class HashAggregationOperatorFactory private final SpillerFactory spillerFactory; private final JoinCompiler joinCompiler; private final boolean useSystemMemory; + private final Optional partialAggregationController; private boolean closed; @@ -112,6 +117,7 @@ public HashAggregationOperatorFactory( expectedGroups, maxPartialMemory, false, + Optional.empty(), new DataSize(0, MEGABYTE), new DataSize(0, MEGABYTE), (types, spillContext, memoryContext) -> { @@ -136,6 +142,7 @@ public HashAggregationOperatorFactory( int expectedGroups, Optional maxPartialMemory, boolean spillEnabled, + Optional partialAggregationController, DataSize unspillMemoryLimit, SpillerFactory spillerFactory, JoinCompiler joinCompiler, @@ -155,6 +162,7 @@ public HashAggregationOperatorFactory( expectedGroups, maxPartialMemory, spillEnabled, + partialAggregationController, unspillMemoryLimit, DataSize.succinctBytes((long) (unspillMemoryLimit.toBytes() * MERGE_WITH_MEMORY_RATIO)), spillerFactory, @@ -178,6 +186,7 @@ public HashAggregationOperatorFactory( int expectedGroups, Optional maxPartialMemory, boolean spillEnabled, + Optional partialAggregationController, DataSize memoryLimitForMerge, DataSize memoryLimitForMergeWithMemory, SpillerFactory spillerFactory, @@ -198,6 +207,7 @@ public HashAggregationOperatorFactory( this.expectedGroups = expectedGroups; this.maxPartialMemory = requireNonNull(maxPartialMemory, "maxPartialMemory is null"); this.spillEnabled = spillEnabled; + this.partialAggregationController = requireNonNull(partialAggregationController, "partialAggregationController is null"); this.memoryLimitForMerge = requireNonNull(memoryLimitForMerge, "memoryLimitForMerge is null"); this.memoryLimitForMergeWithMemory = requireNonNull(memoryLimitForMergeWithMemory, "memoryLimitForMergeWithMemory is null"); this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory is null"); @@ -225,6 +235,7 @@ public Operator createOperator(DriverContext driverContext) expectedGroups, maxPartialMemory, spillEnabled, + partialAggregationController, memoryLimitForMerge, memoryLimitForMergeWithMemory, spillerFactory, @@ -257,6 +268,7 @@ public OperatorFactory duplicate() expectedGroups, maxPartialMemory, spillEnabled, + partialAggregationController.map(PartialAggregationController::duplicate), memoryLimitForMerge, memoryLimitForMergeWithMemory, spillerFactory, @@ -278,6 +290,7 @@ public OperatorFactory duplicate() private final int expectedGroups; private final Optional maxPartialMemory; private final boolean spillEnabled; + private final Optional partialAggregationController; private final DataSize memoryLimitForMerge; private final DataSize memoryLimitForMergeWithMemory; private final SpillerFactory spillerFactory; @@ -299,6 +312,10 @@ public OperatorFactory duplicate() // for yield when memory is not available private Work unfinishedWork; + private long inputBytesProcessed; + private long inputRowsProcessed; + private long uniqueRowsProduced; + public HashAggregationOperator( OperatorContext operatorContext, List groupByTypes, @@ -313,6 +330,7 @@ public HashAggregationOperator( int expectedGroups, Optional maxPartialMemory, boolean spillEnabled, + Optional partialAggregationController, DataSize memoryLimitForMerge, DataSize memoryLimitForMergeWithMemory, SpillerFactory spillerFactory, @@ -337,6 +355,9 @@ public HashAggregationOperator( this.maxPartialMemory = requireNonNull(maxPartialMemory, "maxPartialMemory is null"); this.types = toTypes(groupByTypes, step, accumulatorFactories, hashChannel); this.spillEnabled = spillEnabled; + this.partialAggregationController = requireNonNull(partialAggregationController, "partialAggregationController is null"); + checkArgument(!partialAggregationController.isPresent() || step.isOutputPartial(), + "partialAggregationController should only be present for partial aggregation"); this.memoryLimitForMerge = requireNonNull(memoryLimitForMerge, "memoryLimitForMerge is null"); this.memoryLimitForMergeWithMemory = requireNonNull(memoryLimitForMergeWithMemory, "memoryLimitForMergeWithMemory is null"); this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory is null"); @@ -402,7 +423,10 @@ public void addInput(Page page) if (unfinishedWork != null && unfinishedWork.process()) { unfinishedWork = null; } + aggregationBuilder.updateMemory(); + inputBytesProcessed += page.getSizeInBytes(); + inputRowsProcessed += page.getPositionCount(); } @Override @@ -470,7 +494,9 @@ public Page getOutput() return null; } - return outputPages.getResult(); + Page result = outputPages.getResult(); + uniqueRowsProduced += result.getPositionCount(); + return result; } @Override @@ -534,6 +560,16 @@ private int findLastSegmentStart(PagesHashStrategy pagesHashStrategy, Page page) private void closeAggregationBuilder() { + partialAggregationController.ifPresent( + controller -> controller.onFlush( + inputBytesProcessed, + inputRowsProcessed, + // Empty uniqueRowsProduced indicates to PartialAggregationController that partial agg is disabled + aggregationBuilder instanceof SkipAggregationBuilder ? OptionalLong.empty() : OptionalLong.of(uniqueRowsProduced))); + inputBytesProcessed = 0; + inputRowsProcessed = 0; + uniqueRowsProduced = 0; + outputPages = null; if (aggregationBuilder != null) { aggregationBuilder.recordHashCollisions(hashCollisionsCounter); @@ -563,7 +599,18 @@ private void initializeAggregationBuilderIfNeeded() return; } - if (step.isOutputPartial() || !spillEnabled) { + boolean partialAggregationDisabled = partialAggregationController + .map(PartialAggregationController::isPartialAggregationDisabled) + .orElse(false); + + if (step.isOutputPartial() && partialAggregationDisabled) { + aggregationBuilder = new SkipAggregationBuilder( + groupByChannels, + hashChannel, + accumulatorFactories, + operatorContext.localUserMemoryContext()); + } + else if (step.isOutputPartial() || !spillEnabled) { aggregationBuilder = new InMemoryHashAggregationBuilder( accumulatorFactories, step, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/partial/PartialAggregationController.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/partial/PartialAggregationController.java new file mode 100644 index 0000000000000..917afe5dac25f --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/partial/PartialAggregationController.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.partial; + +import io.airlift.units.DataSize; + +import java.util.OptionalLong; + +import static java.util.Objects.requireNonNull; + +public class PartialAggregationController +{ + /** + * Process enough pages to fill up the partial aggregation buffer, before considering disabling partial aggregation. + * With 16 MB as default partial agg buffer, this means we process at least 24 MB of input data before considering to disable partial agg. + * We use bytes instead of rows as the floor to disable partial aggregation due to issues with file skew when rows are small. We want to make sure + * the partial aggregation buffer is fully utilized before making the decision on disabling partial aggregation. + */ + private static final double DISABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO = 1.5; + /** + * Re-enable partial aggregation periodically, in case later data can be partially aggregated more effectively. + */ + private static final double ENABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO = DISABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO * 200; + + private final DataSize maxPartialAggregationMemorySize; + private final double uniqueRowsRatioThreshold; + + private volatile boolean partialAggregationDisabled; + private long totalBytesProcessed; + private long totalRowsProcessed; + private long totalUniqueRowsProduced; + + public PartialAggregationController(DataSize maxPartialAggregationMemorySize, double uniqueRowsRatioThreshold) + { + this.maxPartialAggregationMemorySize = requireNonNull(maxPartialAggregationMemorySize, "maxPartialMemory is null"); + this.uniqueRowsRatioThreshold = uniqueRowsRatioThreshold; + } + + public boolean isPartialAggregationDisabled() + { + return partialAggregationDisabled; + } + + public synchronized void onFlush(long bytesProcessed, long rowsProcessed, OptionalLong uniqueRowsProduced) + { + if (!partialAggregationDisabled && !uniqueRowsProduced.isPresent()) { + // when partial aggregation has been re-enabled, ignore stats from disabled flushes + return; + } + + totalBytesProcessed += bytesProcessed; + totalRowsProcessed += rowsProcessed; + uniqueRowsProduced.ifPresent(value -> totalUniqueRowsProduced += value); + + if (!partialAggregationDisabled && shouldDisablePartialAggregation()) { + partialAggregationDisabled = true; + } + + if (partialAggregationDisabled && totalBytesProcessed >= maxPartialAggregationMemorySize.toBytes() * ENABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO) { + totalBytesProcessed = 0; + totalRowsProcessed = 0; + totalUniqueRowsProduced = 0; + partialAggregationDisabled = false; + } + } + + private boolean shouldDisablePartialAggregation() + { + return totalBytesProcessed >= maxPartialAggregationMemorySize.toBytes() * DISABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO + && ((double) totalUniqueRowsProduced / totalRowsProcessed) > uniqueRowsRatioThreshold; + } + + public PartialAggregationController duplicate() + { + return new PartialAggregationController(maxPartialAggregationMemorySize, uniqueRowsRatioThreshold); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/partial/SkipAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/partial/SkipAggregationBuilder.java new file mode 100644 index 0000000000000..28ef06478e558 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/partial/SkipAggregationBuilder.java @@ -0,0 +1,187 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.partial; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.LongArrayBlock; +import com.facebook.presto.memory.context.LocalMemoryContext; +import com.facebook.presto.operator.CompletedWork; +import com.facebook.presto.operator.HashCollisionsCounter; +import com.facebook.presto.operator.UpdateMemory; +import com.facebook.presto.operator.Work; +import com.facebook.presto.operator.WorkProcessor; +import com.facebook.presto.operator.aggregation.AccumulatorFactory; +import com.facebook.presto.operator.aggregation.builder.HashAggregationBuilder; +import com.facebook.presto.spi.function.aggregation.GroupByIdBlock; +import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; +import com.google.common.util.concurrent.ListenableFuture; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * This class is an implementation of {@link HashAggregationBuilder} that does not aggregate input rows at all. + * It passes the input pages, augmented with initial accumulator state to the output. + * It can only be used at the partial aggregation step, as it relies on rows be aggregated at the final step. + * The reason to do this is for cases where partial aggregation is ineffective due to a large number of unique inputs. + * By using this builder, we can skip the expensive hash computation step which is not useful in these cases. + * And we cannot just send raw pages because the final aggregation step still expects the partial aggregation schema. + */ +public class SkipAggregationBuilder + implements HashAggregationBuilder +{ + private final LocalMemoryContext memoryContext; + private final List groupedAccumulators; + + @Nullable + private Page currentPage; + private final int[] hashChannels; + + public SkipAggregationBuilder( + List groupByChannels, + Optional inputHashChannel, + List accumulatorFactories, + LocalMemoryContext memoryContext) + { + this.memoryContext = requireNonNull(memoryContext, "memoryContext is null"); + requireNonNull(accumulatorFactories, "accumulatorFactories is null"); + this.groupedAccumulators = accumulatorFactories.stream() + .map(accumulatorFactory -> accumulatorFactory.createGroupedAccumulator(UpdateMemory.NOOP)) + .collect(toImmutableList()); + this.hashChannels = new int[groupByChannels.size() + (inputHashChannel.isPresent() ? 1 : 0)]; + for (int i = 0; i < groupByChannels.size(); i++) { + hashChannels[i] = groupByChannels.get(i); + } + inputHashChannel.ifPresent(channelIndex -> hashChannels[groupByChannels.size()] = channelIndex); + } + + @Override + public Work processPage(Page page) + { + checkState(currentPage == null); + currentPage = page; + return new CompletedWork<>(); + } + + @Override + public WorkProcessor buildResult() + { + if (currentPage == null) { + return WorkProcessor.of(); + } + + Page result = buildOutputPage(currentPage); + currentPage = null; + return WorkProcessor.of(result); + } + + @Override + public boolean isFull() + { + return currentPage != null; + } + + @Override + public void updateMemory() + { + if (currentPage != null) { + memoryContext.setBytes(currentPage.getSizeInBytes()); + } + } + + @Override + public void recordHashCollisions(HashCollisionsCounter hashCollisionsCounter) + { + } + + @Override + public void close() + { + } + + @Override + public ListenableFuture startMemoryRevoke() + { + throw new UnsupportedOperationException("startMemoryRevoke not supported for SkipAggregationBuilder"); + } + + @Override + public void finishMemoryRevoke() + { + throw new UnsupportedOperationException("finishMemoryRevoke not supported for SkipAggregationBuilder"); + } + + private Page buildOutputPage(Page page) + { + populateInitialAccumulatorState(page); + BlockBuilder[] outputBuilders = serializeAccumulatorState(page.getPositionCount()); + return constructOutputPage(page, outputBuilders); + } + + private void populateInitialAccumulatorState(Page page) + { + int positionCount = page.getPositionCount(); + GroupByIdBlock groupByIdBlock = new GroupByIdBlock(positionCount, new LongArrayBlock(positionCount, Optional.empty(), fillConsecutive(positionCount))); + for (GroupedAccumulator groupedAccumulator : groupedAccumulators) { + groupedAccumulator.addInput(groupByIdBlock, page); + } + } + + private BlockBuilder[] serializeAccumulatorState(int positionCount) + { + BlockBuilder[] outputBuilders = new BlockBuilder[groupedAccumulators.size()]; + for (int i = 0; i < outputBuilders.length; i++) { + outputBuilders[i] = groupedAccumulators.get(i).getIntermediateType().createBlockBuilder(null, positionCount); + } + + for (int position = 0; position < positionCount; position++) { + for (int i = 0; i < groupedAccumulators.size(); i++) { + GroupedAccumulator groupedAccumulator = groupedAccumulators.get(i); + BlockBuilder output = outputBuilders[i]; + groupedAccumulator.evaluateIntermediate(position, output); + } + } + + return outputBuilders; + } + + private Page constructOutputPage(Page page, BlockBuilder[] outputBuilders) + { + Block[] outputBlocks = new Block[hashChannels.length + outputBuilders.length]; + for (int i = 0; i < hashChannels.length; i++) { + outputBlocks[i] = page.getBlock(hashChannels[i]); + } + for (int i = 0; i < outputBuilders.length; i++) { + outputBlocks[hashChannels.length + i] = outputBuilders[i].build(); + } + return new Page(page.getPositionCount(), outputBlocks); + } + + private static long[] fillConsecutive(int positionCount) + { + long[] longs = new long[positionCount]; + for (int i = 0; i < positionCount; i++) { + longs[i] = i; + } + return longs; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index c9fd4a8927e75..fc2909ac14d04 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -169,6 +169,8 @@ public class FeaturesConfig private boolean preferPartialAggregation = true; private PartialAggregationStrategy partialAggregationStrategy = PartialAggregationStrategy.ALWAYS; private double partialAggregationByteReductionThreshold = 0.5; + private boolean adaptivePartialAggregationEnabled; + private double adaptivePartialAggregationRowsReductionRatioThreshold = 0.8; private boolean optimizeTopNRowNumber = true; private boolean pushLimitThroughOuterJoin = true; private boolean optimizeConstantGroupingKeys = true; @@ -1060,6 +1062,30 @@ public FeaturesConfig setPartialAggregationByteReductionThreshold(double partial return this; } + public boolean isAdaptivePartialAggregationEnabled() + { + return adaptivePartialAggregationEnabled; + } + + @Config("experimental.adaptive-partial-aggregation") + public FeaturesConfig setAdaptivePartialAggregationEnabled(boolean adaptivePartialAggregationEnabled) + { + this.adaptivePartialAggregationEnabled = adaptivePartialAggregationEnabled; + return this; + } + + public double getAdaptivePartialAggregationRowsReductionRatioThreshold() + { + return adaptivePartialAggregationRowsReductionRatioThreshold; + } + + @Config("experimental.adaptive-partial-aggregation-rows-reduction-ratio-threshold") + public FeaturesConfig setAdaptivePartialAggregationRowsReductionRatioThreshold(double adaptivePartialAggregationRowsReductionRatioThreshold) + { + this.adaptivePartialAggregationRowsReductionRatioThreshold = adaptivePartialAggregationRowsReductionRatioThreshold; + return this; + } + public boolean isOptimizeTopNRowNumber() { return optimizeTopNRowNumber; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index a55465242a087..b0501ce123c66 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -110,6 +110,7 @@ import com.facebook.presto.operator.WindowOperator.WindowOperatorFactory; import com.facebook.presto.operator.aggregation.AccumulatorFactory; import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation; +import com.facebook.presto.operator.aggregation.partial.PartialAggregationController; import com.facebook.presto.operator.exchange.LocalExchange.LocalExchangeFactory; import com.facebook.presto.operator.exchange.LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory; import com.facebook.presto.operator.exchange.LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory; @@ -246,6 +247,7 @@ import java.util.stream.IntStream; import static com.facebook.airlift.concurrent.MoreFutures.addSuccessCallback; +import static com.facebook.presto.SystemSessionProperties.getAdaptivePartialAggregationRowsReductionRatioThreshold; import static com.facebook.presto.SystemSessionProperties.getAggregationOperatorUnspillMemoryLimit; import static com.facebook.presto.SystemSessionProperties.getDynamicFilteringMaxPerDriverRowCount; import static com.facebook.presto.SystemSessionProperties.getDynamicFilteringMaxPerDriverSize; @@ -257,6 +259,7 @@ import static com.facebook.presto.SystemSessionProperties.getTaskPartitionedWriterCount; import static com.facebook.presto.SystemSessionProperties.getTaskWriterCount; import static com.facebook.presto.SystemSessionProperties.getTopNOperatorUnspillMemoryLimit; +import static com.facebook.presto.SystemSessionProperties.isAdaptivePartialAggregationEnabled; import static com.facebook.presto.SystemSessionProperties.isAggregationSpillEnabled; import static com.facebook.presto.SystemSessionProperties.isDistinctAggregationSpillEnabled; import static com.facebook.presto.SystemSessionProperties.isEnableDynamicFiltering; @@ -3364,6 +3367,7 @@ private OperatorFactory createHashAggregationOperatorFactory( expectedGroups, maxPartialAggregationMemorySize, useSpill, + createPartialAggregationController(maxPartialAggregationMemorySize, step, session), unspillMemoryLimit, spillerFactory, joinCompiler, @@ -3382,6 +3386,17 @@ private boolean hasOrderBy(Map aggrega } } + private static Optional createPartialAggregationController( + Optional maxPartialAggregationMemorySize, + AggregationNode.Step step, + Session session) + { + if (maxPartialAggregationMemorySize.isPresent() && step.isOutputPartial() && isAdaptivePartialAggregationEnabled(session)) { + return Optional.of(new PartialAggregationController(maxPartialAggregationMemorySize.get(), getAdaptivePartialAggregationRowsReductionRatioThreshold(session))); + } + return Optional.empty(); + } + private static TableFinisher createTableFinisher(Session session, Metadata metadata, ExecutionWriterTarget target) { return (fragments, statistics) -> { diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java index 7bed98bdd1ead..33bdf9e9e1d8a 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java @@ -149,6 +149,7 @@ private OperatorFactory createHashAggregationOperatorFactory(Optional h 100_000, Optional.of(new DataSize(16, MEGABYTE)), false, + Optional.empty(), succinctBytes(8), succinctBytes(Integer.MAX_VALUE), spillerFactory, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java index 40034088b35c0..e147d04b2b121 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java @@ -170,6 +170,7 @@ private OperatorFactory createHashAggregationOperatorFactory(Optional h 100_000, Optional.of(new DataSize(16, MEGABYTE)), false, + Optional.empty(), succinctBytes(8), succinctBytes(Integer.MAX_VALUE), spillerFactory, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java index 8ffe36a4c3d11..699869294c0c7 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java @@ -28,6 +28,7 @@ import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory; import com.facebook.presto.operator.aggregation.builder.HashAggregationBuilder; import com.facebook.presto.operator.aggregation.builder.InMemoryHashAggregationBuilder; +import com.facebook.presto.operator.aggregation.partial.PartialAggregationController; import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation; import com.facebook.presto.spi.plan.AggregationNode.Step; import com.facebook.presto.spi.plan.PlanNodeId; @@ -61,6 +62,8 @@ import static com.facebook.airlift.testing.Assertions.assertGreaterThan; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.block.BlockAssertions.createLongRepeatBlock; +import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DoubleType.DOUBLE; @@ -86,12 +89,14 @@ import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.airlift.units.DataSize.succinctBytes; +import static io.airlift.units.DataSize.succinctDataSize; import static java.lang.String.format; import static java.util.Collections.emptyIterator; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -191,6 +196,7 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boole 100_000, Optional.of(new DataSize(16, MEGABYTE)), spillEnabled, + Optional.empty(), succinctBytes(memoryLimitForMerge), succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, @@ -245,6 +251,7 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna 100_000, Optional.of(new DataSize(16, MEGABYTE)), spillEnabled, + Optional.empty(), succinctBytes(memoryLimitForMerge), succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, @@ -292,6 +299,7 @@ public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean sp 100_000, Optional.of(new DataSize(16, MEGABYTE)), spillEnabled, + Optional.empty(), succinctBytes(memoryLimitForMerge), succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, @@ -374,6 +382,7 @@ public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boo 100_000, Optional.of(new DataSize(16, MEGABYTE)), spillEnabled, + Optional.empty(), succinctBytes(memoryLimitForMerge), succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, @@ -604,6 +613,7 @@ public void testMergeWithMemorySpill() 1, Optional.of(new DataSize(16, MEGABYTE)), true, + Optional.empty(), new DataSize(smallPagesSpillThresholdSize, Unit.BYTE), succinctBytes(Integer.MAX_VALUE), spillerFactory, @@ -649,6 +659,7 @@ public void testMemoryLimitInSpillWhenTriggerRehash() 1, Optional.of(new DataSize(16, MEGABYTE)), true, + Optional.empty(), new DataSize(smallPagesSpillThresholdSize, Unit.BYTE), succinctBytes(Integer.MAX_VALUE), spillerFactory, @@ -707,6 +718,7 @@ public void testSpillerFailure() 100_000, Optional.of(new DataSize(16, MEGABYTE)), true, + Optional.empty(), succinctBytes(8), succinctBytes(Integer.MAX_VALUE), new FailingSpillerFactory(), @@ -748,6 +760,7 @@ public void testMask() 1, Optional.of(new DataSize(16, MEGABYTE)), false, + Optional.empty(), new DataSize(16, MEGABYTE), new DataSize(16, MEGABYTE), new FailingSpillerFactory(), @@ -816,6 +829,131 @@ private void testMemoryTracking(boolean useSystemMemory) assertEquals(driverContext.getMemoryUsage(), 0); } + @Test + public void testAdaptivePartialAggregation() + { + List hashChannels = Ints.asList(0); + DataSize maxPartialMemory = succinctBytes(1); + PartialAggregationController partialAggregationController = new PartialAggregationController(maxPartialMemory, 0.8); + HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(BIGINT), + hashChannels, + ImmutableList.of(), + ImmutableList.of(), + Step.PARTIAL, + false, + ImmutableList.of(generateAccumulatorFactory(LONG_SUM, ImmutableList.of(0), Optional.empty())), + Optional.empty(), + Optional.empty(), + 100, + Optional.of(maxPartialMemory), // We set partial agg buffer to be 1 byte to force it to flush after every page + false, + Optional.of(partialAggregationController), + new DataSize(0, MEGABYTE), + new DataSize(0, MEGABYTE), + new FailingSpillerFactory(), + joinCompiler, + false); + + // Partial Aggregation should be enabled at the start + assertFalse(partialAggregationController.isPartialAggregationDisabled()); + + // After the first input page, since the values are mostly distinct, adaptive partial agg should kick in and disable partial aggregation for the second page + List input = rowPagesBuilder(false, hashChannels, BIGINT) + .addBlocksPage(createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8, 8)) + .addBlocksPage(createLongRepeatBlock(1, 10)) + .build(); + List expected = rowPagesBuilder(BIGINT, BIGINT) + .addBlocksPage(createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8), createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 16)) // first page should be aggregated + .addBlocksPage(createLongRepeatBlock(1, 10), createLongRepeatBlock(1, 10)) // second page should NOT be aggregated + .build(); + assertOperatorEquals(operatorFactory, input, expected); + // The first flush should have triggered adaptivity and disabled partial aggregation. Now it is disabled for subsequent flushes. + assertTrue(partialAggregationController.isPartialAggregationDisabled()); + + // Now we create a second operator, but we since we re-use the same factory, the PartialAggregationController should ensure we are NOT aggregating still + input = rowPagesBuilder(false, hashChannels, BIGINT) + .addBlocksPage(createLongRepeatBlock(1, 10)) + .addBlocksPage(createLongRepeatBlock(2, 10)) + .build(); + expected = rowPagesBuilder(BIGINT, BIGINT) + .addBlocksPage(createLongRepeatBlock(1, 10), createLongRepeatBlock(1, 10)) // output page should not be aggregated + .addBlocksPage(createLongRepeatBlock(2, 10), createLongRepeatBlock(2, 10)) // output page should not be aggregated + .build(); + assertOperatorEquals(operatorFactory, input, expected); + + // By default, we re-enable partial agg every partial agg buffer * 1.5 * 200 bytes. + // Since we've set our partial agg buffer to be 1 byte, this means we should re-enable partial aggregation every 300 bytes. + // At this point, we have processed 4 long blocks of 10 rows each. Each value in the long block is 8 bytes (for the long) + 1 byte (for the null flag) = 9 bytes. + // So 4 long blocks of 10 rows each = 90 * 4 = 360 bytes, which is over our threshold of 300 bytes. Thus, partial aggregation should be re-enabled at this point. + assertFalse(partialAggregationController.isPartialAggregationDisabled()); + + input = rowPagesBuilder(false, hashChannels, BIGINT) + .addBlocksPage(createLongRepeatBlock(1, 100)) + .addBlocksPage(createLongRepeatBlock(2, 100)) + .build(); + expected = rowPagesBuilder(BIGINT, BIGINT) + .addBlocksPage(createLongsBlock(1), createLongsBlock(100)) + .addBlocksPage(createLongsBlock(2), createLongsBlock(200)) + .build(); + // Partial aggregation should show good efficiency since the values are repeating in the input. So we should keep partial aggregation on. + assertOperatorEquals(operatorFactory, input, expected); + assertFalse(partialAggregationController.isPartialAggregationDisabled()); + } + + @Test + public void testAdaptivePartialAggregationIsTriggeredOnlyOnFlush() + { + List hashChannels = Ints.asList(0); + // We make partial aggregation controller to trigger after page flush by setting to 1 byte + PartialAggregationController partialAggregationController = new PartialAggregationController(succinctBytes(1), 0.8); + HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(BIGINT), + hashChannels, + ImmutableList.of(), + ImmutableList.of(), + Step.PARTIAL, + false, + ImmutableList.of(generateAccumulatorFactory(LONG_SUM, ImmutableList.of(0), Optional.empty())), + Optional.empty(), + Optional.empty(), + 100, + Optional.of(succinctDataSize(16, MEGABYTE)), // We set partial agg buffer to be 16 MB, so that we will only flush after processing all pages + false, + Optional.of(partialAggregationController), + new DataSize(0, MEGABYTE), + new DataSize(0, MEGABYTE), + new FailingSpillerFactory(), + joinCompiler, + false); + + List input = rowPagesBuilder(false, hashChannels, BIGINT) + .addSequencePage(10, 0) + .addBlocksPage(createLongRepeatBlock(1, 2)) + .build(); + List expected = rowPagesBuilder(BIGINT, BIGINT) + .addBlocksPage(createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), createLongsBlock(0, 3, 2, 3, 4, 5, 6, 7, 8, 9)) + .build(); + // Since first input page is unique values, partial agg would've been disabled for the second page. + // But because we wait for the flush, partial agg remains enabled for the second page. + assertOperatorEquals(operatorFactory, input, expected); + // After the flush, partial agg should be disabled because 10 / 12 values are unique, which is > 0.8 default uniqueness row ratio threshold. + assertTrue(partialAggregationController.isPartialAggregationDisabled()); + } + + private void assertOperatorEquals(OperatorFactory operatorFactory, List inputPages, List expectedPages) + { + DriverContext driverContext = createDriverContext(1024); + MaterializedResult expected = MaterializedResult.resultBuilder(driverContext.getSession(), BIGINT, BIGINT) + .pages(expectedPages) + .build(); + OperatorAssertion.assertOperatorEquals(operatorFactory, driverContext, inputPages, expected, false, ImmutableList.of(), false); + } + private DriverContext createDriverContext() { return createDriverContext(Integer.MAX_VALUE); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java index 3ad42a2951c60..f4d6047a5d5e0 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java @@ -77,6 +77,7 @@ public class TestHashAggregationOperatorInSegmentedAggregationMode 4, Optional.of(new DataSize(16, MEGABYTE)), false, + Optional.empty(), new DataSize(16, MEGABYTE), new DataSize(16, MEGABYTE), new DummySpillerFactory(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 478461ff93835..ce719546c4c63 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -157,6 +157,8 @@ public void testDefaults() .setPreferPartialAggregation(true) .setPartialAggregationStrategy(PartialAggregationStrategy.ALWAYS) .setPartialAggregationByteReductionThreshold(0.5) + .setAdaptivePartialAggregationEnabled(false) + .setAdaptivePartialAggregationRowsReductionRatioThreshold(0.8) .setOptimizeTopNRowNumber(true) .setOptimizeCaseExpressionPredicate(false) .setHistogramGroupImplementation(HistogramGroupImplementation.NEW) @@ -358,6 +360,8 @@ public void testExplicitPropertyMappings() .put("optimizer.prefer-partial-aggregation", "false") .put("optimizer.partial-aggregation-strategy", "automatic") .put("optimizer.partial-aggregation-byte-reduction-threshold", "0.8") + .put("experimental.adaptive-partial-aggregation", "true") + .put("experimental.adaptive-partial-aggregation-rows-reduction-ratio-threshold", "0.9") .put("optimizer.optimize-top-n-row-number", "false") .put("optimizer.optimize-case-expression-predicate", "true") .put("distributed-sort", "false") @@ -549,6 +553,8 @@ public void testExplicitPropertyMappings() .setPreferPartialAggregation(false) .setPartialAggregationStrategy(PartialAggregationStrategy.AUTOMATIC) .setPartialAggregationByteReductionThreshold(0.8) + .setAdaptivePartialAggregationEnabled(true) + .setAdaptivePartialAggregationRowsReductionRatioThreshold(0.9) .setOptimizeTopNRowNumber(false) .setOptimizeCaseExpressionPredicate(true) .setHistogramGroupImplementation(HistogramGroupImplementation.LEGACY)