Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ProjectNode.Locality;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
Expand Down Expand Up @@ -182,6 +183,11 @@ public static Optional<RowExpression> getHashExpression(FunctionAndTypeManager f
}

public static PlanNode addProjections(PlanNode source, PlanNodeIdAllocator planNodeIdAllocator, Map<VariableReferenceExpression, RowExpression> variableMap)
{
return addProjections(source, planNodeIdAllocator, variableMap, LOCAL);
}

public static PlanNode addProjections(PlanNode source, PlanNodeIdAllocator planNodeIdAllocator, Map<VariableReferenceExpression, RowExpression> variableMap, Locality locality)
{
Assignments.Builder assignments = Assignments.builder();
for (VariableReferenceExpression variableReferenceExpression : source.getOutputVariables()) {
Expand All @@ -194,7 +200,7 @@ public static PlanNode addProjections(PlanNode source, PlanNodeIdAllocator planN
planNodeIdAllocator.getNextId(),
source,
assignments.build(),
LOCAL);
locality);
}

// Add a projection node, which assignment new value if output exists in variableMap, otherwise identity assignment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isRemoveCrossJoinWithConstantSingleRowInputEnabled;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.UNKNOWN;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -103,7 +104,7 @@ else if (isOutputSingleConstantRow(leftInput, context)) {
if (!mapping.isPresent()) {
return Result.empty();
}
PlanNode resultNode = addProjections(joinInput, context.getIdAllocator(), mapping.get());
PlanNode resultNode = addProjections(joinInput, context.getIdAllocator(), mapping.get(), UNKNOWN);
if (node.getFilter().isPresent()) {
resultNode = new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), resultNode, node.getFilter().get());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.facebook.presto.spi.TestingColumnHandle;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
Expand All @@ -38,13 +39,16 @@
import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.UNKNOWN;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.testing.TestingEnvironment.getOperatorMethodHandle;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;

public class TestRemoveCrossJoinWithConstantInput
extends BaseRuleTest
Expand Down Expand Up @@ -296,10 +300,57 @@ public void testOneColumnValuesNodeExpression()
p.values(ImmutableList.of(leftKey), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT)), ImmutableList.of(constant(2L, BIGINT)))),
p.project(
assignment(rightKey2, p.rowExpression("cast(right_k1 as varchar)")),
p.values(ImmutableList.of(rightKey1), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT))))));
p.values(ImmutableList.of(rightKey1), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT))))));
})
.matches(
project(ImmutableMap.of("left_k1", expression("left_k1"), "right_k2", expression("cast(1 as varchar)")),
values("left_k1")));
}

@Test
public void testProjectNodeLocalityIsUnknown()
{
// Test that the generated ProjectNode has UNKNOWN locality, which allows subsequent optimizers
// to determine the optimal locality based on the context (e.g., if the projection involves remote functions)
PlanNode result = tester().assertThat(new RemoveCrossJoinWithConstantInput(getMetadata().getFunctionAndTypeManager()))
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "true")
.on(p ->
{
VariableReferenceExpression leftKey = p.variable("left_k1", BIGINT);
p.variable("right_k1", BIGINT);
return p.join(JoinType.INNER,
p.tableScan(ImmutableList.of(leftKey), ImmutableMap.of(leftKey, new TestingColumnHandle("col"))),
p.values(ImmutableList.of(p.variable("right_k1")), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT)))));
})
.get();

assertTrue(result instanceof ProjectNode, "Expected result to be ProjectNode");
ProjectNode projectNode = (ProjectNode) result;
assertEquals(projectNode.getLocality(), UNKNOWN, "ProjectNode locality should be UNKNOWN to allow subsequent optimizers to set it");
}

@Test
public void testProjectNodeLocalityIsUnknownWithFilter()
{
// Test that when there's a join filter, the ProjectNode underneath the FilterNode has UNKNOWN locality
PlanNode result = tester().assertThat(new RemoveCrossJoinWithConstantInput(getMetadata().getFunctionAndTypeManager()))
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "true")
.on(p ->
{
VariableReferenceExpression leftKey = p.variable("left_k1", BIGINT);
p.variable("right_k1", BIGINT);
return p.join(JoinType.INNER,
p.tableScan(ImmutableList.of(leftKey), ImmutableMap.of(leftKey, new TestingColumnHandle("col"))),
p.values(ImmutableList.of(p.variable("right_k1")), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT)))),
p.rowExpression("left_k1 + right_k1 > 2"));
})
.get();

assertTrue(result instanceof FilterNode, "Expected result to be FilterNode");
FilterNode filterNode = (FilterNode) result;
PlanNode source = filterNode.getSource();
assertTrue(source instanceof ProjectNode, "Expected FilterNode source to be ProjectNode");
ProjectNode projectNode = (ProjectNode) source;
assertEquals(projectNode.getLocality(), UNKNOWN, "ProjectNode locality should be UNKNOWN");
}
}
Loading