diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index 4e32a6e6918ed..ee25022f84c41 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -36,6 +36,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.ProjectNode.Locality; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; @@ -182,6 +183,11 @@ public static Optional getHashExpression(FunctionAndTypeManager f } public static PlanNode addProjections(PlanNode source, PlanNodeIdAllocator planNodeIdAllocator, Map variableMap) + { + return addProjections(source, planNodeIdAllocator, variableMap, LOCAL); + } + + public static PlanNode addProjections(PlanNode source, PlanNodeIdAllocator planNodeIdAllocator, Map variableMap, Locality locality) { Assignments.Builder assignments = Assignments.builder(); for (VariableReferenceExpression variableReferenceExpression : source.getOutputVariables()) { @@ -194,7 +200,7 @@ public static PlanNode addProjections(PlanNode source, PlanNodeIdAllocator planN planNodeIdAllocator.getNextId(), source, assignments.build(), - LOCAL); + locality); } // Add a projection node, which assignment new value if output exists in variableMap, otherwise identity assignment diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveCrossJoinWithConstantInput.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveCrossJoinWithConstantInput.java index a297fe0b6cd82..0b036bc75879b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveCrossJoinWithConstantInput.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveCrossJoinWithConstantInput.java @@ -36,6 +36,7 @@ import java.util.stream.IntStream; import static com.facebook.presto.SystemSessionProperties.isRemoveCrossJoinWithConstantSingleRowInputEnabled; +import static com.facebook.presto.spi.plan.ProjectNode.Locality.UNKNOWN; import static com.facebook.presto.sql.planner.PlannerUtils.addProjections; import static com.facebook.presto.sql.planner.plan.Patterns.join; import static com.google.common.base.Preconditions.checkState; @@ -103,7 +104,7 @@ else if (isOutputSingleConstantRow(leftInput, context)) { if (!mapping.isPresent()) { return Result.empty(); } - PlanNode resultNode = addProjections(joinInput, context.getIdAllocator(), mapping.get()); + PlanNode resultNode = addProjections(joinInput, context.getIdAllocator(), mapping.get(), UNKNOWN); if (node.getFilter().isPresent()) { resultNode = new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), resultNode, node.getFilter().get()); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveCrossJoinWithConstantInput.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveCrossJoinWithConstantInput.java index 9514346f1b0db..98bcff8ed68cd 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveCrossJoinWithConstantInput.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveCrossJoinWithConstantInput.java @@ -21,6 +21,7 @@ import com.facebook.presto.spi.TestingColumnHandle; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -38,6 +39,7 @@ import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.plan.ProjectNode.Locality.UNKNOWN; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; @@ -45,6 +47,8 @@ import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.testing.TestingEnvironment.getOperatorMethodHandle; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; public class TestRemoveCrossJoinWithConstantInput extends BaseRuleTest @@ -296,10 +300,57 @@ public void testOneColumnValuesNodeExpression() p.values(ImmutableList.of(leftKey), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT)), ImmutableList.of(constant(2L, BIGINT)))), p.project( assignment(rightKey2, p.rowExpression("cast(right_k1 as varchar)")), - p.values(ImmutableList.of(rightKey1), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT)))))); + p.values(ImmutableList.of(rightKey1), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT)))))); }) .matches( project(ImmutableMap.of("left_k1", expression("left_k1"), "right_k2", expression("cast(1 as varchar)")), values("left_k1"))); } + + @Test + public void testProjectNodeLocalityIsUnknown() + { + // Test that the generated ProjectNode has UNKNOWN locality, which allows subsequent optimizers + // to determine the optimal locality based on the context (e.g., if the projection involves remote functions) + PlanNode result = tester().assertThat(new RemoveCrossJoinWithConstantInput(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "true") + .on(p -> + { + VariableReferenceExpression leftKey = p.variable("left_k1", BIGINT); + p.variable("right_k1", BIGINT); + return p.join(JoinType.INNER, + p.tableScan(ImmutableList.of(leftKey), ImmutableMap.of(leftKey, new TestingColumnHandle("col"))), + p.values(ImmutableList.of(p.variable("right_k1")), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT))))); + }) + .get(); + + assertTrue(result instanceof ProjectNode, "Expected result to be ProjectNode"); + ProjectNode projectNode = (ProjectNode) result; + assertEquals(projectNode.getLocality(), UNKNOWN, "ProjectNode locality should be UNKNOWN to allow subsequent optimizers to set it"); + } + + @Test + public void testProjectNodeLocalityIsUnknownWithFilter() + { + // Test that when there's a join filter, the ProjectNode underneath the FilterNode has UNKNOWN locality + PlanNode result = tester().assertThat(new RemoveCrossJoinWithConstantInput(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "true") + .on(p -> + { + VariableReferenceExpression leftKey = p.variable("left_k1", BIGINT); + p.variable("right_k1", BIGINT); + return p.join(JoinType.INNER, + p.tableScan(ImmutableList.of(leftKey), ImmutableMap.of(leftKey, new TestingColumnHandle("col"))), + p.values(ImmutableList.of(p.variable("right_k1")), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT)))), + p.rowExpression("left_k1 + right_k1 > 2")); + }) + .get(); + + assertTrue(result instanceof FilterNode, "Expected result to be FilterNode"); + FilterNode filterNode = (FilterNode) result; + PlanNode source = filterNode.getSource(); + assertTrue(source instanceof ProjectNode, "Expected FilterNode source to be ProjectNode"); + ProjectNode projectNode = (ProjectNode) source; + assertEquals(projectNode.getLocality(), UNKNOWN, "ProjectNode locality should be UNKNOWN"); + } }