From cbfcfef64667674291df20209e151deb5809a0ca Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 14 Sep 2018 16:14:06 +0200 Subject: [PATCH 1/2] Remove TODO related to stats caching The idea was abandoned during https://github.com/prestodb/presto/pull/11267 review. --- .../facebook/presto/sql/planner/iterative/rule/ReorderJoins.java | 1 - 1 file changed, 1 deletion(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java index c648cebed4bc8..21f8912dce84f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java @@ -361,7 +361,6 @@ private static EquiJoinClause toEquiJoinClause(ComparisonExpression equality, Se private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) { - // TODO avoid stat (but not cost) recalculation for all considered (distribution,flip) pairs, since resulting relation is the same in all case if (isAtMostScalar(joinNode.getRight(), lookup)) { return createJoinEnumerationResult(joinNode.withDistributionType(REPLICATED)); } From d50db6ca9f7dfa0430de0b50a8ae11ab7d12916f Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Mon, 13 Aug 2018 18:33:05 +0200 Subject: [PATCH 2/2] Add Exchange before GroupId to improve Partial Aggregation The rule brings significant improvement in TPC-DS Q22 and Q67 while not causing much regression in other TPC-H, TPC-DS queries. (The only observably regressing queries were still much better than non-CBO baseline.) --- .../presto/SystemSessionProperties.java | 11 + .../cost/CostCalculatorUsingExchanges.java | 2 +- .../presto/sql/analyzer/FeaturesConfig.java | 13 + .../presto/sql/planner/PlanOptimizers.java | 19 + ...wPartialAggregationOverGroupIdRuleSet.java | 349 ++++++++++++++++++ .../StreamPreferredProperties.java | 2 +- .../presto/sql/planner/plan/Patterns.java | 8 + .../presto/testing/LocalQueryRunner.java | 10 +- .../sql/analyzer/TestFeaturesConfig.java | 3 + .../resources/sql/presto/tpcds/q18.plan.txt | 44 +-- .../resources/sql/presto/tpcds/q22.plan.txt | 20 +- .../resources/sql/presto/tpcds/q67.plan.txt | 22 +- .../tests/AbstractTestQueryFramework.java | 3 + 13 files changed, 462 insertions(+), 44 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 3db158b22b4fa..24bf4052ee1ac 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -95,6 +95,7 @@ public final class SystemSessionProperties public static final String LEGACY_ROW_FIELD_ORDINAL_ACCESS = "legacy_row_field_ordinal_access"; public static final String ITERATIVE_OPTIMIZER = "iterative_optimizer_enabled"; public static final String ITERATIVE_OPTIMIZER_TIMEOUT = "iterative_optimizer_timeout"; + public static final String ENABLE_FORCED_EXCHANGE_BELOW_GROUP_ID = "enable_forced_exchange_below_group_id"; public static final String EXCHANGE_COMPRESSION = "exchange_compression"; public static final String LEGACY_TIMESTAMP = "legacy_timestamp"; public static final String ENABLE_INTERMEDIATE_AGGREGATIONS = "enable_intermediate_aggregations"; @@ -427,6 +428,11 @@ public SystemSessionProperties( false, value -> Duration.valueOf((String) value), Duration::toString), + booleanProperty( + ENABLE_FORCED_EXCHANGE_BELOW_GROUP_ID, + "Enable a stats-based rule adding exchanges below GroupId", + featuresConfig.isEnableForcedExchangeBelowGroupId(), + true), booleanProperty( EXCHANGE_COMPRESSION, "Enable compression in exchanges", @@ -765,6 +771,11 @@ public static Duration getOptimizerTimeout(Session session) return session.getSystemProperty(ITERATIVE_OPTIMIZER_TIMEOUT, Duration.class); } + public static boolean isEnableForcedExchangeBelowGroupId(Session session) + { + return session.getSystemProperty(ENABLE_FORCED_EXCHANGE_BELOW_GROUP_ID, Boolean.class); + } + public static boolean isExchangeCompressionEnabled(Session session) { return session.getSystemProperty(EXCHANGE_COMPRESSION, Boolean.class); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java index fcb71fd24306a..91809218bc0bd 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java @@ -70,7 +70,7 @@ public CostCalculatorUsingExchanges(NodeSchedulerConfig nodeSchedulerConfig, Int this(currentNumberOfWorkerNodes(nodeSchedulerConfig.isIncludeCoordinator(), nodeManager)); } - static IntSupplier currentNumberOfWorkerNodes(boolean includeCoordinator, InternalNodeManager nodeManager) + public static IntSupplier currentNumberOfWorkerNodes(boolean includeCoordinator, InternalNodeManager nodeManager) { requireNonNull(nodeManager, "nodeManager is null"); return () -> { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index bc50fb7619cad..c2b6d70945ddd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -108,6 +108,7 @@ public class FeaturesConfig private double spillMaxUsedSpaceThreshold = 0.9; private boolean iterativeOptimizerEnabled = true; private boolean enableStatsCalculator = true; + private boolean enableForcedExchangeBelowGroupId = true; private boolean pushAggregationThroughJoin = true; private double memoryRevokingTarget = 0.5; private double memoryRevokingThreshold = 0.9; @@ -593,6 +594,18 @@ public FeaturesConfig setEnableStatsCalculator(boolean enableStatsCalculator) return this; } + public boolean isEnableForcedExchangeBelowGroupId() + { + return enableForcedExchangeBelowGroupId; + } + + @Config("enable-forced-exchange-below-group-id") + public FeaturesConfig setEnableForcedExchangeBelowGroupId(boolean enableForcedExchangeBelowGroupId) + { + this.enableForcedExchangeBelowGroupId = enableForcedExchangeBelowGroupId; + return this; + } + public DataSize getAggregationOperatorUnspillMemoryLimit() { return aggregationOperatorUnspillMemoryLimit; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index d4d4e7b3b152e..c2a9454274d7b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -17,6 +17,9 @@ import com.facebook.presto.cost.CostCalculator.EstimatedExchanges; import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.execution.TaskManagerConfig; +import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; +import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; @@ -24,6 +27,7 @@ import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.iterative.rule.AddExchangesBelowPartialAggregationOverGroupIdRuleSet; import com.facebook.presto.sql.planner.iterative.rule.AddIntermediateAggregations; import com.facebook.presto.sql.planner.iterative.rule.CanonicalizeExpressions; import com.facebook.presto.sql.planner.iterative.rule.CreatePartialTopN; @@ -127,6 +131,9 @@ import java.util.List; import java.util.Set; +import java.util.function.IntSupplier; + +import static com.facebook.presto.cost.CostCalculatorUsingExchanges.currentNumberOfWorkerNodes; public class PlanOptimizers { @@ -140,6 +147,9 @@ public PlanOptimizers( Metadata metadata, SqlParser sqlParser, FeaturesConfig featuresConfig, + NodeSchedulerConfig nodeSchedulerConfig, + InternalNodeManager nodeManager, + TaskManagerConfig taskManagerConfig, MBeanExporter exporter, SplitManager splitManager, PageSourceManager pageSourceManager, @@ -151,6 +161,8 @@ public PlanOptimizers( this(metadata, sqlParser, featuresConfig, + currentNumberOfWorkerNodes(nodeSchedulerConfig.isIncludeCoordinator(), nodeManager), + taskManagerConfig, false, exporter, splitManager, @@ -179,6 +191,8 @@ public PlanOptimizers( Metadata metadata, SqlParser sqlParser, FeaturesConfig featuresConfig, + IntSupplier numberOfNodes, + TaskManagerConfig taskManagerConfig, boolean forceSingleNode, MBeanExporter exporter, SplitManager splitManager, @@ -484,6 +498,11 @@ public PlanOptimizers( new PushPartialAggregationThroughJoin(), new PushPartialAggregationThroughExchange(metadata.getFunctionRegistry()), new PruneJoinColumns()))); + builder.add(new IterativeOptimizer( + ruleStats, + statsCalculator, + costCalculator, + new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(metadata, sqlParser, numberOfNodes, taskManagerConfig).rules())); builder.add(new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java new file mode 100644 index 0000000000000..9f2b27bd81c34 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -0,0 +1,349 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.SymbolStatsEstimate; +import com.facebook.presto.execution.TaskManagerConfig; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Partitioning; +import com.facebook.presto.sql.planner.PartitioningScheme; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties; +import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.GroupIdNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multiset; +import io.airlift.units.DataSize; + +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.IntSupplier; + +import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency; +import static com.facebook.presto.SystemSessionProperties.isEnableForcedExchangeBelowGroupId; +import static com.facebook.presto.SystemSessionProperties.isEnableStatsCalculator; +import static com.facebook.presto.matching.Capture.newCapture; +import static com.facebook.presto.matching.Pattern.nonEmpty; +import static com.facebook.presto.matching.Pattern.typeOf; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.fixedParallelism; +import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.deriveProperties; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.partitionedExchange; +import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.groupingColumns; +import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.step; +import static com.facebook.presto.sql.planner.plan.Patterns.Exchange.scope; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMultiset.toImmutableMultiset; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.Double.isNaN; +import static java.lang.Math.min; +import static java.util.Objects.requireNonNull; + +/** + * Transforms + *
+ *   - Exchange
+ *     - [ Projection ]
+ *       - Partial Aggregation
+ *         - GroupId
+ * 
+ * to + *
+ *   - Exchange
+ *     - [ Projection ]
+ *       - Partial Aggregation
+ *         - GroupId
+ *           - LocalExchange
+ *             - RemoteExchange
+ * 
+ *

+ * Rationale: GroupId increases number of rows (number of times equal to number of grouping sets) and then + * partial aggregation reduces number of rows. However, under certain conditions, exchanging the rows before + * GroupId (before multiplication) makes partial aggregation more effective, resulting in less data being + * exchanged afterwards. + */ +public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet +{ + private static final Capture PROJECTION = newCapture(); + private static final Capture AGGREGATION = newCapture(); + private static final Capture GROUP_ID = newCapture(); + + private static final Pattern WITH_PROJECTION = + // If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges + typeOf(ExchangeNode.class) + .with(scope().equalTo(REMOTE)) + .with(source().matching( + // PushPartialAggregationThroughExchange adds a projection. However, it can be removed if RemoveRedundantIdentityProjections is run in the mean-time. + typeOf(ProjectNode.class).capturedAs(PROJECTION) + .with(source().matching( + typeOf(AggregationNode.class).capturedAs(AGGREGATION) + .with(step().equalTo(AggregationNode.Step.PARTIAL)) + .with(nonEmpty(groupingColumns())) + .with(source().matching( + typeOf(GroupIdNode.class).capturedAs(GROUP_ID))))))); + + private static final Pattern WITHOUT_PROJECTION = + // If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges + typeOf(ExchangeNode.class) + .with(scope().equalTo(REMOTE)) + .with(source().matching( + typeOf(AggregationNode.class).capturedAs(AGGREGATION) + .with(step().equalTo(AggregationNode.Step.PARTIAL)) + .with(nonEmpty(groupingColumns())) + .with(source().matching( + typeOf(GroupIdNode.class).capturedAs(GROUP_ID))))); + + private static final double GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY = 0.5; + private static final double ANTI_SKEWNESS_MARGIN = 3; + + private final Metadata metadata; + private final SqlParser parser; + private final IntSupplier numberOfNodes; + private final DataSize maxPartialAggregationMemoryUsage; + + public AddExchangesBelowPartialAggregationOverGroupIdRuleSet( + Metadata metadata, + SqlParser parser, + IntSupplier numberOfNodes, + TaskManagerConfig taskManagerConfig) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.parser = requireNonNull(parser, "parser is null"); + this.numberOfNodes = requireNonNull(numberOfNodes, "numberOfNodes is null"); + this.maxPartialAggregationMemoryUsage = requireNonNull(taskManagerConfig, "taskManagerConfig is null").getMaxPartialAggregationMemoryUsage(); + } + + public Set> rules() + { + return ImmutableSet.of( + new AddExchangesBelowProjectionPartialAggregationGroupId(), + new AddExchangesBelowExchangePartialAggregationGroupId()); + } + + private class AddExchangesBelowProjectionPartialAggregationGroupId + extends BaseAddExchangesBelowExchangePartialAggregationGroupId + { + @Override + public Pattern getPattern() + { + return WITH_PROJECTION; + } + + @Override + public Result apply(ExchangeNode exchange, Captures captures, Context context) + { + ProjectNode project = captures.get(PROJECTION); + AggregationNode aggregation = captures.get(AGGREGATION); + GroupIdNode groupId = captures.get(GROUP_ID); + return transform(aggregation, groupId, context) + .map(newAggregation -> { + PlanNode newProject = project.replaceChildren(ImmutableList.of(newAggregation)); + PlanNode newExchange = exchange.replaceChildren(ImmutableList.of(newProject)); + return Result.ofPlanNode(newExchange); + }) + .orElseGet(Result::empty); + } + } + + private class AddExchangesBelowExchangePartialAggregationGroupId + extends BaseAddExchangesBelowExchangePartialAggregationGroupId + { + @Override + public Pattern getPattern() + { + return WITHOUT_PROJECTION; + } + + @Override + public Result apply(ExchangeNode exchange, Captures captures, Context context) + { + AggregationNode aggregation = captures.get(AGGREGATION); + GroupIdNode groupId = captures.get(GROUP_ID); + return transform(aggregation, groupId, context) + .map(newAggregation -> { + PlanNode newExchange = exchange.replaceChildren(ImmutableList.of(newAggregation)); + return Result.ofPlanNode(newExchange); + }) + .orElseGet(Result::empty); + } + } + + private abstract class BaseAddExchangesBelowExchangePartialAggregationGroupId + implements Rule + { + @Override + public boolean isEnabled(Session session) + { + if (!isEnableStatsCalculator(session)) { + // Old stats calculator is not trust-worthy + return false; + } + + return isEnableForcedExchangeBelowGroupId(session); + } + + protected Optional transform(AggregationNode aggregation, GroupIdNode groupId, Context context) + { + if (groupId.getGroupingSets().size() < 2) { + return Optional.empty(); + } + + Set groupingKeys = aggregation.getGroupingKeys().stream() + .filter(symbol -> !groupId.getGroupIdSymbol().equals(symbol)) + .collect(toImmutableSet()); + + Multiset groupingSetHistogram = groupId.getGroupingSets().stream() + .flatMap(Collection::stream) + .collect(toImmutableMultiset()); + + if (!Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)) { + // TODO handle the case when some aggregation keys are pass-through in GroupId (e.g. common in all grouping sets). However, this is never the case for ROLLUP. + // TODO handle the case when some grouping set symbols are not used in aggregation (possible?) + return Optional.empty(); + } + + double aggregationMemoryRequirements = estimateAggregationMemoryRequirements(groupingKeys, groupId, groupingSetHistogram, context); + if (isNaN(aggregationMemoryRequirements) || aggregationMemoryRequirements < maxPartialAggregationMemoryUsage.toBytes()) { + // Aggregation will be effective even without exchanges. + return Optional.empty(); + } + + List desiredHashSymbols = groupingSetHistogram.entrySet().stream() + // Take only frequently used symbols + .filter(entry -> entry.getCount() >= groupId.getGroupingSets().size() * GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY) + .map(Multiset.Entry::getElement) + // And only the symbols used in the aggregation (these are usually all symbols) + .peek(symbol -> verify(groupingKeys.contains(symbol))) + // Transform to symbols before GroupId + .map(groupId.getGroupingColumns()::get) + .collect(toImmutableList()); + + StreamPreferredProperties requiredProperties = fixedParallelism().withPartitioning(desiredHashSymbols); + StreamProperties sourceProperties = derivePropertiesRecursively(groupId.getSource(), context); + if (requiredProperties.isSatisfiedBy(sourceProperties)) { + // Stream is already (locally) partitioned just as we want. + // In fact, there might be just a LocalExchange below and no Remote. For now, we give up in this situation anyway. To properly support such situation: + // 1. aggregation effectiveness estimation below need to consider the (helpful) fact that stream is already partitioned, so each operator will need less memory + // 2. if the local exchange becomes unnecessary (after we add a remove on top of it), it should be removed. What if the local exchange is somewhere further + // down the tree? + return Optional.empty(); + } + + double estimatedGroups = estimatedGroupCount(desiredHashSymbols, context.getStatsProvider().getStats(groupId.getSource())); + if (isNaN(estimatedGroups) || estimatedGroups * ANTI_SKEWNESS_MARGIN < maximalConcurrency(context)) { + // Desired hash symbols form too few groups. Hashing over them would harm concurrency. + // TODO instead of taking symbols with >GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY presence, we could take symbols from high freq to low until there are enough groups + return Optional.empty(); + } + + PlanNode source = groupId.getSource(); + + // Above we only checked the data is not yet locally partitioned and it could be already globally partitioned (but not locally). TODO avoid remote exchange in this case + // TODO If the aggregation memory requirements are only slightly above `maxPartialAggregationMemoryUsage`, adding only LocalExchange could be enough + source = partitionedExchange( + context.getIdAllocator().getNextId(), + REMOTE, + source, + new PartitioningScheme( + Partitioning.create(FIXED_HASH_DISTRIBUTION, desiredHashSymbols), + source.getOutputSymbols())); + + source = partitionedExchange( + context.getIdAllocator().getNextId(), + LOCAL, + source, + new PartitioningScheme( + Partitioning.create(FIXED_HASH_DISTRIBUTION, desiredHashSymbols), + source.getOutputSymbols())); + + PlanNode newGroupId = groupId.replaceChildren(ImmutableList.of(source)); + PlanNode newAggregation = aggregation.replaceChildren(ImmutableList.of(newGroupId)); + + return Optional.of(newAggregation); + } + + private int maximalConcurrency(Context context) + { + return getTaskConcurrency(context.getSession()) * numberOfNodes.getAsInt(); + } + + private double estimateAggregationMemoryRequirements(Set groupingKeys, GroupIdNode groupId, Multiset groupingSetHistogram, Context context) + { + checkArgument(Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)); // Otherwise math below would be off-topic + + PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(groupId.getSource()); + double keysMemoryRequirements = 0; + + for (List groupingSet : groupId.getGroupingSets()) { + List sourceSymbols = groupingSet.stream() + .map(groupId.getGroupingColumns()::get) + .collect(toImmutableList()); + + double keyWidth = sourceStats.getOutputSizeInBytes(sourceSymbols, context.getSymbolAllocator().getTypes()) / sourceStats.getOutputRowCount(); + double keyNdv = min(estimatedGroupCount(sourceSymbols, sourceStats), sourceStats.getOutputRowCount()); + + keysMemoryRequirements += keyWidth * keyNdv; + } + + // TODO consider also memory requirements for aggregation values + return keysMemoryRequirements; + } + + private double estimatedGroupCount(List symbols, PlanNodeStatsEstimate statsEstimate) + { + return symbols.stream() + .map(statsEstimate::getSymbolStatistics) + .mapToDouble(this::ndvIncludingNull) + // This assumes no correlation, maximum number of aggregation keys + .reduce(1, (a, b) -> a * b); + } + + private double ndvIncludingNull(SymbolStatsEstimate symbolStatsEstimate) + { + if (symbolStatsEstimate.getNullsFraction() == 0.) { + return symbolStatsEstimate.getDistinctValuesCount(); + } + return symbolStatsEstimate.getDistinctValuesCount() + 1; + } + + private StreamProperties derivePropertiesRecursively(PlanNode node, Context context) + { + PlanNode resolvedPlanNode = context.getLookup().resolve(node); + List inputProperties = resolvedPlanNode.getSources().stream() + .map(source -> derivePropertiesRecursively(source, context)) + .collect(toImmutableList()); + return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession(), context.getSymbolAllocator().getTypes(), parser); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java index dcc1c555e5f64..f19c697790ed9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java @@ -37,7 +37,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; -class StreamPreferredProperties +public class StreamPreferredProperties { private final Optional distribution; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java index faf80bb169cca..2539a72e688b5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java @@ -188,6 +188,14 @@ public static Property> correlation() } } + public static class Exchange + { + public static Property scope() + { + return property("scope", ExchangeNode::getScope); + } + } + public static class LateralJoin { public static Property> correlation() diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index 6884adebac0ee..932bd93641cec 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -215,6 +215,7 @@ public class LocalQueryRunner private final PageSorter pageSorter; private final PageIndexerFactory pageIndexerFactory; private final MetadataManager metadata; + private final int nodeCountForStats; private final StatsCalculator statsCalculator; private final CostCalculator costCalculator; private final CostCalculator estimatedExchangesCostCalculator; @@ -238,6 +239,7 @@ public class LocalQueryRunner private final PluginManager pluginManager; private final ImmutableMap, DataDefinitionTask> dataDefinitionTask; + private final TaskManagerConfig taskManagerConfig; private final boolean alwaysRevokeMemory; private final NodeSpillConfig nodeSpillConfig; private final NodeSchedulerConfig nodeSchedulerConfig; @@ -265,6 +267,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, requireNonNull(defaultSession, "defaultSession is null"); checkArgument(!defaultSession.getTransactionId().isPresent() || !withInitialTransaction, "Already in transaction"); + this.taskManagerConfig = new TaskManagerConfig().setTaskConcurrency(4); this.nodeSpillConfig = requireNonNull(nodeSpillConfig, "nodeSpillConfig is null"); this.alwaysRevokeMemory = alwaysRevokeMemory; this.notificationExecutor = newCachedThreadPool(daemonThreadsNamed("local-query-runner-executor-%s")); @@ -299,7 +302,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, featuresConfig, typeRegistry, blockEncodingManager, - new SessionPropertyManager(new SystemSessionProperties(new QueryManagerConfig(), new TaskManagerConfig(), new MemoryManagerConfig(), featuresConfig)), + new SessionPropertyManager(new SystemSessionProperties(new QueryManagerConfig(), taskManagerConfig, new MemoryManagerConfig(), featuresConfig)), new SchemaPropertyManager(), new TablePropertyManager(), new ColumnPropertyManager(), @@ -307,6 +310,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, this.joinCompiler = new JoinCompiler(metadata, featuresConfig); this.pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler); this.statsCalculator = createNewStatsCalculator(metadata); + this.nodeCountForStats = nodeCountForStats; this.costCalculator = new CostCalculatorUsingExchanges(() -> nodeCountForStats); this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, () -> nodeCountForStats); this.accessControl = new TestingAccessControlManager(transactionManager); @@ -684,7 +688,7 @@ public List createDrivers(Session session, @Language("SQL") String sql, pageFunctionCompiler, joinFilterFunctionCompiler, new IndexJoinLookupStats(), - new TaskManagerConfig().setTaskConcurrency(4), + this.taskManagerConfig, spillerFactory, singleStreamSpillerFactory, partitioningSpillerFactory, @@ -792,6 +796,8 @@ public List getPlanOptimizers(boolean forceSingleNode) metadata, sqlParser, featuresConfig, + () -> nodeCountForStats, + taskManagerConfig, forceSingleNode, new MBeanExporter(new TestingMBeanServer()), splitManager, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 2306073e754b0..593196b4ed27a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -86,6 +86,7 @@ public void testDefaults() .setIterativeOptimizerEnabled(true) .setIterativeOptimizerTimeout(new Duration(3, MINUTES)) .setEnableStatsCalculator(true) + .setEnableForcedExchangeBelowGroupId(true) .setExchangeCompressionEnabled(false) .setLegacyTimestamp(true) .setLegacyRowFieldOrdinalAccess(false) @@ -118,6 +119,7 @@ public void testExplicitPropertyMappings() .put("experimental.iterative-optimizer-enabled", "false") .put("experimental.iterative-optimizer-timeout", "10s") .put("experimental.enable-stats-calculator", "false") + .put("enable-forced-exchange-below-group-id", "false") .put("deprecated.legacy-array-agg", "true") .put("deprecated.legacy-log-function", "true") .put("deprecated.group-by-uses-equal", "true") @@ -179,6 +181,7 @@ public void testExplicitPropertyMappings() .setIterativeOptimizerEnabled(false) .setIterativeOptimizerTimeout(new Duration(10, SECONDS)) .setEnableStatsCalculator(false) + .setEnableForcedExchangeBelowGroupId(false) .setDistributedIndexJoinsEnabled(true) .setJoinDistributionType(BROADCAST) .setJoinMaxBroadcastTableSize(new DataSize(42, GIGABYTE)) diff --git a/presto-main/src/test/resources/sql/presto/tpcds/q18.plan.txt b/presto-main/src/test/resources/sql/presto/tpcds/q18.plan.txt index 4fa2ab06bc137..413d791a53db0 100644 --- a/presto-main/src/test/resources/sql/presto/tpcds/q18.plan.txt +++ b/presto-main/src/test/resources/sql/presto/tpcds/q18.plan.txt @@ -4,30 +4,32 @@ local exchange (GATHER, SINGLE, []) local exchange (REPARTITION, HASH, ["ca_country$gid", "ca_county$gid", "ca_state$gid", "groupid", "i_item_id$gid"]) remote exchange (REPARTITION, HASH, ["ca_country$gid", "ca_county$gid", "ca_state$gid", "groupid", "i_item_id$gid"]) partial aggregation over (ca_country$gid, ca_county$gid, ca_state$gid, groupid, i_item_id$gid) - join (INNER, REPLICATED): - join (INNER, REPLICATED): + local exchange (REPARTITION, HASH, ["ca_country", "i_item_id"]) + remote exchange (REPARTITION, HASH, ["ca_country", "i_item_id"]) join (INNER, REPLICATED): join (INNER, REPLICATED): - scan tpcds:catalog_sales:sf3000.0 + join (INNER, REPLICATED): + join (INNER, REPLICATED): + scan tpcds:catalog_sales:sf3000.0 + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan tpcds:customer_demographics:sf3000.0 + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, ["c_current_cdemo_sk"]) + join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, ["c_current_addr_sk"]) + scan tpcds:customer:sf3000.0 + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, ["ca_address_sk"]) + scan tpcds:customer_address:sf3000.0 + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, ["cd_demo_sk_0"]) + scan tpcds:customer_demographics:sf3000.0 local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan tpcds:customer_demographics:sf3000.0 + scan tpcds:date_dim:sf3000.0 local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["c_current_cdemo_sk"]) - join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["c_current_addr_sk"]) - scan tpcds:customer:sf3000.0 - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["ca_address_sk"]) - scan tpcds:customer_address:sf3000.0 - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["cd_demo_sk_0"]) - scan tpcds:customer_demographics:sf3000.0 - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan tpcds:date_dim:sf3000.0 - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan tpcds:item:sf3000.0 + scan tpcds:item:sf3000.0 diff --git a/presto-main/src/test/resources/sql/presto/tpcds/q22.plan.txt b/presto-main/src/test/resources/sql/presto/tpcds/q22.plan.txt index 12ad5bd5384ff..ce9b6c1edfc3b 100644 --- a/presto-main/src/test/resources/sql/presto/tpcds/q22.plan.txt +++ b/presto-main/src/test/resources/sql/presto/tpcds/q22.plan.txt @@ -4,12 +4,14 @@ local exchange (GATHER, SINGLE, []) local exchange (REPARTITION, HASH, ["groupid", "i_brand$gid", "i_category$gid", "i_class$gid", "i_product_name$gid"]) remote exchange (REPARTITION, HASH, ["groupid", "i_brand$gid", "i_category$gid", "i_class$gid", "i_product_name$gid"]) partial aggregation over (groupid, i_brand$gid, i_category$gid, i_class$gid, i_product_name$gid) - join (INNER, REPLICATED): - join (INNER, REPLICATED): - scan tpcds:inventory:sf3000.0 - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan tpcds:date_dim:sf3000.0 - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan tpcds:item:sf3000.0 + local exchange (REPARTITION, HASH, ["i_brand", "i_product_name"]) + remote exchange (REPARTITION, HASH, ["i_brand", "i_product_name"]) + join (INNER, REPLICATED): + join (INNER, REPLICATED): + scan tpcds:inventory:sf3000.0 + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan tpcds:date_dim:sf3000.0 + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan tpcds:item:sf3000.0 diff --git a/presto-main/src/test/resources/sql/presto/tpcds/q67.plan.txt b/presto-main/src/test/resources/sql/presto/tpcds/q67.plan.txt index 9c8154f197d51..cf39fcba65af7 100644 --- a/presto-main/src/test/resources/sql/presto/tpcds/q67.plan.txt +++ b/presto-main/src/test/resources/sql/presto/tpcds/q67.plan.txt @@ -6,16 +6,18 @@ local exchange (GATHER, SINGLE, []) local exchange (REPARTITION, HASH, ["d_moy$gid", "d_qoy$gid", "d_year$gid", "groupid", "i_brand$gid", "i_category$gid", "i_class$gid", "i_product_name$gid", "s_store_id$gid"]) remote exchange (REPARTITION, HASH, ["d_moy$gid", "d_qoy$gid", "d_year$gid", "groupid", "i_brand$gid", "i_category$gid", "i_class$gid", "i_product_name$gid", "s_store_id$gid"]) partial aggregation over (d_moy$gid, d_qoy$gid, d_year$gid, groupid, i_brand$gid, i_category$gid, i_class$gid, i_product_name$gid, s_store_id$gid) - join (INNER, REPLICATED): - join (INNER, REPLICATED): + local exchange (REPARTITION, HASH, ["i_brand", "i_category", "i_class", "i_product_name"]) + remote exchange (REPARTITION, HASH, ["i_brand", "i_category", "i_class", "i_product_name"]) join (INNER, REPLICATED): - scan tpcds:store_sales:sf3000.0 + join (INNER, REPLICATED): + join (INNER, REPLICATED): + scan tpcds:store_sales:sf3000.0 + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan tpcds:date_dim:sf3000.0 + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan tpcds:store:sf3000.0 local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan tpcds:date_dim:sf3000.0 - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan tpcds:store:sf3000.0 - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan tpcds:item:sf3000.0 + scan tpcds:item:sf3000.0 diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index e428a2923311e..cc40615065a72 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -19,6 +19,7 @@ import com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges; import com.facebook.presto.cost.CostComparator; import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.security.AccessDeniedException; @@ -325,6 +326,8 @@ private QueryExplainer getQueryExplainer() metadata, sqlParser, featuresConfig, + queryRunner::getNodeCount, + new TaskManagerConfig(), forceSingleNode, new MBeanExporter(new TestingMBeanServer()), queryRunner.getSplitManager(),