Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions presto-docs/src/main/sphinx/admin/properties-session.rst
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,24 @@ to make the query plan easier to read.

The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.optimize-hash-generation\`\``.

``pre_aggregate_before_grouping_sets``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

* **Type:** ``boolean``
* **Default value:** ``false``

When enabled, inserts a partial aggregation below the ``GroupId`` node in grouping sets
queries to reduce the number of rows that ``GroupId`` multiplies across grouping sets.
The partial aggregation groups by the union of all grouping set columns (the finest
granularity needed), which can drastically reduce the input to ``GroupId``. This is
most effective when the data has high cardinality on the grouping columns, as the
pre-aggregation can significantly reduce the row count before multiplication.

Only applies to decomposable aggregation functions such as ``SUM``, ``COUNT``, ``MIN``,
or ``MAX`` that support partial/intermediate/final splitting.

The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.pre-aggregate-before-grouping-sets\`\``.

``push_aggregation_through_join``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
12 changes: 12 additions & 0 deletions presto-docs/src/main/sphinx/admin/properties.rst
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,18 @@ create them).
The single distinct optimization will try to replace multiple ``DISTINCT`` clauses
with a single ``GROUP BY`` clause, which can be substantially faster to execute.

``optimizer.pre-aggregate-before-grouping-sets``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

* **Type:** ``boolean``
* **Default value:** ``false``

When enabled, inserts a partial aggregation below the ``GroupId`` node in grouping sets
queries to reduce the number of rows that ``GroupId`` multiplies across grouping sets.
Only applies to decomposable aggregation functions.

The corresponding session property is :ref:`admin/properties-session:\`\`pre_aggregate_before_grouping_sets\`\``.

``optimizer.push-aggregation-through-join``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.hive.benchmark;

import org.testng.annotations.Test;

/**
* Benchmarks the PreAggregateBeforeGroupId optimization against a baseline
* and the existing AddExchangesBelowPartialAggregationOverGroupIdRuleSet.
*
* <p>Uses CROSS JOIN UNNEST to amplify lineitem rows without triggering
* PushPartialAggregationThroughExchange (which would happen with UNION ALL).
*
* <p>Run via:
* <pre>
* mvn test -pl presto-hive \
* -Dtest=BenchmarkGroupingSetsPreAggregation \
* -DfailIfNoTests=false
* </pre>
*/
public final class BenchmarkGroupingSetsPreAggregation
{
private static final String QUERY =
"SELECT yr, mo, dy, shipmode, returnflag, " +
"sum(quantity), count(*), min(extendedprice), max(extendedprice) " +
"FROM (" +
" SELECT year(shipdate) AS yr, month(shipdate) AS mo, day(shipdate) AS dy, " +
" shipmode, returnflag, quantity, extendedprice " +
" FROM lineitem " +
" CROSS JOIN UNNEST(ARRAY[1,2,3,4,5]) AS t(x)" +
") t " +
"GROUP BY CUBE (yr, mo, dy, shipmode, returnflag)";

@Test
public void benchmark()
throws Exception
{
try (HiveDistributedBenchmarkRunner runner =
new HiveDistributedBenchmarkRunner(3, 5)) {
runner.addScenario("baseline", builder -> {
builder.setSystemProperty("pre_aggregate_before_grouping_sets", "false");
builder.setSystemProperty("add_exchange_below_partial_aggregation_over_group_id", "false");
});

runner.addScenario("pre_aggregate_before_groupid", builder -> {
builder.setSystemProperty("pre_aggregate_before_grouping_sets", "true");
builder.setSystemProperty("add_exchange_below_partial_aggregation_over_group_id", "false");
});

runner.addScenario("add_exchange_below_agg", builder -> {
builder.setSystemProperty("pre_aggregate_before_grouping_sets", "false");
builder.setSystemProperty("add_exchange_below_partial_aggregation_over_group_id", "true");
});

runner.runWithVerification(QUERY);
}
}

public static void main(String[] args)
throws Exception
{
new BenchmarkGroupingSetsPreAggregation().benchmark();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.hive.benchmark;

import com.facebook.presto.Session;
import com.facebook.presto.hive.HiveQueryRunner;
import com.facebook.presto.testing.MaterializedResult;
import com.facebook.presto.testing.QueryRunner;

import java.util.LinkedHashMap;
import java.util.Map;

import static io.airlift.tpch.TpchTable.getTables;
import static java.util.Objects.requireNonNull;

/**
* Runs SQL benchmarks against a Hive-backed DistributedQueryRunner.
* Supports comparing multiple session configurations side by side.
*
* <p>Usage:
* <pre>
* try (HiveDistributedBenchmarkRunner runner = new HiveDistributedBenchmarkRunner(3, 5)) {
* runner.addScenario("baseline", sessionBuilder -> {});
* runner.addScenario("optimized", sessionBuilder ->
* sessionBuilder.setSystemProperty("my_property", "true"));
* runner.run("SELECT ... GROUP BY CUBE(...)");
* }
* </pre>
*/
public class HiveDistributedBenchmarkRunner
implements AutoCloseable
{
private final QueryRunner queryRunner;
private final int warmupIterations;
private final int measuredIterations;
private final Map<String, Session> scenarios = new LinkedHashMap<>();
private final StringBuilder results = new StringBuilder();

public HiveDistributedBenchmarkRunner(int warmupIterations, int measuredIterations)
throws Exception
{
this.warmupIterations = warmupIterations;
this.measuredIterations = measuredIterations;
this.queryRunner = HiveQueryRunner.createQueryRunner(getTables());
}

public void addScenario(String name, SessionConfigurator configurator)
{
requireNonNull(name, "name is null");
requireNonNull(configurator, "configurator is null");
Session.SessionBuilder builder = Session.builder(queryRunner.getDefaultSession());
configurator.configure(builder);
scenarios.put(name, builder.build());
}

public String run(String sql)
{
results.setLength(0);
Map<String, Long> averages = new LinkedHashMap<>();

for (Map.Entry<String, Session> entry : scenarios.entrySet()) {
String name = entry.getKey();
Session session = entry.getValue();
long avg = runScenario(name, session, sql);
averages.put(name, avg);
}

// Summary
results.append("\n=== Summary ===\n");
Long baselineAvg = averages.values().iterator().next();
for (Map.Entry<String, Long> entry : averages.entrySet()) {
double speedup = (double) baselineAvg / entry.getValue();
results.append(String.format("%-30s %6d ms (%.2fx)\n",
entry.getKey(), entry.getValue(), speedup));
}

String output = results.toString();
System.out.println(output);

// Write to file since surefire mixes stdout with logging
try {
String path = System.getProperty("java.io.tmpdir") + "/hive_benchmark_results.txt";
java.nio.file.Files.write(java.nio.file.Paths.get(path), output.getBytes());
System.out.println("Results written to: " + path);
}
catch (Exception e) {
// ignore
}

return output;
}

/**
* Runs the benchmark query with correctness verification.
* All scenarios must produce the same results as the first scenario.
*/
public String runWithVerification(String sql)
{
String output = run(sql);

// Verify correctness: all scenarios must match the first
MaterializedResult expected = null;
for (Map.Entry<String, Session> entry : scenarios.entrySet()) {
MaterializedResult actual = queryRunner.execute(entry.getValue(), sql);
if (expected == null) {
expected = actual;
}
else {
if (!resultsMatch(expected, actual)) {
throw new AssertionError(
"Results mismatch for scenario '" + entry.getKey() + "'");
}
}
}
return output;
}

private long runScenario(String name, Session session, String sql)
{
results.append(String.format("--- %s ---\n", name));

// Warmup
for (int i = 0; i < warmupIterations; i++) {
queryRunner.execute(session, sql);
}

// Measured runs
long totalMs = 0;
for (int i = 0; i < measuredIterations; i++) {
long start = System.nanoTime();
queryRunner.execute(session, sql);
long elapsedMs = (System.nanoTime() - start) / 1_000_000;
totalMs += elapsedMs;
results.append(String.format(" run %d: %d ms\n", i + 1, elapsedMs));
}
long avg = totalMs / measuredIterations;
results.append(String.format(" avg: %d ms\n\n", avg));
return avg;
}

private static boolean resultsMatch(MaterializedResult a, MaterializedResult b)
{
return a.getMaterializedRows().size() == b.getMaterializedRows().size()
&& new java.util.HashSet<>(a.getMaterializedRows())
.equals(new java.util.HashSet<>(b.getMaterializedRows()));
}

public QueryRunner getQueryRunner()
{
return queryRunner;
}

@Override
public void close()
{
queryRunner.close();
}

@FunctionalInterface
public interface SessionConfigurator
{
void configure(Session.SessionBuilder builder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ public final class SystemSessionProperties
public static final String PUSHDOWN_THROUGH_UNNEST = "pushdown_through_unnest";
public static final String SIMPLIFY_AGGREGATIONS_OVER_CONSTANT = "simplify_aggregations_over_constant";
public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN = "push_partial_aggregation_through_join";
public static final String PRE_AGGREGATE_BEFORE_GROUPING_SETS = "pre_aggregate_before_grouping_sets";
public static final String PARSE_DECIMAL_LITERALS_AS_DOUBLE = "parse_decimal_literals_as_double";
public static final String FORCE_SINGLE_NODE_OUTPUT = "force_single_node_output";
public static final String FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_SIZE = "filter_and_project_min_output_page_size";
Expand Down Expand Up @@ -955,6 +956,11 @@ public SystemSessionProperties(
"Push partial aggregations below joins",
featuresConfig.isPushPartialAggregationThroughJoin(),
false),
booleanProperty(
PRE_AGGREGATE_BEFORE_GROUPING_SETS,
"Pre-aggregate data before GroupId node to reduce row multiplication in grouping sets queries",
featuresConfig.isPreAggregateBeforeGroupingSets(),
false),
booleanProperty(
PARSE_DECIMAL_LITERALS_AS_DOUBLE,
"Parse decimal literals as DOUBLE instead of DECIMAL",
Expand Down Expand Up @@ -2678,6 +2684,11 @@ public static boolean isSimplifyAggregationsOverConstant(Session session)
return session.getSystemProperty(SIMPLIFY_AGGREGATIONS_OVER_CONSTANT, Boolean.class);
}

public static boolean isPreAggregateBeforeGroupingSets(Session session)
{
return session.getSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, Boolean.class);
}

public static boolean isParseDecimalLiteralsAsDouble(Session session)
{
return session.getSystemProperty(PARSE_DECIMAL_LITERALS_AS_DOUBLE, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ private void initializeAggregationBuilderIfNeeded()
.map(PartialAggregationController::isPartialAggregationDisabled)
.orElse(false);

if (step.isOutputPartial() && partialAggregationDisabled) {
if (step.isInputRaw() && step.isOutputPartial() && partialAggregationDisabled) {
aggregationBuilder = new SkipAggregationBuilder(
groupByChannels,
hashChannel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ public class FeaturesConfig
private boolean simplifyCoalesceOverJoinKeys;
private boolean pushdownThroughUnnest;
private boolean simplifyAggregationsOverConstant;
private boolean preAggregateBeforeGroupingSets;
private double memoryRevokingTarget = 0.5;
private double memoryRevokingThreshold = 0.9;
private boolean useMarkDistinct = true;
Expand Down Expand Up @@ -1744,6 +1745,18 @@ public FeaturesConfig setSimplifyAggregationsOverConstant(boolean simplifyAggreg
return this;
}

public boolean isPreAggregateBeforeGroupingSets()
{
return preAggregateBeforeGroupingSets;
}

@Config("optimizer.pre-aggregate-before-grouping-sets")
public FeaturesConfig setPreAggregateBeforeGroupingSets(boolean preAggregateBeforeGroupingSets)
{
this.preAggregateBeforeGroupingSets = preAggregateBeforeGroupingSets;
return this;
}

public boolean isForceSingleNodeOutput()
{
return forceSingleNodeOutput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3648,7 +3648,7 @@ private static Optional<PartialAggregationController> createPartialAggregationCo
AggregationNode.Step step,
Session session)
{
if (maxPartialAggregationMemorySize.isPresent() && step.isOutputPartial() && isAdaptivePartialAggregationEnabled(session)) {
if (maxPartialAggregationMemorySize.isPresent() && step.isInputRaw() && step.isOutputPartial() && isAdaptivePartialAggregationEnabled(session)) {
return Optional.of(new PartialAggregationController(maxPartialAggregationMemorySize.get(), getAdaptivePartialAggregationRowsReductionRatioThreshold(session)));
}
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import com.facebook.presto.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct;
import com.facebook.presto.sql.planner.iterative.rule.PickTableLayout;
import com.facebook.presto.sql.planner.iterative.rule.PlanRemoteProjections;
import com.facebook.presto.sql.planner.iterative.rule.PreAggregateBeforeGroupId;
import com.facebook.presto.sql.planner.iterative.rule.PruneAggregationColumns;
import com.facebook.presto.sql.planner.iterative.rule.PruneAggregationSourceColumns;
import com.facebook.presto.sql.planner.iterative.rule.PruneCountAggregationOverScalar;
Expand Down Expand Up @@ -1036,6 +1037,14 @@ public PlanOptimizers(
ImmutableSet.of(
new PruneJoinColumns())));

builder.add(new IterativeOptimizer(
metadata,
ruleStats,
statsCalculator,
costCalculator,
ImmutableSet.of(
new PreAggregateBeforeGroupId(metadata.getFunctionAndTypeManager()))));

builder.add(new IterativeOptimizer(
metadata,
ruleStats,
Expand Down
Loading
Loading