diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java
index 34e9330eab12e..2953460c3a461 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java
@@ -320,6 +320,7 @@ public final class SystemSessionProperties
public static final String REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION = "rewrite_expression_with_constant_expression";
public static final String PRINT_ESTIMATED_STATS_FROM_CACHE = "print_estimated_stats_from_cache";
public static final String REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT = "remove_cross_join_with_constant_single_row_input";
+ public static final String OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT = "optimize_conditional_constant_approximate_distinct";
public static final String EAGER_PLAN_VALIDATION_ENABLED = "eager_plan_validation_enabled";
public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode";
public static final String JOIN_PREFILTER_BUILD_SIDE = "join_prefilter_build_side";
@@ -1904,6 +1905,11 @@ public SystemSessionProperties(
"Enable adding an exchange below partial aggregation over a GroupId node to improve partial aggregation performance",
featuresConfig.getAddExchangeBelowPartialAggregationOverGroupId(),
false),
+ booleanProperty(
+ OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT,
+ "Optimize out APPROX_DISTINCT operations over constant conditionals",
+ featuresConfig.isOptimizeConditionalApproxDistinct(),
+ false),
new PropertyMetadata<>(
QUERY_CLIENT_TIMEOUT,
"Configures how long the query runs without contact from the client application, such as the CLI, before it's abandoned",
@@ -3267,4 +3273,9 @@ public static Duration getQueryClientTimeout(Session session)
{
return session.getSystemProperty(QUERY_CLIENT_TIMEOUT, Duration.class);
}
+
+ public static boolean isOptimizeConditionalApproxDistinctEnabled(Session session)
+ {
+ return session.getSystemProperty(OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, Boolean.class);
+ }
}
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
index 46f6ecf466d93..a92a1ab8c7faf 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
@@ -269,6 +269,7 @@ public class FeaturesConfig
private boolean pullUpExpressionFromLambda;
private boolean rewriteConstantArrayContainsToIn;
private boolean rewriteExpressionWithConstantVariable = true;
+ private boolean optimizeConditionalApproxDistinct = true;
private boolean preProcessMetadataCalls;
private boolean handleComplexEquiJoins;
@@ -2787,6 +2788,19 @@ public FeaturesConfig setRewriteExpressionWithConstantVariable(boolean rewriteEx
return this;
}
+ public boolean isOptimizeConditionalApproxDistinct()
+ {
+ return this.optimizeConditionalApproxDistinct;
+ }
+
+ @Config("optimizer.optimize-constant-approx-distinct")
+ @ConfigDescription("Optimize out APPROX_DISTINCT over conditional constant expressions")
+ public FeaturesConfig setOptimizeConditionalApproxDistinct(boolean optimizeConditionalApproxDistinct)
+ {
+ this.optimizeConditionalApproxDistinct = optimizeConditionalApproxDistinct;
+ return this;
+ }
+
public CreateView.Security getDefaultViewSecurityMode()
{
return this.defaultViewSecurityMode;
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
index b017a7e1c8247..958375430b3c5 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
@@ -125,6 +125,7 @@
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes;
import com.facebook.presto.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters;
import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins;
+import com.facebook.presto.sql.planner.iterative.rule.ReplaceConditionalApproxDistinct;
import com.facebook.presto.sql.planner.iterative.rule.RewriteAggregationIfToFilter;
import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseExpressionPredicate;
import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseToMap;
@@ -455,6 +456,12 @@ public PlanOptimizers(
new ReplaceConstantVariableReferencesWithConstants(metadata.getFunctionAndTypeManager()),
simplifyRowExpressionOptimizer,
new ReplaceConstantVariableReferencesWithConstants(metadata.getFunctionAndTypeManager()),
+ new IterativeOptimizer(
+ metadata,
+ ruleStats,
+ statsCalculator,
+ estimatedExchangesCostCalculator,
+ ImmutableSet.of(new ReplaceConditionalApproxDistinct(metadata.getFunctionAndTypeManager()))),
new IterativeOptimizer(
metadata,
ruleStats,
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReplaceConditionalApproxDistinct.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReplaceConditionalApproxDistinct.java
new file mode 100644
index 0000000000000..6079ed2bd45b7
--- /dev/null
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReplaceConditionalApproxDistinct.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Capture;
+import com.facebook.presto.matching.Captures;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.metadata.FunctionAndTypeManager;
+import com.facebook.presto.spi.VariableAllocator;
+import com.facebook.presto.spi.function.StandardFunctionResolution;
+import com.facebook.presto.spi.plan.AggregationNode;
+import com.facebook.presto.spi.plan.Assignments;
+import com.facebook.presto.spi.plan.ProjectNode;
+import com.facebook.presto.spi.relation.CallExpression;
+import com.facebook.presto.spi.relation.ConstantExpression;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.spi.relation.SpecialFormExpression;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.relational.FunctionResolution;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+import java.util.Map.Entry;
+
+import static com.facebook.presto.SystemSessionProperties.isOptimizeConditionalApproxDistinctEnabled;
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE;
+import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
+import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
+import static com.facebook.presto.sql.planner.plan.Patterns.project;
+import static com.facebook.presto.sql.planner.plan.Patterns.source;
+import static com.facebook.presto.sql.relational.Expressions.constant;
+import static com.facebook.presto.sql.relational.Expressions.constantNull;
+import static com.facebook.presto.sql.relational.Expressions.isNull;
+import static com.google.common.base.Preconditions.checkState;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * elimination of approx distinct on conditional constant values.
+ *
+ * depending on the inner conditional, the expression is converted
+ * to its equivalent arbitrary() expression.
+ *
+ * - approx_distinct(if(..., non-null)) -> arbitrary(if(..., 1, NULL))
+ * - approx_distinct(if(..., null, non-null)) -> arbitrary(if(..., NULL, 1))
+ * - approx_distinct(if(..., null, null)) -> arbitrary(0)
+ *
+ * An intermediate projection is inserted to convert any NULL arbitrary output
+ * to zero values.
+ */
+public class ReplaceConditionalApproxDistinct
+ implements Rule
+{
+ private static final Capture SOURCE = Capture.newCapture();
+
+ private static final Pattern PATTERN = aggregation()
+ .with(source().matching(project().capturedAs(SOURCE)));
+
+ private final StandardFunctionResolution functionResolution;
+
+ private static final String ARBITRARY = "arbitrary";
+
+ public ReplaceConditionalApproxDistinct(FunctionAndTypeManager functionAndTypeManager)
+ {
+ requireNonNull(functionAndTypeManager, "functionManager is null");
+ this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
+ }
+
+ @Override
+ public boolean isEnabled(Session session)
+ {
+ return isOptimizeConditionalApproxDistinctEnabled(session);
+ }
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public Result apply(AggregationNode parent, Captures captures, Context context)
+ {
+ VariableAllocator variableAllocator = context.getVariableAllocator();
+ boolean changed = false;
+ ProjectNode project = captures.get(SOURCE);
+ Assignments.Builder outputs = Assignments.builder();
+ Assignments.Builder inputs = Assignments.builder();
+
+ ImmutableMap.Builder aggregations = ImmutableMap.builder();
+ for (Entry entry : parent.getAggregations().entrySet()) {
+ VariableReferenceExpression variable = entry.getKey();
+ AggregationNode.Aggregation aggregation = entry.getValue();
+ SpecialFormExpression replaced;
+ VariableReferenceExpression intermediate;
+ VariableReferenceExpression expression;
+
+ if (!isApproxDistinct(aggregation) || !aggregationIsReplaceable(aggregation, project.getAssignments())) {
+ aggregations.put(variable, aggregation);
+ outputs.put(variable, variable);
+ continue;
+ }
+ changed = true;
+ replaced = (SpecialFormExpression) project.getAssignments().get(
+ (VariableReferenceExpression) aggregation.getArguments().get(0));
+
+ expression = variableAllocator.newVariable("expression", BIGINT);
+ inputs.put(expression, replaceIfExpression(replaced));
+
+ intermediate = variableAllocator.newVariable("intermediate", BIGINT);
+ aggregations.put(intermediate, new AggregationNode.Aggregation(
+ new CallExpression(
+ aggregation.getCall().getSourceLocation(),
+ ARBITRARY,
+ functionResolution.arbitraryFunction(BIGINT),
+ BIGINT,
+ ImmutableList.of(expression)),
+ aggregation.getFilter(),
+ aggregation.getOrderBy(),
+ aggregation.isDistinct(),
+ aggregation.getMask()));
+
+ outputs.put(variable, new SpecialFormExpression(
+ COALESCE,
+ BIGINT,
+ ImmutableList.of(
+ intermediate,
+ constant(0L, BIGINT))));
+ }
+
+ if (!changed) {
+ return Result.empty();
+ }
+
+ ProjectNode child = new ProjectNode(
+ project.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ project.getSource(),
+ inputs.putAll(project.getAssignments()).build(),
+ project.getLocality());
+
+ AggregationNode aggregation = new AggregationNode(
+ parent.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ child,
+ aggregations.build(),
+ parent.getGroupingSets(),
+ ImmutableList.of(),
+ parent.getStep(),
+ parent.getHashVariable(),
+ parent.getGroupIdVariable(),
+ parent.getAggregationId());
+
+ aggregation.getHashVariable().ifPresent(hashvariable -> outputs.put(hashvariable, hashvariable));
+ aggregation.getGroupingSets().getGroupingKeys().forEach(groupingKey -> outputs.put(groupingKey, groupingKey));
+ return Result.ofPlanNode(new ProjectNode(
+ context.getIdAllocator().getNextId(),
+ aggregation,
+ outputs.build()));
+ }
+
+ private boolean isApproxDistinct(AggregationNode.Aggregation aggregation)
+ {
+ return functionResolution.isApproximateCountDistinctFunction(aggregation.getFunctionHandle());
+ }
+
+ private ConstantExpression convertConstant(ConstantExpression expression)
+ {
+ return isNull(expression) ? constantNull(BIGINT) : constant(1L, BIGINT);
+ }
+
+ private RowExpression replaceIfExpression(SpecialFormExpression ifCondition)
+ {
+ ConstantExpression trueThen = (ConstantExpression) ifCondition.getArguments().get(1);
+ ConstantExpression falseThen = (ConstantExpression) ifCondition.getArguments().get(2);
+ RowExpression replace;
+
+ if ((isNull(trueThen) && !isNull(falseThen)) || (!isNull(trueThen) && isNull(falseThen))) {
+ // if(..., null, non-null) or if(..., non-null, null)
+ replace = new SpecialFormExpression(
+ ifCondition.getSourceLocation(),
+ IF,
+ BIGINT,
+ ImmutableList.of(
+ ifCondition.getArguments().get(0),
+ convertConstant(trueThen),
+ convertConstant(falseThen)));
+ }
+ else {
+ // if(..., null, null)
+ checkState(isNull(trueThen) && isNull(falseThen),
+ "expected true (%s) and false (%s) predicates to be null",
+ trueThen, falseThen);
+ replace = convertConstant(trueThen);
+ }
+ return replace;
+ }
+
+ private boolean aggregationIsReplaceable(AggregationNode.Aggregation aggregation, Assignments inputs)
+ {
+ RowExpression argument = aggregation.getArguments().get(0);
+ RowExpression ifCondition = null;
+ RowExpression trueThen = null;
+ RowExpression falseThen = null;
+
+ if (argument instanceof VariableReferenceExpression) {
+ ifCondition = inputs.get((VariableReferenceExpression) argument);
+ }
+
+ if (ifCondition instanceof SpecialFormExpression && ((SpecialFormExpression) ifCondition).getForm() == IF) {
+ trueThen = ((SpecialFormExpression) ifCondition).getArguments().get(1);
+ falseThen = ((SpecialFormExpression) ifCondition).getArguments().get(2);
+ }
+
+ return trueThen instanceof ConstantExpression &&
+ falseThen instanceof ConstantExpression &&
+ (isNull(trueThen) || isNull(falseThen));
+ }
+}
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java
index f971ff1027de6..94b00e4812a25 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java
@@ -340,6 +340,12 @@ public boolean isMinByFunction(FunctionHandle functionHandle)
return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("min_by")));
}
+ @Override
+ public FunctionHandle arbitraryFunction(Type valueType)
+ {
+ return functionAndTypeResolver.lookupFunction("arbitrary", fromTypes(valueType));
+ }
+
@Override
public boolean isMaxFunction(FunctionHandle functionHandle)
{
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
index c764262338fab..db27a9c9a0a07 100644
--- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
@@ -239,6 +239,7 @@ public void testDefaults()
.setCteFilterAndProjectionPushdownEnabled(true)
.setGenerateDomainFilters(false)
.setRewriteExpressionWithConstantVariable(true)
+ .setOptimizeConditionalApproxDistinct(true)
.setDefaultWriterReplicationCoefficient(3.0)
.setDefaultViewSecurityMode(DEFINER)
.setCteHeuristicReplicationThreshold(4)
@@ -449,6 +450,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.skip-hash-generation-for-join-with-table-scan-input", "true")
.put("optimizer.generate-domain-filters", "true")
.put("optimizer.rewrite-expression-with-constant-variable", "false")
+ .put("optimizer.optimize-constant-approx-distinct", "false")
.put("optimizer.default-writer-replication-coefficient", "5.0")
.put("default-view-security-mode", INVOKER.name())
.put("cte-heuristic-replication-threshold", "2")
@@ -656,6 +658,7 @@ public void testExplicitPropertyMappings()
.setCteFilterAndProjectionPushdownEnabled(false)
.setGenerateDomainFilters(true)
.setRewriteExpressionWithConstantVariable(false)
+ .setOptimizeConditionalApproxDistinct(false)
.setDefaultWriterReplicationCoefficient(5.0)
.setDefaultViewSecurityMode(INVOKER)
.setCteHeuristicReplicationThreshold(2)
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java
index fa61a6354e26a..19621f3ce1f97 100644
--- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java
@@ -117,8 +117,8 @@ private static boolean verifyAggregationOrderBy(OrderingScheme orderingScheme, O
private static boolean isEquivalent(Optional expression, Optional rowExpression)
{
// Function's argument provided by FunctionCallProvider is SymbolReference that already resolved from symbolAliases.
- if (rowExpression.isPresent() && expression.isPresent()) {
- checkArgument(rowExpression.get() instanceof VariableReferenceExpression, "can only process variableReference");
+ if (rowExpression.isPresent() && expression.isPresent() && !(expression.get() instanceof AnySymbolReference)) {
+ checkArgument(rowExpression.get() instanceof VariableReferenceExpression, "can only process variableReference: " + rowExpression.get());
return expression.get().equals(createSymbolReference(((VariableReferenceExpression) rowExpression.get())));
}
return rowExpression.isPresent() == expression.isPresent();
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReplaceConditionalApproxDistinct.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReplaceConditionalApproxDistinct.java
new file mode 100644
index 0000000000000..8cb9014afafa2
--- /dev/null
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReplaceConditionalApproxDistinct.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.spi.plan.AggregationNode;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.parser.ParsingOptions;
+import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.testng.annotations.Test;
+
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
+import static com.facebook.presto.common.type.DoubleType.DOUBLE;
+import static com.facebook.presto.common.type.VarcharType.VARCHAR;
+import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
+
+public class TestReplaceConditionalApproxDistinct
+ extends BaseRuleTest
+{
+ @Test
+ public void testReplaceConditionalConstant()
+ {
+ tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager()))
+ .on(p -> {
+ VariableReferenceExpression original = p.variable("original", BOOLEAN);
+ VariableReferenceExpression a = p.variable("a", BIGINT);
+ VariableReferenceExpression b = p.variable("b", BIGINT);
+ return p.aggregation((builder) -> builder
+ .addAggregation(
+ p.variable("output"),
+ p.rowExpression("approx_distinct(original)"))
+ .globalGrouping()
+ .step(AggregationNode.Step.SINGLE)
+ .source(
+ p.project(
+ p.assignment(
+ original, p.rowExpression("if(a > b, 'constant')")),
+ p.values(a, b))));
+ })
+ .matches(
+ project(
+ ImmutableMap.of(
+ "output", expression("coalesce(intermediate, 0)")),
+ aggregation(
+ ImmutableMap.of("intermediate",
+ functionCall("arbitrary", ImmutableList.of("expression"))),
+ SINGLE,
+ project(
+ ImmutableMap.of(
+ "original", expression("if(a > b, 'constant')"),
+ "expression", expression("if(a > b, 1, NULL)")),
+ values("a", "b")))));
+ }
+
+ @Test
+ public void testReplaceConditionalErrorBounds()
+ {
+ tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager()))
+ .on(p -> {
+ VariableReferenceExpression original = p.variable("original", BOOLEAN);
+ VariableReferenceExpression a = p.variable("a", BIGINT);
+ VariableReferenceExpression b = p.variable("b", BIGINT);
+ VariableReferenceExpression bounds = p.variable("bounds", DOUBLE);
+ return p.aggregation((builder) -> builder
+ .addAggregation(
+ p.variable("output"),
+ p.rowExpression("approx_distinct(original, bounds)"))
+ .globalGrouping()
+ .step(AggregationNode.Step.SINGLE)
+ .source(
+ p.project(
+ p.assignment(
+ original, p.rowExpression("if(a > b, 'constant')"),
+ bounds, p.rowExpression("0.0040625", ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE)),
+ p.values(a, b))));
+ })
+ .matches(
+ project(
+ ImmutableMap.of(
+ "output", expression("coalesce(intermediate, 0)")),
+ aggregation(
+ ImmutableMap.of("intermediate",
+ functionCall("arbitrary", ImmutableList.of("expression"))),
+ SINGLE,
+ project(
+ ImmutableMap.of(
+ "original", expression("if(a > b, 'constant')"),
+ "expression", expression("if(a > b, 1, NULL)")),
+ values("a", "b")))));
+ }
+
+ @Test
+ public void testReplaceMultipleConditionalConstant()
+ {
+ tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager()))
+ .on(p -> {
+ VariableReferenceExpression original1 = p.variable("original1", BOOLEAN);
+ VariableReferenceExpression original2 = p.variable("original2", BOOLEAN);
+ VariableReferenceExpression a = p.variable("a", BIGINT);
+ VariableReferenceExpression b = p.variable("b", BIGINT);
+ return p.aggregation((builder) -> builder
+ .addAggregation(
+ p.variable("output1"),
+ p.rowExpression("approx_distinct(original1)"))
+ .addAggregation(
+ p.variable("output2"),
+ p.rowExpression("approx_distinct(original2)"))
+ .globalGrouping()
+ .step(AggregationNode.Step.SINGLE)
+ .source(
+ p.project(
+ p.assignment(
+ original1, p.rowExpression("if(a > b, 'constant')"),
+ original2, p.rowExpression("if(a < b, NULL, 'constant')")),
+ p.values(a, b))));
+ })
+ .matches(
+ project(
+ ImmutableMap.of(
+ "output1", expression("coalesce(intermediate1, 0)"),
+ "output2", expression("coalesce(intermediate2, 0)")),
+ aggregation(
+ ImmutableMap.of(
+ "intermediate1", functionCall("arbitrary", ImmutableList.of("expression1")),
+ "intermediate2", functionCall("arbitrary", ImmutableList.of("expression2"))),
+ SINGLE,
+ project(
+ ImmutableMap.of(
+ "original1", expression("if(a > b, 'constant')"),
+ "original2", expression("if(a < b, NULL, 'constant')"),
+ "expression1", expression("if(a > b, 1, NULL)"),
+ "expression2", expression("if(a < b, NULL, 1)")),
+ values("a", "b")))));
+ }
+
+ @Test
+ public void testDontReplaceConstant()
+ {
+ tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager()))
+ .on(p -> {
+ VariableReferenceExpression input = p.variable("input", VARCHAR);
+ return p.aggregation((builder) -> builder
+ .addAggregation(
+ p.variable("output"),
+ p.rowExpression("approx_distinct(input)"))
+ .globalGrouping()
+ .step(AggregationNode.Step.SINGLE)
+ .source(
+ p.project(
+ p.assignment(input, p.rowExpression("'constant'")),
+ p.values())));
+ }).doesNotFire();
+ }
+
+ @Test
+ public void testDontReplaceVariable()
+ {
+ tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager()))
+ .on(p -> {
+ VariableReferenceExpression input = p.variable("input", VARCHAR);
+ VariableReferenceExpression nonconstant = p.variable("nonconstant", VARCHAR);
+ return p.aggregation((builder) -> builder
+ .addAggregation(
+ p.variable("output"),
+ p.rowExpression("approx_distinct(input)"))
+ .globalGrouping()
+ .step(AggregationNode.Step.SINGLE)
+ .source(
+ p.project(
+ p.assignment(input, p.rowExpression("nonconstant")),
+ p.values(nonconstant))));
+ }).doesNotFire();
+ }
+
+ @Test
+ public void testDontReplaceConditionalVariable()
+ {
+ tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager()))
+ .on(p -> {
+ VariableReferenceExpression original = p.variable("original", BOOLEAN);
+ VariableReferenceExpression a = p.variable("a", BIGINT);
+ VariableReferenceExpression b = p.variable("b", BIGINT);
+ VariableReferenceExpression nonconstant = p.variable("nonconstant", BIGINT);
+ return p.aggregation((builder) -> builder
+ .addAggregation(
+ p.variable("output"),
+ p.rowExpression("approx_distinct(original)"))
+ .globalGrouping()
+ .step(AggregationNode.Step.SINGLE)
+ .source(
+ p.project(
+ p.assignment(
+ original, p.rowExpression("if(a > b, nonconstant)")),
+ p.values(a, b, nonconstant))));
+ }).doesNotFire();
+ }
+}
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestApproxDistinctOptimizer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestApproxDistinctOptimizer.java
new file mode 100644
index 0000000000000..c978a895ada42
--- /dev/null
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestApproxDistinctOptimizer.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.optimizations;
+
+import com.facebook.presto.spi.plan.AggregationNode;
+import com.facebook.presto.sql.planner.assertions.BasePlanTest;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.testng.annotations.Test;
+
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anySymbol;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
+
+public class TestApproxDistinctOptimizer
+ extends BasePlanTest
+{
+ @Test
+ public void testReplacesConditionalApproxDistinct()
+ {
+ assertPlan("SELECT APPROX_DISTINCT(IF(nationkey = 1, 1)) FROM nation",
+ output(
+ project(
+ ImmutableMap.of("output", expression("coalesce(intermediate, 0)")),
+ aggregation(
+ ImmutableMap.of("intermediate", functionCall("arbitrary", ImmutableList.of("partial"))),
+ AggregationNode.Step.FINAL,
+ anyTree(
+ aggregation(
+ ImmutableMap.of("partial", functionCall("arbitrary", false, ImmutableList.of(anySymbol()))),
+ AggregationNode.Step.PARTIAL,
+ anyTree(
+ tableScan("nation"))))))));
+ }
+
+ @Test
+ public void testReplacesConditionalApproxDistinctGrouped()
+ {
+ assertPlan("SELECT APPROX_DISTINCT(IF(nationkey = nationkey, 1)) FROM nation group by nationkey",
+ output(
+ project(
+ ImmutableMap.of("output", expression("coalesce(intermediate, 0)")),
+ aggregation(
+ ImmutableMap.of("intermediate", functionCall("arbitrary", ImmutableList.of("partial"))),
+ AggregationNode.Step.FINAL,
+ anyTree(
+ aggregation(
+ ImmutableMap.of("partial", functionCall("arbitrary", false, ImmutableList.of(anySymbol()))),
+ AggregationNode.Step.PARTIAL,
+ anyTree(
+ tableScan("nation"))))))));
+ }
+
+ @Test
+ public void testDontReplaceConstantApproxDistinct()
+ {
+ assertPlan("SELECT APPROX_DISTINCT('constant') FROM nation",
+ output(
+ aggregation(
+ ImmutableMap.of("final", functionCall("approx_distinct", ImmutableList.of("partial"))),
+ AggregationNode.Step.FINAL,
+ anyTree(
+ aggregation(
+ ImmutableMap.of("partial", functionCall("approx_distinct", false, ImmutableList.of(anySymbol()))),
+ AggregationNode.Step.PARTIAL,
+ anyTree(
+ tableScan("nation")))))));
+ }
+
+ @Test
+ public void testDontReplaceVariableApproxDistinct()
+ {
+ assertPlan("SELECT APPROX_DISTINCT(nationkey) FROM nation",
+ output(
+ aggregation(
+ ImmutableMap.of("final", functionCall("approx_distinct", ImmutableList.of("partial"))),
+ AggregationNode.Step.FINAL,
+ anyTree(
+ aggregation(
+ ImmutableMap.of("partial", functionCall("approx_distinct", ImmutableList.of("nationkey"))),
+ AggregationNode.Step.PARTIAL,
+ tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))));
+ }
+}
diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java
index 2d6f73cda84ee..288d85930d3df 100644
--- a/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java
+++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java
@@ -68,6 +68,8 @@ public interface StandardFunctionResolution
FunctionHandle countFunction(Type valueType);
+ FunctionHandle arbitraryFunction(Type valueType);
+
boolean isMaxFunction(FunctionHandle functionHandle);
FunctionHandle maxFunction(Type valueType);
diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
index 6be4b53757fa5..914642b118be8 100644
--- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
+++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
@@ -69,6 +69,7 @@
import static com.facebook.presto.SystemSessionProperties.OFFSET_CLAUSE_ENABLED;
import static com.facebook.presto.SystemSessionProperties.OPTIMIZER_USE_HISTOGRAMS;
import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_CASE_EXPRESSION_PREDICATE;
+import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT;
import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION;
import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT;
import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT_TIMEOUT_MS;
@@ -6803,6 +6804,41 @@ public void testMergeDuplicateAggregations()
"HAVING min(orderkey) < (SELECT avg(orderkey) FROM orders WHERE orderkey < 7)");
}
+ @Test
+ public void testReplaceConditionalApproxDistinct()
+ {
+ Session enabled = Session.builder(getSession())
+ .setSystemProperty(OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, "true")
+ .build();
+ Session disabled = Session.builder(getSession())
+ .setSystemProperty(OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, "false")
+ .build();
+
+ List queries = ImmutableList.of(
+ "select approx_distinct(if(%s, 1, 2)) from orders %s",
+ "select approx_distinct(if(%s, 3)) from orders %s",
+ "select approx_distinct(if(%s, 4, NULL)) from orders %s",
+ "select approx_distinct(if(%s, NULL, 5)) from orders %s",
+ "select approx_distinct(if(%s, NULL)) from orders %s",
+ "select approx_distinct(if(%s, NULL, NULL)) from orders %s");
+ List conditions = ImmutableList.of(
+ "orderkey = orderkey",
+ "orderkey % 2 = 0");
+ List suffixes = ImmutableList.of(
+ "",
+ "where orderkey % 3 = 0",
+ "group by orderkey");
+
+ for (String query : queries) {
+ for (String condition : conditions) {
+ for (String suffix : suffixes) {
+ String sql = format(query, condition, suffix);
+ assertQueryWithSameQueryRunner(enabled, sql, disabled);
+ }
+ }
+ }
+ }
+
@Test
public void testSameAggregationWithAndWithoutFilter()
{