diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index 34e9330eab12e..2953460c3a461 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -320,6 +320,7 @@ public final class SystemSessionProperties public static final String REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION = "rewrite_expression_with_constant_expression"; public static final String PRINT_ESTIMATED_STATS_FROM_CACHE = "print_estimated_stats_from_cache"; public static final String REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT = "remove_cross_join_with_constant_single_row_input"; + public static final String OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT = "optimize_conditional_constant_approximate_distinct"; public static final String EAGER_PLAN_VALIDATION_ENABLED = "eager_plan_validation_enabled"; public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode"; public static final String JOIN_PREFILTER_BUILD_SIDE = "join_prefilter_build_side"; @@ -1904,6 +1905,11 @@ public SystemSessionProperties( "Enable adding an exchange below partial aggregation over a GroupId node to improve partial aggregation performance", featuresConfig.getAddExchangeBelowPartialAggregationOverGroupId(), false), + booleanProperty( + OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, + "Optimize out APPROX_DISTINCT operations over constant conditionals", + featuresConfig.isOptimizeConditionalApproxDistinct(), + false), new PropertyMetadata<>( QUERY_CLIENT_TIMEOUT, "Configures how long the query runs without contact from the client application, such as the CLI, before it's abandoned", @@ -3267,4 +3273,9 @@ public static Duration getQueryClientTimeout(Session session) { return session.getSystemProperty(QUERY_CLIENT_TIMEOUT, Duration.class); } + + public static boolean isOptimizeConditionalApproxDistinctEnabled(Session session) + { + return session.getSystemProperty(OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, Boolean.class); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 46f6ecf466d93..a92a1ab8c7faf 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -269,6 +269,7 @@ public class FeaturesConfig private boolean pullUpExpressionFromLambda; private boolean rewriteConstantArrayContainsToIn; private boolean rewriteExpressionWithConstantVariable = true; + private boolean optimizeConditionalApproxDistinct = true; private boolean preProcessMetadataCalls; private boolean handleComplexEquiJoins; @@ -2787,6 +2788,19 @@ public FeaturesConfig setRewriteExpressionWithConstantVariable(boolean rewriteEx return this; } + public boolean isOptimizeConditionalApproxDistinct() + { + return this.optimizeConditionalApproxDistinct; + } + + @Config("optimizer.optimize-constant-approx-distinct") + @ConfigDescription("Optimize out APPROX_DISTINCT over conditional constant expressions") + public FeaturesConfig setOptimizeConditionalApproxDistinct(boolean optimizeConditionalApproxDistinct) + { + this.optimizeConditionalApproxDistinct = optimizeConditionalApproxDistinct; + return this; + } + public CreateView.Security getDefaultViewSecurityMode() { return this.defaultViewSecurityMode; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index b017a7e1c8247..958375430b3c5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -125,6 +125,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes; import com.facebook.presto.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins; +import com.facebook.presto.sql.planner.iterative.rule.ReplaceConditionalApproxDistinct; import com.facebook.presto.sql.planner.iterative.rule.RewriteAggregationIfToFilter; import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseExpressionPredicate; import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseToMap; @@ -455,6 +456,12 @@ public PlanOptimizers( new ReplaceConstantVariableReferencesWithConstants(metadata.getFunctionAndTypeManager()), simplifyRowExpressionOptimizer, new ReplaceConstantVariableReferencesWithConstants(metadata.getFunctionAndTypeManager()), + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new ReplaceConditionalApproxDistinct(metadata.getFunctionAndTypeManager()))), new IterativeOptimizer( metadata, ruleStats, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReplaceConditionalApproxDistinct.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReplaceConditionalApproxDistinct.java new file mode 100644 index 0000000000000..6079ed2bd45b7 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReplaceConditionalApproxDistinct.java @@ -0,0 +1,233 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.Map.Entry; + +import static com.facebook.presto.SystemSessionProperties.isOptimizeConditionalApproxDistinctEnabled; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; +import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.relational.Expressions.constantNull; +import static com.facebook.presto.sql.relational.Expressions.isNull; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +/** + * elimination of approx distinct on conditional constant values. + *

+ * depending on the inner conditional, the expression is converted + * to its equivalent arbitrary() expression. + * + * - approx_distinct(if(..., non-null)) -> arbitrary(if(..., 1, NULL)) + * - approx_distinct(if(..., null, non-null)) -> arbitrary(if(..., NULL, 1)) + * - approx_distinct(if(..., null, null)) -> arbitrary(0) + * + * An intermediate projection is inserted to convert any NULL arbitrary output + * to zero values. + */ +public class ReplaceConditionalApproxDistinct + implements Rule +{ + private static final Capture SOURCE = Capture.newCapture(); + + private static final Pattern PATTERN = aggregation() + .with(source().matching(project().capturedAs(SOURCE))); + + private final StandardFunctionResolution functionResolution; + + private static final String ARBITRARY = "arbitrary"; + + public ReplaceConditionalApproxDistinct(FunctionAndTypeManager functionAndTypeManager) + { + requireNonNull(functionAndTypeManager, "functionManager is null"); + this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + } + + @Override + public boolean isEnabled(Session session) + { + return isOptimizeConditionalApproxDistinctEnabled(session); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(AggregationNode parent, Captures captures, Context context) + { + VariableAllocator variableAllocator = context.getVariableAllocator(); + boolean changed = false; + ProjectNode project = captures.get(SOURCE); + Assignments.Builder outputs = Assignments.builder(); + Assignments.Builder inputs = Assignments.builder(); + + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Entry entry : parent.getAggregations().entrySet()) { + VariableReferenceExpression variable = entry.getKey(); + AggregationNode.Aggregation aggregation = entry.getValue(); + SpecialFormExpression replaced; + VariableReferenceExpression intermediate; + VariableReferenceExpression expression; + + if (!isApproxDistinct(aggregation) || !aggregationIsReplaceable(aggregation, project.getAssignments())) { + aggregations.put(variable, aggregation); + outputs.put(variable, variable); + continue; + } + changed = true; + replaced = (SpecialFormExpression) project.getAssignments().get( + (VariableReferenceExpression) aggregation.getArguments().get(0)); + + expression = variableAllocator.newVariable("expression", BIGINT); + inputs.put(expression, replaceIfExpression(replaced)); + + intermediate = variableAllocator.newVariable("intermediate", BIGINT); + aggregations.put(intermediate, new AggregationNode.Aggregation( + new CallExpression( + aggregation.getCall().getSourceLocation(), + ARBITRARY, + functionResolution.arbitraryFunction(BIGINT), + BIGINT, + ImmutableList.of(expression)), + aggregation.getFilter(), + aggregation.getOrderBy(), + aggregation.isDistinct(), + aggregation.getMask())); + + outputs.put(variable, new SpecialFormExpression( + COALESCE, + BIGINT, + ImmutableList.of( + intermediate, + constant(0L, BIGINT)))); + } + + if (!changed) { + return Result.empty(); + } + + ProjectNode child = new ProjectNode( + project.getSourceLocation(), + context.getIdAllocator().getNextId(), + project.getSource(), + inputs.putAll(project.getAssignments()).build(), + project.getLocality()); + + AggregationNode aggregation = new AggregationNode( + parent.getSourceLocation(), + context.getIdAllocator().getNextId(), + child, + aggregations.build(), + parent.getGroupingSets(), + ImmutableList.of(), + parent.getStep(), + parent.getHashVariable(), + parent.getGroupIdVariable(), + parent.getAggregationId()); + + aggregation.getHashVariable().ifPresent(hashvariable -> outputs.put(hashvariable, hashvariable)); + aggregation.getGroupingSets().getGroupingKeys().forEach(groupingKey -> outputs.put(groupingKey, groupingKey)); + return Result.ofPlanNode(new ProjectNode( + context.getIdAllocator().getNextId(), + aggregation, + outputs.build())); + } + + private boolean isApproxDistinct(AggregationNode.Aggregation aggregation) + { + return functionResolution.isApproximateCountDistinctFunction(aggregation.getFunctionHandle()); + } + + private ConstantExpression convertConstant(ConstantExpression expression) + { + return isNull(expression) ? constantNull(BIGINT) : constant(1L, BIGINT); + } + + private RowExpression replaceIfExpression(SpecialFormExpression ifCondition) + { + ConstantExpression trueThen = (ConstantExpression) ifCondition.getArguments().get(1); + ConstantExpression falseThen = (ConstantExpression) ifCondition.getArguments().get(2); + RowExpression replace; + + if ((isNull(trueThen) && !isNull(falseThen)) || (!isNull(trueThen) && isNull(falseThen))) { + // if(..., null, non-null) or if(..., non-null, null) + replace = new SpecialFormExpression( + ifCondition.getSourceLocation(), + IF, + BIGINT, + ImmutableList.of( + ifCondition.getArguments().get(0), + convertConstant(trueThen), + convertConstant(falseThen))); + } + else { + // if(..., null, null) + checkState(isNull(trueThen) && isNull(falseThen), + "expected true (%s) and false (%s) predicates to be null", + trueThen, falseThen); + replace = convertConstant(trueThen); + } + return replace; + } + + private boolean aggregationIsReplaceable(AggregationNode.Aggregation aggregation, Assignments inputs) + { + RowExpression argument = aggregation.getArguments().get(0); + RowExpression ifCondition = null; + RowExpression trueThen = null; + RowExpression falseThen = null; + + if (argument instanceof VariableReferenceExpression) { + ifCondition = inputs.get((VariableReferenceExpression) argument); + } + + if (ifCondition instanceof SpecialFormExpression && ((SpecialFormExpression) ifCondition).getForm() == IF) { + trueThen = ((SpecialFormExpression) ifCondition).getArguments().get(1); + falseThen = ((SpecialFormExpression) ifCondition).getArguments().get(2); + } + + return trueThen instanceof ConstantExpression && + falseThen instanceof ConstantExpression && + (isNull(trueThen) || isNull(falseThen)); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java index f971ff1027de6..94b00e4812a25 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java @@ -340,6 +340,12 @@ public boolean isMinByFunction(FunctionHandle functionHandle) return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("min_by"))); } + @Override + public FunctionHandle arbitraryFunction(Type valueType) + { + return functionAndTypeResolver.lookupFunction("arbitrary", fromTypes(valueType)); + } + @Override public boolean isMaxFunction(FunctionHandle functionHandle) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index c764262338fab..db27a9c9a0a07 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -239,6 +239,7 @@ public void testDefaults() .setCteFilterAndProjectionPushdownEnabled(true) .setGenerateDomainFilters(false) .setRewriteExpressionWithConstantVariable(true) + .setOptimizeConditionalApproxDistinct(true) .setDefaultWriterReplicationCoefficient(3.0) .setDefaultViewSecurityMode(DEFINER) .setCteHeuristicReplicationThreshold(4) @@ -449,6 +450,7 @@ public void testExplicitPropertyMappings() .put("optimizer.skip-hash-generation-for-join-with-table-scan-input", "true") .put("optimizer.generate-domain-filters", "true") .put("optimizer.rewrite-expression-with-constant-variable", "false") + .put("optimizer.optimize-constant-approx-distinct", "false") .put("optimizer.default-writer-replication-coefficient", "5.0") .put("default-view-security-mode", INVOKER.name()) .put("cte-heuristic-replication-threshold", "2") @@ -656,6 +658,7 @@ public void testExplicitPropertyMappings() .setCteFilterAndProjectionPushdownEnabled(false) .setGenerateDomainFilters(true) .setRewriteExpressionWithConstantVariable(false) + .setOptimizeConditionalApproxDistinct(false) .setDefaultWriterReplicationCoefficient(5.0) .setDefaultViewSecurityMode(INVOKER) .setCteHeuristicReplicationThreshold(2) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java index fa61a6354e26a..19621f3ce1f97 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java @@ -117,8 +117,8 @@ private static boolean verifyAggregationOrderBy(OrderingScheme orderingScheme, O private static boolean isEquivalent(Optional expression, Optional rowExpression) { // Function's argument provided by FunctionCallProvider is SymbolReference that already resolved from symbolAliases. - if (rowExpression.isPresent() && expression.isPresent()) { - checkArgument(rowExpression.get() instanceof VariableReferenceExpression, "can only process variableReference"); + if (rowExpression.isPresent() && expression.isPresent() && !(expression.get() instanceof AnySymbolReference)) { + checkArgument(rowExpression.get() instanceof VariableReferenceExpression, "can only process variableReference: " + rowExpression.get()); return expression.get().equals(createSymbolReference(((VariableReferenceExpression) rowExpression.get()))); } return rowExpression.isPresent() == expression.isPresent(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReplaceConditionalApproxDistinct.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReplaceConditionalApproxDistinct.java new file mode 100644 index 0000000000000..8cb9014afafa2 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReplaceConditionalApproxDistinct.java @@ -0,0 +1,215 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.parser.ParsingOptions; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestReplaceConditionalApproxDistinct + extends BaseRuleTest +{ + @Test + public void testReplaceConditionalConstant() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression original = p.variable("original", BOOLEAN); + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(original)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment( + original, p.rowExpression("if(a > b, 'constant')")), + p.values(a, b)))); + }) + .matches( + project( + ImmutableMap.of( + "output", expression("coalesce(intermediate, 0)")), + aggregation( + ImmutableMap.of("intermediate", + functionCall("arbitrary", ImmutableList.of("expression"))), + SINGLE, + project( + ImmutableMap.of( + "original", expression("if(a > b, 'constant')"), + "expression", expression("if(a > b, 1, NULL)")), + values("a", "b"))))); + } + + @Test + public void testReplaceConditionalErrorBounds() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression original = p.variable("original", BOOLEAN); + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression bounds = p.variable("bounds", DOUBLE); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(original, bounds)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment( + original, p.rowExpression("if(a > b, 'constant')"), + bounds, p.rowExpression("0.0040625", ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE)), + p.values(a, b)))); + }) + .matches( + project( + ImmutableMap.of( + "output", expression("coalesce(intermediate, 0)")), + aggregation( + ImmutableMap.of("intermediate", + functionCall("arbitrary", ImmutableList.of("expression"))), + SINGLE, + project( + ImmutableMap.of( + "original", expression("if(a > b, 'constant')"), + "expression", expression("if(a > b, 1, NULL)")), + values("a", "b"))))); + } + + @Test + public void testReplaceMultipleConditionalConstant() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression original1 = p.variable("original1", BOOLEAN); + VariableReferenceExpression original2 = p.variable("original2", BOOLEAN); + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output1"), + p.rowExpression("approx_distinct(original1)")) + .addAggregation( + p.variable("output2"), + p.rowExpression("approx_distinct(original2)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment( + original1, p.rowExpression("if(a > b, 'constant')"), + original2, p.rowExpression("if(a < b, NULL, 'constant')")), + p.values(a, b)))); + }) + .matches( + project( + ImmutableMap.of( + "output1", expression("coalesce(intermediate1, 0)"), + "output2", expression("coalesce(intermediate2, 0)")), + aggregation( + ImmutableMap.of( + "intermediate1", functionCall("arbitrary", ImmutableList.of("expression1")), + "intermediate2", functionCall("arbitrary", ImmutableList.of("expression2"))), + SINGLE, + project( + ImmutableMap.of( + "original1", expression("if(a > b, 'constant')"), + "original2", expression("if(a < b, NULL, 'constant')"), + "expression1", expression("if(a > b, 1, NULL)"), + "expression2", expression("if(a < b, NULL, 1)")), + values("a", "b"))))); + } + + @Test + public void testDontReplaceConstant() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression input = p.variable("input", VARCHAR); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(input)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment(input, p.rowExpression("'constant'")), + p.values()))); + }).doesNotFire(); + } + + @Test + public void testDontReplaceVariable() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression input = p.variable("input", VARCHAR); + VariableReferenceExpression nonconstant = p.variable("nonconstant", VARCHAR); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(input)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment(input, p.rowExpression("nonconstant")), + p.values(nonconstant)))); + }).doesNotFire(); + } + + @Test + public void testDontReplaceConditionalVariable() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression original = p.variable("original", BOOLEAN); + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression nonconstant = p.variable("nonconstant", BIGINT); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(original)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment( + original, p.rowExpression("if(a > b, nonconstant)")), + p.values(a, b, nonconstant)))); + }).doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestApproxDistinctOptimizer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestApproxDistinctOptimizer.java new file mode 100644 index 0000000000000..c978a895ada42 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestApproxDistinctOptimizer.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anySymbol; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; + +public class TestApproxDistinctOptimizer + extends BasePlanTest +{ + @Test + public void testReplacesConditionalApproxDistinct() + { + assertPlan("SELECT APPROX_DISTINCT(IF(nationkey = 1, 1)) FROM nation", + output( + project( + ImmutableMap.of("output", expression("coalesce(intermediate, 0)")), + aggregation( + ImmutableMap.of("intermediate", functionCall("arbitrary", ImmutableList.of("partial"))), + AggregationNode.Step.FINAL, + anyTree( + aggregation( + ImmutableMap.of("partial", functionCall("arbitrary", false, ImmutableList.of(anySymbol()))), + AggregationNode.Step.PARTIAL, + anyTree( + tableScan("nation")))))))); + } + + @Test + public void testReplacesConditionalApproxDistinctGrouped() + { + assertPlan("SELECT APPROX_DISTINCT(IF(nationkey = nationkey, 1)) FROM nation group by nationkey", + output( + project( + ImmutableMap.of("output", expression("coalesce(intermediate, 0)")), + aggregation( + ImmutableMap.of("intermediate", functionCall("arbitrary", ImmutableList.of("partial"))), + AggregationNode.Step.FINAL, + anyTree( + aggregation( + ImmutableMap.of("partial", functionCall("arbitrary", false, ImmutableList.of(anySymbol()))), + AggregationNode.Step.PARTIAL, + anyTree( + tableScan("nation")))))))); + } + + @Test + public void testDontReplaceConstantApproxDistinct() + { + assertPlan("SELECT APPROX_DISTINCT('constant') FROM nation", + output( + aggregation( + ImmutableMap.of("final", functionCall("approx_distinct", ImmutableList.of("partial"))), + AggregationNode.Step.FINAL, + anyTree( + aggregation( + ImmutableMap.of("partial", functionCall("approx_distinct", false, ImmutableList.of(anySymbol()))), + AggregationNode.Step.PARTIAL, + anyTree( + tableScan("nation"))))))); + } + + @Test + public void testDontReplaceVariableApproxDistinct() + { + assertPlan("SELECT APPROX_DISTINCT(nationkey) FROM nation", + output( + aggregation( + ImmutableMap.of("final", functionCall("approx_distinct", ImmutableList.of("partial"))), + AggregationNode.Step.FINAL, + anyTree( + aggregation( + ImmutableMap.of("partial", functionCall("approx_distinct", ImmutableList.of("nationkey"))), + AggregationNode.Step.PARTIAL, + tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))))); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java index 2d6f73cda84ee..288d85930d3df 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java @@ -68,6 +68,8 @@ public interface StandardFunctionResolution FunctionHandle countFunction(Type valueType); + FunctionHandle arbitraryFunction(Type valueType); + boolean isMaxFunction(FunctionHandle functionHandle); FunctionHandle maxFunction(Type valueType); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 6be4b53757fa5..914642b118be8 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -69,6 +69,7 @@ import static com.facebook.presto.SystemSessionProperties.OFFSET_CLAUSE_ENABLED; import static com.facebook.presto.SystemSessionProperties.OPTIMIZER_USE_HISTOGRAMS; import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_CASE_EXPRESSION_PREDICATE; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT; import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION; import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT; import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT_TIMEOUT_MS; @@ -6803,6 +6804,41 @@ public void testMergeDuplicateAggregations() "HAVING min(orderkey) < (SELECT avg(orderkey) FROM orders WHERE orderkey < 7)"); } + @Test + public void testReplaceConditionalApproxDistinct() + { + Session enabled = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, "true") + .build(); + Session disabled = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, "false") + .build(); + + List queries = ImmutableList.of( + "select approx_distinct(if(%s, 1, 2)) from orders %s", + "select approx_distinct(if(%s, 3)) from orders %s", + "select approx_distinct(if(%s, 4, NULL)) from orders %s", + "select approx_distinct(if(%s, NULL, 5)) from orders %s", + "select approx_distinct(if(%s, NULL)) from orders %s", + "select approx_distinct(if(%s, NULL, NULL)) from orders %s"); + List conditions = ImmutableList.of( + "orderkey = orderkey", + "orderkey % 2 = 0"); + List suffixes = ImmutableList.of( + "", + "where orderkey % 3 = 0", + "group by orderkey"); + + for (String query : queries) { + for (String condition : conditions) { + for (String suffix : suffixes) { + String sql = format(query, condition, suffix); + assertQueryWithSameQueryRunner(enabled, sql, disabled); + } + } + } + } + @Test public void testSameAggregationWithAndWithoutFilter() {