Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.testing.InMemoryHistoryBasedPlanStatisticsProvider;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.facebook.presto.tests.DistributedQueryRunner;
import com.facebook.presto.tests.statistics.InMemoryHistoryBasedPlanStatisticsProvider;
import com.google.common.collect.ImmutableList;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -98,8 +98,8 @@ public void testBroadcastJoin()

// CBO Statistics
Plan plan = plan("SELECT * FROM " +
"(SELECT * FROM test_orders where ds = '2020-09-01' and substr(CAST(custkey AS VARCHAR), 1, 3) <> '370') t1 JOIN " +
"(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey", createSession());
"(SELECT * FROM test_orders where ds = '2020-09-01' and substr(CAST(custkey AS VARCHAR), 1, 3) <> '370') t1 JOIN " +
"(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey", createSession());

assertTrue(PlanNodeSearcher.searchFrom(plan.getRoot())
.where(node -> node instanceof JoinNode && ((JoinNode) node).getDistributionType().get().equals(JoinNode.DistributionType.PARTITIONED))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ private static Optional<Integer> getSimilarStatsIndex(
return Optional.empty();
}

private static boolean similarStats(double stats1, double stats2, double threshold)
public static boolean similarStats(double stats1, double stats2, double threshold)
{
if (isNaN(stats1) && isNaN(stats2)) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ public PlanCanonicalInfoProvider getPlanCanonicalInfoProvider()
return planCanonicalInfoProvider;
}

@VisibleForTesting
public StatsCalculator getDelegate()
{
return delegate;
}

@VisibleForTesting
public Supplier<HistoryBasedPlanStatisticsProvider> getHistoryBasedPlanStatisticsProvider()
{
return historyBasedPlanStatisticsProvider;
}

private Map<PlanCanonicalizationStrategy, PlanNodeWithHash> getPlanNodeHashes(PlanNode plan, Session session)
{
if (!useHistoryBasedPlanStatisticsEnabled(session) || !plan.getStatsEquivalentPlanNode().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static StatsCalculator createNewStatsCalculator(
return historyBasedPlanStatisticsManager.getHistoryBasedPlanStatisticsCalculator(delegate);
}

private static ComposableStatsCalculator createComposableStatsCalculator(
public static ComposableStatsCalculator createComposableStatsCalculator(
Metadata metadata,
ScalarStatsCalculator scalarStatsCalculator,
StatsNormalizer normalizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ private PlanNode createRemoteStreamingExchange(ExchangeNode exchange, RewriteCon
.map(PlanFragment::getId)
.collect(toImmutableList());

return new RemoteSourceNode(exchange.getSourceLocation(), exchange.getId(), childrenIds, exchange.getOutputVariables(), exchange.isEnsureSourceOrdering(), exchange.getOrderingScheme(), exchange.getType());
return new RemoteSourceNode(exchange.getSourceLocation(), exchange.getId(), exchange.getStatsEquivalentPlanNode(), childrenIds, exchange.getOutputVariables(), exchange.isEnsureSourceOrdering(), exchange.getOrderingScheme(), exchange.getType());
}

protected void setDistributionForExchange(ExchangeNode.Type exchangeType, PartitioningScheme partitioningScheme, RewriteContext<FragmentProperties> context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,37 @@
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import io.airlift.units.DataSize;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;

import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType;
import static com.facebook.presto.SystemSessionProperties.getJoinMaxBroadcastTableSize;
import static com.facebook.presto.SystemSessionProperties.isSizeBasedJoinDistributionTypeEnabled;
import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateJoinCostWithoutOutput;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.AUTOMATIC;
import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isBelowBroadcastLimit;
import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isSmallerThanThreshold;
import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED;
import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;
import static java.util.Objects.requireNonNull;

public class DetermineJoinDistributionType
implements Rule<JoinNode>
{
private static final Pattern<JoinNode> PATTERN = join().matching(joinNode -> !joinNode.getDistributionType().isPresent());
private static final List<Class<? extends PlanNode>> EXPANDING_NODE_CLASSES = ImmutableList.of(JoinNode.class, UnnestNode.class);
private static final double SIZE_DIFFERENCE_THRESHOLD = 8;

private final CostComparator costComparator;
private final TaskCountEstimator taskCountEstimator;
Expand Down Expand Up @@ -93,7 +90,7 @@ public static boolean isBelowMaxBroadcastSize(JoinNode joinNode, Context context
PlanNodeStatsEstimate buildSideStatsEstimate = context.getStatsProvider().getStats(buildSide);
double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide);
return buildSideSizeInBytes <= joinMaxBroadcastTableSize.toBytes()
|| (isSizeBasedJoinDistributionTypeEnabled(context.getSession())
|| (isSizeBasedJoinDistributionTypeEnabled(context.getSession())
&& getSourceTablesSizeInBytes(buildSide, context) <= joinMaxBroadcastTableSize.toBytes());
}

Expand Down Expand Up @@ -160,19 +157,6 @@ private JoinNode getSizeBasedJoin(JoinNode joinNode, Context context)
return getSyntacticOrderJoin(joinNode, context, AUTOMATIC);
}

public static boolean isBelowBroadcastLimit(PlanNode planNode, Context context)
{
DataSize joinMaxBroadcastTableSize = getJoinMaxBroadcastTableSize(context.getSession());
return getSourceTablesSizeInBytes(planNode, context) <= joinMaxBroadcastTableSize.toBytes();
}

public static boolean isSmallerThanThreshold(PlanNode planNodeA, PlanNode planNodeB, Context context)
{
double aOutputSize = getFirstKnownOutputSizeInBytes(planNodeA, context);
double bOutputSize = getFirstKnownOutputSizeInBytes(planNodeB, context);
return aOutputSize * SIZE_DIFFERENCE_THRESHOLD < bOutputSize;
}

public static double getSourceTablesSizeInBytes(PlanNode node, Context context)
{
return getSourceTablesSizeInBytes(node, context.getLookup(), context.getStatsProvider());
Expand All @@ -182,71 +166,21 @@ public static double getSourceTablesSizeInBytes(PlanNode node, Context context)
static double getSourceTablesSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider)
{
boolean hasExpandingNodes = PlanNodeSearcher.searchFrom(node, lookup)
.whereIsInstanceOfAny(EXPANDING_NODE_CLASSES)
.whereIsInstanceOfAny(JoinSwappingUtils.EXPANDING_NODE_CLASSES)
.matches();
if (hasExpandingNodes) {
return NaN;
}

List<PlanNode> sourceNodes = PlanNodeSearcher.searchFrom(node, lookup)
.whereIsInstanceOfAny(ImmutableList.of(TableScanNode.class, ValuesNode.class))
.whereIsInstanceOfAny(ImmutableList.of(TableScanNode.class, ValuesNode.class, RemoteSourceNode.class))
.findAll();

return sourceNodes.stream()
.mapToDouble(sourceNode -> statsProvider.getStats(sourceNode).getOutputSizeInBytes(sourceNode))
.sum();
}

private static double getFirstKnownOutputSizeInBytes(PlanNode node, Context context)
{
return getFirstKnownOutputSizeInBytes(node, context.getLookup(), context.getStatsProvider());
}

/**
* Recursively looks for the first source node with a known estimate and uses that to return an approximate output size.
* Returns NaN if an un-estimated expanding node (Join or Unnest) is encountered.
* The amount of reduction in size from un-estimated non-expanding nodes (e.g. an un-estimated filter or aggregation)
* is not accounted here. We make use of the first available estimate and make decision about flipping join sides only if
* we find a large difference in output size of both sides.
*/
@VisibleForTesting
public static double getFirstKnownOutputSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider)
{
return Stream.of(node)
.flatMap(planNode -> {
if (planNode instanceof GroupReference) {
return lookup.resolveGroup(node);
}
return Stream.of(planNode);
})
.mapToDouble(resolvedNode -> {
double outputSizeInBytes = statsProvider.getStats(resolvedNode).getOutputSizeInBytes(resolvedNode);
if (!isNaN(outputSizeInBytes)) {
return outputSizeInBytes;
}

if (EXPANDING_NODE_CLASSES.stream().anyMatch(clazz -> clazz.isInstance(resolvedNode))) {
return NaN;
}

List<PlanNode> sourceNodes = resolvedNode.getSources();
if (sourceNodes.isEmpty()) {
return NaN;
}

double sourcesOutputSizeInBytes = 0;
for (PlanNode sourceNode : sourceNodes) {
double firstKnownOutputSizeInBytes = getFirstKnownOutputSizeInBytes(sourceNode, lookup, statsProvider);
if (isNaN(firstKnownOutputSizeInBytes)) {
return NaN;
}
sourcesOutputSizeInBytes += firstKnownOutputSizeInBytes;
}
return sourcesOutputSizeInBytes;
})
.sum();
}

private void addJoinsWithDifferentDistributions(JoinNode joinNode, List<PlanNodeWithCost> possibleJoinNodes, Context context)
{
if (!mustPartition(joinNode) && isBelowMaxBroadcastSize(joinNode, context)) {
Expand Down
Loading