diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java index 03c1d21c0dd10..8f70de16438a5 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java @@ -122,7 +122,7 @@ public static RowExpression and(RowExpression... expressions) return and(asList(expressions)); } - public static RowExpression and(Collection expressions) + public static RowExpression and(Collection expressions) { return binaryExpression(AND, expressions); } @@ -132,12 +132,12 @@ public static RowExpression or(RowExpression... expressions) return or(asList(expressions)); } - public static RowExpression or(Collection expressions) + public static RowExpression or(Collection expressions) { return binaryExpression(OR, expressions); } - public static RowExpression binaryExpression(Form form, Collection expressions) + public static RowExpression binaryExpression(Form form, Collection expressions) { requireNonNull(form, "operator is null"); requireNonNull(expressions, "expressions is null"); diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 5f1934285094e..fa4d2f1fc80d8 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -199,6 +199,7 @@ public final class SystemSessionProperties public static final String VERBOSE_EXCEEDED_MEMORY_LIMIT_ERRORS_ENABLED = "verbose_exceeded_memory_limit_errors_enabled"; public static final String MATERIALIZED_VIEW_DATA_CONSISTENCY_ENABLED = "materialized_view_data_consistency_enabled"; public static final String QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED = "query_optimization_with_materialized_view_enabled"; + public static final String AGGREGATION_IF_TO_FILTER_REWRITE_ENABLED = "aggregation_if_to_filter_rewrite_enabled"; private final List> sessionProperties; @@ -735,8 +736,8 @@ public SystemSessionProperties( PARTIAL_AGGREGATION_STRATEGY, format("Partial aggregation strategy to use. Options are %s", Stream.of(PartialAggregationStrategy.values()) - .map(PartialAggregationStrategy::name) - .collect(joining(","))), + .map(PartialAggregationStrategy::name) + .collect(joining(","))), VARCHAR, PartialAggregationStrategy.class, featuresConfig.getPartialAggregationStrategy(), @@ -1066,7 +1067,12 @@ public SystemSessionProperties( QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED, "Enable query optimization with materialized view", featuresConfig.isQueryOptimizationWithMaterializedViewEnabled(), - true)); + true), + booleanProperty( + AGGREGATION_IF_TO_FILTER_REWRITE_ENABLED, + "Enable rewriting the IF expression inside an aggregation function to a filter clause outside the aggregation", + featuresConfig.isAggregationIfToFilterRewriteEnabled(), + false)); } public static boolean isEmptyJoinOptimization(Session session) @@ -1801,4 +1807,9 @@ public static boolean isQueryOptimizationWithMaterializedViewEnabled(Session ses { return session.getSystemProperty(QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED, Boolean.class); } + + public static boolean isAggregationIfToFilterRewriteEnabled(Session session) + { + return session.getSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_ENABLED, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 314c188dfb3bd..8f693f1e358c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -203,6 +203,7 @@ public class FeaturesConfig private boolean materializedViewDataConsistencyEnabled = true; private boolean queryOptimizationWithMaterializedViewEnabled; + private boolean aggregationIfToFilterRewriteEnabled = true; public enum PartitioningPrecisionStrategy { @@ -1786,4 +1787,17 @@ public FeaturesConfig setQueryOptimizationWithMaterializedViewEnabled(boolean va this.queryOptimizationWithMaterializedViewEnabled = value; return this; } + + public boolean isAggregationIfToFilterRewriteEnabled() + { + return aggregationIfToFilterRewriteEnabled; + } + + @Config("optimizer.aggregation-if-to-filter-rewrite-enabled") + @ConfigDescription("Enable rewriting the IF expression inside an aggregation function to a filter clause outside the aggregation") + public FeaturesConfig setAggregationIfToFilterRewriteEnabled(boolean value) + { + this.aggregationIfToFilterRewriteEnabled = value; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 4e83c0d9c04c8..bc985fbbda3d7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -97,6 +97,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes; import com.facebook.presto.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins; +import com.facebook.presto.sql.planner.iterative.rule.RewriteAggregationIfToFilter; import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject; import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides; @@ -399,6 +400,11 @@ public PlanOptimizers( // After this point, all planNodes should not contain OriginalExpression builder.add( + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new RewriteAggregationIfToFilter())), predicatePushDown, new IterativeOptimizer( ruleStats, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java new file mode 100644 index 0000000000000..9136e88107cb7 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java @@ -0,0 +1,197 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.relational.Expressions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSortedSet; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.SystemSessionProperties.isAggregationIfToFilterRewriteEnabled; +import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; +import static com.facebook.presto.expressions.LogicalRowExpressions.or; +import static com.facebook.presto.matching.Capture.newCapture; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; +import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap; +import static java.util.function.Function.identity; + +/** + * A optimizer rule which rewrites + * AGG(IF(condition, expr)) + * to + * AGG(expr) FILTER (WHERE condition). + *

+ * The latter plan is more efficient because: + * 1. The filter can be pushed down to the scan node. + * 2. The rows not matching the condition are not aggregated. + * 3. The IF() expression wrapper is removed. + */ +public class RewriteAggregationIfToFilter + implements Rule +{ + private static final Capture CHILD = newCapture(); + + private static final Pattern PATTERN = aggregation() + .with(source().matching(project().capturedAs(CHILD))); + + @Override + public boolean isEnabled(Session session) + { + return isAggregationIfToFilterRewriteEnabled(session); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(AggregationNode aggregationNode, Captures captures, Context context) + { + ProjectNode sourceProject = captures.get(CHILD); + + Set aggregationsToRewrite = aggregationNode.getAggregations().values().stream() + .filter(aggregation -> shouldRewriteAggregation(aggregation, sourceProject)) + .collect(toImmutableSet()); + if (aggregationsToRewrite.isEmpty()) { + return Result.empty(); + } + + // Get the corresponding assignments in the input project. + // The aggregationReferences only has the aggregations to rewrite, thus the sourceAssignments only has IF expressions with NULL false results. + // Multiple aggregations may reference the same input. We use a map to dedup them based on the VariableReferenceExpression, so that we only do the rewrite once per input + // IF expression. + // The order of sourceAssignments determines the order of generating the new variables for the IF conditions and results. We use a sorted map to get a deterministic + // order based on the name of the VariableReferenceExpressions. + Map sourceAssignments = aggregationsToRewrite.stream() + .map(aggregation -> (VariableReferenceExpression) aggregation.getArguments().get(0)) + .collect(toImmutableSortedMap(VariableReferenceExpression::compareTo, identity(), variable -> sourceProject.getAssignments().get(variable), (left, right) -> left)); + + Assignments.Builder newAssignments = Assignments.builder(); + // We don't remove the IF expression now in case the aggregation has other references to it. These will be cleaned up by the PruneUnreferencedOutputs rule later. + newAssignments.putAll(sourceProject.getAssignments()); + + // Map from the aggregation reference to the IF condition reference. + Map aggregationReferenceToConditionReference = new HashMap<>(); + // Map from the aggregation reference to the IF result reference. + Map aggregationReferenceToIfResultReference = new HashMap<>(); + + for (Map.Entry entry : sourceAssignments.entrySet()) { + VariableReferenceExpression outputVariable = entry.getKey(); + SpecialFormExpression ifExpression = (SpecialFormExpression) entry.getValue(); + + RowExpression condition = ifExpression.getArguments().get(0); + VariableReferenceExpression conditionReference = context.getVariableAllocator().newVariable(condition); + newAssignments.put(conditionReference, condition); + aggregationReferenceToConditionReference.put(outputVariable, conditionReference); + + RowExpression trueResult = ifExpression.getArguments().get(1); + VariableReferenceExpression ifResultReference = context.getVariableAllocator().newVariable(trueResult); + newAssignments.put(ifResultReference, trueResult); + aggregationReferenceToIfResultReference.put(outputVariable, ifResultReference); + } + + // Build new aggregations. + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + // Stores the masks used to build the filter predicates. Use set to dedup the predicates. + ImmutableSortedSet.Builder masks = ImmutableSortedSet.naturalOrder(); + for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) { + VariableReferenceExpression output = entry.getKey(); + Aggregation aggregation = entry.getValue(); + if (!aggregationsToRewrite.contains(aggregation)) { + aggregations.put(output, aggregation); + continue; + } + VariableReferenceExpression aggregationReference = (VariableReferenceExpression) aggregation.getArguments().get(0); + CallExpression callExpression = aggregation.getCall(); + VariableReferenceExpression mask = aggregationReferenceToConditionReference.get(aggregationReference); + aggregations.put(output, new Aggregation( + new CallExpression( + callExpression.getDisplayName(), + callExpression.getFunctionHandle(), + callExpression.getType(), + ImmutableList.of(aggregationReferenceToIfResultReference.get(aggregationReference))), + Optional.empty(), + aggregation.getOrderBy(), + aggregation.isDistinct(), + Optional.of(aggregationReferenceToConditionReference.get(aggregationReference)))); + masks.add(mask); + } + + RowExpression predicate = TRUE_CONSTANT; + if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) { + // All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient. + predicate = or(masks.build()); + } + return Result.ofPlanNode( + new AggregationNode( + context.getIdAllocator().getNextId(), + new FilterNode( + context.getIdAllocator().getNextId(), + new ProjectNode( + context.getIdAllocator().getNextId(), + sourceProject.getSource(), + newAssignments.build()), + predicate), + aggregations.build(), + aggregationNode.getGroupingSets(), + aggregationNode.getPreGroupedVariables(), + aggregationNode.getStep(), + aggregationNode.getHashVariable(), + aggregationNode.getGroupIdVariable())); + } + + private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode sourceProject) + { + if (!(aggregation.getArguments().size() == 1 && aggregation.getArguments().get(0) instanceof VariableReferenceExpression)) { + // Currently we only handle aggregation with a single VariableReferenceExpression. The detailed expressions are in a project node below this aggregation. + return false; + } + if (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()) { + // Do not rewrite the aggregation if it already has a filter or mask. + return false; + } + RowExpression sourceExpression = sourceProject.getAssignments().get((VariableReferenceExpression) aggregation.getArguments().get(0)); + if (!(sourceExpression instanceof SpecialFormExpression)) { + return false; + } + SpecialFormExpression expression = (SpecialFormExpression) sourceExpression; + // Only rewrite the aggregation if the else branch is not present. + return expression.getForm() == IF && Expressions.isNull(expression.getArguments().get(2)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 6cd349dc1f668..198c5f53fd196 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -174,7 +174,8 @@ public void testDefaults() .setOffsetClauseEnabled(false) .setPartialResultsMaxExecutionTimeMultiplier(2.0) .setMaterializedViewDataConsistencyEnabled(true) - .setQueryOptimizationWithMaterializedViewEnabled(false)); + .setQueryOptimizationWithMaterializedViewEnabled(false) + .setAggregationIfToFilterRewriteEnabled(true)); } @Test @@ -301,6 +302,7 @@ public void testExplicitPropertyMappings() .put("offset-clause-enabled", "true") .put("materialized-view-data-consistency-enabled", "false") .put("query-optimization-with-materialized-view-enabled", "true") + .put("optimizer.aggregation-if-to-filter-rewrite-enabled", "false") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -424,7 +426,8 @@ public void testExplicitPropertyMappings() .setOffsetClauseEnabled(true) .setPartialResultsMaxExecutionTimeMultiplier(1.5) .setMaterializedViewDataConsistencyEnabled(false) - .setQueryOptimizationWithMaterializedViewEnabled(true); + .setQueryOptimizationWithMaterializedViewEnabled(true) + .setAggregationIfToFilterRewriteEnabled(false); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java index ca94f6ef51920..1d3e2e639ce71 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java @@ -85,7 +85,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses List aggregationsWithMask = aggregationNode.getAggregations() .entrySet() .stream() - .filter(entry -> entry.getValue().isDistinct()) + .filter(entry -> entry.getValue().getMask().isPresent()) .map(Map.Entry::getKey) .collect(Collectors.toList()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java index 970a742485a2f..970ed9692086b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java @@ -37,6 +37,7 @@ import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; +import com.facebook.presto.sql.tree.IfExpression; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.IsNotNullPredicate; @@ -73,6 +74,7 @@ import static com.facebook.presto.common.type.StandardTypes.VARCHAR; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IN; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH; @@ -150,6 +152,17 @@ protected Boolean visitCast(Cast expected, RowExpression actual) return process(expected.getExpression(), ((CallExpression) actual).getArguments().get(0)); } + @Override + protected Boolean visitIfExpression(IfExpression expected, RowExpression actual) + { + if (!(actual instanceof SpecialFormExpression) || !((SpecialFormExpression) actual).getForm().equals(IF)) { + return false; + } + return process(expected.getCondition(), ((SpecialFormExpression) actual).getArguments().get(0)) && + process(expected.getTrueValue(), ((SpecialFormExpression) actual).getArguments().get(1)) && + process(expected.getFalseValue().orElse(new NullLiteral()), ((SpecialFormExpression) actual).getArguments().get(2)); + } + @Override protected Boolean visitIsNullPredicate(IsNullPredicate expected, RowExpression actual) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java index 6d8efbde77888..64a37dbc401fd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.google.common.collect.ImmutableList; @@ -50,7 +51,7 @@ public void testNotAllInputsReferenced() ImmutableMap.of( Optional.of("avg"), functionCall("avg", ImmutableList.of("input"))), - ImmutableMap.of(), + ImmutableMap.of(new Symbol("avg"), new Symbol("mask")), Optional.empty(), SINGLE, strictProject( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java index 400679343e7bb..fb53960cb557b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java @@ -17,6 +17,7 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.tree.SortItem; @@ -54,7 +55,9 @@ public void testBasics() "array_agg", ImmutableList.of("input"), ImmutableList.of(sort("input", SortItem.Ordering.ASCENDING, SortItem.NullOrdering.UNDEFINED)))), - ImmutableMap.of(), + ImmutableMap.of( + new Symbol("avg"), new Symbol("mask"), + new Symbol("array_agg"), new Symbol("mask")), Optional.empty(), SINGLE, values("input", "key", "keyHash", "mask"))); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteAggregationIfToFilter.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteAggregationIfToFilter.java new file mode 100644 index 0000000000000..1d8934dadd095 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteAggregationIfToFilter.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.ExpressionMatcher; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.globalAggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; + +public class TestRewriteAggregationIfToFilter + extends BaseRuleTest +{ + @Test + public void testDoesNotFireForNonIf() + { + // The aggregation expression is not an if expression. + tester().assertThat(new RewriteAggregationIfToFilter()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BooleanType.BOOLEAN); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL) + .addAggregation(p.variable("expr"), p.rowExpression("count(a)")) + .source(p.project( + assignment(a, p.rowExpression("ds > '2021-07-01'")), + p.values(ds)))); + }).doesNotFire(); + } + + @Test + public void testDoesNotFireForIfWithElse() + { + // The if expression has an else branch. We cannot rewrite it. + tester().assertThat(new RewriteAggregationIfToFilter()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL) + .addAggregation(p.variable("expr"), p.rowExpression("count(a)")) + .source(p.project( + assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1, 2)")), + p.values(ds)))); + }).doesNotFire(); + } + + @Test + public void testFireOneAggregation() + { + tester().assertThat(new RewriteAggregationIfToFilter()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL) + .addAggregation(p.variable("expr"), p.rowExpression("count(a)")) + .source(p.project( + assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1)")), + p.values(ds)))); + }) + .matches( + aggregation( + globalAggregation(), + ImmutableMap.of(Optional.of("expr"), functionCall("count", ImmutableList.of("expr_0"))), + ImmutableMap.of(new Symbol("expr"), new Symbol("greater_than")), + Optional.empty(), + AggregationNode.Step.FINAL, + filter( + "greater_than", + project(ImmutableMap.of( + "a", expression("IF(ds > '2021-07-01', 1)"), + "expr_0", expression("1"), + "greater_than", expression("ds > '2021-07-01'")), + values("ds"))))); + } + + @Test + public void testFireTwoAggregations() + { + tester().assertThat(new RewriteAggregationIfToFilter()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL) + .addAggregation(p.variable("expr0"), p.rowExpression("count(a)")) + .addAggregation(p.variable("expr1"), p.rowExpression("count(b)")) + .source(p.project( + assignment( + a, p.rowExpression("IF(ds > '2021-07-01', 1)"), + b, p.rowExpression("IF(ds > '2021-06-01', 2)")), + p.values(ds)))); + }) + .matches( + aggregation( + globalAggregation(), + ImmutableMap.of( + Optional.of("expr0"), functionCall("count", ImmutableList.of("expr")), + Optional.of("expr1"), functionCall("count", ImmutableList.of("expr_1"))), + ImmutableMap.of( + new Symbol("expr0"), new Symbol("greater_than"), + new Symbol("expr1"), new Symbol("greater_than_0")), + Optional.empty(), + AggregationNode.Step.FINAL, + filter( + "greater_than or greater_than_0", + project(new ImmutableMap.Builder() + .put("a", expression("IF(ds > '2021-07-01', 1)")) + .put("b", expression("IF(ds > '2021-06-01', 2)")) + .put("expr", expression("1")) + .put("expr_1", expression("2")) + .put("greater_than", expression("ds > '2021-07-01'")) + .put("greater_than_0", expression("ds > '2021-06-01'")) + .build(), + values("ds"))))); + } + + @Test + public void testFireTwoAggregationsWithSharedInput() + { + tester().assertThat(new RewriteAggregationIfToFilter()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression column0 = p.variable("column0", BIGINT); + return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL) + .addAggregation(p.variable("expr0"), p.rowExpression("MIN(a)")) + .addAggregation(p.variable("expr1"), p.rowExpression("MAX(a)")) + .source(p.project( + assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")), + p.values(ds, column0)))); + }) + .matches( + aggregation( + globalAggregation(), + ImmutableMap.of( + Optional.of("expr0"), functionCall("min", ImmutableList.of("expr")), + Optional.of("expr1"), functionCall("max", ImmutableList.of("expr"))), + ImmutableMap.of( + new Symbol("expr0"), new Symbol("greater_than"), + new Symbol("expr1"), new Symbol("greater_than")), + Optional.empty(), + AggregationNode.Step.FINAL, + filter( + "greater_than", + project(new ImmutableMap.Builder() + .put("a", expression("IF(ds > '2021-06-01', column0)")) + .put("expr", expression("column0")) + .put("greater_than", expression("ds > '2021-06-01'")) + .build(), + values("ds", "column0"))))); + } + + @Test + public void testFireForOneOfTwoAggregations() + { + tester().assertThat(new RewriteAggregationIfToFilter()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL) + .addAggregation(p.variable("expr0"), p.rowExpression("count(a)")) + .addAggregation(p.variable("expr1"), p.rowExpression("count(b)")) + .source(p.project( + assignment( + a, p.rowExpression("IF(ds > '2021-07-01', 1)"), + b, p.rowExpression("ds")), + p.values(ds)))); + }) + .matches( + aggregation( + globalAggregation(), + ImmutableMap.of( + Optional.of("expr0"), functionCall("count", ImmutableList.of("expr")), + Optional.of("expr1"), functionCall("count", ImmutableList.of("b"))), + ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")), + Optional.empty(), + AggregationNode.Step.FINAL, + filter( + "true", + project(new ImmutableMap.Builder() + .put("a", expression("IF(ds > '2021-07-01', 1)")) + .put("b", expression("ds")) + .put("greater_than", expression("ds > '2021-07-01'")) + .put("expr", expression("1")) + .build(), + values("ds"))))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/query/TestFilteredAggregations.java b/presto-main/src/test/java/com/facebook/presto/sql/query/TestFilteredAggregations.java index 29bc2df6abf78..85fd47e1165d5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/query/TestFilteredAggregations.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/query/TestFilteredAggregations.java @@ -52,14 +52,26 @@ public void testAddPredicateForFilterClauses() assertions.assertQuery( "SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)", "VALUES (BIGINT '10')"); + assertions.assertQuery( + "SELECT sum(IF(x > 0, x)) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)", + "VALUES (BIGINT '10')"); assertions.assertQuery( "SELECT sum(x) FILTER(WHERE x > 0), sum(x) FILTER(WHERE x < 3) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)", "VALUES (BIGINT '18', BIGINT '2')"); + assertions.assertQuery( + "SELECT sum(IF(x > 0, x)), sum(IF(x < 3, x)) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)", + "VALUES (BIGINT '18', BIGINT '2')"); + assertions.assertQuery( + "SELECT sum(IF(x > 0, x)), sum(x) FILTER(WHERE x < 3) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)", + "VALUES (BIGINT '18', BIGINT '2')"); assertions.assertQuery( "SELECT sum(x) FILTER(WHERE x > 1), sum(x) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)", "VALUES (BIGINT '8', BIGINT '10')"); + assertions.assertQuery( + "SELECT sum(IF(x > 1, x)), sum(x) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)", + "VALUES (BIGINT '8', BIGINT '10')"); } @Test @@ -69,11 +81,19 @@ public void testGroupAll() "SELECT count(DISTINCT x) FILTER (WHERE x > 1) " + "FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)", "VALUES BIGINT '2'"); + assertions.assertQuery( + "SELECT count(DISTINCT IF(x > 1, x)) " + + "FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)", + "VALUES BIGINT '2'"); assertions.assertQuery( "SELECT count(DISTINCT x) FILTER (WHERE x > 1), sum(DISTINCT x) " + "FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)", "VALUES (BIGINT '2', BIGINT '6')"); + assertions.assertQuery( + "SELECT count(DISTINCT IF(x > 1, x)), sum(DISTINCT x) " + + "FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)", + "VALUES (BIGINT '2', BIGINT '6')"); assertions.assertQuery( "SELECT count(DISTINCT x) FILTER (WHERE x > 1), sum(DISTINCT y) FILTER (WHERE x < 3)" + @@ -84,11 +104,24 @@ public void testGroupAll() "(2, 20)," + "(3, 30)) t(x, y)", "VALUES (BIGINT '2', BIGINT '30')"); + assertions.assertQuery( + "SELECT count(DISTINCT IF(x > 1, x)), sum(DISTINCT IF(x < 3, y)) " + + "FROM (VALUES " + + "(1, 10)," + + "(1, 20)," + + "(1, 20)," + + "(2, 20)," + + "(3, 30)) t(x, y)", + "VALUES (BIGINT '2', BIGINT '30')"); assertions.assertQuery( "SELECT count(x) FILTER (WHERE x > 1), sum(DISTINCT x) " + "FROM (VALUES 1, 2, 3, 3) t(x)", "VALUES (BIGINT '3', BIGINT '6')"); + assertions.assertQuery( + "SELECT count(IF(x > 1, x)), sum(DISTINCT x) " + + "FROM (VALUES 1, 2, 3, 3) t(x)", + "VALUES (BIGINT '3', BIGINT '6')"); } @Test @@ -113,6 +146,26 @@ public void testGroupingSets() "(1, BIGINT '2', BIGINT '1'), " + "(2, BIGINT '4', BIGINT '1'), " + "(CAST(NULL AS INTEGER), BIGINT '5', BIGINT '2')"); + + assertions.assertQuery( + "SELECT k, count(DISTINCT IF(y = 100, x)), count(DISTINCT IF(y = 200, x)) FROM " + + "(VALUES " + + " (1, 1, 100)," + + " (1, 1, 200)," + + " (1, 2, 100)," + + " (1, 3, 300)," + + " (2, 1, 100)," + + " (2, 10, 100)," + + " (2, 20, 100)," + + " (2, 20, 200)," + + " (2, 30, 300)," + + " (2, 40, 100)" + + ") t(k, x, y) " + + "GROUP BY GROUPING SETS ((), (k))", + "VALUES " + + "(1, BIGINT '2', BIGINT '1'), " + + "(2, BIGINT '4', BIGINT '1'), " + + "(CAST(NULL AS INTEGER), BIGINT '5', BIGINT '2')"); } @Test @@ -123,31 +176,41 @@ public void rewriteAddFilterWithMultipleFilters() anyTree( filter( "(\"totalprice\" > 0E0 OR \"custkey\" > BIGINT '0')", - tableScan( - "orders", ImmutableMap.of("totalprice", "totalprice", - "custkey", "custkey"))))); + tableScan("orders", ImmutableMap.of("totalprice", "totalprice", "custkey", "custkey"))))); + + assertPlan( + "SELECT sum(IF(totalprice > 0, totalprice)), sum(IF(custkey > 0, custkey)) FROM orders", + anyTree( + filter( + "(\"totalprice\" > 0E0 OR \"custkey\" > BIGINT '0')", + tableScan("orders", ImmutableMap.of("totalprice", "totalprice", "custkey", "custkey"))))); } @Test public void testDoNotPushdownPredicateIfNonFilteredAggregateIsPresent() { assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FROM orders"); + assertPlanContainsNoFilter("SELECT sum(IF(totalprice > 0, totalprice)), sum(custkey) FROM orders"); } @Test public void testPushDownConstantFilterPredicate() { assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE FALSE) FROM orders"); + assertPlanContainsNoFilter("SELECT sum(IF(FALSE, totalprice)) FROM orders"); assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE TRUE) FROM orders"); + assertPlanContainsNoFilter("SELECT sum(IF(TRUE, totalprice)) FROM orders"); } @Test public void testNoFilterAddedForConstantValueFilters() { assertPlanContainsNoFilter("SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x) GROUP BY x"); + assertPlanContainsNoFilter("SELECT sum(IF(x > 0, x)) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x) GROUP BY x"); assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0) FROM orders GROUP BY totalprice"); + assertPlanContainsNoFilter("SELECT sum(IF(totalprice > 0, totalprice)) FROM orders GROUP BY totalprice"); } private void assertPlanContainsNoFilter(String sql) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestLogicalRowExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestLogicalRowExpressions.java index 8ec289bce4146..f01cb3de42d9a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestLogicalRowExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestLogicalRowExpressions.java @@ -63,6 +63,9 @@ public class TestLogicalRowExpressions private static final RowExpression f = name("f"); private static final RowExpression g = name("g"); private static final RowExpression h = name("h"); + private static final VariableReferenceExpression V_0 = variable("v0"); + private static final VariableReferenceExpression V_1 = variable("v1"); + private static final VariableReferenceExpression V_2 = variable("v2"); @BeforeClass public void setup() @@ -95,6 +98,18 @@ public void testAnd() ImmutableList.of(a, b, c, d, e)); } + @Test + public void testAndWithSubclassOfRowExpression() + { + assertEquals( + LogicalRowExpressions.and(V_0, V_1, V_2), + and(and(V_0, V_1), V_2)); + + assertEquals( + LogicalRowExpressions.and(ImmutableList.of(V_0, V_1, V_2)), + and(and(V_0, V_1), V_2)); + } + @Test public void testOr() { @@ -115,6 +130,18 @@ public void testOr() ImmutableList.of(a, b, c, d, e)); } + @Test + public void testOrWithSubclassOfRowExpression() + { + assertEquals( + LogicalRowExpressions.or(V_0, V_1, V_2), + or(or(V_0, V_1), V_2)); + + assertEquals( + LogicalRowExpressions.or(ImmutableList.of(V_0, V_1, V_2)), + or(or(V_0, V_1), V_2)); + } + @Test public void testDeterminism() { @@ -502,6 +529,11 @@ private static RowExpression name(String name) return new VariableReferenceExpression(name, BOOLEAN); } + private static VariableReferenceExpression variable(String name) + { + return new VariableReferenceExpression(name, BOOLEAN); + } + private RowExpression compare(RowExpression left, OperatorType operator, RowExpression right) { return call(