diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index ce28de087701..b42be381c117 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -99,6 +99,7 @@ public final class SystemSessionProperties public static final String QUERY_PRIORITY = "query_priority"; public static final String SPILL_ENABLED = "spill_enabled"; public static final String AGGREGATION_OPERATOR_UNSPILL_MEMORY_LIMIT = "aggregation_operator_unspill_memory_limit"; + public static final String OPTIMIZE_CASE_EXPRESSION_PREDICATE = "optimize_case_expression_predicate"; public static final String OPTIMIZE_DISTINCT_AGGREGATIONS = "optimize_mixed_distinct_aggregations"; public static final String ITERATIVE_OPTIMIZER_TIMEOUT = "iterative_optimizer_timeout"; public static final String ENABLE_FORCED_EXCHANGE_BELOW_GROUP_ID = "enable_forced_exchange_below_group_id"; @@ -448,6 +449,11 @@ public SystemSessionProperties( "How much memory should be allocated per aggregation operator in unspilling process", featuresConfig.getAggregationOperatorUnspillMemoryLimit(), false), + booleanProperty( + OPTIMIZE_CASE_EXPRESSION_PREDICATE, + "Optimize case expression predicates", + optimizerConfig.isOptimizeCaseExpressionPredicate(), + false), booleanProperty( OPTIMIZE_DISTINCT_AGGREGATIONS, "Optimize mixed non-distinct and distinct aggregations", @@ -1112,6 +1118,11 @@ public static DataSize getAggregationOperatorUnspillMemoryLimit(Session session) return memoryLimitForMerge; } + public static boolean isOptimizeCaseExpressionPredicate(Session session) + { + return session.getSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, Boolean.class); + } + public static boolean isOptimizeDistinctAggregationEnabled(Session session) { return session.getSystemProperty(OPTIMIZE_DISTINCT_AGGREGATIONS, Boolean.class); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/CaseExpressionPredicateRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/CaseExpressionPredicateRewriter.java new file mode 100644 index 000000000000..a00615986463 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/CaseExpressionPredicateRewriter.java @@ -0,0 +1,273 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.security.AllowAllAccessControl; +import io.trino.spi.type.Type; +import io.trino.sql.PlannerContext; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.tree.Cast; +import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.ExpressionRewriter; +import io.trino.sql.tree.ExpressionTreeRewriter; +import io.trino.sql.tree.IsNotNullPredicate; +import io.trino.sql.tree.IsNullPredicate; +import io.trino.sql.tree.NodeRef; +import io.trino.sql.tree.NotExpression; +import io.trino.sql.tree.SearchedCaseExpression; +import io.trino.sql.tree.SimpleCaseExpression; +import io.trino.sql.tree.WhenClause; + +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static io.trino.sql.ExpressionUtils.and; +import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; +import static io.trino.sql.ExpressionUtils.or; +import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression; +import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; +import static java.util.Objects.requireNonNull; + +public class CaseExpressionPredicateRewriter +{ + private CaseExpressionPredicateRewriter() {} + + public static Expression rewrite( + Expression expression, + Rule.Context context, + PlannerContext plannerContext, + TypeAnalyzer typeAnalyzer) + { + requireNonNull(context, "context is null"); + requireNonNull(plannerContext, "plannerContext is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + + return ExpressionTreeRewriter.rewriteWith(new Visitor(plannerContext, typeAnalyzer, context.getSession(), context.getSymbolAllocator()), expression); + } + + private static class Visitor + extends ExpressionRewriter + { + private final PlannerContext plannerContext; + private final LiteralEncoder literalEncoder; + private final TypeAnalyzer typeAnalyzer; + private final Session session; + private final SymbolAllocator symbolAllocator; + + public Visitor(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, Session session, SymbolAllocator symbolAllocator) + { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.literalEncoder = new LiteralEncoder(this.plannerContext); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + this.session = requireNonNull(session, "session is null"); + this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); + } + + @Override + public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + Expression rewritten = node; + if (!DeterminismEvaluator.isDeterministic(node, plannerContext.getMetadata())) { + return treeRewriter.defaultRewrite(rewritten, context); + } + else if (isCaseExpression(node.getLeft()) && isEffectivelyLiteral(plannerContext, session, node.getRight())) { + rewritten = processCaseExpression(node.getLeft(), node.getRight(), node.getOperator()).orElse(node); + } + else if (isCaseExpression(node.getRight()) && isEffectivelyLiteral(plannerContext, session, node.getLeft())) { + rewritten = processCaseExpression(node.getRight(), node.getLeft(), node.getOperator().flip()).orElse(node); + } + return treeRewriter.defaultRewrite(rewritten, context); + } + + private boolean isCaseExpression(Expression expression) + { + if (expression instanceof Cast) { + expression = ((Cast) expression).getExpression(); + } + return expression instanceof SimpleCaseExpression || expression instanceof SearchedCaseExpression; + } + + private Optional processCaseExpression(Expression expression, Expression otherExpression, ComparisonExpression.Operator operator) + { + Expression caseExpression = expression; + Optional castExpression = Optional.empty(); + if (expression instanceof Cast) { + castExpression = Optional.of((Cast) expression); + caseExpression = castExpression.get().getExpression(); + } + return caseExpression instanceof SimpleCaseExpression ? + processSimpleCaseExpression((SimpleCaseExpression) caseExpression, castExpression, otherExpression, operator) : + processSearchedCaseExpression((SearchedCaseExpression) caseExpression, castExpression, otherExpression, operator); + } + + private Optional processSimpleCaseExpression( + SimpleCaseExpression caseExpression, + Optional castExpression, + Expression otherExpression, + ComparisonExpression.Operator operator) + { + if (!canRewriteSimpleCaseExpression(caseExpression)) { + return Optional.empty(); + } + return processCaseExpression( + castExpression, + caseExpression.getWhenClauses(), + caseExpression.getDefaultValue(), + whenClause -> new ComparisonExpression(EQUAL, caseExpression.getOperand(), whenClause.getOperand()), + otherExpression, + operator, + caseExpression.getOperand()); + } + + private boolean canRewriteSimpleCaseExpression(SimpleCaseExpression caseExpression) + { + List whenOperands = caseExpression.getWhenClauses().stream() + .map(WhenClause::getOperand) + .collect(Collectors.toList()); + return checkNonNullUniqueLiterals(whenOperands, getType(whenOperands)); + } + + private Type getType(List expressions) + { + Expression expression = expressions.stream().findFirst().orElseThrow(); + return typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression).get(NodeRef.of(expression)); + } + + private Optional processSearchedCaseExpression( + SearchedCaseExpression caseExpression, + Optional castExpression, + Expression otherExpression, + ComparisonExpression.Operator operator) + { + if (!canRewriteSearchedCaseExpression(caseExpression)) { + return Optional.empty(); + } + return processCaseExpression( + castExpression, + caseExpression.getWhenClauses(), + caseExpression.getDefaultValue(), + WhenClause::getOperand, + otherExpression, + operator, + getCommonOperand(caseExpression)); + } + + private Expression getCommonOperand(SearchedCaseExpression caseExpression) + { + return caseExpression.getWhenClauses().stream() + .map(x -> ((ComparisonExpression) x.getOperand()).getLeft()) + .findFirst().orElseThrow(); + } + + private boolean canRewriteSearchedCaseExpression(SearchedCaseExpression caseExpression) + { + ImmutableList.Builder rightHandSideExpressions = ImmutableList.builder(); + ImmutableList.Builder leftHandSideExpressions = ImmutableList.builder(); + for (WhenClause whenClause : caseExpression.getWhenClauses()) { + Expression whenOperand = whenClause.getOperand(); + if (!(whenOperand instanceof ComparisonExpression)) { + return false; + } + ComparisonExpression whenComparisonFunction = (ComparisonExpression) whenOperand; + Expression left = whenComparisonFunction.getLeft(); + Expression right = whenComparisonFunction.getRight(); + + if (!whenComparisonFunction.getOperator().equals(EQUAL)) { + return false; + } + leftHandSideExpressions.add(left); + rightHandSideExpressions.add(right); + } + List rightHandExpressions = rightHandSideExpressions.build(); + return checkAllAreSimilar(leftHandSideExpressions.build()) && + checkNonNullUniqueLiterals(rightHandExpressions, getType(rightHandExpressions)); + } + + private boolean checkAllAreSimilar(List expressions) + { + return expressions.stream().distinct().count() <= 1; + } + + private boolean checkNonNullUniqueLiterals(List expressions, Type type) + { + Set literals = new HashSet<>(); + for (Expression expression : expressions) { + if (!isEffectivelyLiteral(plannerContext, session, expression)) { + return false; + } + Object constantExpression = evaluateConstantExpression(expression, type, plannerContext, session, new AllowAllAccessControl(), ImmutableMap.of()); + if (constantExpression == null || literals.contains(constantExpression)) { + return false; + } + literals.add(constantExpression); + } + return true; + } + + private Optional processCaseExpression( + Optional castExpression, + List whenClauses, + Optional defaultValue, + Function operandExtractor, + Expression otherExpression, + ComparisonExpression.Operator operator, + Expression commonOperand) + { + ImmutableList.Builder andExpressions = ImmutableList.builder(); + ImmutableList.Builder invertedOperands = ImmutableList.builder(); + + for (WhenClause whenClause : whenClauses) { + Expression whenOperand = operandExtractor.apply(whenClause); + Expression whenResult = getCastExpression(castExpression, whenClause.getResult()); + + andExpressions.add(and( + new ComparisonExpression(operator, whenResult, otherExpression), + new IsNotNullPredicate(commonOperand), + whenOperand)); + invertedOperands.add(new NotExpression(whenOperand)); + } + + Expression defaultExpression = defaultValue + .map(value -> getCastExpression(castExpression, value)) + .orElse(getNullLiteral(otherExpression)); + + Expression elseResult = new ComparisonExpression(operator, defaultExpression, otherExpression); + + andExpressions.add(and(elseResult, or(new IsNullPredicate(commonOperand), and(invertedOperands.build())))); + + return Optional.of(or(andExpressions.build())); + } + + private Expression getCastExpression(Optional castExpression, Expression expression) + { + return castExpression + .map(cast -> (Expression) new Cast(expression, cast.getType(), cast.isSafe(), cast.isTypeOnly())) + .orElse(expression); + } + + private Expression getNullLiteral(Expression expression) + { + Type type = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression).get(NodeRef.of(expression)); + return this.literalEncoder.toExpression(session, null, type); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java index faf821a80da4..dd1c861eccb2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java @@ -83,6 +83,7 @@ public class OptimizerConfig private boolean mergeProjectWithValues = true; private boolean forceSingleNodeOutput; private boolean useExactPartitioning; + private boolean optimizeCaseExpressionPredicate; // adaptive partial aggregation private boolean adaptivePartialAggregationEnabled = true; private long adaptivePartialAggregationMinRows = 100_000; @@ -692,6 +693,18 @@ public OptimizerConfig setForceSingleNodeOutput(boolean value) return this; } + public boolean isOptimizeCaseExpressionPredicate() + { + return optimizeCaseExpressionPredicate; + } + + @Config("optimizer.optimize-case-expression-predicate") + public OptimizerConfig setOptimizeCaseExpressionPredicate(boolean optimizeCaseExpressionPredicate) + { + this.optimizeCaseExpressionPredicate = optimizeCaseExpressionPredicate; + return this; + } + public boolean isAdaptivePartialAggregationEnabled() { return adaptivePartialAggregationEnabled; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 5787730d0133..c691de97e822 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -212,6 +212,7 @@ import io.trino.sql.planner.iterative.rule.ReplaceRedundantJoinWithProject; import io.trino.sql.planner.iterative.rule.ReplaceRedundantJoinWithSource; import io.trino.sql.planner.iterative.rule.ReplaceWindowWithRowNumber; +import io.trino.sql.planner.iterative.rule.RewriteCaseExpressionPredicate; import io.trino.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; import io.trino.sql.planner.iterative.rule.RewriteTableFunctionToTableScan; import io.trino.sql.planner.iterative.rule.SimplifyCountOverConstant; @@ -373,6 +374,7 @@ public PlanOptimizers( .addAll(new RemoveRedundantDateTrunc(plannerContext, typeAnalyzer).rules()) .addAll(new ArraySortAfterArrayDistinct(plannerContext).rules()) .add(new RemoveTrivialFilters()) + .addAll(new RewriteCaseExpressionPredicate(plannerContext, typeAnalyzer).rules()) .build(); IterativeOptimizer simplifyOptimizer = new IterativeOptimizer( plannerContext, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index ee20ff805798..448e8f0de0b1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.trino.Session; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.sql.planner.OrderingScheme; @@ -79,6 +80,11 @@ public ExpressionRewriteRuleSet(ExpressionRewriter rewriter) this.rewriter = requireNonNull(rewriter, "rewriter is null"); } + public boolean isRewriterEnabled(Session session) + { + return true; + } + public Set> rules() { return ImmutableSet.of( @@ -120,7 +126,7 @@ public Rule patternRecognitionExpressionRewrite() return new PatternRecognitionExpressionRewrite(rewriter); } - private static final class ProjectExpressionRewrite + private final class ProjectExpressionRewrite implements Rule { private final ExpressionRewriter rewriter; @@ -130,6 +136,12 @@ private static final class ProjectExpressionRewrite this.rewriter = rewriter; } + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -153,7 +165,7 @@ public String toString() } } - private static final class AggregationExpressionRewrite + private final class AggregationExpressionRewrite implements Rule { private final ExpressionRewriter rewriter; @@ -163,6 +175,12 @@ private static final class AggregationExpressionRewrite this.rewriter = rewriter; } + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -223,7 +241,7 @@ public String toString() } } - private static final class FilterExpressionRewrite + private final class FilterExpressionRewrite implements Rule { private final ExpressionRewriter rewriter; @@ -233,6 +251,12 @@ private static final class FilterExpressionRewrite this.rewriter = rewriter; } + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -256,7 +280,7 @@ public String toString() } } - private static final class JoinExpressionRewrite + private final class JoinExpressionRewrite implements Rule { private final ExpressionRewriter rewriter; @@ -266,6 +290,12 @@ private static final class JoinExpressionRewrite this.rewriter = rewriter; } + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -304,7 +334,7 @@ public String toString() } } - private static final class ValuesExpressionRewrite + private final class ValuesExpressionRewrite implements Rule { private final ExpressionRewriter rewriter; @@ -314,6 +344,12 @@ private static final class ValuesExpressionRewrite this.rewriter = rewriter; } + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { @@ -358,7 +394,7 @@ public String toString() } } - private static final class PatternRecognitionExpressionRewrite + private final class PatternRecognitionExpressionRewrite implements Rule { private final ExpressionRewriter rewriter; @@ -368,6 +404,12 @@ private static final class PatternRecognitionExpressionRewrite this.rewriter = rewriter; } + @Override + public boolean isEnabled(Session session) + { + return isRewriterEnabled(session); + } + @Override public Pattern getPattern() { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java new file mode 100644 index 000000000000..6746d0d67cf1 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteCaseExpressionPredicate.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.sql.planner.iterative.rule; + +import io.trino.Session; +import io.trino.sql.PlannerContext; +import io.trino.sql.planner.CaseExpressionPredicateRewriter; +import io.trino.sql.planner.TypeAnalyzer; + +import static io.trino.SystemSessionProperties.isOptimizeCaseExpressionPredicate; + +/** + * This Rule rewrites a CASE expression predicate into a series of AND/OR clauses. + * The following CASE expression + *

+ * (CASE + * WHEN expression=constant1 THEN result1 + * WHEN expression=constant2 THEN result2 + * WHEN expression=constant3 THEN result3 + * ELSE elseResult + * END) = value + *

+ * can be converted into a series AND/OR clauses as below + *

+ * (result1 = value AND expression IS NOT NULL AND expression=constant1) OR + * (result2 = value AND expression IS NOT NULL AND expression=constant2 AND !(expression=constant1)) OR + * (result3 = value AND expression IS NOT NULL AND expression=constant3 AND !(expression=constant1) AND !(expression=constant2)) OR + * (elseResult = value AND ((expression IS NULL) OR (!(expression=constant1) AND !(expression=constant2) AND !(expression=constant3)))) + *

+ * The above conversion evaluates the conditions in WHEN clauses multiple times. But if we ensure these conditions are + * disjunct, we can skip all the NOT of previous WHEN conditions and simplify the expression to: + *

+ * (result1 = value AND expression IS NOT NULL AND expression=constant1) OR + * (result2 = value AND expression IS NOT NULL AND expression=constant2) OR + * (result3 = value AND expression IS NOT NULL AND expression=constant3) OR + * (elseResult = value AND ((expression IS NULL) OR (!(expression=constant1) AND !(expression=constant2) AND !(expression=constant3))) + *

+ * To ensure the WHEN conditions are disjunct, the following criteria needs to be met: + * 1. Value is either a constant or column reference or input reference and not any function + * 2. The LHS expression in all WHEN clauses are the same and deterministic + * For example, if one WHEN clause has a expression using col1 and another using col2, it will not work + * 3. The relational operator in the WHEN clause is equals. With other operators it is hard to check for exclusivity. + * 4. All the RHS expressions are a constant, non-null and unique + *

+ * This conversion is done so that it is easy for the ExpressionInterpreter & other Optimizers to further + * simplify this and construct a domain for the column that can be used by Readers . + * i.e, ExpressionInterpreter can discard all conditions in which result != value and + * RowExpressionDomainTranslator can construct a Domain for the column + */ +public class RewriteCaseExpressionPredicate + extends ExpressionRewriteRuleSet +{ + public RewriteCaseExpressionPredicate(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) + { + super((expression, context) -> CaseExpressionPredicateRewriter.rewrite(expression, context, plannerContext, typeAnalyzer)); + } + + @Override + public boolean isRewriterEnabled(Session session) + { + return isOptimizeCaseExpressionPredicate(session); + } +} diff --git a/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java b/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java index 12306998472d..63ee7b1305ac 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java @@ -85,6 +85,7 @@ public void testDefaults() .setTableScanNodePartitioningMinBucketToTaskRatio(0.5) .setMergeProjectWithValues(true) .setForceSingleNodeOutput(false) + .setOptimizeCaseExpressionPredicate(false) .setAdaptivePartialAggregationEnabled(true) .setAdaptivePartialAggregationMinRows(100_000) .setAdaptivePartialAggregationUniqueRowsRatioThreshold(0.8) @@ -142,6 +143,7 @@ public void testExplicitPropertyMappings() .put("optimizer.use-table-scan-node-partitioning", "false") .put("optimizer.table-scan-node-partitioning-min-bucket-to-task-ratio", "0.0") .put("optimizer.merge-project-with-values", "false") + .put("optimizer.optimize-case-expression-predicate", "true") .put("adaptive-partial-aggregation.enabled", "false") .put("adaptive-partial-aggregation.min-rows", "1") .put("adaptive-partial-aggregation.unique-rows-ratio-threshold", "0.99") @@ -196,6 +198,7 @@ public void testExplicitPropertyMappings() .setTableScanNodePartitioningMinBucketToTaskRatio(0.0) .setMergeProjectWithValues(false) .setForceSingleNodeOutput(true) + .setOptimizeCaseExpressionPredicate(true) .setAdaptivePartialAggregationEnabled(false) .setAdaptivePartialAggregationMinRows(1) .setAdaptivePartialAggregationUniqueRowsRatioThreshold(0.99) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java new file mode 100644 index 000000000000..d76bf460c608 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRewriteCaseExpressionPredicate.java @@ -0,0 +1,218 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableMap; +import io.trino.spi.type.Type; +import io.trino.sql.ExpressionTestUtils; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.tree.Expression; +import org.testng.annotations.Test; + +import static io.trino.SystemSessionProperties.OPTIMIZE_CASE_EXPRESSION_PREDICATE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRewriteCaseExpressionPredicate + extends BaseRuleTest +{ + private static final TypeProvider INPUT_TYPES = TypeProvider.copyOf(ImmutableMap.builder() + .put(new Symbol("col1"), INTEGER) + .put(new Symbol("col2"), INTEGER) + .put(new Symbol("col3"), VARCHAR) + .buildOrThrow()); + + @Test + public void testRewriterDoesNotFireOnPredicateWithoutCaseExpression() + { + assertRewriteDoesNotFire("col1 > 1"); + } + + @Test + public void testRewriterDoesNotFireOnPredicateWithoutComparisonFunction() + { + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end)"); + } + + @Test + public void testRewriterDoesNotFireOnPredicateWithFunctionCallOnComparisonValue() + { + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end) = UPPER('case1')"); + assertRewriteDoesNotFire("(case when col1=1 then 10 when col1=2 then 20 else 30 end) = ceil(col1)"); + assertRewriteDoesNotFire("(case when col1=1 then 10 when col2=2 then 20 else 30 end) = ceil(col1)"); + } + + @Test + public void testRewriterDoesNotFireOnSearchCaseExpressionThatDoesNotMeetRewriteConditions() + { + // All LHS expressions are not the same + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end) = 'case1'"); + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when ceil(col1)=2 then 'case2' else 'default' end) = 'case1'"); + + // Any expression is non deterministic + assertRewriteDoesNotFire("(case when random(col1)=1 then 'case1' when random(col1)=2 then 'case2' else 'default' end) = 'case1'"); + assertRewriteDoesNotFire("(case when col1=1 then 1.0 when col1=2 then 2.0 else 3.0 end) = rand()"); + + // All expressions are not equals function + assertRewriteDoesNotFire("(case when col1>1 then 1 when col1>2 then 2 else 3 end) > 2"); + assertRewriteDoesNotFire("(case when col1<1 then 1 when col1<2 then 2 else 3 end) < 2"); + + // All RHS expressions are not Constant Expression + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col1=ceil(1) then 'case2' else 'default' end) = 'case1'"); + + // All RHS expressions are not unique + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col1=1 then 'case2' else 'default' end) = 'case1'"); + assertRewriteDoesNotFire("(case when col1=CAST(1 as SMALLINT) then 'case1' when col1=CAST(1 as TINYINT) then 'case2' else 'default' end) = 'case1'"); + + // RHS expression is NULL + assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col1=NULL then 'case2' else 'default' end) = 'case1'"); + } + + @Test + public void testSimpleCaseExpressionRewrite() + { + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'case1'", + "('case1' = 'case1' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case1' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'case1' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'case2'", + "('case1' = 'case2' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case2' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'case2' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'default'", + "('case1' = 'default' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'default' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'default' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + } + + @Test + public void testSearchedCaseExpressionRewrite() + { + assertRewrittenExpression( + "(case when col1=1 then 'case1' when col1=2 then 'case2' else 'default' end) = 'case1'", + "('case1' = 'case1' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case1' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'case1' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "(case when col3='a' then 'case1' when col3='b' then 'case2' else 'default' end) = 'case1'", + "('case1' = 'case1' AND col3 IS NOT NULL AND col3 = 'a') OR ('case2' = 'case1' AND col3 IS NOT NULL AND col3 = 'b') OR ('default' = 'case1' AND (col3 IS NULL OR (NOT(col3 = 'a') AND NOT(col3 = 'b'))))"); + + assertRewrittenExpression( + "(case when col1=1 then 'case1' when col1=2 then 'case2' else 'default' end) = 'default'", + "('case1' = 'default' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'default' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'default' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + } + + @Test + public void testRewriterOnCaseExpressionInRightSideOfComparisonFunction() + { + assertRewrittenExpression( + "(case col1 when 1 then 10 when 2 then 20 else 30 end) > 20", + "(10 > 20 AND col1 IS NOT NULL AND col1 = 1) OR (20 > 20 AND col1 IS NOT NULL AND col1 = 2) OR (30 > 20 AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "25 < (case col1 when 1 then 10 when 2 then 20 else 30 end)", + "(25 < 10 AND col1 IS NOT NULL AND col1 = 1) OR (25 < 20 AND col1 IS NOT NULL AND col1 = 2) OR (25 < 30 AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + } + + @Test + public void testRewriterWhenMoreThanOneConditionMatches() + { + assertRewrittenExpression( + "(case col1 when 1 then 'case' when 2 then 'case' else 'default' end) = 'case'", + "('case' = 'case' AND col1 IS NOT NULL AND col1 = 1) OR ('case' = 'case' AND col1 IS NOT NULL AND col1 = 2) OR ('default' = 'case' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'defaultAndCase1' when 2 then 'case2' else 'defaultAndCase1' end) = 'defaultAndCase1'", + "('defaultAndCase1' = 'defaultAndCase1' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'defaultAndCase1' AND col1 IS NOT NULL AND col1 = 2) OR ('defaultAndCase1' = 'defaultAndCase1' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "(case col3 when 'data1' then 'case1' when 'data2' then 'case2' else col3 end) = 'case1'", + "('case1' = 'case1' AND col3 IS NOT NULL AND col3 = 'data1') OR ('case2' = 'case1' AND col3 IS NOT NULL AND col3 = 'data2') OR (col3 = 'case1' AND (col3 IS NULL OR (NOT(col3 = 'data1') AND NOT(col3 = 'data2'))))"); + } + + @Test + public void testRewriterOnCaseExpressionWithoutElseClause() + { + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case1'", + "('case1' = 'case1' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case1' AND col1 IS NOT NULL AND col1 = 2) OR (CAST(null as VARCHAR(5)) = 'case1' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case3'", + "('case1' = 'case3' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case3' AND col1 IS NOT NULL AND col1 = 2) OR (CAST(null as VARCHAR(5)) = 'case3' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case2'", + "('case1' = 'case2' AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = 'case2' AND col1 IS NOT NULL AND col1 = 2) OR (CAST(null as VARCHAR(5)) = 'case2' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + } + + @Test + public void testRewriterOnCaseExpressionWithCastFunction() + { + // When left hand and right hand side of the expression are of different types, RowExpressionInterpreter identifies the common super type and adds a CAST function + assertRewrittenExpression( + "cast((case col1 when 1 then 'case11' when 2 then 'case2' else 'def' end) as VARCHAR(6)) = 'case11'", + "(cast('case11' as VARCHAR(6)) = 'case11' AND col1 IS NOT NULL AND col1 = 1) OR (cast('case2' as VARCHAR(6)) = 'case11' AND col1 IS NOT NULL AND col1 = 2) OR (cast('def' as VARCHAR(6)) = 'case11' AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = cast('case1' AS VARCHAR)", + "('case1' = cast('case1' AS VARCHAR) AND col1 IS NOT NULL AND col1 = 1) OR ('case2' = cast('case1' AS VARCHAR) AND col1 IS NOT NULL AND col1 = 2) OR ('default' = cast('case1' AS VARCHAR) AND (col1 IS NULL OR (NOT(col1 = 1) AND NOT(col1 = 2))))"); + + assertRewrittenExpression( + "(case when col1=cast('1' as INTEGER) then 'case1' when col1=cast('2' as INTEGER) then 'case2' else 'default' end) = 'case1'", + "('case1' = 'case1' AND col1 IS NOT NULL AND col1 = cast('1' as INTEGER)) OR ('case2' = 'case1' AND col1 IS NOT NULL AND col1 = cast('2' as INTEGER)) OR ('default' = 'case1' AND (col1 IS NULL OR (NOT(col1 = cast('1' as INTEGER)) AND NOT(col1 = cast('2' as INTEGER)))))"); + } + + @Test + public void testIfSubExpressionsAreRewritten() + { + assertRewrittenExpression( + "((case col1 when 1 then 'a' else 'b' end) = 'a') = true", + "(('a' = 'a' AND col1 IS NOT NULL AND col1 = 1) OR ('b' = 'a' AND (col1 IS NULL OR NOT(col1=1)))) = true"); + + assertRewrittenExpression( + "(case when col2=2 then 'a' when ((case col1 when 1 then 'a' else 'b' end) = 'a') then 'b' else 'c' end) = 'a'", + "(case when col2=2 then 'a' when (('a' = 'a' AND col1 IS NOT NULL AND col1 = 1) OR ('b' = 'a' AND (col1 IS NULL OR NOT(col1=1)))) then 'b' else 'c' end) = 'a'"); + } + + private void assertRewriteDoesNotFire(String expression) + { + tester().assertThat(new RewriteCaseExpressionPredicate(tester().getPlannerContext(), tester().getTypeAnalyzer()).filterExpressionRewrite()) + .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true") + .on(p -> p.filter(ExpressionTestUtils.createExpression(tester().getSession(), expression, tester().getPlannerContext(), INPUT_TYPES), p.values())) + .doesNotFire(); + } + + private void assertRewrittenExpression( + String inputExpression, + String expectedRewritten) + { + assertRewrittenExpression( + PlanBuilder.expression(inputExpression), + PlanBuilder.expression(expectedRewritten)); + } + + private void assertRewrittenExpression( + Expression inputExpression, + Expression expectedRewritten) + { + tester().assertThat(new RewriteCaseExpressionPredicate(tester().getPlannerContext(), tester().getTypeAnalyzer()).filterExpressionRewrite()) + .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true") + .on(p -> p.filter(inputExpression, p.values(p.symbol("col1"), p.symbol("col2"), p.symbol("col3")))) + .matches(filter(expectedRewritten, values("col1", "col2", "col3"))); + } +} diff --git a/docs/src/main/sphinx/admin/properties-optimizer.rst b/docs/src/main/sphinx/admin/properties-optimizer.rst index 43c6292780dc..69865c766c86 100644 --- a/docs/src/main/sphinx/admin/properties-optimizer.rst +++ b/docs/src/main/sphinx/admin/properties-optimizer.rst @@ -210,3 +210,41 @@ The minimum number of join build side rows required to use partitioned join look If the build side of a join is estimated to be smaller than the configured threshold, single threaded join lookup is used to improve join performance. A value of ``0`` disables this optimization. + +``optimizer.optimize-case-expression-predicate`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``false`` + +When set to true, CASE expression predicate gets simplified into a series of AND/OR clauses. +For example:: + +.. code-block:: sql + + SELECT * FROM orders + WHERE (CASE + WHEN status=0 THEN ‘Pending’ + WHEN status=1 THEN ‘Complete’ + WHEN status=2 THEN ‘Returned’ + ELSE ‘Unknown’ + END) = ‘Pending’; + + +will get simplified into:: + + SELECT * FROM orders + WHERE status IS NOT NULL AND status=0; + +If the filter condition was to match the ELSE clause ‘Unknown’, it will get translated into:: + + SELECT * FROM orders + WHERE (status IS NULL OR (status!=0 AND status!=1 and status !=2)); + +The simplification avoids branching and string operations making it more efficient and also allows +predicate pushdown to happen avoiding a full table scan. This optimizer is to mainly address queries +generated by business intelligence tools like Looker that support human readable labels through +`case `_ statements. + +The optimization currently only applies to simple CASE expressions where the WHEN clause conditions are +unambiguous and deterministic and on the same column with the comparison operator being equals. diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueries.java index fc7c86e16d67..631b7dad4949 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueries.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.trino.Session; import io.trino.metadata.FunctionBundle; import io.trino.metadata.InternalFunctionBundle; import io.trino.tpch.TpchTable; @@ -27,6 +28,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.SystemSessionProperties.OPTIMIZE_CASE_EXPRESSION_PREDICATE; import static io.trino.connector.informationschema.InformationSchemaTable.INFORMATION_SCHEMA; import static io.trino.operator.scalar.ApplyFunction.APPLY_FUNCTION; import static io.trino.operator.scalar.InvokeFunction.INVOKE_FUNCTION; @@ -502,4 +504,55 @@ public void testFilterPushdownWithAggregation() assertQuery("SELECT * FROM (SELECT count(*) FROM orders) WHERE 0=1"); assertQuery("SELECT * FROM (SELECT count(*) FROM orders) WHERE null"); } + + @Test + public void testCasePredicateRewrite() + { + Session caseExpressionRewriteEnabled = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true") + .build(); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT nationkey FROM NATION WHERE (CASE WHEN regionkey <= 1 THEN 'SMALL' WHEN (regionkey > 1 AND regionkey <= 3) THEN 'MEDIUM' ELSE 'LARGE' END) = 'SMALL'"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT nationkey FROM NATION WHERE 'MEDIUM' = (CASE WHEN regionkey <= 1 THEN 'SMALL' WHEN (regionkey > 1 AND regionkey <= 3) THEN 'MEDIUM' ELSE 'LARGE' END)"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT nationkey FROM NATION WHERE (CASE WHEN regionkey <= 1 THEN 'SMALL' WHEN (regionkey > 1 AND regionkey <= 3) THEN 'MEDIUM' ELSE 'LARGE' END) = 'LARGE'"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT nationkey FROM NATION WHERE (CASE name WHEN 'PERU' THEN 1 WHEN 'CHINA' THEN 2 WHEN 'P' THEN 3 ELSE -1 END) = 2"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT nationkey FROM NATION WHERE 1 < (CASE name WHEN 'PERU' THEN 1 WHEN 'CHINA' THEN 2 WHEN 'P' THEN 3 ELSE -1 END)"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT nationkey FROM NATION WHERE (CASE name WHEN 'FRANCE' THEN 1 WHEN 'KENYA' THEN 2 WHEN 'P' THEN 3 ELSE 2 END) = 2"); + + assertQuery( + caseExpressionRewriteEnabled, + "SELECT nationkey FROM NATION WHERE (CASE WHEN regionkey=1 THEN 1 WHEN (CASE WHEN name = 'FRANCE' THEN true ELSE false END) THEN 2 WHEN regionkey=2 THEN 3 ELSE -1 END) > 1"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = 1 THEN 'a' WHEN col = 2 THEN 'b' ELSE 'c' END) = 'a' FROM (VALUES NULL, 1, 2, 3) t(col)"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = 1 THEN 'a' WHEN col = 2 THEN 'b' ELSE 'c' END) = 'b' FROM (VALUES NULL, 1, 2, 3) t(col)"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = 1 THEN 'a' WHEN col = 2 THEN 'b' ELSE 'c' END) = 'c' FROM (VALUES NULL, 1, 2, 3) t(col)"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = NULL THEN 'a' WHEN col = 1 THEN 'b' ELSE 'c' END) = 'a' FROM (VALUES NULL, 1, 2, 3) t(col)"); + + assertQuery(caseExpressionRewriteEnabled, + "SELECT (CASE WHEN col = 1 THEN 'a' WHEN col = 2 THEN 'b' END) = NULL FROM (VALUES NULL, 1, 2, 3) t(col)"); + } }