diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 382d39ee170cc..b1dcab354324d 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -167,6 +167,7 @@ public final class SystemSessionProperties public static final String DYNAMIC_FILTERING_MAX_PER_DRIVER_ROW_COUNT = "dynamic_filtering_max_per_driver_row_count"; public static final String DYNAMIC_FILTERING_MAX_PER_DRIVER_SIZE = "dynamic_filtering_max_per_driver_size"; public static final String LEGACY_TYPE_COERCION_WARNING_ENABLED = "legacy_type_coercion_warning_enabled"; + public static final String INLINE_SQL_FUNCTIONS = "inline_sql_functions"; private final List> sessionProperties; @@ -869,7 +870,12 @@ public SystemSessionProperties( LEGACY_TYPE_COERCION_WARNING_ENABLED, "Enable warning for query relying on legacy type coercion", featuresConfig.isLegacyDateTimestampToVarcharCoercion(), - true)); + true), + booleanProperty( + INLINE_SQL_FUNCTIONS, + "Inline SQL function definition at plan time", + featuresConfig.isInlineSqlFunctions(), + false)); } public List> getSessionProperties() @@ -1470,4 +1476,9 @@ public static boolean isLegacyTypeCoercionWarningEnabled(Session session) { return session.getSystemProperty(LEGACY_TYPE_COERCION_WARNING_ENABLED, Boolean.class); } + + public static boolean isInlineSqlFunctions(Session session) + { + return session.getSystemProperty(INLINE_SQL_FUNCTIONS, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java b/presto-main/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java index 624497b35d54f..122372588027d 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java @@ -29,6 +29,8 @@ import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CreateFunction; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.Return; import com.facebook.presto.sql.tree.RoutineBody; import com.facebook.presto.transaction.TransactionManager; @@ -103,10 +105,32 @@ private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement, Me Expression bodyExpression = ((Return) statement.getBody()).getExpression(); Type bodyType = analysis.getType(bodyExpression); + // Coerce expressions in body if necessary + bodyExpression = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() + { + @Override + public Expression rewriteExpression(Expression expression, Void context, ExpressionTreeRewriter treeRewriter) + { + Expression rewritten = treeRewriter.defaultRewrite(expression, null); + + Type coercion = analysis.getCoercion(expression); + if (coercion != null) { + return new Cast( + rewritten, + coercion.getTypeSignature().toString(), + false, + analysis.isTypeOnlyCoercion(expression)); + } + return rewritten; + } + }, bodyExpression, null); + if (!bodyType.equals(metadata.getType(returnType))) { - // Casting is safe-here, since we have verified that the actual type of the body is coercible to declared return type. - body = new Return(new Cast(bodyExpression, statement.getReturnType())); + // Casting is safe here, since we have verified at analysis time that the actual type of the body is coercible to declared return type. + bodyExpression = new Cast(bodyExpression, statement.getReturnType()); } + + body = new Return(bodyExpression); } return new SqlInvokedFunction( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 5a996ce8aad91..f6c4de60ccef8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -161,6 +161,7 @@ public class FeaturesConfig private boolean preferDistributedUnion = true; private boolean optimizeNullsInJoin; private boolean pushdownDereferenceEnabled; + private boolean inlineSqlFunctions = true; private String warnOnNoTableLayoutFilter = ""; @@ -1385,4 +1386,16 @@ public FeaturesConfig setWarnOnNoTableLayoutFilter(String warnOnNoTableLayoutFil this.warnOnNoTableLayoutFilter = warnOnNoTableLayoutFilter; return this; } + + public boolean isInlineSqlFunctions() + { + return inlineSqlFunctions; + } + + @Config("inline-sql-functions") + public FeaturesConfig setInlineSqlFunctions(boolean inlineSqlFunctions) + { + this.inlineSqlFunctions = inlineSqlFunctions; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java index cc34012e3bd1a..2a1a24d2394b7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java @@ -932,7 +932,7 @@ protected Object visitFunctionCall(FunctionCall node, Object context) result = functionInvoker.invoke(functionHandle, session.getSqlFunctionProperties(), argumentValues); break; case SQL: - Expression function = getSqlFunctionExpression(functionMetadata, (SqlInvokedScalarFunctionImplementation) metadata.getFunctionManager().getScalarFunctionImplementation(functionHandle), session.getSqlFunctionProperties(), node.getArguments()); + Expression function = getSqlFunctionExpression(functionMetadata, (SqlInvokedScalarFunctionImplementation) metadata.getFunctionManager().getScalarFunctionImplementation(functionHandle), metadata, session.getSqlFunctionProperties(), node.getArguments()); ExpressionInterpreter functionInterpreter = new ExpressionInterpreter( function, metadata, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 8312ea9b30fe5..e38a3c5a9ef92 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -42,6 +42,7 @@ import com.facebook.presto.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter; import com.facebook.presto.sql.planner.iterative.rule.ImplementFilteredAggregations; import com.facebook.presto.sql.planner.iterative.rule.InlineProjections; +import com.facebook.presto.sql.planner.iterative.rule.InlineSqlFunctions; import com.facebook.presto.sql.planner.iterative.rule.MergeFilters; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithDistinct; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithSort; @@ -330,6 +331,11 @@ public PlanOptimizers( new LimitPushDown(), // Run the LimitPushDown after flattening set operators to make it easier to do the set flattening new PruneUnreferencedOutputs(), inlineProjections, + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + new InlineSqlFunctions(metadata, sqlParser).rules()), new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineSqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineSqlFunctions.java new file mode 100644 index 0000000000000..e42414f236058 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineSqlFunctions.java @@ -0,0 +1,136 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionImplementationType; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.NodeRef; +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.facebook.presto.SystemSessionProperties.isInlineSqlFunctions; +import static com.facebook.presto.metadata.FunctionManager.qualifyFunctionName; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.relational.SqlFunctionUtils.getSqlFunctionExpression; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNull; + +public class InlineSqlFunctions + extends ExpressionRewriteRuleSet +{ + public InlineSqlFunctions(Metadata metadata, SqlParser sqlParser) + { + super(createRewrite(metadata, sqlParser)); + } + + private static ExpressionRewriter createRewrite(Metadata metadata, SqlParser sqlParser) + { + requireNonNull(metadata, "metadata is null"); + requireNonNull(sqlParser, "sqlParser is null"); + + return (expression, context) -> InlineSqlFunctionsRewriter.rewrite( + expression, + context.getSession(), + metadata, + getExpressionTypes( + context.getSession(), + metadata, + sqlParser, + context.getVariableAllocator().getTypes(), + expression, + emptyList(), + context.getWarningCollector())); + } + + @Override + public Set> rules() + { + // Aggregations are not rewritten because they cannot have SQL functions + return ImmutableSet.of( + projectExpressionRewrite(), + filterExpressionRewrite(), + joinExpressionRewrite(), + valuesExpressionRewrite()); + } + + public static class InlineSqlFunctionsRewriter + { + private InlineSqlFunctionsRewriter() {} + + public static Expression rewrite(Expression expression, Session session, Metadata metadata, Map, Type> expressionTypes) + { + if (isInlineSqlFunctions(session)) { + return ExpressionTreeRewriter.rewriteWith(new Visitor(session, metadata, expressionTypes), expression); + } + return expression; + } + + private static class Visitor + extends com.facebook.presto.sql.tree.ExpressionRewriter + { + private final Session session; + private final Metadata metadata; + private final Map, Type> expressionTypes; + + public Visitor(Session session, Metadata metadata, Map, Type> expressionTypes) + { + this.session = requireNonNull(session, "session is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.expressionTypes = expressionTypes; + } + + @Override + public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter treeRewriter) + { + List argumentTypes = new ArrayList<>(); + List rewrittenArguments = new ArrayList<>(); + for (Expression argument : node.getArguments()) { + argumentTypes.add(expressionTypes.get(NodeRef.of(argument))); + rewrittenArguments.add(treeRewriter.rewrite(argument, context)); + } + + FunctionHandle functionHandle = metadata.getFunctionManager().resolveFunction( + session.getTransactionId(), + qualifyFunctionName(node.getName()), + fromTypes(argumentTypes)); + FunctionMetadata functionMetadata = metadata.getFunctionManager().getFunctionMetadata(functionHandle); + + if (functionMetadata.getImplementationType() != FunctionImplementationType.SQL) { + return new FunctionCall(node.getName(), rewrittenArguments); + } + return getSqlFunctionExpression( + functionMetadata, + (SqlInvokedScalarFunctionImplementation) metadata.getFunctionManager().getScalarFunctionImplementation(functionHandle), + metadata, + session.getSqlFunctionProperties(), + rewrittenArguments); + } + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java index 82d5523ddcae2..05967ad563a2f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java @@ -55,18 +55,18 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; import static java.util.Locale.ENGLISH; -import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; public final class SqlFunctionUtils { private SqlFunctionUtils() {} - public static Expression getSqlFunctionExpression(FunctionMetadata functionMetadata, SqlInvokedScalarFunctionImplementation implementation, SqlFunctionProperties sqlFunctionProperties, List arguments) + public static Expression getSqlFunctionExpression(FunctionMetadata functionMetadata, SqlInvokedScalarFunctionImplementation implementation, Metadata metadata, SqlFunctionProperties sqlFunctionProperties, List arguments) { checkArgument(functionMetadata.getImplementationType().equals(SQL), format("Expect SQL function, get %s", functionMetadata.getImplementationType())); checkArgument(functionMetadata.getArgumentNames().isPresent(), "ArgumentNames is missing"); Expression expression = normalizeParameters(functionMetadata.getArgumentNames().get(), parseSqlFunctionExpression(implementation, sqlFunctionProperties)); + expression = coerceIfNecessary(functionMetadata, expression, sqlFunctionProperties, metadata); return SqlFunctionArgumentBinder.bindFunctionArguments(expression, functionMetadata.getArgumentNames().get(), arguments); } @@ -226,7 +226,7 @@ public static Expression bindFunctionArguments(Expression function, List for (int i = 0; i < argumentNames.size(); i++) { argumentBindings.put(argumentNames.get(i), argumentValues.get(i)); } - return ExpressionTreeRewriter.rewriteWith(new ExpressionFunctionVisitor(argumentBindings.build()), function); + return ExpressionTreeRewriter.rewriteWith(new ExpressionFunctionVisitor(), function, argumentBindings.build()); } public static RowExpression bindFunctionArguments(RowExpression function, List> argumentNames, List argumentValues) @@ -262,20 +262,26 @@ public RowExpression rewriteVariableReference(VariableReferenceExpression variab } private static class ExpressionFunctionVisitor - extends ExpressionRewriter + extends ExpressionRewriter> { - private final Map argumentBindings; - - public ExpressionFunctionVisitor(Map argumentBindings) + @Override + public Expression rewriteLambdaExpression(LambdaExpression lambda, Map context, ExpressionTreeRewriter> treeRewriter) { - this.argumentBindings = requireNonNull(argumentBindings, "argumentBindings is null"); + ImmutableList lambdaStringArguments = lambda.getArguments().stream() + .map(x -> x.getName().getValue()) + .collect(toImmutableList()); + ImmutableMap lambdaContext = context.entrySet().stream() + .filter(entry -> !lambdaStringArguments.contains(entry.getKey())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + Expression rewrittenBody = treeRewriter.rewrite(lambda.getBody(), lambdaContext); + return new LambdaExpression(lambda.getArguments(), rewrittenBody); } @Override - public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteIdentifier(Identifier node, Map context, ExpressionTreeRewriter> treeRewriter) { - if (argumentBindings.containsKey(node.getValue())) { - return argumentBindings.get(node.getValue()); + if (context.containsKey(node.getValue())) { + return context.get(node.getValue()); } return node; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 4f26114acb99b..ce0843ca8d597 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -143,7 +143,8 @@ public void testDefaults() .setOptimizeCommonSubExpressions(true) .setPreferDistributedUnion(true) .setOptimizeNullsInJoin(false) - .setWarnOnNoTableLayoutFilter("")); + .setWarnOnNoTableLayoutFilter("") + .setInlineSqlFunctions(true)); } @Test @@ -243,6 +244,7 @@ public void testExplicitPropertyMappings() .put("prefer-distributed-union", "false") .put("optimize-nulls-in-join", "true") .put("warn-on-no-table-layout-filter", "ry@nlikestheyankees,ds") + .put("inline-sql-functions", "false") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -338,7 +340,8 @@ public void testExplicitPropertyMappings() .setOptimizeCommonSubExpressions(false) .setPreferDistributedUnion(false) .setOptimizeNullsInJoin(true) - .setWarnOnNoTableLayoutFilter("ry@nlikestheyankees,ds"); + .setWarnOnNoTableLayoutFilter("ry@nlikestheyankees,ds") + .setInlineSqlFunctions(false); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java new file mode 100644 index 0000000000000..33a146c1596d1 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.function.QualifiedFunctionName; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; +import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager; +import com.facebook.presto.metadata.FunctionManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.sql.ExpressionUtils; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.NodeRef; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; +import com.facebook.presto.testing.TestingSession; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.common.type.StandardTypes.INTEGER; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.planner.TypeProvider.viewOf; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; +import static org.testng.Assert.assertEquals; + +public class TestInlineSqlFunctions +{ + private static final SqlInvokedFunction SQL_FUNCTION_SQUARE = new SqlInvokedFunction( + QualifiedFunctionName.of(new CatalogSchemaName("unittest", "memory"), "square"), + ImmutableList.of(new Parameter("x", parseTypeSignature(INTEGER))), + parseTypeSignature(INTEGER), + "square", + RoutineCharacteristics.builder() + .setDeterminism(DETERMINISTIC) + .setNullCallClause(RETURNS_NULL_ON_NULL_INPUT) + .build(), + "RETURN x * x", + Optional.empty()); + + private static final SqlInvokedFunction THRIFT_FUNCTION_FOO = new SqlInvokedFunction( + QualifiedFunctionName.of(new CatalogSchemaName("unittest", "memory"), "foo"), + ImmutableList.of(new Parameter("x", parseTypeSignature(INTEGER))), + parseTypeSignature(INTEGER), + "thrift function foo", + RoutineCharacteristics.builder() + .setLanguage(new RoutineCharacteristics.Language("java")) + .setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT) + .build(), + "", + Optional.empty()); + + private static final SqlInvokedFunction SQL_FUNCTION_ADD_1_TO_INT_ARRAY = new SqlInvokedFunction( + QualifiedFunctionName.of(new CatalogSchemaName("unittest", "memory"), "add_1_int"), + ImmutableList.of(new Parameter("x", parseTypeSignature("array(int)"))), + parseTypeSignature("array(int)"), + "add 1 to all elements of array", + RoutineCharacteristics.builder() + .setDeterminism(DETERMINISTIC) + .setNullCallClause(RETURNS_NULL_ON_NULL_INPUT) + .build(), + "RETURN transform(x, x -> x + 1)", + Optional.empty()); + + private static final SqlInvokedFunction SQL_FUNCTION_ADD_1_TO_BIGINT_ARRAY = new SqlInvokedFunction( + QualifiedFunctionName.of(new CatalogSchemaName("unittest", "memory"), "add_1_bigint"), + ImmutableList.of(new Parameter("x", parseTypeSignature("array(bigint)"))), + parseTypeSignature("array(bigint)"), + "add 1 to all elements of array", + RoutineCharacteristics.builder() + .setDeterminism(DETERMINISTIC) + .setNullCallClause(RETURNS_NULL_ON_NULL_INPUT) + .build(), + "RETURN transform(x, x -> x + 1)", + Optional.empty()); + + private RuleTester tester; + + @BeforeTest + public void setup() + { + RuleTester tester = new RuleTester(); + FunctionManager functionManager = tester.getMetadata().getFunctionManager(); + InMemoryFunctionNamespaceManager namespaceManager = new InMemoryFunctionNamespaceManager( + "unittest", + new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("{\"sql\": \"SQL\",\"java\": \"THRIFT\"}")); + functionManager.addFunctionNamespace("unittest", namespaceManager); + functionManager.createFunction(SQL_FUNCTION_SQUARE, true); + functionManager.createFunction(THRIFT_FUNCTION_FOO, true); + functionManager.createFunction(SQL_FUNCTION_ADD_1_TO_INT_ARRAY, true); + functionManager.createFunction(SQL_FUNCTION_ADD_1_TO_BIGINT_ARRAY, true); + this.tester = tester; + } + + @Test + public void testInlineFunction() + { + assertInlined(tester, "unittest.memory.square(x)", "x * x", ImmutableMap.of("x", IntegerType.INTEGER)); + } + + @Test + public void testInlineFunctionInsideFunction() + { + assertInlined(tester, "abs(unittest.memory.square(x))", "abs(x * x)", ImmutableMap.of("x", IntegerType.INTEGER)); + } + + @Test + public void testInlineFunctionContainingLambda() + { + assertInlined(tester, "unittest.memory.add_1_int(x)", "transform(x, x -> x + 1)", ImmutableMap.of("x", new ArrayType(IntegerType.INTEGER))); + } + + @Test + public void testInlineSqlFunctionCoercesConstantWithCast() + { + assertInlined(tester, + "unittest.memory.add_1_bigint(x)", + "transform(x, x -> x + CAST(1 AS bigint))", + ImmutableMap.of("x", new ArrayType(BigintType.BIGINT))); + } + + @Test + public void testInlineBuiltinSqlFunction() + { + assertInlined(tester, + "array_sum(x)", + "reduce(x, BIGINT '0', (s, x) -> (s + coalesce(x, BIGINT '0')), (s -> s))", + ImmutableMap.of("x", new ArrayType(IntegerType.INTEGER))); + } + + @Test + public void testNoInlineThriftFunction() + { + assertInlined(tester, "unittest.memory.foo(x)", "unittest.memory.foo(x)", ImmutableMap.of("x", IntegerType.INTEGER)); + } + + @Test + public void testInlineFunctionIntoPlan() + { + tester.assertThat(new InlineSqlFunctions(tester.getMetadata(), tester.getSqlParser()).projectExpressionRewrite()) + .on(p -> p.project( + assignment( + p.variable("squared"), + new FunctionCall(QualifiedName.of("unittest", "memory", "square"), ImmutableList.of(new SymbolReference("a")))), + p.values(p.variable("a", IntegerType.INTEGER)))) + .matches(project( + ImmutableMap.of("squared", expression("x * x")), + values(ImmutableMap.of("x", 0)))); + } + + @Test + public void testNoInlineIntoPlanWhenInlineIsDisabled() + { + tester.assertThat(new InlineSqlFunctions(tester.getMetadata(), tester.getSqlParser()).projectExpressionRewrite()) + .setSystemProperty("inline_sql_functions", "false") + .on(p -> p.project( + assignment( + p.variable("squared"), + new FunctionCall(QualifiedName.of("unittest", "memory", "square"), ImmutableList.of(new SymbolReference("a")))), + p.values(p.variable("a", IntegerType.INTEGER)))) + .doesNotFire(); + } + + private void assertInlined(RuleTester tester, String inputSql, String expected, Map variableTypes) + { + Session session = TestingSession.testSessionBuilder() + .setSystemProperty("inline_sql_functions", "true") + .build(); + Metadata metadata = tester.getMetadata(); + Expression inputSqlExpression = PlanBuilder.expression(inputSql); + Map, Type> expressionTypes = getExpressionTypes( + session, + metadata, + tester.getSqlParser(), + viewOf(variableTypes), + inputSqlExpression, + ImmutableList.of(), + WarningCollector.NOOP); + Expression inlinedExpression = InlineSqlFunctions.InlineSqlFunctionsRewriter.rewrite(inputSqlExpression, session, metadata, expressionTypes); + inlinedExpression = ExpressionUtils.rewriteIdentifiersToSymbolReferences(inlinedExpression); + Expression expectedExpression = PlanBuilder.expression(expected); + assertEquals(inlinedExpression, expectedExpression); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java index f5406c0a52391..f52fab50413be 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java @@ -162,6 +162,19 @@ public void testCreateFunctionWithCoercion() rows = computeActual("SELECT testing.test.return_int() + 3"); assertEquals(rows.getMaterializedRows().get(0).getFields().get(0), 4); + + assertQuerySucceeds("CREATE FUNCTION testing.test.add_1_bigint(x array(bigint)) RETURNS array(bigint) RETURN transform(x, x -> x + 1)"); + String createFunctionAdd1BigintFormatted = "CREATE FUNCTION testing.test.add_1_bigint (\n" + + " x array(bigint)\n" + + ")\n" + + "RETURNS array(bigint)\n" + + "COMMENT ''\n" + + "LANGUAGE SQL\n" + + "NOT DETERMINISTIC\n" + + "CALLED ON NULL INPUT\n" + + "RETURN \"transform\"(x, (x) -> (x + CAST(1 AS bigint)))"; + rows = computeActual("SHOW CREATE FUNCTION testing.test.add_1_bigint(array(bigint))"); + assertEquals(rows.getMaterializedRows().get(0).getFields(), ImmutableList.of(createFunctionAdd1BigintFormatted, "array(bigint)")); } @Test