createPartialAggregationCo
AggregationNode.Step step,
Session session)
{
- if (maxPartialAggregationMemorySize.isPresent() && step.isOutputPartial() && isAdaptivePartialAggregationEnabled(session)) {
+ if (maxPartialAggregationMemorySize.isPresent() && step.isInputRaw() && step.isOutputPartial() && isAdaptivePartialAggregationEnabled(session)) {
return Optional.of(new PartialAggregationController(maxPartialAggregationMemorySize.get(), getAdaptivePartialAggregationRowsReductionRatioThreshold(session)));
}
return Optional.empty();
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
index 8f8f9715491c8..ff53bf73799c9 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
@@ -67,6 +67,7 @@
import com.facebook.presto.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct;
import com.facebook.presto.sql.planner.iterative.rule.PickTableLayout;
import com.facebook.presto.sql.planner.iterative.rule.PlanRemoteProjections;
+import com.facebook.presto.sql.planner.iterative.rule.PreAggregateBeforeGroupId;
import com.facebook.presto.sql.planner.iterative.rule.PruneAggregationColumns;
import com.facebook.presto.sql.planner.iterative.rule.PruneAggregationSourceColumns;
import com.facebook.presto.sql.planner.iterative.rule.PruneCountAggregationOverScalar;
@@ -1036,6 +1037,14 @@ public PlanOptimizers(
ImmutableSet.of(
new PruneJoinColumns())));
+ builder.add(new IterativeOptimizer(
+ metadata,
+ ruleStats,
+ statsCalculator,
+ costCalculator,
+ ImmutableSet.of(
+ new PreAggregateBeforeGroupId(metadata.getFunctionAndTypeManager()))));
+
builder.add(new IterativeOptimizer(
metadata,
ruleStats,
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PreAggregateBeforeGroupId.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PreAggregateBeforeGroupId.java
new file mode 100644
index 0000000000000..cd3ffc730cd5f
--- /dev/null
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PreAggregateBeforeGroupId.java
@@ -0,0 +1,370 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Captures;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.metadata.FunctionAndTypeManager;
+import com.facebook.presto.spi.function.AggregationFunctionImplementation;
+import com.facebook.presto.spi.function.FunctionHandle;
+import com.facebook.presto.spi.plan.AggregationNode;
+import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
+import com.facebook.presto.spi.plan.Partitioning;
+import com.facebook.presto.spi.plan.PartitioningScheme;
+import com.facebook.presto.spi.plan.PlanNode;
+import com.facebook.presto.spi.plan.ProjectNode;
+import com.facebook.presto.spi.relation.CallExpression;
+import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.ExchangeNode;
+import com.facebook.presto.sql.planner.plan.GroupIdNode;
+import com.google.common.collect.ImmutableList;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+import static com.facebook.presto.SystemSessionProperties.isPreAggregateBeforeGroupingSets;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.isDecomposable;
+import static com.facebook.presto.spi.plan.AggregationNode.Step.INTERMEDIATE;
+import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
+import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
+import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING;
+import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.step;
+import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Objects.requireNonNull;
+import static java.util.stream.Collectors.toSet;
+
+/**
+ * Transforms
+ *
+ * - Partial Aggregation
+ * - GroupId
+ * - Source
+ *
+ * to
+ *
+ * - Intermediate Aggregation
+ * - GroupId
+ * - Intermediate Aggregation
+ * - RemoteExchange
+ * - Partial Aggregation
+ * - Source
+ *
+ *
+ * Rationale: GroupId increases the number of rows (one copy per grouping set), then partial
+ * aggregation reduces them. By pre-aggregating at the finest granularity (union of all grouping
+ * set columns) and shuffling by grouping keys before GroupId, we reduce the number of rows that
+ * get multiplied. The original PARTIAL above GroupId is changed to INTERMEDIATE to merge the
+ * pre-aggregated partial states within each grouping set.
+ *
+ * Also handles the case where a ProjectNode (e.g., from hash generation) sits between
+ * the Aggregation and GroupId.
+ *
+ * Only applies to decomposable aggregation functions ({@code SUM}, {@code COUNT}, {@code MIN},
+ * {@code MAX}) that support partial/intermediate/final splitting.
+ */
+public class PreAggregateBeforeGroupId
+ implements Rule
+{
+ // Match Aggregation(PARTIAL) whose source is either GroupId directly
+ // or a Project node (e.g., from hash generation) on top of GroupId.
+ private static final Pattern PATTERN = aggregation()
+ .with(step().equalTo(PARTIAL));
+
+ private final FunctionAndTypeManager functionAndTypeManager;
+
+ public PreAggregateBeforeGroupId(FunctionAndTypeManager functionAndTypeManager)
+ {
+ this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
+ }
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public boolean isEnabled(Session session)
+ {
+ return isPreAggregateBeforeGroupingSets(session);
+ }
+
+ @Override
+ public Result apply(AggregationNode node, Captures captures, Context context)
+ {
+ // Find GroupIdNode: either the direct source or behind a ProjectNode.
+ // Must resolve through Lookup since the Memo wraps sources in GroupReference.
+ GroupIdNode groupIdNode = findGroupIdNode(node, context.getLookup());
+ if (groupIdNode == null) {
+ return Result.empty();
+ }
+
+ // Safety check: must be decomposable (no DISTINCT, no ORDER BY)
+ if (!isDecomposable(node, functionAndTypeManager)) {
+ return Result.empty();
+ }
+
+ // Skip checksum whose XOR-based intermediate merge produces different results
+ // at different grouping granularities.
+ for (Aggregation agg : node.getAggregations().values()) {
+ String name = functionAndTypeManager.getFunctionMetadata(agg.getFunctionHandle()).getName().getObjectName();
+ if (name.equals("checksum")) {
+ return Result.empty();
+ }
+ }
+
+ // Verify that the aggregation's grouping keys are consistent with GroupId output.
+ if (!aggregationGroupingMatchesGroupId(node, groupIdNode)) {
+ return Result.empty();
+ }
+
+ // Compute the union of all grouping set columns mapped back to source variables.
+ Map groupingColumns = groupIdNode.getGroupingColumns();
+ Set allSourceGroupingKeys = new LinkedHashSet<>();
+ for (List groupingSet : groupIdNode.getGroupingSets()) {
+ for (VariableReferenceExpression outputVar : groupingSet) {
+ VariableReferenceExpression sourceVar = groupingColumns.get(outputVar);
+ if (sourceVar != null) {
+ allSourceGroupingKeys.add(sourceVar);
+ }
+ }
+ }
+
+ if (allSourceGroupingKeys.isEmpty()) {
+ return Result.empty();
+ }
+
+ // Build variable mappings for the three aggregation levels:
+ // 1. PARTIAL: raw values → partialVar (intermediate type)
+ // 2. INTERMEDIATE below GroupId: partialVar → preGroupIdVar (intermediate type)
+ // 3. INTERMEDIATE above GroupId: preGroupIdVar → originalOutputVar (intermediate type)
+ Map outputToPartialVarMap = new HashMap<>();
+ Map outputToPreGroupIdVarMap = new HashMap<>();
+ Map newPartialAggregations = new HashMap<>();
+
+ for (Map.Entry entry : node.getAggregations().entrySet()) {
+ Aggregation originalAggregation = entry.getValue();
+ FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
+ String functionName = functionAndTypeManager.getFunctionMetadata(functionHandle).getName().getObjectName();
+ AggregationFunctionImplementation function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle);
+
+ // Variable for PARTIAL output
+ VariableReferenceExpression partialVariable = context.getVariableAllocator().newVariable(
+ entry.getValue().getCall().getSourceLocation(),
+ functionName,
+ function.getIntermediateType());
+
+ // Variable for INTERMEDIATE-below-GroupId output
+ VariableReferenceExpression preGroupIdVariable = context.getVariableAllocator().newVariable(
+ entry.getValue().getCall().getSourceLocation(),
+ functionName,
+ function.getIntermediateType());
+
+ outputToPartialVarMap.put(entry.getKey(), partialVariable);
+ outputToPreGroupIdVarMap.put(entry.getKey(), preGroupIdVariable);
+
+ // The new PARTIAL aggregation uses the original arguments (which are
+ // GroupIdNode.aggregationArguments — source-side pass-through variables)
+ newPartialAggregations.put(partialVariable, new Aggregation(
+ new CallExpression(
+ originalAggregation.getCall().getSourceLocation(),
+ functionName,
+ functionHandle,
+ function.getIntermediateType(),
+ originalAggregation.getArguments()),
+ originalAggregation.getFilter(),
+ originalAggregation.getOrderBy(),
+ originalAggregation.isDistinct(),
+ originalAggregation.getMask()));
+ }
+
+ // Step 1: Create new PARTIAL AggregationNode on Source
+ ImmutableList groupingKeysList = ImmutableList.copyOf(allSourceGroupingKeys);
+ PlanNode newPartialAggregation = new AggregationNode(
+ node.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ groupIdNode.getSource(),
+ newPartialAggregations,
+ AggregationNode.singleGroupingSet(groupingKeysList),
+ ImmutableList.of(),
+ PARTIAL,
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ // Step 2: Create Exchange (hash partitioned by grouping keys) to shuffle partial states
+ PlanNode exchange = ExchangeNode.partitionedExchange(
+ context.getIdAllocator().getNextId(),
+ REMOTE_STREAMING,
+ newPartialAggregation,
+ new PartitioningScheme(
+ Partitioning.create(FIXED_HASH_DISTRIBUTION, groupingKeysList),
+ newPartialAggregation.getOutputVariables()));
+
+ // Step 3: Create INTERMEDIATE AggregationNode below GroupId to merge partial states after shuffle
+ Map preGroupIdIntermediateAggregations = new HashMap<>();
+ for (Map.Entry entry : node.getAggregations().entrySet()) {
+ Aggregation originalAggregation = entry.getValue();
+ FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
+ String functionName = functionAndTypeManager.getFunctionMetadata(functionHandle).getName().getObjectName();
+ AggregationFunctionImplementation function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle);
+ VariableReferenceExpression partialVariable = outputToPartialVarMap.get(entry.getKey());
+ VariableReferenceExpression preGroupIdVariable = outputToPreGroupIdVarMap.get(entry.getKey());
+
+ preGroupIdIntermediateAggregations.put(preGroupIdVariable, new Aggregation(
+ new CallExpression(
+ originalAggregation.getCall().getSourceLocation(),
+ functionName,
+ functionHandle,
+ function.getIntermediateType(),
+ ImmutableList.builder()
+ .add(partialVariable)
+ .addAll(originalAggregation.getArguments()
+ .stream()
+ .filter(PreAggregateBeforeGroupId::isLambda)
+ .collect(toImmutableList()))
+ .build()),
+ Optional.empty(),
+ Optional.empty(),
+ false,
+ Optional.empty()));
+ }
+
+ PlanNode preGroupIdIntermediate = new AggregationNode(
+ node.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ exchange,
+ preGroupIdIntermediateAggregations,
+ AggregationNode.singleGroupingSet(groupingKeysList),
+ ImmutableList.of(),
+ INTERMEDIATE,
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ // Step 4: Create new GroupIdNode with INTERMEDIATE output as source
+ ImmutableList newAggregationArguments = ImmutableList.copyOf(outputToPreGroupIdVarMap.values());
+
+ GroupIdNode newGroupIdNode = new GroupIdNode(
+ groupIdNode.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ Optional.empty(),
+ preGroupIdIntermediate,
+ groupIdNode.getGroupingSets(),
+ groupIdNode.getGroupingColumns(),
+ newAggregationArguments,
+ groupIdNode.getGroupIdVariable());
+
+ // Step 5: Change the original PARTIAL aggregation above GroupId to INTERMEDIATE.
+ // It takes intermediate state (passed through GroupId) and produces intermediate
+ // state for the existing FINAL above.
+ Map aboveGroupIdIntermediateAggregations = new HashMap<>();
+ for (Map.Entry entry : node.getAggregations().entrySet()) {
+ Aggregation originalAggregation = entry.getValue();
+ FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
+ String functionName = functionAndTypeManager.getFunctionMetadata(functionHandle).getName().getObjectName();
+ AggregationFunctionImplementation function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle);
+ VariableReferenceExpression preGroupIdVariable = outputToPreGroupIdVarMap.get(entry.getKey());
+
+ aboveGroupIdIntermediateAggregations.put(entry.getKey(), new Aggregation(
+ new CallExpression(
+ originalAggregation.getCall().getSourceLocation(),
+ functionName,
+ functionHandle,
+ function.getIntermediateType(),
+ ImmutableList.builder()
+ .add(preGroupIdVariable)
+ .addAll(originalAggregation.getArguments()
+ .stream()
+ .filter(PreAggregateBeforeGroupId::isLambda)
+ .collect(toImmutableList()))
+ .build()),
+ Optional.empty(),
+ Optional.empty(),
+ false,
+ Optional.empty()));
+ }
+
+ PlanNode aboveGroupIdIntermediate = new AggregationNode(
+ node.getSourceLocation(),
+ node.getId(),
+ newGroupIdNode,
+ aboveGroupIdIntermediateAggregations,
+ node.getGroupingSets(),
+ ImmutableList.of(),
+ INTERMEDIATE,
+ node.getHashVariable(),
+ node.getGroupIdVariable(),
+ node.getAggregationId());
+
+ return Result.ofPlanNode(aboveGroupIdIntermediate);
+ }
+
+ /**
+ * Finds the GroupIdNode below the aggregation, looking through an optional
+ * ProjectNode (e.g., inserted by hash generation). Uses Lookup to resolve
+ * through GroupReference nodes in the Memo.
+ */
+ private static GroupIdNode findGroupIdNode(AggregationNode aggregation, Lookup lookup)
+ {
+ PlanNode source = lookup.resolve(aggregation.getSource());
+ if (source instanceof GroupIdNode) {
+ return (GroupIdNode) source;
+ }
+ if (source instanceof ProjectNode) {
+ PlanNode projectSource = lookup.resolve(((ProjectNode) source).getSource());
+ if (projectSource instanceof GroupIdNode) {
+ return (GroupIdNode) projectSource;
+ }
+ }
+ return null;
+ }
+
+ /**
+ * Verifies that the aggregation's grouping keys are exactly the GroupIdNode's
+ * grouping set columns plus the groupId variable.
+ */
+ private static boolean aggregationGroupingMatchesGroupId(AggregationNode aggregation, GroupIdNode groupId)
+ {
+ Set aggregationGroupingKeys = new HashSet<>(aggregation.getGroupingKeys());
+
+ if (!aggregationGroupingKeys.contains(groupId.getGroupIdVariable())) {
+ return false;
+ }
+
+ Set expectedKeys = groupId.getGroupingSets().stream()
+ .flatMap(Collection::stream)
+ .collect(toSet());
+ expectedKeys.add(groupId.getGroupIdVariable());
+
+ return aggregationGroupingKeys.equals(expectedKeys);
+ }
+
+ private static boolean isLambda(RowExpression rowExpression)
+ {
+ return rowExpression instanceof LambdaDefinitionExpression;
+ }
+}
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java
index 9d12d01fed842..ddae310d6b61b 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java
@@ -134,7 +134,7 @@ public RowExpression rewriteSpecialForm(SpecialFormExpression node, Void context
List innerArgs = innerIf.getArguments();
RowExpression innerCondition = innerArgs.get(0);
if (falseValue.equals(innerArgs.get(2)) && determinismEvaluator.isDeterministic(innerCondition)) {
- RowExpression combinedCondition = new SpecialFormExpression(AND, BOOLEAN, condition, innerCondition);
+ RowExpression combinedCondition = new SpecialFormExpression(rewritten.getSourceLocation(), AND, BOOLEAN, condition, innerCondition);
return new SpecialFormExpression(rewritten.getSourceLocation(), IF, rewritten.getType(), combinedCondition, innerArgs.get(1), falseValue);
}
}
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
index 0c7010f0ef9a2..b71ffb0988527 100644
--- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
@@ -141,6 +141,7 @@ public void testDefaults()
.setSimplifyCoalesceOverJoinKeys(false)
.setPushdownThroughUnnest(false)
.setSimplifyAggregationsOverConstant(false)
+ .setPreAggregateBeforeGroupingSets(false)
.setForceSingleNodeOutput(true)
.setPagesIndexEagerCompactionEnabled(false)
.setFilterAndProjectMinOutputPageSize(new DataSize(500, KILOBYTE))
@@ -366,6 +367,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.simplify-coalesce-over-join-keys", "true")
.put("optimizer.pushdown-through-unnest", "true")
.put("optimizer.simplify-aggregations-over-constant", "true")
+ .put("optimizer.pre-aggregate-before-grouping-sets", "true")
.put("optimizer.aggregation-partition-merging", "top_down")
.put("optimizer.local-exchange-parent-preference-strategy", "automatic")
.put("experimental.spill-enabled", "true")
@@ -602,6 +604,7 @@ public void testExplicitPropertyMappings()
.setSimplifyCoalesceOverJoinKeys(true)
.setPushdownThroughUnnest(true)
.setSimplifyAggregationsOverConstant(true)
+ .setPreAggregateBeforeGroupingSets(true)
.setSpillEnabled(true)
.setJoinSpillingEnabled(false)
.setSpillerSpillPaths("/tmp/custom/spill/path1,/tmp/custom/spill/path2")
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPreAggregateBeforeGroupId.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPreAggregateBeforeGroupId.java
new file mode 100644
index 0000000000000..ed1d61ea951f1
--- /dev/null
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPreAggregateBeforeGroupId.java
@@ -0,0 +1,356 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.spi.plan.AggregationNode;
+import com.facebook.presto.spi.plan.Assignments;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
+import com.facebook.presto.sql.planner.plan.ExchangeNode;
+import com.facebook.presto.sql.planner.plan.GroupIdNode;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import org.testng.annotations.Test;
+
+import static com.facebook.presto.SystemSessionProperties.PRE_AGGREGATE_BEFORE_GROUPING_SETS;
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
+
+public class TestPreAggregateBeforeGroupId
+ extends BaseRuleTest
+{
+ @Test
+ public void testPreAggregatesBeforeGroupId()
+ {
+ tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager()))
+ .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true")
+ .on(p -> {
+ VariableReferenceExpression y = p.variable("y", BIGINT);
+ VariableReferenceExpression z = p.variable("z", BIGINT);
+ VariableReferenceExpression x = p.variable("x", BIGINT);
+ VariableReferenceExpression groupId = p.variable("groupId", BIGINT);
+
+ return p.aggregation(a -> a
+ .addAggregation(
+ p.variable("sum_x", BIGINT),
+ p.rowExpression("sum(x)"))
+ .groupingSets(new AggregationNode.GroupingSetDescriptor(
+ ImmutableList.of(y, z, groupId),
+ 3,
+ ImmutableSet.of()))
+ .groupIdVariable(groupId)
+ .step(AggregationNode.Step.PARTIAL)
+ .source(p.groupId(
+ ImmutableList.of(
+ ImmutableList.of(y, z),
+ ImmutableList.of(y)),
+ ImmutableList.of(x),
+ groupId,
+ p.values(y, z, x))));
+ })
+ .matches(
+ aggregation(
+ ImmutableMap.of("sum_x", functionCall("sum", ImmutableList.of("sum_0"))),
+ AggregationNode.Step.INTERMEDIATE,
+ node(GroupIdNode.class,
+ aggregation(
+ ImmutableMap.of("sum_0", functionCall("sum", ImmutableList.of("sum"))),
+ AggregationNode.Step.INTERMEDIATE,
+ node(ExchangeNode.class,
+ aggregation(
+ ImmutableMap.of("sum", functionCall("sum", ImmutableList.of("x"))),
+ AggregationNode.Step.PARTIAL,
+ values("y", "z", "x")))))));
+ }
+
+ @Test
+ public void testDoesNotFireWhenDisabled()
+ {
+ tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager()))
+ .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "false")
+ .on(p -> {
+ VariableReferenceExpression y = p.variable("y", BIGINT);
+ VariableReferenceExpression z = p.variable("z", BIGINT);
+ VariableReferenceExpression x = p.variable("x", BIGINT);
+ VariableReferenceExpression groupId = p.variable("groupId", BIGINT);
+
+ return p.aggregation(a -> a
+ .addAggregation(
+ p.variable("sum_x", BIGINT),
+ p.rowExpression("sum(x)"))
+ .groupingSets(new AggregationNode.GroupingSetDescriptor(
+ ImmutableList.of(y, z, groupId),
+ 3,
+ ImmutableSet.of()))
+ .groupIdVariable(groupId)
+ .step(AggregationNode.Step.PARTIAL)
+ .source(p.groupId(
+ ImmutableList.of(
+ ImmutableList.of(y, z),
+ ImmutableList.of(y)),
+ ImmutableList.of(x),
+ groupId,
+ p.values(y, z, x))));
+ })
+ .doesNotFire();
+ }
+
+ @Test
+ public void testDoesNotFireOnSingleStep()
+ {
+ tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager()))
+ .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true")
+ .on(p -> {
+ VariableReferenceExpression y = p.variable("y", BIGINT);
+ VariableReferenceExpression z = p.variable("z", BIGINT);
+ VariableReferenceExpression x = p.variable("x", BIGINT);
+ VariableReferenceExpression groupId = p.variable("groupId", BIGINT);
+
+ return p.aggregation(a -> a
+ .addAggregation(
+ p.variable("sum_x", BIGINT),
+ p.rowExpression("sum(x)"))
+ .groupingSets(new AggregationNode.GroupingSetDescriptor(
+ ImmutableList.of(y, z, groupId),
+ 3,
+ ImmutableSet.of()))
+ .groupIdVariable(groupId)
+ .step(AggregationNode.Step.SINGLE)
+ .source(p.groupId(
+ ImmutableList.of(
+ ImmutableList.of(y, z),
+ ImmutableList.of(y)),
+ ImmutableList.of(x),
+ groupId,
+ p.values(y, z, x))));
+ })
+ .doesNotFire();
+ }
+
+ @Test
+ public void testDoesNotFireOnNonGroupIdSource()
+ {
+ tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager()))
+ .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true")
+ .on(p -> {
+ VariableReferenceExpression x = p.variable("x", BIGINT);
+ return p.aggregation(a -> a
+ .addAggregation(
+ p.variable("sum_x", BIGINT),
+ p.rowExpression("sum(x)"))
+ .singleGroupingSet(x)
+ .step(AggregationNode.Step.PARTIAL)
+ .source(p.values(x)));
+ })
+ .doesNotFire();
+ }
+
+ @Test
+ public void testDoesNotFireOnDistinctAggregation()
+ {
+ tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager()))
+ .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true")
+ .on(p -> {
+ VariableReferenceExpression y = p.variable("y", BIGINT);
+ VariableReferenceExpression z = p.variable("z", BIGINT);
+ VariableReferenceExpression x = p.variable("x", BIGINT);
+ VariableReferenceExpression groupId = p.variable("groupId", BIGINT);
+
+ return p.aggregation(a -> a
+ .addAggregation(
+ p.variable("sum_x", BIGINT),
+ p.rowExpression("sum(x)"),
+ true)
+ .groupingSets(new AggregationNode.GroupingSetDescriptor(
+ ImmutableList.of(y, z, groupId),
+ 3,
+ ImmutableSet.of()))
+ .groupIdVariable(groupId)
+ .step(AggregationNode.Step.PARTIAL)
+ .source(p.groupId(
+ ImmutableList.of(
+ ImmutableList.of(y, z),
+ ImmutableList.of(y)),
+ ImmutableList.of(x),
+ groupId,
+ p.values(y, z, x))));
+ })
+ .doesNotFire();
+ }
+
+ @Test
+ public void testPreAggregatesMultipleAggregations()
+ {
+ tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager()))
+ .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true")
+ .on(p -> {
+ VariableReferenceExpression y = p.variable("y", BIGINT);
+ VariableReferenceExpression z = p.variable("z", BIGINT);
+ VariableReferenceExpression x = p.variable("x", BIGINT);
+ VariableReferenceExpression w = p.variable("w", BIGINT);
+ VariableReferenceExpression groupId = p.variable("groupId", BIGINT);
+
+ return p.aggregation(a -> a
+ .addAggregation(
+ p.variable("sum_x", BIGINT),
+ p.rowExpression("sum(x)"))
+ .addAggregation(
+ p.variable("count_w", BIGINT),
+ p.rowExpression("count(w)"))
+ .groupingSets(new AggregationNode.GroupingSetDescriptor(
+ ImmutableList.of(y, z, groupId),
+ 3,
+ ImmutableSet.of()))
+ .groupIdVariable(groupId)
+ .step(AggregationNode.Step.PARTIAL)
+ .source(p.groupId(
+ ImmutableList.of(
+ ImmutableList.of(y, z),
+ ImmutableList.of(y)),
+ ImmutableList.of(x, w),
+ groupId,
+ p.values(y, z, x, w))));
+ })
+ .matches(
+ aggregation(
+ ImmutableMap.of(
+ "sum_x", functionCall("sum", ImmutableList.of("sum_0")),
+ "count_w", functionCall("count", ImmutableList.of("count_0"))),
+ AggregationNode.Step.INTERMEDIATE,
+ node(GroupIdNode.class,
+ aggregation(
+ ImmutableMap.of(
+ "sum_0", functionCall("sum", ImmutableList.of("sum")),
+ "count_0", functionCall("count", ImmutableList.of("count"))),
+ AggregationNode.Step.INTERMEDIATE,
+ node(ExchangeNode.class,
+ aggregation(
+ ImmutableMap.of(
+ "sum", functionCall("sum", ImmutableList.of("x")),
+ "count", functionCall("count", ImmutableList.of("w"))),
+ AggregationNode.Step.PARTIAL,
+ values("y", "z", "x", "w")))))));
+ }
+
+ @Test
+ public void testPreAggregatesWithCountStar()
+ {
+ tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager()))
+ .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true")
+ .on(p -> {
+ VariableReferenceExpression y = p.variable("y", BIGINT);
+ VariableReferenceExpression z = p.variable("z", BIGINT);
+ VariableReferenceExpression x = p.variable("x", BIGINT);
+ VariableReferenceExpression groupId = p.variable("groupId", BIGINT);
+
+ return p.aggregation(a -> a
+ .addAggregation(
+ p.variable("sum_x", BIGINT),
+ p.rowExpression("sum(x)"))
+ .addAggregation(
+ p.variable("count_star", BIGINT),
+ p.rowExpression("count()"))
+ .groupingSets(new AggregationNode.GroupingSetDescriptor(
+ ImmutableList.of(y, z, groupId),
+ 3,
+ ImmutableSet.of()))
+ .groupIdVariable(groupId)
+ .step(AggregationNode.Step.PARTIAL)
+ .source(p.groupId(
+ ImmutableList.of(
+ ImmutableList.of(y, z),
+ ImmutableList.of(y)),
+ ImmutableList.of(x),
+ groupId,
+ p.values(y, z, x))));
+ })
+ .matches(
+ aggregation(
+ ImmutableMap.of(
+ "sum_x", functionCall("sum", ImmutableList.of("sum_0")),
+ "count_star", functionCall("count", ImmutableList.of("count_0"))),
+ AggregationNode.Step.INTERMEDIATE,
+ node(GroupIdNode.class,
+ aggregation(
+ ImmutableMap.of(
+ "sum_0", functionCall("sum", ImmutableList.of("sum")),
+ "count_0", functionCall("count", ImmutableList.of("count"))),
+ AggregationNode.Step.INTERMEDIATE,
+ node(ExchangeNode.class,
+ aggregation(
+ ImmutableMap.of(
+ "sum", functionCall("sum", ImmutableList.of("x")),
+ "count", functionCall("count", ImmutableList.of())),
+ AggregationNode.Step.PARTIAL,
+ values("y", "z", "x")))))));
+ }
+
+ @Test
+ public void testFiresThroughProjectNode()
+ {
+ tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager()))
+ .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true")
+ .on(p -> {
+ VariableReferenceExpression y = p.variable("y", BIGINT);
+ VariableReferenceExpression z = p.variable("z", BIGINT);
+ VariableReferenceExpression x = p.variable("x", BIGINT);
+ VariableReferenceExpression groupId = p.variable("groupId", BIGINT);
+
+ // Simulate a Project node (e.g., hash generation) between Agg and GroupId
+ Assignments identityAssignments = Assignments.builder()
+ .put(y, y)
+ .put(z, z)
+ .put(x, x)
+ .put(groupId, groupId)
+ .build();
+
+ return p.aggregation(a -> a
+ .addAggregation(
+ p.variable("sum_x", BIGINT),
+ p.rowExpression("sum(x)"))
+ .groupingSets(new AggregationNode.GroupingSetDescriptor(
+ ImmutableList.of(y, z, groupId),
+ 3,
+ ImmutableSet.of()))
+ .groupIdVariable(groupId)
+ .step(AggregationNode.Step.PARTIAL)
+ .source(p.project(
+ identityAssignments,
+ p.groupId(
+ ImmutableList.of(
+ ImmutableList.of(y, z),
+ ImmutableList.of(y)),
+ ImmutableList.of(x),
+ groupId,
+ p.values(y, z, x)))));
+ })
+ .matches(
+ aggregation(
+ ImmutableMap.of("sum_x", functionCall("sum", ImmutableList.of("sum_0"))),
+ AggregationNode.Step.INTERMEDIATE,
+ node(GroupIdNode.class,
+ aggregation(
+ ImmutableMap.of("sum_0", functionCall("sum", ImmutableList.of("sum"))),
+ AggregationNode.Step.INTERMEDIATE,
+ node(ExchangeNode.class,
+ aggregation(
+ ImmutableMap.of("sum", functionCall("sum", ImmutableList.of("x"))),
+ AggregationNode.Step.PARTIAL,
+ values("y", "z", "x")))))));
+ }
+}
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java
index 2214946b08ffd..034208b6b9319 100644
--- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java
@@ -175,6 +175,11 @@ public void testSimplifyNestedIf()
assertSimplifies(
"IF(X, V, CAST(null AS boolean))",
"IF(X, V)");
+
+ // No simplification: inner condition is non-deterministic
+ assertSimplifies(
+ "IF(X, IF(random() > 0.5e0, V, Z), Z)",
+ "IF(X, IF(random() > 0.5e0, V, Z), Z)");
}
@Test
diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
index 3a6c366e4e191..f837ad0c7a1f5 100644
--- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
+++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
@@ -72,6 +72,7 @@
import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION;
import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT;
import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT_TIMEOUT_MS;
+import static com.facebook.presto.SystemSessionProperties.PRE_AGGREGATE_BEFORE_GROUPING_SETS;
import static com.facebook.presto.SystemSessionProperties.PRE_PROCESS_METADATA_CALLS;
import static com.facebook.presto.SystemSessionProperties.PULL_EXPRESSION_FROM_LAMBDA_ENABLED;
import static com.facebook.presto.SystemSessionProperties.PUSH_DOWN_FILTER_EXPRESSION_EVALUATION_THROUGH_CROSS_JOIN;
@@ -1453,6 +1454,53 @@ private void testGroupingSets(Session session)
" (415502467, NULL, '3-MEDIUM')");
}
+ @Test
+ public void testGroupingSetsWithPreAggregation()
+ {
+ Session enabled = Session.builder(getSession())
+ .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true")
+ .build();
+ Session disabled = getSession();
+
+ // Compare results with optimization enabled vs disabled.
+ // Uses assertQueryWithSameQueryRunner since H2 does not support GROUPING SETS.
+ // Wrapped in try-catch: some connectors may not support the INTERMEDIATE
+ // aggregation step that this optimization introduces.
+ String[] queries = {
+ "SELECT sum(totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT min(totalprice), orderstatus, orderpriority FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderstatus, orderpriority))",
+ "SELECT max(totalprice), orderstatus, orderpriority FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderpriority))",
+ "SELECT sum(totalprice), min(totalprice), max(totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT sum(totalprice), orderstatus, orderpriority FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderpriority))",
+ "SELECT sum(totalprice), orderstatus, orderpriority FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderpriority), (orderstatus, orderpriority))",
+ "SELECT avg(totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT variance(totalprice), stddev(totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT approx_distinct(custkey), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT bool_and(totalprice > 0), bool_or(totalprice > 100000), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT count_if(totalprice > 100000), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT min_by(comment, totalprice), max_by(comment, totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT cardinality(approx_set(custkey)), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT bitwise_and_agg(CAST(custkey AS BIGINT)), bitwise_or_agg(CAST(custkey AS BIGINT)), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT covar_samp(totalprice, CAST(custkey AS DOUBLE)), corr(totalprice, CAST(custkey AS DOUBLE)), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())",
+ "SELECT sum(extendedprice), count(extendedprice), day(shipdate), month(shipdate), shipdate FROM lineitem GROUP BY CUBE (day(shipdate), month(shipdate), shipdate)",
+ };
+
+ for (String query : queries) {
+ try {
+ assertQueryWithSameQueryRunner(enabled, query, disabled);
+ }
+ catch (AssertionError e) {
+ // LocalQueryRunner cannot handle REMOTE_STREAMING exchanges that
+ // this optimization introduces. Skip rather than fail.
+ if (e.getMessage() != null && (e.getMessage().contains("query failed")
+ || e.getMessage().contains("subplan"))) {
+ continue;
+ }
+ throw e;
+ }
+ }
+ }
+
@Test
public void testGroupingWithFortyArguments()
{