diff --git a/presto-docs/src/main/sphinx/admin/properties-session.rst b/presto-docs/src/main/sphinx/admin/properties-session.rst index dad07f4e654eb..5d1b00792402b 100644 --- a/presto-docs/src/main/sphinx/admin/properties-session.rst +++ b/presto-docs/src/main/sphinx/admin/properties-session.rst @@ -348,6 +348,24 @@ to make the query plan easier to read. The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.optimize-hash-generation\`\``. +``pre_aggregate_before_grouping_sets`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +When enabled, inserts a partial aggregation below the ``GroupId`` node in grouping sets +queries to reduce the number of rows that ``GroupId`` multiplies across grouping sets. +The partial aggregation groups by the union of all grouping set columns (the finest +granularity needed), which can drastically reduce the input to ``GroupId``. This is +most effective when the data has high cardinality on the grouping columns, as the +pre-aggregation can significantly reduce the row count before multiplication. + +Only applies to decomposable aggregation functions such as ``SUM``, ``COUNT``, ``MIN``, +or ``MAX`` that support partial/intermediate/final splitting. + +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.pre-aggregate-before-grouping-sets\`\``. + ``push_aggregation_through_join`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 4a89a9c229ed3..274701aade9b0 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -922,6 +922,18 @@ create them). The single distinct optimization will try to replace multiple ``DISTINCT`` clauses with a single ``GROUP BY`` clause, which can be substantially faster to execute. +``optimizer.pre-aggregate-before-grouping-sets`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +When enabled, inserts a partial aggregation below the ``GroupId`` node in grouping sets +queries to reduce the number of rows that ``GroupId`` multiplies across grouping sets. +Only applies to decomposable aggregation functions. + +The corresponding session property is :ref:`admin/properties-session:\`\`pre_aggregate_before_grouping_sets\`\``. + ``optimizer.push-aggregation-through-join`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/BenchmarkGroupingSetsPreAggregation.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/BenchmarkGroupingSetsPreAggregation.java new file mode 100644 index 0000000000000..ccb746258033e --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/BenchmarkGroupingSetsPreAggregation.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.benchmark; + +import org.testng.annotations.Test; + +/** + * Benchmarks the PreAggregateBeforeGroupId optimization against a baseline + * and the existing AddExchangesBelowPartialAggregationOverGroupIdRuleSet. + * + *

Uses CROSS JOIN UNNEST to amplify lineitem rows without triggering + * PushPartialAggregationThroughExchange (which would happen with UNION ALL). + * + *

Run via: + *

+ * mvn test -pl presto-hive \
+ *   -Dtest=BenchmarkGroupingSetsPreAggregation \
+ *   -DfailIfNoTests=false
+ * 
+ */ +public final class BenchmarkGroupingSetsPreAggregation +{ + private static final String QUERY = + "SELECT yr, mo, dy, shipmode, returnflag, " + + "sum(quantity), count(*), min(extendedprice), max(extendedprice) " + + "FROM (" + + " SELECT year(shipdate) AS yr, month(shipdate) AS mo, day(shipdate) AS dy, " + + " shipmode, returnflag, quantity, extendedprice " + + " FROM lineitem " + + " CROSS JOIN UNNEST(ARRAY[1,2,3,4,5]) AS t(x)" + + ") t " + + "GROUP BY CUBE (yr, mo, dy, shipmode, returnflag)"; + + @Test + public void benchmark() + throws Exception + { + try (HiveDistributedBenchmarkRunner runner = + new HiveDistributedBenchmarkRunner(3, 5)) { + runner.addScenario("baseline", builder -> { + builder.setSystemProperty("pre_aggregate_before_grouping_sets", "false"); + builder.setSystemProperty("add_exchange_below_partial_aggregation_over_group_id", "false"); + }); + + runner.addScenario("pre_aggregate_before_groupid", builder -> { + builder.setSystemProperty("pre_aggregate_before_grouping_sets", "true"); + builder.setSystemProperty("add_exchange_below_partial_aggregation_over_group_id", "false"); + }); + + runner.addScenario("add_exchange_below_agg", builder -> { + builder.setSystemProperty("pre_aggregate_before_grouping_sets", "false"); + builder.setSystemProperty("add_exchange_below_partial_aggregation_over_group_id", "true"); + }); + + runner.runWithVerification(QUERY); + } + } + + public static void main(String[] args) + throws Exception + { + new BenchmarkGroupingSetsPreAggregation().benchmark(); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveDistributedBenchmarkRunner.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveDistributedBenchmarkRunner.java new file mode 100644 index 0000000000000..7e20b400b2f40 --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveDistributedBenchmarkRunner.java @@ -0,0 +1,175 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.benchmark; + +import com.facebook.presto.Session; +import com.facebook.presto.hive.HiveQueryRunner; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; + +import java.util.LinkedHashMap; +import java.util.Map; + +import static io.airlift.tpch.TpchTable.getTables; +import static java.util.Objects.requireNonNull; + +/** + * Runs SQL benchmarks against a Hive-backed DistributedQueryRunner. + * Supports comparing multiple session configurations side by side. + * + *

Usage: + *

+ * try (HiveDistributedBenchmarkRunner runner = new HiveDistributedBenchmarkRunner(3, 5)) {
+ *     runner.addScenario("baseline", sessionBuilder -> {});
+ *     runner.addScenario("optimized", sessionBuilder ->
+ *             sessionBuilder.setSystemProperty("my_property", "true"));
+ *     runner.run("SELECT ... GROUP BY CUBE(...)");
+ * }
+ * 
+ */ +public class HiveDistributedBenchmarkRunner + implements AutoCloseable +{ + private final QueryRunner queryRunner; + private final int warmupIterations; + private final int measuredIterations; + private final Map scenarios = new LinkedHashMap<>(); + private final StringBuilder results = new StringBuilder(); + + public HiveDistributedBenchmarkRunner(int warmupIterations, int measuredIterations) + throws Exception + { + this.warmupIterations = warmupIterations; + this.measuredIterations = measuredIterations; + this.queryRunner = HiveQueryRunner.createQueryRunner(getTables()); + } + + public void addScenario(String name, SessionConfigurator configurator) + { + requireNonNull(name, "name is null"); + requireNonNull(configurator, "configurator is null"); + Session.SessionBuilder builder = Session.builder(queryRunner.getDefaultSession()); + configurator.configure(builder); + scenarios.put(name, builder.build()); + } + + public String run(String sql) + { + results.setLength(0); + Map averages = new LinkedHashMap<>(); + + for (Map.Entry entry : scenarios.entrySet()) { + String name = entry.getKey(); + Session session = entry.getValue(); + long avg = runScenario(name, session, sql); + averages.put(name, avg); + } + + // Summary + results.append("\n=== Summary ===\n"); + Long baselineAvg = averages.values().iterator().next(); + for (Map.Entry entry : averages.entrySet()) { + double speedup = (double) baselineAvg / entry.getValue(); + results.append(String.format("%-30s %6d ms (%.2fx)\n", + entry.getKey(), entry.getValue(), speedup)); + } + + String output = results.toString(); + System.out.println(output); + + // Write to file since surefire mixes stdout with logging + try { + String path = System.getProperty("java.io.tmpdir") + "/hive_benchmark_results.txt"; + java.nio.file.Files.write(java.nio.file.Paths.get(path), output.getBytes()); + System.out.println("Results written to: " + path); + } + catch (Exception e) { + // ignore + } + + return output; + } + + /** + * Runs the benchmark query with correctness verification. + * All scenarios must produce the same results as the first scenario. + */ + public String runWithVerification(String sql) + { + String output = run(sql); + + // Verify correctness: all scenarios must match the first + MaterializedResult expected = null; + for (Map.Entry entry : scenarios.entrySet()) { + MaterializedResult actual = queryRunner.execute(entry.getValue(), sql); + if (expected == null) { + expected = actual; + } + else { + if (!resultsMatch(expected, actual)) { + throw new AssertionError( + "Results mismatch for scenario '" + entry.getKey() + "'"); + } + } + } + return output; + } + + private long runScenario(String name, Session session, String sql) + { + results.append(String.format("--- %s ---\n", name)); + + // Warmup + for (int i = 0; i < warmupIterations; i++) { + queryRunner.execute(session, sql); + } + + // Measured runs + long totalMs = 0; + for (int i = 0; i < measuredIterations; i++) { + long start = System.nanoTime(); + queryRunner.execute(session, sql); + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + totalMs += elapsedMs; + results.append(String.format(" run %d: %d ms\n", i + 1, elapsedMs)); + } + long avg = totalMs / measuredIterations; + results.append(String.format(" avg: %d ms\n\n", avg)); + return avg; + } + + private static boolean resultsMatch(MaterializedResult a, MaterializedResult b) + { + return a.getMaterializedRows().size() == b.getMaterializedRows().size() + && new java.util.HashSet<>(a.getMaterializedRows()) + .equals(new java.util.HashSet<>(b.getMaterializedRows())); + } + + public QueryRunner getQueryRunner() + { + return queryRunner; + } + + @Override + public void close() + { + queryRunner.close(); + } + + @FunctionalInterface + public interface SessionConfigurator + { + void configure(Session.SessionBuilder builder); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index 1d4bbdde14608..3723bc8e99017 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -178,6 +178,7 @@ public final class SystemSessionProperties public static final String PUSHDOWN_THROUGH_UNNEST = "pushdown_through_unnest"; public static final String SIMPLIFY_AGGREGATIONS_OVER_CONSTANT = "simplify_aggregations_over_constant"; public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN = "push_partial_aggregation_through_join"; + public static final String PRE_AGGREGATE_BEFORE_GROUPING_SETS = "pre_aggregate_before_grouping_sets"; public static final String PARSE_DECIMAL_LITERALS_AS_DOUBLE = "parse_decimal_literals_as_double"; public static final String FORCE_SINGLE_NODE_OUTPUT = "force_single_node_output"; public static final String FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_SIZE = "filter_and_project_min_output_page_size"; @@ -955,6 +956,11 @@ public SystemSessionProperties( "Push partial aggregations below joins", featuresConfig.isPushPartialAggregationThroughJoin(), false), + booleanProperty( + PRE_AGGREGATE_BEFORE_GROUPING_SETS, + "Pre-aggregate data before GroupId node to reduce row multiplication in grouping sets queries", + featuresConfig.isPreAggregateBeforeGroupingSets(), + false), booleanProperty( PARSE_DECIMAL_LITERALS_AS_DOUBLE, "Parse decimal literals as DOUBLE instead of DECIMAL", @@ -2678,6 +2684,11 @@ public static boolean isSimplifyAggregationsOverConstant(Session session) return session.getSystemProperty(SIMPLIFY_AGGREGATIONS_OVER_CONSTANT, Boolean.class); } + public static boolean isPreAggregateBeforeGroupingSets(Session session) + { + return session.getSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, Boolean.class); + } + public static boolean isParseDecimalLiteralsAsDouble(Session session) { return session.getSystemProperty(PARSE_DECIMAL_LITERALS_AS_DOUBLE, Boolean.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java index 92d6f37de0b53..0e913e18891e2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java @@ -604,7 +604,7 @@ private void initializeAggregationBuilderIfNeeded() .map(PartialAggregationController::isPartialAggregationDisabled) .orElse(false); - if (step.isOutputPartial() && partialAggregationDisabled) { + if (step.isInputRaw() && step.isOutputPartial() && partialAggregationDisabled) { aggregationBuilder = new SkipAggregationBuilder( groupByChannels, hashChannel, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 0f1d0dede3bec..152be4481c424 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -160,6 +160,7 @@ public class FeaturesConfig private boolean simplifyCoalesceOverJoinKeys; private boolean pushdownThroughUnnest; private boolean simplifyAggregationsOverConstant; + private boolean preAggregateBeforeGroupingSets; private double memoryRevokingTarget = 0.5; private double memoryRevokingThreshold = 0.9; private boolean useMarkDistinct = true; @@ -1744,6 +1745,18 @@ public FeaturesConfig setSimplifyAggregationsOverConstant(boolean simplifyAggreg return this; } + public boolean isPreAggregateBeforeGroupingSets() + { + return preAggregateBeforeGroupingSets; + } + + @Config("optimizer.pre-aggregate-before-grouping-sets") + public FeaturesConfig setPreAggregateBeforeGroupingSets(boolean preAggregateBeforeGroupingSets) + { + this.preAggregateBeforeGroupingSets = preAggregateBeforeGroupingSets; + return this; + } + public boolean isForceSingleNodeOutput() { return forceSingleNodeOutput; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 80ef55ea1139e..5f0638635d25c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -3648,7 +3648,7 @@ private static Optional createPartialAggregationCo AggregationNode.Step step, Session session) { - if (maxPartialAggregationMemorySize.isPresent() && step.isOutputPartial() && isAdaptivePartialAggregationEnabled(session)) { + if (maxPartialAggregationMemorySize.isPresent() && step.isInputRaw() && step.isOutputPartial() && isAdaptivePartialAggregationEnabled(session)) { return Optional.of(new PartialAggregationController(maxPartialAggregationMemorySize.get(), getAdaptivePartialAggregationRowsReductionRatioThreshold(session))); } return Optional.empty(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 8f8f9715491c8..ff53bf73799c9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -67,6 +67,7 @@ import com.facebook.presto.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct; import com.facebook.presto.sql.planner.iterative.rule.PickTableLayout; import com.facebook.presto.sql.planner.iterative.rule.PlanRemoteProjections; +import com.facebook.presto.sql.planner.iterative.rule.PreAggregateBeforeGroupId; import com.facebook.presto.sql.planner.iterative.rule.PruneAggregationColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneAggregationSourceColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneCountAggregationOverScalar; @@ -1036,6 +1037,14 @@ public PlanOptimizers( ImmutableSet.of( new PruneJoinColumns()))); + builder.add(new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of( + new PreAggregateBeforeGroupId(metadata.getFunctionAndTypeManager())))); + builder.add(new IterativeOptimizer( metadata, ruleStats, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PreAggregateBeforeGroupId.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PreAggregateBeforeGroupId.java new file mode 100644 index 0000000000000..cd3ffc730cd5f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PreAggregateBeforeGroupId.java @@ -0,0 +1,370 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.function.AggregationFunctionImplementation; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.Partitioning; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.GroupIdNode; +import com.google.common.collect.ImmutableList; + +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.SystemSessionProperties.isPreAggregateBeforeGroupingSets; +import static com.facebook.presto.operator.aggregation.AggregationUtils.isDecomposable; +import static com.facebook.presto.spi.plan.AggregationNode.Step.INTERMEDIATE; +import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.step; +import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toSet; + +/** + * Transforms + *
+ *   - Partial Aggregation
+ *     - GroupId
+ *       - Source
+ * 
+ * to + *
+ *   - Intermediate Aggregation
+ *     - GroupId
+ *       - Intermediate Aggregation
+ *         - RemoteExchange
+ *           - Partial Aggregation
+ *             - Source
+ * 
+ *

+ * Rationale: GroupId increases the number of rows (one copy per grouping set), then partial + * aggregation reduces them. By pre-aggregating at the finest granularity (union of all grouping + * set columns) and shuffling by grouping keys before GroupId, we reduce the number of rows that + * get multiplied. The original PARTIAL above GroupId is changed to INTERMEDIATE to merge the + * pre-aggregated partial states within each grouping set. + *

+ * Also handles the case where a ProjectNode (e.g., from hash generation) sits between + * the Aggregation and GroupId. + *

+ * Only applies to decomposable aggregation functions ({@code SUM}, {@code COUNT}, {@code MIN}, + * {@code MAX}) that support partial/intermediate/final splitting. + */ +public class PreAggregateBeforeGroupId + implements Rule +{ + // Match Aggregation(PARTIAL) whose source is either GroupId directly + // or a Project node (e.g., from hash generation) on top of GroupId. + private static final Pattern PATTERN = aggregation() + .with(step().equalTo(PARTIAL)); + + private final FunctionAndTypeManager functionAndTypeManager; + + public PreAggregateBeforeGroupId(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(Session session) + { + return isPreAggregateBeforeGroupingSets(session); + } + + @Override + public Result apply(AggregationNode node, Captures captures, Context context) + { + // Find GroupIdNode: either the direct source or behind a ProjectNode. + // Must resolve through Lookup since the Memo wraps sources in GroupReference. + GroupIdNode groupIdNode = findGroupIdNode(node, context.getLookup()); + if (groupIdNode == null) { + return Result.empty(); + } + + // Safety check: must be decomposable (no DISTINCT, no ORDER BY) + if (!isDecomposable(node, functionAndTypeManager)) { + return Result.empty(); + } + + // Skip checksum whose XOR-based intermediate merge produces different results + // at different grouping granularities. + for (Aggregation agg : node.getAggregations().values()) { + String name = functionAndTypeManager.getFunctionMetadata(agg.getFunctionHandle()).getName().getObjectName(); + if (name.equals("checksum")) { + return Result.empty(); + } + } + + // Verify that the aggregation's grouping keys are consistent with GroupId output. + if (!aggregationGroupingMatchesGroupId(node, groupIdNode)) { + return Result.empty(); + } + + // Compute the union of all grouping set columns mapped back to source variables. + Map groupingColumns = groupIdNode.getGroupingColumns(); + Set allSourceGroupingKeys = new LinkedHashSet<>(); + for (List groupingSet : groupIdNode.getGroupingSets()) { + for (VariableReferenceExpression outputVar : groupingSet) { + VariableReferenceExpression sourceVar = groupingColumns.get(outputVar); + if (sourceVar != null) { + allSourceGroupingKeys.add(sourceVar); + } + } + } + + if (allSourceGroupingKeys.isEmpty()) { + return Result.empty(); + } + + // Build variable mappings for the three aggregation levels: + // 1. PARTIAL: raw values → partialVar (intermediate type) + // 2. INTERMEDIATE below GroupId: partialVar → preGroupIdVar (intermediate type) + // 3. INTERMEDIATE above GroupId: preGroupIdVar → originalOutputVar (intermediate type) + Map outputToPartialVarMap = new HashMap<>(); + Map outputToPreGroupIdVarMap = new HashMap<>(); + Map newPartialAggregations = new HashMap<>(); + + for (Map.Entry entry : node.getAggregations().entrySet()) { + Aggregation originalAggregation = entry.getValue(); + FunctionHandle functionHandle = originalAggregation.getFunctionHandle(); + String functionName = functionAndTypeManager.getFunctionMetadata(functionHandle).getName().getObjectName(); + AggregationFunctionImplementation function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle); + + // Variable for PARTIAL output + VariableReferenceExpression partialVariable = context.getVariableAllocator().newVariable( + entry.getValue().getCall().getSourceLocation(), + functionName, + function.getIntermediateType()); + + // Variable for INTERMEDIATE-below-GroupId output + VariableReferenceExpression preGroupIdVariable = context.getVariableAllocator().newVariable( + entry.getValue().getCall().getSourceLocation(), + functionName, + function.getIntermediateType()); + + outputToPartialVarMap.put(entry.getKey(), partialVariable); + outputToPreGroupIdVarMap.put(entry.getKey(), preGroupIdVariable); + + // The new PARTIAL aggregation uses the original arguments (which are + // GroupIdNode.aggregationArguments — source-side pass-through variables) + newPartialAggregations.put(partialVariable, new Aggregation( + new CallExpression( + originalAggregation.getCall().getSourceLocation(), + functionName, + functionHandle, + function.getIntermediateType(), + originalAggregation.getArguments()), + originalAggregation.getFilter(), + originalAggregation.getOrderBy(), + originalAggregation.isDistinct(), + originalAggregation.getMask())); + } + + // Step 1: Create new PARTIAL AggregationNode on Source + ImmutableList groupingKeysList = ImmutableList.copyOf(allSourceGroupingKeys); + PlanNode newPartialAggregation = new AggregationNode( + node.getSourceLocation(), + context.getIdAllocator().getNextId(), + groupIdNode.getSource(), + newPartialAggregations, + AggregationNode.singleGroupingSet(groupingKeysList), + ImmutableList.of(), + PARTIAL, + Optional.empty(), + Optional.empty(), + Optional.empty()); + + // Step 2: Create Exchange (hash partitioned by grouping keys) to shuffle partial states + PlanNode exchange = ExchangeNode.partitionedExchange( + context.getIdAllocator().getNextId(), + REMOTE_STREAMING, + newPartialAggregation, + new PartitioningScheme( + Partitioning.create(FIXED_HASH_DISTRIBUTION, groupingKeysList), + newPartialAggregation.getOutputVariables())); + + // Step 3: Create INTERMEDIATE AggregationNode below GroupId to merge partial states after shuffle + Map preGroupIdIntermediateAggregations = new HashMap<>(); + for (Map.Entry entry : node.getAggregations().entrySet()) { + Aggregation originalAggregation = entry.getValue(); + FunctionHandle functionHandle = originalAggregation.getFunctionHandle(); + String functionName = functionAndTypeManager.getFunctionMetadata(functionHandle).getName().getObjectName(); + AggregationFunctionImplementation function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle); + VariableReferenceExpression partialVariable = outputToPartialVarMap.get(entry.getKey()); + VariableReferenceExpression preGroupIdVariable = outputToPreGroupIdVarMap.get(entry.getKey()); + + preGroupIdIntermediateAggregations.put(preGroupIdVariable, new Aggregation( + new CallExpression( + originalAggregation.getCall().getSourceLocation(), + functionName, + functionHandle, + function.getIntermediateType(), + ImmutableList.builder() + .add(partialVariable) + .addAll(originalAggregation.getArguments() + .stream() + .filter(PreAggregateBeforeGroupId::isLambda) + .collect(toImmutableList())) + .build()), + Optional.empty(), + Optional.empty(), + false, + Optional.empty())); + } + + PlanNode preGroupIdIntermediate = new AggregationNode( + node.getSourceLocation(), + context.getIdAllocator().getNextId(), + exchange, + preGroupIdIntermediateAggregations, + AggregationNode.singleGroupingSet(groupingKeysList), + ImmutableList.of(), + INTERMEDIATE, + Optional.empty(), + Optional.empty(), + Optional.empty()); + + // Step 4: Create new GroupIdNode with INTERMEDIATE output as source + ImmutableList newAggregationArguments = ImmutableList.copyOf(outputToPreGroupIdVarMap.values()); + + GroupIdNode newGroupIdNode = new GroupIdNode( + groupIdNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + Optional.empty(), + preGroupIdIntermediate, + groupIdNode.getGroupingSets(), + groupIdNode.getGroupingColumns(), + newAggregationArguments, + groupIdNode.getGroupIdVariable()); + + // Step 5: Change the original PARTIAL aggregation above GroupId to INTERMEDIATE. + // It takes intermediate state (passed through GroupId) and produces intermediate + // state for the existing FINAL above. + Map aboveGroupIdIntermediateAggregations = new HashMap<>(); + for (Map.Entry entry : node.getAggregations().entrySet()) { + Aggregation originalAggregation = entry.getValue(); + FunctionHandle functionHandle = originalAggregation.getFunctionHandle(); + String functionName = functionAndTypeManager.getFunctionMetadata(functionHandle).getName().getObjectName(); + AggregationFunctionImplementation function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle); + VariableReferenceExpression preGroupIdVariable = outputToPreGroupIdVarMap.get(entry.getKey()); + + aboveGroupIdIntermediateAggregations.put(entry.getKey(), new Aggregation( + new CallExpression( + originalAggregation.getCall().getSourceLocation(), + functionName, + functionHandle, + function.getIntermediateType(), + ImmutableList.builder() + .add(preGroupIdVariable) + .addAll(originalAggregation.getArguments() + .stream() + .filter(PreAggregateBeforeGroupId::isLambda) + .collect(toImmutableList())) + .build()), + Optional.empty(), + Optional.empty(), + false, + Optional.empty())); + } + + PlanNode aboveGroupIdIntermediate = new AggregationNode( + node.getSourceLocation(), + node.getId(), + newGroupIdNode, + aboveGroupIdIntermediateAggregations, + node.getGroupingSets(), + ImmutableList.of(), + INTERMEDIATE, + node.getHashVariable(), + node.getGroupIdVariable(), + node.getAggregationId()); + + return Result.ofPlanNode(aboveGroupIdIntermediate); + } + + /** + * Finds the GroupIdNode below the aggregation, looking through an optional + * ProjectNode (e.g., inserted by hash generation). Uses Lookup to resolve + * through GroupReference nodes in the Memo. + */ + private static GroupIdNode findGroupIdNode(AggregationNode aggregation, Lookup lookup) + { + PlanNode source = lookup.resolve(aggregation.getSource()); + if (source instanceof GroupIdNode) { + return (GroupIdNode) source; + } + if (source instanceof ProjectNode) { + PlanNode projectSource = lookup.resolve(((ProjectNode) source).getSource()); + if (projectSource instanceof GroupIdNode) { + return (GroupIdNode) projectSource; + } + } + return null; + } + + /** + * Verifies that the aggregation's grouping keys are exactly the GroupIdNode's + * grouping set columns plus the groupId variable. + */ + private static boolean aggregationGroupingMatchesGroupId(AggregationNode aggregation, GroupIdNode groupId) + { + Set aggregationGroupingKeys = new HashSet<>(aggregation.getGroupingKeys()); + + if (!aggregationGroupingKeys.contains(groupId.getGroupIdVariable())) { + return false; + } + + Set expectedKeys = groupId.getGroupingSets().stream() + .flatMap(Collection::stream) + .collect(toSet()); + expectedKeys.add(groupId.getGroupIdVariable()); + + return aggregationGroupingKeys.equals(expectedKeys); + } + + private static boolean isLambda(RowExpression rowExpression) + { + return rowExpression instanceof LambdaDefinitionExpression; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java index 9d12d01fed842..ddae310d6b61b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java @@ -134,7 +134,7 @@ public RowExpression rewriteSpecialForm(SpecialFormExpression node, Void context List innerArgs = innerIf.getArguments(); RowExpression innerCondition = innerArgs.get(0); if (falseValue.equals(innerArgs.get(2)) && determinismEvaluator.isDeterministic(innerCondition)) { - RowExpression combinedCondition = new SpecialFormExpression(AND, BOOLEAN, condition, innerCondition); + RowExpression combinedCondition = new SpecialFormExpression(rewritten.getSourceLocation(), AND, BOOLEAN, condition, innerCondition); return new SpecialFormExpression(rewritten.getSourceLocation(), IF, rewritten.getType(), combinedCondition, innerArgs.get(1), falseValue); } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 0c7010f0ef9a2..b71ffb0988527 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -141,6 +141,7 @@ public void testDefaults() .setSimplifyCoalesceOverJoinKeys(false) .setPushdownThroughUnnest(false) .setSimplifyAggregationsOverConstant(false) + .setPreAggregateBeforeGroupingSets(false) .setForceSingleNodeOutput(true) .setPagesIndexEagerCompactionEnabled(false) .setFilterAndProjectMinOutputPageSize(new DataSize(500, KILOBYTE)) @@ -366,6 +367,7 @@ public void testExplicitPropertyMappings() .put("optimizer.simplify-coalesce-over-join-keys", "true") .put("optimizer.pushdown-through-unnest", "true") .put("optimizer.simplify-aggregations-over-constant", "true") + .put("optimizer.pre-aggregate-before-grouping-sets", "true") .put("optimizer.aggregation-partition-merging", "top_down") .put("optimizer.local-exchange-parent-preference-strategy", "automatic") .put("experimental.spill-enabled", "true") @@ -602,6 +604,7 @@ public void testExplicitPropertyMappings() .setSimplifyCoalesceOverJoinKeys(true) .setPushdownThroughUnnest(true) .setSimplifyAggregationsOverConstant(true) + .setPreAggregateBeforeGroupingSets(true) .setSpillEnabled(true) .setJoinSpillingEnabled(false) .setSpillerSpillPaths("/tmp/custom/spill/path1,/tmp/custom/spill/path2") diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPreAggregateBeforeGroupId.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPreAggregateBeforeGroupId.java new file mode 100644 index 0000000000000..ed1d61ea951f1 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPreAggregateBeforeGroupId.java @@ -0,0 +1,356 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.GroupIdNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import static com.facebook.presto.SystemSessionProperties.PRE_AGGREGATE_BEFORE_GROUPING_SETS; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPreAggregateBeforeGroupId + extends BaseRuleTest +{ + @Test + public void testPreAggregatesBeforeGroupId() + { + tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager())) + .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true") + .on(p -> { + VariableReferenceExpression y = p.variable("y", BIGINT); + VariableReferenceExpression z = p.variable("z", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression groupId = p.variable("groupId", BIGINT); + + return p.aggregation(a -> a + .addAggregation( + p.variable("sum_x", BIGINT), + p.rowExpression("sum(x)")) + .groupingSets(new AggregationNode.GroupingSetDescriptor( + ImmutableList.of(y, z, groupId), + 3, + ImmutableSet.of())) + .groupIdVariable(groupId) + .step(AggregationNode.Step.PARTIAL) + .source(p.groupId( + ImmutableList.of( + ImmutableList.of(y, z), + ImmutableList.of(y)), + ImmutableList.of(x), + groupId, + p.values(y, z, x)))); + }) + .matches( + aggregation( + ImmutableMap.of("sum_x", functionCall("sum", ImmutableList.of("sum_0"))), + AggregationNode.Step.INTERMEDIATE, + node(GroupIdNode.class, + aggregation( + ImmutableMap.of("sum_0", functionCall("sum", ImmutableList.of("sum"))), + AggregationNode.Step.INTERMEDIATE, + node(ExchangeNode.class, + aggregation( + ImmutableMap.of("sum", functionCall("sum", ImmutableList.of("x"))), + AggregationNode.Step.PARTIAL, + values("y", "z", "x"))))))); + } + + @Test + public void testDoesNotFireWhenDisabled() + { + tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager())) + .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "false") + .on(p -> { + VariableReferenceExpression y = p.variable("y", BIGINT); + VariableReferenceExpression z = p.variable("z", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression groupId = p.variable("groupId", BIGINT); + + return p.aggregation(a -> a + .addAggregation( + p.variable("sum_x", BIGINT), + p.rowExpression("sum(x)")) + .groupingSets(new AggregationNode.GroupingSetDescriptor( + ImmutableList.of(y, z, groupId), + 3, + ImmutableSet.of())) + .groupIdVariable(groupId) + .step(AggregationNode.Step.PARTIAL) + .source(p.groupId( + ImmutableList.of( + ImmutableList.of(y, z), + ImmutableList.of(y)), + ImmutableList.of(x), + groupId, + p.values(y, z, x)))); + }) + .doesNotFire(); + } + + @Test + public void testDoesNotFireOnSingleStep() + { + tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager())) + .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true") + .on(p -> { + VariableReferenceExpression y = p.variable("y", BIGINT); + VariableReferenceExpression z = p.variable("z", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression groupId = p.variable("groupId", BIGINT); + + return p.aggregation(a -> a + .addAggregation( + p.variable("sum_x", BIGINT), + p.rowExpression("sum(x)")) + .groupingSets(new AggregationNode.GroupingSetDescriptor( + ImmutableList.of(y, z, groupId), + 3, + ImmutableSet.of())) + .groupIdVariable(groupId) + .step(AggregationNode.Step.SINGLE) + .source(p.groupId( + ImmutableList.of( + ImmutableList.of(y, z), + ImmutableList.of(y)), + ImmutableList.of(x), + groupId, + p.values(y, z, x)))); + }) + .doesNotFire(); + } + + @Test + public void testDoesNotFireOnNonGroupIdSource() + { + tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager())) + .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true") + .on(p -> { + VariableReferenceExpression x = p.variable("x", BIGINT); + return p.aggregation(a -> a + .addAggregation( + p.variable("sum_x", BIGINT), + p.rowExpression("sum(x)")) + .singleGroupingSet(x) + .step(AggregationNode.Step.PARTIAL) + .source(p.values(x))); + }) + .doesNotFire(); + } + + @Test + public void testDoesNotFireOnDistinctAggregation() + { + tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager())) + .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true") + .on(p -> { + VariableReferenceExpression y = p.variable("y", BIGINT); + VariableReferenceExpression z = p.variable("z", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression groupId = p.variable("groupId", BIGINT); + + return p.aggregation(a -> a + .addAggregation( + p.variable("sum_x", BIGINT), + p.rowExpression("sum(x)"), + true) + .groupingSets(new AggregationNode.GroupingSetDescriptor( + ImmutableList.of(y, z, groupId), + 3, + ImmutableSet.of())) + .groupIdVariable(groupId) + .step(AggregationNode.Step.PARTIAL) + .source(p.groupId( + ImmutableList.of( + ImmutableList.of(y, z), + ImmutableList.of(y)), + ImmutableList.of(x), + groupId, + p.values(y, z, x)))); + }) + .doesNotFire(); + } + + @Test + public void testPreAggregatesMultipleAggregations() + { + tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager())) + .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true") + .on(p -> { + VariableReferenceExpression y = p.variable("y", BIGINT); + VariableReferenceExpression z = p.variable("z", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression w = p.variable("w", BIGINT); + VariableReferenceExpression groupId = p.variable("groupId", BIGINT); + + return p.aggregation(a -> a + .addAggregation( + p.variable("sum_x", BIGINT), + p.rowExpression("sum(x)")) + .addAggregation( + p.variable("count_w", BIGINT), + p.rowExpression("count(w)")) + .groupingSets(new AggregationNode.GroupingSetDescriptor( + ImmutableList.of(y, z, groupId), + 3, + ImmutableSet.of())) + .groupIdVariable(groupId) + .step(AggregationNode.Step.PARTIAL) + .source(p.groupId( + ImmutableList.of( + ImmutableList.of(y, z), + ImmutableList.of(y)), + ImmutableList.of(x, w), + groupId, + p.values(y, z, x, w)))); + }) + .matches( + aggregation( + ImmutableMap.of( + "sum_x", functionCall("sum", ImmutableList.of("sum_0")), + "count_w", functionCall("count", ImmutableList.of("count_0"))), + AggregationNode.Step.INTERMEDIATE, + node(GroupIdNode.class, + aggregation( + ImmutableMap.of( + "sum_0", functionCall("sum", ImmutableList.of("sum")), + "count_0", functionCall("count", ImmutableList.of("count"))), + AggregationNode.Step.INTERMEDIATE, + node(ExchangeNode.class, + aggregation( + ImmutableMap.of( + "sum", functionCall("sum", ImmutableList.of("x")), + "count", functionCall("count", ImmutableList.of("w"))), + AggregationNode.Step.PARTIAL, + values("y", "z", "x", "w"))))))); + } + + @Test + public void testPreAggregatesWithCountStar() + { + tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager())) + .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true") + .on(p -> { + VariableReferenceExpression y = p.variable("y", BIGINT); + VariableReferenceExpression z = p.variable("z", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression groupId = p.variable("groupId", BIGINT); + + return p.aggregation(a -> a + .addAggregation( + p.variable("sum_x", BIGINT), + p.rowExpression("sum(x)")) + .addAggregation( + p.variable("count_star", BIGINT), + p.rowExpression("count()")) + .groupingSets(new AggregationNode.GroupingSetDescriptor( + ImmutableList.of(y, z, groupId), + 3, + ImmutableSet.of())) + .groupIdVariable(groupId) + .step(AggregationNode.Step.PARTIAL) + .source(p.groupId( + ImmutableList.of( + ImmutableList.of(y, z), + ImmutableList.of(y)), + ImmutableList.of(x), + groupId, + p.values(y, z, x)))); + }) + .matches( + aggregation( + ImmutableMap.of( + "sum_x", functionCall("sum", ImmutableList.of("sum_0")), + "count_star", functionCall("count", ImmutableList.of("count_0"))), + AggregationNode.Step.INTERMEDIATE, + node(GroupIdNode.class, + aggregation( + ImmutableMap.of( + "sum_0", functionCall("sum", ImmutableList.of("sum")), + "count_0", functionCall("count", ImmutableList.of("count"))), + AggregationNode.Step.INTERMEDIATE, + node(ExchangeNode.class, + aggregation( + ImmutableMap.of( + "sum", functionCall("sum", ImmutableList.of("x")), + "count", functionCall("count", ImmutableList.of())), + AggregationNode.Step.PARTIAL, + values("y", "z", "x"))))))); + } + + @Test + public void testFiresThroughProjectNode() + { + tester().assertThat(new PreAggregateBeforeGroupId(getFunctionManager())) + .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true") + .on(p -> { + VariableReferenceExpression y = p.variable("y", BIGINT); + VariableReferenceExpression z = p.variable("z", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression groupId = p.variable("groupId", BIGINT); + + // Simulate a Project node (e.g., hash generation) between Agg and GroupId + Assignments identityAssignments = Assignments.builder() + .put(y, y) + .put(z, z) + .put(x, x) + .put(groupId, groupId) + .build(); + + return p.aggregation(a -> a + .addAggregation( + p.variable("sum_x", BIGINT), + p.rowExpression("sum(x)")) + .groupingSets(new AggregationNode.GroupingSetDescriptor( + ImmutableList.of(y, z, groupId), + 3, + ImmutableSet.of())) + .groupIdVariable(groupId) + .step(AggregationNode.Step.PARTIAL) + .source(p.project( + identityAssignments, + p.groupId( + ImmutableList.of( + ImmutableList.of(y, z), + ImmutableList.of(y)), + ImmutableList.of(x), + groupId, + p.values(y, z, x))))); + }) + .matches( + aggregation( + ImmutableMap.of("sum_x", functionCall("sum", ImmutableList.of("sum_0"))), + AggregationNode.Step.INTERMEDIATE, + node(GroupIdNode.class, + aggregation( + ImmutableMap.of("sum_0", functionCall("sum", ImmutableList.of("sum"))), + AggregationNode.Step.INTERMEDIATE, + node(ExchangeNode.class, + aggregation( + ImmutableMap.of("sum", functionCall("sum", ImmutableList.of("x"))), + AggregationNode.Step.PARTIAL, + values("y", "z", "x"))))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java index 2214946b08ffd..034208b6b9319 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java @@ -175,6 +175,11 @@ public void testSimplifyNestedIf() assertSimplifies( "IF(X, V, CAST(null AS boolean))", "IF(X, V)"); + + // No simplification: inner condition is non-deterministic + assertSimplifies( + "IF(X, IF(random() > 0.5e0, V, Z), Z)", + "IF(X, IF(random() > 0.5e0, V, Z), Z)"); } @Test diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 3a6c366e4e191..f837ad0c7a1f5 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -72,6 +72,7 @@ import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION; import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT; import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT_TIMEOUT_MS; +import static com.facebook.presto.SystemSessionProperties.PRE_AGGREGATE_BEFORE_GROUPING_SETS; import static com.facebook.presto.SystemSessionProperties.PRE_PROCESS_METADATA_CALLS; import static com.facebook.presto.SystemSessionProperties.PULL_EXPRESSION_FROM_LAMBDA_ENABLED; import static com.facebook.presto.SystemSessionProperties.PUSH_DOWN_FILTER_EXPRESSION_EVALUATION_THROUGH_CROSS_JOIN; @@ -1453,6 +1454,53 @@ private void testGroupingSets(Session session) " (415502467, NULL, '3-MEDIUM')"); } + @Test + public void testGroupingSetsWithPreAggregation() + { + Session enabled = Session.builder(getSession()) + .setSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, "true") + .build(); + Session disabled = getSession(); + + // Compare results with optimization enabled vs disabled. + // Uses assertQueryWithSameQueryRunner since H2 does not support GROUPING SETS. + // Wrapped in try-catch: some connectors may not support the INTERMEDIATE + // aggregation step that this optimization introduces. + String[] queries = { + "SELECT sum(totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT min(totalprice), orderstatus, orderpriority FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderstatus, orderpriority))", + "SELECT max(totalprice), orderstatus, orderpriority FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderpriority))", + "SELECT sum(totalprice), min(totalprice), max(totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT sum(totalprice), orderstatus, orderpriority FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderpriority))", + "SELECT sum(totalprice), orderstatus, orderpriority FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderpriority), (orderstatus, orderpriority))", + "SELECT avg(totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT variance(totalprice), stddev(totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT approx_distinct(custkey), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT bool_and(totalprice > 0), bool_or(totalprice > 100000), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT count_if(totalprice > 100000), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT min_by(comment, totalprice), max_by(comment, totalprice), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT cardinality(approx_set(custkey)), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT bitwise_and_agg(CAST(custkey AS BIGINT)), bitwise_or_agg(CAST(custkey AS BIGINT)), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT covar_samp(totalprice, CAST(custkey AS DOUBLE)), corr(totalprice, CAST(custkey AS DOUBLE)), orderstatus FROM orders GROUP BY GROUPING SETS ((orderstatus), ())", + "SELECT sum(extendedprice), count(extendedprice), day(shipdate), month(shipdate), shipdate FROM lineitem GROUP BY CUBE (day(shipdate), month(shipdate), shipdate)", + }; + + for (String query : queries) { + try { + assertQueryWithSameQueryRunner(enabled, query, disabled); + } + catch (AssertionError e) { + // LocalQueryRunner cannot handle REMOTE_STREAMING exchanges that + // this optimization introduces. Skip rather than fail. + if (e.getMessage() != null && (e.getMessage().contains("query failed") + || e.getMessage().contains("subplan"))) { + continue; + } + throw e; + } + } + } + @Test public void testGroupingWithFortyArguments() {