From 083baafcb3467597a7160963943d08ea64211a8f Mon Sep 17 00:00:00 2001 From: maswin Date: Mon, 2 May 2022 12:22:46 -0700 Subject: [PATCH 1/2] Allow RowExpressionRewriteRuleSet to be enabled or disabled --- .../rule/RowExpressionRewriteRuleSet.java | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java index fc6796730b002..3a0ae482e1c00 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.Session; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.plan.AggregationNode; @@ -70,6 +71,11 @@ public RowExpressionRewriteRuleSet(PlanRowExpressionRewriter rewriter) this.rewriter = requireNonNull(rewriter, "rewriter is null"); } + public boolean isRewriterEnabled(Session session) + { + return true; + } + public Set> rules() { return ImmutableSet.of( @@ -138,6 +144,12 @@ public Rule aggregationRowExpressionRewriteRule() private final class ProjectRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -167,6 +179,12 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) private final class SpatialJoinRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -199,6 +217,12 @@ public Result apply(SpatialJoinNode spatialJoinNode, Captures captures, Context private final class JoinRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -237,6 +261,12 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) private final class WindowRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -288,6 +318,12 @@ public Result apply(WindowNode windowNode, Captures captures, Context context) private final class ApplyRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -331,6 +367,12 @@ private Optional translateAssignments(Assignments assignments, Rule private final class FilterRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -353,6 +395,12 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) private final class ValuesRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -385,6 +433,12 @@ public Result apply(ValuesNode valuesNode, Captures captures, Context context) private final class AggregationRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -426,6 +480,12 @@ public Result apply(AggregationNode node, Captures captures, Context context) private final class TableFinishRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -477,6 +537,12 @@ private Optional translateStatisticAggregation(StatisticA private final class TableWriterRowExpressionRewrite implements Rule { + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { From 8c11aa4121cd308eb34f5fb10af85fb2064c6cda Mon Sep 17 00:00:00 2001 From: maswin Date: Mon, 2 May 2022 13:24:16 -0700 Subject: [PATCH 2/2] Optimize filter condition with CASE predicate --- .../expressions/LogicalRowExpressions.java | 36 +- .../presto/SystemSessionProperties.java | 11 + .../presto/sql/analyzer/FeaturesConfig.java | 13 + .../presto/sql/planner/PlanOptimizers.java | 8 + .../rule/RewriteCaseExpressionPredicate.java | 333 ++++++++++++++++++ .../sql/relational/FunctionResolution.java | 6 + .../sql/analyzer/TestFeaturesConfig.java | 3 + .../TestRewriteCaseExpressionPredicate.java | 197 +++++++++++ .../function/StandardFunctionResolution.java | 2 + .../presto/tests/AbstractTestQueries.java | 37 ++ 10 files changed, 645 insertions(+), 1 deletion(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java index 68f9d1133aac3..44ca3b27b7f95 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java @@ -51,6 +51,7 @@ import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.OR; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH; import static java.lang.Math.min; import static java.util.Arrays.asList; import static java.util.Arrays.stream; @@ -631,6 +632,21 @@ private boolean isComparisonExpression(RowExpression expression) return expression instanceof CallExpression && functionResolution.isComparisonFunction(((CallExpression) expression).getFunctionHandle()); } + public boolean isEqualsExpression(RowExpression expression) + { + return expression instanceof CallExpression && functionResolution.isEqualsFunction(((CallExpression) expression).getFunctionHandle()); + } + + public boolean isCastExpression(RowExpression expression) + { + return expression instanceof CallExpression && functionResolution.isCastFunction(((CallExpression) expression).getFunctionHandle()); + } + + public boolean isCaseExpression(RowExpression expression) + { + return expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm().equals(SWITCH); + } + /** * Extract the component predicates as a list of list in which is grouped so that the outer level has same conjunctive/disjunctive joiner as original predicate and * inner level has opposite joiner. @@ -771,11 +787,29 @@ private Optional getOperator(RowExpression expression) return Optional.empty(); } - private RowExpression notCallExpression(RowExpression argument) + public RowExpression notCallExpression(RowExpression argument) { return new CallExpression(argument.getSourceLocation(), "not", functionResolution.notFunction(), BOOLEAN, singletonList(argument)); } + public RowExpression equalsCallExpression(RowExpression left, RowExpression right) + { + return new CallExpression( + EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, left.getType(), right.getType()), + BOOLEAN, + asList(left, right)); + } + + public static RowExpression replaceArguments(CallExpression expression, RowExpression... arguments) + { + return new CallExpression( + expression.getDisplayName(), + expression.getFunctionHandle(), + expression.getType(), + asList(arguments)); + } + private static OperatorType negate(OperatorType operator) { switch (operator) { diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 3283972342279..2c3342f385d33 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -159,6 +159,7 @@ public final class SystemSessionProperties public static final String PARTIAL_AGGREGATION_STRATEGY = "partial_aggregation_strategy"; public static final String PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD = "partial_aggregation_byte_reduction_threshold"; public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number"; + public static final String OPTIMIZE_CASE_EXPRESSION_PREDICATE = "optimize_case_expression_predicate"; public static final String MAX_GROUPING_SETS = "max_grouping_sets"; public static final String LEGACY_UNNEST = "legacy_unnest"; public static final String STATISTICS_CPU_TIMER_ENABLED = "statistics_cpu_timer_enabled"; @@ -851,6 +852,11 @@ public SystemSessionProperties( "Use top N row number optimization", featuresConfig.isOptimizeTopNRowNumber(), false), + booleanProperty( + OPTIMIZE_CASE_EXPRESSION_PREDICATE, + "Optimize case expression predicates", + featuresConfig.isOptimizeCaseExpressionPredicate(), + false), integerProperty( MAX_GROUPING_SETS, "Maximum number of grouping sets in a GROUP BY", @@ -1795,6 +1801,11 @@ public static boolean isOptimizeTopNRowNumber(Session session) return session.getSystemProperty(OPTIMIZE_TOP_N_ROW_NUMBER, Boolean.class); } + public static boolean isOptimizeCaseExpressionPredicate(Session session) + { + return session.getSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, Boolean.class); + } + public static boolean isDistributedSortEnabled(Session session) { return session.getSystemProperty(DISTRIBUTED_SORT, Boolean.class); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 5ee88785060ea..d421abba72d88 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -95,6 +95,7 @@ public class FeaturesConfig private int optimizeMetadataQueriesCallThreshold = 100; private boolean optimizeHashGeneration = true; private boolean enableIntermediateAggregations; + private boolean optimizeCaseExpressionPredicate; private boolean pushTableWriteThroughUnion = true; private boolean exchangeCompressionEnabled; private boolean exchangeChecksumEnabled; @@ -874,6 +875,18 @@ public FeaturesConfig setOptimizeTopNRowNumber(boolean optimizeTopNRowNumber) return this; } + public boolean isOptimizeCaseExpressionPredicate() + { + return optimizeCaseExpressionPredicate; + } + + @Config("optimizer.optimize-case-expression-predicate") + public FeaturesConfig setOptimizeCaseExpressionPredicate(boolean optimizeCaseExpressionPredicate) + { + this.optimizeCaseExpressionPredicate = optimizeCaseExpressionPredicate; + return this; + } + public boolean isOptimizeHashGeneration() { return optimizeHashGeneration; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index f89bcc7f2e57a..699f73fb5b2ed 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -98,6 +98,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins; import com.facebook.presto.sql.planner.iterative.rule.RewriteAggregationIfToFilter; +import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseExpressionPredicate; import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject; import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides; @@ -277,6 +278,12 @@ public PlanOptimizers( .add(new PruneRedundantProjectionAssignments()) .build()); + IterativeOptimizer caseExpressionPredicateRewriter = new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + new RewriteCaseExpressionPredicate(metadata.getFunctionAndTypeManager()).rules()); + PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, sqlParser)); builder.add( @@ -405,6 +412,7 @@ public PlanOptimizers( // After this point, all planNodes should not contain OriginalExpression builder.add( + caseExpressionPredicateRewriter, new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java new file mode 100644 index 0000000000000..62fd299aa8086 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java @@ -0,0 +1,333 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.expressions.LogicalRowExpressions; +import com.facebook.presto.expressions.RowExpressionRewriter; +import com.facebook.presto.expressions.RowExpressionTreeRewriter; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.InputReferenceExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.facebook.presto.SystemSessionProperties.isOptimizeCaseExpressionPredicate; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.expressions.LogicalRowExpressions.and; +import static com.facebook.presto.expressions.LogicalRowExpressions.or; +import static com.facebook.presto.expressions.LogicalRowExpressions.replaceArguments; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * This Rule rewrites a CASE expression predicate into a series of AND/OR clauses. + * The following CASE expression + *

+ * (CASE + * WHEN expression=constant1 THEN result1 + * WHEN expression=constant2 THEN result2 + * WHEN expression=constant3 THEN result3 + * ELSE elseResult + * END) = value + *

+ * can be converted into a series AND/OR clauses as below + *

+ * (result1 = value AND expression=constant1) OR + * (result2 = value AND expression=constant2 AND !(expression=constant1)) OR + * (result3 = value AND expression=constant3 AND !(expression=constant1) AND !(expression=constant2)) OR + * (elseResult = value AND !(expression=constant1) AND !(expression=constant2) AND !(expression=constant3)) + *

+ * The above conversion evaluates the conditions in WHEN clauses multiple times. But if we ensure these conditions are + * disjunct, we can skip all the NOT of previous WHEN conditions and simplify the expression to: + *

+ * (result1 = value AND expression=constant1) OR + * (result2 = value AND expression=constant2) OR + * (result3 = value AND expression=constant3) OR + * (elseResult = value AND !(expression=constant1) AND !(expression=constant2) AND !(expression=constant3)) + *

+ * To ensure the WHEN conditions are disjunct, the following criteria needs to be met: + * 1. Value is either a constant or column reference or input reference and not any function + * 2. The LHS expression in all WHEN clauses are the same. + * For example, if one WHEN clause has a expression using col1 and another using col2, it will not work + * 3. The relational operator in the WHEN clause is equals. With other operators it is hard to check for exclusivity. + * 4. All the RHS expressions are a constant and unique + *

+ * This conversion is done so that it is easy for the ExpressionInterpreter & other Optimizers to further + * simplify this and construct a domain for the column that can be used by Readers . + * i.e, ExpressionInterpreter can discard all conditions in which result != value and + * RowExpressionDomainTranslator can construct a Domain for the column + */ +public class RewriteCaseExpressionPredicate + extends RowExpressionRewriteRuleSet +{ + public RewriteCaseExpressionPredicate(FunctionAndTypeManager functionAndTypeManager) + { + super(new Rewriter(functionAndTypeManager)); + } + + private static class Rewriter + implements PlanRowExpressionRewriter + { + private final CaseExpressionPredicateRewriter caseExpressionPredicateRewriter; + + public Rewriter(FunctionAndTypeManager functionAndTypeManager) + { + requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.caseExpressionPredicateRewriter = new CaseExpressionPredicateRewriter(functionAndTypeManager); + } + + @Override + public RowExpression rewrite(RowExpression expression, Rule.Context context) + { + return RowExpressionTreeRewriter.rewriteWith(caseExpressionPredicateRewriter, expression); + } + } + + private static class CaseExpressionPredicateRewriter + extends RowExpressionRewriter + { + private final FunctionResolution functionResolution; + private final LogicalRowExpressions logicalRowExpressions; + + private CaseExpressionPredicateRewriter(FunctionAndTypeManager functionAndTypeManager) + { + this.functionResolution = new FunctionResolution(functionAndTypeManager); + this.logicalRowExpressions = new LogicalRowExpressions( + new RowExpressionDeterminismEvaluator(functionAndTypeManager), + functionResolution, + functionAndTypeManager); + } + + @Override + public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter treeRewriter) + { + if (functionResolution.isComparisonFunction(node.getFunctionHandle()) && node.getArguments().size() == 2) { + RowExpression left = node.getArguments().get(0); + RowExpression right = node.getArguments().get(1); + if (isCaseExpression(left) && isSimpleExpression(right)) { + return processCaseExpression(left, expression -> replaceArguments(node, expression, right), right); + } + else if (isCaseExpression(right) && isSimpleExpression(left)) { + return processCaseExpression(right, expression -> replaceArguments(node, left, expression), left); + } + } + return null; + } + + private boolean isCaseExpression(RowExpression expression) + { + if (logicalRowExpressions.isCastExpression(expression)) { + expression = ((CallExpression) expression).getArguments().get(0); + } + return expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm().equals(SWITCH); + } + + private boolean isSimpleExpression(RowExpression expression) + { + if (logicalRowExpressions.isCastExpression(expression)) { + return isSimpleExpression(((CallExpression) expression).getArguments().get(0)); + } + return expression instanceof ConstantExpression || + expression instanceof VariableReferenceExpression || + expression instanceof InputReferenceExpression; + } + + private RowExpression processCaseExpression(RowExpression expression, + Function comparisonExpressionGenerator, + RowExpression value) + { + if (expression instanceof SpecialFormExpression) { + checkArgument(logicalRowExpressions.isCaseExpression(expression), "expression must be a CASE expression"); + return processCaseExpression( + (SpecialFormExpression) expression, + Optional.empty(), + comparisonExpressionGenerator, + value); + } + else { + checkArgument(logicalRowExpressions.isCastExpression(expression), "expression must be a CAST expression"); + checkArgument(logicalRowExpressions.isCaseExpression(((CallExpression) expression).getArguments().get(0)), "expression argument must be a CASE expression"); + return processCaseExpression( + (SpecialFormExpression) ((CallExpression) expression).getArguments().get(0), + Optional.of((CallExpression) expression), + comparisonExpressionGenerator, + value); + } + } + + /** + * RowExpression representation of Case Statement: + * SpecialFormExpression: + * form: SWITCH + * arguments: + * [0]: RowExpression (or) ConstantExpression(TRUE) // SimpleCaseExpression (or) SearchedCaseExpression + * [1..n-1 (or) n]: SpecialFormExpression(form: WHEN) // else clause is present (or) absent + * [n]: RowExpression // available if else clause is present + */ + private RowExpression processCaseExpression(SpecialFormExpression caseExpression, + Optional castExpression, + Function comparisonExpressionGenerator, + RowExpression value) + { + Optional caseOperand = getCaseOperand(caseExpression.getArguments().get(0)); + List whenClauses; + Optional elseResult = Optional.empty(); + int argumentsSize = caseExpression.getArguments().size(); + RowExpression last = caseExpression.getArguments().get(argumentsSize - 1); + + if (last instanceof SpecialFormExpression && ((SpecialFormExpression) last).getForm().equals(WHEN)) { + whenClauses = caseExpression.getArguments().subList(1, argumentsSize); + } + else { + whenClauses = caseExpression.getArguments().subList(1, argumentsSize - 1); + elseResult = Optional.of(last); + } + + if (caseOperand.isPresent() ? + !canRewriteSimpleCaseExpression(whenClauses) : + !canRewriteSearchedCaseExpression(whenClauses)) { + return null; + } + + ImmutableList.Builder andExpressions = new ImmutableList.Builder<>(); + ImmutableList.Builder invertedOperands = new ImmutableList.Builder<>(); + + for (RowExpression whenClause : whenClauses) { + RowExpression whenOperand = ((SpecialFormExpression) whenClause).getArguments().get(0); + if (caseOperand.isPresent()) { + whenOperand = logicalRowExpressions.equalsCallExpression(caseOperand.get(), whenOperand); + } + RowExpression whenResult = ((SpecialFormExpression) whenClause).getArguments().get(1); + if (castExpression.isPresent()) { + whenResult = replaceArguments(castExpression.get(), whenResult); + } + + RowExpression comparisonExpression = comparisonExpressionGenerator.apply(whenResult); + andExpressions.add(and(comparisonExpression, whenOperand)); + invertedOperands.add(logicalRowExpressions.notCallExpression(whenOperand)); + } + RowExpression elseCondition = and( + getElseExpression(castExpression, value, elseResult, comparisonExpressionGenerator), + and(invertedOperands.build())); + andExpressions.add(elseCondition); + + return or(andExpressions.build()); + } + + private RowExpression getElseExpression(Optional castExpression, + RowExpression value, + Optional elseValue, + Function comparisonExpressionGenerator) + { + return elseValue.map( + elseVal -> comparisonExpressionGenerator.apply(castExpression.map(castExp -> replaceArguments(castExp, elseVal)).orElse(elseVal) + )).orElse(new SpecialFormExpression(IS_NULL, BOOLEAN, value)); + } + + private Optional getCaseOperand(RowExpression expression) + { + boolean searchedCase = (expression instanceof ConstantExpression && expression.getType() == BOOLEAN && + ((ConstantExpression) expression).getValue() == Boolean.TRUE); + return searchedCase ? Optional.empty() : Optional.of(expression); + } + + private boolean canRewriteSimpleCaseExpression(List whenClauses) + { + List whenOperands = whenClauses.stream() + .map(x -> ((SpecialFormExpression) x).getArguments().get(0)) + .collect(Collectors.toList()); + return allExpressionsAreConstantAndUnique(whenOperands); + } + + private boolean canRewriteSearchedCaseExpression(List whenClauses) + { + if (!allAreEqualsExpression(whenClauses) || !allLHSOperandsAreUnique(whenClauses)) { + return false; + } + List rhsExpressions = whenClauses.stream() + .map(whenClause -> ((SpecialFormExpression) whenClause).getArguments().get(0)) + .map(whenOperand -> ((CallExpression) whenOperand).getArguments().get(1)) + .collect(Collectors.toList()); + return allExpressionsAreConstantAndUnique(rhsExpressions); + } + + private boolean allLHSOperandsAreUnique(List whenClauses) + { + return whenClauses.stream() + .map(whenClause -> ((SpecialFormExpression) whenClause).getArguments().get(0)) + .map(whenOperand -> ((CallExpression) whenOperand).getArguments().get(0)) + .distinct() + .count() == 1; + } + + private boolean allAreEqualsExpression(List whenClauses) + { + return whenClauses.stream() + .map(whenClause -> ((SpecialFormExpression) whenClause).getArguments().get(0)) + .allMatch(logicalRowExpressions::isEqualsExpression); + } + + private boolean allExpressionsAreConstantAndUnique(List expressions) + { + Set expressionSet = new HashSet<>(); + for (RowExpression expression : expressions) { + if (!isConstantExpression(expression) || expressionSet.contains(expression)) { + return false; + } + expressionSet.add(expression); + } + return true; + } + + private boolean isConstantExpression(RowExpression expression) + { + if (logicalRowExpressions.isCastExpression(expression)) { + return isConstantExpression(((CallExpression) expression).getArguments().get(0)); + } + return expression instanceof ConstantExpression; + } + } + + @Override + public boolean isRewriterEnabled(Session session) + { + return isOptimizeCaseExpressionPredicate(session); + } + + @Override + public Set> rules() + { + return ImmutableSet.of( + filterRowExpressionRewriteRule(), + joinRowExpressionRewriteRule()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java index 7c668fcc7afca..19b315b51d1d6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java @@ -235,6 +235,12 @@ public boolean isComparisonFunction(FunctionHandle functionHandle) return operatorType.isPresent() && operatorType.get().isComparisonOperator(); } + public boolean isEqualsFunction(FunctionHandle functionHandle) + { + Optional operatorType = functionAndTypeManager.getFunctionMetadata(functionHandle).getOperatorType(); + return operatorType.isPresent() && operatorType.get().getOperator().equals(EQUAL.getOperator()); + } + @Override public FunctionHandle subscriptFunction(Type baseType, Type indexType) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 97165af968f74..f1f601057f0a6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -146,6 +146,7 @@ public void testDefaults() .setPartialAggregationStrategy(PartialAggregationStrategy.ALWAYS) .setPartialAggregationByteReductionThreshold(0.5) .setOptimizeTopNRowNumber(true) + .setOptimizeCaseExpressionPredicate(false) .setHistogramGroupImplementation(HistogramGroupImplementation.NEW) .setArrayAggGroupImplementation(ArrayAggGroupImplementation.NEW) .setMultimapAggGroupImplementation(MultimapAggGroupImplementation.NEW) @@ -296,6 +297,7 @@ public void testExplicitPropertyMappings() .put("optimizer.partial-aggregation-strategy", "automatic") .put("optimizer.partial-aggregation-byte-reduction-threshold", "0.8") .put("optimizer.optimize-top-n-row-number", "false") + .put("optimizer.optimize-case-expression-predicate", "true") .put("distributed-sort", "false") .put("analyzer.max-grouping-sets", "2047") .put("deprecated.legacy-unnest-array-rows", "true") @@ -436,6 +438,7 @@ public void testExplicitPropertyMappings() .setPartialAggregationStrategy(PartialAggregationStrategy.AUTOMATIC) .setPartialAggregationByteReductionThreshold(0.8) .setOptimizeTopNRowNumber(false) + .setOptimizeCaseExpressionPredicate(true) .setHistogramGroupImplementation(HistogramGroupImplementation.LEGACY) .setArrayAggGroupImplementation(ArrayAggGroupImplementation.LEGACY) .setMultimapAggGroupImplementation(MultimapAggGroupImplementation.LEGACY) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java new file mode 100644 index 0000000000000..740dd49c01aac --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java @@ -0,0 +1,197 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_CASE_EXPRESSION_PREDICATE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRewriteCaseExpressionPredicate + extends BaseRuleTest +{ + private static final MetadataManager METADATA = createTestMetadataManager(); + private static final Map TYPE_MAP = ImmutableMap.of("col1", INTEGER, "col2", INTEGER, "col3", VARCHAR); + + private final TestingRowExpressionTranslator testSqlToRowExpressionTranslator = new TestingRowExpressionTranslator(); + + @Test + public void testRewriterDoesNotFireOnPredicateWithoutCaseExpression() + { + assertRewriteDoesNotFire("col1 > 1"); + } + + @Test + public void testRewriterDoesNotFireOnPredicateWithoutComparisonFunction() + { + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end)"); + } + + @Test + public void testRewriterDoesNotFireOnPredicateWithFunctionCallOnComparisonValue() + { + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end) = upper('case1')"); + assertRewriteDoesNotFire("(case when col1=1 then 10 when col2=2 then 20 else 30 end) = ceil(col1)"); + } + + @Test + public void testRewriterDoesNotFireOnInvalidSearchCaseExpression() + { + // All LHS expressions are not the same + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end) = 'case1'"); + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when ceil(col1)=2 then 'case2' else 'default' end) = 'case1'"); + + // All expressions are not equals function + assertRewriteDoesNotFire("(case when col1>1 then 1 when col1>2 then 2 else 3 end) > 2"); + assertRewriteDoesNotFire("(case when col1<1 then 1 when col1<2 then 2 else 3 end) < 2"); + + // All RHS expressions are not Constant Expression + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col1=ceil(1) then 'case2' else 'default' end) = 'case1'"); + + // All RHS expressions are not unique + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col1=1 then 'case2' else 'default' end) = 'case1'"); + } + + @Test + public void testSimpleCaseExpressionRewrite() + { + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'case1'", + "('case1' = 'case1' AND col1 = 1) OR ('case2' = 'case1' AND col1 = 2) OR ('default' = 'case1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'case2'", + "('case1' = 'case2' AND col1 = 1) OR ('case2' = 'case2' AND col1 = 2) OR ('default' = 'case2' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'default'", + "('case1' = 'default' AND col1 = 1) OR ('case2' = 'default' AND col1 = 2) OR ('default' = 'default' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + } + + @Test + public void testSearchedCaseExpressionRewrite() + { + assertRewrittenExpression( + "(case when col1=1 then 'case1' when col1=2 then 'case2' else 'default' end) = 'case1'", + "('case1' = 'case1' AND col1 = 1) OR ('case2' = 'case1' AND col1 = 2) OR ('default' = 'case1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "(case when lower(col3)='a' then 'case1' when lower(col3)='b' then 'case2' else 'default' end) = 'case1'", + "('case1' = 'case1' AND lower(col3) = 'a') OR ('case2' = 'case1' AND lower(col3) = 'b') OR ('default' = 'case1' AND (NOT(lower(col3) = 'a') AND NOT(lower(col3) = 'b')))"); + + assertRewrittenExpression( + "(case when ceil(col1)=1 then 'case1' when ceil(col1)=2 then 'case2' else 'default' end) = 'default'", + "('case1' = 'default' AND ceil(col1) = 1) OR ('case2' = 'default' AND ceil(col1) = 2) OR ('default' = 'default' AND (NOT(ceil(col1) = 1) AND NOT(ceil(col1) = 2)))"); + } + + @Test + public void testRewriterOnCaseExpressionInRightSideOfComparisonFunction() + { + assertRewrittenExpression( + "(case col1 when 1 then 10 when 2 then 20 else 30 end) > 20", + "(10 > 20 AND col1 = 1) OR (20 > 20 AND col1 = 2) OR (30 > 20 AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "25 < (case col1 when 1 then 10 when 2 then 20 else 30 end)", + "(25 < 10 AND col1 = 1) OR (25 < 20 AND col1 = 2) OR (25 < 30 AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + } + + @Test + public void testRewriterWhenMoreThanOneConditionMatches() + { + assertRewrittenExpression( + "(case col1 when 1 then 'case' when 2 then 'case' else 'default' end) = 'case'", + "('case' = 'case' AND col1 = 1) OR ('case' = 'case' AND col1 = 2) OR ('default' = 'case' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "(case col1 when 1 then concat('default', 'AndCase1') when 2 then 'case2' else 'defaultAndCase1' end) = 'defaultAndCase1'", + "(concat('default', 'AndCase1') = 'defaultAndCase1' AND col1 = 1) OR ('case2' = 'defaultAndCase1' AND col1 = 2) OR ('defaultAndCase1' = 'defaultAndCase1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "(case col3 when 'data1' then 'case1' when 'data2' then 'case2' else col3 end) = 'case1'", + "('case1' = 'case1' AND col3 = 'data1') OR ('case2' = 'case1' AND col3 = 'data2') OR (col3 = 'case1' AND (NOT(col3 = 'data1') AND NOT(col3 = 'data2')))"); + } + + @Test + public void testRewriterOnCaseExpressionWithoutElseClause() + { + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case1'", + "('case1' = 'case1' AND col1 = 1) OR ('case2' = 'case1' AND col1 = 2) OR (null = 'case1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case3'", + "('case1' = 'case3' AND col1 = 1) OR ('case2' = 'case3' AND col1 = 2) OR (null = 'case3' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case2'", + "('case1' = 'case2' AND col1 = 1) OR ('case2' = 'case2' AND col1 = 2) OR (null = 'case2' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + } + + @Test + public void testRewriterOnCaseExpressionWithCastFunction() + { + // When left hand and right hand side of the expression are of different types, RowExpressionInterpreter identifies the common super type and adds a CAST function + assertRewrittenExpression( + "cast((case col1 when 1 then 'case11' when 2 then 'case2' else 'def' end) as VARCHAR(6)) = 'case11'", + "(cast('case11' as VARCHAR(6)) = 'case11' AND col1 = 1) OR (cast('case2' as VARCHAR(6)) = 'case11' AND col1 = 2) OR (cast('def' as VARCHAR(6)) = 'case11' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = cast('case1' AS VARCHAR)", + "('case1' = cast('case1' AS VARCHAR) AND col1 = 1) OR ('case2' = cast('case1' AS VARCHAR) AND col1 = 2) OR ('default' = cast('case1' AS VARCHAR) AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + + assertRewrittenExpression( + "(case when col1=cast('1' as INTEGER) then 'case1' when col1=cast('2' as INTEGER) then 'case2' else 'default' end) = 'case1'", + "('case1' = 'case1' AND col1 = cast('1' as INTEGER)) OR ('case2' = 'case1' AND col1 = cast('2' as INTEGER)) OR ('default' = 'case1' AND (NOT(col1 = cast('1' as INTEGER)) AND NOT(col1 = cast('2' as INTEGER))))"); + } + + @Test + public void testIfSubExpressionsAreRewritten() + { + assertRewrittenExpression( + "((case col1 when 1 then 'a' else 'b' end) = 'a') = true", + "(('a' = 'a' AND col1 = 1) OR ('b' = 'a' AND NOT(col1=1))) = true"); + } + + private void assertRewriteDoesNotFire(String expression) + { + tester().assertThat(new RewriteCaseExpressionPredicate(METADATA.getFunctionAndTypeManager()).filterRowExpressionRewriteRule()) + .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true") + .on(p -> p.filter(testSqlToRowExpressionTranslator.translate(expression, TYPE_MAP), p.values())) + .doesNotFire(); + } + + private void assertRewrittenExpression(String inputExpressionStr, + String expectedExpressionStr) + { + RowExpression inputExpression = testSqlToRowExpressionTranslator.translate(inputExpressionStr, TYPE_MAP); + + tester().assertThat(new RewriteCaseExpressionPredicate(METADATA.getFunctionAndTypeManager()).filterRowExpressionRewriteRule()) + .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true") + .on(p -> p.filter(inputExpression, p.values(p.variable("col1"), p.variable("col2"), p.variable("col3")))) + .matches(filter(expectedExpressionStr, values("col1", "col2", "col3"))); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java index 5cffddd8337d5..6f3e2956090cc 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java @@ -48,6 +48,8 @@ public interface StandardFunctionResolution boolean isComparisonFunction(FunctionHandle functionHandle); + boolean isEqualsFunction(FunctionHandle functionHandle); + FunctionHandle betweenFunction(Type valueType, Type lowerBoundType, Type upperBoundType); boolean isBetweenFunction(FunctionHandle functionHandle); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index a500dfbc21d93..bdd219b47665e 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -55,6 +55,7 @@ import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_FUNCTION; import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_PERCENTAGE; import static com.facebook.presto.SystemSessionProperties.OFFSET_CLAUSE_ENABLED; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_CASE_EXPRESSION_PREDICATE; import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_JOINS_WITH_EMPTY_SOURCES; import static com.facebook.presto.SystemSessionProperties.QUICK_DISTINCT_LIMIT_ENABLED; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -5499,6 +5500,42 @@ public void testSwitchOptimization() "SELECT CASE x WHEN 1 THEN 1 WHEN 5 THEN 5 WHEN 3 THEN 10 ELSE -1 END FROM (SELECT ORDERKEY x FROM orders where orderkey <= 10)"); } + @Test + public void testCasePredicateRewrite() + { + Session caseExpressionRewriteEnabled = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true") + .build(); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT LINENUMBER FROM LINEITEM WHERE (CASE WHEN QUANTITY <= 15 THEN 'SMALL' WHEN (QUANTITY > 15 AND QUANTITY <= 30) THEN 'MEDIUM' ELSE 'LARGE' END) = 'SMALL'"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT LINENUMBER FROM LINEITEM WHERE 'MEDIUM' = (CASE WHEN QUANTITY <= 15 THEN 'SMALL' WHEN (QUANTITY > 15 AND QUANTITY <= 30) THEN 'MEDIUM' ELSE 'LARGE' END)"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT LINENUMBER FROM LINEITEM WHERE (CASE WHEN QUANTITY <= 15 THEN 'SMALL' WHEN (QUANTITY > 15 AND QUANTITY <= 30) THEN 'MEDIUM' ELSE 'LARGE' END) = 'LARGE'"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT SUM(TOTALPRICE) FROM ORDERS WHERE (CASE ORDERSTATUS WHEN 'F' THEN 1 WHEN 'O' THEN 2 WHEN 'P' THEN 3 ELSE -1 END) = 2"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT SUM(TOTALPRICE) FROM ORDERS WHERE 1 < (CASE ORDERSTATUS WHEN 'F' THEN 1 WHEN 'O' THEN 2 WHEN 'P' THEN 3 ELSE -1 END)"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT SUM(TOTALPRICE) FROM ORDERS WHERE (CASE ORDERSTATUS WHEN 'F' THEN 1 WHEN 'O' THEN 2 WHEN 'P' THEN 3 ELSE 2 END) = 2"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT ORDERSTATUS, ORDERPRIORITY, TOTALPRICE FROM ORDERS WHERE (CASE WHEN ORDERSTATUS='F' THEN 1 WHEN (CASE WHEN ORDERPRIORITY = '5-LOW' THEN true ELSE false END) THEN 2 WHEN ORDERSTATUS='O' THEN 3 ELSE -1 END) > 1"); + } + @Test public void testSwitchReturnsNull() {