diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PlanRemotePojections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PlanRemotePojections.java index 6c8194f023539..f8d96e41a3ce3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PlanRemotePojections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PlanRemotePojections.java @@ -250,14 +250,16 @@ public List visitCall(CallExpression call, Void context) boolean local = !functionMetadata.getImplementationType().isExternal(); // Break function arguments into local and remote projections first - ImmutableList.Builder newArgumentsBuilder = ImmutableList.builder(); + ImmutableList.Builder newArgumentsBuilder = ImmutableList.builder(); List processedArguments = processArguments(call.getArguments(), newArgumentsBuilder); - List newArguments = newArgumentsBuilder.build(); + List newArguments = newArgumentsBuilder.build(); CallExpression newCall = new CallExpression( call.getDisplayName(), call.getFunctionHandle(), call.getType(), - newArguments); + newArguments.stream() + .map(RowExpression.class::cast) + .collect(toImmutableList())); if (local) { if (processedArguments.size() == 1 && !processedArguments.get(0).isRemote()) { @@ -277,7 +279,7 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) { call.getFunctionHandle(), call.getType(), newArguments.stream() - .map(argument -> argument instanceof VariableReferenceExpression ? last.getProjections().get(argument) : argument) + .map(last.getProjections()::get) .collect(toImmutableList()))), false)); return projectionContextBuilder.build(); @@ -303,13 +305,13 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) { @Override public List visitInputReference(InputReferenceExpression reference, Void context) { - throw new IllegalStateException("Optimizers should not see InputReferenceExpression"); + return ImmutableList.of(); } @Override public List visitConstant(ConstantExpression literal, Void context) { - throw new IllegalStateException("We should not create ProjectionContext for constants"); + return ImmutableList.of(); } @Override @@ -327,9 +329,9 @@ public List visitVariableReference(VariableReferenceExpressio @Override public List visitSpecialForm(SpecialFormExpression specialForm, Void context) { - ImmutableList.Builder newArgumentsBuilder = ImmutableList.builder(); + ImmutableList.Builder newArgumentsBuilder = ImmutableList.builder(); List processedArguments = processArguments(specialForm.getArguments(), newArgumentsBuilder); - List newArguments = newArgumentsBuilder.build(); + List newArguments = newArgumentsBuilder.build(); if (processedArguments.size() == 1 && !processedArguments.get(0).isRemote()) { // Arguments do not contain remote projection return ImmutableList.of(); @@ -346,7 +348,7 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) { specialForm.getForm(), specialForm.getType(), newArguments.stream() - .map(argument -> argument instanceof VariableReferenceExpression ? last.getProjections().get(argument) : argument) + .map(last.getProjections()::get) .collect(toImmutableList()))), false)); return projectionContextBuilder.build(); @@ -361,30 +363,27 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) { new SpecialFormExpression( specialForm.getForm(), specialForm.getType(), - newArguments)), + newArguments.stream() + .map(RowExpression.class::cast) + .collect(toImmutableList()))), false)); return projectionContextBuilder.build(); } } - private List processArguments(List arguments, ImmutableList.Builder newArguments) + private List processArguments(List arguments, ImmutableList.Builder newArguments) { // Break function arguments into local and remote projections first ImmutableList.Builder> argumentProjections = ImmutableList.builder(); for (RowExpression argument : arguments) { - if (argument instanceof ConstantExpression) { - newArguments.add(argument); - } - else { - List argumentProjection = argument.accept(this, null); - if (argumentProjection.isEmpty()) { - VariableReferenceExpression variable = variableAllocator.newVariable(argument); - argumentProjection = ImmutableList.of(new ProjectionContext(ImmutableMap.of(variable, argument), false)); - } - argumentProjections.add(argumentProjection); - newArguments.add(getAssignedArgument(argumentProjection)); + List argumentProjection = argument.accept(this, null); + if (argumentProjection.isEmpty()) { + VariableReferenceExpression variable = variableAllocator.newVariable(argument); + argumentProjection = ImmutableList.of(new ProjectionContext(ImmutableMap.of(variable, argument), false)); } + argumentProjections.add(argumentProjection); + newArguments.add(getAssignedArgument(argumentProjection)); } return mergeProjectionContexts(argumentProjections.build()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPlanRemoteProjections.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPlanRemoteProjections.java index dc8273bc123c6..34696b8c02beb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPlanRemoteProjections.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPlanRemoteProjections.java @@ -136,19 +136,6 @@ void testLocalOnly() assertEquals(rewritten.get(0).getProjections().size(), 2); } - @Test - void testRemoteWithConstantArgument() - { - PlanBuilder planBuilder = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), getMetadata()); - - PlanRemotePojections rule = new PlanRemotePojections(getFunctionAndTypeManager()); - List rewritten = rule.planRemoteAssignments(Assignments.builder() - .put(planBuilder.variable("a"), planBuilder.rowExpression("unittest.memory.remote_foo(0)")) - .put(planBuilder.variable("b"), planBuilder.rowExpression("unittest.memory.remote_foo()")) - .build(), new PlanVariableAllocator(planBuilder.getTypes().allVariables())); - assertEquals(rewritten.size(), 1); - } - @Test void testRemoteOnly() { @@ -236,7 +223,7 @@ void testMixedExpressionRewrite() p.variable("y", INTEGER); return p.project( Assignments.builder() - .put(p.variable("a"), p.rowExpression("unittest.memory.remote_foo(1, y + unittest.memory.remote_foo(x))")) // identity + .put(p.variable("a"), p.rowExpression("unittest.memory.remote_foo(x, y + unittest.memory.remote_foo(x))")) // identity .put(p.variable("b"), p.rowExpression("x IS NULL OR y IS NULL")) // complex expression referenced multiple times .put(p.variable("c"), p.rowExpression("abs(unittest.memory.remote_foo()) > 0")) // complex expression referenced multiple times .put(p.variable("d"), p.rowExpression("unittest.memory.remote_foo(x + y, abs(x))")) // literal referenced multiple times @@ -246,31 +233,35 @@ void testMixedExpressionRewrite() .matches( project( ImmutableMap.of( - "a", PlanMatchPattern.expression("unittest.memory.remote_foo(1, add)"), + "a", PlanMatchPattern.expression("unittest.memory.remote_foo(x, add)"), "b", PlanMatchPattern.expression("b"), "c", PlanMatchPattern.expression("c"), "d", PlanMatchPattern.expression("d")), project( ImmutableMap.of( + "x", PlanMatchPattern.expression("x"), "add", PlanMatchPattern.expression("y + unittest_memory_remote_foo"), "b", PlanMatchPattern.expression("b"), - "c", PlanMatchPattern.expression("abs(unittest_memory_remote_foo_7) > 0"), + "c", PlanMatchPattern.expression("abs(unittest_memory_remote_foo_7) > expr_8"), "d", PlanMatchPattern.expression("d")), project( ImmutableMap.builder() + .put("x", PlanMatchPattern.expression("x")) .put("y", PlanMatchPattern.expression("y")) .put("unittest_memory_remote_foo", PlanMatchPattern.expression("unittest.memory.remote_foo(x)")) .put("b", PlanMatchPattern.expression("b")) .put("unittest_memory_remote_foo_7", PlanMatchPattern.expression("unittest.memory.remote_foo()")) - .put("d", PlanMatchPattern.expression("unittest.memory.remote_foo(add_9, abs_11)")) + .put("expr_8", PlanMatchPattern.expression("expr_8")) + .put("d", PlanMatchPattern.expression("unittest.memory.remote_foo(add_14, abs_16)")) .build(), project( ImmutableMap.builder() .put("x", PlanMatchPattern.expression("x")) .put("y", PlanMatchPattern.expression("y")) .put("b", PlanMatchPattern.expression("x IS NULL OR y is NULL")) - .put("add_9", PlanMatchPattern.expression("x + y")) - .put("abs_11", PlanMatchPattern.expression("abs(x)")) + .put("expr_8", PlanMatchPattern.expression("0")) + .put("add_14", PlanMatchPattern.expression("x + y")) + .put("abs_16", PlanMatchPattern.expression("abs(x)")) .build(), values(ImmutableMap.of("x", 0, "y", 1))))))); }