diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 1ecd11590f00..49ace6fc3524 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -546,7 +546,7 @@ public PlanOptimizers( new TransformCorrelatedSingleRowSubqueryToProject(), new RemoveAggregationInSemiJoin(), new MergeProjectWithValues(metadata), - new ReplaceJoinOverConstantWithProject())), + new ReplaceJoinOverConstantWithProject(metadata))), new CheckSubqueryNodesAreRewritten(), simplifyOptimizer, // Should run after MergeProjectWithValues new StatsRecordingPlanOptimizer( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java index cdd0b75fda2f..e18984e3cdc1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java @@ -15,6 +15,7 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; +import io.trino.metadata.Metadata; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -32,9 +33,11 @@ import java.util.Map; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; import static io.trino.sql.planner.plan.Patterns.join; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static java.util.Objects.requireNonNull; /** * This rule transforms plans with join where one of the sources is @@ -49,11 +52,12 @@ * be done, because the result of the transformation would be possibly * empty, while the single constant row should be preserved on output. *

- * Note 2: The transformation is valid when the ValuesNode contains - * non-deterministic expressions. This is because any expression from - * the ValuesNode can only be used once. Assignments.Builder deduplicates - * them in case when the JoinNode produces any of the input symbols - * more than once. + * Note 2: The transformation is not valid when the ValuesNode contains + * a non-deterministic expression. According to the semantics of the + * original plan, such expression should be evaluated once, and the value + * should be appended to each row of the other join source. Inlining the + * expression would result in evaluating it for each row to a potentially + * different value. *

* Note 3: The transformation is valid when the ValuesNode contains * expressions using correlation symbols. They are constant from the @@ -79,6 +83,13 @@ public class ReplaceJoinOverConstantWithProject private static final Pattern PATTERN = join() .matching(ReplaceJoinOverConstantWithProject::isUnconditional); + private final Metadata metadata; + + public ReplaceJoinOverConstantWithProject(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + @Override public Pattern getPattern() { @@ -170,6 +181,10 @@ private boolean isSingleConstantRow(PlanNode node) Expression row = getOnlyElement(values.getRows().get()); + if (!isDeterministic(row, metadata)) { + return false; + } + return row instanceof Row; } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java index 24392475551e..151679e4bd8d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java @@ -18,6 +18,9 @@ import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.JoinNode.EquiJoinClause; +import io.trino.sql.tree.FunctionCall; +import io.trino.sql.tree.QualifiedName; +import io.trino.sql.tree.Row; import org.testng.annotations.Test; import java.util.Optional; @@ -37,7 +40,7 @@ public class TestReplaceJoinOverConstantWithProject @Test public void testDoesNotFireOnJoinWithEmptySource() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -45,7 +48,7 @@ public void testDoesNotFireOnJoinWithEmptySource() p.values(0, p.symbol("b")))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -57,7 +60,7 @@ public void testDoesNotFireOnJoinWithEmptySource() @Test public void testDoesNotFireOnJoinWithCondition() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -66,7 +69,7 @@ public void testDoesNotFireOnJoinWithCondition() new EquiJoinClause(p.symbol("a"), p.symbol("b")))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -79,7 +82,7 @@ public void testDoesNotFireOnJoinWithCondition() @Test public void testDoesNotFireOnValuesWithMultipleRows() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -91,7 +94,7 @@ public void testDoesNotFireOnValuesWithMultipleRows() @Test public void testDoesNotFireOnValuesWithNoOutputs() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -103,7 +106,7 @@ public void testDoesNotFireOnValuesWithNoOutputs() @Test public void testDoesNotFireOnValuesWithNonRowExpression() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -115,7 +118,7 @@ public void testDoesNotFireOnValuesWithNonRowExpression() @Test public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( LEFT, @@ -125,7 +128,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() p.values(10, p.symbol("b"))))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( RIGHT, @@ -135,7 +138,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() p.values(1, p.symbol("b")))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( FULL, @@ -145,7 +148,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() p.values(10, p.symbol("b"))))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( FULL, @@ -159,7 +162,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() @Test public void testReplaceInnerJoinWithProject() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -173,7 +176,7 @@ public void testReplaceInnerJoinWithProject() "c", PlanMatchPattern.expression("c")), values("c"))); - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -191,7 +194,7 @@ public void testReplaceInnerJoinWithProject() @Test public void testReplaceLeftJoinWithProject() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( LEFT, @@ -205,7 +208,7 @@ public void testReplaceLeftJoinWithProject() "c", PlanMatchPattern.expression("c")), values("c"))); - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( LEFT, @@ -223,7 +226,7 @@ public void testReplaceLeftJoinWithProject() @Test public void testReplaceRightJoinWithProject() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( RIGHT, @@ -237,7 +240,7 @@ public void testReplaceRightJoinWithProject() "c", PlanMatchPattern.expression("c")), values("c"))); - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( RIGHT, @@ -255,7 +258,7 @@ public void testReplaceRightJoinWithProject() @Test public void testReplaceFullJoinWithProject() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( FULL, @@ -269,7 +272,7 @@ public void testReplaceFullJoinWithProject() "c", PlanMatchPattern.expression("c")), values("c"))); - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( FULL, @@ -287,7 +290,7 @@ public void testReplaceFullJoinWithProject() @Test public void testRemoveOutputDuplicates() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, @@ -309,17 +312,28 @@ public void testRemoveOutputDuplicates() @Test public void testNonDeterministicValues() { - tester().assertThat(new ReplaceJoinOverConstantWithProject()) + FunctionCall randomFunction = new FunctionCall( + tester().getMetadata().resolveFunction(tester().getSession(), QualifiedName.of("random"), ImmutableList.of()).toQualifiedName(), + ImmutableList.of()); + + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) .on(p -> p.join( INNER, - p.valuesOfExpressions(ImmutableList.of(p.symbol("a")), ImmutableList.of(expression("ROW(rand())"))), + p.valuesOfExpressions(ImmutableList.of(p.symbol("rand")), ImmutableList.of(new Row(ImmutableList.of(randomFunction)))), p.values(5, p.symbol("b")))) - .matches( - strictProject( - ImmutableMap.of( - "a", PlanMatchPattern.expression("rand()"), - "b", PlanMatchPattern.expression("b")), - values("b"))); + .doesNotFire(); + + FunctionCall uuidFunction = new FunctionCall( + tester().getMetadata().resolveFunction(tester().getSession(), QualifiedName.of("uuid"), ImmutableList.of()).toQualifiedName(), + ImmutableList.of()); + + tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + .on(p -> + p.join( + INNER, + p.valuesOfExpressions(ImmutableList.of(p.symbol("uuid")), ImmutableList.of(new Row(ImmutableList.of(uuidFunction)))), + p.values(5, p.symbol("b")))) + .doesNotFire(); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java index 27403b606ce9..4fc70ca49612 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java @@ -72,6 +72,18 @@ d as (SELECT id FROM (VALUES (1)) AS t(id)) .matches("VALUES 1"); } + @Test + public void testSingleRowNonDeterministicSource() + { + assertThat(assertions.query(""" + WITH data(id) AS (SELECT uuid()) + SELECT COUNT(DISTINCT id) + FROM (VALUES 1, 2, 3, 4, 5, 6, 7, 8) + CROSS JOIN data + """)) + .matches("VALUES BIGINT '1'"); + } + @Test public void testJoinOnNan() {