Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ public final class SystemSessionProperties
public static final String REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION = "rewrite_expression_with_constant_expression";
public static final String PRINT_ESTIMATED_STATS_FROM_CACHE = "print_estimated_stats_from_cache";
public static final String REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT = "remove_cross_join_with_constant_single_row_input";
public static final String OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT = "optimize_conditional_constant_approximate_distinct";
public static final String EAGER_PLAN_VALIDATION_ENABLED = "eager_plan_validation_enabled";
public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode";
public static final String JOIN_PREFILTER_BUILD_SIDE = "join_prefilter_build_side";
Expand Down Expand Up @@ -1904,6 +1905,11 @@ public SystemSessionProperties(
"Enable adding an exchange below partial aggregation over a GroupId node to improve partial aggregation performance",
featuresConfig.getAddExchangeBelowPartialAggregationOverGroupId(),
false),
booleanProperty(
OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT,
"Optimize out APPROX_DISTINCT operations over constant conditionals",
featuresConfig.isOptimizeConditionalApproxDistinct(),
false),
new PropertyMetadata<>(
QUERY_CLIENT_TIMEOUT,
"Configures how long the query runs without contact from the client application, such as the CLI, before it's abandoned",
Expand Down Expand Up @@ -3267,4 +3273,9 @@ public static Duration getQueryClientTimeout(Session session)
{
return session.getSystemProperty(QUERY_CLIENT_TIMEOUT, Duration.class);
}

public static boolean isOptimizeConditionalApproxDistinctEnabled(Session session)
{
return session.getSystemProperty(OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ public class FeaturesConfig
private boolean pullUpExpressionFromLambda;
private boolean rewriteConstantArrayContainsToIn;
private boolean rewriteExpressionWithConstantVariable = true;
private boolean optimizeConditionalApproxDistinct = true;

private boolean preProcessMetadataCalls;
private boolean handleComplexEquiJoins;
Expand Down Expand Up @@ -2787,6 +2788,19 @@ public FeaturesConfig setRewriteExpressionWithConstantVariable(boolean rewriteEx
return this;
}

public boolean isOptimizeConditionalApproxDistinct()
{
return this.optimizeConditionalApproxDistinct;
}

@Config("optimizer.optimize-constant-approx-distinct")
@ConfigDescription("Optimize out APPROX_DISTINCT over conditional constant expressions")
public FeaturesConfig setOptimizeConditionalApproxDistinct(boolean optimizeConditionalApproxDistinct)
{
this.optimizeConditionalApproxDistinct = optimizeConditionalApproxDistinct;
return this;
}

public CreateView.Security getDefaultViewSecurityMode()
{
return this.defaultViewSecurityMode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes;
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters;
import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins;
import com.facebook.presto.sql.planner.iterative.rule.ReplaceConditionalApproxDistinct;
import com.facebook.presto.sql.planner.iterative.rule.RewriteAggregationIfToFilter;
import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseExpressionPredicate;
import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseToMap;
Expand Down Expand Up @@ -455,6 +456,12 @@ public PlanOptimizers(
new ReplaceConstantVariableReferencesWithConstants(metadata.getFunctionAndTypeManager()),
simplifyRowExpressionOptimizer,
new ReplaceConstantVariableReferencesWithConstants(metadata.getFunctionAndTypeManager()),
new IterativeOptimizer(
metadata,
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new ReplaceConditionalApproxDistinct(metadata.getFunctionAndTypeManager()))),
new IterativeOptimizer(
metadata,
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.Map.Entry;

import static com.facebook.presto.SystemSessionProperties.isOptimizeConditionalApproxDistinctEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.constantNull;
import static com.facebook.presto.sql.relational.Expressions.isNull;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

/**
* elimination of approx distinct on conditional constant values.
* <p>
* depending on the inner conditional, the expression is converted
* to its equivalent arbitrary() expression.
*
* - approx_distinct(if(..., non-null)) -> arbitrary(if(..., 1, NULL))
* - approx_distinct(if(..., null, non-null)) -> arbitrary(if(..., NULL, 1))
* - approx_distinct(if(..., null, null)) -> arbitrary(0)
*
* An intermediate projection is inserted to convert any NULL arbitrary output
* to zero values.
*/
public class ReplaceConditionalApproxDistinct
implements Rule<AggregationNode>
{
private static final Capture<ProjectNode> SOURCE = Capture.newCapture();

private static final Pattern<AggregationNode> PATTERN = aggregation()
.with(source().matching(project().capturedAs(SOURCE)));

private final StandardFunctionResolution functionResolution;

private static final String ARBITRARY = "arbitrary";

public ReplaceConditionalApproxDistinct(FunctionAndTypeManager functionAndTypeManager)
{
requireNonNull(functionAndTypeManager, "functionManager is null");
this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
}

@Override
public boolean isEnabled(Session session)
{
return isOptimizeConditionalApproxDistinctEnabled(session);
}

@Override
public Pattern<AggregationNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(AggregationNode parent, Captures captures, Context context)
{
VariableAllocator variableAllocator = context.getVariableAllocator();
boolean changed = false;
ProjectNode project = captures.get(SOURCE);
Assignments.Builder outputs = Assignments.builder();
Assignments.Builder inputs = Assignments.builder();

ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregations = ImmutableMap.builder();
for (Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
VariableReferenceExpression variable = entry.getKey();
AggregationNode.Aggregation aggregation = entry.getValue();
SpecialFormExpression replaced;
VariableReferenceExpression intermediate;
VariableReferenceExpression expression;

if (!isApproxDistinct(aggregation) || !aggregationIsReplaceable(aggregation, project.getAssignments())) {
aggregations.put(variable, aggregation);
outputs.put(variable, variable);
continue;
}
changed = true;
replaced = (SpecialFormExpression) project.getAssignments().get(
(VariableReferenceExpression) aggregation.getArguments().get(0));

expression = variableAllocator.newVariable("expression", BIGINT);
inputs.put(expression, replaceIfExpression(replaced));

intermediate = variableAllocator.newVariable("intermediate", BIGINT);
aggregations.put(intermediate, new AggregationNode.Aggregation(
new CallExpression(
aggregation.getCall().getSourceLocation(),
ARBITRARY,
functionResolution.arbitraryFunction(BIGINT),
BIGINT,
ImmutableList.of(expression)),
aggregation.getFilter(),
aggregation.getOrderBy(),
aggregation.isDistinct(),
aggregation.getMask()));

outputs.put(variable, new SpecialFormExpression(
COALESCE,
BIGINT,
ImmutableList.of(
intermediate,
constant(0L, BIGINT))));
}

if (!changed) {
return Result.empty();
}

ProjectNode child = new ProjectNode(
project.getSourceLocation(),
context.getIdAllocator().getNextId(),
project.getSource(),
inputs.putAll(project.getAssignments()).build(),
project.getLocality());

AggregationNode aggregation = new AggregationNode(
parent.getSourceLocation(),
context.getIdAllocator().getNextId(),
child,
aggregations.build(),
parent.getGroupingSets(),
ImmutableList.of(),
parent.getStep(),
parent.getHashVariable(),
parent.getGroupIdVariable(),
parent.getAggregationId());

aggregation.getHashVariable().ifPresent(hashvariable -> outputs.put(hashvariable, hashvariable));
aggregation.getGroupingSets().getGroupingKeys().forEach(groupingKey -> outputs.put(groupingKey, groupingKey));
return Result.ofPlanNode(new ProjectNode(
context.getIdAllocator().getNextId(),
aggregation,
outputs.build()));
}

private boolean isApproxDistinct(AggregationNode.Aggregation aggregation)
{
return functionResolution.isApproximateCountDistinctFunction(aggregation.getFunctionHandle());
}

private ConstantExpression convertConstant(ConstantExpression expression)
{
return isNull(expression) ? constantNull(BIGINT) : constant(1L, BIGINT);
}

private RowExpression replaceIfExpression(SpecialFormExpression ifCondition)
{
ConstantExpression trueThen = (ConstantExpression) ifCondition.getArguments().get(1);
ConstantExpression falseThen = (ConstantExpression) ifCondition.getArguments().get(2);
RowExpression replace;

if ((isNull(trueThen) && !isNull(falseThen)) || (!isNull(trueThen) && isNull(falseThen))) {
// if(..., null, non-null) or if(..., non-null, null)
replace = new SpecialFormExpression(
ifCondition.getSourceLocation(),
IF,
BIGINT,
ImmutableList.of(
ifCondition.getArguments().get(0),
convertConstant(trueThen),
convertConstant(falseThen)));
}
else {
// if(..., null, null)
checkState(isNull(trueThen) && isNull(falseThen),
"expected true (%s) and false (%s) predicates to be null",
trueThen, falseThen);
replace = convertConstant(trueThen);
}
return replace;
}

private boolean aggregationIsReplaceable(AggregationNode.Aggregation aggregation, Assignments inputs)
{
RowExpression argument = aggregation.getArguments().get(0);
RowExpression ifCondition = null;
RowExpression trueThen = null;
RowExpression falseThen = null;

if (argument instanceof VariableReferenceExpression) {
ifCondition = inputs.get((VariableReferenceExpression) argument);
}

if (ifCondition instanceof SpecialFormExpression && ((SpecialFormExpression) ifCondition).getForm() == IF) {
trueThen = ((SpecialFormExpression) ifCondition).getArguments().get(1);
falseThen = ((SpecialFormExpression) ifCondition).getArguments().get(2);
}

return trueThen instanceof ConstantExpression &&
falseThen instanceof ConstantExpression &&
(isNull(trueThen) || isNull(falseThen));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,12 @@ public boolean isMinByFunction(FunctionHandle functionHandle)
return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("min_by")));
}

@Override
public FunctionHandle arbitraryFunction(Type valueType)
{
return functionAndTypeResolver.lookupFunction("arbitrary", fromTypes(valueType));
}

@Override
public boolean isMaxFunction(FunctionHandle functionHandle)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ public void testDefaults()
.setCteFilterAndProjectionPushdownEnabled(true)
.setGenerateDomainFilters(false)
.setRewriteExpressionWithConstantVariable(true)
.setOptimizeConditionalApproxDistinct(true)
.setDefaultWriterReplicationCoefficient(3.0)
.setDefaultViewSecurityMode(DEFINER)
.setCteHeuristicReplicationThreshold(4)
Expand Down Expand Up @@ -449,6 +450,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.skip-hash-generation-for-join-with-table-scan-input", "true")
.put("optimizer.generate-domain-filters", "true")
.put("optimizer.rewrite-expression-with-constant-variable", "false")
.put("optimizer.optimize-constant-approx-distinct", "false")
.put("optimizer.default-writer-replication-coefficient", "5.0")
.put("default-view-security-mode", INVOKER.name())
.put("cte-heuristic-replication-threshold", "2")
Expand Down Expand Up @@ -656,6 +658,7 @@ public void testExplicitPropertyMappings()
.setCteFilterAndProjectionPushdownEnabled(false)
.setGenerateDomainFilters(true)
.setRewriteExpressionWithConstantVariable(false)
.setOptimizeConditionalApproxDistinct(false)
.setDefaultWriterReplicationCoefficient(5.0)
.setDefaultViewSecurityMode(INVOKER)
.setCteHeuristicReplicationThreshold(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ private static boolean verifyAggregationOrderBy(OrderingScheme orderingScheme, O
private static boolean isEquivalent(Optional<Expression> expression, Optional<RowExpression> rowExpression)
{
// Function's argument provided by FunctionCallProvider is SymbolReference that already resolved from symbolAliases.
if (rowExpression.isPresent() && expression.isPresent()) {
checkArgument(rowExpression.get() instanceof VariableReferenceExpression, "can only process variableReference");
if (rowExpression.isPresent() && expression.isPresent() && !(expression.get() instanceof AnySymbolReference)) {
checkArgument(rowExpression.get() instanceof VariableReferenceExpression, "can only process variableReference: " + rowExpression.get());
return expression.get().equals(createSymbolReference(((VariableReferenceExpression) rowExpression.get())));
}
return rowExpression.isPresent() == expression.isPresent();
Expand Down
Loading
Loading