diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 1f2c30d30c1d9..88d9d6dc76b6c 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -292,6 +292,7 @@ public final class SystemSessionProperties public static final String PULL_EXPRESSION_FROM_LAMBDA_ENABLED = "pull_expression_from_lambda_enabled"; public static final String REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION = "rewrite_constant_array_contains_to_in_expression"; public static final String INFER_INEQUALITY_PREDICATES = "infer_inequality_predicates"; + public static final String ENABLE_HISTORY_BASED_SCALED_WRITER = "enable_history_based_scaled_writer"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "native_simplified_expression_evaluation_enabled"; @@ -1751,6 +1752,11 @@ public SystemSessionProperties( INFER_INEQUALITY_PREDICATES, "Infer nonequality predicates for joins", featuresConfig.getInferInequalityPredicates(), + false), + booleanProperty( + ENABLE_HISTORY_BASED_SCALED_WRITER, + "Enable setting the initial number of tasks for scaled writers with HBO", + featuresConfig.isUseHBOForScaledWriters(), false)); } @@ -2920,4 +2926,9 @@ public static boolean shouldInferInequalityPredicates(Session session) { return session.getSystemProperty(INFER_INEQUALITY_PREDICATES, Boolean.class); } + + public static boolean useHBOForScaledWriters(Session session) + { + return session.getSystemProperty(ENABLE_HISTORY_BASED_SCALED_WRITER, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java b/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java index c1faa91ceb947..3ce0e02f4090b 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java @@ -32,8 +32,11 @@ import com.facebook.presto.spi.statistics.JoinNodeStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.PlanStatisticsWithSourceInfo; +import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.facebook.presto.sql.planner.CanonicalPlan; import com.facebook.presto.sql.planner.PlanNodeCanonicalInfo; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.TableWriterNode; import com.facebook.presto.sql.planner.planPrinter.PlanNodeStats; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; @@ -52,6 +55,7 @@ import static com.facebook.presto.common.resourceGroups.QueryType.INSERT; import static com.facebook.presto.common.resourceGroups.QueryType.SELECT; import static com.facebook.presto.cost.HistoricalPlanStatisticsUtil.updatePlanStatistics; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION; import static com.facebook.presto.sql.planner.planPrinter.PlanNodeStatsSummarizer.aggregateStageStats; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -130,6 +134,7 @@ public Map getQueryStats(QueryIn if (!stageInfo.getPlan().isPresent()) { continue; } + boolean isScaledWriterStage = stageInfo.getPlan().isPresent() && stageInfo.getPlan().get().getPartitioning().equals(SCALED_WRITER_DISTRIBUTION); PlanNode root = stageInfo.getPlan().get().getRoot(); for (PlanNode planNode : forTree(PlanNode::getSources).depthFirstPreOrder(root)) { if (!planNode.getStatsEquivalentPlanNode().isPresent()) { @@ -144,6 +149,16 @@ public Map getQueryStats(QueryIn double nullJoinBuildKeyCount = planNodeStats.getPlanNodeNullJoinBuildKeyCount(); double joinBuildKeyCount = planNodeStats.getPlanNodeJoinBuildKeyCount(); + JoinNodeStatistics joinNodeStatistics = JoinNodeStatistics.empty(); + if (planNode instanceof JoinNode) { + joinNodeStatistics = new JoinNodeStatistics(Estimate.of(nullJoinBuildKeyCount), Estimate.of(joinBuildKeyCount)); + } + + TableWriterNodeStatistics tableWriterNodeStatistics = TableWriterNodeStatistics.empty(); + if (isScaledWriterStage && planNode instanceof TableWriterNode) { + tableWriterNodeStatistics = new TableWriterNodeStatistics(Estimate.of(stageInfo.getLatestAttemptExecutionInfo().getStats().getTotalTasks())); + } + PlanNode statsEquivalentPlanNode = planNode.getStatsEquivalentPlanNode().get(); for (PlanCanonicalizationStrategy strategy : historyBasedPlanCanonicalizationStrategyList()) { Optional planNodeCanonicalInfo = Optional.ofNullable( @@ -152,20 +167,23 @@ public Map getQueryStats(QueryIn String hash = planNodeCanonicalInfo.get().getHash(); List inputTableStatistics = planNodeCanonicalInfo.get().getInputTableStatistics(); PlanNodeWithHash planNodeWithHash = new PlanNodeWithHash(statsEquivalentPlanNode, Optional.of(hash)); - // Plan node added after HistoricalStatisticsEquivalentPlanMarkingOptimizer will have the same hash as its source node. If the source node is join node, - // the newly added node will have the same hash with the join but no join statistics, hence we need to overwrite in this case. - if (!planStatistics.containsKey(planNodeWithHash) || nullJoinBuildKeyCount > 0 || joinBuildKeyCount > 0) { - planStatistics.put( - planNodeWithHash, - new PlanStatisticsWithSourceInfo( - planNode.getId(), - new PlanStatistics( - Estimate.of(outputPositions), - Double.isNaN(outputBytes) ? Estimate.unknown() : Estimate.of(outputBytes), - 1.0, - new JoinNodeStatistics(Estimate.of(nullJoinBuildKeyCount), Estimate.of(joinBuildKeyCount))), - new HistoryBasedSourceInfo(Optional.of(hash), Optional.of(inputTableStatistics)))); + // Plan node added after HistoricalStatisticsEquivalentPlanMarkingOptimizer will have the same hash as its source node. If the source node is not join or + // table writer node, the newly added node will have the same hash but no join/table writer statistics, hence we need to overwrite in this case. + PlanStatistics newPlanNodeStats = new PlanStatistics( + Estimate.of(outputPositions), + Double.isNaN(outputBytes) ? Estimate.unknown() : Estimate.of(outputBytes), + 1.0, + joinNodeStatistics, + tableWriterNodeStatistics); + if (planStatistics.containsKey(planNodeWithHash)) { + newPlanNodeStats = planStatistics.get(planNodeWithHash).getPlanStatistics().update(newPlanNodeStats); } + planStatistics.put( + planNodeWithHash, + new PlanStatisticsWithSourceInfo( + planNode.getId(), + newPlanNodeStats, + new HistoryBasedSourceInfo(Optional.of(hash), Optional.of(inputTableStatistics)))); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeStatsEstimate.java index bbd187ec2684e..917cc6e25c951 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeStatsEstimate.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeStatsEstimate.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.cost; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + import java.util.Objects; import static com.google.common.base.MoreObjects.toStringHelper; @@ -25,7 +28,8 @@ public class JoinNodeStatsEstimate private final double nullJoinBuildKeyCount; private final double joinBuildKeyCount; - public JoinNodeStatsEstimate(double nullJoinBuildKeyCount, double joinBuildKeyCount) + @JsonCreator + public JoinNodeStatsEstimate(@JsonProperty("nullJoinBuildKeyCount") double nullJoinBuildKeyCount, @JsonProperty("joinBuildKeyCount") double joinBuildKeyCount) { this.nullJoinBuildKeyCount = nullJoinBuildKeyCount; this.joinBuildKeyCount = joinBuildKeyCount; @@ -36,11 +40,13 @@ public static JoinNodeStatsEstimate unknown() return UNKNOWN; } + @JsonProperty public double getNullJoinBuildKeyCount() { return nullJoinBuildKeyCount; } + @JsonProperty public double getJoinBuildKeyCount() { return joinBuildKeyCount; diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java index f9bd99d593329..34b0d529beffd 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.PlanStatisticsWithSourceInfo; import com.facebook.presto.spi.statistics.SourceInfo; +import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.facebook.presto.sql.Serialization; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -49,7 +50,7 @@ public class PlanNodeStatsEstimate { private static final double DEFAULT_DATA_SIZE_PER_COLUMN = 50; - private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, false, ImmutableMap.of()); + private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, false, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown()); private final double outputRowCount; private final double totalSize; @@ -59,6 +60,8 @@ public class PlanNodeStatsEstimate private final JoinNodeStatsEstimate joinNodeStatsEstimate; + private final TableWriterNodeStatsEstimate tableWriterNodeStatsEstimate; + public static PlanNodeStatsEstimate unknown() { return UNKNOWN; @@ -69,9 +72,11 @@ public PlanNodeStatsEstimate( @JsonProperty("outputRowCount") double outputRowCount, @JsonProperty("totalSize") double totalSize, @JsonProperty("confident") boolean confident, - @JsonProperty("variableStatistics") Map variableStatistics) + @JsonProperty("variableStatistics") Map variableStatistics, + @JsonProperty("joinNodeStatsEstimate") JoinNodeStatsEstimate joinNodeStatsEstimate, + @JsonProperty("tableWriterNodeStatsEstimate") TableWriterNodeStatsEstimate tableWriterNodeStatsEstimate) { - this(outputRowCount, totalSize, confident, HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null"))); + this(outputRowCount, totalSize, HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null")), new CostBasedSourceInfo(confident), joinNodeStatsEstimate, tableWriterNodeStatsEstimate); } private PlanNodeStatsEstimate(double outputRowCount, double totalSize, boolean confident, PMap variableStatistics) @@ -81,11 +86,11 @@ private PlanNodeStatsEstimate(double outputRowCount, double totalSize, boolean c public PlanNodeStatsEstimate(double outputRowCount, double totalSize, PMap variableStatistics, SourceInfo sourceInfo) { - this(outputRowCount, totalSize, variableStatistics, sourceInfo, JoinNodeStatsEstimate.unknown()); + this(outputRowCount, totalSize, variableStatistics, sourceInfo, JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown()); } public PlanNodeStatsEstimate(double outputRowCount, double totalSize, PMap variableStatistics, SourceInfo sourceInfo, - JoinNodeStatsEstimate joinNodeStatsEstimate) + JoinNodeStatsEstimate joinNodeStatsEstimate, TableWriterNodeStatsEstimate tableWriterNodeStatsEstimate) { checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative"); this.outputRowCount = outputRowCount; @@ -93,6 +98,7 @@ public PlanNodeStatsEstimate(double outputRowCount, double totalSize, PMap initialTaskCount; private final Set scheduledNodes = new HashSet<>(); @@ -61,7 +62,8 @@ public ScaledWriterScheduler( NodeSelector nodeSelector, ScheduledExecutorService executor, DataSize writerMinSize, - boolean optimizedScaleWriterProducerBuffer) + boolean optimizedScaleWriterProducerBuffer, + Optional initialTaskCount) { this.stage = requireNonNull(stage, "stage is null"); this.sourceTasksProvider = requireNonNull(sourceTasksProvider, "sourceTasksProvider is null"); @@ -70,6 +72,7 @@ public ScaledWriterScheduler( this.executor = requireNonNull(executor, "executor is null"); this.writerMinSizeBytes = requireNonNull(writerMinSize, "minWriterSize is null").toBytes(); this.optimizedScaleWriterProducerBuffer = optimizedScaleWriterProducerBuffer; + this.initialTaskCount = requireNonNull(initialTaskCount, "initialTaskCount is null"); } public void finish() @@ -93,7 +96,7 @@ public ScheduleResult schedule() private int getNewTaskCount() { if (scheduledNodes.isEmpty()) { - return 1; + return initialTaskCount.orElse(1); } double fullTasks = sourceTasksProvider.get().stream() diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java index add917473f64b..060379cd7a4a7 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java @@ -41,6 +41,7 @@ import com.facebook.presto.sql.planner.NodePartitionMap; import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.PartitioningHandle; +import com.facebook.presto.sql.planner.PlanFragmenterUtils; import com.facebook.presto.sql.planner.SplitSourceFactory; import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.PlanFragmentId; @@ -309,6 +310,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { .map(RemoteTask::getTaskStatus) .collect(toList()); + Optional taskNumberIfScaledWriter = PlanFragmenterUtils.getTableWriterTasks(plan.getFragment().getRoot()); + ScaledWriterScheduler scheduler = new ScaledWriterScheduler( stageExecution, sourceTasksProvider, @@ -316,7 +319,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { nodeScheduler.createNodeSelector(session, null, nodePredicate), scheduledExecutor, getWriterMinSize(session), - isOptimizedScaleWriterProducerBuffer(session)); + isOptimizedScaleWriterProducerBuffer(session), + taskNumberIfScaledWriter); whenAllStages(childStageExecutions, StageExecutionState::isDone) .addListener(scheduler::finish, directExecutor()); return scheduler; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 3d58f6479b739..ac8595fe6e473 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -281,6 +281,7 @@ public class FeaturesConfig private boolean rewriteConstantArrayContainsToIn; private boolean preProcessMetadataCalls; + private boolean useHBOForScaledWriters; public enum PartitioningPrecisionStrategy { @@ -2802,4 +2803,17 @@ public FeaturesConfig setRewriteConstantArrayContainsToInEnabled(boolean rewrite this.rewriteConstantArrayContainsToIn = rewriteConstantArrayContainsToIn; return this; } + + public boolean isUseHBOForScaledWriters() + { + return this.useHBOForScaledWriters; + } + + @Config("optimizer.use-hbo-for-scaled-writers") + @ConfigDescription("Enable HBO for setting initial number of tasks for scaled writers") + public FeaturesConfig setUseHBOForScaledWriters(boolean useHBOForScaledWriters) + { + this.useHBOForScaledWriters = useHBOForScaledWriters; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index 6399e82927102..b75f324839ff1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -579,7 +579,8 @@ private TableFinishNode createTemporaryTableWrite( outputNotNullColumnVariables, Optional.of(partitioningScheme), Optional.empty(), - enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getPartialAggregation()) : Optional.empty())), + enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getPartialAggregation()) : Optional.empty(), + Optional.empty())), variableAllocator.newVariable("intermediaterows", BIGINT), variableAllocator.newVariable("intermediatefragments", VARBINARY), variableAllocator.newVariable("intermediatetablecommitcontext", VARBINARY), @@ -599,7 +600,8 @@ private TableFinishNode createTemporaryTableWrite( outputNotNullColumnVariables, Optional.of(partitioningScheme), Optional.empty(), - enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getPartialAggregation()) : Optional.empty()); + enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getPartialAggregation()) : Optional.empty(), + Optional.empty()); } return new TableFinishNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CachingPlanCanonicalInfoProvider.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CachingPlanCanonicalInfoProvider.java index d21c156a46700..57604e0085ccd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CachingPlanCanonicalInfoProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CachingPlanCanonicalInfoProvider.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.statistics.JoinNodeStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; @@ -119,7 +120,7 @@ private PlanStatistics getPlanStatisticsForTable(Session session, TableScanNode return planStatistics; } TableStatistics tableStatistics = metadata.getTableStatistics(session, key.getTableHandle(), key.getColumnHandles(), key.getConstraint()); - planStatistics = new PlanStatistics(tableStatistics.getRowCount(), tableStatistics.getTotalSize(), 1, JoinNodeStatistics.empty()); + planStatistics = new PlanStatistics(tableStatistics.getRowCount(), tableStatistics.getTotalSize(), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()); cache.put(key, planStatistics); return planStatistics; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index 603b78392b947..746af7070552b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -195,6 +195,7 @@ public Optional visitTableWriter(TableWriterNode node, Context context ImmutableSet.of(), Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); context.addPlan(node, new CanonicalPlan(result, strategy)); return Optional.of(result); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index e0d619dce5393..3b6bc58bf6a4d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -421,7 +421,8 @@ private RelationPlan createTableWriterPlan( preferredShufflePartitioningScheme, // partial aggregation is run within the TableWriteOperator to calculate the statistics for // the data consumed by the TableWriteOperator - Optional.of(aggregations.getPartialAggregation())), + Optional.of(aggregations.getPartialAggregation()), + Optional.empty()), Optional.of(target), variableAllocator.newVariable("rows", BIGINT), // final aggregation is run within the TableFinishOperator to summarize collected statistics @@ -448,6 +449,7 @@ private RelationPlan createTableWriterPlan( notNullColumnVariables, tablePartitioningScheme, preferredShufflePartitioningScheme, + Optional.empty(), Optional.empty()), Optional.of(target), variableAllocator.newVariable("rows", BIGINT), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java index 51141346e2475..1296b2febf928 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java @@ -36,6 +36,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.SystemSessionProperties.getExchangeMaterializationStrategy; @@ -244,6 +245,16 @@ public static Set getTableWriterNodeIds(PlanNode plan) .collect(toImmutableSet()); } + public static Optional getTableWriterTasks(PlanNode plan) + { + return stream(forTree(PlanNode::getSources).depthFirstPreOrder(plan)) + .filter(node -> node instanceof TableWriterNode) + .map(x -> ((TableWriterNode) x).getTaskCountIfScaledWriter()) + .filter(Optional::isPresent) + .map(Optional::get) + .max(Integer::compareTo); + } + private static final class PartitioningHandleReassigner extends SimplePlanRewriter { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index c08f46c880686..220a6d5b1c96a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -117,6 +117,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject; import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides; +import com.facebook.presto.sql.planner.iterative.rule.ScaledWriterRule; import com.facebook.presto.sql.planner.iterative.rule.SimplifyCardinalityMap; import com.facebook.presto.sql.planner.iterative.rule.SimplifyCountOverConstant; import com.facebook.presto.sql.planner.iterative.rule.SimplifyRowExpressions; @@ -706,6 +707,14 @@ public PlanOptimizers( builder.add(new RemoveRedundantDistinctAggregation()); + builder.add( + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of(new ScaledWriterRule()))); + if (!forceSingleNode) { builder.add(new ReplicateSemiJoinInDelete()); // Must run before AddExchanges builder.add(new IterativeOptimizer( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java index 27de8b19f1252..bdb4885a15c33 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java @@ -565,6 +565,7 @@ public Result apply(TableWriterNode node, Captures captures, Context context) return Result.ofPlanNode(new TableWriterNode( node.getSourceLocation(), node.getId(), + node.getStatsEquivalentPlanNode(), node.getSource(), node.getTarget(), node.getRowCountVariable(), @@ -575,7 +576,8 @@ public Result apply(TableWriterNode node, Captures captures, Context context) node.getNotNullColumnVariables(), node.getTablePartitioningScheme(), node.getPreferredShufflePartitioningScheme(), - rewrittenStatisticsAggregation)); + rewrittenStatisticsAggregation, + node.getTaskCountIfScaledWriter())); } return Result.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ScaledWriterRule.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ScaledWriterRule.java new file mode 100644 index 0000000000000..950ce1fdef6ad --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ScaledWriterRule.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableWriterNode; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.useHBOForScaledWriters; +import static com.facebook.presto.sql.planner.plan.Patterns.tableWriterNode; +import static com.google.common.base.Preconditions.checkState; + +public class ScaledWriterRule + implements Rule +{ + @Override + public Pattern getPattern() + { + return tableWriterNode().matching(x -> !x.getTaskCountIfScaledWriter().isPresent()); + } + + @Override + public boolean isEnabled(Session session) + { + return useHBOForScaledWriters(session); + } + + @Override + public Result apply(TableWriterNode node, Captures captures, Context context) + { + double taskNumber = context.getStatsProvider().getStats(node).getTableWriterNodeStatsEstimate().getTaskCountIfScaledWriter(); + if (Double.isNaN(taskNumber)) { + return Result.empty(); + } + // We start from half of the original number + int initialTaskNumber = (int) Math.ceil(taskNumber / 2); + checkState(initialTaskNumber > 0, "taskCountIfScaledWriter should be at least 1"); + return Result.ofPlanNode(new TableWriterNode( + node.getSourceLocation(), + node.getId(), + node.getStatsEquivalentPlanNode(), + node.getSource(), + node.getTarget(), + node.getRowCountVariable(), + node.getFragmentVariable(), + node.getTableCommitContextVariable(), + node.getColumns(), + node.getColumnNames(), + node.getNotNullColumnVariables(), + node.getTablePartitioningScheme(), + node.getPreferredShufflePartitioningScheme(), + node.getStatisticsAggregation(), + Optional.of(initialTaskNumber))); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index 227820bc91e2e..432da7f9e9a65 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -575,7 +575,8 @@ public PlanWithProperties visitTableWriter(TableWriterNode originalTableWriterNo originalTableWriterNode.getNotNullColumnVariables(), originalTableWriterNode.getTablePartitioningScheme(), originalTableWriterNode.getPreferredShufflePartitioningScheme(), - statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation)), + statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation), + originalTableWriterNode.getTaskCountIfScaledWriter()), fixedParallelism(), fixedParallelism()); } @@ -599,7 +600,8 @@ public PlanWithProperties visitTableWriter(TableWriterNode originalTableWriterNo originalTableWriterNode.getNotNullColumnVariables(), originalTableWriterNode.getTablePartitioningScheme(), originalTableWriterNode.getPreferredShufflePartitioningScheme(), - statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation)), + statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation), + originalTableWriterNode.getTaskCountIfScaledWriter()), exchange.getProperties()); } } @@ -627,7 +629,8 @@ public PlanWithProperties visitTableWriter(TableWriterNode originalTableWriterNo originalTableWriterNode.getNotNullColumnVariables(), originalTableWriterNode.getTablePartitioningScheme(), originalTableWriterNode.getPreferredShufflePartitioningScheme(), - statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation)), + statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation), + originalTableWriterNode.getTaskCountIfScaledWriter()), exchange.getProperties()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 56392ad81d83d..eb643f2606e8b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -705,7 +705,8 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext NULL_BUILD_KEY_COUNT_THRESHOLD - && joinEstimate.getJoinNodeStatsEstimate().getNullJoinBuildKeyCount() / joinEstimate.getJoinNodeStatsEstimate().getJoinBuildKeyCount() > getRandomizeOuterJoinNullKeyNullRatioThreshold(session); + JoinNodeStatsEstimate joinEstimate = statsProvider.getStats(joinNode).getJoinNodeStatsEstimate(); + boolean isValidEstimate = !Double.isNaN(joinEstimate.getJoinBuildKeyCount()) && !Double.isNaN(joinEstimate.getNullJoinBuildKeyCount()); + boolean enabledByCostModel = isValidEstimate && strategy.equals(COST_BASED) && joinEstimate.getNullJoinBuildKeyCount() > NULL_BUILD_KEY_COUNT_THRESHOLD + && joinEstimate.getNullJoinBuildKeyCount() / joinEstimate.getJoinBuildKeyCount() > getRandomizeOuterJoinNullKeyNullRatioThreshold(session); String statsSource = null; if (enabledByCostModel) { - statsSource = joinEstimate.getSourceInfo().getSourceInfoName(); + statsSource = statsProvider.getStats(joinNode).getSourceInfo().getSourceInfoName(); } List candidateEquiJoinClauses = joinNode.getCriteria().stream() .filter(x -> isSupportedType(x.getLeft()) && isSupportedType(x.getRight())) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 79c077d34ec98..2961d28d0ecee 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -252,7 +252,8 @@ public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId new node.getNotNullColumnVariables(), node.getTablePartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source)), node.getPreferredShufflePartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source)), - node.getStatisticsAggregation().map(this::map)); + node.getStatisticsAggregation().map(this::map), + node.getTaskCountIfScaledWriter()); } public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java index 29d0c434d1647..a5c5fee9c1163 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java @@ -55,6 +55,7 @@ public class TableWriterNode private final Optional preferredShufflePartitioningScheme; private final Optional statisticsAggregation; private final List outputs; + private final Optional taskCountIfScaledWriter; @JsonCreator public TableWriterNode( @@ -70,9 +71,10 @@ public TableWriterNode( @JsonProperty("notNullColumnVariables") Set notNullColumnVariables, @JsonProperty("partitioningScheme") Optional tablePartitioningScheme, @JsonProperty("preferredShufflePartitioningScheme") Optional preferredShufflePartitioningScheme, - @JsonProperty("statisticsAggregation") Optional statisticsAggregation) + @JsonProperty("statisticsAggregation") Optional statisticsAggregation, + @JsonProperty("taskCountIfScaledWriter") Optional taskCountIfScaledWriter) { - this(sourceLocation, id, Optional.empty(), source, target, rowCountVariable, fragmentVariable, tableCommitContextVariable, columns, columnNames, notNullColumnVariables, tablePartitioningScheme, preferredShufflePartitioningScheme, statisticsAggregation); + this(sourceLocation, id, Optional.empty(), source, target, rowCountVariable, fragmentVariable, tableCommitContextVariable, columns, columnNames, notNullColumnVariables, tablePartitioningScheme, preferredShufflePartitioningScheme, statisticsAggregation, taskCountIfScaledWriter); } public TableWriterNode( @@ -89,7 +91,8 @@ public TableWriterNode( Set notNullColumnVariables, Optional tablePartitioningScheme, Optional preferredShufflePartitioningScheme, - Optional statisticsAggregation) + Optional statisticsAggregation, + Optional taskCountIfScaledWriter) { super(sourceLocation, id, statsEquivalentPlanNode); @@ -121,6 +124,7 @@ public TableWriterNode( outputs.addAll(aggregation.getAggregations().keySet()); }); this.outputs = outputs.build(); + this.taskCountIfScaledWriter = requireNonNull(taskCountIfScaledWriter, "taskCountIfScaledWriter is null"); } @JsonProperty @@ -201,6 +205,12 @@ public List getOutputVariables() return outputs; } + @JsonProperty + public Optional getTaskCountIfScaledWriter() + { + return taskCountIfScaledWriter; + } + @Override public R accept(InternalPlanVisitor visitor, C context) { @@ -224,7 +234,8 @@ public PlanNode replaceChildren(List newChildren) notNullColumnVariables, tablePartitioningScheme, preferredShufflePartitioningScheme, - statisticsAggregation); + statisticsAggregation, + taskCountIfScaledWriter); } @Override @@ -244,7 +255,8 @@ public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalent notNullColumnVariables, tablePartitioningScheme, preferredShufflePartitioningScheme, - statisticsAggregation); + statisticsAggregation, + taskCountIfScaledWriter); } // only used during planning -- will not be serialized diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 6a47a7f9a2eb8..456c1b98a8058 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -385,7 +385,7 @@ private static String formatFragment( double sdAmongTasks = Math.sqrt(squaredDifferences / tasks.size()); builder.append(indentString(1)) - .append(format("CPU: %s, Scheduled: %s, Input: %s (%s); per task: avg.: %s std.dev.: %s, Output: %s (%s)%n", + .append(format("CPU: %s, Scheduled: %s, Input: %s (%s); per task: avg.: %s std.dev.: %s, Output: %s (%s), %s tasks%n", stageExecutionStats.getTotalCpuTime().convertToMostSuccinctTimeUnit(), stageExecutionStats.getTotalScheduledTime().convertToMostSuccinctTimeUnit(), formatPositions(stageExecutionStats.getProcessedInputPositions()), @@ -393,7 +393,8 @@ private static String formatFragment( formatDouble(avgPositionsPerTask), formatDouble(sdAmongTasks), formatPositions(stageExecutionStats.getOutputPositions()), - stageExecutionStats.getOutputDataSize())); + stageExecutionStats.getOutputDataSize(), + tasks.size())); } PartitioningScheme partitioningScheme = fragment.getPartitioningScheme(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java index 2d575c395d28d..3acf1fc5303ab 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java @@ -15,6 +15,7 @@ import com.facebook.presto.cost.PlanCostEstimate; import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.TableWriterNodeStatsEstimate; import com.facebook.presto.spi.eventlistener.CTEInformation; import com.facebook.presto.spi.eventlistener.PlanOptimizerInformation; import com.facebook.presto.sql.planner.optimizations.OptimizerResult; @@ -235,11 +236,14 @@ private String printEstimates(PlanRepresentation plan, NodeRepresentation node) for (int i = 0; i < estimateCount; i++) { PlanNodeStatsEstimate stats = node.getEstimatedStats().get(i); PlanCostEstimate cost = node.getEstimatedCost().get(i); - String formatStr = "{source: %s, rows: %s (%s), cpu: %s, memory: %s, network: %s%s%s}"; + String formatStr = "{source: %s, rows: %s (%s), cpu: %s, memory: %s, network: %s"; boolean hasHashtableStats = stats.getJoinNodeStatsEstimate().getJoinBuildKeyCount() > 0 || stats.getJoinNodeStatsEstimate().getNullJoinBuildKeyCount() > 0; - if (hasHashtableStats) { - formatStr = "{source: %s, rows: %s (%s), cpu: %s, memory: %s, network: %s, hashtable size: %s, hashtable null: %s}"; - } + String joinStatsFormatStr = hasHashtableStats ? ", hashtable[size: %s, nulls %s]" : "%s%s"; + boolean hasTableWriterStats = !stats.getTableWriterNodeStatsEstimate().equals(TableWriterNodeStatsEstimate.unknown()); + String tableWriterStatsFormatStr = hasTableWriterStats ? ", tablewriter[initial tasks: %s]" : "%s"; + formatStr += joinStatsFormatStr; + formatStr += tableWriterStatsFormatStr; + formatStr += "}"; output.append(format(formatStr, stats.getSourceInfo().getClass().getSimpleName(), formatAsLong(stats.getOutputRowCount()), @@ -248,7 +252,8 @@ private String printEstimates(PlanRepresentation plan, NodeRepresentation node) formatDouble(cost.getMaxMemory()), formatDouble(cost.getNetworkCost()), hasHashtableStats ? formatDouble(stats.getJoinNodeStatsEstimate().getJoinBuildKeyCount()) : "", - hasHashtableStats ? formatDouble(stats.getJoinNodeStatsEstimate().getNullJoinBuildKeyCount()) : "")); + hasHashtableStats ? formatDouble(stats.getJoinNodeStatsEstimate().getNullJoinBuildKeyCount()) : "", + hasTableWriterStats ? formatAsLong(stats.getTableWriterNodeStatsEstimate().getTaskCountIfScaledWriter()) : "")); if (i < estimateCount - 1) { output.append("/"); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestFragmentStatsProvider.java b/presto-main/src/test/java/com/facebook/presto/cost/TestFragmentStatsProvider.java index 46f8492dd5b91..263f59ad29d9f 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestFragmentStatsProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestFragmentStatsProvider.java @@ -32,8 +32,8 @@ public void testFragmentStatsProvider() QueryId queryId2 = new QueryId("queryid2"); PlanFragmentId planFragmentId1 = new PlanFragmentId(1); PlanFragmentId planFragmentId2 = new PlanFragmentId(2); - PlanNodeStatsEstimate planNodeStatsEstimate1 = new PlanNodeStatsEstimate(NaN, 10, true, ImmutableMap.of()); - PlanNodeStatsEstimate planNodeStatsEstimate2 = new PlanNodeStatsEstimate(NaN, 100, true, ImmutableMap.of()); + PlanNodeStatsEstimate planNodeStatsEstimate1 = new PlanNodeStatsEstimate(NaN, 10, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown()); + PlanNodeStatsEstimate planNodeStatsEstimate2 = new PlanNodeStatsEstimate(NaN, 100, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown()); assertEquals(fragmentStatsProvider.getStats(queryId1, planFragmentId1), PlanNodeStatsEstimate.unknown()); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestHistoricalPlanStatistics.java b/presto-main/src/test/java/com/facebook/presto/cost/TestHistoricalPlanStatistics.java index ecc4e867574c7..870a3012422d5 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestHistoricalPlanStatistics.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestHistoricalPlanStatistics.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.statistics.HistoricalPlanStatistics; import com.facebook.presto.spi.statistics.JoinNodeStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; +import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -83,7 +84,7 @@ public void testMaxStatistics() private PlanStatistics stats(double rows, double size) { - return new PlanStatistics(Estimate.of(rows), Estimate.of(size), 1, JoinNodeStatistics.empty()); + return new PlanStatistics(Estimate.of(rows), Estimate.of(size), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()); } private static HistoricalPlanStatistics updatePlanStatistics( diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestHistoryBasedStatsProvider.java b/presto-main/src/test/java/com/facebook/presto/cost/TestHistoryBasedStatsProvider.java index 10b640e73b2bc..9b650d1427da2 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestHistoryBasedStatsProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestHistoryBasedStatsProvider.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.statistics.HistoryBasedPlanStatisticsProvider; import com.facebook.presto.spi.statistics.JoinNodeStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; +import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.facebook.presto.sql.Optimizer; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.assertions.PlanAssert; @@ -123,8 +124,8 @@ public Map getStats(List planBuilder.remoteSource(ImmutableList.of(new PlanFragmentId(1), new PlanFragmentId(2)))) .check(check -> check.totalSize(2000) .outputRowsCountUnknown()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 3a7bd96889c65..30066d2cf7a44 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -246,7 +246,8 @@ public void testDefaults() .setAddPartialNodeForRowNumberWithLimitEnabled(true) .setInferInequalityPredicates(false) .setPullUpExpressionFromLambdaEnabled(false) - .setRewriteConstantArrayContainsToInEnabled(false)); + .setRewriteConstantArrayContainsToInEnabled(false) + .setUseHBOForScaledWriters(false)); } @Test @@ -441,6 +442,7 @@ public void testExplicitPropertyMappings() .put("optimizer.infer-inequality-predicates", "true") .put("optimizer.pull-up-expression-from-lambda", "true") .put("optimizer.rewrite-constant-array-contains-to-in", "true") + .put("optimizer.use-hbo-for-scaled-writers", "true") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -632,7 +634,8 @@ public void testExplicitPropertyMappings() .setAddPartialNodeForRowNumberWithLimitEnabled(false) .setInferInequalityPredicates(true) .setPullUpExpressionFromLambdaEnabled(true) - .setRewriteConstantArrayContainsToInEnabled(true); + .setRewriteConstantArrayContainsToInEnabled(true) + .setUseHBOForScaledWriters(true); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 6c6b7db6bd005..28e8cf7506676 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -869,6 +869,7 @@ public TableWriterNode tableWriter(List columns, Li ImmutableSet.of(), Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty()); } diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java index 36258190176a3..c0459fe173ba9 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java @@ -18,8 +18,10 @@ import com.facebook.presto.cost.FragmentStatsProvider; import com.facebook.presto.cost.HistoryBasedOptimizationConfig; import com.facebook.presto.cost.HistoryBasedPlanStatisticsCalculator; +import com.facebook.presto.cost.JoinNodeStatsEstimate; import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.cost.StatsCalculatorTester; +import com.facebook.presto.cost.TableWriterNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.QueryId; @@ -32,6 +34,7 @@ import com.facebook.presto.spi.statistics.HistoryBasedPlanStatisticsProvider; import com.facebook.presto.spi.statistics.JoinNodeStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; +import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.testing.InMemoryHistoryBasedPlanStatisticsProvider; @@ -107,7 +110,7 @@ public void resetCaches() @Test public void testUsesHboStatsWhenMatchRuntime() { - fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown())); PlanBuilder planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata); PlanNode statsEquivalentRemoteSource = planBuilder .registerVariable(planBuilder.variable("c1")) @@ -123,7 +126,7 @@ public void testUsesHboStatsWhenMatchRuntime() new HistoricalPlanStatistics( ImmutableList.of( new HistoricalPlanStatisticsEntry( - new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty()), + new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), ImmutableList.of()))))); tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)), statsEquivalentRemoteSource)) @@ -134,7 +137,7 @@ public void testUsesHboStatsWhenMatchRuntime() @Test public void testUsesRuntimeStatsWhenNoHboStats() { - fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown())); tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)))) .check(check -> check.totalSize(1000) .outputRowsCountUnknown()); @@ -157,7 +160,7 @@ public void testUsesRuntimeStatsWhenHboDisabled() StatsCalculatorTester tester = new StatsCalculatorTester( localQueryRunner, prestoSparkStatsCalculator); - fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown())); PlanBuilder planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), localQueryRunner.getMetadata()); PlanNode statsEquivalentRemoteSource = planBuilder @@ -172,7 +175,7 @@ public void testUsesRuntimeStatsWhenHboDisabled() new HistoricalPlanStatistics( ImmutableList.of( new HistoricalPlanStatisticsEntry( - new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty()), + new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), ImmutableList.of()))))); tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)))) @@ -184,7 +187,7 @@ public void testUsesRuntimeStatsWhenHboDisabled() @Test public void testUsesRuntimeStatsWhenDiffersFromHbo() { - fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of())); + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown())); PlanBuilder planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata); PlanNode statsEquivalentRemoteSource = planBuilder @@ -199,7 +202,7 @@ public void testUsesRuntimeStatsWhenDiffersFromHbo() new HistoricalPlanStatistics( ImmutableList.of( new HistoricalPlanStatisticsEntry( - new PlanStatistics(Estimate.of(10), Estimate.of(100), 1, JoinNodeStatistics.empty()), + new PlanStatistics(Estimate.of(10), Estimate.of(100), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), ImmutableList.of()))))); tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)))) diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/JoinNodeStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/JoinNodeStatistics.java index 01d2070c31aee..7123ee0991041 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/JoinNodeStatistics.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/JoinNodeStatistics.java @@ -47,6 +47,11 @@ public static JoinNodeStatistics empty() return EMPTY; } + public boolean isEmpty() + { + return this.equals(empty()); + } + @JsonProperty @ThriftField(1) public Estimate getNullJoinBuildKeyCount() diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PlanStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PlanStatistics.java index 5249a5e9a8921..a5e365fedcbf9 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PlanStatistics.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PlanStatistics.java @@ -27,7 +27,7 @@ @ThriftStruct public class PlanStatistics { - private static final PlanStatistics EMPTY = new PlanStatistics(Estimate.unknown(), Estimate.unknown(), 0, JoinNodeStatistics.empty()); + private static final PlanStatistics EMPTY = new PlanStatistics(Estimate.unknown(), Estimate.unknown(), 0, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()); private final Estimate rowCount; private final Estimate outputSize; @@ -35,6 +35,8 @@ public class PlanStatistics private final double confidence; // Join node specific statistics private final JoinNodeStatistics joinNodeStatistics; + // TableWriter node specific statistics + private final TableWriterNodeStatistics tableWriterNodeStatistics; public static PlanStatistics empty() { @@ -46,13 +48,15 @@ public static PlanStatistics empty() public PlanStatistics(@JsonProperty("rowCount") Estimate rowCount, @JsonProperty("outputSize") Estimate outputSize, @JsonProperty("confidence") double confidence, - @JsonProperty("joinNodeStatistics") JoinNodeStatistics joinNodeStatistics) + @JsonProperty("joinNodeStatistics") JoinNodeStatistics joinNodeStatistics, + @JsonProperty("tableWriterNodeStatistics") TableWriterNodeStatistics tableWriterNodeStatistics) { this.rowCount = requireNonNull(rowCount, "rowCount is null"); this.outputSize = requireNonNull(outputSize, "outputSize is null"); checkArgument(confidence >= 0 && confidence <= 1, "confidence should be between 0 and 1"); this.confidence = confidence; this.joinNodeStatistics = requireNonNull(joinNodeStatistics == null ? JoinNodeStatistics.empty() : joinNodeStatistics, "joinNodeStatistics is null"); + this.tableWriterNodeStatistics = requireNonNull(tableWriterNodeStatistics == null ? TableWriterNodeStatistics.empty() : tableWriterNodeStatistics, "tableWriterNodeStatistics is null"); } @JsonProperty @@ -83,7 +87,23 @@ public JoinNodeStatistics getJoinNodeStatistics() return joinNodeStatistics; } - // Next ThriftField value 7 + @JsonProperty + @ThriftField(value = 7, requiredness = OPTIONAL) + public TableWriterNodeStatistics getTableWriterNodeStatistics() + { + return tableWriterNodeStatistics; + } + + // Next ThriftField value 8 + + public PlanStatistics update(PlanStatistics planStatistics) + { + return new PlanStatistics(planStatistics.getRowCount(), + planStatistics.getOutputSize(), + planStatistics.getConfidence(), + planStatistics.getJoinNodeStatistics().isEmpty() ? getJoinNodeStatistics() : planStatistics.getJoinNodeStatistics(), + planStatistics.getTableWriterNodeStatistics().isEmpty() ? getTableWriterNodeStatistics() : planStatistics.getTableWriterNodeStatistics()); + } private static void checkArgument(boolean condition, String message) { @@ -103,13 +123,13 @@ public boolean equals(Object o) } PlanStatistics that = (PlanStatistics) o; return Double.compare(that.confidence, confidence) == 0 && Objects.equals(rowCount, that.rowCount) && Objects.equals(outputSize, that.outputSize) - && Objects.equals(joinNodeStatistics, that.joinNodeStatistics); + && Objects.equals(joinNodeStatistics, that.joinNodeStatistics) && Objects.equals(tableWriterNodeStatistics, that.tableWriterNodeStatistics); } @Override public int hashCode() { - return Objects.hash(rowCount, outputSize, confidence, joinNodeStatistics); + return Objects.hash(rowCount, outputSize, confidence, joinNodeStatistics, tableWriterNodeStatistics); } @Override @@ -120,6 +140,7 @@ public String toString() ", outputSize=" + outputSize + ", confidence=" + confidence + ", joinNodeStatistics=" + joinNodeStatistics + + ", tableWriterNodeStatistics=" + tableWriterNodeStatistics + '}'; } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/TableWriterNodeStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/TableWriterNodeStatistics.java new file mode 100644 index 0000000000000..1ca65ef8b4591 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/TableWriterNodeStatistics.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.statistics; + +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public class TableWriterNodeStatistics +{ + private static final TableWriterNodeStatistics EMPTY = new TableWriterNodeStatistics(Estimate.unknown()); + // Number of writer tasks when the writer is scaled writer, otherwise unknown + private final Estimate taskCountIfScaledWriter; + + @JsonCreator + @ThriftConstructor + public TableWriterNodeStatistics( + @JsonProperty("taskCountIfScaledWriter") Estimate taskCountIfScaledWriter) + { + this.taskCountIfScaledWriter = requireNonNull(taskCountIfScaledWriter, "taskCountIfScaledWriter is null"); + } + + public static TableWriterNodeStatistics empty() + { + return EMPTY; + } + + public boolean isEmpty() + { + return this.equals(empty()); + } + + @JsonProperty + @ThriftField(1) + public Estimate getTaskCountIfScaledWriter() + { + return taskCountIfScaledWriter; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TableWriterNodeStatistics that = (TableWriterNodeStatistics) o; + return Objects.equals(taskCountIfScaledWriter, that.taskCountIfScaledWriter); + } + + @Override + public int hashCode() + { + return Objects.hash(taskCountIfScaledWriter); + } + + @Override + public String toString() + { + return "TableWriterNodeStatistics{" + + "taskCountIfScaledWriter=" + taskCountIfScaledWriter + + '}'; + } +} diff --git a/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoricalStatisticsSerde.java b/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoricalStatisticsSerde.java index 9bb44fa2d5384..def878e927957 100644 --- a/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoricalStatisticsSerde.java +++ b/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoricalStatisticsSerde.java @@ -18,6 +18,7 @@ import com.facebook.presto.spi.statistics.HistoricalPlanStatisticsEntry; import com.facebook.presto.spi.statistics.JoinNodeStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; +import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -36,8 +37,8 @@ public class TestHistoricalStatisticsSerde public void testSimpleHistoricalStatisticsEncoderDecoder() { HistoricalPlanStatistics samplePlanStatistics = new HistoricalPlanStatistics(ImmutableList.of(new HistoricalPlanStatisticsEntry( - new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty()), - ImmutableList.of(new PlanStatistics(Estimate.of(15000), Estimate.unknown(), 1, JoinNodeStatistics.empty()))))); + new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), + ImmutableList.of(new PlanStatistics(Estimate.of(15000), Estimate.unknown(), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()))))); HistoricalStatisticsSerde historicalStatisticsEncoderDecoder = new HistoricalStatisticsSerde(); // Test PlanHash @@ -54,8 +55,8 @@ public void testHistoricalPlanStatisticsEntryList() { List historicalPlanStatisticsEntryList = new ArrayList<>(); for (int i = 0; i < 50; i++) { - historicalPlanStatisticsEntryList.add(new HistoricalPlanStatisticsEntry(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty()), - ImmutableList.of(new PlanStatistics(Estimate.of(100), Estimate.of(i), 0, JoinNodeStatistics.empty())))); + historicalPlanStatisticsEntryList.add(new HistoricalPlanStatisticsEntry(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), + ImmutableList.of(new PlanStatistics(Estimate.of(100), Estimate.of(i), 0, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty())))); } HistoricalPlanStatistics samplePlanStatistics = new HistoricalPlanStatistics(historicalPlanStatisticsEntryList); HistoricalStatisticsSerde historicalStatisticsEncoderDecoder = new HistoricalStatisticsSerde(); @@ -81,11 +82,11 @@ public void testPlanStatisticsList() { List planStatisticsEntryList = new ArrayList<>(); for (int i = 0; i < 50; i++) { - planStatisticsEntryList.add(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty())); + planStatisticsEntryList.add(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty())); } List historicalPlanStatisticsEntryList = new ArrayList<>(); for (int i = 0; i < 50; i++) { - historicalPlanStatisticsEntryList.add(new HistoricalPlanStatisticsEntry(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty()), + historicalPlanStatisticsEntryList.add(new HistoricalPlanStatisticsEntry(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), planStatisticsEntryList)); } HistoricalPlanStatistics samplePlanStatistics = new HistoricalPlanStatistics(historicalPlanStatisticsEntryList);