Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.execution;

import com.facebook.presto.common.function.QualifiedFunctionName;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.security.AccessControl;
Expand All @@ -25,8 +26,11 @@
import com.facebook.presto.sql.analyzer.Analysis;
import com.facebook.presto.sql.analyzer.Analyzer;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CreateFunction;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Return;
import com.facebook.presto.sql.tree.RoutineBody;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.util.concurrent.ListenableFuture;

Expand Down Expand Up @@ -76,11 +80,11 @@ public ListenableFuture<?> execute(CreateFunction statement, TransactionManager
throw new PrestoException(NOT_SUPPORTED, "Invoking a dynamically registered function in SQL function body is not supported");
}

metadata.getFunctionManager().createFunction(createSqlInvokedFunction(statement), statement.isReplace());
metadata.getFunctionManager().createFunction(createSqlInvokedFunction(statement, metadata, analysis), statement.isReplace());
return immediateFuture(null);
}

private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement)
private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement, Metadata metadata, Analysis analysis)
{
QualifiedFunctionName functionName = qualifyFunctionName(statement.getFunctionName());
List<Parameter> parameters = statement.getParameters().stream()
Expand All @@ -93,15 +97,25 @@ private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement)
.setDeterminism(RoutineCharacteristics.Determinism.valueOf(statement.getCharacteristics().getDeterminism().name()))
.setNullCallClause(RoutineCharacteristics.NullCallClause.valueOf(statement.getCharacteristics().getNullCallClause().name()))
.build();
String body = formatSql(statement.getBody(), Optional.empty());
RoutineBody body = statement.getBody();

if (statement.getBody() instanceof Return) {
Expression bodyExpression = ((Return) statement.getBody()).getExpression();
Type bodyType = analysis.getType(bodyExpression);

if (!bodyType.equals(metadata.getType(returnType))) {
// Casting is safe-here, since we have verified that the actual type of the body is coercible to declared return type.
body = new Return(new Cast(bodyExpression, statement.getReturnType()));
}
}

return new SqlInvokedFunction(
functionName,
parameters,
returnType,
description,
routineCharacteristics,
body,
formatSql(body, Optional.empty()),
Optional.empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ protected Scope visitCreateFunction(CreateFunction node, Optional<Scope> scope)
if (node.getBody() instanceof Return) {
Expression returnExpression = ((Return) node.getBody()).getExpression();
Type bodyType = analyzeExpression(returnExpression, functionScope).getExpressionTypes().get(NodeRef.of(returnExpression));
if (!bodyType.equals(returnType)) {
if (!metadata.getTypeManager().canCoerce(bodyType, returnType)) {
throw new SemanticException(TYPE_MISMATCH, node, "Function implementation type '%s' does not match declared return type '%s'", bodyType, returnType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,56 @@ public void testCreateFunction()
assertQueryFails("CREATE FUNCTION testing.test.foo(x varchar) RETURNS varchar LANGUAGE UNSUPPORTED EXTERNAL", "Catalog testing does not support functions implemented in language UNSUPPORTED");
}

@Test
public void testCreateFunctionWithCoercion()
{
assertQuerySucceeds("CREATE FUNCTION testing.test.return_double() RETURNS DOUBLE RETURN 1");
String createFunctionReturnDoubleFormatted = "CREATE FUNCTION testing.test.return_double ()\n" +
"RETURNS DOUBLE\n" +
"COMMENT ''\n" +
"LANGUAGE SQL\n" +
"NOT DETERMINISTIC\n" +
"CALLED ON NULL INPUT\n" +
"RETURN CAST(1 AS double)";

MaterializedResult rows = computeActual("SHOW CREATE FUNCTION testing.test.return_double()");
assertEquals(rows.getMaterializedRows().get(0).getFields(), ImmutableList.of(createFunctionReturnDoubleFormatted, ""));

rows = computeActual("SELECT testing.test.return_double() + 1");
assertEquals(rows.getMaterializedRows().get(0).getFields().get(0), 2.0);

assertQuerySucceeds("CREATE FUNCTION testing.test.return_varchar() RETURNS VARCHAR RETURN 'ABC'");
String createFunctionReturnVarcharFormatted = "CREATE FUNCTION testing.test.return_varchar ()\n" +
"RETURNS varchar\n" +
"COMMENT ''\n" +
"LANGUAGE SQL\n" +
"NOT DETERMINISTIC\n" +
"CALLED ON NULL INPUT\n" +
"RETURN CAST('ABC' AS varchar)";

rows = computeActual("SHOW CREATE FUNCTION testing.test.return_varchar()");
assertEquals(rows.getMaterializedRows().get(0).getFields(), ImmutableList.of(createFunctionReturnVarcharFormatted, ""));

rows = computeActual("SELECT lower(testing.test.return_varchar())");
assertEquals(rows.getMaterializedRows().get(0).getFields().get(0), "abc");

// no explicit cast added
assertQuerySucceeds("CREATE FUNCTION testing.test.return_int() RETURNS INTEGER RETURN 1");
String createFunctionReturnIntFormatted = "CREATE FUNCTION testing.test.return_int ()\n" +
"RETURNS INTEGER\n" +
"COMMENT ''\n" +
"LANGUAGE SQL\n" +
"NOT DETERMINISTIC\n" +
"CALLED ON NULL INPUT\n" +
"RETURN 1";

rows = computeActual("SHOW CREATE FUNCTION testing.test.return_int()");
assertEquals(rows.getMaterializedRows().get(0).getFields(), ImmutableList.of(createFunctionReturnIntFormatted, ""));

rows = computeActual("SELECT testing.test.return_int() + 3");
assertEquals(rows.getMaterializedRows().get(0).getFields().get(0), 4);
}

@Test
public void testAlterFunctionInvalidFunctionName()
{
Expand Down