diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java index 4bed0b39e8b36..a8aa552763d44 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java @@ -343,11 +343,11 @@ public void testHistoryBasedStatsCalculatorCTE() .setSystemProperty(CTE_PARTITIONING_PROVIDER_CATALOG, "hive") .build(); // CBO Statistics - assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(Double.NaN))); + assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(0D))); // HBO Statistics executeAndTrackHistory(sql, cteMaterialization); - assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(3))); + assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(3D))); } @Test diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java b/presto-main-base/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java index c11fedce9c259..9962d12fef241 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java @@ -20,12 +20,14 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.google.common.collect.ImmutableMap; import java.util.Collection; import java.util.Map; import java.util.Optional; -import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.spi.plan.AggregationNode.Step.INTERMEDIATE; +import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; import static java.lang.Math.min; @@ -54,49 +56,81 @@ protected Optional doCalculate(AggregationNode node, Stat return Optional.empty(); } - if (node.getStep() != SINGLE) { - return Optional.empty(); - } + PlanNodeStatsEstimate estimate; - return Optional.of(groupBy( - statsProvider.getStats(node.getSource()), - node.getGroupingKeys(), - node.getAggregations())); + if (node.getStep() == PARTIAL || node.getStep() == INTERMEDIATE) { + estimate = partialGroupBy( + statsProvider.getStats(node.getSource()), + node.getGroupingKeys(), + node.getAggregations()); + } + else { + estimate = groupBy( + statsProvider.getStats(node.getSource()), + node.getGroupingKeys(), + node.getAggregations()); + } + return Optional.of(estimate); } public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection groupByVariables, Map aggregations) { + // Used to estimate FINAL or SINGLE step aggregations PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); - - if (isGlobalAggregation(groupByVariables)) { + if (groupByVariables.isEmpty()) { result.setConfidence(FACT); + result.setOutputRowCount(1); } - - for (VariableReferenceExpression groupByVariable : groupByVariables) { - VariableStatsEstimate symbolStatistics = sourceStats.getVariableStatistics(groupByVariable); - result.addVariableStatistics(groupByVariable, symbolStatistics.mapNullsFraction(nullsFraction -> { - if (nullsFraction == 0.0) { - return 0.0; - } - return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1); - })); + else { + result.addVariableStatistics(getGroupByVariablesStatistics(sourceStats, groupByVariables)); + double rowsCount = getRowsCount(sourceStats, groupByVariables); + result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount())); } + aggregations.forEach((key, value) -> result.addVariableStatistics(key, estimateAggregationStats(value, sourceStats))); + + return result.build(); + } + + public static double getRowsCount(PlanNodeStatsEstimate sourceStats, Collection groupByVariables) + { double rowsCount = 1; for (VariableReferenceExpression groupByVariable : groupByVariables) { VariableStatsEstimate symbolStatistics = sourceStats.getVariableStatistics(groupByVariable); int nullRow = (symbolStatistics.getNullsFraction() == 0.0) ? 0 : 1; rowsCount *= symbolStatistics.getDistinctValuesCount() + nullRow; } - result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount())); + return rowsCount; + } - for (Map.Entry aggregationEntry : aggregations.entrySet()) { - result.addVariableStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats)); - } + private static PlanNodeStatsEstimate partialGroupBy(PlanNodeStatsEstimate sourceStats, Collection groupByVariables, Map aggregations) + { + // Pessimistic assumption of no reduction from PARTIAL and INTERMEDIATE aggregation, forwarding of the source statistics. + // This makes the CBO estimates in the EXPLAIN plan output easier to understand, + // even though partial aggregations are added after the CBO rules have been run. + PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); + result.setOutputRowCount(sourceStats.getOutputRowCount()); + result.addVariableStatistics(getGroupByVariablesStatistics(sourceStats, groupByVariables)); + aggregations.forEach((key, value) -> result.addVariableStatistics(key, estimateAggregationStats(value, sourceStats))); return result.build(); } + private static Map getGroupByVariablesStatistics(PlanNodeStatsEstimate sourceStats, Collection groupByVariables) + { + ImmutableMap.Builder variableStatsEstimates = ImmutableMap.builder(); + for (VariableReferenceExpression groupByVariable : groupByVariables) { + VariableStatsEstimate symbolStatistics = sourceStats.getVariableStatistics(groupByVariable); + variableStatsEstimates.put(groupByVariable, symbolStatistics.mapNullsFraction(nullsFraction -> { + if (nullsFraction == 0.0) { + return 0.0; + } + return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1); + })); + } + return variableStatsEstimates.build(); + } + private static VariableStatsEstimate estimateAggregationStats(Aggregation aggregation, PlanNodeStatsEstimate sourceStats) { requireNonNull(aggregation, "aggregation is null"); @@ -105,9 +139,4 @@ private static VariableStatsEstimate estimateAggregationStats(Aggregation aggreg // TODO implement simple aggregations like: min, max, count, sum return VariableStatsEstimate.unknown(); } - - private static boolean isGlobalAggregation(Collection groupingKeys) - { - return groupingKeys.isEmpty(); - } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java b/presto-main-base/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java index 2656d9777a549..128d271a993e6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java +++ b/presto-main-base/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.cost; +import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import org.testng.annotations.Test; @@ -20,21 +21,26 @@ import java.util.function.Consumer; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT; public class TestAggregationStatsRule extends BaseStatsCalculatorTest { + private static final VariableReferenceExpression VARIABLE_X = new VariableReferenceExpression(Optional.empty(), "x", BIGINT); + private static final VariableReferenceExpression VARIABLE_Y = new VariableReferenceExpression(Optional.empty(), "y", BIGINT); + private static final VariableReferenceExpression VARIABLE_Z = new VariableReferenceExpression(Optional.empty(), "z", BIGINT); + @Test public void testAggregationWhenAllStatisticsAreKnown() { Consumer outputRowCountAndZStatsAreCalculated = check -> check .outputRowsCount(15) - .variableStats(new VariableReferenceExpression(Optional.empty(), "z", BIGINT), symbolStatsAssertion -> symbolStatsAssertion + .variableStats(VARIABLE_Z, symbolStatsAssertion -> symbolStatsAssertion .lowValue(10) .highValue(15) .distinctValuesCount(4) .nullsFraction(0.2)) - .variableStats(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), symbolStatsAssertion -> symbolStatsAssertion + .variableStats(VARIABLE_Y, symbolStatsAssertion -> symbolStatsAssertion .lowValue(0) .highValue(3) .distinctValuesCount(3) @@ -59,11 +65,11 @@ public void testAggregationWhenAllStatisticsAreKnown() Consumer outputRowsCountAndZStatsAreNotFullyCalculated = check -> check .outputRowsCountUnknown() - .variableStats(new VariableReferenceExpression(Optional.empty(), "z", BIGINT), symbolStatsAssertion -> symbolStatsAssertion + .variableStats(VARIABLE_Z, symbolStatsAssertion -> symbolStatsAssertion .unknownRange() .distinctValuesCountUnknown() .nullsFractionUnknown()) - .variableStats(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), symbolStatsAssertion -> symbolStatsAssertion + .variableStats(VARIABLE_Y, symbolStatsAssertion -> symbolStatsAssertion .unknownRange() .nullsFractionUnknown() .distinctValuesCountUnknown()); @@ -96,19 +102,19 @@ private StatsCalculatorAssertion testAggregation(VariableStatsEstimate zStats) .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT), pb.variable("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), VariableStatsEstimate.builder() + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() .setLowValue(1) .setHighValue(10) .setDistinctValuesCount(5) .setNullsFraction(0.3) .build()) - .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), VariableStatsEstimate.builder() + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(3) .setDistinctValuesCount(3) .setNullsFraction(0) .build()) - .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "z", BIGINT), zStats) + .addVariableStatistics(VARIABLE_Z, zStats) .build()) .check(check -> check .variableStats(new VariableReferenceExpression(Optional.empty(), "sum", BIGINT), symbolStatsAssertion -> symbolStatsAssertion @@ -126,7 +132,7 @@ private StatsCalculatorAssertion testAggregation(VariableStatsEstimate zStats) .highValueUnknown() .distinctValuesCountUnknown() .nullsFractionUnknown()) - .variableStats(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), symbolStatsAssertion -> symbolStatsAssertion + .variableStats(VARIABLE_X, symbolStatsAssertion -> symbolStatsAssertion .lowValueUnknown() .highValueUnknown() .distinctValuesCountUnknown() @@ -144,9 +150,459 @@ public void testAggregationStatsCappedToInputRows() .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT), pb.variable("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), VariableStatsEstimate.builder().setDistinctValuesCount(50).build()) - .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "z", BIGINT), VariableStatsEstimate.builder().setDistinctValuesCount(50).build()) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder().setDistinctValuesCount(50).build()) + .addVariableStatistics(VARIABLE_Z, VariableStatsEstimate.builder().setDistinctValuesCount(50).build()) .build()) .check(check -> check.outputRowsCount(100)); } + + /** + * Verifies that a global aggregation (no grouping keys) always produces + * exactly one output row with FACT confidence level, regardless of the + * input statistics. + */ + @Test + public void testGlobalAggregationReturnsOneRow() + { + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("sum", BIGINT), pb.rowExpression("sum(x)")) + .addAggregation(pb.variable("count", BIGINT), pb.rowExpression("count()")) + .globalGrouping() + .source(pb.values(pb.variable("x", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(1000) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(100) + .setDistinctValuesCount(50) + .setNullsFraction(0.1) + .build()) + .build()) + .check(check -> check + .outputRowsCount(1) + .confident(FACT) + .variableStats(new VariableReferenceExpression(Optional.empty(), "sum", BIGINT), symbolStatsAssertion -> symbolStatsAssertion + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown()) + .variableStats(new VariableReferenceExpression(Optional.empty(), "count", BIGINT), symbolStatsAssertion -> symbolStatsAssertion + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown())); + } + + /** + * Verifies that a global aggregation with zero input rows still produces + * exactly one output row with FACT confidence. This is the expected behavior + * for queries like {@code SELECT count(*) FROM empty_table}. + */ + @Test + public void testGlobalAggregationWithZeroInputRows() + { + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("count", BIGINT), pb.rowExpression("count()")) + .globalGrouping() + .source(pb.values(pb.variable("x", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(0) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setDistinctValuesCount(0) + .setNullsFraction(0) + .build()) + .build()) + .check(check -> check + .outputRowsCount(1) + .confident(FACT)); + } + + /** + * Verifies that a PARTIAL aggregation step does not reduce the estimated + * row count. The rule pessimistically assumes no reduction for partial + * aggregations and forwards the source row count directly. + */ + @Test + public void testPartialAggregationPreservesSourceRowCount() + { + double sourceRowCount = 500; + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("sum", BIGINT), pb.rowExpression("sum(x)")) + .singleGroupingSet(pb.variable("y", BIGINT)) + .step(AggregationNode.Step.PARTIAL) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(sourceRowCount) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(5) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(3) + .setDistinctValuesCount(3) + .setNullsFraction(0) + .build()) + .build()) + .check(check -> check + .outputRowsCount(sourceRowCount) + .variableStats(VARIABLE_Y, symbolStatsAssertion -> symbolStatsAssertion + .lowValue(0) + .highValue(3) + .distinctValuesCount(3) + .nullsFraction(0))); + } + + /** + * Verifies that an INTERMEDIATE aggregation step behaves identically to a + * PARTIAL step: no reduction in estimated row count, source stats forwarded. + */ + @Test + public void testIntermediateAggregationPreservesSourceRowCount() + { + double sourceRowCount = 500; + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("sum", BIGINT), pb.rowExpression("sum(x)")) + .singleGroupingSet(pb.variable("y", BIGINT)) + .step(AggregationNode.Step.INTERMEDIATE) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(sourceRowCount) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(5) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(3) + .setDistinctValuesCount(3) + .setNullsFraction(0) + .build()) + .build()) + .check(check -> check + .outputRowsCount(sourceRowCount) + .variableStats(VARIABLE_Y, symbolStatsAssertion -> symbolStatsAssertion + .lowValue(0) + .highValue(3) + .distinctValuesCount(3) + .nullsFraction(0))); + } + + /** + * Verifies that for a SINGLE-step aggregation with a single grouping key, + * the output row count equals the distinct value count of the grouping key. + * Also verifies that the grouping key's nulls fraction is set to zero when + * the source has no nulls in that column. + */ + @Test + public void testSingleGroupingKeyNoNulls() + { + // y has 10 distinct values with no nulls, source has 200 rows + // Expected output: 10 rows (= NDV of y) + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("count", BIGINT), pb.rowExpression("count()")) + .singleGroupingSet(pb.variable("y", BIGINT)) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(200) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(50) + .setDistinctValuesCount(50) + .setNullsFraction(0) + .build()) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(9) + .setDistinctValuesCount(10) + .setNullsFraction(0) + .build()) + .build()) + .check(check -> check + .outputRowsCount(10) + .variableStats(VARIABLE_Y, symbolStatsAssertion -> symbolStatsAssertion + .lowValue(0) + .highValue(9) + .distinctValuesCount(10) + .nullsFraction(0))); + } + + /** + * Verifies that when a grouping key has a non-zero nulls fraction, the + * output row count accounts for the null group (NDV + 1 for the null row). + * Also checks that the resulting nulls fraction for the grouping key is + * adjusted to {@code 1 / (NDV + 1)}. + */ + @Test + public void testSingleGroupingKeyWithNulls() + { + // y has 10 distinct values with 20% nulls, source has 200 rows + // Expected output: 10 + 1 = 11 rows (NDV + 1 for null group) + // Expected nulls fraction: 1 / (10 + 1) = 1/11 + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("count", BIGINT), pb.rowExpression("count()")) + .singleGroupingSet(pb.variable("y", BIGINT)) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(200) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(50) + .setDistinctValuesCount(50) + .setNullsFraction(0) + .build()) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(9) + .setDistinctValuesCount(10) + .setNullsFraction(0.2) + .build()) + .build()) + .check(check -> check + .outputRowsCount(11) + .variableStats(VARIABLE_Y, symbolStatsAssertion -> symbolStatsAssertion + .lowValue(0) + .highValue(9) + .distinctValuesCount(10) + .nullsFraction(1.0 / 11))); + } + + /** + * Verifies the row count estimate for multiple grouping keys with nulls. + * The output row count is the product of (NDV + null_row) for each key, + * capped at the source row count. The nulls fractions of grouping keys + * are each adjusted to {@code 1 / (NDV + 1)}. + */ + @Test + public void testMultipleGroupingKeysWithNulls() + { + // y: NDV=3, nullsFraction=0.1 -> contributes 3+1=4 + // z: NDV=5, nullsFraction=0.2 -> contributes 5+1=6 + // Product = 4 * 6 = 24, source has 200 rows, so 24 is used + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("count", BIGINT), pb.rowExpression("count()")) + .singleGroupingSet(pb.variable("y", BIGINT), pb.variable("z", BIGINT)) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT), pb.variable("z", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(200) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(10) + .setNullsFraction(0) + .build()) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(2) + .setDistinctValuesCount(3) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(VARIABLE_Z, VariableStatsEstimate.builder() + .setLowValue(10) + .setHighValue(14) + .setDistinctValuesCount(5) + .setNullsFraction(0.2) + .build()) + .build()) + .check(check -> check + .outputRowsCount(24) + .variableStats(VARIABLE_Y, symbolStatsAssertion -> symbolStatsAssertion + .lowValue(0) + .highValue(2) + .distinctValuesCount(3) + .nullsFraction(1.0 / 4)) + .variableStats(VARIABLE_Z, symbolStatsAssertion -> symbolStatsAssertion + .lowValue(10) + .highValue(14) + .distinctValuesCount(5) + .nullsFraction(1.0 / 6))); + } + + /** + * Verifies that when grouping key statistics are completely unknown + * (all NaN), the output row count estimate is also unknown. This mirrors + * how join stats handle missing column statistics. + */ + @Test + public void testAggregationWithUnknownGroupingKeyStats() + { + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("count", BIGINT), pb.rowExpression("count()")) + .singleGroupingSet(pb.variable("y", BIGINT)) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(5) + .setNullsFraction(0) + .build()) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.unknown()) + .build()) + .check(check -> check + .outputRowsCountUnknown() + .variableStats(VARIABLE_Y, symbolStatsAssertion -> symbolStatsAssertion + .unknownRange() + .distinctValuesCountUnknown() + .nullsFractionUnknown())); + } + + /** + * Verifies that a FINAL-step aggregation with a single grouping key + * produces the same estimates as a SINGLE-step aggregation, since both + * are handled by the same {@code groupBy} code path. + */ + @Test + public void testFinalAggregationMatchesSingleStep() + { + double sourceRowCount = 500; + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("sum", BIGINT), pb.rowExpression("sum(x)")) + .singleGroupingSet(pb.variable("y", BIGINT)) + .step(AggregationNode.Step.FINAL) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(sourceRowCount) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(5) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(19) + .setDistinctValuesCount(20) + .setNullsFraction(0) + .build()) + .build()) + .check(check -> check + .outputRowsCount(20) + .variableStats(VARIABLE_Y, symbolStatsAssertion -> symbolStatsAssertion + .lowValue(0) + .highValue(19) + .distinctValuesCount(20) + .nullsFraction(0))); + } + + /** + * Verifies that a partial aggregation with a global grouping (no grouping keys) + * preserves the full source row count, since partial aggregations assume + * pessimistic (no) reduction. + */ + @Test + public void testPartialGlobalAggregationPreservesSourceRows() + { + double sourceRowCount = 300; + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("count", BIGINT), pb.rowExpression("count()")) + .globalGrouping() + .step(AggregationNode.Step.PARTIAL) + .source(pb.values(pb.variable("x", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(sourceRowCount) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(100) + .setDistinctValuesCount(50) + .setNullsFraction(0) + .build()) + .build()) + .check(check -> check + .outputRowsCount(sourceRowCount)); + } + + /** + * Verifies that the aggregation output row count is correctly capped at the + * source row count when multiple grouping keys with no nulls produce a + * product of NDVs that exceeds the number of input rows. + */ + @Test + public void testMultipleGroupingKeysCappedToInputRows() + { + // y: NDV=50, no nulls -> contributes 50 + // z: NDV=50, no nulls -> contributes 50 + // Product = 50 * 50 = 2500, but source has only 100 rows => capped to 100 + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("count", BIGINT), pb.rowExpression("count()")) + .singleGroupingSet(pb.variable("y", BIGINT), pb.variable("z", BIGINT)) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT), pb.variable("z", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(49) + .setDistinctValuesCount(50) + .setNullsFraction(0) + .build()) + .addVariableStatistics(VARIABLE_Z, VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(49) + .setDistinctValuesCount(50) + .setNullsFraction(0) + .build()) + .build()) + .check(check -> check.outputRowsCount(100)); + } + + /** + * Verifies that aggregation statistics are correctly computed for a + * SINGLE-step aggregation when the grouping key has a high NDV but the + * source row count is low. The output row count should equal the source + * row count since NDV cannot exceed it. + */ + @Test + public void testGroupingKeyNdvExceedsSourceRows() + { + // y: NDV=200, no nulls, but source only has 50 rows => capped to 50 + tester().assertStatsFor(pb -> pb + .registerVariable(pb.variable("x")) + .aggregation(ab -> ab + .addAggregation(pb.variable("count", BIGINT), pb.rowExpression("count()")) + .singleGroupingSet(pb.variable("y", BIGINT)) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(50) + .addVariableStatistics(VARIABLE_X, VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(10) + .setNullsFraction(0) + .build()) + .addVariableStatistics(VARIABLE_Y, VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(199) + .setDistinctValuesCount(200) + .setNullsFraction(0) + .build()) + .build()) + .check(check -> check.outputRowsCount(50)); + } } diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java b/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java index fddc4fc06371d..28785a9327457 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java @@ -120,7 +120,7 @@ public void testHistoryBasedStatsCalculator() anyTree(node(ProjectNode.class, node(FilterNode.class, any())).withOutputRowCount(12.5))); assertPlan( "SELECT max(nationkey) FROM nation where name < 'D' group by regionkey", - anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(Double.NaN))); + anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(5).withOutputSize(90))); // HBO Statistics executeAndTrackHistory("SELECT max(nationkey) FROM nation where name < 'D' group by regionkey"); @@ -227,7 +227,7 @@ public void testHistoryBasedStatsCalculatorEnforceTimeOut() assertPlan( sessionWithDefaultTimeoutLimit, "SELECT max(nationkey) FROM nation where name < 'D' group by regionkey", - anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(Double.NaN))); + anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(5).withOutputSize(90))); // Write HBO statistics failed as we set timeout limit to be 0 executeAndNoHistoryWritten("SELECT max(nationkey) FROM nation where name < 'D' group by regionkey", sessionWithZeroTimeoutLimit); @@ -239,7 +239,7 @@ public void testHistoryBasedStatsCalculatorEnforceTimeOut() assertPlan( sessionWithDefaultTimeoutLimit, "SELECT max(nationkey) FROM nation where name < 'D' group by regionkey", - anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(Double.NaN))); + anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(5).withOutputSize(90))); // Write HBO Statistics is successful, as we use the default 10 seconds timeout limit executeAndTrackHistory("SELECT max(nationkey) FROM nation where name < 'D' group by regionkey", sessionWithDefaultTimeoutLimit); @@ -261,7 +261,7 @@ public void testHistoryBasedStatsCalculatorEnforceTimeOut() assertPlan( sessionWithZeroTimeoutLimit, "SELECT max(nationkey) FROM nation where name < 'D' group by regionkey", - anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(Double.NaN))); + anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(5))); } @Test diff --git a/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoryBasedRedisStatisticsTracking.java b/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoryBasedRedisStatisticsTracking.java index 01c6773751c1a..f50cddedf36dc 100644 --- a/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoryBasedRedisStatisticsTracking.java +++ b/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoryBasedRedisStatisticsTracking.java @@ -146,7 +146,7 @@ public void testHistoryBasedStatsCalculator() anyTree(node(ProjectNode.class, node(FilterNode.class, any())).withOutputRowCount(12.5))); assertPlan( "SELECT max(nationkey) FROM nation where name < 'D' group by regionkey", - anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(Double.NaN))); + anyTree(node(AggregationNode.class, node(ExchangeNode.class, anyTree(any()))).withOutputRowCount(5).withOutputSize(90))); // HBO Statistics executeAndTrackHistory("SELECT max(nationkey) FROM nation where name < 'D' group by regionkey");