Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
import java.util.LinkedHashMap;
Expand All @@ -112,6 +111,7 @@
import static java.lang.Boolean.FALSE;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Collections.unmodifiableList;
import static java.util.Collections.unmodifiableMap;
import static java.util.Collections.unmodifiableSet;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -1073,7 +1073,7 @@ public void addRowFilter(Table table, Expression filter)

public List<Expression> getRowFilters(Table node)
{
return rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of());
return unmodifiableList(rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of()));
}

public boolean hasColumnMask(QualifiedObjectName table, String column, String identity)
Expand Down Expand Up @@ -1101,7 +1101,7 @@ public void addColumnMask(Table table, String column, Expression mask)

public Map<String, Expression> getColumnMasks(Table table)
{
return columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of());
return unmodifiableMap(columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of()));
}

public List<TableInfo> getReferencedTables()
Expand Down Expand Up @@ -1571,22 +1571,22 @@ public void addQuantifiedComparisons(List<QuantifiedComparisonExpression> expres

public List<InPredicate> getInPredicatesSubqueries()
{
return Collections.unmodifiableList(inPredicatesSubqueries);
return unmodifiableList(inPredicatesSubqueries);
}

public List<SubqueryExpression> getSubqueries()
{
return Collections.unmodifiableList(subqueries);
return unmodifiableList(subqueries);
}

public List<ExistsPredicate> getExistsSubqueries()
{
return Collections.unmodifiableList(existsSubqueries);
return unmodifiableList(existsSubqueries);
}

public List<QuantifiedComparisonExpression> getQuantifiedComparisonSubqueries()
{
return Collections.unmodifiableList(quantifiedComparisonSubqueries);
return unmodifiableList(quantifiedComparisonSubqueries);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ private PlanBuilder filter(PlanBuilder subPlan, Expression predicate, Node node)

subPlan = subqueryPlanner.handleSubqueries(subPlan, predicate, analysis.getSubqueries(node));

return subPlan.withNewRoot(new FilterNode(idAllocator.getNextId(), subPlan.getRoot(), subPlan.rewrite(predicate)));
return subPlan.withNewRoot(new FilterNode(idAllocator.getNextId(), subPlan.getRoot(), coerceIfNecessary(analysis, predicate, subPlan.rewrite(predicate))));
}

private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ public RelationPlan addRowFilters(Table node, RelationPlan plan, Function<Expres
for (Expression filter : filters) {
planBuilder = subqueryPlanner.handleSubqueries(planBuilder, filter, analysis.getSubqueries(filter));

Expression predicate = planBuilder.rewrite(filter);
Expression predicate = coerceIfNecessary(analysis, filter, planBuilder.rewrite(filter));
predicate = predicateTransformation.apply(predicate);
planBuilder = planBuilder.withNewRoot(new FilterNode(
idAllocator.getNextId(),
Expand Down Expand Up @@ -809,7 +809,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende
rootPlanBuilder = subqueryPlanner.handleSubqueries(rootPlanBuilder, complexJoinExpressions, subqueries);

for (Expression expression : complexJoinExpressions) {
postInnerJoinConditions.add(rootPlanBuilder.rewrite(expression));
postInnerJoinConditions.add(coerceIfNecessary(analysis, expression, rootPlanBuilder.rewrite(expression)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please apply it also to CorrelatedJoinNode?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Will do & add a test in TestLogicalPlanner as well

}
root = rootPlanBuilder.getRoot();

Expand Down Expand Up @@ -994,7 +994,7 @@ private RelationPlan planCorrelatedJoin(Join join, RelationPlan leftPlan, Latera
.withAdditionalMappings(leftPlanBuilder.getTranslations().getMappings())
.withAdditionalMappings(rightPlanBuilder.getTranslations().getMappings());

Expression rewrittenFilterCondition = translationMap.rewrite(filterExpression);
Expression rewrittenFilterCondition = coerceIfNecessary(analysis, filterExpression, translationMap.rewrite(filterExpression));

PlanBuilder planBuilder = subqueryPlanner.appendCorrelatedJoin(
leftPlanBuilder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NullLiteral;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.sql.planner.plan.Patterns.filter;
import static io.trino.sql.tree.BooleanLiteral.FALSE_LITERAL;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
Expand All @@ -41,12 +43,14 @@ public Pattern<FilterNode> getPattern()
public Result apply(FilterNode filterNode, Captures captures, Context context)
{
Expression predicate = filterNode.getPredicate();
checkArgument(!(predicate instanceof NullLiteral), "Unexpected null literal without a cast to boolean");

if (predicate.equals(TRUE_LITERAL)) {
return Result.ofPlanNode(filterNode.getSource());
}

if (predicate.equals(FALSE_LITERAL) || predicate instanceof NullLiteral) {
if (predicate.equals(FALSE_LITERAL) ||
(predicate instanceof Cast cast && cast.getExpression() instanceof NullLiteral)) {
return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), filterNode.getOutputSymbols(), emptyList()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Join;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NullLiteral;

import javax.annotation.concurrent.Immutable;

Expand Down Expand Up @@ -100,6 +101,9 @@ public CorrelatedJoinNode(
requireNonNull(subquery, "subquery is null");
requireNonNull(correlation, "correlation is null");
requireNonNull(filter, "filter is null");
// The condition doesn't guarantee that filter is of type boolean, but was found to be a practical way to identify
// places where CorrelatedJoinNode could be created without appropriate coercions.
checkArgument(!(filter instanceof NullLiteral), "Filter must be an expression of boolean type: %s", filter);
requireNonNull(originSubquery, "originSubquery is null");

checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
import com.google.common.collect.Iterables;
import io.trino.sql.planner.Symbol;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NullLiteral;

import javax.annotation.concurrent.Immutable;

import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

@Immutable
public class FilterNode
extends PlanNode
Expand All @@ -39,6 +43,10 @@ public FilterNode(@JsonProperty("id") PlanNodeId id,
super(id);

this.source = source;
requireNonNull(predicate, "predicate is null");
// The condition doesn't guarantee that predicate is of type boolean, but was found to be a practical way to identify
// places where FilterNode was created without appropriate coercions.
checkArgument(!(predicate instanceof NullLiteral), "Predicate must be an expression of boolean type: %s", predicate);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do something more generic here? Like checking if predicate is of type Literal but not a BooleanLiteral?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make it more complicated, yes.
However we cannot analyze expression type fully, as we don't have symbol types here, so we fail short of being really generic. I draw the line here, we can draw the line somewhere else.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not add the check for NullLiteral. We don't do this in other PlanNodes which require boolean expressions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can all agree that FilterNode.predicate should be of boolean type.
Yet, the code base required updates in quite a few places to bring reality in line with the design.
The only known to me way to effectively find those places was to add the check here.
Now, we can hope that codebase is now correct and we can remove the check, but how will we ensure the problems do not come back with some new code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that the PlanNode constructor seems to be the best place for the check. If we do the check immediately after the initial planning, we might miss wrong filters created in the Optimizer. If we wait until execution, we might miss / break optimizations.

This check is not satisfying though, as it only detects NullLiteral. If you want to keep it, then maybe add a TODO explaining what's missing. That TODO will be easy to address in the future, when we have the new IR with types included.

Also, consider adding a similar check in other PlanNodes, at least those covered by this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is not satisfying though, as it only detects NullLiteral.

Correct, the check is a required check, but doesn't guarantee well-formedness.

consider adding a similar check in other PlanNodes

I would prefer to defer that, if i were to choose

If you want to keep it, then maybe add a TODO explaining what's missing.

Good idea, will add a comment!

this.predicate = predicate;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Join;
import io.trino.sql.tree.NullLiteral;

import javax.annotation.concurrent.Immutable;

Expand Down Expand Up @@ -89,6 +90,9 @@ public JoinNode(
requireNonNull(leftOutputSymbols, "leftOutputSymbols is null");
requireNonNull(rightOutputSymbols, "rightOutputSymbols is null");
requireNonNull(filter, "filter is null");
// The condition doesn't guarantee that filter is of type boolean, but was found to be a practical way to identify
// places where JoinNode could be created without appropriate coercions.
checkArgument(filter.isEmpty() || !(filter.get() instanceof NullLiteral), "Filter must be an expression of boolean type: %s", filter);
requireNonNull(leftHashSymbol, "leftHashSymbol is null");
requireNonNull(rightHashSymbol, "rightHashSymbol is null");
requireNonNull(distributionType, "distributionType is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import io.trino.sql.planner.OptimizerConfig.JoinDistributionType;
import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.assertions.RowNumberSymbolMatcher;
import io.trino.sql.planner.optimizations.AddLocalExchanges;
Expand Down Expand Up @@ -106,6 +105,7 @@
import static io.trino.sql.planner.assertions.PlanMatchPattern.assignUniqueId;
import static io.trino.sql.planner.assertions.PlanMatchPattern.constrainedTableScan;
import static io.trino.sql.planner.assertions.PlanMatchPattern.constrainedTableScanWithTableLayout;
import static io.trino.sql.planner.assertions.PlanMatchPattern.correlatedJoin;
import static io.trino.sql.planner.assertions.PlanMatchPattern.equiJoinClause;
import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
import static io.trino.sql.planner.assertions.PlanMatchPattern.expression;
Expand Down Expand Up @@ -813,6 +813,31 @@ public void testCorrelatedJoinWithTopN()
anyTree(tableScan("nation", ImmutableMap.of("nation_name", "name", "nation_regionkey", "regionkey")))))))));
}

@Test
public void testCorrelatedJoinWithNullCondition()
{
assertPlan(
"SELECT regionkey, n.name FROM region LEFT JOIN LATERAL (SELECT name FROM nation) n ON NULL",
CREATED,
anyTree(
correlatedJoin(
List.of("r_row_number", "r_regionkey", "r_name", "r_comment"),
"CAST(null AS boolean)",
tableScan("region", Map.of(
"r_row_number", "row_number",
"r_regionkey", "regionkey",
"r_name", "name",
"r_comment", "comment")),
anyTree(tableScan("nation")))));
assertPlan(
"SELECT regionkey, n.name FROM region LEFT JOIN LATERAL (SELECT name FROM nation) n ON NULL",
any(
join(LEFT, builder -> builder
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a join filter in the expected plan

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, will do!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the filter, not the equiCriteria. However, the filter is empty at that planning phase, as the result of PredicatePushDown. To see how the filter is planned initially, please test with CREATED.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filter is empty so cannot assert that it's there, since it isn't.
Sure, will add another case with CREATED

.equiCriteria(List.of())
.left(tableScan("region"))
.right(values("name")))));
}

@Test
public void testCorrelatedScalarSubqueryInSelect()
{
Expand Down Expand Up @@ -1109,6 +1134,14 @@ public void testRemovesTrivialFilters()
"SELECT * FROM nation WHERE 1 = 0",
output(
values("nationkey", "name", "regionkey", "comment")));
assertPlan(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this test correspond to the change? Could you add a test with WHERE null?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this test correspond to the change?

as a side effect of the fixes, this now works

Could you add a test with WHERE null?

absolutely!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a side effect of the fixes, this now works

Could you please explain what changes for this case? I can't see a link, so perhaps I'm reading the code wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain what changes for this case?

I will separately adding the test and making the code change, so it will be visible what changed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it's clear. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The appreciation of the clearness is highly appreciated

"SELECT * FROM nation WHERE null",
output(
values("nationkey", "name", "regionkey", "comment")));
assertPlan(
"SELECT * FROM nation WHERE nationkey = null",
output(
values("nationkey", "name", "regionkey", "comment")));
}

@Test
Expand Down Expand Up @@ -1488,7 +1521,7 @@ public void testOffset()
"SELECT name FROM nation OFFSET 2 ROWS",
any(
strictProject(
ImmutableMap.of("name", new ExpressionMatcher("name")),
ImmutableMap.of("name", expression("name")),
filter(
"row_num > BIGINT '2'",
rowNumber(
Expand All @@ -1502,7 +1535,7 @@ public void testOffset()
"SELECT name FROM nation ORDER BY regionkey OFFSET 2 ROWS",
any(
strictProject(
ImmutableMap.of("name", new ExpressionMatcher("name")),
ImmutableMap.of("name", expression("name")),
filter(
"row_num > BIGINT '2'",
rowNumber(
Expand All @@ -1519,7 +1552,7 @@ public void testOffset()
"SELECT name FROM nation ORDER BY regionkey OFFSET 2 ROWS FETCH NEXT 5 ROWS ONLY",
any(
strictProject(
ImmutableMap.of("name", new ExpressionMatcher("name")),
ImmutableMap.of("name", expression("name")),
filter(
"row_num > BIGINT '2'",
rowNumber(
Expand All @@ -1538,7 +1571,7 @@ public void testOffset()
"SELECT name FROM nation OFFSET 2 ROWS FETCH NEXT 5 ROWS ONLY",
any(
strictProject(
ImmutableMap.of("name", new ExpressionMatcher("name")),
ImmutableMap.of("name", expression("name")),
filter(
"row_num > BIGINT '2'",
rowNumber(
Expand All @@ -1558,7 +1591,7 @@ public void testWithTies()
"SELECT name, regionkey FROM nation ORDER BY regionkey FETCH FIRST 6 ROWS WITH TIES",
any(
strictProject(
ImmutableMap.of("name", new ExpressionMatcher("name"), "regionkey", new ExpressionMatcher("regionkey")),
ImmutableMap.of("name", expression("name"), "regionkey", expression("regionkey")),
topNRanking(
pattern -> pattern
.specification(
Expand All @@ -1578,14 +1611,14 @@ public void testWithTies()
"SELECT name, regionkey FROM nation ORDER BY regionkey OFFSET 10 ROWS FETCH FIRST 6 ROWS WITH TIES",
any(
strictProject(
ImmutableMap.of("name", new ExpressionMatcher("name"), "regionkey", new ExpressionMatcher("regionkey")),
ImmutableMap.of("name", expression("name"), "regionkey", expression("regionkey")),
filter(
"row_num > BIGINT '10'",
rowNumber(
pattern -> pattern
.partitionBy(ImmutableList.of()),
strictProject(
ImmutableMap.of("name", new ExpressionMatcher("name"), "regionkey", new ExpressionMatcher("regionkey")),
ImmutableMap.of("name", expression("name"), "regionkey", expression("regionkey")),
topNRanking(
pattern -> pattern
.specification(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -791,32 +791,19 @@ private void testNoUnwrap(String inputType, String inputPredicate, String expect

private void testNoUnwrap(Session session, String inputType, String inputPredicate, String expectedCastType)
{
String sql = format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE a %s", inputType, inputPredicate);
try {
assertPlan(sql,
session,
output(
filter(format("CAST(a AS %s) %s", expectedCastType, inputPredicate),
values("a"))));
}
catch (Throwable e) {
e.addSuppressed(new Exception("Query: " + sql));
throw e;
}
assertPlan(format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE a %s", inputType, inputPredicate),
session,
output(
filter(format("CAST(a AS %s) %s", expectedCastType, inputPredicate),
values("a"))));
}

private void testRemoveFilter(String inputType, String inputPredicate)
{
String sql = format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s", inputType, inputPredicate);
try {
assertPlan(sql,
output(
values("a")));
}
catch (Throwable e) {
e.addSuppressed(new Exception("Query: " + sql));
throw e;
}
assertPlan(format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s AND rand() = 42", inputType, inputPredicate),
output(
filter("rand() = 42e0",
values("a"))));
}

private void testUnwrap(String inputType, String inputPredicate, String expectedPredicate)
Expand All @@ -826,18 +813,11 @@ private void testUnwrap(String inputType, String inputPredicate, String expected

private void testUnwrap(Session session, String inputType, String inputPredicate, String expectedPredicate)
{
String sql = format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s", inputType, inputPredicate);
try {
assertPlan(sql,
session,
output(
filter(expectedPredicate,
values("a"))));
}
catch (Throwable e) {
e.addSuppressed(new Exception("Query: " + sql));
throw e;
}
assertPlan(format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s OR rand() = 42", inputType, inputPredicate),
session,
output(
filter(format("%s OR rand() = 42e0", expectedPredicate),
values("a"))));
}

private static Session withZone(Session session, TimeZoneKey timeZoneKey)
Expand Down
Loading