Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,12 @@ private void collectOptimizerInformation(PlanOptimizer optimizer, PlanNode oldNo
isTriggered ||
!optimizer.isEnabled(session) && isVerboseOptimizerInfoEnabled(session) &&
optimizer.isApplicable(oldNode, session, TypeProvider.viewOf(variableAllocator.getVariables()), variableAllocator, idAllocator, warningCollector);
boolean isCostBased = isTriggered && optimizer.isCostBased(session);
String statsSource = optimizer.getStatsSource();

if (isTriggered || isApplicable) {
session.getOptimizerInformationCollector().addInformation(new PlanOptimizerInformation(optimizerName, isTriggered, Optional.of(isApplicable), Optional.empty()));
if (isTriggered || isApplicable || isCostBased) {
Copy link
Contributor

@vivek-bharathan vivek-bharathan Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this disjunction cause all cost-based optimizers to be logged? Currently it does not matter since none of the PlanOptimizers are cost-based, but if that should change this would not be correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right, I'll change it to only log isCostBased if the optimization triggered

session.getOptimizerInformationCollector().addInformation(
new PlanOptimizerInformation(optimizerName, isTriggered, Optional.of(isApplicable), Optional.empty(), Optional.of(isCostBased), statsSource == null ? Optional.empty() : Optional.of(statsSource)));
}

if (isTriggered && isVerboseOptimizerResults(session, optimizerName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import com.google.common.collect.ImmutableList;
import io.airlift.units.Duration;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -195,7 +194,7 @@ private boolean exploreNode(int group, Context context, Matcher matcher)
transformedNode = transformedNode.assignStatsEquivalentPlanNode(node.getStatsEquivalentPlanNode());
}
}
context.addRulesTriggered(rule.getClass().getSimpleName(), node, transformedNode);
context.addRulesTriggered(rule.getClass().getSimpleName(), node, transformedNode, rule.isCostBased(context.session), rule.getStatsSource());
node = context.memo.replace(group, transformedNode, rule.getClass().getName());

done = false;
Expand Down Expand Up @@ -337,12 +336,16 @@ private static class RuleTriggered
private final String rule;
private final Optional<String> oldNode;
private final Optional<String> newNode;
private boolean isCostBased;
private final Optional<String> statsSource;

public RuleTriggered(String rule, Optional<String> oldNode, Optional<String> newNode)
public RuleTriggered(String rule, Optional<String> oldNode, Optional<String> newNode, boolean isCostBased, String statsSource)
{
this.rule = requireNonNull(rule, "rule is null");
this.oldNode = requireNonNull(oldNode, "oldNode is null");
this.newNode = requireNonNull(newNode, "newNode is null");
this.isCostBased = isCostBased;
this.statsSource = statsSource == null ? Optional.empty() : Optional.of(statsSource);
}

public String getRule()
Expand All @@ -359,6 +362,16 @@ public Optional<String> getNewNode()
{
return newNode;
}

public boolean isCostBased()
{
return isCostBased;
}

public Optional<String> getStatsSource()
{
return statsSource;
}
}

private static class Context
Expand All @@ -373,7 +386,7 @@ private static class Context
private final WarningCollector warningCollector;
private final CostProvider costProvider;
private final StatsProvider statsProvider;
private final List<RuleTriggered> rulesTriggered;
private final Set<RuleTriggered> rulesTriggered;
private final Set<String> rulesApplicable;
private final Metadata metadata;
private final TypeProvider types;
Expand Down Expand Up @@ -406,7 +419,7 @@ public Context(
this.statsProvider = statsProvider;
this.metadata = metadata;
this.types = types;
this.rulesTriggered = new ArrayList<>();
this.rulesTriggered = new HashSet<>();
this.rulesApplicable = new HashSet<>();
}

Expand All @@ -417,7 +430,7 @@ public void checkTimeoutNotExhausted()
}
}

public void addRulesTriggered(String rule, PlanNode oldNode, PlanNode newNode)
public void addRulesTriggered(String rule, PlanNode oldNode, PlanNode newNode, boolean isCostBased, String statsSource)
{
Optional<String> before = Optional.empty();
Optional<String> after = Optional.empty();
Expand All @@ -427,7 +440,7 @@ public void addRulesTriggered(String rule, PlanNode oldNode, PlanNode newNode)
after = Optional.of(PlannerUtils.getPlanString(newNode, session, types, metadata, false));
}

rulesTriggered.add(new RuleTriggered(rule, before, after));
rulesTriggered.add(new RuleTriggered(rule, before, after, isCostBased, statsSource));
}

public void addRulesApplicable(String rule)
Expand All @@ -437,11 +450,15 @@ public void addRulesApplicable(String rule)

public void collectOptimizerInformation()
{
rulesTriggered.stream().map(x -> x.getRule()).distinct().forEach(rule -> session.getOptimizerInformationCollector().addInformation(new PlanOptimizerInformation(rule, true, Optional.empty(), Optional.empty())));
rulesTriggered.stream().map(
x -> new PlanOptimizerInformation(x.getRule(), true, Optional.empty(), Optional.empty(), Optional.of(x.isCostBased()), x.getStatsSource()))
.distinct().forEach(rule -> session.getOptimizerInformationCollector().addInformation(rule));

if (SystemSessionProperties.isVerboseOptimizerResults(session)) {
rulesTriggered.stream().filter(x -> x.getNewNode().isPresent()).forEach(x -> session.getOptimizerResultCollector().addOptimizerResult(x.getRule(), x.getOldNode().get(), x.getNewNode().get()));
}
rulesApplicable.forEach(x -> session.getOptimizerInformationCollector().addInformation(new PlanOptimizerInformation(x, false, Optional.of(true), Optional.empty())));
rulesApplicable.forEach(x -> session.getOptimizerInformationCollector().addInformation(
new PlanOptimizerInformation(x, false, Optional.of(true), Optional.empty(), Optional.empty(), Optional.empty())));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ default boolean isEnabled(Session session)
{
return true;
}
default boolean isCostBased(Session session)
{
return false;
}

default String getStatsSource()
{
return null;
}

Result apply(T node, Captures captures, Context context);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.LocalCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
Expand Down Expand Up @@ -63,12 +64,27 @@ public class DetermineJoinDistributionType
private final CostComparator costComparator;
private final TaskCountEstimator taskCountEstimator;

// records whether distribution decision was cost-based
private String statsSource;

public DetermineJoinDistributionType(CostComparator costComparator, TaskCountEstimator taskCountEstimator)
{
this.costComparator = requireNonNull(costComparator, "costComparator is null");
this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null");
}

@Override
public boolean isCostBased(Session session)
{
return getJoinDistributionType(session) == AUTOMATIC;
}

@Override
public String getStatsSource()
{
return statsSource;
}

@Override
public Pattern<JoinNode> getPattern()
{
Expand All @@ -80,7 +96,9 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
{
JoinDistributionType joinDistributionType = getJoinDistributionType(context.getSession());
if (joinDistributionType == AUTOMATIC) {
return Result.ofPlanNode(getCostBasedJoin(joinNode, context));
PlanNode resultNode = getCostBasedJoin(joinNode, context);
statsSource = context.getStatsProvider().getStats(joinNode).getSourceInfo().getSourceInfoName();
return Result.ofPlanNode(resultNode);
}
return Result.ofPlanNode(getSyntacticOrderJoin(joinNode, context, joinDistributionType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.LocalCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
Expand Down Expand Up @@ -62,10 +63,13 @@
public class DetermineSemiJoinDistributionType
implements Rule<SemiJoinNode>
{
private static final Pattern<SemiJoinNode> PATTERN = semiJoin().matching(semiJoin -> !semiJoin.getDistributionType().isPresent());

private final TaskCountEstimator taskCountEstimator;
private final CostComparator costComparator;

private static final Pattern<SemiJoinNode> PATTERN = semiJoin().matching(semiJoin -> !semiJoin.getDistributionType().isPresent());
// records whether distribution decision was cost-based
private String statsSource;

public DetermineSemiJoinDistributionType(CostComparator costComparator, TaskCountEstimator taskCountEstimator)
{
Expand All @@ -79,13 +83,27 @@ public Pattern<SemiJoinNode> getPattern()
return PATTERN;
}

@Override
public boolean isCostBased(Session session)
{
return getJoinDistributionType(session) == JoinDistributionType.AUTOMATIC;
}

@Override
public String getStatsSource()
{
return statsSource;
}

@Override
public Result apply(SemiJoinNode semiJoinNode, Captures captures, Context context)
{
JoinDistributionType joinDistributionType = getJoinDistributionType(context.getSession());
switch (joinDistributionType) {
case AUTOMATIC:
return Result.ofPlanNode(getCostBasedDistributionType(semiJoinNode, context));
PlanNode resultNode = getCostBasedDistributionType(semiJoinNode, context);
statsSource = context.getStatsProvider().getStats(semiJoinNode).getSourceInfo().getSourceInfoName();
return Result.ofPlanNode(resultNode);
case PARTITIONED:
return Result.ofPlanNode(semiJoinNode.withDistributionType(PARTITIONED));
case BROADCAST:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.matching.Capture;
Expand Down Expand Up @@ -68,6 +69,7 @@ public class PushPartialAggregationThroughExchange
implements Rule<AggregationNode>
{
private final FunctionAndTypeManager functionAndTypeManager;
private String statsSource;

public PushPartialAggregationThroughExchange(FunctionAndTypeManager functionAndTypeManager)
{
Expand All @@ -88,6 +90,18 @@ public Pattern<AggregationNode> getPattern()
return PATTERN;
}

@Override
public boolean isCostBased(Session session)
{
return getPartialAggregationStrategy(session) == AUTOMATIC;
}

@Override
public String getStatsSource()
{
return statsSource;
}

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
{
Expand Down Expand Up @@ -145,18 +159,30 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context
return Result.empty();
}

PlanNode resultNode = null;
switch (aggregationNode.getStep()) {
case SINGLE:
// Split it into a FINAL on top of a PARTIAL and
return Result.ofPlanNode(split(aggregationNode, context));
resultNode = split(aggregationNode, context);
storeStatsSourceInfo(context, partialAggregationStrategy, aggregationNode);
return Result.ofPlanNode(resultNode);
case PARTIAL:
// Push it underneath each branch of the exchange
return Result.ofPlanNode(pushPartial(aggregationNode, exchangeNode, context));
resultNode = pushPartial(aggregationNode, exchangeNode, context);
storeStatsSourceInfo(context, partialAggregationStrategy, aggregationNode);
return Result.ofPlanNode(resultNode);
default:
return Result.empty();
}
}

private void storeStatsSourceInfo(Context context, PartialAggregationStrategy partialAggregationStrategy, PlanNode resultNode)
{
if (partialAggregationStrategy == AUTOMATIC) {
statsSource = context.getStatsProvider().getStats(resultNode).getSourceInfo().getSourceInfoName();
}
}

private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, Context context)
{
List<PlanNode> partials = new ArrayList<>();
Expand Down Expand Up @@ -299,6 +325,7 @@ private boolean partialAggregationNotUseful(AggregationNode aggregationNode, Exc
double outputBytes = aggregationStats.getOutputSizeInBytes(aggregationNode);
double byteReductionThreshold = getPartialAggregationByteReductionThreshold(context.getSession());

// calling this function means we are using a cost-based strategy for this optimization
return exchangeStats.isConfident() && outputBytes > inputBytes * byteReductionThreshold;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ public class ReorderJoins
private final Metadata metadata;
private final FunctionResolution functionResolution;
private final DeterminismEvaluator determinismEvaluator;
private String statsSource;

public ReorderJoins(CostComparator costComparator, Metadata metadata)
{
Expand All @@ -132,6 +133,18 @@ public boolean isEnabled(Session session)
return getJoinReorderingStrategy(session) == AUTOMATIC;
}

@Override
public boolean isCostBased(Session session)
{
// when enabled, join order is always cost-based
return isEnabled(session);
}

public String getStatsSource()
{
return statsSource;
}

@Override
public Result apply(JoinNode joinNode, Captures captures, Context context)
{
Expand All @@ -147,6 +160,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
if (!result.getPlanNode().isPresent()) {
return Result.empty();
}
statsSource = context.getStatsProvider().getStats(joinNode).getSourceInfo().getSourceInfoName();
return Result.ofPlanNode(result.getPlanNode().get());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ private boolean checkTimeOut(long startTimeInNano, long timeoutInMilliseconds)

private void logOptimizerFailure(Session session)
{
session.getOptimizerInformationCollector().addInformation(new PlanOptimizerInformation(HistoricalStatisticsEquivalentPlanMarkingOptimizer.class.getSimpleName(), false, Optional.empty(), Optional.of(true)));
session.getOptimizerInformationCollector().addInformation(
new PlanOptimizerInformation(HistoricalStatisticsEquivalentPlanMarkingOptimizer.class.getSimpleName(), false, Optional.empty(), Optional.of(true), Optional.empty(), Optional.empty()));
}

private static class Rewriter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ private AggregationNode createPartialAggregationNode(AggregationNode node, PlanN
Map<VariableReferenceExpression, AggregationNode.Aggregation> newAggregations = node.getAggregations().entrySet().stream()
.filter(x -> !partialResultToMerge.contains(x.getKey())).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));

session.getOptimizerInformationCollector().addInformation(new PlanOptimizerInformation(MergePartialAggregationsWithFilter.class.getSimpleName(), true, Optional.empty(), Optional.empty()));
session.getOptimizerInformationCollector().addInformation(
new PlanOptimizerInformation(MergePartialAggregationsWithFilter.class.getSimpleName(), true, Optional.empty(), Optional.empty(), Optional.of(false), Optional.empty()));

return new AggregationNode(
node.getSourceLocation(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> cont
}

// replace the tablescan node with a values node
session.getOptimizerInformationCollector().addInformation(new PlanOptimizerInformation(MetadataQueryOptimizer.class.getSimpleName(), true, Optional.empty(), Optional.empty()));
session.getOptimizerInformationCollector().addInformation(
new PlanOptimizerInformation(MetadataQueryOptimizer.class.getSimpleName(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()));
return SimplePlanRewriter.rewriteWith(new Replacer(new ValuesNode(node.getSourceLocation(), idAllocator.getNextId(), inputs, rowsBuilder.build(), Optional.empty())), node);
}

Expand Down Expand Up @@ -324,7 +325,8 @@ else if (value.getValue() != null) {
return context.defaultRewrite(node);
}
}
session.getOptimizerInformationCollector().addInformation(new PlanOptimizerInformation(MetadataQueryOptimizer.class.getSimpleName(), true, Optional.empty(), Optional.empty()));
session.getOptimizerInformationCollector().addInformation(
new PlanOptimizerInformation(MetadataQueryOptimizer.class.getSimpleName(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()));
Assignments assignments = assignmentsBuilder.build();
ValuesNode valuesNode = new ValuesNode(node.getSourceLocation(), idAllocator.getNextId(), node.getOutputVariables(), ImmutableList.of(new ArrayList<>(assignments.getExpressions())), Optional.empty());
return new ProjectNode(node.getSourceLocation(), idAllocator.getNextId(), valuesNode, assignments, LOCAL);
Expand Down
Loading