Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,16 @@ public List<ProjectionContext> visitCall(CallExpression call, Void context)
boolean local = !functionMetadata.getImplementationType().isExternal();

// Break function arguments into local and remote projections first
ImmutableList.Builder<RowExpression> newArgumentsBuilder = ImmutableList.builder();
ImmutableList.Builder<VariableReferenceExpression> newArgumentsBuilder = ImmutableList.builder();
List<ProjectionContext> processedArguments = processArguments(call.getArguments(), newArgumentsBuilder);
List<RowExpression> newArguments = newArgumentsBuilder.build();
List<VariableReferenceExpression> newArguments = newArgumentsBuilder.build();
CallExpression newCall = new CallExpression(
call.getDisplayName(),
call.getFunctionHandle(),
call.getType(),
newArguments);
newArguments.stream()
.map(RowExpression.class::cast)
.collect(toImmutableList()));

if (local) {
if (processedArguments.size() == 1 && !processedArguments.get(0).isRemote()) {
Expand All @@ -277,7 +279,7 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
call.getFunctionHandle(),
call.getType(),
newArguments.stream()
.map(argument -> argument instanceof VariableReferenceExpression ? last.getProjections().get(argument) : argument)
.map(last.getProjections()::get)
.collect(toImmutableList()))),
false));
return projectionContextBuilder.build();
Expand All @@ -303,13 +305,13 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
@Override
public List<ProjectionContext> visitInputReference(InputReferenceExpression reference, Void context)
{
throw new IllegalStateException("Optimizers should not see InputReferenceExpression");
return ImmutableList.of();
}

@Override
public List<ProjectionContext> visitConstant(ConstantExpression literal, Void context)
{
throw new IllegalStateException("We should not create ProjectionContext for constants");
return ImmutableList.of();
}

@Override
Expand All @@ -327,9 +329,9 @@ public List<ProjectionContext> visitVariableReference(VariableReferenceExpressio
@Override
public List<ProjectionContext> visitSpecialForm(SpecialFormExpression specialForm, Void context)
{
ImmutableList.Builder<RowExpression> newArgumentsBuilder = ImmutableList.builder();
ImmutableList.Builder<VariableReferenceExpression> newArgumentsBuilder = ImmutableList.builder();
List<ProjectionContext> processedArguments = processArguments(specialForm.getArguments(), newArgumentsBuilder);
List<RowExpression> newArguments = newArgumentsBuilder.build();
List<VariableReferenceExpression> newArguments = newArgumentsBuilder.build();
if (processedArguments.size() == 1 && !processedArguments.get(0).isRemote()) {
// Arguments do not contain remote projection
return ImmutableList.of();
Expand All @@ -346,7 +348,7 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
specialForm.getForm(),
specialForm.getType(),
newArguments.stream()
.map(argument -> argument instanceof VariableReferenceExpression ? last.getProjections().get(argument) : argument)
.map(last.getProjections()::get)
.collect(toImmutableList()))),
false));
return projectionContextBuilder.build();
Expand All @@ -361,30 +363,27 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
new SpecialFormExpression(
specialForm.getForm(),
specialForm.getType(),
newArguments)),
newArguments.stream()
.map(RowExpression.class::cast)
.collect(toImmutableList()))),
false));
return projectionContextBuilder.build();
}
}

private List<ProjectionContext> processArguments(List<RowExpression> arguments, ImmutableList.Builder<RowExpression> newArguments)
private List<ProjectionContext> processArguments(List<RowExpression> arguments, ImmutableList.Builder<VariableReferenceExpression> newArguments)
{
// Break function arguments into local and remote projections first
ImmutableList.Builder<List<ProjectionContext>> argumentProjections = ImmutableList.builder();

for (RowExpression argument : arguments) {
if (argument instanceof ConstantExpression) {
newArguments.add(argument);
}
else {
List<ProjectionContext> argumentProjection = argument.accept(this, null);
if (argumentProjection.isEmpty()) {
VariableReferenceExpression variable = variableAllocator.newVariable(argument);
argumentProjection = ImmutableList.of(new ProjectionContext(ImmutableMap.of(variable, argument), false));
}
argumentProjections.add(argumentProjection);
newArguments.add(getAssignedArgument(argumentProjection));
List<ProjectionContext> argumentProjection = argument.accept(this, null);
if (argumentProjection.isEmpty()) {
VariableReferenceExpression variable = variableAllocator.newVariable(argument);
argumentProjection = ImmutableList.of(new ProjectionContext(ImmutableMap.of(variable, argument), false));
}
argumentProjections.add(argumentProjection);
newArguments.add(getAssignedArgument(argumentProjection));
}
return mergeProjectionContexts(argumentProjections.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,6 @@ void testLocalOnly()
assertEquals(rewritten.get(0).getProjections().size(), 2);
}

@Test
void testRemoteWithConstantArgument()
{
PlanBuilder planBuilder = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), getMetadata());

PlanRemotePojections rule = new PlanRemotePojections(getFunctionAndTypeManager());
List<ProjectionContext> rewritten = rule.planRemoteAssignments(Assignments.builder()
.put(planBuilder.variable("a"), planBuilder.rowExpression("unittest.memory.remote_foo(0)"))
.put(planBuilder.variable("b"), planBuilder.rowExpression("unittest.memory.remote_foo()"))
.build(), new PlanVariableAllocator(planBuilder.getTypes().allVariables()));
assertEquals(rewritten.size(), 1);
}

@Test
void testRemoteOnly()
{
Expand Down Expand Up @@ -236,7 +223,7 @@ void testMixedExpressionRewrite()
p.variable("y", INTEGER);
return p.project(
Assignments.builder()
.put(p.variable("a"), p.rowExpression("unittest.memory.remote_foo(1, y + unittest.memory.remote_foo(x))")) // identity
.put(p.variable("a"), p.rowExpression("unittest.memory.remote_foo(x, y + unittest.memory.remote_foo(x))")) // identity
.put(p.variable("b"), p.rowExpression("x IS NULL OR y IS NULL")) // complex expression referenced multiple times
.put(p.variable("c"), p.rowExpression("abs(unittest.memory.remote_foo()) > 0")) // complex expression referenced multiple times
.put(p.variable("d"), p.rowExpression("unittest.memory.remote_foo(x + y, abs(x))")) // literal referenced multiple times
Expand All @@ -246,31 +233,35 @@ void testMixedExpressionRewrite()
.matches(
project(
ImmutableMap.of(
"a", PlanMatchPattern.expression("unittest.memory.remote_foo(1, add)"),
"a", PlanMatchPattern.expression("unittest.memory.remote_foo(x, add)"),
"b", PlanMatchPattern.expression("b"),
"c", PlanMatchPattern.expression("c"),
"d", PlanMatchPattern.expression("d")),
project(
ImmutableMap.of(
"x", PlanMatchPattern.expression("x"),
"add", PlanMatchPattern.expression("y + unittest_memory_remote_foo"),
"b", PlanMatchPattern.expression("b"),
"c", PlanMatchPattern.expression("abs(unittest_memory_remote_foo_7) > 0"),
"c", PlanMatchPattern.expression("abs(unittest_memory_remote_foo_7) > expr_8"),
"d", PlanMatchPattern.expression("d")),
project(
ImmutableMap.<String, ExpressionMatcher>builder()
.put("x", PlanMatchPattern.expression("x"))
.put("y", PlanMatchPattern.expression("y"))
.put("unittest_memory_remote_foo", PlanMatchPattern.expression("unittest.memory.remote_foo(x)"))
.put("b", PlanMatchPattern.expression("b"))
.put("unittest_memory_remote_foo_7", PlanMatchPattern.expression("unittest.memory.remote_foo()"))
.put("d", PlanMatchPattern.expression("unittest.memory.remote_foo(add_9, abs_11)"))
.put("expr_8", PlanMatchPattern.expression("expr_8"))
.put("d", PlanMatchPattern.expression("unittest.memory.remote_foo(add_14, abs_16)"))
.build(),
project(
ImmutableMap.<String, ExpressionMatcher>builder()
.put("x", PlanMatchPattern.expression("x"))
.put("y", PlanMatchPattern.expression("y"))
.put("b", PlanMatchPattern.expression("x IS NULL OR y is NULL"))
.put("add_9", PlanMatchPattern.expression("x + y"))
.put("abs_11", PlanMatchPattern.expression("abs(x)"))
.put("expr_8", PlanMatchPattern.expression("0"))
.put("add_14", PlanMatchPattern.expression("x + y"))
.put("abs_16", PlanMatchPattern.expression("abs(x)"))
.build(),
values(ImmutableMap.of("x", 0, "y", 1)))))));
}
Expand Down