diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java index 8b255d2781316..cf0307914b014 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.statistics.HistoryBasedSourceInfo; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; @@ -126,7 +127,10 @@ private PlanNode getCostBasedJoin(JoinNode joinNode, Context context) if (possibleJoinNodes.stream().anyMatch(result -> result.getCost().hasUnknownComponents()) || possibleJoinNodes.isEmpty()) { // TODO: currently this session parameter is added so as to roll out the plan change gradually, after proved to be a better choice, make it default and get rid of the session parameter here. if (isUseBroadcastJoinWhenBuildSizeSmallProbeSizeUnknownEnabled(context.getSession()) && possibleJoinNodes.stream().anyMatch(result -> ((JoinNode) result.getPlanNode()).getDistributionType().get().equals(REPLICATED))) { - return getOnlyElement(possibleJoinNodes.stream().filter(result -> ((JoinNode) result.getPlanNode()).getDistributionType().get().equals(REPLICATED)).map(x -> x.getPlanNode()).collect(toImmutableList())); + JoinNode broadcastJoin = (JoinNode) getOnlyElement(possibleJoinNodes.stream().filter(result -> ((JoinNode) result.getPlanNode()).getDistributionType().get().equals(REPLICATED)).map(x -> x.getPlanNode()).collect(toImmutableList())); + if (context.getStatsProvider().getStats(broadcastJoin.getBuild()).getSourceInfo() instanceof HistoryBasedSourceInfo) { + return broadcastJoin; + } } if (isSizeBasedJoinDistributionTypeEnabled(context.getSession())) { return getSizeBasedJoin(joinNode, context); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java index c055d480510cd..085b80129095b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java @@ -36,6 +36,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.statistics.HistoryBasedSourceInfo; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.SemiJoinNode; @@ -48,11 +49,14 @@ import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; import static com.facebook.presto.SystemSessionProperties.getJoinMaxBroadcastTableSize; import static com.facebook.presto.SystemSessionProperties.isSizeBasedJoinDistributionTypeEnabled; +import static com.facebook.presto.SystemSessionProperties.isUseBroadcastJoinWhenBuildSizeSmallProbeSizeUnknownEnabled; import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateJoinCostWithoutOutput; import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.getSourceTablesSizeInBytes; import static com.facebook.presto.sql.planner.plan.Patterns.semiJoin; import static com.facebook.presto.sql.planner.plan.SemiJoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.SemiJoinNode.DistributionType.REPLICATED; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; /** @@ -124,6 +128,12 @@ private PlanNode getCostBasedDistributionType(SemiJoinNode node, Context context possibleJoinNodes.add(getSemiJoinNodeWithCost(node.withDistributionType(PARTITIONED), context)); if (possibleJoinNodes.stream().anyMatch(result -> result.getCost().hasUnknownComponents())) { + if (isUseBroadcastJoinWhenBuildSizeSmallProbeSizeUnknownEnabled(context.getSession()) && possibleJoinNodes.stream().anyMatch(result -> ((SemiJoinNode) result.getPlanNode()).getDistributionType().get().equals(REPLICATED))) { + SemiJoinNode broadcastJoin = (SemiJoinNode) getOnlyElement(possibleJoinNodes.stream().filter(result -> ((SemiJoinNode) result.getPlanNode()).getDistributionType().get().equals(REPLICATED)).map(x -> x.getPlanNode()).collect(toImmutableList())); + if (context.getStatsProvider().getStats(broadcastJoin.getBuild()).getSourceInfo() instanceof HistoryBasedSourceInfo) { + return broadcastJoin; + } + } if (isSizeBasedJoinDistributionTypeEnabled(context.getSession())) { return getSizeBaseDistributionType(node, context); } @@ -155,7 +165,7 @@ private boolean canReplicate(SemiJoinNode node, Context context) PlanNodeStatsEstimate buildSideStatsEstimate = context.getStatsProvider().getStats(buildSide); double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide); return buildSideSizeInBytes <= joinMaxBroadcastTableSize.toBytes() - || (isSizeBasedJoinDistributionTypeEnabled(context.getSession()) + || (isSizeBasedJoinDistributionTypeEnabled(context.getSession()) && getSourceTablesSizeInBytes(buildSide, context) <= joinMaxBroadcastTableSize.toBytes()); }