diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 9920ac65bf80..c5ae41c93e16 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -89,7 +89,6 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.Deque; import java.util.HashSet; import java.util.LinkedHashMap; @@ -112,6 +111,7 @@ import static java.lang.Boolean.FALSE; import static java.lang.String.format; import static java.util.Collections.emptyList; +import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableMap; import static java.util.Collections.unmodifiableSet; import static java.util.Objects.requireNonNull; @@ -1073,7 +1073,7 @@ public void addRowFilter(Table table, Expression filter) public List getRowFilters(Table node) { - return rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of()); + return unmodifiableList(rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of())); } public boolean hasColumnMask(QualifiedObjectName table, String column, String identity) @@ -1101,7 +1101,7 @@ public void addColumnMask(Table table, String column, Expression mask) public Map getColumnMasks(Table table) { - return columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of()); + return unmodifiableMap(columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of())); } public List getReferencedTables() @@ -1571,22 +1571,22 @@ public void addQuantifiedComparisons(List expres public List getInPredicatesSubqueries() { - return Collections.unmodifiableList(inPredicatesSubqueries); + return unmodifiableList(inPredicatesSubqueries); } public List getSubqueries() { - return Collections.unmodifiableList(subqueries); + return unmodifiableList(subqueries); } public List getExistsSubqueries() { - return Collections.unmodifiableList(existsSubqueries); + return unmodifiableList(existsSubqueries); } public List getQuantifiedComparisonSubqueries() { - return Collections.unmodifiableList(quantifiedComparisonSubqueries); + return unmodifiableList(quantifiedComparisonSubqueries); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index b1fbf93b1e42..70bc48c193ff 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -1033,7 +1033,7 @@ private PlanBuilder filter(PlanBuilder subPlan, Expression predicate, Node node) subPlan = subqueryPlanner.handleSubqueries(subPlan, predicate, analysis.getSubqueries(node)); - return subPlan.withNewRoot(new FilterNode(idAllocator.getNextId(), subPlan.getRoot(), subPlan.rewrite(predicate))); + return subPlan.withNewRoot(new FilterNode(idAllocator.getNextId(), subPlan.getRoot(), coerceIfNecessary(analysis, predicate, subPlan.rewrite(predicate)))); } private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index d3e7c56fb3dd..5d213b8f5385 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -279,7 +279,7 @@ public RelationPlan addRowFilters(Table node, RelationPlan plan, Function getPattern() public Result apply(FilterNode filterNode, Captures captures, Context context) { Expression predicate = filterNode.getPredicate(); + checkArgument(!(predicate instanceof NullLiteral), "Unexpected null literal without a cast to boolean"); if (predicate.equals(TRUE_LITERAL)) { return Result.ofPlanNode(filterNode.getSource()); } - if (predicate.equals(FALSE_LITERAL) || predicate instanceof NullLiteral) { + if (predicate.equals(FALSE_LITERAL) || + (predicate instanceof Cast cast && cast.getExpression() instanceof NullLiteral)) { return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), filterNode.getOutputSymbols(), emptyList())); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java index ed935944f519..60441272eacf 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java @@ -20,6 +20,7 @@ import io.trino.sql.tree.Expression; import io.trino.sql.tree.Join; import io.trino.sql.tree.Node; +import io.trino.sql.tree.NullLiteral; import javax.annotation.concurrent.Immutable; @@ -100,6 +101,9 @@ public CorrelatedJoinNode( requireNonNull(subquery, "subquery is null"); requireNonNull(correlation, "correlation is null"); requireNonNull(filter, "filter is null"); + // The condition doesn't guarantee that filter is of type boolean, but was found to be a practical way to identify + // places where CorrelatedJoinNode could be created without appropriate coercions. + checkArgument(!(filter instanceof NullLiteral), "Filter must be an expression of boolean type: %s", filter); requireNonNull(originSubquery, "originSubquery is null"); checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/FilterNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/FilterNode.java index a589aa5f5634..e7e8d935829d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/FilterNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/FilterNode.java @@ -19,11 +19,15 @@ import com.google.common.collect.Iterables; import io.trino.sql.planner.Symbol; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.NullLiteral; import javax.annotation.concurrent.Immutable; import java.util.List; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + @Immutable public class FilterNode extends PlanNode @@ -39,6 +43,10 @@ public FilterNode(@JsonProperty("id") PlanNodeId id, super(id); this.source = source; + requireNonNull(predicate, "predicate is null"); + // The condition doesn't guarantee that predicate is of type boolean, but was found to be a practical way to identify + // places where FilterNode was created without appropriate coercions. + checkArgument(!(predicate instanceof NullLiteral), "Predicate must be an expression of boolean type: %s", predicate); this.predicate = predicate; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java index 6b8c00f7f49c..e470e443d4e3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java @@ -23,6 +23,7 @@ import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; import io.trino.sql.tree.Join; +import io.trino.sql.tree.NullLiteral; import javax.annotation.concurrent.Immutable; @@ -89,6 +90,9 @@ public JoinNode( requireNonNull(leftOutputSymbols, "leftOutputSymbols is null"); requireNonNull(rightOutputSymbols, "rightOutputSymbols is null"); requireNonNull(filter, "filter is null"); + // The condition doesn't guarantee that filter is of type boolean, but was found to be a practical way to identify + // places where JoinNode could be created without appropriate coercions. + checkArgument(filter.isEmpty() || !(filter.get() instanceof NullLiteral), "Filter must be an expression of boolean type: %s", filter); requireNonNull(leftHashSymbol, "leftHashSymbol is null"); requireNonNull(rightHashSymbol, "rightHashSymbol is null"); requireNonNull(distributionType, "distributionType is null"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index ef5d17a2ddf9..16bc55c7d079 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -29,7 +29,6 @@ import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; import io.trino.sql.planner.assertions.BasePlanTest; -import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.assertions.RowNumberSymbolMatcher; import io.trino.sql.planner.optimizations.AddLocalExchanges; @@ -106,6 +105,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId; import static io.trino.sql.planner.assertions.PlanMatchPattern.constrainedTableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.constrainedTableScanWithTableLayout; +import static io.trino.sql.planner.assertions.PlanMatchPattern.correlatedJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.equiJoinClause; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -813,6 +813,31 @@ public void testCorrelatedJoinWithTopN() anyTree(tableScan("nation", ImmutableMap.of("nation_name", "name", "nation_regionkey", "regionkey"))))))))); } + @Test + public void testCorrelatedJoinWithNullCondition() + { + assertPlan( + "SELECT regionkey, n.name FROM region LEFT JOIN LATERAL (SELECT name FROM nation) n ON NULL", + CREATED, + anyTree( + correlatedJoin( + List.of("r_row_number", "r_regionkey", "r_name", "r_comment"), + "CAST(null AS boolean)", + tableScan("region", Map.of( + "r_row_number", "row_number", + "r_regionkey", "regionkey", + "r_name", "name", + "r_comment", "comment")), + anyTree(tableScan("nation"))))); + assertPlan( + "SELECT regionkey, n.name FROM region LEFT JOIN LATERAL (SELECT name FROM nation) n ON NULL", + any( + join(LEFT, builder -> builder + .equiCriteria(List.of()) + .left(tableScan("region")) + .right(values("name"))))); + } + @Test public void testCorrelatedScalarSubqueryInSelect() { @@ -1109,6 +1134,14 @@ public void testRemovesTrivialFilters() "SELECT * FROM nation WHERE 1 = 0", output( values("nationkey", "name", "regionkey", "comment"))); + assertPlan( + "SELECT * FROM nation WHERE null", + output( + values("nationkey", "name", "regionkey", "comment"))); + assertPlan( + "SELECT * FROM nation WHERE nationkey = null", + output( + values("nationkey", "name", "regionkey", "comment"))); } @Test @@ -1488,7 +1521,7 @@ public void testOffset() "SELECT name FROM nation OFFSET 2 ROWS", any( strictProject( - ImmutableMap.of("name", new ExpressionMatcher("name")), + ImmutableMap.of("name", expression("name")), filter( "row_num > BIGINT '2'", rowNumber( @@ -1502,7 +1535,7 @@ public void testOffset() "SELECT name FROM nation ORDER BY regionkey OFFSET 2 ROWS", any( strictProject( - ImmutableMap.of("name", new ExpressionMatcher("name")), + ImmutableMap.of("name", expression("name")), filter( "row_num > BIGINT '2'", rowNumber( @@ -1519,7 +1552,7 @@ public void testOffset() "SELECT name FROM nation ORDER BY regionkey OFFSET 2 ROWS FETCH NEXT 5 ROWS ONLY", any( strictProject( - ImmutableMap.of("name", new ExpressionMatcher("name")), + ImmutableMap.of("name", expression("name")), filter( "row_num > BIGINT '2'", rowNumber( @@ -1538,7 +1571,7 @@ public void testOffset() "SELECT name FROM nation OFFSET 2 ROWS FETCH NEXT 5 ROWS ONLY", any( strictProject( - ImmutableMap.of("name", new ExpressionMatcher("name")), + ImmutableMap.of("name", expression("name")), filter( "row_num > BIGINT '2'", rowNumber( @@ -1558,7 +1591,7 @@ public void testWithTies() "SELECT name, regionkey FROM nation ORDER BY regionkey FETCH FIRST 6 ROWS WITH TIES", any( strictProject( - ImmutableMap.of("name", new ExpressionMatcher("name"), "regionkey", new ExpressionMatcher("regionkey")), + ImmutableMap.of("name", expression("name"), "regionkey", expression("regionkey")), topNRanking( pattern -> pattern .specification( @@ -1578,14 +1611,14 @@ public void testWithTies() "SELECT name, regionkey FROM nation ORDER BY regionkey OFFSET 10 ROWS FETCH FIRST 6 ROWS WITH TIES", any( strictProject( - ImmutableMap.of("name", new ExpressionMatcher("name"), "regionkey", new ExpressionMatcher("regionkey")), + ImmutableMap.of("name", expression("name"), "regionkey", expression("regionkey")), filter( "row_num > BIGINT '10'", rowNumber( pattern -> pattern .partitionBy(ImmutableList.of()), strictProject( - ImmutableMap.of("name", new ExpressionMatcher("name"), "regionkey", new ExpressionMatcher("regionkey")), + ImmutableMap.of("name", expression("name"), "regionkey", expression("regionkey")), topNRanking( pattern -> pattern .specification( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java index b55b5594e39a..73ced38af4ec 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java @@ -791,32 +791,19 @@ private void testNoUnwrap(String inputType, String inputPredicate, String expect private void testNoUnwrap(Session session, String inputType, String inputPredicate, String expectedCastType) { - String sql = format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE a %s", inputType, inputPredicate); - try { - assertPlan(sql, - session, - output( - filter(format("CAST(a AS %s) %s", expectedCastType, inputPredicate), - values("a")))); - } - catch (Throwable e) { - e.addSuppressed(new Exception("Query: " + sql)); - throw e; - } + assertPlan(format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE a %s", inputType, inputPredicate), + session, + output( + filter(format("CAST(a AS %s) %s", expectedCastType, inputPredicate), + values("a")))); } private void testRemoveFilter(String inputType, String inputPredicate) { - String sql = format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s", inputType, inputPredicate); - try { - assertPlan(sql, - output( - values("a"))); - } - catch (Throwable e) { - e.addSuppressed(new Exception("Query: " + sql)); - throw e; - } + assertPlan(format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s AND rand() = 42", inputType, inputPredicate), + output( + filter("rand() = 42e0", + values("a")))); } private void testUnwrap(String inputType, String inputPredicate, String expectedPredicate) @@ -826,18 +813,11 @@ private void testUnwrap(String inputType, String inputPredicate, String expected private void testUnwrap(Session session, String inputType, String inputPredicate, String expectedPredicate) { - String sql = format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s", inputType, inputPredicate); - try { - assertPlan(sql, - session, - output( - filter(expectedPredicate, - values("a")))); - } - catch (Throwable e) { - e.addSuppressed(new Exception("Query: " + sql)); - throw e; - } + assertPlan(format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s OR rand() = 42", inputType, inputPredicate), + session, + output( + filter(format("%s OR rand() = 42e0", expectedPredicate), + values("a")))); } private static Session withZone(Session session, TimeZoneKey timeZoneKey) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java index cdcc083877aa..42c84ece036e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java @@ -142,11 +142,17 @@ protected void assertPlan(@Language("SQL") String sql, LogicalPlanner.Stage stag protected void assertPlan(@Language("SQL") String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern, List optimizers) { - queryRunner.inTransaction(transactionSession -> { - Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers, stage, WarningCollector.NOOP); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getFunctionManager(), queryRunner.getStatsCalculator(), actualPlan, pattern); - return null; - }); + try { + queryRunner.inTransaction(transactionSession -> { + Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers, stage, WarningCollector.NOOP); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getFunctionManager(), queryRunner.getStatsCalculator(), actualPlan, pattern); + return null; + }); + } + catch (Throwable e) { + e.addSuppressed(new Exception("Query: " + sql)); + throw e; + } } protected void assertDistributedPlan(@Language("SQL") String sql, PlanMatchPattern pattern) @@ -178,21 +184,33 @@ protected void assertMinimallyOptimizedPlan(@Language("SQL") String sql, PlanMat protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean forceSingleNode, PlanMatchPattern pattern) { - queryRunner.inTransaction(session, transactionSession -> { - Plan actualPlan = queryRunner.createPlan(transactionSession, sql, OPTIMIZED_AND_VALIDATED, forceSingleNode, WarningCollector.NOOP); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getFunctionManager(), queryRunner.getStatsCalculator(), actualPlan, pattern); - return null; - }); + try { + queryRunner.inTransaction(session, transactionSession -> { + Plan actualPlan = queryRunner.createPlan(transactionSession, sql, OPTIMIZED_AND_VALIDATED, forceSingleNode, WarningCollector.NOOP); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getFunctionManager(), queryRunner.getStatsCalculator(), actualPlan, pattern); + return null; + }); + } + catch (Throwable e) { + e.addSuppressed(new Exception("Query: " + sql)); + throw e; + } } protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean forceSingleNode, PlanMatchPattern pattern, Consumer planValidator) { - queryRunner.inTransaction(session, transactionSession -> { - Plan actualPlan = queryRunner.createPlan(transactionSession, sql, OPTIMIZED_AND_VALIDATED, forceSingleNode, WarningCollector.NOOP); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getFunctionManager(), queryRunner.getStatsCalculator(), actualPlan, pattern); - planValidator.accept(actualPlan); - return null; - }); + try { + queryRunner.inTransaction(session, transactionSession -> { + Plan actualPlan = queryRunner.createPlan(transactionSession, sql, OPTIMIZED_AND_VALIDATED, forceSingleNode, WarningCollector.NOOP); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getFunctionManager(), queryRunner.getStatsCalculator(), actualPlan, pattern); + planValidator.accept(actualPlan); + return null; + }); + } + catch (Throwable e) { + e.addSuppressed(new Exception("Query: " + sql)); + throw e; + } } protected Plan plan(@Language("SQL") String sql) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelatedJoinMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelatedJoinMatcher.java new file mode 100644 index 000000000000..56aa90730d5b --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelatedJoinMatcher.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.assertions; + +import io.trino.Session; +import io.trino.cost.StatsProvider; +import io.trino.metadata.Metadata; +import io.trino.sql.DynamicFilters; +import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.tree.Expression; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static io.trino.sql.DynamicFilters.extractDynamicFilters; +import static io.trino.sql.ExpressionUtils.combineConjuncts; +import static java.util.Objects.requireNonNull; + +final class CorrelatedJoinMatcher + implements Matcher +{ + private final Expression filter; + + CorrelatedJoinMatcher(Expression filter) + { + this.filter = requireNonNull(filter, "filter is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + // However this is used for CorrelatedJoinNode only + return true; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + if (!(node instanceof CorrelatedJoinNode correlatedJoinNode)) { + throw new IllegalStateException("This is a detailed matcher for CorrelatedJoinNode, got: " + node); + } + Expression filter = correlatedJoinNode.getFilter(); + ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); + DynamicFilters.ExtractResult extractResult = extractDynamicFilters(filter); + return new MatchResult(verifier.process(combineConjuncts(metadata, extractResult.getStaticConjuncts()), filter)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("filter", filter) + .toString(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java index 6b6bed794455..890e727c8c5c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java @@ -25,6 +25,7 @@ import io.trino.sql.tree.Expression; import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.SymbolReference; +import org.intellij.lang.annotations.Language; import java.util.List; import java.util.Map; @@ -40,13 +41,13 @@ public class ExpressionMatcher private final String sql; private final Expression expression; - public ExpressionMatcher(String expression) + ExpressionMatcher(@Language("SQL") String expression) { this.sql = requireNonNull(expression, "expression is null"); this.expression = expression(expression); } - public ExpressionMatcher(Expression expression) + ExpressionMatcher(Expression expression) { this.expression = requireNonNull(expression, "expression is null"); this.sql = expression.toString(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 42a5fdcb833a..4d52921ed152 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -508,23 +508,23 @@ public static PlanMatchPattern semiJoin(String sourceSymbolAlias, String filteri return node(SemiJoinNode.class, source, filtering).with(new SemiJoinMatcher(sourceSymbolAlias, filteringSymbolAlias, outputAlias, distributionType, hasDynamicFilter)); } - public static PlanMatchPattern spatialJoin(String expectedFilter, PlanMatchPattern left, PlanMatchPattern right) + public static PlanMatchPattern spatialJoin(@Language("SQL") String expectedFilter, PlanMatchPattern left, PlanMatchPattern right) { return spatialJoin(expectedFilter, Optional.empty(), left, right); } - public static PlanMatchPattern spatialJoin(String expectedFilter, Optional kdbTree, PlanMatchPattern left, PlanMatchPattern right) + public static PlanMatchPattern spatialJoin(@Language("SQL") String expectedFilter, Optional kdbTree, PlanMatchPattern left, PlanMatchPattern right) { return spatialJoin(expectedFilter, kdbTree, Optional.empty(), left, right); } - public static PlanMatchPattern spatialJoin(String expectedFilter, Optional kdbTree, Optional> outputSymbols, PlanMatchPattern left, PlanMatchPattern right) + public static PlanMatchPattern spatialJoin(@Language("SQL") String expectedFilter, Optional kdbTree, Optional> outputSymbols, PlanMatchPattern left, PlanMatchPattern right) { return node(SpatialJoinNode.class, left, right).with( new SpatialJoinMatcher(SpatialJoinNode.Type.INNER, PlanBuilder.expression(expectedFilter), kdbTree, outputSymbols)); } - public static PlanMatchPattern spatialLeftJoin(String expectedFilter, PlanMatchPattern left, PlanMatchPattern right) + public static PlanMatchPattern spatialLeftJoin(@Language("SQL") String expectedFilter, PlanMatchPattern left, PlanMatchPattern right) { return node(SpatialJoinNode.class, left, right).with( new SpatialJoinMatcher(SpatialJoinNode.Type.LEFT, PlanBuilder.expression(expectedFilter), Optional.empty(), Optional.empty())); @@ -719,6 +719,12 @@ public static PlanMatchPattern correlatedJoin(List correlationSymbolAlia .with(new CorrelationMatcher(correlationSymbolAliases)); } + public static PlanMatchPattern correlatedJoin(List correlationSymbolAliases, @Language("SQL") String filter, PlanMatchPattern inputPattern, PlanMatchPattern subqueryPattern) + { + return correlatedJoin(correlationSymbolAliases, inputPattern, subqueryPattern) + .with(new CorrelatedJoinMatcher(PlanBuilder.expression(filter))); + } + public static PlanMatchPattern groupId(List> groupingSets, String groupIdSymbol, PlanMatchPattern source) { return groupId(groupingSets, ImmutableList.of(), groupIdSymbol, source); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementLimitWithTies.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementLimitWithTies.java index 7cef893773d6..f4cb445898b3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementLimitWithTies.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementLimitWithTies.java @@ -17,12 +17,12 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.SortOrder; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.testng.annotations.Test; import java.util.Optional; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.functionCall; import static io.trino.sql.planner.assertions.PlanMatchPattern.specification; @@ -47,7 +47,7 @@ public void testReplaceLimitWithTies() }) .matches( strictProject( - ImmutableMap.of("a", new ExpressionMatcher("a"), "b", new ExpressionMatcher("b")), + ImmutableMap.of("a", expression("a"), "b", expression("b")), filter( "rank_num <= BIGINT '2'", window( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementOffset.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementOffset.java index 103c021b479e..a77cd8ca9309 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementOffset.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementOffset.java @@ -16,11 +16,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.assertions.RowNumberSymbolMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.testng.annotations.Test; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.rowNumber; import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; @@ -45,7 +45,7 @@ public void testReplaceOffsetOverValues() }) .matches( strictProject( - ImmutableMap.of("a", new ExpressionMatcher("a"), "b", new ExpressionMatcher("b")), + ImmutableMap.of("a", expression("a"), "b", expression("b")), filter( "row_num > BIGINT '2'", rowNumber( @@ -70,7 +70,7 @@ public void testReplaceOffsetOverSort() }) .matches( strictProject( - ImmutableMap.of("a", new ExpressionMatcher("a"), "b", new ExpressionMatcher("b")), + ImmutableMap.of("a", expression("a"), "b", expression("b")), filter( "row_num > BIGINT '2'", rowNumber( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitOverProjectWithSort.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitOverProjectWithSort.java index bc970a37e0fc..5fd8b63a496a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitOverProjectWithSort.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeLimitOverProjectWithSort.java @@ -16,13 +16,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; import org.testng.annotations.Test; import java.util.List; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; import static io.trino.sql.planner.assertions.PlanMatchPattern.topN; @@ -50,7 +50,7 @@ public void testMergeLimitOverProjectWithSort() }) .matches( project( - ImmutableMap.of("b", new ExpressionMatcher("b")), + ImmutableMap.of("b", expression("b")), topN( 1, ImmutableList.of(sort("a", ASCENDING, FIRST)), @@ -94,7 +94,7 @@ public void testLimitWithPreSortedInputs() }) .matches( project( - ImmutableMap.of("b", new ExpressionMatcher("b")), + ImmutableMap.of("b", expression("b")), topN( 1, ImmutableList.of(sort("a", ASCENDING, FIRST)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java index da9a820ef4a3..278dd20f6dc4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.type.RowType; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.tree.ArithmeticBinaryExpression; @@ -78,7 +77,7 @@ public void testPushdownLimitWithTiesNNonIdentityProjection() }) .matches( project( - ImmutableMap.of("projectedA", new ExpressionMatcher("a"), "projectedB", new ExpressionMatcher("b")), + ImmutableMap.of("projectedA", expression("a"), "projectedB", expression("b")), limit(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } @@ -102,7 +101,7 @@ projectedC, new ArithmeticBinaryExpression(ADD, new SymbolReference("a"), new Sy }) .matches( project( - ImmutableMap.of("projectedA", new ExpressionMatcher("a"), "projectedC", new ExpressionMatcher("a + b")), + ImmutableMap.of("projectedA", expression("a"), "projectedC", expression("a + b")), limit(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } @@ -199,7 +198,7 @@ projectedC, new ArithmeticBinaryExpression(ADD, new SymbolReference("a"), new Sy }) .matches( project( - ImmutableMap.of("projectedA", new ExpressionMatcher("a"), "projectedC", new ExpressionMatcher("a + b")), + ImmutableMap.of("projectedA", expression("a"), "projectedC", expression("a + b")), limit(1, ImmutableList.of(), true, ImmutableList.of("a"), values("a", "b")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java index 159c0cd3cd9f..ef186c52bcf4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.type.RowType; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.tree.ArithmeticBinaryExpression; @@ -65,7 +64,7 @@ public void testPushdownTopNNonIdentityProjection() }) .matches( project( - ImmutableMap.of("projectedA", new ExpressionMatcher("a"), "projectedB", new ExpressionMatcher("b")), + ImmutableMap.of("projectedA", expression("a"), "projectedB", expression("b")), topN(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } @@ -89,7 +88,7 @@ projectedC, new ArithmeticBinaryExpression(ADD, new SymbolReference("a"), new Sy }) .matches( project( - ImmutableMap.of("projectedA", new ExpressionMatcher("a"), "projectedC", new ExpressionMatcher("a + b")), + ImmutableMap.of("projectedA", expression("a"), "projectedC", expression("a + b")), topN(1, ImmutableList.of(sort("a", ASCENDING, FIRST)), values("a", "b")))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveTrivialFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveTrivialFilters.java index e415cb53e384..c641d37a3a41 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveTrivialFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveTrivialFilters.java @@ -57,10 +57,12 @@ public void testRemovesNullFilter() { tester().assertThat(new RemoveTrivialFilters()) .on(p -> p.filter( - expression("null"), + expression("CAST(null AS boolean)"), p.values( ImmutableList.of(p.symbol("a")), ImmutableList.of(expressions("1"))))) - .matches(values("a")); + .matches(values( + ImmutableList.of("a"), + ImmutableList.of())); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java index 8d87edc0a274..adca98d2ccb7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java @@ -16,11 +16,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.tree.ComparisonExpression; import org.testng.annotations.Test; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; @@ -120,8 +120,8 @@ public void testRewriteRightCorrelatedJoin() .matches( project( ImmutableMap.of( - "a", new ExpressionMatcher("if(b > a, a, null)"), - "b", new ExpressionMatcher("b")), + "a", expression("if(b > a, a, null)"), + "b", expression("b")), join(Type.INNER, builder -> builder .left(values("a")) .right(values("b")))));