diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index d07f849b4b4af..846d92381d6f5 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -28,6 +28,7 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationPartitioningMergingStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; +import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialMergePushdownStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartitioningPrecisionStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.SingleStreamSpillerChoice; @@ -56,6 +57,8 @@ import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.PARTITIONED; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.NONE; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.ALWAYS; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.NEVER; import static com.google.common.base.Preconditions.checkArgument; import static java.lang.Math.min; import static java.lang.String.format; @@ -140,6 +143,8 @@ public final class SystemSessionProperties public static final String DISTRIBUTED_SORT = "distributed_sort"; public static final String USE_MARK_DISTINCT = "use_mark_distinct"; public static final String PREFER_PARTIAL_AGGREGATION = "prefer_partial_aggregation"; + public static final String PARTIAL_AGGREGATION_STRATEGY = "partial_aggregation_strategy"; + public static final String PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD = "partial_aggregation_byte_reduction_threshold"; public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number"; public static final String MAX_GROUPING_SETS = "max_grouping_sets"; public static final String LEGACY_UNNEST = "legacy_unnest"; @@ -720,7 +725,24 @@ public SystemSessionProperties( booleanProperty( PREFER_PARTIAL_AGGREGATION, "Prefer splitting aggregations into partial and final stages", - featuresConfig.isPreferPartialAggregation(), + null, + false), + new PropertyMetadata<>( + PARTIAL_AGGREGATION_STRATEGY, + format("Partial aggregation strategy to use. Options are %s", + Stream.of(PartialAggregationStrategy.values()) + .map(PartialAggregationStrategy::name) + .collect(joining(","))), + VARCHAR, + PartialAggregationStrategy.class, + featuresConfig.getPartialAggregationStrategy(), + false, + value -> PartialAggregationStrategy.valueOf(((String) value).toUpperCase()), + PartialAggregationStrategy::name), + doubleProperty( + PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD, + "Byte reduction ratio threshold at which to disable partial aggregation", + featuresConfig.getPartialAggregationByteReductionThreshold(), false), booleanProperty( OPTIMIZE_TOP_N_ROW_NUMBER, @@ -1442,9 +1464,21 @@ public static boolean useMarkDistinct(Session session) return session.getSystemProperty(USE_MARK_DISTINCT, Boolean.class); } - public static boolean preferPartialAggregation(Session session) + public static PartialAggregationStrategy getPartialAggregationStrategy(Session session) + { + Boolean preferPartialAggregation = session.getSystemProperty(PREFER_PARTIAL_AGGREGATION, Boolean.class); + if (preferPartialAggregation != null) { + if (preferPartialAggregation) { + return ALWAYS; + } + return NEVER; + } + return session.getSystemProperty(PARTIAL_AGGREGATION_STRATEGY, PartialAggregationStrategy.class); + } + + public static double getPartialAggregationByteReductionThreshold(Session session) { - return session.getSystemProperty(PREFER_PARTIAL_AGGREGATION, Boolean.class); + return session.getSystemProperty(PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD, Double.class); } public static boolean isOptimizeTopNRowNumber(Session session) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java index 4bea17aafba16..13add075e9945 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java @@ -51,10 +51,14 @@ protected Optional doCalculate(ExchangeNode node, StatsPr { Optional estimate = Optional.empty(); double totalSize = 0; + boolean confident = true; for (int i = 0; i < node.getSources().size(); i++) { PlanNode source = node.getSources().get(i); PlanNodeStatsEstimate sourceStats = statsProvider.getStats(source); totalSize += sourceStats.getOutputSizeInBytes(); + if (!sourceStats.isConfident()) { + confident = false; + } PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputVariables(sourceStats, node.getInputs().get(i), node.getOutputVariables()); @@ -69,6 +73,7 @@ protected Optional doCalculate(ExchangeNode node, StatsPr verify(estimate.isPresent()); return Optional.of(buildFrom(estimate.get()) .setTotalSize(totalSize) + .setConfident(confident) .build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java index d333007c2c964..e5c6889226bf4 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java @@ -39,10 +39,11 @@ public class PlanNodeStatsEstimate { private static final double DEFAULT_DATA_SIZE_PER_COLUMN = 50; - private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, ImmutableMap.of()); + private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, false, ImmutableMap.of()); private final double outputRowCount; private final double totalSize; + private final boolean confident; private final PMap variableStatistics; public static PlanNodeStatsEstimate unknown() @@ -54,16 +55,18 @@ public static PlanNodeStatsEstimate unknown() public PlanNodeStatsEstimate( @JsonProperty("outputRowCount") double outputRowCount, @JsonProperty("totalSize") double totalSize, + @JsonProperty("confident") boolean confident, @JsonProperty("variableStatistics") Map variableStatistics) { - this(outputRowCount, totalSize, HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null"))); + this(outputRowCount, totalSize, confident, HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null"))); } - private PlanNodeStatsEstimate(double outputRowCount, double totalSize, PMap variableStatistics) + private PlanNodeStatsEstimate(double outputRowCount, double totalSize, boolean confident, PMap variableStatistics) { checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative"); this.outputRowCount = outputRowCount; this.totalSize = totalSize; + this.confident = confident; this.variableStatistics = variableStatistics; } @@ -83,6 +86,12 @@ public double getTotalSize() return totalSize; } + @JsonProperty + public boolean isConfident() + { + return confident; + } + /** * Only use when getting all columns and meanwhile do not want to * do per-column estimation. @@ -210,24 +219,26 @@ public static Builder builder() // we should propagate totalSize as default to simplify the relevant operations in rules that do not change this field. public static Builder buildFrom(PlanNodeStatsEstimate other) { - return new Builder(other.getOutputRowCount(), NaN, other.variableStatistics); + return new Builder(other.getOutputRowCount(), NaN, other.isConfident(), other.variableStatistics); } public static final class Builder { private double outputRowCount; private double totalSize; + private boolean confident; private PMap variableStatistics; public Builder() { - this(NaN, NaN, HashTreePMap.empty()); + this(NaN, NaN, false, HashTreePMap.empty()); } - private Builder(double outputRowCount, double totalSize, PMap variableStatistics) + private Builder(double outputRowCount, double totalSize, boolean confident, PMap variableStatistics) { this.outputRowCount = outputRowCount; this.totalSize = totalSize; + this.confident = confident; this.variableStatistics = variableStatistics; } @@ -243,6 +254,12 @@ public Builder setTotalSize(double totalSize) return this; } + public Builder setConfident(boolean confident) + { + this.confident = confident; + return this; + } + public Builder addVariableStatistics(VariableReferenceExpression variable, VariableStatsEstimate statistics) { variableStatistics = variableStatistics.plus(variable, statistics); @@ -263,7 +280,7 @@ public Builder removeVariableStatistics(VariableReferenceExpression variable) public PlanNodeStatsEstimate build() { - return new PlanNodeStatsEstimate(outputRowCount, totalSize, variableStatistics); + return new PlanNodeStatsEstimate(outputRowCount, totalSize, confident, variableStatistics); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java index 3f9fa4665d303..8f368e6df4d75 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java @@ -53,7 +53,8 @@ protected Optional doCalculate(ProjectNode node, StatsPro { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource()); PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder() - .setOutputRowCount(sourceStats.getOutputRowCount()); + .setOutputRowCount(sourceStats.getOutputRowCount()) + .setConfident(sourceStats.isConfident() && noChangeToSourceColumns(node)); for (Map.Entry entry : node.getAssignments().entrySet()) { RowExpression expression = entry.getValue(); @@ -66,4 +67,9 @@ protected Optional doCalculate(ProjectNode node, StatsPro } return Optional.of(calculatedStats.build()); } + + private boolean noChangeToSourceColumns(ProjectNode node) + { + return node.getOutputVariables().containsAll(node.getSource().getOutputVariables()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java index f85c03fc7539c..2bae731d8c66a 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java @@ -69,6 +69,7 @@ protected Optional doCalculate(TableScanNode node, StatsP return Optional.of(PlanNodeStatsEstimate.builder() .setOutputRowCount(tableStatistics.getRowCount().getValue()) .setTotalSize(tableStatistics.getTotalSize().getValue()) + .setConfident(true) .addVariableStatistics(outputVariableStats) .build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java index 03ce5c03c2fa3..a0514949371ee 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java @@ -63,7 +63,8 @@ public Pattern getPattern() public Optional calculate(ValuesNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) { PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); - statsBuilder.setOutputRowCount(node.getRows().size()); + statsBuilder.setOutputRowCount(node.getRows().size()) + .setConfident(true); for (int variableId = 0; variableId < node.getOutputVariables().size(); ++variableId) { VariableReferenceExpression variable = node.getOutputVariables().get(variableId); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 6519e108e2367..6431fb5c1bcc4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -143,6 +143,8 @@ public class FeaturesConfig private boolean parseDecimalLiteralsAsDouble; private boolean useMarkDistinct = true; private boolean preferPartialAggregation = true; + private PartialAggregationStrategy partialAggregationStrategy = PartialAggregationStrategy.ALWAYS; + private double partialAggregationByteReductionThreshold = 0.5; private boolean optimizeTopNRowNumber = true; private boolean pushLimitThroughOuterJoin = true; @@ -266,6 +268,13 @@ public enum SingleStreamSpillerChoice TEMP_STORAGE } + public enum PartialAggregationStrategy + { + ALWAYS, // Always do partial aggregation + NEVER, // Never do partial aggregation + AUTOMATIC // Let the optimizer decide for each aggregation + } + public double getCpuCostWeight() { return cpuCostWeight; @@ -755,6 +764,30 @@ public FeaturesConfig setPreferPartialAggregation(boolean value) return this; } + public PartialAggregationStrategy getPartialAggregationStrategy() + { + return partialAggregationStrategy; + } + + @Config("optimizer.partial-aggregation-strategy") + public FeaturesConfig setPartialAggregationStrategy(PartialAggregationStrategy partialAggregationStrategy) + { + this.partialAggregationStrategy = partialAggregationStrategy; + return this; + } + + public double getPartialAggregationByteReductionThreshold() + { + return partialAggregationByteReductionThreshold; + } + + @Config("optimizer.partial-aggregation-byte-reduction-threshold") + public FeaturesConfig setPartialAggregationByteReductionThreshold(double partialAggregationByteReductionThreshold) + { + this.partialAggregationByteReductionThreshold = partialAggregationByteReductionThreshold; + return this; + } + public boolean isOptimizeTopNRowNumber() { return optimizeTopNRowNumber; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index 9322c98c42ff8..ffd2c09b92885 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; @@ -27,6 +29,7 @@ import com.facebook.presto.spi.relation.LambdaDefinitionExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.SymbolMapper; @@ -41,12 +44,15 @@ import java.util.Optional; import java.util.stream.Collectors; -import static com.facebook.presto.SystemSessionProperties.preferPartialAggregation; +import static com.facebook.presto.SystemSessionProperties.getPartialAggregationByteReductionThreshold; +import static com.facebook.presto.SystemSessionProperties.getPartialAggregationStrategy; import static com.facebook.presto.operator.aggregation.AggregationUtils.isDecomposable; import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.AUTOMATIC; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.NEVER; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; @@ -101,7 +107,12 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context return Result.ofPlanNode(split(aggregationNode, context)); } - if (!decomposable || !preferPartialAggregation(context.getSession())) { + PartialAggregationStrategy partialAggregationStrategy = getPartialAggregationStrategy(context.getSession()); + if (!decomposable || + partialAggregationStrategy == NEVER || + partialAggregationStrategy == AUTOMATIC && + partialAggregationNotUseful(aggregationNode, exchangeNode, context) && + aggregationNode.getGroupingKeys().size() == 1) { return Result.empty(); } @@ -265,6 +276,18 @@ private PlanNode split(AggregationNode node, Context context) node.getGroupIdVariable()); } + private boolean partialAggregationNotUseful(AggregationNode aggregationNode, ExchangeNode exchangeNode, Context context) + { + StatsProvider stats = context.getStatsProvider(); + PlanNodeStatsEstimate exchangeStats = stats.getStats(exchangeNode); + PlanNodeStatsEstimate aggregationStats = stats.getStats(aggregationNode); + double inputBytes = exchangeStats.getOutputSizeInBytes(exchangeNode.getOutputVariables()); + double outputBytes = aggregationStats.getOutputSizeInBytes(aggregationNode.getOutputVariables()); + double byteReductionThreshold = getPartialAggregationByteReductionThreshold(context.getSession()); + + return exchangeStats.isConfident() && outputBytes > inputBytes * byteReductionThreshold; + } + private static boolean isLambda(RowExpression rowExpression) { return rowExpression instanceof LambdaDefinitionExpression; diff --git a/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java b/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java index 36b85d22aa271..b65a8d08510a5 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java @@ -50,6 +50,12 @@ public PlanNodeStatsAssertion totalSize(double expected) return this; } + public PlanNodeStatsAssertion confident(boolean expected) + { + assertEquals(actual.isConfident(), expected); + return this; + } + public PlanNodeStatsAssertion outputRowsCountUnknown() { assertTrue(Double.isNaN(actual.getOutputRowCount()), "expected unknown outputRowsCount but got " + actual.getOutputRowCount()); @@ -87,6 +93,7 @@ public PlanNodeStatsAssertion variablesWithKnownStats(VariableReferenceExpressio public PlanNodeStatsAssertion equalTo(PlanNodeStatsEstimate expected) { assertEstimateEquals(actual.getOutputRowCount(), expected.getOutputRowCount(), "outputRowCount mismatch"); + assertEquals(actual.isConfident(), expected.isConfident()); for (VariableReferenceExpression variable : union(expected.getVariablesWithKnownStatistics(), actual.getVariablesWithKnownStatistics())) { assertVariableStatsEqual(variable, actual.getVariableStatistics(variable), expected.getVariableStatistics(variable)); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java index 49c93bd3e955a..6d27c737cde96 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java @@ -118,4 +118,46 @@ public void testExchange() .distinctValuesCount(4) .nullsFraction(0.1))); } + + @Test + public void testExchangeConfidence() + { + // Confidence of exchange stats should be logical AND of its source nodes' confidence + + tester().assertStatsFor(pb -> pb + .exchange(exchangeBuilder -> exchangeBuilder + .addInputsSet() + .addInputsSet() + .singleDistributionPartitioningScheme() + .addSource(pb.values()) + .addSource(pb.values()))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .setConfident(true) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .setConfident(true) + .build()) + .check(check -> check + .confident(true)); + + tester().assertStatsFor(pb -> pb + .exchange(exchangeBuilder -> exchangeBuilder + .addInputsSet() + .addInputsSet() + .singleDistributionPartitioningScheme() + .addSource(pb.values()) + .addSource(pb.values()))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .setConfident(true) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .setConfident(false) + .build()) + .check(check -> check + .confident(false)); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java b/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java index 15f0f0b898ec8..85fe3004871f2 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java @@ -47,6 +47,7 @@ public void testStatsForValuesNode() .check(outputStats -> outputStats.equalTo( PlanNodeStatsEstimate.builder() .setOutputRowCount(3) + .setConfident(true) .addVariableStatistics( new VariableReferenceExpression("a", BIGINT), VariableStatsEstimate.builder() @@ -76,6 +77,7 @@ public void testStatsForValuesNode() .check(outputStats -> outputStats.equalTo( PlanNodeStatsEstimate.builder() .setOutputRowCount(4) + .setConfident(true) .addVariableStatistics( new VariableReferenceExpression("v", createVarcharType(30)), VariableStatsEstimate.builder() @@ -92,6 +94,7 @@ public void testStatsForValuesNodeWithJustNulls() FunctionResolution resolution = new FunctionResolution(tester().getMetadata().getFunctionAndTypeManager()); PlanNodeStatsEstimate bigintNullAStats = PlanNodeStatsEstimate.builder() .setOutputRowCount(1) + .setConfident(true) .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), VariableStatsEstimate.zero()) .build(); @@ -110,6 +113,7 @@ public void testStatsForValuesNodeWithJustNulls() PlanNodeStatsEstimate unknownNullAStats = PlanNodeStatsEstimate.builder() .setOutputRowCount(1) + .setConfident(true) .addVariableStatistics(new VariableReferenceExpression("a", UNKNOWN), VariableStatsEstimate.zero()) .build(); @@ -130,6 +134,7 @@ public void testStatsForEmptyValues() .check(outputStats -> outputStats.equalTo( PlanNodeStatsEstimate.builder() .setOutputRowCount(0) + .setConfident(true) .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), VariableStatsEstimate.zero()) .build())); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 3fb13fb75b3e0..4810f206669e6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -18,6 +18,7 @@ import com.facebook.presto.operator.aggregation.arrayagg.ArrayAggGroupImplementation; import com.facebook.presto.operator.aggregation.histogram.HistogramGroupImplementation; import com.facebook.presto.operator.aggregation.multimapagg.MultimapAggGroupImplementation; +import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartitioningPrecisionStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.SingleStreamSpillerChoice; import com.google.common.collect.ImmutableMap; @@ -131,6 +132,8 @@ public void testDefaults() .setFilterAndProjectMinOutputPageRowCount(256) .setUseMarkDistinct(true) .setPreferPartialAggregation(true) + .setPartialAggregationStrategy(PartialAggregationStrategy.ALWAYS) + .setPartialAggregationByteReductionThreshold(0.5) .setOptimizeTopNRowNumber(true) .setHistogramGroupImplementation(HistogramGroupImplementation.NEW) .setArrayAggGroupImplementation(ArrayAggGroupImplementation.NEW) @@ -255,6 +258,8 @@ public void testExplicitPropertyMappings() .put("multimapagg.implementation", "LEGACY") .put("optimizer.use-mark-distinct", "false") .put("optimizer.prefer-partial-aggregation", "false") + .put("optimizer.partial-aggregation-strategy", "automatic") + .put("optimizer.partial-aggregation-byte-reduction-threshold", "0.8") .put("optimizer.optimize-top-n-row-number", "false") .put("distributed-sort", "false") .put("analyzer.max-grouping-sets", "2047") @@ -369,6 +374,8 @@ public void testExplicitPropertyMappings() .setFilterAndProjectMinOutputPageRowCount(2048) .setUseMarkDistinct(false) .setPreferPartialAggregation(false) + .setPartialAggregationStrategy(PartialAggregationStrategy.AUTOMATIC) + .setPartialAggregationByteReductionThreshold(0.8) .setOptimizeTopNRowNumber(false) .setHistogramGroupImplementation(HistogramGroupImplementation.LEGACY) .setArrayAggGroupImplementation(ArrayAggGroupImplementation.LEGACY) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughExchange.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughExchange.java new file mode 100644 index 0000000000000..a51e28748b924 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughExchange.java @@ -0,0 +1,141 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.VariableStatsEstimate; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.SystemSessionProperties.PARTIAL_AGGREGATION_STRATEGY; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; +import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.sql.relational.Expressions.variable; + +public class TestPushPartialAggregationThroughExchange + extends BaseRuleTest +{ + @Test + public void testPartialAggregationAdded() + { + tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager())) + .setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC") + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + return p.aggregation(ab -> ab + .source( + p.exchange(e -> e + .addSource(p.values(a)) + .addInputsSet(a) + .singleDistributionPartitioningScheme(a))) + .addAggregation(p.variable("SUM", DOUBLE), expression("SUM(a)"), ImmutableList.of(DOUBLE)) + .globalGrouping() + .step(PARTIAL)); + }) + .matches(exchange( + project( + aggregation( + ImmutableMap.of("SUM", functionCall("sum", ImmutableList.of("a"))), + PARTIAL, + values("a"))))); + } + + @Test + public void testNoPartialAggregationWhenDisabled() + { + tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager())) + .setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "NEVER") + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + return p.aggregation(ab -> ab + .source( + p.exchange(e -> e + .addSource(p.values(a)) + .addInputsSet(a) + .singleDistributionPartitioningScheme(a))) + .addAggregation(p.variable("SUM", DOUBLE), expression("SUM(a)"), ImmutableList.of(DOUBLE)) + .globalGrouping() + .step(PARTIAL)); + }) + .doesNotFire(); + } + + @Test + public void testNoPartialAggregationWhenReductionBelowThreshold() + { + tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager())) + .setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC") + .on(p -> { + VariableReferenceExpression a = p.variable("a", DOUBLE); + VariableReferenceExpression b = p.variable("b", DOUBLE); + return p.aggregation(ab -> ab + .source( + p.exchange(e -> e + .addSource(p.values(new PlanNodeId("values"), a, b)) + .addInputsSet(a, b) + .singleDistributionPartitioningScheme(a, b))) + .addAggregation(p.variable("SUM", DOUBLE), expression("SUM(a)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(b) + .step(SINGLE)); + }) + .overrideStats("values", PlanNodeStatsEstimate.builder() + .setOutputRowCount(1000) + .addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800)) + .setConfident(true) + .build()) + .doesNotFire(); + } + + @Test + public void testPartialAggregationEnabledWhenNotConfident() + { + tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager())) + .setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC") + .on(p -> { + VariableReferenceExpression a = p.variable("a", DOUBLE); + VariableReferenceExpression b = p.variable("b", DOUBLE); + return p.aggregation(ab -> ab + .source( + p.exchange(e -> e + .addSource(p.values(new PlanNodeId("values"), a, b)) + .addInputsSet(a, b) + .singleDistributionPartitioningScheme(a, b))) + .addAggregation(p.variable("SUM", DOUBLE), expression("SUM(a)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(b) + .step(PARTIAL)); + }) + .overrideStats("values", PlanNodeStatsEstimate.builder() + .setOutputRowCount(1000) + .addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800)) + .setConfident(false) + .build()) + .matches(exchange( + project( + aggregation( + ImmutableMap.of("SUM", functionCall("sum", ImmutableList.of("a"))), + PARTIAL, + values("a", "b"))))); + } +}