diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 4e6181b8de475..98a3d749c4135 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -737,6 +737,42 @@ Optimizer Properties .. warning:: The number of possible join orders scales factorially with the number of relations, so increasing this value can cause serious performance issues. +``optimizer.optimize-case-expression-predicate`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``false`` + + When set to true, CASE expression predicate gets simplified into a series of AND/OR clauses. + For example:: + + SELECT * FROM orders + WHERE (CASE + WHEN status=0 THEN ‘Pending’ + WHEN status=1 THEN ‘Complete’ + WHEN status=2 THEN ‘Returned’ + ELSE ‘Unknown’ + END) = ‘Pending’ + + + will get simplified into:: + + SELECT * FROM orders + WHERE status IS NOT NULL AND status=0; + + If the filter condition was to match the ELSE clause ‘Unknown’, it will get translated into:: + + SELECT * FROM orders + WHERE (status IS NULL OR (status!=0 AND status!=1 and status !=2)); + + The simplification avoids branching and string operations making it more efficient and also allows + predicate pushdown to happen avoiding a full table scan. This optimizer is to mainly address queries + generated by business intelligence tools like Looker that support human readable labels through + `case `_ statements. + + The optimization currently only applies to simple CASE expressions where the WHEN clause conditions are + unambiguous and deterministic and on the same column with the comparison operator being equals. + Planner Properties -------------------------------------- diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index ab74e105fe698..bfec7cbfe8f23 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -292,7 +292,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new RewriteCaseExpressionPredicate(metadata.getFunctionAndTypeManager()).rules()); + new RewriteCaseExpressionPredicate(metadata).rules()); PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, sqlParser)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java index 62fd299aa8086..0d29b5c925373 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java @@ -17,9 +17,10 @@ import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.expressions.RowExpressionRewriter; import com.facebook.presto.expressions.RowExpressionTreeRewriter; -import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.InputReferenceExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; @@ -45,6 +46,7 @@ import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN; +import static com.facebook.presto.sql.planner.RowExpressionInterpreter.evaluateConstantRowExpression; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -61,18 +63,18 @@ *

* can be converted into a series AND/OR clauses as below *

- * (result1 = value AND expression=constant1) OR - * (result2 = value AND expression=constant2 AND !(expression=constant1)) OR - * (result3 = value AND expression=constant3 AND !(expression=constant1) AND !(expression=constant2)) OR - * (elseResult = value AND !(expression=constant1) AND !(expression=constant2) AND !(expression=constant3)) + * (result1 = value AND expression IS NOT NULL AND expression=constant1) OR + * (result2 = value AND expression IS NOT NULL AND expression=constant2 AND !(expression=constant1)) OR + * (result3 = value AND expression IS NOT NULL AND expression=constant3 AND !(expression=constant1) AND !(expression=constant2)) OR + * (elseResult = value AND ((expression IS NULL) OR (!(expression=constant1) AND !(expression=constant2) AND !(expression=constant3)))) *

* The above conversion evaluates the conditions in WHEN clauses multiple times. But if we ensure these conditions are * disjunct, we can skip all the NOT of previous WHEN conditions and simplify the expression to: *

- * (result1 = value AND expression=constant1) OR - * (result2 = value AND expression=constant2) OR - * (result3 = value AND expression=constant3) OR - * (elseResult = value AND !(expression=constant1) AND !(expression=constant2) AND !(expression=constant3)) + * (result1 = value AND expression IS NOT NULL AND expression=constant1) OR + * (result2 = value AND expression IS NOT NULL AND expression=constant2) OR + * (result3 = value AND expression IS NOT NULL AND expression=constant3) OR + * (elseResult = value AND ((expression IS NULL) OR (!(expression=constant1) AND !(expression=constant2) AND !(expression=constant3))) *

* To ensure the WHEN conditions are disjunct, the following criteria needs to be met: * 1. Value is either a constant or column reference or input reference and not any function @@ -89,47 +91,56 @@ public class RewriteCaseExpressionPredicate extends RowExpressionRewriteRuleSet { - public RewriteCaseExpressionPredicate(FunctionAndTypeManager functionAndTypeManager) + public RewriteCaseExpressionPredicate(Metadata metadata) { - super(new Rewriter(functionAndTypeManager)); + super(new Rewriter(metadata)); } private static class Rewriter implements PlanRowExpressionRewriter { - private final CaseExpressionPredicateRewriter caseExpressionPredicateRewriter; + private final Metadata metadata; - public Rewriter(FunctionAndTypeManager functionAndTypeManager) + public Rewriter(Metadata metadata) { - requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); - this.caseExpressionPredicateRewriter = new CaseExpressionPredicateRewriter(functionAndTypeManager); + this.metadata = requireNonNull(metadata, "metadata is null"); } @Override public RowExpression rewrite(RowExpression expression, Rule.Context context) { - return RowExpressionTreeRewriter.rewriteWith(caseExpressionPredicateRewriter, expression); + return RowExpressionTreeRewriter.rewriteWith(new CaseExpressionPredicateRewriter(this.metadata, context.getSession()), expression); } } private static class CaseExpressionPredicateRewriter extends RowExpressionRewriter { + private final Metadata metadata; + private final Session session; private final FunctionResolution functionResolution; private final LogicalRowExpressions logicalRowExpressions; + private final DeterminismEvaluator determinismEvaluator; - private CaseExpressionPredicateRewriter(FunctionAndTypeManager functionAndTypeManager) + private CaseExpressionPredicateRewriter(Metadata metadata, Session session) { - this.functionResolution = new FunctionResolution(functionAndTypeManager); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.session = requireNonNull(session, "session is null"); + this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager()); this.logicalRowExpressions = new LogicalRowExpressions( - new RowExpressionDeterminismEvaluator(functionAndTypeManager), + new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager()), functionResolution, - functionAndTypeManager); + metadata.getFunctionAndTypeManager()); + this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager()); } @Override public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter treeRewriter) { + RowExpression rewritten = node; + if (!determinismEvaluator.isDeterministic(node)) { + return treeRewriter.defaultRewrite(rewritten, context); + } if (functionResolution.isComparisonFunction(node.getFunctionHandle()) && node.getArguments().size() == 2) { RowExpression left = node.getArguments().get(0); RowExpression right = node.getArguments().get(1); @@ -140,7 +151,7 @@ else if (isCaseExpression(right) && isSimpleExpression(left)) { return processCaseExpression(right, expression -> replaceArguments(node, left, expression), left); } } - return null; + return treeRewriter.defaultRewrite(rewritten, context); } private boolean isCaseExpression(RowExpression expression) @@ -221,6 +232,17 @@ private RowExpression processCaseExpression(SpecialFormExpression caseExpression ImmutableList.Builder andExpressions = new ImmutableList.Builder<>(); ImmutableList.Builder invertedOperands = new ImmutableList.Builder<>(); + RowExpression nullCheckExpression; + if (caseOperand.isPresent()) { + nullCheckExpression = new SpecialFormExpression(IS_NULL, BOOLEAN, caseOperand.get()); + } + else { + RowExpression whenOperand = whenClauses.stream().findFirst() + .map(whenClause -> ((SpecialFormExpression) whenClause).getArguments().get(0)) + .orElseThrow(() -> new IllegalArgumentException("When clause is empty")); + nullCheckExpression = new SpecialFormExpression(IS_NULL, BOOLEAN, ((CallExpression) whenOperand).getArguments().get(0)); + } + for (RowExpression whenClause : whenClauses) { RowExpression whenOperand = ((SpecialFormExpression) whenClause).getArguments().get(0); if (caseOperand.isPresent()) { @@ -232,12 +254,15 @@ private RowExpression processCaseExpression(SpecialFormExpression caseExpression } RowExpression comparisonExpression = comparisonExpressionGenerator.apply(whenResult); - andExpressions.add(and(comparisonExpression, whenOperand)); + andExpressions.add(and( + comparisonExpression, + logicalRowExpressions.notCallExpression(nullCheckExpression), + whenOperand)); invertedOperands.add(logicalRowExpressions.notCallExpression(whenOperand)); } RowExpression elseCondition = and( getElseExpression(castExpression, value, elseResult, comparisonExpressionGenerator), - and(invertedOperands.build())); + or(nullCheckExpression, and(invertedOperands.build()))); andExpressions.add(elseCondition); return or(andExpressions.build()); @@ -298,12 +323,16 @@ private boolean allAreEqualsExpression(List whenClauses) private boolean allExpressionsAreConstantAndUnique(List expressions) { - Set expressionSet = new HashSet<>(); + Set literals = new HashSet<>(); for (RowExpression expression : expressions) { - if (!isConstantExpression(expression) || expressionSet.contains(expression)) { + if (!isConstantExpression(expression)) { + return false; + } + Object constantExpression = evaluateConstantRowExpression(expression, metadata, session.toConnectorSession()); + if (constantExpression == null || literals.contains(constantExpression)) { return false; } - expressionSet.add(expression); + literals.add(constantExpression); } return true; } @@ -327,6 +356,7 @@ public boolean isRewriterEnabled(Session session) public Set> rules() { return ImmutableSet.of( + projectRowExpressionRewriteRule(), filterRowExpressionRewriteRule(), joinRowExpressionRewriteRule()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java index 740dd49c01aac..674dc4b1a8e9c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java @@ -53,17 +53,22 @@ public void testRewriterDoesNotFireOnPredicateWithoutComparisonFunction() @Test public void testRewriterDoesNotFireOnPredicateWithFunctionCallOnComparisonValue() { - assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end) = upper('case1')"); + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end) = UPPER('case1')"); + assertRewriteDoesNotFire("(case when col1=1 then 10 when col1=2 then 20 else 30 end) = ceil(col1)"); assertRewriteDoesNotFire("(case when col1=1 then 10 when col2=2 then 20 else 30 end) = ceil(col1)"); } @Test - public void testRewriterDoesNotFireOnInvalidSearchCaseExpression() + public void testRewriterDoesNotFireOnSearchCaseExpressionThatDoesNotMeetRewriteConditions() { // All LHS expressions are not the same assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end) = 'case1'"); assertRewriteDoesNotFire("(case when col1=1 then 'case1' when ceil(col1)=2 then 'case2' else 'default' end) = 'case1'"); + // Any expression is non deterministic + assertRewriteDoesNotFire("(case when random(col1)=1 then 'case1' when random(col1)=2 then 'case2' else 'default' end) = 'case1'"); + assertRewriteDoesNotFire("(case when col1=1 then 2 when col1=2 then 3 else 4 end) = rand()"); + // All expressions are not equals function assertRewriteDoesNotFire("(case when col1>1 then 1 when col1>2 then 2 else 3 end) > 2"); assertRewriteDoesNotFire("(case when col1<1 then 1 when col1<2 then 2 else 3 end) < 2"); @@ -73,6 +78,10 @@ public void testRewriterDoesNotFireOnInvalidSearchCaseExpression() // All RHS expressions are not unique assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col1=1 then 'case2' else 'default' end) = 'case1'"); + assertRewriteDoesNotFire("(case when col1=CAST(1 as SMALLINT) then 'case1' when col1=CAST(1 as TINYINT) then 'case2' else 'default' end) = 'case1'"); + + // RHS expression is NULL + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col1=NULL then 'case2' else 'default' end) = 'case1'"); } @Test @@ -80,15 +89,15 @@ public void testSimpleCaseExpressionRewrite() { assertRewrittenExpression( "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'case1'", - "('case1' = 'case1' AND col1 = 1) OR ('case2' = 'case1' AND col1 = 2) OR ('default' = 'case1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "('case1' = 'case1' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case1' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'case1' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'case2'", - "('case1' = 'case2' AND col1 = 1) OR ('case2' = 'case2' AND col1 = 2) OR ('default' = 'case2' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "('case1' = 'case2' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case2' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'case2' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'default'", - "('case1' = 'default' AND col1 = 1) OR ('case2' = 'default' AND col1 = 2) OR ('default' = 'default' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "('case1' = 'default' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'default' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'default' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); } @Test @@ -96,15 +105,15 @@ public void testSearchedCaseExpressionRewrite() { assertRewrittenExpression( "(case when col1=1 then 'case1' when col1=2 then 'case2' else 'default' end) = 'case1'", - "('case1' = 'case1' AND col1 = 1) OR ('case2' = 'case1' AND col1 = 2) OR ('default' = 'case1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "('case1' = 'case1' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case1' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'case1' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( - "(case when lower(col3)='a' then 'case1' when lower(col3)='b' then 'case2' else 'default' end) = 'case1'", - "('case1' = 'case1' AND lower(col3) = 'a') OR ('case2' = 'case1' AND lower(col3) = 'b') OR ('default' = 'case1' AND (NOT(lower(col3) = 'a') AND NOT(lower(col3) = 'b')))"); + "(case when col3='a' then 'case1' when col3='b' then 'case2' else 'default' end) = 'case1'", + "('case1' = 'case1' AND col3 IS NOT NULL AND col3 = 'a') OR ('case2' = 'case1' AND col3 IS NOT NULL AND col3 = 'b') OR ('default' = 'case1' AND (col3 IS NULL OR (NOT(col3 = 'a') AND NOT(col3 = 'b'))))"); assertRewrittenExpression( - "(case when ceil(col1)=1 then 'case1' when ceil(col1)=2 then 'case2' else 'default' end) = 'default'", - "('case1' = 'default' AND ceil(col1) = 1) OR ('case2' = 'default' AND ceil(col1) = 2) OR ('default' = 'default' AND (NOT(ceil(col1) = 1) AND NOT(ceil(col1) = 2)))"); + "(case when col1=1 then 'case1' when col1=2 then 'case2' else 'default' end) = 'default'", + "('case1' = 'default' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'default' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'default' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); } @Test @@ -112,11 +121,11 @@ public void testRewriterOnCaseExpressionInRightSideOfComparisonFunction() { assertRewrittenExpression( "(case col1 when 1 then 10 when 2 then 20 else 30 end) > 20", - "(10 > 20 AND col1 = 1) OR (20 > 20 AND col1 = 2) OR (30 > 20 AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "(10 > 20 AND col1 IS NOT NULL AND col1 = 1) OR (20 > 20 AND col1 IS NOT NULL AND col1 = 2) OR (30 > 20 AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( "25 < (case col1 when 1 then 10 when 2 then 20 else 30 end)", - "(25 < 10 AND col1 = 1) OR (25 < 20 AND col1 = 2) OR (25 < 30 AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "(25 < 10 AND col1 IS NOT NULL AND col1 = 1) OR (25 < 20 AND col1 IS NOT NULL AND col1 = 2) OR (25 < 30 AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); } @Test @@ -124,15 +133,15 @@ public void testRewriterWhenMoreThanOneConditionMatches() { assertRewrittenExpression( "(case col1 when 1 then 'case' when 2 then 'case' else 'default' end) = 'case'", - "('case' = 'case' AND col1 = 1) OR ('case' = 'case' AND col1 = 2) OR ('default' = 'case' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "('case' = 'case' AND col1 IS NOT NULL AND col1 = 1) OR ('case' = 'case' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'case' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( "(case col1 when 1 then concat('default', 'AndCase1') when 2 then 'case2' else 'defaultAndCase1' end) = 'defaultAndCase1'", - "(concat('default', 'AndCase1') = 'defaultAndCase1' AND col1 = 1) OR ('case2' = 'defaultAndCase1' AND col1 = 2) OR ('defaultAndCase1' = 'defaultAndCase1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "(concat('default', 'AndCase1') = 'defaultAndCase1' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'defaultAndCase1' AND col1 IS NOT NULL AND col1 = 2) OR ('defaultAndCase1' = 'defaultAndCase1' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( "(case col3 when 'data1' then 'case1' when 'data2' then 'case2' else col3 end) = 'case1'", - "('case1' = 'case1' AND col3 = 'data1') OR ('case2' = 'case1' AND col3 = 'data2') OR (col3 = 'case1' AND (NOT(col3 = 'data1') AND NOT(col3 = 'data2')))"); + "('case1' = 'case1' AND col3 IS NOT NULL AND col3 = 'data1') OR ('case2' = 'case1' AND col3 IS NOT NULL AND col3 = 'data2') OR (col3 = 'case1' AND (col3 IS NULL OR (NOT(col3 = 'data1') AND NOT(col3 = 'data2'))))"); } @Test @@ -140,15 +149,15 @@ public void testRewriterOnCaseExpressionWithoutElseClause() { assertRewrittenExpression( "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case1'", - "('case1' = 'case1' AND col1 = 1) OR ('case2' = 'case1' AND col1 = 2) OR (null = 'case1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "('case1' = 'case1' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case1' AND col1 IS NOT NULL AND col1 = 2) OR (null = 'case1' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case3'", - "('case1' = 'case3' AND col1 = 1) OR ('case2' = 'case3' AND col1 = 2) OR (null = 'case3' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "('case1' = 'case3' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case3' AND col1 IS NOT NULL AND col1 = 2) OR (null = 'case3' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case2'", - "('case1' = 'case2' AND col1 = 1) OR ('case2' = 'case2' AND col1 = 2) OR (null = 'case2' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "('case1' = 'case2' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case2' AND col1 IS NOT NULL AND col1 = 2) OR (null = 'case2' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); } @Test @@ -157,15 +166,15 @@ public void testRewriterOnCaseExpressionWithCastFunction() // When left hand and right hand side of the expression are of different types, RowExpressionInterpreter identifies the common super type and adds a CAST function assertRewrittenExpression( "cast((case col1 when 1 then 'case11' when 2 then 'case2' else 'def' end) as VARCHAR(6)) = 'case11'", - "(cast('case11' as VARCHAR(6)) = 'case11' AND col1 = 1) OR (cast('case2' as VARCHAR(6)) = 'case11' AND col1 = 2) OR (cast('def' as VARCHAR(6)) = 'case11' AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "(cast('case11' as VARCHAR(6)) = 'case11' AND col1 IS NOT NULL AND col1 = 1) OR (cast('case2' as VARCHAR(6)) = 'case11' AND col1 IS NOT NULL AND col1 = 2) OR (cast('def' as VARCHAR(6)) = 'case11' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = cast('case1' AS VARCHAR)", - "('case1' = cast('case1' AS VARCHAR) AND col1 = 1) OR ('case2' = cast('case1' AS VARCHAR) AND col1 = 2) OR ('default' = cast('case1' AS VARCHAR) AND (NOT(col1 = 1) AND NOT(col1 = 2)))"); + "('case1' = cast('case1' AS VARCHAR) AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = cast('case1' AS VARCHAR) AND col1 IS NOT NULL AND col1 = 2) OR ('default' = cast('case1' AS VARCHAR) AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); assertRewrittenExpression( "(case when col1=cast('1' as INTEGER) then 'case1' when col1=cast('2' as INTEGER) then 'case2' else 'default' end) = 'case1'", - "('case1' = 'case1' AND col1 = cast('1' as INTEGER)) OR ('case2' = 'case1' AND col1 = cast('2' as INTEGER)) OR ('default' = 'case1' AND (NOT(col1 = cast('1' as INTEGER)) AND NOT(col1 = cast('2' as INTEGER))))"); + "('case1' = 'case1' AND col1 IS NOT NULL AND col1 = cast('1' as INTEGER)) OR ('case2' = 'case1' AND col1 IS NOT NULL AND col1 = cast('2' as INTEGER)) OR ('default' = 'case1' AND (col1 IS NULL OR (NOT(col1 = cast('1' as INTEGER)) AND NOT(col1 = cast('2' as INTEGER)))))"); } @Test @@ -173,12 +182,12 @@ public void testIfSubExpressionsAreRewritten() { assertRewrittenExpression( "((case col1 when 1 then 'a' else 'b' end) = 'a') = true", - "(('a' = 'a' AND col1 = 1) OR ('b' = 'a' AND NOT(col1=1))) = true"); + "(('a' = 'a' AND col1 IS NOT NULL AND col1 = 1) OR ('b' = 'a' AND (col1 IS NULL OR NOT(col1=1)))) = true"); } private void assertRewriteDoesNotFire(String expression) { - tester().assertThat(new RewriteCaseExpressionPredicate(METADATA.getFunctionAndTypeManager()).filterRowExpressionRewriteRule()) + tester().assertThat(new RewriteCaseExpressionPredicate(METADATA).filterRowExpressionRewriteRule()) .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true") .on(p -> p.filter(testSqlToRowExpressionTranslator.translate(expression, TYPE_MAP), p.values())) .doesNotFire(); @@ -189,7 +198,7 @@ private void assertRewrittenExpression(String inputExpressionStr, { RowExpression inputExpression = testSqlToRowExpressionTranslator.translate(inputExpressionStr, TYPE_MAP); - tester().assertThat(new RewriteCaseExpressionPredicate(METADATA.getFunctionAndTypeManager()).filterRowExpressionRewriteRule()) + tester().assertThat(new RewriteCaseExpressionPredicate(METADATA).filterRowExpressionRewriteRule()) .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true") .on(p -> p.filter(inputExpression, p.values(p.variable("col1"), p.variable("col2"), p.variable("col3")))) .matches(filter(expectedExpressionStr, values("col1", "col2", "col3"))); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index aa6e94d59c0b5..bf5c85fce6aed 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -5599,6 +5599,21 @@ public void testCasePredicateRewrite() assertQuery( caseExpressionRewriteEnabled, "SELECT ORDERSTATUS, ORDERPRIORITY, TOTALPRICE FROM ORDERS WHERE (CASE WHEN ORDERSTATUS='F' THEN 1 WHEN (CASE WHEN ORDERPRIORITY = '5-LOW' THEN true ELSE false END) THEN 2 WHEN ORDERSTATUS='O' THEN 3 ELSE -1 END) > 1"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = 1 THEN 'a' WHEN col = 2 THEN 'b' ELSE 'c' END) = 'a' FROM (VALUES NULL, 1, 2, 3) t(col)"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = 1 THEN 'a' WHEN col = 2 THEN 'b' ELSE 'c' END) = 'b' FROM (VALUES NULL, 1, 2, 3) t(col)"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = 1 THEN 'a' WHEN col = 2 THEN 'b' ELSE 'c' END) = 'c' FROM (VALUES NULL, 1, 2, 3) t(col)"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = NULL THEN 'a' WHEN col = 1 THEN 'b' ELSE 'c' END) = 'a' FROM (VALUES NULL, 1, 2, 3) t(col)"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = 1 THEN 'a' WHEN col = 2 THEN 'b' END) = NULL FROM (VALUES NULL, 1, 2, 3) t(col)"); } @Test