diff --git a/presto-main/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java b/presto-main/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java index bb430b9c75874..624497b35d54f 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution; import com.facebook.presto.common.function.QualifiedFunctionName; +import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; @@ -25,8 +26,11 @@ import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Analyzer; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CreateFunction; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.Return; +import com.facebook.presto.sql.tree.RoutineBody; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; @@ -76,11 +80,11 @@ public ListenableFuture execute(CreateFunction statement, TransactionManager throw new PrestoException(NOT_SUPPORTED, "Invoking a dynamically registered function in SQL function body is not supported"); } - metadata.getFunctionManager().createFunction(createSqlInvokedFunction(statement), statement.isReplace()); + metadata.getFunctionManager().createFunction(createSqlInvokedFunction(statement, metadata, analysis), statement.isReplace()); return immediateFuture(null); } - private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement) + private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement, Metadata metadata, Analysis analysis) { QualifiedFunctionName functionName = qualifyFunctionName(statement.getFunctionName()); List parameters = statement.getParameters().stream() @@ -93,7 +97,17 @@ private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement) .setDeterminism(RoutineCharacteristics.Determinism.valueOf(statement.getCharacteristics().getDeterminism().name())) .setNullCallClause(RoutineCharacteristics.NullCallClause.valueOf(statement.getCharacteristics().getNullCallClause().name())) .build(); - String body = formatSql(statement.getBody(), Optional.empty()); + RoutineBody body = statement.getBody(); + + if (statement.getBody() instanceof Return) { + Expression bodyExpression = ((Return) statement.getBody()).getExpression(); + Type bodyType = analysis.getType(bodyExpression); + + if (!bodyType.equals(metadata.getType(returnType))) { + // Casting is safe-here, since we have verified that the actual type of the body is coercible to declared return type. + body = new Return(new Cast(bodyExpression, statement.getReturnType())); + } + } return new SqlInvokedFunction( functionName, @@ -101,7 +115,7 @@ private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement) returnType, description, routineCharacteristics, - body, + formatSql(body, Optional.empty()), Optional.empty()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 53e69b93901d6..ef7f8da6cf0b4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -592,7 +592,7 @@ protected Scope visitCreateFunction(CreateFunction node, Optional scope) if (node.getBody() instanceof Return) { Expression returnExpression = ((Return) node.getBody()).getExpression(); Type bodyType = analyzeExpression(returnExpression, functionScope).getExpressionTypes().get(NodeRef.of(returnExpression)); - if (!bodyType.equals(returnType)) { + if (!metadata.getTypeManager().canCoerce(bodyType, returnType)) { throw new SemanticException(TYPE_MISMATCH, node, "Function implementation type '%s' does not match declared return type '%s'", bodyType, returnType); } diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java index d652e469ac035..f1e66c774259c 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java @@ -114,6 +114,56 @@ public void testCreateFunction() assertQueryFails("CREATE FUNCTION testing.test.foo(x varchar) RETURNS varchar LANGUAGE UNSUPPORTED EXTERNAL", "Catalog testing does not support functions implemented in language UNSUPPORTED"); } + @Test + public void testCreateFunctionWithCoercion() + { + assertQuerySucceeds("CREATE FUNCTION testing.test.return_double() RETURNS DOUBLE RETURN 1"); + String createFunctionReturnDoubleFormatted = "CREATE FUNCTION testing.test.return_double ()\n" + + "RETURNS DOUBLE\n" + + "COMMENT ''\n" + + "LANGUAGE SQL\n" + + "NOT DETERMINISTIC\n" + + "CALLED ON NULL INPUT\n" + + "RETURN CAST(1 AS double)"; + + MaterializedResult rows = computeActual("SHOW CREATE FUNCTION testing.test.return_double()"); + assertEquals(rows.getMaterializedRows().get(0).getFields(), ImmutableList.of(createFunctionReturnDoubleFormatted, "")); + + rows = computeActual("SELECT testing.test.return_double() + 1"); + assertEquals(rows.getMaterializedRows().get(0).getFields().get(0), 2.0); + + assertQuerySucceeds("CREATE FUNCTION testing.test.return_varchar() RETURNS VARCHAR RETURN 'ABC'"); + String createFunctionReturnVarcharFormatted = "CREATE FUNCTION testing.test.return_varchar ()\n" + + "RETURNS varchar\n" + + "COMMENT ''\n" + + "LANGUAGE SQL\n" + + "NOT DETERMINISTIC\n" + + "CALLED ON NULL INPUT\n" + + "RETURN CAST('ABC' AS varchar)"; + + rows = computeActual("SHOW CREATE FUNCTION testing.test.return_varchar()"); + assertEquals(rows.getMaterializedRows().get(0).getFields(), ImmutableList.of(createFunctionReturnVarcharFormatted, "")); + + rows = computeActual("SELECT lower(testing.test.return_varchar())"); + assertEquals(rows.getMaterializedRows().get(0).getFields().get(0), "abc"); + + // no explicit cast added + assertQuerySucceeds("CREATE FUNCTION testing.test.return_int() RETURNS INTEGER RETURN 1"); + String createFunctionReturnIntFormatted = "CREATE FUNCTION testing.test.return_int ()\n" + + "RETURNS INTEGER\n" + + "COMMENT ''\n" + + "LANGUAGE SQL\n" + + "NOT DETERMINISTIC\n" + + "CALLED ON NULL INPUT\n" + + "RETURN 1"; + + rows = computeActual("SHOW CREATE FUNCTION testing.test.return_int()"); + assertEquals(rows.getMaterializedRows().get(0).getFields(), ImmutableList.of(createFunctionReturnIntFormatted, "")); + + rows = computeActual("SELECT testing.test.return_int() + 3"); + assertEquals(rows.getMaterializedRows().get(0).getFields().get(0), 4); + } + @Test public void testAlterFunctionInvalidFunctionName() {