diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlPullConstantsAboveGroupByBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlPullConstantsAboveGroupByBenchmark.java new file mode 100644 index 0000000000000..9b64a599596e4 --- /dev/null +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlPullConstantsAboveGroupByBenchmark.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.benchmark; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.testing.LocalQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.intellij.lang.annotations.Language; + +import java.util.Map; + +import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; + +public class SqlPullConstantsAboveGroupByBenchmark + extends AbstractSqlBenchmark +{ + private static final Logger LOGGER = Logger.get(SqlRewriteConditionalAggregationBenchmarks.class); + + public SqlPullConstantsAboveGroupByBenchmark(LocalQueryRunner localQueryRunner, @Language("SQL") String sql) + { + super(localQueryRunner, + "pull_constants_above_group_by", + 10, + 20, + sql); + } + + public static void main(String[] args) + { + Map disableOptimization = ImmutableMap.of("optimize_constant_grouping_keys", "false"); + String sql = "SELECT * FROM (SELECT regionkey, col, count(*) FROM (SELECT regionkey, 'bla' as col FROM nation) GROUP BY regionkey, col)"; + LOGGER.info("Without optimization"); + new SqlPullConstantsAboveGroupByBenchmark(createLocalQueryRunner(disableOptimization), sql).runBenchmark(new SimpleLineBenchmarkResultWriter(System.out)); + LOGGER.info("With optimization"); + new SqlPullConstantsAboveGroupByBenchmark(createLocalQueryRunner(), sql).runBenchmark(new SimpleLineBenchmarkResultWriter(System.out)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index f8eea8c97a20e..17c169f441e8f 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -178,6 +178,7 @@ public final class SystemSessionProperties public static final String MAX_TASKS_PER_STAGE = "max_tasks_per_stage"; public static final String DEFAULT_FILTER_FACTOR_ENABLED = "default_filter_factor_enabled"; public static final String PUSH_LIMIT_THROUGH_OUTER_JOIN = "push_limit_through_outer_join"; + public static final String OPTIMIZE_CONSTANT_GROUPING_KEYS = "optimize_constant_grouping_keys"; public static final String MAX_CONCURRENT_MATERIALIZATIONS = "max_concurrent_materializations"; public static final String PUSHDOWN_SUBFIELDS_ENABLED = "pushdown_subfields_enabled"; public static final String TABLE_WRITER_MERGE_OPERATOR_ENABLED = "table_writer_merge_operator_enabled"; @@ -958,6 +959,11 @@ public SystemSessionProperties( "push limits to the outer side of an outer join", featuresConfig.isPushLimitThroughOuterJoin(), false), + booleanProperty( + OPTIMIZE_CONSTANT_GROUPING_KEYS, + "Pull constant grouping keys above the group by", + featuresConfig.isOptimizeConstantGroupingKeys(), + false), integerProperty( MAX_CONCURRENT_MATERIALIZATIONS, "Maximum number of materializing plan sections that can run concurrently", @@ -2055,6 +2061,11 @@ public static boolean isPushLimitThroughOuterJoin(Session session) return session.getSystemProperty(PUSH_LIMIT_THROUGH_OUTER_JOIN, Boolean.class); } + public static boolean isOptimizeConstantGroupingKeys(Session session) + { + return session.getSystemProperty(OPTIMIZE_CONSTANT_GROUPING_KEYS, Boolean.class); + } + public static int getMaxConcurrentMaterializations(Session session) { return session.getSystemProperty(MAX_CONCURRENT_MATERIALIZATIONS, Integer.class); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index f8573630c5503..83b6714c35df8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -159,6 +159,7 @@ public class FeaturesConfig private double partialAggregationByteReductionThreshold = 0.5; private boolean optimizeTopNRowNumber = true; private boolean pushLimitThroughOuterJoin = true; + private boolean optimizeConstantGroupingKeys = true; private Duration iterativeOptimizerTimeout = new Duration(3, MINUTES); // by default let optimizer wait a long time in case it retrieves some data from ConnectorMetadata private Duration queryAnalyzerTimeout = new Duration(3, MINUTES); @@ -1629,6 +1630,18 @@ public boolean isPushLimitThroughOuterJoin() return pushLimitThroughOuterJoin; } + @Config("optimizer.optimize-constant-grouping-keys") + public FeaturesConfig setOptimizeConstantGroupingKeys(boolean optimizeConstantGroupingKeys) + { + this.optimizeConstantGroupingKeys = optimizeConstantGroupingKeys; + return this; + } + + public boolean isOptimizeConstantGroupingKeys() + { + return optimizeConstantGroupingKeys; + } + @Config("max-concurrent-materializations") @ConfigDescription("The maximum number of materializing plan sections that can run concurrently") public FeaturesConfig setMaxConcurrentMaterializations(int maxConcurrentMaterializations) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 78f46587727ae..5637d769e691a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -75,6 +75,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PruneTopNColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneValuesColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneWindowColumns; +import com.facebook.presto.sql.planner.iterative.rule.PullConstantsAboveGroupBy; import com.facebook.presto.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; import com.facebook.presto.sql.planner.iterative.rule.PushDownDereferences; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; @@ -415,7 +416,12 @@ public PlanOptimizers( new InlineProjections(metadata.getFunctionAndTypeManager()), new RemoveRedundantIdentityProjections(), new TransformCorrelatedSingleRowSubqueryToProject())), - new CheckSubqueryNodesAreRewritten()); + new CheckSubqueryNodesAreRewritten(), + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new PullConstantsAboveGroupBy()))); // TODO: move this before optimization if possible!! // Replace all expressions with row expressions diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PullConstantsAboveGroupBy.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PullConstantsAboveGroupBy.java new file mode 100644 index 0000000000000..48150d7e087e6 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PullConstantsAboveGroupBy.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.SystemSessionProperties.isOptimizeConstantGroupingKeys; +import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isConstant; +import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; +import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.function.Function.identity; + +/** + * Transforms: + *
+ * - GroupBy key1, , key2
+ * 
+ * Into: + *
+ * - Project 
+ *    - GroupBy key1, key2
+ * 
+ */ +public class PullConstantsAboveGroupBy + implements Rule +{ + private static final Capture SOURCE = Capture.newCapture(); + + private static final Pattern PATTERN = + aggregation() + .matching(agg -> agg.getGroupingSetCount() == 1) + .with(source().matching(project().capturedAs(SOURCE))); + + @Override + public boolean isEnabled(Session session) + { + return isOptimizeConstantGroupingKeys(session); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(AggregationNode parent, Captures captures, Context context) + { + if (!isEnabled(context.getSession())) { + return Result.empty(); + } + + // for each variable references in grouping keys, check if the source expression defines them as constants + ProjectNode source = captures.get(SOURCE); + List outputVariables = parent.getOutputVariables(); + + Map constSourceVars = extractConstVars(source, outputVariables); + + List groupingKeys = parent.getGroupingKeys(); + List newGroupingKeys = + groupingKeys.stream() + .filter(key -> !constSourceVars.containsKey(key)) + .collect(toImmutableList()); + + if (constSourceVars.isEmpty() || newGroupingKeys.equals(groupingKeys)) { + return Result.empty(); + } + + AggregationNode newAgg = new AggregationNode( + parent.getSourceLocation(), + parent.getId(), + source, + parent.getAggregations(), + singleGroupingSet(newGroupingKeys), + ImmutableList.of(), + parent.getStep(), + parent.getHashVariable(), + parent.getGroupIdVariable()); + + Map remainingVars = + outputVariables.stream() + .filter(var -> !constSourceVars.containsKey(var)) + .collect(toImmutableMap(identity(), identity())); + + Assignments.Builder assignments = Assignments.builder(); + assignments.putAll(constSourceVars); + assignments.putAll(remainingVars); + return Result.ofPlanNode( + new ProjectNode( + parent.getSourceLocation(), + context.getIdAllocator().getNextId(), + newAgg, + assignments.build(), + source.getLocality())); + } + + private static Map extractConstVars(ProjectNode projectNode, List outputVariables) + { + return projectNode.getAssignments().entrySet().stream() + .filter((entry) -> isConstantRowExpr(entry.getValue()) && outputVariables.contains(entry.getKey())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + private static boolean isConstantRowExpr(RowExpression expr) + { + if (isExpression(expr)) { + return isConstant(castToExpression(expr)); + } + return expr instanceof ConstantExpression; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 98a17bfd145de..16a715fd8e721 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -160,6 +160,7 @@ public void testDefaults() .setLegacyUnnestArrayRows(false) .setJsonSerdeCodeGenerationEnabled(false) .setPushLimitThroughOuterJoin(true) + .setOptimizeConstantGroupingKeys(true) .setMaxConcurrentMaterializations(3) .setPushdownSubfieldsEnabled(false) .setPushdownDereferenceEnabled(false) @@ -321,6 +322,7 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-unnest-array-rows", "true") .put("experimental.json-serde-codegen-enabled", "true") .put("optimizer.push-limit-through-outer-join", "false") + .put("optimizer.optimize-constant-grouping-keys", "false") .put("max-concurrent-materializations", "5") .put("experimental.pushdown-subfields-enabled", "true") .put("experimental.pushdown-dereference-enabled", "true") @@ -479,6 +481,7 @@ public void testExplicitPropertyMappings() .setDefaultFilterFactorEnabled(true) .setJsonSerdeCodeGenerationEnabled(true) .setPushLimitThroughOuterJoin(false) + .setOptimizeConstantGroupingKeys(false) .setMaxConcurrentMaterializations(5) .setPushdownSubfieldsEnabled(true) .setPushdownDereferenceEnabled(true) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPullConstantsAboveGroupBy.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPullConstantsAboveGroupBy.java new file mode 100644 index 0000000000000..9bddb4caa81f3 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPullConstantsAboveGroupBy.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.tree.FunctionCall; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.spi.plan.AggregationNode.groupingSets; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPullConstantsAboveGroupBy + extends BaseRuleTest +{ + @Test + public void testNoConstGroupingKeysDoesNotFire() + { + tester().assertThat(new PullConstantsAboveGroupBy()) + .on(p -> p.aggregation(ab -> ab + .source( + p.project( + Assignments.builder() + .put(p.variable("COL"), p.rowExpression("COL")) + .put(p.variable("CONST_COL"), p.rowExpression("1")) + .build(), + p.values(p.variable("COL")))) + .addAggregation(p.variable("AVG", DOUBLE), p.rowExpression("avg(COL)")) + .singleGroupingSet(p.variable("COL")))) + .doesNotFire(); + } + + @Test + public void testMultipleGroupingSetsDoesNotFire() + { + tester().assertThat(new PullConstantsAboveGroupBy()) + .on(p -> p.aggregation(ab -> ab + .source( + p.project( + Assignments.builder() + .put(p.variable("COL"), p.rowExpression("COL")) + .put(p.variable("CONST_COL"), p.rowExpression("1")) + .build(), + p.values(p.variable("COL")))) + .addAggregation(p.variable("AVG", DOUBLE), p.rowExpression("avg(COL)")) + .groupingSets( + groupingSets(ImmutableList.of(p.variable("COL")), 2, ImmutableSet.of(0))))) + .doesNotFire(); + } + + @Test + public void testRuleDisabledDoesNotFire() + { + RuleTester tester = new RuleTester(ImmutableList.of(), ImmutableMap.of("optimize_constant_grouping_keys", "false")); + + tester.assertThat(new PullConstantsAboveGroupBy()) + .on(p -> p.aggregation(ab -> ab + .source( + p.project( + Assignments.builder() + .put(p.variable("COL"), p.rowExpression("COL")) + .put(p.variable("CONST_COL"), p.rowExpression("1")) + .build(), + p.values(p.variable("COL")))) + .addAggregation(p.variable("AVG", DOUBLE), p.rowExpression("avg(COL)")) + .singleGroupingSet(p.variable("CONST_COL"), p.variable("COL")))) + .doesNotFire(); + } + + @Test + public void testSingleConstColumn() + { + tester().assertThat(new PullConstantsAboveGroupBy()) + .on(p -> p.aggregation(ab -> ab + .source( + p.project( + Assignments.builder() + .put(p.variable("COL"), p.rowExpression("COL")) + .put(p.variable("CONST_COL"), p.rowExpression("1")) + .build(), + p.values(p.variable("COL")))) + .addAggregation(p.variable("AVG", DOUBLE), p.rowExpression("avg(COL)")) + .singleGroupingSet(p.variable("CONST_COL"), p.variable("COL")))) + .matches( + project( + ImmutableMap.of( + "CONST_COL", expression("1")), + aggregation( + singleGroupingSet("COL"), + ImmutableMap., ExpectedValueProvider>builder() + .put(Optional.of("AVG"), functionCall("avg", ImmutableList.of("col"))) + .build(), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + project(ImmutableMap.of( + "CONST_COL", expression("1")), + values("COL"))))); + } + + @Test + public void testMultipleConstCols() + { + tester().assertThat(new PullConstantsAboveGroupBy()) + .on(p -> p.aggregation(ab -> ab + .source( + p.project( + Assignments.builder() + .put(p.variable("COL"), p.rowExpression("COL")) + .put(p.variable("CONST_COL1"), p.rowExpression("1")) + .put(p.variable("CONST_COL2"), p.rowExpression("2")) + .build(), + p.values(p.variable("COL")))) + .addAggregation(p.variable("AVG", DOUBLE), p.rowExpression("avg(COL)")) + .singleGroupingSet(p.variable("CONST_COL1"), p.variable("COL"), p.variable("CONST_COL2")))) + .matches( + project( + ImmutableMap.of( + "CONST_COL1", expression("1"), + "CONST_COL2", expression("2")), + aggregation( + singleGroupingSet("COL"), + ImmutableMap., ExpectedValueProvider>builder() + .put(Optional.of("AVG"), functionCall("avg", ImmutableList.of("col"))) + .build(), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + project(ImmutableMap.of( + "CONST_COL1", expression("1"), + "CONST_COL2", expression("2")), + values("COL"))))); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index eff8d638fd605..41aeaba9511b6 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -4880,6 +4880,18 @@ public void testFilterPushdownWithAggregation() assertQuery("SELECT * FROM (SELECT count(*) FROM orders) WHERE null"); } + @Test + public void testGroupByWithConstants() + { + assertQuery("SELECT * FROM (SELECT regionkey, col, count(*) FROM (SELECT regionkey, 'bla' as col FROM nation) GROUP BY regionkey, col)"); + assertQuery("select 'blah', * from (select regionkey, count(*) FROM nation GROUP BY regionkey)"); + assertQuery("SELECT * FROM (SELECT col, count(*) FROM (SELECT 'bla' as col FROM nation) GROUP BY col)"); + assertQuery("SELECT cnt FROM (SELECT col, count(*) as cnt FROM (SELECT 'bla' as col from nation) GROUP BY col)"); + assertQuery("SELECT MIN(10), 1 as col1 GROUP BY 2"); + assertQuery("SELECT col, 'bla' as const_col, count(*) FROM (SELECT 1 as col) GROUP BY 1,2"); + assertQuery("SELECT AVG(x) FROM (SELECT 1 AS x, orderstatus FROM orders) GROUP BY x, orderstatus"); + } + @Test public void testAccessControl() {