diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/NormalizeOrExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/NormalizeOrExpressionRewriter.java index 1c5f2aaa4e0f..e26b23bf2878 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/NormalizeOrExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/NormalizeOrExpressionRewriter.java @@ -14,6 +14,8 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; import io.trino.sql.tree.ExpressionRewriter; @@ -22,19 +24,17 @@ import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.LogicalExpression; -import java.util.LinkedHashMap; +import java.util.Collection; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.sql.ExpressionUtils.and; import static io.trino.sql.ExpressionUtils.or; import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.tree.LogicalExpression.Operator.AND; -import static java.util.stream.Collectors.groupingBy; -import static java.util.stream.Collectors.mapping; public final class NormalizeOrExpressionRewriter { @@ -59,35 +59,70 @@ public Expression rewriteLogicalExpression(LogicalExpression node, Void context, return and(terms); } - List comparisons = terms.stream() - .filter(NormalizeOrExpressionRewriter::isEqualityComparisonExpression) - .map(ComparisonExpression.class::cast) - .collect(groupingBy( - ComparisonExpression::getLeft, - LinkedHashMap::new, - mapping(ComparisonExpression::getRight, Collectors.toList()))) - .entrySet().stream() - .filter(entry -> entry.getValue().size() > 1) - .map(entry -> new InPredicate(entry.getKey(), new InListExpression(entry.getValue()))) - .collect(Collectors.toList()); + ImmutableList.Builder inPredicateBuilder = ImmutableList.builder(); + ImmutableSet.Builder expressionToSkipBuilder = ImmutableSet.builder(); + ImmutableList.Builder othersExpressionBuilder = ImmutableList.builder(); + groupComparisonAndInPredicate(terms).forEach((expression, values) -> { + if (values.size() > 1) { + inPredicateBuilder.add(new InPredicate(expression, mergeToInListExpression(values))); + expressionToSkipBuilder.add(expression); + } + }); - Set expressionToSkip = comparisons.stream() - .map(InPredicate::getValue) - .collect(toImmutableSet()); - - List others = terms.stream() - .filter(expression -> !isEqualityComparisonExpression(expression) || !expressionToSkip.contains(((ComparisonExpression) expression).getLeft())) - .collect(Collectors.toList()); + Set expressionToSkip = expressionToSkipBuilder.build(); + for (Expression expression : terms) { + if (expression instanceof ComparisonExpression comparisonExpression && comparisonExpression.getOperator() == EQUAL) { + if (!expressionToSkip.contains(comparisonExpression.getLeft())) { + othersExpressionBuilder.add(expression); + } + } + else if (expression instanceof InPredicate inPredicate && inPredicate.getValueList() instanceof InListExpression) { + if (!expressionToSkip.contains(inPredicate.getValue())) { + othersExpressionBuilder.add(expression); + } + } + else { + othersExpressionBuilder.add(expression); + } + } return or(ImmutableList.builder() - .addAll(others) - .addAll(comparisons) + .addAll(othersExpressionBuilder.build()) + .addAll(inPredicateBuilder.build()) .build()); } - } - private static boolean isEqualityComparisonExpression(Expression expression) - { - return expression instanceof ComparisonExpression && ((ComparisonExpression) expression).getOperator() == EQUAL; + private InListExpression mergeToInListExpression(Collection expressions) + { + LinkedHashSet expressionValues = new LinkedHashSet<>(); + for (Expression expression : expressions) { + if (expression instanceof ComparisonExpression comparisonExpression && comparisonExpression.getOperator() == EQUAL) { + expressionValues.add(comparisonExpression.getRight()); + } + else if (expression instanceof InPredicate inPredicate && inPredicate.getValueList() instanceof InListExpression valueList) { + expressionValues.addAll(valueList.getValues()); + } + else { + throw new IllegalStateException("Unexpected expression: " + expression); + } + } + + return new InListExpression(ImmutableList.copyOf(expressionValues)); + } + + private Map> groupComparisonAndInPredicate(List terms) + { + ImmutableMultimap.Builder expressionBuilder = ImmutableMultimap.builder(); + for (Expression expression : terms) { + if (expression instanceof ComparisonExpression comparisonExpression && comparisonExpression.getOperator() == EQUAL) { + expressionBuilder.put(comparisonExpression.getLeft(), comparisonExpression); + } + else if (expression instanceof InPredicate inPredicate && inPredicate.getValueList() instanceof InListExpression) { + expressionBuilder.put(inPredicate.getValue(), inPredicate); + } + } + + return expressionBuilder.build().asMap(); + } } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java index eeca60d51d3c..127cb83de946 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -383,10 +383,23 @@ public void testPushesDownNegationsNumericTypes() public void testRewriteOrExpression() { assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 ", "I1 IN (1, 2)"); - // TODO: Implement rule for Merging IN expression - assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I1 IN (3, 4)", "I1 IN (3, 4) OR I1 IN (1, 2)"); + assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I1 IN (3, 4)", "I1 IN (1, 2, 3, 4)"); assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I1 = I2", "I1 IN (1, 2, I2)"); assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I2 = 3 OR I2 = 4", "I1 IN (1, 2) OR I2 IN (3, 4)"); + assertSimplifiesNumericTypes("I1 = 1 OR I1 IN (1, 2)", "I1 IN (1, 2)"); + assertSimplifiesNumericTypes("I1 = 1 OR I2 IN (1, 2) OR I2 IN (2, 3)", "I1 = 1 OR I2 IN (1, 2, 3)"); + assertSimplifiesNumericTypes("I1 IN (1)", "I1 = 1"); + assertSimplifiesNumericTypes("I1 = 1 OR I1 IN (1)", "I1 = 1"); + assertSimplifiesNumericTypes("I1 = 1 OR I1 IN (2)", "I1 IN (1, 2)"); + assertSimplifiesNumericTypes("I1 IN (1, 2) OR I1 = 1", "I1 IN (1, 2)"); + assertSimplifiesNumericTypes("I1 IN (1, 2) OR I2 = 1 OR I1 = 3 OR I2 = 4", "I1 IN (1, 2, 3) OR I2 IN (1, 4)"); + assertSimplifiesNumericTypes("I1 IN (1, 2) OR I1 = 3 OR I1 IN (4, 5, 6) OR I2 = 3 OR I2 IN (3, 4)", "I1 IN (1, 2, 3, 4, 5, 6) OR I2 IN (3, 4)"); + + assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I1 IN (3, 4) OR I1 IN (SELECT 1)", "I1 IN (1, 2, 3, 4) OR I1 IN (SELECT 1)"); + assertSimplifiesNumericTypes("I1 = 1 OR I2 = 2 OR I1 = 3 OR I2 = 4", "I1 IN (1, 3) OR I2 IN (2, 4)"); + assertSimplifiesNumericTypes("I1 = 1 OR I2 = 2 OR I1 = 3 OR I2 IS NULL", "I1 IN(1, 3) OR I2 = 2 OR I2 IS NULL"); + assertSimplifiesNumericTypes("I1 = 1 OR I2 IN (2, 3) OR I1 = 4 OR I2 IN (5, 6)", "I1 IN (1, 4) OR I2 IN (2, 3, 5, 6)"); + assertSimplifiesNumericTypes("I1 = 1 OR I2 = 2 OR I1 = 3 OR I2 = I1", "I1 IN (1, 3) OR I2 IN (2, I1)"); } private static void assertSimplifiesNumericTypes(String expression, String expected)