diff --git a/presto-native-sidecar-plugin/pom.xml b/presto-native-sidecar-plugin/pom.xml index e998e698f929e..6a6298953d51c 100644 --- a/presto-native-sidecar-plugin/pom.xml +++ b/presto-native-sidecar-plugin/pom.xml @@ -79,6 +79,12 @@ provided + + com.facebook.presto + presto-analyzer + test + + com.facebook.airlift units @@ -288,6 +294,7 @@ **/TestNativeSidecar*.java **/TestNativeExpressionInterpreter.java + **/TestNativeExpressionOptimizer.java diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizer.java index bb669b43782ba..d5e5702033a16 100644 --- a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizer.java +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizer.java @@ -345,12 +345,15 @@ public RowExpression visitExpression(RowExpression originalExpression, Void cont @Override public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context) { - if (canBeReplaced(lambda.getBody())) { + if (canBeReplaced(lambda)) { + RowExpression replacement = resolver.apply(lambda); + // Sidecar optimizes only the body of lambda expression. + RowExpression optimizedBody = ((LambdaDefinitionExpression) replacement).getBody().accept(this, context); return new LambdaDefinitionExpression( lambda.getSourceLocation(), lambda.getArgumentTypes(), lambda.getArguments(), - toRowExpression(lambda.getSourceLocation(), resolver.apply(lambda.getBody()), lambda.getBody().getType())); + toRowExpression(lambda.getSourceLocation(), optimizedBody, optimizedBody.getType())); } return lambda; } diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/expressions/TestNativeExpressionOptimizer.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/expressions/TestNativeExpressionOptimizer.java new file mode 100644 index 0000000000000..1df9d84e123c7 --- /dev/null +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/expressions/TestNativeExpressionOptimizer.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sidecar.expressions; + +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.operator.scalar.FunctionAssertions; +import com.facebook.presto.sidecar.NativeSidecarPluginQueryRunner; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.AfterClass; +import org.testng.annotations.Test; + +import java.util.function.Function; + +import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.sidecar.expressions.NativeExpressionOptimizerFactory.NAME; +import static com.facebook.presto.sql.expressions.AbstractTestExpressionInterpreter.SYMBOL_TYPES; +import static com.facebook.presto.sql.expressions.AbstractTestExpressionInterpreter.assertRowExpressionEvaluationEquals; + +public class TestNativeExpressionOptimizer +{ + private final DistributedQueryRunner queryRunner; + private final MetadataManager metadata; + private final TestingRowExpressionTranslator translator; + private final NativeExpressionOptimizer expressionOptimizer; + + public TestNativeExpressionOptimizer() + throws Exception + { + this.queryRunner = NativeSidecarPluginQueryRunner.getQueryRunner(); + FunctionAndTypeManager functionAndTypeManager = queryRunner.getCoordinator().getFunctionAndTypeManager(); + this.metadata = createTestMetadataManager(functionAndTypeManager); + this.translator = new TestingRowExpressionTranslator(metadata); + this.expressionOptimizer = (NativeExpressionOptimizer) queryRunner.getCoordinator() + .getExpressionManager() + .getExpressionOptimizer(NAME); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + closeAllRuntimeException(queryRunner); + } + + @Test + public void testLambdaBodyConstantFolding() + { + // Simple lambda constant folding. + assertOptimizedEquals( + "transform(ARRAY[unbound_long, unbound_long2], x -> 1 + 1)", + "transform(ARRAY[unbound_long, unbound_long2], x -> 2)"); + assertOptimizedEquals( + "transform(ARRAY[unbound_long, unbound_long2], x -> cast('123' AS integer))", + "transform(ARRAY[unbound_long, unbound_long2], x -> 123)"); + assertOptimizedEquals( + "transform(ARRAY[unbound_long, unbound_long2], x -> cast(json_parse('[1, 2]') AS ARRAY)[1] + 1)", + "transform(ARRAY[unbound_long, unbound_long2], x -> 2)"); + + // Nested lambda constant folding. + assertOptimizedEquals( + "transform(ARRAY[unbound_long, unbound_long2], x -> transform(ARRAY[1, 2], y -> 1 + 1))", + "transform(ARRAY[unbound_long, unbound_long2], x -> transform(ARRAY[1, 2], y -> 2))"); + // Multiple lambda occurrences constant folding. + assertOptimizedEquals( + "filter(transform(ARRAY[unbound_long, unbound_long2], x -> 1 + 1), x -> true and false)", + "filter(transform(ARRAY[unbound_long, unbound_long2], x -> 2), x -> false)"); + } + + private void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + RowExpression optimizedActual = optimize(actual, ExpressionOptimizer.Level.OPTIMIZED); + RowExpression optimizedExpected = optimize(expected, ExpressionOptimizer.Level.OPTIMIZED); + assertRowExpressionEvaluationEquals(optimizedActual, optimizedExpected); + } + + private RowExpression optimize(@Language("SQL") String expression, ExpressionOptimizer.Level level) + { + RowExpression parsedExpression = sqlToRowExpression(expression); + Function variableResolver = variable -> null; + return expressionOptimizer.optimize(parsedExpression, level, TEST_SESSION.toConnectorSession(), variableResolver); + } + + private RowExpression sqlToRowExpression(String expression) + { + Expression parsedExpression = FunctionAssertions.createExpression(expression, metadata, SYMBOL_TYPES); + return translator.translate(parsedExpression, SYMBOL_TYPES); + } +}