diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java
index 1ecd11590f00..49ace6fc3524 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java
@@ -546,7 +546,7 @@ public PlanOptimizers(
new TransformCorrelatedSingleRowSubqueryToProject(),
new RemoveAggregationInSemiJoin(),
new MergeProjectWithValues(metadata),
- new ReplaceJoinOverConstantWithProject())),
+ new ReplaceJoinOverConstantWithProject(metadata))),
new CheckSubqueryNodesAreRewritten(),
simplifyOptimizer, // Should run after MergeProjectWithValues
new StatsRecordingPlanOptimizer(
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java
index cdd0b75fda2f..e18984e3cdc1 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java
@@ -15,6 +15,7 @@
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
+import io.trino.metadata.Metadata;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
@@ -32,9 +33,11 @@
import java.util.Map;
import static com.google.common.collect.Iterables.getOnlyElement;
+import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality;
import static io.trino.sql.planner.plan.Patterns.join;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
+import static java.util.Objects.requireNonNull;
/**
* This rule transforms plans with join where one of the sources is
@@ -49,11 +52,12 @@
* be done, because the result of the transformation would be possibly
* empty, while the single constant row should be preserved on output.
*
- * Note 2: The transformation is valid when the ValuesNode contains
- * non-deterministic expressions. This is because any expression from
- * the ValuesNode can only be used once. Assignments.Builder deduplicates
- * them in case when the JoinNode produces any of the input symbols
- * more than once.
+ * Note 2: The transformation is not valid when the ValuesNode contains
+ * a non-deterministic expression. According to the semantics of the
+ * original plan, such expression should be evaluated once, and the value
+ * should be appended to each row of the other join source. Inlining the
+ * expression would result in evaluating it for each row to a potentially
+ * different value.
*
* Note 3: The transformation is valid when the ValuesNode contains
* expressions using correlation symbols. They are constant from the
@@ -79,6 +83,13 @@ public class ReplaceJoinOverConstantWithProject
private static final Pattern PATTERN = join()
.matching(ReplaceJoinOverConstantWithProject::isUnconditional);
+ private final Metadata metadata;
+
+ public ReplaceJoinOverConstantWithProject(Metadata metadata)
+ {
+ this.metadata = requireNonNull(metadata, "metadata is null");
+ }
+
@Override
public Pattern getPattern()
{
@@ -170,6 +181,10 @@ private boolean isSingleConstantRow(PlanNode node)
Expression row = getOnlyElement(values.getRows().get());
+ if (!isDeterministic(row, metadata)) {
+ return false;
+ }
+
return row instanceof Row;
}
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java
index 24392475551e..151679e4bd8d 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java
@@ -18,6 +18,9 @@
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.JoinNode.EquiJoinClause;
+import io.trino.sql.tree.FunctionCall;
+import io.trino.sql.tree.QualifiedName;
+import io.trino.sql.tree.Row;
import org.testng.annotations.Test;
import java.util.Optional;
@@ -37,7 +40,7 @@ public class TestReplaceJoinOverConstantWithProject
@Test
public void testDoesNotFireOnJoinWithEmptySource()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -45,7 +48,7 @@ public void testDoesNotFireOnJoinWithEmptySource()
p.values(0, p.symbol("b"))))
.doesNotFire();
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -57,7 +60,7 @@ public void testDoesNotFireOnJoinWithEmptySource()
@Test
public void testDoesNotFireOnJoinWithCondition()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -66,7 +69,7 @@ public void testDoesNotFireOnJoinWithCondition()
new EquiJoinClause(p.symbol("a"), p.symbol("b"))))
.doesNotFire();
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -79,7 +82,7 @@ public void testDoesNotFireOnJoinWithCondition()
@Test
public void testDoesNotFireOnValuesWithMultipleRows()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -91,7 +94,7 @@ public void testDoesNotFireOnValuesWithMultipleRows()
@Test
public void testDoesNotFireOnValuesWithNoOutputs()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -103,7 +106,7 @@ public void testDoesNotFireOnValuesWithNoOutputs()
@Test
public void testDoesNotFireOnValuesWithNonRowExpression()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -115,7 +118,7 @@ public void testDoesNotFireOnValuesWithNonRowExpression()
@Test
public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
LEFT,
@@ -125,7 +128,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
p.values(10, p.symbol("b")))))
.doesNotFire();
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
RIGHT,
@@ -135,7 +138,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
p.values(1, p.symbol("b"))))
.doesNotFire();
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
FULL,
@@ -145,7 +148,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
p.values(10, p.symbol("b")))))
.doesNotFire();
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
FULL,
@@ -159,7 +162,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty()
@Test
public void testReplaceInnerJoinWithProject()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -173,7 +176,7 @@ public void testReplaceInnerJoinWithProject()
"c", PlanMatchPattern.expression("c")),
values("c")));
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -191,7 +194,7 @@ public void testReplaceInnerJoinWithProject()
@Test
public void testReplaceLeftJoinWithProject()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
LEFT,
@@ -205,7 +208,7 @@ public void testReplaceLeftJoinWithProject()
"c", PlanMatchPattern.expression("c")),
values("c")));
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
LEFT,
@@ -223,7 +226,7 @@ public void testReplaceLeftJoinWithProject()
@Test
public void testReplaceRightJoinWithProject()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
RIGHT,
@@ -237,7 +240,7 @@ public void testReplaceRightJoinWithProject()
"c", PlanMatchPattern.expression("c")),
values("c")));
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
RIGHT,
@@ -255,7 +258,7 @@ public void testReplaceRightJoinWithProject()
@Test
public void testReplaceFullJoinWithProject()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
FULL,
@@ -269,7 +272,7 @@ public void testReplaceFullJoinWithProject()
"c", PlanMatchPattern.expression("c")),
values("c")));
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
FULL,
@@ -287,7 +290,7 @@ public void testReplaceFullJoinWithProject()
@Test
public void testRemoveOutputDuplicates()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
@@ -309,17 +312,28 @@ public void testRemoveOutputDuplicates()
@Test
public void testNonDeterministicValues()
{
- tester().assertThat(new ReplaceJoinOverConstantWithProject())
+ FunctionCall randomFunction = new FunctionCall(
+ tester().getMetadata().resolveFunction(tester().getSession(), QualifiedName.of("random"), ImmutableList.of()).toQualifiedName(),
+ ImmutableList.of());
+
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
.on(p ->
p.join(
INNER,
- p.valuesOfExpressions(ImmutableList.of(p.symbol("a")), ImmutableList.of(expression("ROW(rand())"))),
+ p.valuesOfExpressions(ImmutableList.of(p.symbol("rand")), ImmutableList.of(new Row(ImmutableList.of(randomFunction)))),
p.values(5, p.symbol("b"))))
- .matches(
- strictProject(
- ImmutableMap.of(
- "a", PlanMatchPattern.expression("rand()"),
- "b", PlanMatchPattern.expression("b")),
- values("b")));
+ .doesNotFire();
+
+ FunctionCall uuidFunction = new FunctionCall(
+ tester().getMetadata().resolveFunction(tester().getSession(), QualifiedName.of("uuid"), ImmutableList.of()).toQualifiedName(),
+ ImmutableList.of());
+
+ tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata()))
+ .on(p ->
+ p.join(
+ INNER,
+ p.valuesOfExpressions(ImmutableList.of(p.symbol("uuid")), ImmutableList.of(new Row(ImmutableList.of(uuidFunction)))),
+ p.values(5, p.symbol("b"))))
+ .doesNotFire();
}
}
diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java
index 27403b606ce9..4fc70ca49612 100644
--- a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java
+++ b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java
@@ -72,6 +72,18 @@ d as (SELECT id FROM (VALUES (1)) AS t(id))
.matches("VALUES 1");
}
+ @Test
+ public void testSingleRowNonDeterministicSource()
+ {
+ assertThat(assertions.query("""
+ WITH data(id) AS (SELECT uuid())
+ SELECT COUNT(DISTINCT id)
+ FROM (VALUES 1, 2, 3, 4, 5, 6, 7, 8)
+ CROSS JOIN data
+ """))
+ .matches("VALUES BIGINT '1'");
+ }
+
@Test
public void testJoinOnNan()
{