Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ public PlanOptimizers(
new TransformCorrelatedSingleRowSubqueryToProject(),
new RemoveAggregationInSemiJoin(),
new MergeProjectWithValues(metadata),
new ReplaceJoinOverConstantWithProject())),
new ReplaceJoinOverConstantWithProject(metadata))),
new CheckSubqueryNodesAreRewritten(),
simplifyOptimizer, // Should run after MergeProjectWithValues
new StatsRecordingPlanOptimizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
Expand All @@ -32,9 +33,11 @@
import java.util.Map;

import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality;
import static io.trino.sql.planner.plan.Patterns.join;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static java.util.Objects.requireNonNull;

/**
* This rule transforms plans with join where one of the sources is
Expand All @@ -49,11 +52,12 @@
* be done, because the result of the transformation would be possibly
* empty, while the single constant row should be preserved on output.
* <p>
* Note 2: The transformation is valid when the ValuesNode contains
* non-deterministic expressions. This is because any expression from
* the ValuesNode can only be used once. Assignments.Builder deduplicates
* them in case when the JoinNode produces any of the input symbols
* more than once.
* Note 2: The transformation is not valid when the ValuesNode contains
* a non-deterministic expression. According to the semantics of the
* original plan, such expression should be evaluated once, and the value
* should be appended to each row of the other join source. Inlining the
* expression would result in evaluating it for each row to a potentially
* different value.
* <p>
* Note 3: The transformation is valid when the ValuesNode contains
* expressions using correlation symbols. They are constant from the
Expand All @@ -79,6 +83,13 @@ public class ReplaceJoinOverConstantWithProject
private static final Pattern<JoinNode> PATTERN = join()
.matching(ReplaceJoinOverConstantWithProject::isUnconditional);

private final Metadata metadata;

public ReplaceJoinOverConstantWithProject(Metadata metadata)
{
this.metadata = requireNonNull(metadata, "metadata is null");
}

@Override
public Pattern<JoinNode> getPattern()
{
Expand Down Expand Up @@ -170,6 +181,10 @@ private boolean isSingleConstantRow(PlanNode node)

Expression row = getOnlyElement(values.getRows().get());

if (!isDeterministic(row, metadata)) {
return false;
}

return row instanceof Row;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.JoinNode.EquiJoinClause;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.Row;
import org.testng.annotations.Test;

import java.util.Optional;
Expand All @@ -37,15 +40,15 @@ public class TestReplaceJoinOverConstantWithProject
@Test
public void testDoesNotFireOnJoinWithEmptySource()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
p.values(1, p.symbol("a")),
p.values(0, p.symbol("b"))))
.doesNotFire();

tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
Expand All @@ -57,7 +60,7 @@ public void testDoesNotFireOnJoinWithEmptySource()
@Test
public void testDoesNotFireOnJoinWithCondition()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
Expand All @@ -66,7 +69,7 @@ public void testDoesNotFireOnJoinWithCondition()
new EquiJoinClause(p.symbol("a"), p.symbol("b"))))
.doesNotFire();

tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
Expand All @@ -79,7 +82,7 @@ public void testDoesNotFireOnJoinWithCondition()
@Test
public void testDoesNotFireOnValuesWithMultipleRows()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
Expand All @@ -91,7 +94,7 @@ public void testDoesNotFireOnValuesWithMultipleRows()
@Test
public void testDoesNotFireOnValuesWithNoOutputs()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
Expand All @@ -103,7 +106,7 @@ public void testDoesNotFireOnValuesWithNoOutputs()
@Test
public void testDoesNotFireOnValuesWithNonRowExpression()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
Expand All @@ -115,7 +118,7 @@ public void testDoesNotFireOnValuesWithNonRowExpression()
@Test
public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
LEFT,
Expand All @@ -125,7 +128,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
p.values(10, p.symbol("b")))))
.doesNotFire();

tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
RIGHT,
Expand All @@ -135,7 +138,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
p.values(1, p.symbol("b"))))
.doesNotFire();

tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
FULL,
Expand All @@ -145,7 +148,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
p.values(10, p.symbol("b")))))
.doesNotFire();

tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
FULL,
Expand All @@ -159,7 +162,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
@Test
public void testReplaceInnerJoinWithProject()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
Expand All @@ -173,7 +176,7 @@ public void testReplaceInnerJoinWithProject()
"c", PlanMatchPattern.expression("c")),
values("c")));

tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
Expand All @@ -191,7 +194,7 @@ public void testReplaceInnerJoinWithProject()
@Test
public void testReplaceLeftJoinWithProject()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
LEFT,
Expand All @@ -205,7 +208,7 @@ public void testReplaceLeftJoinWithProject()
"c", PlanMatchPattern.expression("c")),
values("c")));

tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
LEFT,
Expand All @@ -223,7 +226,7 @@ public void testReplaceLeftJoinWithProject()
@Test
public void testReplaceRightJoinWithProject()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
RIGHT,
Expand All @@ -237,7 +240,7 @@ public void testReplaceRightJoinWithProject()
"c", PlanMatchPattern.expression("c")),
values("c")));

tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
RIGHT,
Expand All @@ -255,7 +258,7 @@ public void testReplaceRightJoinWithProject()
@Test
public void testReplaceFullJoinWithProject()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
FULL,
Expand All @@ -269,7 +272,7 @@ public void testReplaceFullJoinWithProject()
"c", PlanMatchPattern.expression("c")),
values("c")));

tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
FULL,
Expand All @@ -287,7 +290,7 @@ public void testReplaceFullJoinWithProject()
@Test
public void testRemoveOutputDuplicates()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
Expand All @@ -309,17 +312,28 @@ public void testRemoveOutputDuplicates()
@Test
public void testNonDeterministicValues()
{
tester().assertThat(new ReplaceJoinOverConstantWithProject())
FunctionCall randomFunction = new FunctionCall(
tester().getMetadata().resolveFunction(tester().getSession(), QualifiedName.of("random"), ImmutableList.of()).toQualifiedName(),
ImmutableList.of());

tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
p.valuesOfExpressions(ImmutableList.of(p.symbol("a")), ImmutableList.of(expression("ROW(rand())"))),
p.valuesOfExpressions(ImmutableList.of(p.symbol("rand")), ImmutableList.of(new Row(ImmutableList.of(randomFunction)))),
p.values(5, p.symbol("b"))))
.matches(
strictProject(
ImmutableMap.of(
"a", PlanMatchPattern.expression("rand()"),
"b", PlanMatchPattern.expression("b")),
values("b")));
.doesNotFire();

FunctionCall uuidFunction = new FunctionCall(
tester().getMetadata().resolveFunction(tester().getSession(), QualifiedName.of("uuid"), ImmutableList.of()).toQualifiedName(),
ImmutableList.of());

tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
p.valuesOfExpressions(ImmutableList.of(p.symbol("uuid")), ImmutableList.of(new Row(ImmutableList.of(uuidFunction)))),
p.values(5, p.symbol("b"))))
.doesNotFire();
}
}
12 changes: 12 additions & 0 deletions core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ d as (SELECT id FROM (VALUES (1)) AS t(id))
.matches("VALUES 1");
}

@Test
public void testSingleRowNonDeterministicSource()
{
assertThat(assertions.query("""
WITH data(id) AS (SELECT uuid())
SELECT COUNT(DISTINCT id)
FROM (VALUES 1, 2, 3, 4, 5, 6, 7, 8)
CROSS JOIN data
"""))
.matches("VALUES BIGINT '1'");
}

@Test
public void testJoinOnNan()
{
Expand Down