Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public static RowExpression and(RowExpression... expressions)
return and(asList(expressions));
}

public static RowExpression and(Collection<RowExpression> expressions)
public static RowExpression and(Collection<? extends RowExpression> expressions)
{
return binaryExpression(AND, expressions);
}
Expand All @@ -132,12 +132,12 @@ public static RowExpression or(RowExpression... expressions)
return or(asList(expressions));
}

public static RowExpression or(Collection<RowExpression> expressions)
public static RowExpression or(Collection<? extends RowExpression> expressions)
{
return binaryExpression(OR, expressions);
}

public static RowExpression binaryExpression(Form form, Collection<RowExpression> expressions)
public static RowExpression binaryExpression(Form form, Collection<? extends RowExpression> expressions)
{
requireNonNull(form, "operator is null");
requireNonNull(expressions, "expressions is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ public final class SystemSessionProperties
public static final String VERBOSE_EXCEEDED_MEMORY_LIMIT_ERRORS_ENABLED = "verbose_exceeded_memory_limit_errors_enabled";
public static final String MATERIALIZED_VIEW_DATA_CONSISTENCY_ENABLED = "materialized_view_data_consistency_enabled";
public static final String QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED = "query_optimization_with_materialized_view_enabled";
public static final String AGGREGATION_IF_TO_FILTER_REWRITE_ENABLED = "aggregation_if_to_filter_rewrite_enabled";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -735,8 +736,8 @@ public SystemSessionProperties(
PARTIAL_AGGREGATION_STRATEGY,
format("Partial aggregation strategy to use. Options are %s",
Stream.of(PartialAggregationStrategy.values())
.map(PartialAggregationStrategy::name)
.collect(joining(","))),
.map(PartialAggregationStrategy::name)
.collect(joining(","))),
VARCHAR,
PartialAggregationStrategy.class,
featuresConfig.getPartialAggregationStrategy(),
Expand Down Expand Up @@ -1066,7 +1067,12 @@ public SystemSessionProperties(
QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED,
"Enable query optimization with materialized view",
featuresConfig.isQueryOptimizationWithMaterializedViewEnabled(),
true));
true),
booleanProperty(
AGGREGATION_IF_TO_FILTER_REWRITE_ENABLED,
"Enable rewriting the IF expression inside an aggregation function to a filter clause outside the aggregation",
featuresConfig.isAggregationIfToFilterRewriteEnabled(),
false));
}

public static boolean isEmptyJoinOptimization(Session session)
Expand Down Expand Up @@ -1801,4 +1807,9 @@ public static boolean isQueryOptimizationWithMaterializedViewEnabled(Session ses
{
return session.getSystemProperty(QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED, Boolean.class);
}

public static boolean isAggregationIfToFilterRewriteEnabled(Session session)
{
return session.getSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_ENABLED, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ public class FeaturesConfig
private boolean materializedViewDataConsistencyEnabled = true;

private boolean queryOptimizationWithMaterializedViewEnabled;
private boolean aggregationIfToFilterRewriteEnabled = true;

public enum PartitioningPrecisionStrategy
{
Expand Down Expand Up @@ -1786,4 +1787,17 @@ public FeaturesConfig setQueryOptimizationWithMaterializedViewEnabled(boolean va
this.queryOptimizationWithMaterializedViewEnabled = value;
return this;
}

public boolean isAggregationIfToFilterRewriteEnabled()
{
return aggregationIfToFilterRewriteEnabled;
}

@Config("optimizer.aggregation-if-to-filter-rewrite-enabled")
@ConfigDescription("Enable rewriting the IF expression inside an aggregation function to a filter clause outside the aggregation")
public FeaturesConfig setAggregationIfToFilterRewriteEnabled(boolean value)
{
this.aggregationIfToFilterRewriteEnabled = value;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes;
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters;
import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins;
import com.facebook.presto.sql.planner.iterative.rule.RewriteAggregationIfToFilter;
import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject;
import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation;
import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides;
Expand Down Expand Up @@ -399,6 +400,11 @@ public PlanOptimizers(
// After this point, all planNodes should not contain OriginalExpression

builder.add(
new IterativeOptimizer(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new RewriteAggregationIfToFilter())),
predicatePushDown,
new IterativeOptimizer(
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedSet;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.isAggregationIfToFilterRewriteEnabled;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.or;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap;
import static java.util.function.Function.identity;

/**
* A optimizer rule which rewrites
* AGG(IF(condition, expr))
* to
* AGG(expr) FILTER (WHERE condition).
* <p>
* The latter plan is more efficient because:
* 1. The filter can be pushed down to the scan node.
* 2. The rows not matching the condition are not aggregated.
* 3. The IF() expression wrapper is removed.
*/
public class RewriteAggregationIfToFilter
implements Rule<AggregationNode>
{
private static final Capture<ProjectNode> CHILD = newCapture();

private static final Pattern<AggregationNode> PATTERN = aggregation()
.with(source().matching(project().capturedAs(CHILD)));

@Override
public boolean isEnabled(Session session)
{
return isAggregationIfToFilterRewriteEnabled(session);
}

@Override
public Pattern<AggregationNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
{
ProjectNode sourceProject = captures.get(CHILD);

Set<Aggregation> aggregationsToRewrite = aggregationNode.getAggregations().values().stream()
.filter(aggregation -> shouldRewriteAggregation(aggregation, sourceProject))
.collect(toImmutableSet());
if (aggregationsToRewrite.isEmpty()) {
return Result.empty();
}

// Get the corresponding assignments in the input project.
// The aggregationReferences only has the aggregations to rewrite, thus the sourceAssignments only has IF expressions with NULL false results.
// Multiple aggregations may reference the same input. We use a map to dedup them based on the VariableReferenceExpression, so that we only do the rewrite once per input
// IF expression.
// The order of sourceAssignments determines the order of generating the new variables for the IF conditions and results. We use a sorted map to get a deterministic
// order based on the name of the VariableReferenceExpressions.
Map<VariableReferenceExpression, RowExpression> sourceAssignments = aggregationsToRewrite.stream()
.map(aggregation -> (VariableReferenceExpression) aggregation.getArguments().get(0))
.collect(toImmutableSortedMap(VariableReferenceExpression::compareTo, identity(), variable -> sourceProject.getAssignments().get(variable), (left, right) -> left));

Assignments.Builder newAssignments = Assignments.builder();
// We don't remove the IF expression now in case the aggregation has other references to it. These will be cleaned up by the PruneUnreferencedOutputs rule later.
newAssignments.putAll(sourceProject.getAssignments());

// Map from the aggregation reference to the IF condition reference.
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToConditionReference = new HashMap<>();
// Map from the aggregation reference to the IF result reference.
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToIfResultReference = new HashMap<>();

for (Map.Entry<VariableReferenceExpression, RowExpression> entry : sourceAssignments.entrySet()) {
VariableReferenceExpression outputVariable = entry.getKey();
SpecialFormExpression ifExpression = (SpecialFormExpression) entry.getValue();

RowExpression condition = ifExpression.getArguments().get(0);
VariableReferenceExpression conditionReference = context.getVariableAllocator().newVariable(condition);
newAssignments.put(conditionReference, condition);
aggregationReferenceToConditionReference.put(outputVariable, conditionReference);

RowExpression trueResult = ifExpression.getArguments().get(1);
VariableReferenceExpression ifResultReference = context.getVariableAllocator().newVariable(trueResult);
newAssignments.put(ifResultReference, trueResult);
aggregationReferenceToIfResultReference.put(outputVariable, ifResultReference);
}

// Build new aggregations.
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
// Stores the masks used to build the filter predicates. Use set to dedup the predicates.
ImmutableSortedSet.Builder<VariableReferenceExpression> masks = ImmutableSortedSet.naturalOrder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
VariableReferenceExpression output = entry.getKey();
Aggregation aggregation = entry.getValue();
if (!aggregationsToRewrite.contains(aggregation)) {
aggregations.put(output, aggregation);
continue;
}
VariableReferenceExpression aggregationReference = (VariableReferenceExpression) aggregation.getArguments().get(0);
CallExpression callExpression = aggregation.getCall();
VariableReferenceExpression mask = aggregationReferenceToConditionReference.get(aggregationReference);
aggregations.put(output, new Aggregation(
new CallExpression(
callExpression.getDisplayName(),
callExpression.getFunctionHandle(),
callExpression.getType(),
ImmutableList.of(aggregationReferenceToIfResultReference.get(aggregationReference))),
Optional.empty(),
aggregation.getOrderBy(),
aggregation.isDistinct(),
Optional.of(aggregationReferenceToConditionReference.get(aggregationReference))));
masks.add(mask);
}

RowExpression predicate = TRUE_CONSTANT;
if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) {
// All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient.
predicate = or(masks.build());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This or might actually cause more slowdown than help because all masks have to be evaluated for every row anyway. I don't think this helps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it helps because

  • The filter could be pushed down to the scan node which might be able to evaluated very efficiently, e.g., if it is on some partition columns, the partition columns are using dictionary encoding. This only need to evaluated once per dictionary item.
  • The filter can be used to prune the partitions/splits based on column stats.

Besides, the AGG() FILTER implementation also adds this predicate. It is better to keep the same behavior:

predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL);

}
return Result.ofPlanNode(
new AggregationNode(
context.getIdAllocator().getNextId(),
new FilterNode(
context.getIdAllocator().getNextId(),
new ProjectNode(
context.getIdAllocator().getNextId(),
sourceProject.getSource(),
newAssignments.build()),
predicate),
aggregations.build(),
aggregationNode.getGroupingSets(),
aggregationNode.getPreGroupedVariables(),
aggregationNode.getStep(),
aggregationNode.getHashVariable(),
aggregationNode.getGroupIdVariable()));
}

private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode sourceProject)
{
if (!(aggregation.getArguments().size() == 1 && aggregation.getArguments().get(0) instanceof VariableReferenceExpression)) {
// Currently we only handle aggregation with a single VariableReferenceExpression. The detailed expressions are in a project node below this aggregation.
return false;
}
if (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()) {
// Do not rewrite the aggregation if it already has a filter or mask.
return false;
}
RowExpression sourceExpression = sourceProject.getAssignments().get((VariableReferenceExpression) aggregation.getArguments().get(0));
if (!(sourceExpression instanceof SpecialFormExpression)) {
return false;
}
SpecialFormExpression expression = (SpecialFormExpression) sourceExpression;
// Only rewrite the aggregation if the else branch is not present.
return expression.getForm() == IF && Expressions.isNull(expression.getArguments().get(2));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also add when the else part is the null literal like IF(x, y, null)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ public void testDefaults()
.setOffsetClauseEnabled(false)
.setPartialResultsMaxExecutionTimeMultiplier(2.0)
.setMaterializedViewDataConsistencyEnabled(true)
.setQueryOptimizationWithMaterializedViewEnabled(false));
.setQueryOptimizationWithMaterializedViewEnabled(false)
.setAggregationIfToFilterRewriteEnabled(true));
}

@Test
Expand Down Expand Up @@ -301,6 +302,7 @@ public void testExplicitPropertyMappings()
.put("offset-clause-enabled", "true")
.put("materialized-view-data-consistency-enabled", "false")
.put("query-optimization-with-materialized-view-enabled", "true")
.put("optimizer.aggregation-if-to-filter-rewrite-enabled", "false")
.build();

FeaturesConfig expected = new FeaturesConfig()
Expand Down Expand Up @@ -424,7 +426,8 @@ public void testExplicitPropertyMappings()
.setOffsetClauseEnabled(true)
.setPartialResultsMaxExecutionTimeMultiplier(1.5)
.setMaterializedViewDataConsistencyEnabled(false)
.setQueryOptimizationWithMaterializedViewEnabled(true);
.setQueryOptimizationWithMaterializedViewEnabled(true)
.setAggregationIfToFilterRewriteEnabled(false);
assertFullMapping(properties, expected);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses
List<VariableReferenceExpression> aggregationsWithMask = aggregationNode.getAggregations()
.entrySet()
.stream()
.filter(entry -> entry.getValue().isDistinct())
.filter(entry -> entry.getValue().getMask().isPresent())
.map(Map.Entry::getKey)
.collect(Collectors.toList());

Expand Down
Loading