diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java index 8c43ba6ded079..b870e7d45a1a9 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java @@ -22,10 +22,10 @@ import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.testing.InMemoryHistoryBasedPlanStatisticsProvider; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; import com.facebook.presto.tests.DistributedQueryRunner; -import com.facebook.presto.tests.statistics.InMemoryHistoryBasedPlanStatisticsProvider; import com.google.common.collect.ImmutableList; import org.intellij.lang.annotations.Language; import org.testng.annotations.Test; @@ -98,8 +98,8 @@ public void testBroadcastJoin() // CBO Statistics Plan plan = plan("SELECT * FROM " + - "(SELECT * FROM test_orders where ds = '2020-09-01' and substr(CAST(custkey AS VARCHAR), 1, 3) <> '370') t1 JOIN " + - "(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey", createSession()); + "(SELECT * FROM test_orders where ds = '2020-09-01' and substr(CAST(custkey AS VARCHAR), 1, 3) <> '370') t1 JOIN " + + "(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey", createSession()); assertTrue(PlanNodeSearcher.searchFrom(plan.getRoot()) .where(node -> node instanceof JoinNode && ((JoinNode) node).getDistributionType().get().equals(JoinNode.DistributionType.PARTITIONED)) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/HistoricalPlanStatisticsUtil.java b/presto-main/src/main/java/com/facebook/presto/cost/HistoricalPlanStatisticsUtil.java index 9e943dcbc7ee1..c9cab9afac3cf 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/HistoricalPlanStatisticsUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/HistoricalPlanStatisticsUtil.java @@ -113,7 +113,7 @@ private static Optional getSimilarStatsIndex( return Optional.empty(); } - private static boolean similarStats(double stats1, double stats2, double threshold) + public static boolean similarStats(double stats1, double stats2, double threshold) { if (isNaN(stats1) && isNaN(stats2)) { return true; diff --git a/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsCalculator.java index b6572154d06fd..e1c8730ce5a38 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsCalculator.java @@ -106,6 +106,18 @@ public PlanCanonicalInfoProvider getPlanCanonicalInfoProvider() return planCanonicalInfoProvider; } + @VisibleForTesting + public StatsCalculator getDelegate() + { + return delegate; + } + + @VisibleForTesting + public Supplier getHistoryBasedPlanStatisticsProvider() + { + return historyBasedPlanStatisticsProvider; + } + private Map getPlanNodeHashes(PlanNode plan, Session session) { if (!useHistoryBasedPlanStatisticsEnabled(session) || !plan.getStatsEquivalentPlanNode().isPresent()) { diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculatorModule.java b/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculatorModule.java index 65c7eb9c7317d..83020a96b1206 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculatorModule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculatorModule.java @@ -52,7 +52,7 @@ public static StatsCalculator createNewStatsCalculator( return historyBasedPlanStatisticsManager.getHistoryBasedPlanStatisticsCalculator(delegate); } - private static ComposableStatsCalculator createComposableStatsCalculator( + public static ComposableStatsCalculator createComposableStatsCalculator( Metadata metadata, ScalarStatsCalculator scalarStatsCalculator, StatsNormalizer normalizer, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index 921cba5d43f3b..453cb4683bc32 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -328,7 +328,7 @@ private PlanNode createRemoteStreamingExchange(ExchangeNode exchange, RewriteCon .map(PlanFragment::getId) .collect(toImmutableList()); - return new RemoteSourceNode(exchange.getSourceLocation(), exchange.getId(), childrenIds, exchange.getOutputVariables(), exchange.isEnsureSourceOrdering(), exchange.getOrderingScheme(), exchange.getType()); + return new RemoteSourceNode(exchange.getSourceLocation(), exchange.getId(), exchange.getStatsEquivalentPlanNode(), childrenIds, exchange.getOutputVariables(), exchange.isEnsureSourceOrdering(), exchange.getOrderingScheme(), exchange.getType()); } protected void setDistributionForExchange(ExchangeNode.Type exchangeType, PartitioningScheme partitioningScheme, RewriteContext context) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java index 4a8dfa2504bb6..25fe341dda7f5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java @@ -25,12 +25,11 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; -import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.JoinNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; +import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Ordering; @@ -38,27 +37,25 @@ import java.util.ArrayList; import java.util.List; -import java.util.stream.Stream; import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; import static com.facebook.presto.SystemSessionProperties.getJoinMaxBroadcastTableSize; import static com.facebook.presto.SystemSessionProperties.isSizeBasedJoinDistributionTypeEnabled; import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateJoinCostWithoutOutput; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.AUTOMATIC; +import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isBelowBroadcastLimit; +import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isSmallerThanThreshold; import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; import static com.facebook.presto.sql.planner.plan.Patterns.join; import static java.lang.Double.NaN; -import static java.lang.Double.isNaN; import static java.util.Objects.requireNonNull; public class DetermineJoinDistributionType implements Rule { private static final Pattern PATTERN = join().matching(joinNode -> !joinNode.getDistributionType().isPresent()); - private static final List> EXPANDING_NODE_CLASSES = ImmutableList.of(JoinNode.class, UnnestNode.class); - private static final double SIZE_DIFFERENCE_THRESHOLD = 8; private final CostComparator costComparator; private final TaskCountEstimator taskCountEstimator; @@ -93,7 +90,7 @@ public static boolean isBelowMaxBroadcastSize(JoinNode joinNode, Context context PlanNodeStatsEstimate buildSideStatsEstimate = context.getStatsProvider().getStats(buildSide); double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide); return buildSideSizeInBytes <= joinMaxBroadcastTableSize.toBytes() - || (isSizeBasedJoinDistributionTypeEnabled(context.getSession()) + || (isSizeBasedJoinDistributionTypeEnabled(context.getSession()) && getSourceTablesSizeInBytes(buildSide, context) <= joinMaxBroadcastTableSize.toBytes()); } @@ -160,19 +157,6 @@ private JoinNode getSizeBasedJoin(JoinNode joinNode, Context context) return getSyntacticOrderJoin(joinNode, context, AUTOMATIC); } - public static boolean isBelowBroadcastLimit(PlanNode planNode, Context context) - { - DataSize joinMaxBroadcastTableSize = getJoinMaxBroadcastTableSize(context.getSession()); - return getSourceTablesSizeInBytes(planNode, context) <= joinMaxBroadcastTableSize.toBytes(); - } - - public static boolean isSmallerThanThreshold(PlanNode planNodeA, PlanNode planNodeB, Context context) - { - double aOutputSize = getFirstKnownOutputSizeInBytes(planNodeA, context); - double bOutputSize = getFirstKnownOutputSizeInBytes(planNodeB, context); - return aOutputSize * SIZE_DIFFERENCE_THRESHOLD < bOutputSize; - } - public static double getSourceTablesSizeInBytes(PlanNode node, Context context) { return getSourceTablesSizeInBytes(node, context.getLookup(), context.getStatsProvider()); @@ -182,14 +166,14 @@ public static double getSourceTablesSizeInBytes(PlanNode node, Context context) static double getSourceTablesSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider) { boolean hasExpandingNodes = PlanNodeSearcher.searchFrom(node, lookup) - .whereIsInstanceOfAny(EXPANDING_NODE_CLASSES) + .whereIsInstanceOfAny(JoinSwappingUtils.EXPANDING_NODE_CLASSES) .matches(); if (hasExpandingNodes) { return NaN; } List sourceNodes = PlanNodeSearcher.searchFrom(node, lookup) - .whereIsInstanceOfAny(ImmutableList.of(TableScanNode.class, ValuesNode.class)) + .whereIsInstanceOfAny(ImmutableList.of(TableScanNode.class, ValuesNode.class, RemoteSourceNode.class)) .findAll(); return sourceNodes.stream() @@ -197,56 +181,6 @@ static double getSourceTablesSizeInBytes(PlanNode node, Lookup lookup, StatsProv .sum(); } - private static double getFirstKnownOutputSizeInBytes(PlanNode node, Context context) - { - return getFirstKnownOutputSizeInBytes(node, context.getLookup(), context.getStatsProvider()); - } - - /** - * Recursively looks for the first source node with a known estimate and uses that to return an approximate output size. - * Returns NaN if an un-estimated expanding node (Join or Unnest) is encountered. - * The amount of reduction in size from un-estimated non-expanding nodes (e.g. an un-estimated filter or aggregation) - * is not accounted here. We make use of the first available estimate and make decision about flipping join sides only if - * we find a large difference in output size of both sides. - */ - @VisibleForTesting - public static double getFirstKnownOutputSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider) - { - return Stream.of(node) - .flatMap(planNode -> { - if (planNode instanceof GroupReference) { - return lookup.resolveGroup(node); - } - return Stream.of(planNode); - }) - .mapToDouble(resolvedNode -> { - double outputSizeInBytes = statsProvider.getStats(resolvedNode).getOutputSizeInBytes(resolvedNode); - if (!isNaN(outputSizeInBytes)) { - return outputSizeInBytes; - } - - if (EXPANDING_NODE_CLASSES.stream().anyMatch(clazz -> clazz.isInstance(resolvedNode))) { - return NaN; - } - - List sourceNodes = resolvedNode.getSources(); - if (sourceNodes.isEmpty()) { - return NaN; - } - - double sourcesOutputSizeInBytes = 0; - for (PlanNode sourceNode : sourceNodes) { - double firstKnownOutputSizeInBytes = getFirstKnownOutputSizeInBytes(sourceNode, lookup, statsProvider); - if (isNaN(firstKnownOutputSizeInBytes)) { - return NaN; - } - sourcesOutputSizeInBytes += firstKnownOutputSizeInBytes; - } - return sourcesOutputSizeInBytes; - }) - .sum(); - } - private void addJoinsWithDifferentDistributions(JoinNode joinNode, List possibleJoinNodes, Context context) { if (!mustPartition(joinNode) && isBelowMaxBroadcastSize(joinNode, context)) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java new file mode 100644 index 0000000000000..058f185ac255b --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java @@ -0,0 +1,250 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.iterative.GroupReference; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties; +import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.UnnestNode; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; + +import static com.facebook.presto.SystemSessionProperties.getJoinMaxBroadcastTableSize; +import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency; +import static com.facebook.presto.SystemSessionProperties.isJoinSpillingEnabled; +import static com.facebook.presto.SystemSessionProperties.isSpillEnabled; +import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.defaultParallelism; +import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.exactlyPartitionedOn; +import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.fixedParallelism; +import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.singleStream; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.systemPartitionedExchange; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.Double.NaN; +import static java.lang.Double.isNaN; + +public class JoinSwappingUtils +{ + static final List> EXPANDING_NODE_CLASSES = ImmutableList.of(JoinNode.class, UnnestNode.class); + private static final double SIZE_DIFFERENCE_THRESHOLD = 8; + + private JoinSwappingUtils() {} + + public static Optional createRuntimeSwappedJoinNode( + JoinNode joinNode, + Metadata metadata, + SqlParser parser, + Lookup lookup, + Session session, + VariableAllocator variableAllocator, + PlanNodeIdAllocator idAllocator) + { + JoinNode swapped = joinNode.flipChildren(); + + PlanNode newLeft = swapped.getLeft(); + Optional leftHashVariable = swapped.getLeftHashVariable(); + // Remove unnecessary LocalExchange in the current probe side. If the immediate left child (new probe side) of the join node + // is a localExchange, there are two cases: an Exchange introduced by the current probe side (previous build side); or it is a UnionNode. + // If the exchangeNode has more than 1 sources, it corresponds to the second case, otherwise it corresponds to the first case and could be safe to remove + PlanNode resolvedSwappedLeft = lookup.resolve(newLeft); + if (resolvedSwappedLeft instanceof ExchangeNode && resolvedSwappedLeft.getSources().size() == 1) { + // Ensure the new probe after skipping the local exchange will satisfy the required probe side property + if (checkProbeSidePropertySatisfied(resolvedSwappedLeft.getSources().get(0), metadata, parser, lookup, session, variableAllocator)) { + newLeft = resolvedSwappedLeft.getSources().get(0); + // The HashGenerationOptimizer will generate hashVariables and append to the output layout of the nodes following the same order. Therefore, + // we use the index of the old hashVariable in the ExchangeNode output layout to retrieve the hashVariable from the new left node, and feed + // it as the leftHashVariable of the swapped join node. + if (swapped.getLeftHashVariable().isPresent()) { + int hashVariableIndex = resolvedSwappedLeft.getOutputVariables().indexOf(swapped.getLeftHashVariable().get()); + leftHashVariable = Optional.of(resolvedSwappedLeft.getSources().get(0).getOutputVariables().get(hashVariableIndex)); + // When join output layout contains new left side's hashVariable (e.g., a nested join in a single stage, the inner join's output layout possibly + // carry the join hashVariable from its new probe), after removing the local exchange at the new probe, the output variables of the join node will + // also change, which has to be broadcast upwards (rewriting plan nodes) until the point where this hashVariable is no longer the output. + // This is against typical iterativeOptimizer behavior and given this case is rare, just abort the swapping for this scenario. + if (swapped.getOutputVariables().contains(swapped.getLeftHashVariable().get())) { + return Optional.empty(); + } + } + } + } + + // Add additional localExchange if the new build side does not satisfy the partitioning conditions. + List buildJoinVariables = swapped.getCriteria().stream() + .map(JoinNode.EquiJoinClause::getRight) + .collect(toImmutableList()); + PlanNode newRight = swapped.getRight(); + if (!checkBuildSidePropertySatisfied(swapped.getRight(), buildJoinVariables, metadata, parser, lookup, session, variableAllocator)) { + if (getTaskConcurrency(session) > 1) { + newRight = systemPartitionedExchange( + idAllocator.getNextId(), + LOCAL, + swapped.getRight(), + buildJoinVariables, + swapped.getRightHashVariable()); + } + else { + newRight = gatheringExchange(idAllocator.getNextId(), LOCAL, swapped.getRight()); + } + } + + JoinNode newJoinNode = new JoinNode( + swapped.getSourceLocation(), + swapped.getId(), + swapped.getType(), + newLeft, + newRight, + swapped.getCriteria(), + swapped.getOutputVariables(), + swapped.getFilter(), + leftHashVariable, + swapped.getRightHashVariable(), + swapped.getDistributionType(), + swapped.getDynamicFilters()); + + return Optional.of(newJoinNode); + } + + // Check if the new probe side after removing unnecessary local exchange is valid. + public static boolean checkProbeSidePropertySatisfied(PlanNode node, Metadata metadata, SqlParser parser, Lookup lookup, Session session, VariableAllocator variableAllocator) + { + StreamPreferredProperties requiredProbeProperty; + if (isSpillEnabled(session) && isJoinSpillingEnabled(session)) { + requiredProbeProperty = fixedParallelism(); + } + else { + requiredProbeProperty = defaultParallelism(session); + } + StreamPropertyDerivations.StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, parser, lookup, session, variableAllocator); + return requiredProbeProperty.isSatisfiedBy(nodeProperty); + } + + // Check if the property of a planNode satisfies the requirements for directly feeding as the build side of a JoinNode. + private static boolean checkBuildSidePropertySatisfied( + PlanNode node, + List partitioningColumns, + Metadata metadata, + SqlParser parser, + Lookup lookup, + Session session, + VariableAllocator variableAllocator) + { + StreamPreferredProperties requiredBuildProperty; + if (getTaskConcurrency(session) > 1) { + requiredBuildProperty = exactlyPartitionedOn(partitioningColumns); + } + else { + requiredBuildProperty = singleStream(); + } + StreamPropertyDerivations.StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, parser, lookup, session, variableAllocator); + return requiredBuildProperty.isSatisfiedBy(nodeProperty); + } + + private static StreamPropertyDerivations.StreamProperties derivePropertiesRecursively( + PlanNode node, + Metadata metadata, + SqlParser parser, + Lookup lookup, + Session session, + VariableAllocator variableAllocator) + { + PlanNode actual = lookup.resolve(node); + List inputProperties = actual.getSources().stream() + .map(source -> derivePropertiesRecursively(source, metadata, parser, lookup, session, variableAllocator)) + .collect(toImmutableList()); + return StreamPropertyDerivations.deriveProperties(actual, inputProperties, metadata, session, TypeProvider.viewOf(variableAllocator.getVariables()), parser); + } + + public static boolean isBelowBroadcastLimit(PlanNode planNode, Rule.Context context) + { + DataSize joinMaxBroadcastTableSize = getJoinMaxBroadcastTableSize(context.getSession()); + return DetermineJoinDistributionType.getSourceTablesSizeInBytes(planNode, context) <= joinMaxBroadcastTableSize.toBytes(); + } + + public static boolean isSmallerThanThreshold(PlanNode planNodeA, PlanNode planNodeB, Rule.Context context) + { + double aOutputSize = getFirstKnownOutputSizeInBytes(planNodeA, context); + double bOutputSize = getFirstKnownOutputSizeInBytes(planNodeB, context); + return aOutputSize * SIZE_DIFFERENCE_THRESHOLD < bOutputSize; + } + + private static double getFirstKnownOutputSizeInBytes(PlanNode node, Rule.Context context) + { + return getFirstKnownOutputSizeInBytes(node, context.getLookup(), context.getStatsProvider()); + } + + /** + * Recursively looks for the first source node with a known estimate and uses that to return an approximate output size. + * Returns NaN if an un-estimated expanding node (Join or Unnest) is encountered. + * The amount of reduction in size from un-estimated non-expanding nodes (e.g. an un-estimated filter or aggregation) + * is not accounted here. We make use of the first available estimate and make decision about flipping join sides only if + * we find a large difference in output size of both sides. + */ + @VisibleForTesting + public static double getFirstKnownOutputSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider) + { + return Stream.of(node) + .flatMap(planNode -> { + if (planNode instanceof GroupReference) { + return lookup.resolveGroup(node); + } + return Stream.of(planNode); + }) + .mapToDouble(resolvedNode -> { + double outputSizeInBytes = statsProvider.getStats(resolvedNode).getOutputSizeInBytes(resolvedNode); + if (!isNaN(outputSizeInBytes)) { + return outputSizeInBytes; + } + + if (EXPANDING_NODE_CLASSES.stream().anyMatch(clazz -> clazz.isInstance(resolvedNode))) { + return NaN; + } + + List sourceNodes = resolvedNode.getSources(); + if (sourceNodes.isEmpty()) { + return NaN; + } + + double sourcesOutputSizeInBytes = 0; + for (PlanNode sourceNode : sourceNodes) { + double firstKnownOutputSizeInBytes = getFirstKnownOutputSizeInBytes(sourceNode, lookup, statsProvider); + if (isNaN(firstKnownOutputSizeInBytes)) { + return NaN; + } + sourcesOutputSizeInBytes += firstKnownOutputSizeInBytes; + } + return sourcesOutputSizeInBytes; + }) + .sum(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RuntimeReorderJoinSides.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RuntimeReorderJoinSides.java index cf52033983542..dd90dab4a6c7d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RuntimeReorderJoinSides.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RuntimeReorderJoinSides.java @@ -19,38 +19,21 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableScanNode; -import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties; -import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.JoinNode; -import java.util.List; import java.util.Optional; -import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency; -import static com.facebook.presto.SystemSessionProperties.isJoinSpillingEnabled; -import static com.facebook.presto.SystemSessionProperties.isSpillEnabled; +import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.createRuntimeSwappedJoinNode; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.defaultParallelism; -import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.exactlyPartitionedOn; -import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.fixedParallelism; -import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.singleStream; -import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; -import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; -import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange; -import static com.facebook.presto.sql.planner.plan.ExchangeNode.systemPartitionedExchange; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; import static com.facebook.presto.sql.planner.plan.Patterns.join; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -116,70 +99,13 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) if (!isSwappedJoinValid(joinNode)) { return Result.empty(); } - JoinNode swapped = joinNode.flipChildren(); - PlanNode newLeft = swapped.getLeft(); - Optional leftHashVariable = swapped.getLeftHashVariable(); - // Remove unnecessary LocalExchange in the current probe side. If the immediate left child (new probe side) of the join node - // is a localExchange, there are two cases: an Exchange introduced by the current probe side (previous build side); or it is a UnionNode. - // If the exchangeNode has more than 1 sources, it corresponds to the second case, otherwise it corresponds to the first case and could be safe to remove - PlanNode resolvedSwappedLeft = context.getLookup().resolve(newLeft); - if (resolvedSwappedLeft instanceof ExchangeNode && resolvedSwappedLeft.getSources().size() == 1) { - // Ensure the new probe after skipping the local exchange will satisfy the required probe side property - if (checkProbeSidePropertySatisfied(resolvedSwappedLeft.getSources().get(0), context)) { - newLeft = resolvedSwappedLeft.getSources().get(0); - // The HashGenerationOptimizer will generate hashVariables and append to the output layout of the nodes following the same order. Therefore, - // we use the index of the old hashVariable in the ExchangeNode output layout to retrieve the hashVariable from the new left node, and feed - // it as the leftHashVariable of the swapped join node. - if (swapped.getLeftHashVariable().isPresent()) { - int hashVariableIndex = resolvedSwappedLeft.getOutputVariables().indexOf(swapped.getLeftHashVariable().get()); - leftHashVariable = Optional.of(resolvedSwappedLeft.getSources().get(0).getOutputVariables().get(hashVariableIndex)); - // When join output layout contains new left side's hashVariable (e.g., a nested join in a single stage, the inner join's output layout possibly - // carry the join hashVariable from its new probe), after removing the local exchange at the new probe, the output variables of the join node will - // also change, which has to be broadcast upwards (rewriting plan nodes) until the point where this hashVariable is no longer the output. - // This is against typical iterativeOptimizer behavior and given this case is rare, just abort the swapping for this scenario. - if (swapped.getOutputVariables().contains(swapped.getLeftHashVariable().get())) { - return Result.empty(); - } - } - } + Optional rewrittenNode = createRuntimeSwappedJoinNode(joinNode, metadata, parser, context.getLookup(), context.getSession(), context.getVariableAllocator(), context.getIdAllocator()); + if (rewrittenNode.isPresent()) { + log.debug(format("Probe size: %.2f is smaller than Build size: %.2f => invoke runtime join swapping on JoinNode ID: %s.", leftOutputSizeInBytes, rightOutputSizeInBytes, joinNode.getId())); + return Result.ofPlanNode(rewrittenNode.get()); } - - // Add additional localExchange if the new build side does not satisfy the partitioning conditions. - List buildJoinVariables = swapped.getCriteria().stream() - .map(JoinNode.EquiJoinClause::getRight) - .collect(toImmutableList()); - PlanNode newRight = swapped.getRight(); - if (!checkBuildSidePropertySatisfied(swapped.getRight(), buildJoinVariables, context)) { - if (getTaskConcurrency(context.getSession()) > 1) { - newRight = systemPartitionedExchange( - context.getIdAllocator().getNextId(), - LOCAL, - swapped.getRight(), - buildJoinVariables, - swapped.getRightHashVariable()); - } - else { - newRight = gatheringExchange(context.getIdAllocator().getNextId(), LOCAL, swapped.getRight()); - } - } - - JoinNode newJoinNode = new JoinNode( - swapped.getSourceLocation(), - swapped.getId(), - swapped.getType(), - newLeft, - newRight, - swapped.getCriteria(), - swapped.getOutputVariables(), - swapped.getFilter(), - leftHashVariable, - swapped.getRightHashVariable(), - swapped.getDistributionType(), - swapped.getDynamicFilters()); - - log.debug(format("Probe size: %.2f is smaller than Build size: %.2f => invoke runtime join swapping on JoinNode ID: %s.", leftOutputSizeInBytes, rightOutputSizeInBytes, newJoinNode.getId())); - return Result.ofPlanNode(newJoinNode); + return Result.empty(); } private boolean isSwappedJoinValid(JoinNode join) @@ -187,41 +113,4 @@ private boolean isSwappedJoinValid(JoinNode join) return !(join.getDistributionType().get() == REPLICATED && join.getType() == LEFT) && !(join.getDistributionType().get() == PARTITIONED && join.getCriteria().isEmpty() && join.getType() == RIGHT); } - - // Check if the new probe side after removing unnecessary local exchange is valid. - private boolean checkProbeSidePropertySatisfied(PlanNode node, Context context) - { - StreamPreferredProperties requiredProbeProperty; - if (isSpillEnabled(context.getSession()) && isJoinSpillingEnabled(context.getSession())) { - requiredProbeProperty = fixedParallelism(); - } - else { - requiredProbeProperty = defaultParallelism(context.getSession()); - } - StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, parser, context); - return requiredProbeProperty.isSatisfiedBy(nodeProperty); - } - - // Check if the property of a planNode satisfies the requirements for directly feeding as the build side of a JoinNode. - private boolean checkBuildSidePropertySatisfied(PlanNode node, List partitioningColumns, Context context) - { - StreamPreferredProperties requiredBuildProperty; - if (getTaskConcurrency(context.getSession()) > 1) { - requiredBuildProperty = exactlyPartitionedOn(partitioningColumns); - } - else { - requiredBuildProperty = singleStream(); - } - StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, parser, context); - return requiredBuildProperty.isSatisfiedBy(nodeProperty); - } - - private StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, SqlParser parser, Context context) - { - PlanNode actual = context.getLookup().resolve(node); - List inputProperties = actual.getSources().stream() - .map(source -> derivePropertiesRecursively(source, metadata, parser, context)) - .collect(toImmutableList()); - return StreamPropertyDerivations.deriveProperties(actual, inputProperties, metadata, context.getSession(), TypeProvider.viewOf(context.getVariableAllocator().getVariables()), parser); - } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index 10e06e07016dc..1bee9ff07cbf7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -56,6 +56,7 @@ import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.MergeJoinNode; +import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; @@ -788,6 +789,24 @@ public ActualProperties visitTableScan(TableScanNode node, List inputProperties) + { + if (node.getOrderingScheme().isPresent()) { + return ActualProperties.builder() + .global(singleStreamPartition()) + .unordered(false) + .build(); + } + if (node.isEnsureSourceOrdering()) { + return ActualProperties.builder() + .global(singleStreamPartition()) + .build(); + } + + return ActualProperties.builder().build(); + } + private Global deriveGlobalProperties(TableLayout layout, Map assignments, Map constants) { Optional> streamPartitioning = layout.getStreamPartitioningColumns() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index e4a3a78d02933..83baddf6f0148 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -47,6 +47,7 @@ import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.MergeJoinNode; +import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; @@ -617,6 +618,19 @@ public StreamProperties visitSample(SampleNode node, List inpu { return Iterables.getOnlyElement(inputProperties); } + + @Override + public StreamProperties visitRemoteSource(RemoteSourceNode node, List inputProperties) + { + if (node.getOrderingScheme().isPresent()) { + return StreamProperties.ordered(); + } + if (node.isEnsureSourceOrdering()) { + return StreamProperties.singleStream(); + } + + return StreamProperties.fixedStreams(); + } } @Immutable diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 0d358a9c5fc6b..1df86a3d424db 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -354,6 +354,7 @@ public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext co return new RemoteSourceNode( node.getSourceLocation(), node.getId(), + node.getStatsEquivalentPlanNode(), node.getSourceFragmentIds(), canonicalizeAndDistinct(node.getOutputVariables()), node.isEnsureSourceOrdering(), diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/InMemoryHistoryBasedPlanStatisticsProvider.java b/presto-main/src/main/java/com/facebook/presto/testing/InMemoryHistoryBasedPlanStatisticsProvider.java similarity index 98% rename from presto-tests/src/main/java/com/facebook/presto/tests/statistics/InMemoryHistoryBasedPlanStatisticsProvider.java rename to presto-main/src/main/java/com/facebook/presto/testing/InMemoryHistoryBasedPlanStatisticsProvider.java index a4f6ff6c32061..642b94c5798d7 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/InMemoryHistoryBasedPlanStatisticsProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/InMemoryHistoryBasedPlanStatisticsProvider.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.tests.statistics; +package com.facebook.presto.testing; import com.facebook.presto.spi.plan.PlanNodeWithHash; import com.facebook.presto.spi.statistics.HistoricalPlanStatistics; diff --git a/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorTester.java b/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorTester.java index f57a9ababd363..eedc5b2f851f8 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorTester.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorTester.java @@ -25,6 +25,7 @@ import java.util.function.Function; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.util.Objects.requireNonNull; public class StatsCalculatorTester implements AutoCloseable @@ -49,14 +50,14 @@ public Metadata getMetadata() return metadata; } - public FragmentStatsProvider getFragmentStatsProvider() + public StatsCalculatorTester(LocalQueryRunner queryRunner) { - return queryRunner.getFragmentStatsProvider(); + this(queryRunner, queryRunner.getStatsCalculator()); } - private StatsCalculatorTester(LocalQueryRunner queryRunner) + public StatsCalculatorTester(LocalQueryRunner queryRunner, StatsCalculator statsCalculator) { - this.statsCalculator = queryRunner.getStatsCalculator(); + this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.session = queryRunner.getDefaultSession(); this.metadata = queryRunner.getMetadata(); this.queryRunner = queryRunner; diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestRemoteSourceStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestRemoteSourceStatsRule.java index a6f4623495e9c..dbe3494464f90 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestRemoteSourceStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestRemoteSourceStatsRule.java @@ -17,6 +17,7 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.QueryId; import com.facebook.presto.sql.planner.plan.PlanFragmentId; +import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -25,7 +26,6 @@ import static java.lang.Double.NaN; public class TestRemoteSourceStatsRule - extends BaseStatsCalculatorTest { @Test public void testRemoteSourceStatsRule() @@ -34,12 +34,12 @@ public void testRemoteSourceStatsRule() Session session = testSessionBuilder() .setQueryId(queryId) .build(); - StatsCalculatorTester tester = new StatsCalculatorTester(session); - - tester.getFragmentStatsProvider().putStats(queryId, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); - tester.getFragmentStatsProvider().putStats(queryId, new PlanFragmentId(2), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); - tester - .assertStatsFor(planBuilder -> planBuilder.remoteSource(ImmutableList.of(new PlanFragmentId(1), new PlanFragmentId(2)))) + LocalQueryRunner localQueryRunner = new LocalQueryRunner(session); + StatsCalculatorTester tester = new StatsCalculatorTester(localQueryRunner); + FragmentStatsProvider fragmentStatsProvider = localQueryRunner.getFragmentStatsProvider(); + fragmentStatsProvider.putStats(queryId, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + fragmentStatsProvider.putStats(queryId, new PlanFragmentId(2), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + tester.assertStatsFor(planBuilder -> planBuilder.remoteSource(ImmutableList.of(new PlanFragmentId(1), new PlanFragmentId(2)))) .check(check -> check.totalSize(2000) .outputRowsCountUnknown()); tester.close(); @@ -48,9 +48,10 @@ public void testRemoteSourceStatsRule() @Test public void testRemoteSourceStatsUnknown() { - tester() - .assertStatsFor(planBuilder -> planBuilder.remoteSource(ImmutableList.of(new PlanFragmentId(1), new PlanFragmentId(2)))) + StatsCalculatorTester tester = new StatsCalculatorTester(); + tester.assertStatsFor(planBuilder -> planBuilder.remoteSource(ImmutableList.of(new PlanFragmentId(1), new PlanFragmentId(2)))) .check(check -> check.outputRowsCountUnknown() .totalSizeUnknown()); + tester.close(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index d2200b5340d66..aa3cab5b4243d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -46,6 +46,8 @@ import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.MergeJoinNode; import com.facebook.presto.sql.planner.plan.OffsetNode; +import com.facebook.presto.sql.planner.plan.PlanFragmentId; +import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.SortNode; import com.facebook.presto.sql.planner.plan.SpatialJoinNode; @@ -614,6 +616,11 @@ public static PlanMatchPattern tableWriter(List columns, List co return node(TableWriterNode.class, source).with(new TableWriterMatcher(columns, columnNames)); } + public static PlanMatchPattern remoteSource(List sourceFragmentIds, Map outputSymbolAliases) + { + return node(RemoteSourceNode.class).with(new RemoteSourceMatcher(sourceFragmentIds, outputSymbolAliases)); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RemoteSourceMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RemoteSourceMatcher.java new file mode 100644 index 0000000000000..dca9e556f5723 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RemoteSourceMatcher.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanFragmentId; +import com.facebook.presto.sql.planner.plan.RemoteSourceNode; +import com.google.common.collect.Maps; + +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class RemoteSourceMatcher + implements Matcher +{ + private List sourceFragmentIds; + private final Map outputSymbolAliases; + + public RemoteSourceMatcher(List sourceFragmentIds, Map outputSymbolAliases) + { + this.sourceFragmentIds = requireNonNull(sourceFragmentIds, "sourceFragmentIds is null"); + this.outputSymbolAliases = requireNonNull(outputSymbolAliases, "outputSymbolAliases is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof RemoteSourceNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + RemoteSourceNode remoteSourceNode = (RemoteSourceNode) node; + if (remoteSourceNode.getSourceFragmentIds().equals(sourceFragmentIds)) { + return match(SymbolAliases.builder() + .putAll(Maps.transformValues(outputSymbolAliases, index -> createSymbolReference(remoteSourceNode.getOutputVariables().get(index)))) + .build()); + } + return NO_MATCH; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java index 649ff81f18681..2929d7eff8745 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java @@ -54,8 +54,8 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; -import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.getFirstKnownOutputSizeInBytes; import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.getSourceTablesSizeInBytes; +import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.getFirstKnownOutputSizeInBytes; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; @@ -718,22 +718,22 @@ public void testReplicatesWhenSourceIsSmall() PlanNodeStatsEstimate aStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) .addVariableStatistics(ImmutableMap.of( - new VariableReferenceExpression(Optional.empty(), "A1", variableType), - new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + new VariableReferenceExpression(Optional.empty(), "A1", variableType), + new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); // output size exceeds JOIN_MAX_BROADCAST_TABLE_SIZE limit PlanNodeStatsEstimate bStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) .addVariableStatistics(ImmutableMap.of( - new VariableReferenceExpression(Optional.empty(), "B1", variableType), - new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + new VariableReferenceExpression(Optional.empty(), "B1", variableType), + new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); // output size does not exceed JOIN_MAX_BROADCAST_TABLE_SIZE limit PlanNodeStatsEstimate bSourceStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) .addVariableStatistics(ImmutableMap.of( - new VariableReferenceExpression(Optional.empty(), "B1", variableType), - new VariableStatsEstimate(0, 100, 0, 64, 10))) + new VariableReferenceExpression(Optional.empty(), "B1", variableType), + new VariableStatsEstimate(0, 100, 0, 64, 10))) .build(); // immediate join sources exceeds JOIN_MAX_BROADCAST_TABLE_SIZE limit but build tables are small @@ -748,20 +748,20 @@ public void testReplicatesWhenSourceIsSmall() VariableReferenceExpression a1 = p.variable("A1", variableType); VariableReferenceExpression b1 = p.variable("B1", variableType); return p.join( - INNER, - p.values(new PlanNodeId("valuesA"), aRows, a1), - p.filter(new PlanNodeId("filterB"), TRUE_CONSTANT, p.values(new PlanNodeId("valuesB"), bRows, b1)), - ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), - ImmutableList.of(a1, b1), - Optional.empty()); + INNER, + p.values(new PlanNodeId("valuesA"), aRows, a1), + p.filter(new PlanNodeId("filterB"), TRUE_CONSTANT, p.values(new PlanNodeId("valuesB"), bRows, b1)), + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty()); }) .matches(join( - INNER, - ImmutableList.of(equiJoinClause("A1", "B1")), - Optional.empty(), - Optional.of(REPLICATED), - values(ImmutableMap.of("A1", 0)), - filter("true", values(ImmutableMap.of("B1", 0))))); + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(REPLICATED), + values(ImmutableMap.of("A1", 0)), + filter("true", values(ImmutableMap.of("B1", 0))))); // same but with join sides reversed assertDetermineJoinDistributionType() @@ -774,20 +774,20 @@ public void testReplicatesWhenSourceIsSmall() VariableReferenceExpression a1 = p.variable("A1", variableType); VariableReferenceExpression b1 = p.variable("B1", variableType); return p.join( - INNER, - p.filter(new PlanNodeId("filterB"), TRUE_CONSTANT, p.values(new PlanNodeId("valuesB"), bRows, b1)), - p.values(new PlanNodeId("valuesA"), aRows, a1), - ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), - ImmutableList.of(b1, a1), - Optional.empty()); + INNER, + p.filter(new PlanNodeId("filterB"), TRUE_CONSTANT, p.values(new PlanNodeId("valuesB"), bRows, b1)), + p.values(new PlanNodeId("valuesA"), aRows, a1), + ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), + ImmutableList.of(b1, a1), + Optional.empty()); }) .matches(join( - INNER, - ImmutableList.of(equiJoinClause("A1", "B1")), - Optional.empty(), - Optional.of(REPLICATED), - values(ImmutableMap.of("A1", 0)), - filter("true", values(ImmutableMap.of("B1", 0))))); + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(REPLICATED), + values(ImmutableMap.of("A1", 0)), + filter("true", values(ImmutableMap.of("B1", 0))))); // only probe side (with small tables) source stats are available, join sides should be flipped assertDetermineJoinDistributionType() @@ -800,20 +800,20 @@ public void testReplicatesWhenSourceIsSmall() VariableReferenceExpression a1 = p.variable("A1", variableType); VariableReferenceExpression b1 = p.variable("B1", variableType); return p.join( - LEFT, - p.filter(new PlanNodeId("filterB"), TRUE_CONSTANT, p.values(new PlanNodeId("valuesB"), bRows, b1)), - p.values(new PlanNodeId("valuesA"), aRows, a1), - ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), - ImmutableList.of(b1, a1), - Optional.empty()); + LEFT, + p.filter(new PlanNodeId("filterB"), TRUE_CONSTANT, p.values(new PlanNodeId("valuesB"), bRows, b1)), + p.values(new PlanNodeId("valuesA"), aRows, a1), + ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), + ImmutableList.of(b1, a1), + Optional.empty()); }) .matches(join( - RIGHT, - ImmutableList.of(equiJoinClause("A1", "B1")), - Optional.empty(), - Optional.of(PARTITIONED), - values(ImmutableMap.of("A1", 0)), - filter("true", values(ImmutableMap.of("B1", 0))))); + RIGHT, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(PARTITIONED), + values(ImmutableMap.of("A1", 0)), + filter("true", values(ImmutableMap.of("B1", 0))))); } @Test @@ -827,15 +827,15 @@ public void testFlipWhenSizeDifferenceLarge() PlanNodeStatsEstimate aStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) .addVariableStatistics(ImmutableMap.of( - new VariableReferenceExpression(Optional.empty(), "A1", variableType), - new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + new VariableReferenceExpression(Optional.empty(), "A1", variableType), + new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); // output size exceeds JOIN_MAX_BROADCAST_TABLE_SIZE limit PlanNodeStatsEstimate bStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) .addVariableStatistics(ImmutableMap.of( - new VariableReferenceExpression(Optional.empty(), "B1", variableType), - new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + new VariableReferenceExpression(Optional.empty(), "B1", variableType), + new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); // source tables size exceeds JOIN_MAX_BROADCAST_TABLE_SIZE limit but one side is significantly bigger than the other @@ -850,23 +850,23 @@ public void testFlipWhenSizeDifferenceLarge() VariableReferenceExpression a1 = p.variable("A1", variableType); VariableReferenceExpression b1 = p.variable("B1", variableType); return p.join( - INNER, - p.values(new PlanNodeId("valuesA"), aRows, a1), - p.filter( - new PlanNodeId("filterB"), - TRUE_CONSTANT, - p.values(new PlanNodeId("valuesB"), bRows, b1)), - ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), - ImmutableList.of(a1, b1), - Optional.empty()); + INNER, + p.values(new PlanNodeId("valuesA"), aRows, a1), + p.filter( + new PlanNodeId("filterB"), + TRUE_CONSTANT, + p.values(new PlanNodeId("valuesB"), bRows, b1)), + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty()); }) .matches(join( - INNER, - ImmutableList.of(equiJoinClause("A1", "B1")), - Optional.empty(), - Optional.of(PARTITIONED), - values(ImmutableMap.of("A1", 0)), - filter("true", values(ImmutableMap.of("B1", 0))))); + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(PARTITIONED), + values(ImmutableMap.of("A1", 0)), + filter("true", values(ImmutableMap.of("B1", 0))))); // same but with join sides reversed assertDetermineJoinDistributionType() @@ -879,23 +879,23 @@ public void testFlipWhenSizeDifferenceLarge() VariableReferenceExpression a1 = p.variable("A1", variableType); VariableReferenceExpression b1 = p.variable("B1", variableType); return p.join( - INNER, - p.filter( - new PlanNodeId("filterB"), - TRUE_CONSTANT, - p.values(new PlanNodeId("valuesB"), bRows, b1)), - p.values(new PlanNodeId("valuesA"), aRows, a1), - ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), - ImmutableList.of(b1, a1), - Optional.empty()); + INNER, + p.filter( + new PlanNodeId("filterB"), + TRUE_CONSTANT, + p.values(new PlanNodeId("valuesB"), bRows, b1)), + p.values(new PlanNodeId("valuesA"), aRows, a1), + ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), + ImmutableList.of(b1, a1), + Optional.empty()); }) .matches(join( - INNER, - ImmutableList.of(equiJoinClause("A1", "B1")), - Optional.empty(), - Optional.of(PARTITIONED), - values(ImmutableMap.of("A1", 0)), - filter("true", values(ImmutableMap.of("B1", 0))))); + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(PARTITIONED), + values(ImmutableMap.of("A1", 0)), + filter("true", values(ImmutableMap.of("B1", 0))))); // Use REPLICATED join type for cross join assertDetermineJoinDistributionType() @@ -908,30 +908,30 @@ public void testFlipWhenSizeDifferenceLarge() VariableReferenceExpression a1 = p.variable("A1", variableType); VariableReferenceExpression b1 = p.variable("B1", variableType); return p.join( - INNER, - p.filter( - new PlanNodeId("filterB"), - TRUE_CONSTANT, - p.values(new PlanNodeId("valuesB"), bRows, b1)), - p.values(new PlanNodeId("valuesA"), aRows, a1), - ImmutableList.of(), - ImmutableList.of(b1, a1), - Optional.empty()); + INNER, + p.filter( + new PlanNodeId("filterB"), + TRUE_CONSTANT, + p.values(new PlanNodeId("valuesB"), bRows, b1)), + p.values(new PlanNodeId("valuesA"), aRows, a1), + ImmutableList.of(), + ImmutableList.of(b1, a1), + Optional.empty()); }) .matches(join( - INNER, - ImmutableList.of(), - Optional.empty(), - Optional.of(REPLICATED), - filter("true", values(ImmutableMap.of("B1", 0))), - values(ImmutableMap.of("A1", 0)))); + INNER, + ImmutableList.of(), + Optional.empty(), + Optional.of(REPLICATED), + filter("true", values(ImmutableMap.of("B1", 0))), + values(ImmutableMap.of("A1", 0)))); // Don't flip sides when both are similar in size bStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) .addVariableStatistics(ImmutableMap.of( - new VariableReferenceExpression(Optional.empty(), "B1", variableType), - new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + new VariableReferenceExpression(Optional.empty(), "B1", variableType), + new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); assertDetermineJoinDistributionType() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) @@ -943,23 +943,23 @@ public void testFlipWhenSizeDifferenceLarge() VariableReferenceExpression a1 = p.variable("A1", variableType); VariableReferenceExpression b1 = p.variable("B1", variableType); return p.join( - INNER, - p.filter( - new PlanNodeId("filterB"), - TRUE_CONSTANT, - p.values(new PlanNodeId("valuesB"), aRows, b1)), - p.values(new PlanNodeId("valuesA"), aRows, a1), - ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), - ImmutableList.of(b1, a1), - Optional.empty()); + INNER, + p.filter( + new PlanNodeId("filterB"), + TRUE_CONSTANT, + p.values(new PlanNodeId("valuesB"), aRows, b1)), + p.values(new PlanNodeId("valuesA"), aRows, a1), + ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), + ImmutableList.of(b1, a1), + Optional.empty()); }) .matches(join( - INNER, - ImmutableList.of(equiJoinClause("B1", "A1")), - Optional.empty(), - Optional.of(PARTITIONED), - filter("true", values(ImmutableMap.of("B1", 0))), - values(ImmutableMap.of("A1", 0)))); + INNER, + ImmutableList.of(equiJoinClause("B1", "A1")), + Optional.empty(), + Optional.of(PARTITIONED), + filter("true", values(ImmutableMap.of("B1", 0))), + values(ImmutableMap.of("A1", 0)))); } @Test @@ -993,8 +993,8 @@ public void testGetSourceTablesSizeInBytes() .put(variable, sourceVariable2) .build(), ImmutableList.of(planBuilder.tableScan( - ImmutableList.of(sourceVariable1), - ImmutableMap.of(sourceVariable1, new TestingColumnHandle("col"))), + ImmutableList.of(sourceVariable1), + ImmutableMap.of(sourceVariable1, new TestingColumnHandle("col"))), planBuilder.values(new PlanNodeId("valuesNode"), sourceVariable2))), noLookup(), node -> { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 2457f00d82b92..d4b59eca59a4b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -323,7 +323,24 @@ public AggregationNode aggregation(Consumer aggregationBuild public RemoteSourceNode remoteSource(List sourceFragmentIds) { - return new RemoteSourceNode(Optional.empty(), idAllocator.getNextId(), sourceFragmentIds, ImmutableList.of(), false, Optional.empty(), REPARTITION); + return remoteSource(idAllocator.getNextId(), sourceFragmentIds, ImmutableList.of()); + } + + public RemoteSourceNode remoteSource(PlanNodeId planNodeId, List sourceFragmentIds, List outputVariables) + { + return new RemoteSourceNode(Optional.empty(), planNodeId, sourceFragmentIds, outputVariables, false, Optional.empty(), REPARTITION); + } + + public RemoteSourceNode remoteSource(List sourceFragmentIds, PlanNode statsEquivalentPlanNode) + { + return new RemoteSourceNode( + Optional.empty(), + idAllocator.getNextId(), + Optional.of(statsEquivalentPlanNode), + sourceFragmentIds, ImmutableList.of(), + false, + Optional.empty(), + REPARTITION); } public CallExpression binaryOperation(OperatorType operatorType, RowExpression left, RowExpression right) diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index 788dea643fe20..54de9c9ea13fd 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -38,7 +38,6 @@ import com.facebook.presto.cost.CostCalculatorUsingExchanges; import com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges; import com.facebook.presto.cost.CostComparator; -import com.facebook.presto.cost.StatsCalculatorModule; import com.facebook.presto.cost.TaskCountEstimator; import com.facebook.presto.dispatcher.QueryPrerequisitesManager; import com.facebook.presto.event.QueryMonitor; @@ -123,6 +122,7 @@ import com.facebook.presto.spark.planner.PrestoSparkPlanFragmenter; import com.facebook.presto.spark.planner.PrestoSparkQueryPlanner; import com.facebook.presto.spark.planner.PrestoSparkRddFactory; +import com.facebook.presto.spark.planner.PrestoSparkStatsCalculatorModule; import com.facebook.presto.spark.planner.optimizers.AdaptivePlanOptimizers; import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; import com.facebook.presto.spi.ConnectorTypeSerde; @@ -439,7 +439,7 @@ protected void setup(Binder binder) jsonBinder(binder).addKeyDeserializerBinding(VariableReferenceExpression.class).to(VariableReferenceExpressionDeserializer.class); // statistics calculator / cost calculator - binder.install(new StatsCalculatorModule()); + binder.install(new PrestoSparkStatsCalculatorModule()); binder.bind(CostCalculator.class).to(CostCalculatorUsingExchanges.class).in(Scopes.SINGLETON); binder.bind(CostCalculator.class).annotatedWith(CostCalculator.EstimatedExchanges.class).to(CostCalculatorWithEstimatedExchanges.class).in(Scopes.SINGLETON); binder.bind(CostComparator.class).in(Scopes.SINGLETON); diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java index 39d981df50547..eddfd3e76047c 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java @@ -26,6 +26,7 @@ import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.common.type.Type; +import com.facebook.presto.cost.FragmentStatsProvider; import com.facebook.presto.cost.HistoryBasedPlanStatisticsManager; import com.facebook.presto.cost.HistoryBasedPlanStatisticsTracker; import com.facebook.presto.cost.StatsAndCosts; @@ -71,12 +72,15 @@ import com.facebook.presto.spark.planner.PrestoSparkQueryPlanner; import com.facebook.presto.spark.planner.PrestoSparkQueryPlanner.PlanAndMore; import com.facebook.presto.spark.planner.PrestoSparkRddFactory; +import com.facebook.presto.spark.planner.optimizers.AdaptivePlanOptimizers; import com.facebook.presto.spark.util.PrestoSparkTransactionUtils; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AnalyzerOptions; import com.facebook.presto.spi.memory.MemoryPoolId; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.security.AccessControl; @@ -194,6 +198,8 @@ public class PrestoSparkQueryExecutionFactory private final Map, DataDefinitionTask> ddlTasks; private final Optional errorClassifier; private final HistoryBasedPlanStatisticsTracker historyBasedPlanStatisticsTracker; + private final AdaptivePlanOptimizers adaptivePlanOptimizers; + private final FragmentStatsProvider fragmentStatsProvider; @Inject public PrestoSparkQueryExecutionFactory( @@ -231,7 +237,9 @@ public PrestoSparkQueryExecutionFactory( Set waitTimeMetrics, Map, DataDefinitionTask> ddlTasks, Optional errorClassifier, - HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager) + HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager, + AdaptivePlanOptimizers adaptivePlanOptimizers, + FragmentStatsProvider fragmentStatsProvider) { this.queryIdGenerator = requireNonNull(queryIdGenerator, "queryIdGenerator is null"); this.sessionSupplier = requireNonNull(sessionSupplier, "sessionSupplier is null"); @@ -268,6 +276,8 @@ public PrestoSparkQueryExecutionFactory( this.ddlTasks = ImmutableMap.copyOf(requireNonNull(ddlTasks, "ddlTasks is null")); this.errorClassifier = requireNonNull(errorClassifier, "errorClassifier is null"); this.historyBasedPlanStatisticsTracker = requireNonNull(historyBasedPlanStatisticsManager, "historyBasedPlanStatisticsManager is null").getHistoryBasedPlanStatisticsTracker(); + this.adaptivePlanOptimizers = requireNonNull(adaptivePlanOptimizers, "adaptivePlanOptimizers is null"); + this.fragmentStatsProvider = requireNonNull(fragmentStatsProvider, "fragmentStatsProvider is null"); } public static QueryInfo createQueryInfo( @@ -650,7 +660,9 @@ else if (preparedQuery.isExplainTypeValidate()) { return accessControlChecker.createExecution(session, preparedQuery, queryStateTimer, warningCollector); } else { - planAndMore = queryPlanner.createQueryPlan(session, preparedQuery, warningCollector); + VariableAllocator variableAllocator = new VariableAllocator(); + PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); + planAndMore = queryPlanner.createQueryPlan(session, preparedQuery, warningCollector, variableAllocator, planNodeIdAllocator); JavaSparkContext javaSparkContext = new JavaSparkContext(sparkContext); CollectionAccumulator taskInfoCollector = new CollectionAccumulator<>(); taskInfoCollector.register(sparkContext, Option.empty(), false); @@ -735,6 +747,10 @@ else if (preparedQuery.isExplainTypeValidate()) { metadata, partitioningProviderManager, historyBasedPlanStatisticsTracker, + adaptivePlanOptimizers, + variableAllocator, + planNodeIdAllocator, + fragmentStatsProvider, bootstrapMetricsCollector); } } diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkAdaptiveQueryExecution.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkAdaptiveQueryExecution.java index dad49c983a222..005ac8f89b17c 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkAdaptiveQueryExecution.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkAdaptiveQueryExecution.java @@ -17,8 +17,8 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; +import com.facebook.presto.cost.FragmentStatsProvider; import com.facebook.presto.cost.HistoryBasedPlanStatisticsTracker; -import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.event.QueryMonitor; import com.facebook.presto.execution.QueryManagerConfig; import com.facebook.presto.execution.QueryStateTimer; @@ -43,6 +43,8 @@ import com.facebook.presto.spark.planner.PrestoSparkPlanFragmenter; import com.facebook.presto.spark.planner.PrestoSparkQueryPlanner.PlanAndMore; import com.facebook.presto.spark.planner.PrestoSparkRddFactory; +import com.facebook.presto.spark.planner.optimizers.AdaptivePlanOptimizers; +import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.page.PagesSerde; import com.facebook.presto.spi.plan.OutputNode; @@ -54,7 +56,9 @@ import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.SubPlan; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; +import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; @@ -95,6 +99,7 @@ import static com.google.common.base.Throwables.propagateIfPossible; import static com.google.common.base.Verify.verify; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; /** @@ -109,6 +114,10 @@ public class PrestoSparkAdaptiveQueryExecution private static final Logger log = Logger.get(PrestoSparkAdaptiveQueryExecution.class); private final IterativePlanFragmenter iterativePlanFragmenter; + private final List adaptivePlanOptimizers; + private final VariableAllocator variableAllocator; + private final PlanNodeIdAllocator idAllocator; + private final FragmentStatsProvider fragmentStatsProvider; /** * Set with the IDs of the fragments that have finished execution. @@ -158,6 +167,10 @@ public PrestoSparkAdaptiveQueryExecution( Metadata metadata, PartitioningProviderManager partitioningProviderManager, HistoryBasedPlanStatisticsTracker historyBasedPlanStatisticsTracker, + AdaptivePlanOptimizers adaptivePlanOptimizers, + VariableAllocator variableAllocator, + PlanNodeIdAllocator idAllocator, + FragmentStatsProvider fragmentStatsProvider, Optional>> bootstrapMetricsCollector) { super( @@ -198,6 +211,10 @@ public PrestoSparkAdaptiveQueryExecution( historyBasedPlanStatisticsTracker, bootstrapMetricsCollector); + this.fragmentStatsProvider = requireNonNull(fragmentStatsProvider, "fragmentStatsProvider is null"); + this.adaptivePlanOptimizers = requireNonNull(adaptivePlanOptimizers, "adaptivePlanOptimizers is null").getAdaptiveOptimizers(); + this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null"); + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.iterativePlanFragmenter = createIterativePlanFragmenter(); } @@ -207,8 +224,17 @@ private IterativePlanFragmenter createIterativePlanFragmenter() Function isFragmentFinished = this.executedFragments::contains; // TODO Create the IterativePlanFragmenter by injection (it has to become stateless first--check PR 18811). - return new IterativePlanFragmenter(this.planAndMore.getPlan(), isFragmentFinished, this.metadata, new PlanChecker(this.featuresConfig, forceSingleNode), new SqlParser(), - new PlanNodeIdAllocator(), new PrestoSparkNodePartitioningManager(this.partitioningProviderManager), this.queryManagerConfig, this.session, this.warningCollector, + return new IterativePlanFragmenter( + this.planAndMore.getPlan(), + isFragmentFinished, + this.metadata, + new PlanChecker(this.featuresConfig, forceSingleNode), + new SqlParser(), + this.idAllocator, + new PrestoSparkNodePartitioningManager(this.partitioningProviderManager), + this.queryManagerConfig, + this.session, + this.warningCollector, forceSingleNode); } @@ -295,11 +321,23 @@ public Void apply(Try result) verify(fragmentEvent instanceof FragmentCompletionSuccessEvent, String.format("Unexpected FragmentCompletionEvent type: %s", fragmentEvent.getClass().getSimpleName())); FragmentCompletionSuccessEvent successEvent = (FragmentCompletionSuccessEvent) fragmentEvent; executedFragments.add(successEvent.getFragmentId()); - Optional runtimeStats = createRuntimeStats(successEvent.getMapOutputStats()); - // Re-optimizations here. + + // add runtime stats to the fragmentStatsProvider + createRuntimeStats(successEvent.getMapOutputStats()).ifPresent( + stats -> fragmentStatsProvider.putStats(session.getQueryId(), successEvent.getFragmentId(), stats)); + + // Re-optimize plan. + PlanNode optimizedPlan = planAndFragments.getRemainingPlan().get(); + for (PlanOptimizer optimizer : adaptivePlanOptimizers) { + optimizedPlan = optimizer.optimize(optimizedPlan, session, TypeProvider.viewOf(variableAllocator.getVariables()), variableAllocator, idAllocator, warningCollector); + } + + if (!optimizedPlan.equals(planAndFragments.getRemainingPlan().get())) { + log.info("adaptive plan optimizations triggered"); + } // Call the iterative fragmenter on the remaining plan that has not yet been submitted for execution. - planAndFragments = iterativePlanFragmenter.createReadySubPlans(planAndFragments.getRemainingPlan().get()); + planAndFragments = iterativePlanFragmenter.createReadySubPlans(optimizedPlan); } verify(planAndFragments.getReadyFragments().size() == 1, "The last step of the adaptive execution is expected to have a single fragment remaining."); @@ -307,7 +345,7 @@ public Void apply(Try result) setFinalFragmentedPlan(finalFragment); - return executeFinalFragment(session, finalFragment, tableWriteInfo); + return executeFinalFragment(finalFragment, tableWriteInfo); } private static Set getRootChildNodeFragmentIDs(PlanNode rootPlanNode) @@ -345,10 +383,8 @@ private void publishFragmentCompletionEvent(FragmentCompletionEvent fragmentComp /** * Execute the final fragment of the plan and collect the result. */ - private List> executeFinalFragment(Session session, - SubPlan finalFragment, - TableWriteInfo tableWriteInfo - ) throws SparkException, TimeoutException + private List> executeFinalFragment(SubPlan finalFragment, TableWriteInfo tableWriteInfo) + throws SparkException, TimeoutException { if (finalFragment.getFragment().getPartitioning().equals(COORDINATOR_DISTRIBUTION)) { Map> inputRdds = new HashMap<>(); diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkQueryPlanner.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkQueryPlanner.java index 73f07ea1b5f95..5a0fbd0cb4b6c 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkQueryPlanner.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkQueryPlanner.java @@ -96,10 +96,8 @@ public PrestoSparkQueryPlanner( this.planCanonicalInfoProvider = requireNonNull(historyBasedPlanStatisticsManager, "historyBasedPlanStatisticsManager is null").getPlanCanonicalInfoProvider(); } - public PlanAndMore createQueryPlan(Session session, BuiltInPreparedQuery preparedQuery, WarningCollector warningCollector) + public PlanAndMore createQueryPlan(Session session, BuiltInPreparedQuery preparedQuery, WarningCollector warningCollector, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator) { - PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); - Analyzer analyzer = new Analyzer( session, metadata, @@ -112,12 +110,11 @@ public PlanAndMore createQueryPlan(Session session, BuiltInPreparedQuery prepare Analysis analysis = analyzer.analyze(preparedQuery.getStatement()); - final VariableAllocator planVariableAllocator = new VariableAllocator(); LogicalPlanner logicalPlanner = new LogicalPlanner( session, idAllocator, metadata, - planVariableAllocator, + variableAllocator, sqlParser); PlanNode planNode = session.getRuntimeStats().profileNanos( @@ -130,7 +127,7 @@ public PlanAndMore createQueryPlan(Session session, BuiltInPreparedQuery prepare optimizers.getPlanningTimeOptimizers(), planChecker, sqlParser, - planVariableAllocator, + variableAllocator, idAllocator, warningCollector, statsCalculator, diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkStatsCalculator.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkStatsCalculator.java new file mode 100644 index 0000000000000..fede7ed44764d --- /dev/null +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkStatsCalculator.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spark.planner; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.HistoryBasedOptimizationConfig; +import com.facebook.presto.cost.HistoryBasedPlanStatisticsCalculator; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.statistics.HistoryBasedSourceInfo; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; +import com.facebook.presto.sql.planner.plan.RemoteSourceNode; + +import java.util.List; + +import static com.facebook.presto.cost.HistoricalPlanStatisticsUtil.similarStats; +import static java.util.Objects.requireNonNull; + +/** + * This StatsCalculator incorporates runtime statistics when available + * and decides whether to use historical or runtime-based statistics. + */ +public class PrestoSparkStatsCalculator + implements StatsCalculator + +{ + private final HistoryBasedPlanStatisticsCalculator historyBasedPlanStatisticsCalculator; + private final StatsCalculator delegate; + private final HistoryBasedOptimizationConfig historyBasedOptimizationConfig; + + public PrestoSparkStatsCalculator(HistoryBasedPlanStatisticsCalculator historyBasedPlanStatisticsCalculator, StatsCalculator delegate, HistoryBasedOptimizationConfig historyBasedOptimizationConfig) + { + this.historyBasedPlanStatisticsCalculator = requireNonNull(historyBasedPlanStatisticsCalculator, "historyBasedPlanStatisticsCalculator is null"); + this.delegate = requireNonNull(delegate, "delegate is null"); + this.historyBasedOptimizationConfig = requireNonNull(historyBasedOptimizationConfig, "historyBasedOptimizationConfig"); + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + { + boolean shouldUseHistoricalStats = shouldUseHistoricalStats(node, sourceStats, lookup, session, types); + if (shouldUseHistoricalStats) { + return historyBasedPlanStatisticsCalculator.calculateStats(node, sourceStats, lookup, session, types); + } + + return delegate.calculateStats(node, sourceStats, lookup, session, types); + } + + private boolean shouldUseHistoricalStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + { + // RemoteSourceNode stats are computed at runtime. If there are any + // RemoteSourceNodes, check whether we should use historical or runtime stats + // by comparing whether the runtime stats are similar to the historical stats + // if they are similar, use historical stats. if they differ, use runtime stats. + List remoteSourceNodes = PlanNodeSearcher.searchFrom(node, lookup) + .where(RemoteSourceNode.class::isInstance) + .findAll(); + for (RemoteSourceNode remoteSourceNode : remoteSourceNodes) { + PlanNodeStatsEstimate historicalStats = historyBasedPlanStatisticsCalculator.calculateStats(remoteSourceNode, sourceStats, lookup, session, types); + PlanNodeStatsEstimate runtimeStats = delegate.calculateStats(remoteSourceNode, sourceStats, lookup, session, types); + if (!runtimeStats.isTotalSizeUnknown() && + (!(historicalStats.getSourceInfo() instanceof HistoryBasedSourceInfo) || + !similarStats(historicalStats.getTotalSize(), runtimeStats.getTotalSize(), historyBasedOptimizationConfig.getHistoryMatchingThreshold()))) { + return false; + } + } + + return true; + } + + @Override + public void registerPlan(PlanNode root, Session session) + { + historyBasedPlanStatisticsCalculator.registerPlan(root, session); + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkStatsCalculatorModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkStatsCalculatorModule.java new file mode 100644 index 0000000000000..b2471077841e6 --- /dev/null +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkStatsCalculatorModule.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spark.planner; + +import com.facebook.presto.cost.FilterStatsCalculator; +import com.facebook.presto.cost.FragmentStatsProvider; +import com.facebook.presto.cost.HistoryBasedOptimizationConfig; +import com.facebook.presto.cost.HistoryBasedPlanStatisticsManager; +import com.facebook.presto.cost.ScalarStatsCalculator; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.cost.StatsNormalizer; +import com.facebook.presto.metadata.Metadata; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Provides; +import com.google.inject.Scopes; + +import javax.inject.Singleton; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.facebook.presto.cost.StatsCalculatorModule.createComposableStatsCalculator; + +public class PrestoSparkStatsCalculatorModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(ScalarStatsCalculator.class).in(Scopes.SINGLETON); + binder.bind(StatsNormalizer.class).in(Scopes.SINGLETON); + binder.bind(FilterStatsCalculator.class).in(Scopes.SINGLETON); + configBinder(binder).bindConfig(HistoryBasedOptimizationConfig.class); + binder.bind(HistoryBasedPlanStatisticsManager.class).in(Scopes.SINGLETON); + binder.bind(FragmentStatsProvider.class).in(Scopes.SINGLETON); + } + + @Provides + @Singleton + public static StatsCalculator createNewStatsCalculator( + Metadata metadata, + ScalarStatsCalculator scalarStatsCalculator, + StatsNormalizer normalizer, + FilterStatsCalculator filterStatsCalculator, + FragmentStatsProvider fragmentStatsProvider, + HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager, + HistoryBasedOptimizationConfig historyBasedOptimizationConfig) + { + StatsCalculator delegate = createComposableStatsCalculator(metadata, scalarStatsCalculator, normalizer, filterStatsCalculator, fragmentStatsProvider); + return new PrestoSparkStatsCalculator(historyBasedPlanStatisticsManager.getHistoryBasedPlanStatisticsCalculator(delegate), delegate, historyBasedOptimizationConfig); + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/AdaptivePlanOptimizers.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/AdaptivePlanOptimizers.java index bef8857480ea9..c186b3a164102 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/AdaptivePlanOptimizers.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/AdaptivePlanOptimizers.java @@ -16,6 +16,8 @@ import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.OptimizerStatsRecorder; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; @@ -44,11 +46,13 @@ public class AdaptivePlanOptimizers @Inject public AdaptivePlanOptimizers( MBeanExporter exporter, + Metadata metadata, + SqlParser sqlParser, StatsCalculator statsCalculator, CostCalculator costCalculator) { this.exporter = exporter; - this.adaptiveOptimizers = ImmutableList.of(new IterativeOptimizer(ruleStats, statsCalculator, costCalculator, ImmutableSet.of(new PickJoinSides()))); + this.adaptiveOptimizers = ImmutableList.of(new IterativeOptimizer(ruleStats, statsCalculator, costCalculator, ImmutableSet.of(new PickJoinSides(metadata, sqlParser)))); } @PostConstruct diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/PickJoinSides.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/PickJoinSides.java index 30c29dfc5e809..8f8112fe06e80 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/PickJoinSides.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/PickJoinSides.java @@ -32,17 +32,23 @@ import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; +import java.util.Optional; + import static com.facebook.presto.SystemSessionProperties.isSizeBasedJoinDistributionTypeEnabled; import static com.facebook.presto.spark.PrestoSparkSessionProperties.isAdaptiveJoinSideSwitchingEnabled; -import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.isBelowBroadcastLimit; -import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.isSmallerThanThreshold; +import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.createRuntimeSwappedJoinNode; +import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isBelowBroadcastLimit; +import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isSmallerThanThreshold; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; import static com.facebook.presto.sql.planner.plan.Patterns.join; +import static java.util.Objects.requireNonNull; /** * This optimizer chooses the build and probe side of the join based on the size of the @@ -66,6 +72,15 @@ public class PickJoinSides // changing the distribution type too && !(joinNode.getCriteria().isEmpty() && (joinNode.getType() == LEFT || joinNode.getType() == RIGHT))); + private Metadata metadata; + private SqlParser sqlParser; + + public PickJoinSides(Metadata metadata, SqlParser sqlParser) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + } + @Override public Pattern getPattern() { @@ -85,17 +100,14 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) double leftSize = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes(); double rightSize = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes(); - if (rightSize > leftSize) { - return Result.ofPlanNode(joinNode.flipChildren()); - } - + Optional rewrittenNode = Optional.empty(); // if we don't have exact costs for the join, but based on source tables we think the left side // is very small or much smaller than the right, then flip the join. - if (isSizeBasedJoinDistributionTypeEnabled(context.getSession()) && (Double.isNaN(leftSize) || Double.isNaN(rightSize)) && isLeftSideSmall(joinNode, context)) { - return Result.ofPlanNode(joinNode.flipChildren()); + if (rightSize > leftSize || (isSizeBasedJoinDistributionTypeEnabled(context.getSession()) && (Double.isNaN(leftSize) || Double.isNaN(rightSize)) && isLeftSideSmall(joinNode, context))) { + rewrittenNode = createRuntimeSwappedJoinNode(joinNode, metadata, sqlParser, context.getLookup(), context.getSession(), context.getVariableAllocator(), context.getIdAllocator()); } - return Result.empty(); + return rewrittenNode.map(Result::ofPlanNode).orElseGet(Result::empty); } // This logic is based on DetermineJoinDistributionType.getSizeBasedJoin(), diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkHistoryBasedTracking.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkHistoryBasedTracking.java index ff42046466c70..62a7076ce0ba5 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkHistoryBasedTracking.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkHistoryBasedTracking.java @@ -21,9 +21,9 @@ import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.testing.InMemoryHistoryBasedPlanStatisticsProvider; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; -import com.facebook.presto.tests.statistics.InMemoryHistoryBasedPlanStatisticsProvider; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -42,7 +42,8 @@ public class TestPrestoSparkHistoryBasedTracking extends AbstractTestQueryFramework { @Override - protected QueryRunner createQueryRunner() throws Exception + protected QueryRunner createQueryRunner() + throws Exception { PrestoSparkQueryRunner queryRunner = createHivePrestoSparkQueryRunner(ImmutableList.of(NATION, ORDERS)); queryRunner.installPlugin(new Plugin() diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/adaptive/execution/TestPrestoSparkAdaptiveJoinQueries.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/adaptive/execution/TestPrestoSparkAdaptiveJoinQueries.java index 12fbbd19afd0b..a30f518ca2ecc 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/adaptive/execution/TestPrestoSparkAdaptiveJoinQueries.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/adaptive/execution/TestPrestoSparkAdaptiveJoinQueries.java @@ -15,8 +15,17 @@ import com.facebook.presto.Session; import com.facebook.presto.spark.TestPrestoSparkJoinQueries; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; +import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION; +import static com.facebook.presto.SystemSessionProperties.USE_HISTORY_BASED_PLAN_STATISTICS; +import static com.facebook.presto.spark.PrestoSparkSessionProperties.ADAPTIVE_JOIN_SIDE_SWITCHING_ENABLED; import static com.facebook.presto.spark.PrestoSparkSessionProperties.SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.PARTITIONED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.NONE; public class TestPrestoSparkAdaptiveJoinQueries extends TestPrestoSparkJoinQueries @@ -26,6 +35,29 @@ protected Session getSession() { return Session.builder(super.getSession()) .setSystemProperty(SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED, "true") + .setSystemProperty(ADAPTIVE_JOIN_SIDE_SWITCHING_ENABLED, "true") .build(); } + + @DataProvider(name = "optimize_hash_generation") + public Object[][] optimizeHashGeneration() + { + return new Object[][] {{"true"}, {"false"}}; + } + + @Test(dataProvider = "optimize_hash_generation") + public void testQuerySucceedsWithAQE(String optimizeHashGeneration) + { + // we don't add a memory limit and test that the query succeeds with aqe and fails without + // because the memory used by the PrestoSparkRowOutputOperator is very variable and too + // close to the memory used by the build side of the join to allow such a test to run reliably + Session session = Session.builder(getSession()) + .setSystemProperty(JOIN_REORDERING_STRATEGY, NONE.name()) + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, PARTITIONED.name()) + .setSystemProperty(USE_HISTORY_BASED_PLAN_STATISTICS, "false") + .setSystemProperty(OPTIMIZE_HASH_GENERATION, optimizeHashGeneration) + .build(); + + assertQuery(session, "SELECT * FROM nation n JOIN orders o ON n.nationkey = o.orderkey"); + } } diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java new file mode 100644 index 0000000000000..7bb606b83b944 --- /dev/null +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java @@ -0,0 +1,208 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spark.planner; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.FragmentStatsProvider; +import com.facebook.presto.cost.HistoryBasedOptimizationConfig; +import com.facebook.presto.cost.HistoryBasedPlanStatisticsCalculator; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsCalculatorTester; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.PlanNodeWithHash; +import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.statistics.HistoricalPlanStatistics; +import com.facebook.presto.spi.statistics.HistoricalPlanStatisticsEntry; +import com.facebook.presto.spi.statistics.HistoryBasedPlanStatisticsProvider; +import com.facebook.presto.spi.statistics.PlanStatistics; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.PlanFragmentId; +import com.facebook.presto.testing.InMemoryHistoryBasedPlanStatisticsProvider; +import com.facebook.presto.testing.LocalQueryRunner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.USE_HISTORY_BASED_PLAN_STATISTICS; +import static com.facebook.presto.common.plan.PlanCanonicalizationStrategy.REMOVE_SAFE_CONSTANTS; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.Double.NaN; + +@Test(singleThreaded = true) +public class TestPrestoSparkStatsCalculator +{ + private static final QueryId TEST_QUERY_ID = new QueryId("testqueryid"); + private HistoryBasedPlanStatisticsCalculator historyBasedPlanStatisticsCalculator; + private FragmentStatsProvider fragmentStatsProvider; + private PrestoSparkStatsCalculator prestoSparkStatsCalculator; + private Metadata metadata; + private StatsCalculatorTester tester; + private Session session; + + @BeforeClass + public void setUp() + { + session = testSessionBuilder() + .setQueryId(TEST_QUERY_ID) + .setSystemProperty(USE_HISTORY_BASED_PLAN_STATISTICS, "true") + .build(); + LocalQueryRunner queryRunner = new LocalQueryRunner(session); + queryRunner.installPlugin(new Plugin() + { + @Override + public Iterable getHistoryBasedPlanStatisticsProviders() + { + return ImmutableList.of(new InMemoryHistoryBasedPlanStatisticsProvider()); + } + }); + + historyBasedPlanStatisticsCalculator = (HistoryBasedPlanStatisticsCalculator) queryRunner.getStatsCalculator(); + fragmentStatsProvider = queryRunner.getFragmentStatsProvider(); + prestoSparkStatsCalculator = new PrestoSparkStatsCalculator( + historyBasedPlanStatisticsCalculator, + historyBasedPlanStatisticsCalculator.getDelegate(), + new HistoryBasedOptimizationConfig()); + metadata = queryRunner.getMetadata(); + tester = new StatsCalculatorTester( + queryRunner, + prestoSparkStatsCalculator); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + tester.close(); + tester = null; + } + + @AfterMethod(alwaysRun = true) + public void resetCaches() + { + ((InMemoryHistoryBasedPlanStatisticsProvider) historyBasedPlanStatisticsCalculator.getHistoryBasedPlanStatisticsProvider().get()).clearCache(); + fragmentStatsProvider.invalidateStats(TEST_QUERY_ID, 1); + } + + @Test + public void testUsesHboStatsWhenMatchRuntime() + { + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + PlanBuilder planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata); + PlanNode statsEquivalentRemoteSource = planBuilder + .registerVariable(planBuilder.variable("c1")) + .filter(planBuilder.rowExpression("c1 IS NOT NULL"), + planBuilder.values(planBuilder.variable("c1"))); + Optional hash = historyBasedPlanStatisticsCalculator.getPlanCanonicalInfoProvider().hash(session, statsEquivalentRemoteSource, REMOVE_SAFE_CONSTANTS); + + InMemoryHistoryBasedPlanStatisticsProvider historyBasedPlanStatisticsProvider = (InMemoryHistoryBasedPlanStatisticsProvider) historyBasedPlanStatisticsCalculator.getHistoryBasedPlanStatisticsProvider().get(); + historyBasedPlanStatisticsProvider.putStats(ImmutableMap.of( + new PlanNodeWithHash( + statsEquivalentRemoteSource, + hash), + new HistoricalPlanStatistics( + ImmutableList.of( + new HistoricalPlanStatisticsEntry( + new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1), + ImmutableList.of()))))); + + tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)), statsEquivalentRemoteSource)) + .check(check -> check.totalSize(1000) + .outputRowsCount(100)); + } + + @Test + public void testUsesRuntimeStatsWhenNoHboStats() + { + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)))) + .check(check -> check.totalSize(1000) + .outputRowsCountUnknown()); + } + + @Test + public void testUsesRuntimeStatsWhenHboDisabled() + { + Session session = testSessionBuilder() + .setQueryId(TEST_QUERY_ID) + .setSystemProperty(USE_HISTORY_BASED_PLAN_STATISTICS, "false") + .build(); + LocalQueryRunner localQueryRunner = new LocalQueryRunner(session); + HistoryBasedPlanStatisticsCalculator historyBasedPlanStatisticsCalculator = (HistoryBasedPlanStatisticsCalculator) localQueryRunner.getStatsCalculator(); + FragmentStatsProvider fragmentStatsProvider = localQueryRunner.getFragmentStatsProvider(); + PrestoSparkStatsCalculator prestoSparkStatsCalculator = new PrestoSparkStatsCalculator( + historyBasedPlanStatisticsCalculator, + historyBasedPlanStatisticsCalculator.getDelegate(), + new HistoryBasedOptimizationConfig()); + StatsCalculatorTester tester = new StatsCalculatorTester( + localQueryRunner, + prestoSparkStatsCalculator); + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + + PlanBuilder planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), localQueryRunner.getMetadata()); + PlanNode statsEquivalentRemoteSource = planBuilder + .registerVariable(planBuilder.variable("c1")) + .filter(planBuilder.rowExpression("c1 IS NOT NULL"), + planBuilder.values(planBuilder.variable("c1"))); + HistoryBasedPlanStatisticsProvider historyBasedPlanStatisticsProvider = historyBasedPlanStatisticsCalculator.getHistoryBasedPlanStatisticsProvider().get(); + historyBasedPlanStatisticsProvider.putStats(ImmutableMap.of( + new PlanNodeWithHash( + statsEquivalentRemoteSource, + Optional.empty()), + new HistoricalPlanStatistics( + ImmutableList.of( + new HistoricalPlanStatisticsEntry( + new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1), + ImmutableList.of()))))); + + tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)))) + .check(check -> check.totalSize(1000) + .outputRowsCountUnknown()); + tester.close(); + } + + @Test + public void testUsesRuntimeStatsWhenDiffersFromHbo() + { + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + + PlanBuilder planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata); + PlanNode statsEquivalentRemoteSource = planBuilder + .registerVariable(planBuilder.variable("c1")) + .filter(planBuilder.rowExpression("c1 IS NOT NULL"), + planBuilder.values(planBuilder.variable("c1"))); + HistoryBasedPlanStatisticsProvider historyBasedPlanStatisticsProvider = historyBasedPlanStatisticsCalculator.getHistoryBasedPlanStatisticsProvider().get(); + historyBasedPlanStatisticsProvider.putStats(ImmutableMap.of( + new PlanNodeWithHash( + statsEquivalentRemoteSource, + Optional.empty()), + new HistoricalPlanStatistics( + ImmutableList.of( + new HistoricalPlanStatisticsEntry( + new PlanStatistics(Estimate.of(10), Estimate.of(100), 1), + ImmutableList.of()))))); + + tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)))) + .check(check -> check.totalSize(1000) + .outputRowsCountUnknown()); + } +} diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/optimizers/TestPickJoinSides.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/optimizers/TestPickJoinSides.java index 703a5824e4025..a8d7d9dee31e2 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/optimizers/TestPickJoinSides.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/optimizers/TestPickJoinSides.java @@ -37,7 +37,9 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.tpch.TpchConnectorFactory; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -50,13 +52,16 @@ import java.util.Optional; import static com.facebook.presto.SystemSessionProperties.JOIN_MAX_BROADCAST_TABLE_SIZE; +import static com.facebook.presto.SystemSessionProperties.TASK_CONCURRENCY; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; import static com.facebook.presto.spark.PrestoSparkSessionProperties.ADAPTIVE_JOIN_SIDE_SWITCHING_ENABLED; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.remoteSource; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; @@ -277,19 +282,19 @@ public void testFlipWhenSizeDifferenceLarge() // source tables size exceeds JOIN_MAX_BROADCAST_TABLE_SIZE limit but one side is significantly bigger than the other // therefore we keep the smaller side to the build assertPickJoinSides() - .overrideStats("valuesA", aStatsEstimate) - .overrideStats("valuesB", bStatsEstimate) + .overrideStats("remoteSourceA", aStatsEstimate) + .overrideStats("remoteSourceB", bStatsEstimate) .overrideStats("filterB", PlanNodeStatsEstimate.unknown()) // unestimated term to trigger size based join ordering .on(p -> { VariableReferenceExpression a1 = p.variable("A1", variableType); VariableReferenceExpression b1 = p.variable("B1", variableType); return p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, a1), + p.remoteSource(new PlanNodeId("remoteSourceA"), ImmutableList.of(new PlanFragmentId(1)), ImmutableList.of(a1)), p.filter( new PlanNodeId("filterB"), TRUE_CONSTANT, - p.values(new PlanNodeId("valuesB"), bRows, b1)), + p.remoteSource(new PlanNodeId("remoteSourceB"), ImmutableList.of(new PlanFragmentId(2)), ImmutableList.of(b1))), ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), ImmutableList.of(a1, b1), Optional.empty(), @@ -302,8 +307,9 @@ public void testFlipWhenSizeDifferenceLarge() // same but with join sides reversed assertPickJoinSides() - .overrideStats("valuesA", aStatsEstimate) - .overrideStats("valuesB", bStatsEstimate) + .setSystemProperty(TASK_CONCURRENCY, "2") + .overrideStats("remoteSourceA", aStatsEstimate) + .overrideStats("remoteSourceB", bStatsEstimate) .overrideStats("filterB", PlanNodeStatsEstimate.unknown()) // unestimated term to trigger size based join ordering .on(p -> { VariableReferenceExpression a1 = p.variable("A1", variableType); @@ -313,8 +319,11 @@ public void testFlipWhenSizeDifferenceLarge() p.filter( new PlanNodeId("filterB"), TRUE_CONSTANT, - p.values(new PlanNodeId("valuesB"), bRows, b1)), - p.values(new PlanNodeId("valuesA"), aRows, a1), + p.remoteSource(new PlanNodeId("remoteSourceB"), ImmutableList.of(new PlanFragmentId(2)), ImmutableList.of(b1))), + p.exchange(e -> e.scope(ExchangeNode.Scope.LOCAL) + .fixedHashDistributionPartitioningScheme(ImmutableList.of(a1), ImmutableList.of(a1)) + .addInputsSet(a1) + .addSource(p.remoteSource(new PlanNodeId("remoteSourceA"), ImmutableList.of(new PlanFragmentId(1)), ImmutableList.of(a1)))), ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), ImmutableList.of(b1, a1), Optional.empty(), @@ -328,8 +337,10 @@ public void testFlipWhenSizeDifferenceLarge() ImmutableList.of(equiJoinClause("A1", "B1")), Optional.empty(), Optional.of(PARTITIONED), - values(ImmutableMap.of("A1", 0)), - filter("true", values(ImmutableMap.of("B1", 0))))); + remoteSource(ImmutableList.of(new PlanFragmentId(1)), ImmutableMap.of("A1", 0)), + exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, + filter("true", + remoteSource(ImmutableList.of(new PlanFragmentId(2)), ImmutableMap.of("B1", 0)))))); // Don't flip sides when both are similar in size bStatsEstimate = PlanNodeStatsEstimate.builder() @@ -339,8 +350,8 @@ public void testFlipWhenSizeDifferenceLarge() new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); assertPickJoinSides() - .overrideStats("valuesA", aStatsEstimate) - .overrideStats("valuesB", bStatsEstimate) + .overrideStats("remoteSourceA", aStatsEstimate) + .overrideStats("remoteSourceB", bStatsEstimate) .overrideStats("filterB", PlanNodeStatsEstimate.unknown()) // unestimated term to trigger size based join ordering .on(p -> { VariableReferenceExpression a1 = p.variable("A1", variableType); @@ -350,8 +361,8 @@ public void testFlipWhenSizeDifferenceLarge() p.filter( new PlanNodeId("filterB"), TRUE_CONSTANT, - p.values(new PlanNodeId("valuesB"), aRows, b1)), - p.values(new PlanNodeId("valuesA"), aRows, a1), + p.remoteSource(new PlanNodeId("remoteSourceB"), ImmutableList.of(new PlanFragmentId(2)), ImmutableList.of(b1))), + p.remoteSource(new PlanNodeId("remoteSourceA"), ImmutableList.of(new PlanFragmentId(1)), ImmutableList.of(a1)), ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), ImmutableList.of(b1, a1), Optional.empty(), @@ -367,7 +378,7 @@ public void testDoesNotFireWhenDisabled() { int aSize = 100; int bSize = 10_000; - tester.assertThat(new PickJoinSides()) + tester.assertThat(new PickJoinSides(tester.getMetadata(), tester.getSqlParser())) .setSystemProperty(ADAPTIVE_JOIN_SIDE_SWITCHING_ENABLED, "false") .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setTotalSize(aSize) @@ -458,7 +469,7 @@ public void testDoesNotFireForRightCrossJoin() private RuleAssert assertPickJoinSides() { - return tester.assertThat(new PickJoinSides()) + return tester.assertThat(new PickJoinSides(tester.getMetadata(), tester.getSqlParser())) .setSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, "100MB") .setSystemProperty(ADAPTIVE_JOIN_SIDE_SWITCHING_ENABLED, "true"); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/RuntimeSourceInfo.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/RuntimeSourceInfo.java index d308bbdd9f026..5ed35a8dab88a 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/RuntimeSourceInfo.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/RuntimeSourceInfo.java @@ -44,4 +44,10 @@ public boolean isConfident() { return true; } + + @Override + public boolean estimateSizeUsingVariables() + { + return false; + } } diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java b/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java index 4fce6388c9fdb..8ff22068aaff6 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java @@ -37,10 +37,10 @@ import com.facebook.presto.sql.planner.plan.SortNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.WindowNode; +import com.facebook.presto.testing.InMemoryHistoryBasedPlanStatisticsProvider; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; import com.facebook.presto.tests.DistributedQueryRunner; -import com.facebook.presto.tests.statistics.InMemoryHistoryBasedPlanStatisticsProvider; import com.facebook.presto.tpch.TpchPlugin; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap;