Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationPartitioningMergingStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialMergePushdownStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartitioningPrecisionStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.SingleStreamSpillerChoice;
Expand Down Expand Up @@ -56,6 +57,8 @@
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.PARTITIONED;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.NONE;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.ALWAYS;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.NEVER;
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Math.min;
import static java.lang.String.format;
Expand Down Expand Up @@ -140,6 +143,8 @@ public final class SystemSessionProperties
public static final String DISTRIBUTED_SORT = "distributed_sort";
public static final String USE_MARK_DISTINCT = "use_mark_distinct";
public static final String PREFER_PARTIAL_AGGREGATION = "prefer_partial_aggregation";
public static final String PARTIAL_AGGREGATION_STRATEGY = "partial_aggregation_strategy";
public static final String PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD = "partial_aggregation_byte_reduction_threshold";
public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number";
public static final String MAX_GROUPING_SETS = "max_grouping_sets";
public static final String LEGACY_UNNEST = "legacy_unnest";
Expand Down Expand Up @@ -720,7 +725,24 @@ public SystemSessionProperties(
booleanProperty(
PREFER_PARTIAL_AGGREGATION,
"Prefer splitting aggregations into partial and final stages",
featuresConfig.isPreferPartialAggregation(),
null,
false),
new PropertyMetadata<>(
PARTIAL_AGGREGATION_STRATEGY,
format("Partial aggregation strategy to use. Options are %s",
Stream.of(PartialAggregationStrategy.values())
.map(PartialAggregationStrategy::name)
.collect(joining(","))),
VARCHAR,
PartialAggregationStrategy.class,
featuresConfig.getPartialAggregationStrategy(),
false,
value -> PartialAggregationStrategy.valueOf(((String) value).toUpperCase()),
PartialAggregationStrategy::name),
doubleProperty(
PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD,
"Byte reduction ratio threshold at which to disable partial aggregation",
featuresConfig.getPartialAggregationByteReductionThreshold(),
false),
booleanProperty(
OPTIMIZE_TOP_N_ROW_NUMBER,
Expand Down Expand Up @@ -1442,9 +1464,21 @@ public static boolean useMarkDistinct(Session session)
return session.getSystemProperty(USE_MARK_DISTINCT, Boolean.class);
}

public static boolean preferPartialAggregation(Session session)
public static PartialAggregationStrategy getPartialAggregationStrategy(Session session)
{
Boolean preferPartialAggregation = session.getSystemProperty(PREFER_PARTIAL_AGGREGATION, Boolean.class);
if (preferPartialAggregation != null) {
if (preferPartialAggregation) {
return ALWAYS;
}
return NEVER;
}
return session.getSystemProperty(PARTIAL_AGGREGATION_STRATEGY, PartialAggregationStrategy.class);
}

public static double getPartialAggregationByteReductionThreshold(Session session)
{
return session.getSystemProperty(PREFER_PARTIAL_AGGREGATION, Boolean.class);
return session.getSystemProperty(PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD, Double.class);
}

public static boolean isOptimizeTopNRowNumber(Session session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,14 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ExchangeNode node, StatsPr
{
Optional<PlanNodeStatsEstimate> estimate = Optional.empty();
double totalSize = 0;
boolean confident = true;
for (int i = 0; i < node.getSources().size(); i++) {
PlanNode source = node.getSources().get(i);
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(source);
totalSize += sourceStats.getOutputSizeInBytes();
if (!sourceStats.isConfident()) {
confident = false;
}

PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputVariables(sourceStats, node.getInputs().get(i), node.getOutputVariables());

Expand All @@ -69,6 +73,7 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ExchangeNode node, StatsPr
verify(estimate.isPresent());
return Optional.of(buildFrom(estimate.get())
.setTotalSize(totalSize)
.setConfident(confident)
.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@
public class PlanNodeStatsEstimate
{
private static final double DEFAULT_DATA_SIZE_PER_COLUMN = 50;
private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, ImmutableMap.of());
private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, false, ImmutableMap.of());

private final double outputRowCount;
private final double totalSize;
private final boolean confident;
private final PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics;

public static PlanNodeStatsEstimate unknown()
Expand All @@ -54,16 +55,18 @@ public static PlanNodeStatsEstimate unknown()
public PlanNodeStatsEstimate(
@JsonProperty("outputRowCount") double outputRowCount,
@JsonProperty("totalSize") double totalSize,
@JsonProperty("confident") boolean confident,
@JsonProperty("variableStatistics") Map<VariableReferenceExpression, VariableStatsEstimate> variableStatistics)
{
this(outputRowCount, totalSize, HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null")));
this(outputRowCount, totalSize, confident, HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null")));
}

private PlanNodeStatsEstimate(double outputRowCount, double totalSize, PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics)
private PlanNodeStatsEstimate(double outputRowCount, double totalSize, boolean confident, PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics)
{
checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative");
this.outputRowCount = outputRowCount;
this.totalSize = totalSize;
this.confident = confident;
this.variableStatistics = variableStatistics;
}

Expand All @@ -83,6 +86,12 @@ public double getTotalSize()
return totalSize;
}

@JsonProperty
public boolean isConfident()
{
return confident;
}

/**
* Only use when getting all columns and meanwhile do not want to
* do per-column estimation.
Expand Down Expand Up @@ -210,24 +219,26 @@ public static Builder builder()
// we should propagate totalSize as default to simplify the relevant operations in rules that do not change this field.
public static Builder buildFrom(PlanNodeStatsEstimate other)
{
return new Builder(other.getOutputRowCount(), NaN, other.variableStatistics);
return new Builder(other.getOutputRowCount(), NaN, other.isConfident(), other.variableStatistics);
}

public static final class Builder
{
private double outputRowCount;
private double totalSize;
private boolean confident;
private PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics;

public Builder()
{
this(NaN, NaN, HashTreePMap.empty());
this(NaN, NaN, false, HashTreePMap.empty());
}

private Builder(double outputRowCount, double totalSize, PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics)
private Builder(double outputRowCount, double totalSize, boolean confident, PMap<VariableReferenceExpression, VariableStatsEstimate> variableStatistics)
{
this.outputRowCount = outputRowCount;
this.totalSize = totalSize;
this.confident = confident;
this.variableStatistics = variableStatistics;
}

Expand All @@ -243,6 +254,12 @@ public Builder setTotalSize(double totalSize)
return this;
}

public Builder setConfident(boolean confident)
{
this.confident = confident;
return this;
}

public Builder addVariableStatistics(VariableReferenceExpression variable, VariableStatsEstimate statistics)
{
variableStatistics = variableStatistics.plus(variable, statistics);
Expand All @@ -263,7 +280,7 @@ public Builder removeVariableStatistics(VariableReferenceExpression variable)

public PlanNodeStatsEstimate build()
{
return new PlanNodeStatsEstimate(outputRowCount, totalSize, variableStatistics);
return new PlanNodeStatsEstimate(outputRowCount, totalSize, confident, variableStatistics);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ProjectNode node, StatsPro
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder()
.setOutputRowCount(sourceStats.getOutputRowCount());
.setOutputRowCount(sourceStats.getOutputRowCount())
.setConfident(sourceStats.isConfident() && noChangeToSourceColumns(node));

for (Map.Entry<VariableReferenceExpression, RowExpression> entry : node.getAssignments().entrySet()) {
RowExpression expression = entry.getValue();
Expand All @@ -66,4 +67,9 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ProjectNode node, StatsPro
}
return Optional.of(calculatedStats.build());
}

private boolean noChangeToSourceColumns(ProjectNode node)
{
return node.getOutputVariables().containsAll(node.getSource().getOutputVariables());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(TableScanNode node, StatsP
return Optional.of(PlanNodeStatsEstimate.builder()
.setOutputRowCount(tableStatistics.getRowCount().getValue())
.setTotalSize(tableStatistics.getTotalSize().getValue())
.setConfident(true)
.addVariableStatistics(outputVariableStats)
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public Pattern<ValuesNode> getPattern()
public Optional<PlanNodeStatsEstimate> calculate(ValuesNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
{
PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
statsBuilder.setOutputRowCount(node.getRows().size());
statsBuilder.setOutputRowCount(node.getRows().size())
.setConfident(true);

for (int variableId = 0; variableId < node.getOutputVariables().size(); ++variableId) {
VariableReferenceExpression variable = node.getOutputVariables().get(variableId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ public class FeaturesConfig
private boolean parseDecimalLiteralsAsDouble;
private boolean useMarkDistinct = true;
private boolean preferPartialAggregation = true;
private PartialAggregationStrategy partialAggregationStrategy = PartialAggregationStrategy.ALWAYS;
private double partialAggregationByteReductionThreshold = 0.5;
private boolean optimizeTopNRowNumber = true;
private boolean pushLimitThroughOuterJoin = true;

Expand Down Expand Up @@ -266,6 +268,13 @@ public enum SingleStreamSpillerChoice
TEMP_STORAGE
}

public enum PartialAggregationStrategy
{
ALWAYS, // Always do partial aggregation
NEVER, // Never do partial aggregation
AUTOMATIC // Let the optimizer decide for each aggregation
}

public double getCpuCostWeight()
{
return cpuCostWeight;
Expand Down Expand Up @@ -755,6 +764,30 @@ public FeaturesConfig setPreferPartialAggregation(boolean value)
return this;
}

public PartialAggregationStrategy getPartialAggregationStrategy()
{
return partialAggregationStrategy;
}

@Config("optimizer.partial-aggregation-strategy")
public FeaturesConfig setPartialAggregationStrategy(PartialAggregationStrategy partialAggregationStrategy)
{
this.partialAggregationStrategy = partialAggregationStrategy;
return this;
}

public double getPartialAggregationByteReductionThreshold()
{
return partialAggregationByteReductionThreshold;
}

@Config("optimizer.partial-aggregation-byte-reduction-threshold")
public FeaturesConfig setPartialAggregationByteReductionThreshold(double partialAggregationByteReductionThreshold)
{
this.partialAggregationByteReductionThreshold = partialAggregationByteReductionThreshold;
return this;
}

public boolean isOptimizeTopNRowNumber()
{
return optimizeTopNRowNumber;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
Expand All @@ -27,6 +29,7 @@
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.SymbolMapper;
Expand All @@ -41,12 +44,15 @@
import java.util.Optional;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.preferPartialAggregation;
import static com.facebook.presto.SystemSessionProperties.getPartialAggregationByteReductionThreshold;
import static com.facebook.presto.SystemSessionProperties.getPartialAggregationStrategy;
import static com.facebook.presto.operator.aggregation.AggregationUtils.isDecomposable;
import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.AUTOMATIC;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.NEVER;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
Expand Down Expand Up @@ -101,7 +107,12 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context
return Result.ofPlanNode(split(aggregationNode, context));
}

if (!decomposable || !preferPartialAggregation(context.getSession())) {
PartialAggregationStrategy partialAggregationStrategy = getPartialAggregationStrategy(context.getSession());
if (!decomposable ||
partialAggregationStrategy == NEVER ||
partialAggregationStrategy == AUTOMATIC &&
partialAggregationNotUseful(aggregationNode, exchangeNode, context) &&
aggregationNode.getGroupingKeys().size() == 1) {
return Result.empty();
}

Expand Down Expand Up @@ -265,6 +276,18 @@ private PlanNode split(AggregationNode node, Context context)
node.getGroupIdVariable());
}

private boolean partialAggregationNotUseful(AggregationNode aggregationNode, ExchangeNode exchangeNode, Context context)
{
StatsProvider stats = context.getStatsProvider();
PlanNodeStatsEstimate exchangeStats = stats.getStats(exchangeNode);
PlanNodeStatsEstimate aggregationStats = stats.getStats(aggregationNode);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This estimate might not be reliable, especially when aggregating on more than one column. What do you think about looking at the exchangeStats only? Can we extract the cardinallity for the aggregation column (if there's only a single column) and check if the cardinality above a threshold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limited the scope of disabling partial aggregation to group bys on a single column

double inputBytes = exchangeStats.getOutputSizeInBytes(exchangeNode.getOutputVariables());
double outputBytes = aggregationStats.getOutputSizeInBytes(aggregationNode.getOutputVariables());
double byteReductionThreshold = getPartialAggregationByteReductionThreshold(context.getSession());

return exchangeStats.isConfident() && outputBytes > inputBytes * byteReductionThreshold;
}

private static boolean isLambda(RowExpression rowExpression)
{
return rowExpression instanceof LambdaDefinitionExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public PlanNodeStatsAssertion totalSize(double expected)
return this;
}

public PlanNodeStatsAssertion confident(boolean expected)
{
assertEquals(actual.isConfident(), expected);
return this;
}

public PlanNodeStatsAssertion outputRowsCountUnknown()
{
assertTrue(Double.isNaN(actual.getOutputRowCount()), "expected unknown outputRowsCount but got " + actual.getOutputRowCount());
Expand Down Expand Up @@ -87,6 +93,7 @@ public PlanNodeStatsAssertion variablesWithKnownStats(VariableReferenceExpressio
public PlanNodeStatsAssertion equalTo(PlanNodeStatsEstimate expected)
{
assertEstimateEquals(actual.getOutputRowCount(), expected.getOutputRowCount(), "outputRowCount mismatch");
assertEquals(actual.isConfident(), expected.isConfident());

for (VariableReferenceExpression variable : union(expected.getVariablesWithKnownStatistics(), actual.getVariablesWithKnownStatistics())) {
assertVariableStatsEqual(variable, actual.getVariableStatistics(variable), expected.getVariableStatistics(variable));
Expand Down
Loading