diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
index 3db158b22b4fa..24bf4052ee1ac 100644
--- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
+++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
@@ -95,6 +95,7 @@ public final class SystemSessionProperties
public static final String LEGACY_ROW_FIELD_ORDINAL_ACCESS = "legacy_row_field_ordinal_access";
public static final String ITERATIVE_OPTIMIZER = "iterative_optimizer_enabled";
public static final String ITERATIVE_OPTIMIZER_TIMEOUT = "iterative_optimizer_timeout";
+ public static final String ENABLE_FORCED_EXCHANGE_BELOW_GROUP_ID = "enable_forced_exchange_below_group_id";
public static final String EXCHANGE_COMPRESSION = "exchange_compression";
public static final String LEGACY_TIMESTAMP = "legacy_timestamp";
public static final String ENABLE_INTERMEDIATE_AGGREGATIONS = "enable_intermediate_aggregations";
@@ -427,6 +428,11 @@ public SystemSessionProperties(
false,
value -> Duration.valueOf((String) value),
Duration::toString),
+ booleanProperty(
+ ENABLE_FORCED_EXCHANGE_BELOW_GROUP_ID,
+ "Enable a stats-based rule adding exchanges below GroupId",
+ featuresConfig.isEnableForcedExchangeBelowGroupId(),
+ true),
booleanProperty(
EXCHANGE_COMPRESSION,
"Enable compression in exchanges",
@@ -765,6 +771,11 @@ public static Duration getOptimizerTimeout(Session session)
return session.getSystemProperty(ITERATIVE_OPTIMIZER_TIMEOUT, Duration.class);
}
+ public static boolean isEnableForcedExchangeBelowGroupId(Session session)
+ {
+ return session.getSystemProperty(ENABLE_FORCED_EXCHANGE_BELOW_GROUP_ID, Boolean.class);
+ }
+
public static boolean isExchangeCompressionEnabled(Session session)
{
return session.getSystemProperty(EXCHANGE_COMPRESSION, Boolean.class);
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java
index fcb71fd24306a..91809218bc0bd 100644
--- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java
+++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java
@@ -70,7 +70,7 @@ public CostCalculatorUsingExchanges(NodeSchedulerConfig nodeSchedulerConfig, Int
this(currentNumberOfWorkerNodes(nodeSchedulerConfig.isIncludeCoordinator(), nodeManager));
}
- static IntSupplier currentNumberOfWorkerNodes(boolean includeCoordinator, InternalNodeManager nodeManager)
+ public static IntSupplier currentNumberOfWorkerNodes(boolean includeCoordinator, InternalNodeManager nodeManager)
{
requireNonNull(nodeManager, "nodeManager is null");
return () -> {
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
index bc50fb7619cad..c2b6d70945ddd 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
@@ -108,6 +108,7 @@ public class FeaturesConfig
private double spillMaxUsedSpaceThreshold = 0.9;
private boolean iterativeOptimizerEnabled = true;
private boolean enableStatsCalculator = true;
+ private boolean enableForcedExchangeBelowGroupId = true;
private boolean pushAggregationThroughJoin = true;
private double memoryRevokingTarget = 0.5;
private double memoryRevokingThreshold = 0.9;
@@ -593,6 +594,18 @@ public FeaturesConfig setEnableStatsCalculator(boolean enableStatsCalculator)
return this;
}
+ public boolean isEnableForcedExchangeBelowGroupId()
+ {
+ return enableForcedExchangeBelowGroupId;
+ }
+
+ @Config("enable-forced-exchange-below-group-id")
+ public FeaturesConfig setEnableForcedExchangeBelowGroupId(boolean enableForcedExchangeBelowGroupId)
+ {
+ this.enableForcedExchangeBelowGroupId = enableForcedExchangeBelowGroupId;
+ return this;
+ }
+
public DataSize getAggregationOperatorUnspillMemoryLimit()
{
return aggregationOperatorUnspillMemoryLimit;
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
index d4d4e7b3b152e..c2a9454274d7b 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
@@ -17,6 +17,9 @@
import com.facebook.presto.cost.CostCalculator.EstimatedExchanges;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.StatsCalculator;
+import com.facebook.presto.execution.TaskManagerConfig;
+import com.facebook.presto.execution.scheduler.NodeSchedulerConfig;
+import com.facebook.presto.metadata.InternalNodeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.split.PageSourceManager;
import com.facebook.presto.split.SplitManager;
@@ -24,6 +27,7 @@
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.iterative.IterativeOptimizer;
import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.iterative.rule.AddExchangesBelowPartialAggregationOverGroupIdRuleSet;
import com.facebook.presto.sql.planner.iterative.rule.AddIntermediateAggregations;
import com.facebook.presto.sql.planner.iterative.rule.CanonicalizeExpressions;
import com.facebook.presto.sql.planner.iterative.rule.CreatePartialTopN;
@@ -127,6 +131,9 @@
import java.util.List;
import java.util.Set;
+import java.util.function.IntSupplier;
+
+import static com.facebook.presto.cost.CostCalculatorUsingExchanges.currentNumberOfWorkerNodes;
public class PlanOptimizers
{
@@ -140,6 +147,9 @@ public PlanOptimizers(
Metadata metadata,
SqlParser sqlParser,
FeaturesConfig featuresConfig,
+ NodeSchedulerConfig nodeSchedulerConfig,
+ InternalNodeManager nodeManager,
+ TaskManagerConfig taskManagerConfig,
MBeanExporter exporter,
SplitManager splitManager,
PageSourceManager pageSourceManager,
@@ -151,6 +161,8 @@ public PlanOptimizers(
this(metadata,
sqlParser,
featuresConfig,
+ currentNumberOfWorkerNodes(nodeSchedulerConfig.isIncludeCoordinator(), nodeManager),
+ taskManagerConfig,
false,
exporter,
splitManager,
@@ -179,6 +191,8 @@ public PlanOptimizers(
Metadata metadata,
SqlParser sqlParser,
FeaturesConfig featuresConfig,
+ IntSupplier numberOfNodes,
+ TaskManagerConfig taskManagerConfig,
boolean forceSingleNode,
MBeanExporter exporter,
SplitManager splitManager,
@@ -484,6 +498,11 @@ public PlanOptimizers(
new PushPartialAggregationThroughJoin(),
new PushPartialAggregationThroughExchange(metadata.getFunctionRegistry()),
new PruneJoinColumns())));
+ builder.add(new IterativeOptimizer(
+ ruleStats,
+ statsCalculator,
+ costCalculator,
+ new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(metadata, sqlParser, numberOfNodes, taskManagerConfig).rules()));
builder.add(new IterativeOptimizer(
ruleStats,
statsCalculator,
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java
new file mode 100644
index 0000000000000..9f2b27bd81c34
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java
@@ -0,0 +1,349 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.cost.PlanNodeStatsEstimate;
+import com.facebook.presto.cost.SymbolStatsEstimate;
+import com.facebook.presto.execution.TaskManagerConfig;
+import com.facebook.presto.matching.Capture;
+import com.facebook.presto.matching.Captures;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.sql.parser.SqlParser;
+import com.facebook.presto.sql.planner.Partitioning;
+import com.facebook.presto.sql.planner.PartitioningScheme;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties;
+import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties;
+import com.facebook.presto.sql.planner.plan.AggregationNode;
+import com.facebook.presto.sql.planner.plan.ExchangeNode;
+import com.facebook.presto.sql.planner.plan.GroupIdNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Multiset;
+import io.airlift.units.DataSize;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.function.IntSupplier;
+
+import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency;
+import static com.facebook.presto.SystemSessionProperties.isEnableForcedExchangeBelowGroupId;
+import static com.facebook.presto.SystemSessionProperties.isEnableStatsCalculator;
+import static com.facebook.presto.matching.Capture.newCapture;
+import static com.facebook.presto.matching.Pattern.nonEmpty;
+import static com.facebook.presto.matching.Pattern.typeOf;
+import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
+import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.fixedParallelism;
+import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.deriveProperties;
+import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL;
+import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE;
+import static com.facebook.presto.sql.planner.plan.ExchangeNode.partitionedExchange;
+import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.groupingColumns;
+import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.step;
+import static com.facebook.presto.sql.planner.plan.Patterns.Exchange.scope;
+import static com.facebook.presto.sql.planner.plan.Patterns.source;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Verify.verify;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.ImmutableMultiset.toImmutableMultiset;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+import static java.lang.Double.isNaN;
+import static java.lang.Math.min;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Transforms
+ *
+ * - Exchange
+ * - [ Projection ]
+ * - Partial Aggregation
+ * - GroupId
+ *
+ * to
+ *
+ * - Exchange
+ * - [ Projection ]
+ * - Partial Aggregation
+ * - GroupId
+ * - LocalExchange
+ * - RemoteExchange
+ *
+ *
+ * Rationale: GroupId increases number of rows (number of times equal to number of grouping sets) and then
+ * partial aggregation reduces number of rows. However, under certain conditions, exchanging the rows before
+ * GroupId (before multiplication) makes partial aggregation more effective, resulting in less data being
+ * exchanged afterwards.
+ */
+public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet
+{
+ private static final Capture PROJECTION = newCapture();
+ private static final Capture AGGREGATION = newCapture();
+ private static final Capture GROUP_ID = newCapture();
+
+ private static final Pattern WITH_PROJECTION =
+ // If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges
+ typeOf(ExchangeNode.class)
+ .with(scope().equalTo(REMOTE))
+ .with(source().matching(
+ // PushPartialAggregationThroughExchange adds a projection. However, it can be removed if RemoveRedundantIdentityProjections is run in the mean-time.
+ typeOf(ProjectNode.class).capturedAs(PROJECTION)
+ .with(source().matching(
+ typeOf(AggregationNode.class).capturedAs(AGGREGATION)
+ .with(step().equalTo(AggregationNode.Step.PARTIAL))
+ .with(nonEmpty(groupingColumns()))
+ .with(source().matching(
+ typeOf(GroupIdNode.class).capturedAs(GROUP_ID)))))));
+
+ private static final Pattern WITHOUT_PROJECTION =
+ // If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges
+ typeOf(ExchangeNode.class)
+ .with(scope().equalTo(REMOTE))
+ .with(source().matching(
+ typeOf(AggregationNode.class).capturedAs(AGGREGATION)
+ .with(step().equalTo(AggregationNode.Step.PARTIAL))
+ .with(nonEmpty(groupingColumns()))
+ .with(source().matching(
+ typeOf(GroupIdNode.class).capturedAs(GROUP_ID)))));
+
+ private static final double GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY = 0.5;
+ private static final double ANTI_SKEWNESS_MARGIN = 3;
+
+ private final Metadata metadata;
+ private final SqlParser parser;
+ private final IntSupplier numberOfNodes;
+ private final DataSize maxPartialAggregationMemoryUsage;
+
+ public AddExchangesBelowPartialAggregationOverGroupIdRuleSet(
+ Metadata metadata,
+ SqlParser parser,
+ IntSupplier numberOfNodes,
+ TaskManagerConfig taskManagerConfig)
+ {
+ this.metadata = requireNonNull(metadata, "metadata is null");
+ this.parser = requireNonNull(parser, "parser is null");
+ this.numberOfNodes = requireNonNull(numberOfNodes, "numberOfNodes is null");
+ this.maxPartialAggregationMemoryUsage = requireNonNull(taskManagerConfig, "taskManagerConfig is null").getMaxPartialAggregationMemoryUsage();
+ }
+
+ public Set> rules()
+ {
+ return ImmutableSet.of(
+ new AddExchangesBelowProjectionPartialAggregationGroupId(),
+ new AddExchangesBelowExchangePartialAggregationGroupId());
+ }
+
+ private class AddExchangesBelowProjectionPartialAggregationGroupId
+ extends BaseAddExchangesBelowExchangePartialAggregationGroupId
+ {
+ @Override
+ public Pattern getPattern()
+ {
+ return WITH_PROJECTION;
+ }
+
+ @Override
+ public Result apply(ExchangeNode exchange, Captures captures, Context context)
+ {
+ ProjectNode project = captures.get(PROJECTION);
+ AggregationNode aggregation = captures.get(AGGREGATION);
+ GroupIdNode groupId = captures.get(GROUP_ID);
+ return transform(aggregation, groupId, context)
+ .map(newAggregation -> {
+ PlanNode newProject = project.replaceChildren(ImmutableList.of(newAggregation));
+ PlanNode newExchange = exchange.replaceChildren(ImmutableList.of(newProject));
+ return Result.ofPlanNode(newExchange);
+ })
+ .orElseGet(Result::empty);
+ }
+ }
+
+ private class AddExchangesBelowExchangePartialAggregationGroupId
+ extends BaseAddExchangesBelowExchangePartialAggregationGroupId
+ {
+ @Override
+ public Pattern getPattern()
+ {
+ return WITHOUT_PROJECTION;
+ }
+
+ @Override
+ public Result apply(ExchangeNode exchange, Captures captures, Context context)
+ {
+ AggregationNode aggregation = captures.get(AGGREGATION);
+ GroupIdNode groupId = captures.get(GROUP_ID);
+ return transform(aggregation, groupId, context)
+ .map(newAggregation -> {
+ PlanNode newExchange = exchange.replaceChildren(ImmutableList.of(newAggregation));
+ return Result.ofPlanNode(newExchange);
+ })
+ .orElseGet(Result::empty);
+ }
+ }
+
+ private abstract class BaseAddExchangesBelowExchangePartialAggregationGroupId
+ implements Rule
+ {
+ @Override
+ public boolean isEnabled(Session session)
+ {
+ if (!isEnableStatsCalculator(session)) {
+ // Old stats calculator is not trust-worthy
+ return false;
+ }
+
+ return isEnableForcedExchangeBelowGroupId(session);
+ }
+
+ protected Optional transform(AggregationNode aggregation, GroupIdNode groupId, Context context)
+ {
+ if (groupId.getGroupingSets().size() < 2) {
+ return Optional.empty();
+ }
+
+ Set groupingKeys = aggregation.getGroupingKeys().stream()
+ .filter(symbol -> !groupId.getGroupIdSymbol().equals(symbol))
+ .collect(toImmutableSet());
+
+ Multiset groupingSetHistogram = groupId.getGroupingSets().stream()
+ .flatMap(Collection::stream)
+ .collect(toImmutableMultiset());
+
+ if (!Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)) {
+ // TODO handle the case when some aggregation keys are pass-through in GroupId (e.g. common in all grouping sets). However, this is never the case for ROLLUP.
+ // TODO handle the case when some grouping set symbols are not used in aggregation (possible?)
+ return Optional.empty();
+ }
+
+ double aggregationMemoryRequirements = estimateAggregationMemoryRequirements(groupingKeys, groupId, groupingSetHistogram, context);
+ if (isNaN(aggregationMemoryRequirements) || aggregationMemoryRequirements < maxPartialAggregationMemoryUsage.toBytes()) {
+ // Aggregation will be effective even without exchanges.
+ return Optional.empty();
+ }
+
+ List desiredHashSymbols = groupingSetHistogram.entrySet().stream()
+ // Take only frequently used symbols
+ .filter(entry -> entry.getCount() >= groupId.getGroupingSets().size() * GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY)
+ .map(Multiset.Entry::getElement)
+ // And only the symbols used in the aggregation (these are usually all symbols)
+ .peek(symbol -> verify(groupingKeys.contains(symbol)))
+ // Transform to symbols before GroupId
+ .map(groupId.getGroupingColumns()::get)
+ .collect(toImmutableList());
+
+ StreamPreferredProperties requiredProperties = fixedParallelism().withPartitioning(desiredHashSymbols);
+ StreamProperties sourceProperties = derivePropertiesRecursively(groupId.getSource(), context);
+ if (requiredProperties.isSatisfiedBy(sourceProperties)) {
+ // Stream is already (locally) partitioned just as we want.
+ // In fact, there might be just a LocalExchange below and no Remote. For now, we give up in this situation anyway. To properly support such situation:
+ // 1. aggregation effectiveness estimation below need to consider the (helpful) fact that stream is already partitioned, so each operator will need less memory
+ // 2. if the local exchange becomes unnecessary (after we add a remove on top of it), it should be removed. What if the local exchange is somewhere further
+ // down the tree?
+ return Optional.empty();
+ }
+
+ double estimatedGroups = estimatedGroupCount(desiredHashSymbols, context.getStatsProvider().getStats(groupId.getSource()));
+ if (isNaN(estimatedGroups) || estimatedGroups * ANTI_SKEWNESS_MARGIN < maximalConcurrency(context)) {
+ // Desired hash symbols form too few groups. Hashing over them would harm concurrency.
+ // TODO instead of taking symbols with >GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY presence, we could take symbols from high freq to low until there are enough groups
+ return Optional.empty();
+ }
+
+ PlanNode source = groupId.getSource();
+
+ // Above we only checked the data is not yet locally partitioned and it could be already globally partitioned (but not locally). TODO avoid remote exchange in this case
+ // TODO If the aggregation memory requirements are only slightly above `maxPartialAggregationMemoryUsage`, adding only LocalExchange could be enough
+ source = partitionedExchange(
+ context.getIdAllocator().getNextId(),
+ REMOTE,
+ source,
+ new PartitioningScheme(
+ Partitioning.create(FIXED_HASH_DISTRIBUTION, desiredHashSymbols),
+ source.getOutputSymbols()));
+
+ source = partitionedExchange(
+ context.getIdAllocator().getNextId(),
+ LOCAL,
+ source,
+ new PartitioningScheme(
+ Partitioning.create(FIXED_HASH_DISTRIBUTION, desiredHashSymbols),
+ source.getOutputSymbols()));
+
+ PlanNode newGroupId = groupId.replaceChildren(ImmutableList.of(source));
+ PlanNode newAggregation = aggregation.replaceChildren(ImmutableList.of(newGroupId));
+
+ return Optional.of(newAggregation);
+ }
+
+ private int maximalConcurrency(Context context)
+ {
+ return getTaskConcurrency(context.getSession()) * numberOfNodes.getAsInt();
+ }
+
+ private double estimateAggregationMemoryRequirements(Set groupingKeys, GroupIdNode groupId, Multiset groupingSetHistogram, Context context)
+ {
+ checkArgument(Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)); // Otherwise math below would be off-topic
+
+ PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(groupId.getSource());
+ double keysMemoryRequirements = 0;
+
+ for (List groupingSet : groupId.getGroupingSets()) {
+ List sourceSymbols = groupingSet.stream()
+ .map(groupId.getGroupingColumns()::get)
+ .collect(toImmutableList());
+
+ double keyWidth = sourceStats.getOutputSizeInBytes(sourceSymbols, context.getSymbolAllocator().getTypes()) / sourceStats.getOutputRowCount();
+ double keyNdv = min(estimatedGroupCount(sourceSymbols, sourceStats), sourceStats.getOutputRowCount());
+
+ keysMemoryRequirements += keyWidth * keyNdv;
+ }
+
+ // TODO consider also memory requirements for aggregation values
+ return keysMemoryRequirements;
+ }
+
+ private double estimatedGroupCount(List symbols, PlanNodeStatsEstimate statsEstimate)
+ {
+ return symbols.stream()
+ .map(statsEstimate::getSymbolStatistics)
+ .mapToDouble(this::ndvIncludingNull)
+ // This assumes no correlation, maximum number of aggregation keys
+ .reduce(1, (a, b) -> a * b);
+ }
+
+ private double ndvIncludingNull(SymbolStatsEstimate symbolStatsEstimate)
+ {
+ if (symbolStatsEstimate.getNullsFraction() == 0.) {
+ return symbolStatsEstimate.getDistinctValuesCount();
+ }
+ return symbolStatsEstimate.getDistinctValuesCount() + 1;
+ }
+
+ private StreamProperties derivePropertiesRecursively(PlanNode node, Context context)
+ {
+ PlanNode resolvedPlanNode = context.getLookup().resolve(node);
+ List inputProperties = resolvedPlanNode.getSources().stream()
+ .map(source -> derivePropertiesRecursively(source, context))
+ .collect(toImmutableList());
+ return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession(), context.getSymbolAllocator().getTypes(), parser);
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java
index c648cebed4bc8..21f8912dce84f 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java
@@ -361,7 +361,6 @@ private static EquiJoinClause toEquiJoinClause(ComparisonExpression equality, Se
private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode)
{
- // TODO avoid stat (but not cost) recalculation for all considered (distribution,flip) pairs, since resulting relation is the same in all case
if (isAtMostScalar(joinNode.getRight(), lookup)) {
return createJoinEnumerationResult(joinNode.withDistributionType(REPLICATED));
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java
index dcc1c555e5f64..f19c697790ed9 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java
@@ -37,7 +37,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
-class StreamPreferredProperties
+public class StreamPreferredProperties
{
private final Optional distribution;
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java
index faf80bb169cca..2539a72e688b5 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java
@@ -188,6 +188,14 @@ public static Property> correlation()
}
}
+ public static class Exchange
+ {
+ public static Property scope()
+ {
+ return property("scope", ExchangeNode::getScope);
+ }
+ }
+
public static class LateralJoin
{
public static Property> correlation()
diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java
index 6884adebac0ee..932bd93641cec 100644
--- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java
+++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java
@@ -215,6 +215,7 @@ public class LocalQueryRunner
private final PageSorter pageSorter;
private final PageIndexerFactory pageIndexerFactory;
private final MetadataManager metadata;
+ private final int nodeCountForStats;
private final StatsCalculator statsCalculator;
private final CostCalculator costCalculator;
private final CostCalculator estimatedExchangesCostCalculator;
@@ -238,6 +239,7 @@ public class LocalQueryRunner
private final PluginManager pluginManager;
private final ImmutableMap, DataDefinitionTask>> dataDefinitionTask;
+ private final TaskManagerConfig taskManagerConfig;
private final boolean alwaysRevokeMemory;
private final NodeSpillConfig nodeSpillConfig;
private final NodeSchedulerConfig nodeSchedulerConfig;
@@ -265,6 +267,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig,
requireNonNull(defaultSession, "defaultSession is null");
checkArgument(!defaultSession.getTransactionId().isPresent() || !withInitialTransaction, "Already in transaction");
+ this.taskManagerConfig = new TaskManagerConfig().setTaskConcurrency(4);
this.nodeSpillConfig = requireNonNull(nodeSpillConfig, "nodeSpillConfig is null");
this.alwaysRevokeMemory = alwaysRevokeMemory;
this.notificationExecutor = newCachedThreadPool(daemonThreadsNamed("local-query-runner-executor-%s"));
@@ -299,7 +302,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig,
featuresConfig,
typeRegistry,
blockEncodingManager,
- new SessionPropertyManager(new SystemSessionProperties(new QueryManagerConfig(), new TaskManagerConfig(), new MemoryManagerConfig(), featuresConfig)),
+ new SessionPropertyManager(new SystemSessionProperties(new QueryManagerConfig(), taskManagerConfig, new MemoryManagerConfig(), featuresConfig)),
new SchemaPropertyManager(),
new TablePropertyManager(),
new ColumnPropertyManager(),
@@ -307,6 +310,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig,
this.joinCompiler = new JoinCompiler(metadata, featuresConfig);
this.pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler);
this.statsCalculator = createNewStatsCalculator(metadata);
+ this.nodeCountForStats = nodeCountForStats;
this.costCalculator = new CostCalculatorUsingExchanges(() -> nodeCountForStats);
this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, () -> nodeCountForStats);
this.accessControl = new TestingAccessControlManager(transactionManager);
@@ -684,7 +688,7 @@ public List createDrivers(Session session, @Language("SQL") String sql,
pageFunctionCompiler,
joinFilterFunctionCompiler,
new IndexJoinLookupStats(),
- new TaskManagerConfig().setTaskConcurrency(4),
+ this.taskManagerConfig,
spillerFactory,
singleStreamSpillerFactory,
partitioningSpillerFactory,
@@ -792,6 +796,8 @@ public List getPlanOptimizers(boolean forceSingleNode)
metadata,
sqlParser,
featuresConfig,
+ () -> nodeCountForStats,
+ taskManagerConfig,
forceSingleNode,
new MBeanExporter(new TestingMBeanServer()),
splitManager,
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
index 2306073e754b0..593196b4ed27a 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
@@ -86,6 +86,7 @@ public void testDefaults()
.setIterativeOptimizerEnabled(true)
.setIterativeOptimizerTimeout(new Duration(3, MINUTES))
.setEnableStatsCalculator(true)
+ .setEnableForcedExchangeBelowGroupId(true)
.setExchangeCompressionEnabled(false)
.setLegacyTimestamp(true)
.setLegacyRowFieldOrdinalAccess(false)
@@ -118,6 +119,7 @@ public void testExplicitPropertyMappings()
.put("experimental.iterative-optimizer-enabled", "false")
.put("experimental.iterative-optimizer-timeout", "10s")
.put("experimental.enable-stats-calculator", "false")
+ .put("enable-forced-exchange-below-group-id", "false")
.put("deprecated.legacy-array-agg", "true")
.put("deprecated.legacy-log-function", "true")
.put("deprecated.group-by-uses-equal", "true")
@@ -179,6 +181,7 @@ public void testExplicitPropertyMappings()
.setIterativeOptimizerEnabled(false)
.setIterativeOptimizerTimeout(new Duration(10, SECONDS))
.setEnableStatsCalculator(false)
+ .setEnableForcedExchangeBelowGroupId(false)
.setDistributedIndexJoinsEnabled(true)
.setJoinDistributionType(BROADCAST)
.setJoinMaxBroadcastTableSize(new DataSize(42, GIGABYTE))
diff --git a/presto-main/src/test/resources/sql/presto/tpcds/q18.plan.txt b/presto-main/src/test/resources/sql/presto/tpcds/q18.plan.txt
index 4fa2ab06bc137..413d791a53db0 100644
--- a/presto-main/src/test/resources/sql/presto/tpcds/q18.plan.txt
+++ b/presto-main/src/test/resources/sql/presto/tpcds/q18.plan.txt
@@ -4,30 +4,32 @@ local exchange (GATHER, SINGLE, [])
local exchange (REPARTITION, HASH, ["ca_country$gid", "ca_county$gid", "ca_state$gid", "groupid", "i_item_id$gid"])
remote exchange (REPARTITION, HASH, ["ca_country$gid", "ca_county$gid", "ca_state$gid", "groupid", "i_item_id$gid"])
partial aggregation over (ca_country$gid, ca_county$gid, ca_state$gid, groupid, i_item_id$gid)
- join (INNER, REPLICATED):
- join (INNER, REPLICATED):
+ local exchange (REPARTITION, HASH, ["ca_country", "i_item_id"])
+ remote exchange (REPARTITION, HASH, ["ca_country", "i_item_id"])
join (INNER, REPLICATED):
join (INNER, REPLICATED):
- scan tpcds:catalog_sales:sf3000.0
+ join (INNER, REPLICATED):
+ join (INNER, REPLICATED):
+ scan tpcds:catalog_sales:sf3000.0
+ local exchange (GATHER, SINGLE, [])
+ remote exchange (REPLICATE, BROADCAST, [])
+ scan tpcds:customer_demographics:sf3000.0
+ local exchange (GATHER, SINGLE, [])
+ remote exchange (REPLICATE, BROADCAST, [])
+ join (INNER, PARTITIONED):
+ remote exchange (REPARTITION, HASH, ["c_current_cdemo_sk"])
+ join (INNER, PARTITIONED):
+ remote exchange (REPARTITION, HASH, ["c_current_addr_sk"])
+ scan tpcds:customer:sf3000.0
+ local exchange (GATHER, SINGLE, [])
+ remote exchange (REPARTITION, HASH, ["ca_address_sk"])
+ scan tpcds:customer_address:sf3000.0
+ local exchange (GATHER, SINGLE, [])
+ remote exchange (REPARTITION, HASH, ["cd_demo_sk_0"])
+ scan tpcds:customer_demographics:sf3000.0
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
- scan tpcds:customer_demographics:sf3000.0
+ scan tpcds:date_dim:sf3000.0
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
- join (INNER, PARTITIONED):
- remote exchange (REPARTITION, HASH, ["c_current_cdemo_sk"])
- join (INNER, PARTITIONED):
- remote exchange (REPARTITION, HASH, ["c_current_addr_sk"])
- scan tpcds:customer:sf3000.0
- local exchange (GATHER, SINGLE, [])
- remote exchange (REPARTITION, HASH, ["ca_address_sk"])
- scan tpcds:customer_address:sf3000.0
- local exchange (GATHER, SINGLE, [])
- remote exchange (REPARTITION, HASH, ["cd_demo_sk_0"])
- scan tpcds:customer_demographics:sf3000.0
- local exchange (GATHER, SINGLE, [])
- remote exchange (REPLICATE, BROADCAST, [])
- scan tpcds:date_dim:sf3000.0
- local exchange (GATHER, SINGLE, [])
- remote exchange (REPLICATE, BROADCAST, [])
- scan tpcds:item:sf3000.0
+ scan tpcds:item:sf3000.0
diff --git a/presto-main/src/test/resources/sql/presto/tpcds/q22.plan.txt b/presto-main/src/test/resources/sql/presto/tpcds/q22.plan.txt
index 12ad5bd5384ff..ce9b6c1edfc3b 100644
--- a/presto-main/src/test/resources/sql/presto/tpcds/q22.plan.txt
+++ b/presto-main/src/test/resources/sql/presto/tpcds/q22.plan.txt
@@ -4,12 +4,14 @@ local exchange (GATHER, SINGLE, [])
local exchange (REPARTITION, HASH, ["groupid", "i_brand$gid", "i_category$gid", "i_class$gid", "i_product_name$gid"])
remote exchange (REPARTITION, HASH, ["groupid", "i_brand$gid", "i_category$gid", "i_class$gid", "i_product_name$gid"])
partial aggregation over (groupid, i_brand$gid, i_category$gid, i_class$gid, i_product_name$gid)
- join (INNER, REPLICATED):
- join (INNER, REPLICATED):
- scan tpcds:inventory:sf3000.0
- local exchange (GATHER, SINGLE, [])
- remote exchange (REPLICATE, BROADCAST, [])
- scan tpcds:date_dim:sf3000.0
- local exchange (GATHER, SINGLE, [])
- remote exchange (REPLICATE, BROADCAST, [])
- scan tpcds:item:sf3000.0
+ local exchange (REPARTITION, HASH, ["i_brand", "i_product_name"])
+ remote exchange (REPARTITION, HASH, ["i_brand", "i_product_name"])
+ join (INNER, REPLICATED):
+ join (INNER, REPLICATED):
+ scan tpcds:inventory:sf3000.0
+ local exchange (GATHER, SINGLE, [])
+ remote exchange (REPLICATE, BROADCAST, [])
+ scan tpcds:date_dim:sf3000.0
+ local exchange (GATHER, SINGLE, [])
+ remote exchange (REPLICATE, BROADCAST, [])
+ scan tpcds:item:sf3000.0
diff --git a/presto-main/src/test/resources/sql/presto/tpcds/q67.plan.txt b/presto-main/src/test/resources/sql/presto/tpcds/q67.plan.txt
index 9c8154f197d51..cf39fcba65af7 100644
--- a/presto-main/src/test/resources/sql/presto/tpcds/q67.plan.txt
+++ b/presto-main/src/test/resources/sql/presto/tpcds/q67.plan.txt
@@ -6,16 +6,18 @@ local exchange (GATHER, SINGLE, [])
local exchange (REPARTITION, HASH, ["d_moy$gid", "d_qoy$gid", "d_year$gid", "groupid", "i_brand$gid", "i_category$gid", "i_class$gid", "i_product_name$gid", "s_store_id$gid"])
remote exchange (REPARTITION, HASH, ["d_moy$gid", "d_qoy$gid", "d_year$gid", "groupid", "i_brand$gid", "i_category$gid", "i_class$gid", "i_product_name$gid", "s_store_id$gid"])
partial aggregation over (d_moy$gid, d_qoy$gid, d_year$gid, groupid, i_brand$gid, i_category$gid, i_class$gid, i_product_name$gid, s_store_id$gid)
- join (INNER, REPLICATED):
- join (INNER, REPLICATED):
+ local exchange (REPARTITION, HASH, ["i_brand", "i_category", "i_class", "i_product_name"])
+ remote exchange (REPARTITION, HASH, ["i_brand", "i_category", "i_class", "i_product_name"])
join (INNER, REPLICATED):
- scan tpcds:store_sales:sf3000.0
+ join (INNER, REPLICATED):
+ join (INNER, REPLICATED):
+ scan tpcds:store_sales:sf3000.0
+ local exchange (GATHER, SINGLE, [])
+ remote exchange (REPLICATE, BROADCAST, [])
+ scan tpcds:date_dim:sf3000.0
+ local exchange (GATHER, SINGLE, [])
+ remote exchange (REPLICATE, BROADCAST, [])
+ scan tpcds:store:sf3000.0
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
- scan tpcds:date_dim:sf3000.0
- local exchange (GATHER, SINGLE, [])
- remote exchange (REPLICATE, BROADCAST, [])
- scan tpcds:store:sf3000.0
- local exchange (GATHER, SINGLE, [])
- remote exchange (REPLICATE, BROADCAST, [])
- scan tpcds:item:sf3000.0
+ scan tpcds:item:sf3000.0
diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java
index e428a2923311e..cc40615065a72 100644
--- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java
+++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java
@@ -19,6 +19,7 @@
import com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.execution.QueryManagerConfig;
+import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.execution.warnings.WarningCollector;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.security.AccessDeniedException;
@@ -325,6 +326,8 @@ private QueryExplainer getQueryExplainer()
metadata,
sqlParser,
featuresConfig,
+ queryRunner::getNodeCount,
+ new TaskManagerConfig(),
forceSingleNode,
new MBeanExporter(new TestingMBeanServer()),
queryRunner.getSplitManager(),