Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,11 @@ public void testHistoryBasedStatsCalculatorCTE()
.setSystemProperty(CTE_PARTITIONING_PROVIDER_CATALOG, "hive")
.build();
// CBO Statistics
assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(Double.NaN)));
assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(0D)));

// HBO Statistics
executeAndTrackHistory(sql, cteMaterialization);
assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(3)));
assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(3D)));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.google.common.collect.ImmutableMap;

import java.util.Collection;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.spi.plan.AggregationNode.Step.INTERMEDIATE;
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static java.lang.Math.min;
Expand Down Expand Up @@ -54,49 +56,81 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(AggregationNode node, Stat
return Optional.empty();
}

if (node.getStep() != SINGLE) {
return Optional.empty();
}
PlanNodeStatsEstimate estimate;

return Optional.of(groupBy(
statsProvider.getStats(node.getSource()),
node.getGroupingKeys(),
node.getAggregations()));
if (node.getStep() == PARTIAL || node.getStep() == INTERMEDIATE) {
estimate = partialGroupBy(
statsProvider.getStats(node.getSource()),
node.getGroupingKeys(),
node.getAggregations());
}
else {
estimate = groupBy(
statsProvider.getStats(node.getSource()),
node.getGroupingKeys(),
node.getAggregations());
}
return Optional.of(estimate);
}

public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection<VariableReferenceExpression> groupByVariables, Map<VariableReferenceExpression, Aggregation> aggregations)
{
// Used to estimate FINAL or SINGLE step aggregations
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();

if (isGlobalAggregation(groupByVariables)) {
if (groupByVariables.isEmpty()) {
result.setConfidence(FACT);
result.setOutputRowCount(1);
}

for (VariableReferenceExpression groupByVariable : groupByVariables) {
VariableStatsEstimate symbolStatistics = sourceStats.getVariableStatistics(groupByVariable);
result.addVariableStatistics(groupByVariable, symbolStatistics.mapNullsFraction(nullsFraction -> {
if (nullsFraction == 0.0) {
return 0.0;
}
return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1);
}));
else {
result.addVariableStatistics(getGroupByVariablesStatistics(sourceStats, groupByVariables));
double rowsCount = getRowsCount(sourceStats, groupByVariables);
result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount()));
}

aggregations.forEach((key, value) -> result.addVariableStatistics(key, estimateAggregationStats(value, sourceStats)));

return result.build();
}

public static double getRowsCount(PlanNodeStatsEstimate sourceStats, Collection<VariableReferenceExpression> groupByVariables)
{
double rowsCount = 1;
for (VariableReferenceExpression groupByVariable : groupByVariables) {
VariableStatsEstimate symbolStatistics = sourceStats.getVariableStatistics(groupByVariable);
int nullRow = (symbolStatistics.getNullsFraction() == 0.0) ? 0 : 1;
rowsCount *= symbolStatistics.getDistinctValuesCount() + nullRow;
}
result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount()));
return rowsCount;
}

for (Map.Entry<VariableReferenceExpression, Aggregation> aggregationEntry : aggregations.entrySet()) {
result.addVariableStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats));
}
private static PlanNodeStatsEstimate partialGroupBy(PlanNodeStatsEstimate sourceStats, Collection<VariableReferenceExpression> groupByVariables, Map<VariableReferenceExpression, Aggregation> aggregations)
{
// Pessimistic assumption of no reduction from PARTIAL and INTERMEDIATE aggregation, forwarding of the source statistics.
// This makes the CBO estimates in the EXPLAIN plan output easier to understand,
// even though partial aggregations are added after the CBO rules have been run.
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
result.setOutputRowCount(sourceStats.getOutputRowCount());
result.addVariableStatistics(getGroupByVariablesStatistics(sourceStats, groupByVariables));
aggregations.forEach((key, value) -> result.addVariableStatistics(key, estimateAggregationStats(value, sourceStats)));

return result.build();
}

private static Map<VariableReferenceExpression, VariableStatsEstimate> getGroupByVariablesStatistics(PlanNodeStatsEstimate sourceStats, Collection<VariableReferenceExpression> groupByVariables)
{
ImmutableMap.Builder<VariableReferenceExpression, VariableStatsEstimate> variableStatsEstimates = ImmutableMap.builder();
for (VariableReferenceExpression groupByVariable : groupByVariables) {
VariableStatsEstimate symbolStatistics = sourceStats.getVariableStatistics(groupByVariable);
variableStatsEstimates.put(groupByVariable, symbolStatistics.mapNullsFraction(nullsFraction -> {
if (nullsFraction == 0.0) {
return 0.0;
}
return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1);
}));
}
return variableStatsEstimates.build();
}

private static VariableStatsEstimate estimateAggregationStats(Aggregation aggregation, PlanNodeStatsEstimate sourceStats)
{
requireNonNull(aggregation, "aggregation is null");
Expand All @@ -105,9 +139,4 @@ private static VariableStatsEstimate estimateAggregationStats(Aggregation aggreg
// TODO implement simple aggregations like: min, max, count, sum
return VariableStatsEstimate.unknown();
}

private static boolean isGlobalAggregation(Collection<VariableReferenceExpression> groupingKeys)
{
return groupingKeys.isEmpty();
}
}
Loading
Loading