Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.connector.system.GlobalSystemConnector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.LanguageFunctionManager;
import io.trino.metadata.Metadata;
import io.trino.metadata.QualifiedObjectName;
import io.trino.security.AccessControl;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.CatalogSchemaName;
import io.trino.spi.function.LanguageFunction;
import io.trino.sql.SqlEnvironmentConfig;
import io.trino.sql.parser.ParsingException;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.tree.CreateFunction;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionSpecification;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.QualifiedName;

import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;

import static com.google.common.util.concurrent.Futures.immediateVoidFuture;
import static io.trino.spi.StandardErrorCode.ALREADY_EXISTS;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.SYNTAX_ERROR;
import static io.trino.sql.SqlFormatter.formatSql;
import static io.trino.sql.analyzer.SemanticExceptions.semanticException;
import static io.trino.sql.routine.SqlRoutineAnalyzer.isRunAsInvoker;
import static java.util.Objects.requireNonNull;

public class CreateFunctionTask
implements DataDefinitionTask<CreateFunction>
{
private final Optional<CatalogSchemaName> defaultFunctionSchema;
private final SqlParser sqlParser;
private final Metadata metadata;
private final FunctionManager functionManager;
private final AccessControl accessControl;
private final LanguageFunctionManager languageFunctionManager;

@Inject
public CreateFunctionTask(
SqlEnvironmentConfig sqlEnvironmentConfig,
SqlParser sqlParser,
Metadata metadata,
FunctionManager functionManager,
AccessControl accessControl,
LanguageFunctionManager languageFunctionManager)
{
this.defaultFunctionSchema = defaultFunctionSchema(sqlEnvironmentConfig);
this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.functionManager = requireNonNull(functionManager, "functionManager is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
this.languageFunctionManager = requireNonNull(languageFunctionManager, "languageFunctionManager is null");
}

@Override
public String getName()
{
return "CREATE FUNCTION";
}

@Override
public ListenableFuture<Void> execute(CreateFunction statement, QueryStateMachine stateMachine, List<Expression> parameters, WarningCollector warningCollector)
{
Session session = stateMachine.getSession();

FunctionSpecification function = statement.getSpecification();
QualifiedObjectName name = qualifiedFunctionName(defaultFunctionSchema, statement, function.getName());

accessControl.checkCanCreateFunction(session.toSecurityContext(), name);

String formatted = formatSql(function);
verifyFormattedFunction(formatted, function);

languageFunctionManager.verifyForCreate(session, formatted, functionManager, accessControl);

String signatureToken = languageFunctionManager.getSignatureToken(function.getParameters());

// system path elements currently are not stored
List<CatalogSchemaName> path = session.getPath().getPath().stream()
.filter(element -> !element.getCatalogName().equals(GlobalSystemConnector.NAME))
.toList();

Optional<String> owner = isRunAsInvoker(function) ? Optional.empty() : Optional.of(session.getUser());

LanguageFunction languageFunction = new LanguageFunction(signatureToken, formatted, path, owner);

boolean replace = false;
if (metadata.languageFunctionExists(session, name, signatureToken)) {
Comment thread
electrum marked this conversation as resolved.
Outdated
if (!statement.isReplace()) {
throw semanticException(ALREADY_EXISTS, statement, "Function already exists");
}
accessControl.checkCanDropFunction(session.toSecurityContext(), name);
replace = true;
}

metadata.createLanguageFunction(session, name, languageFunction, replace);

return immediateVoidFuture();
}

private void verifyFormattedFunction(String sql, FunctionSpecification function)
{
try {
FunctionSpecification parsed = sqlParser.createFunctionSpecification(sql);
if (!function.equals(parsed)) {
throw formattingFailure(null, "Function does not round-trip", function, sql);
}
}
catch (ParsingException e) {
throw formattingFailure(e, "Formatted function does not parse", function, sql);
}
}

static Optional<CatalogSchemaName> defaultFunctionSchema(SqlEnvironmentConfig config)
{
return combine(config.getDefaultFunctionCatalog(), config.getDefaultFunctionSchema(), CatalogSchemaName::new);
}

static QualifiedObjectName qualifiedFunctionName(Optional<CatalogSchemaName> functionSchema, Node node, QualifiedName name)
{
List<String> parts = name.getParts();
return switch (parts.size()) {
case 1 -> {
CatalogSchemaName schema = functionSchema.orElseThrow(() ->
semanticException(NOT_SUPPORTED, node, "Catalog and schema must be specified when function schema is not configured"));
yield new QualifiedObjectName(schema.getCatalogName(), schema.getSchemaName(), parts.get(0));
}
case 2 -> throw semanticException(NOT_SUPPORTED, node, "Function name must be unqualified or fully qualified with catalog and schema");
case 3 -> new QualifiedObjectName(parts.get(0), parts.get(1), parts.get(2));
default -> throw semanticException(SYNTAX_ERROR, node, "Too many dots in function name: %s", name);
};
}

private static TrinoException formattingFailure(Throwable cause, String message, FunctionSpecification function, String sql)
{
TrinoException exception = new TrinoException(GENERIC_INTERNAL_ERROR, message, cause);
exception.addSuppressed(new RuntimeException("Function: " + function));
exception.addSuppressed(new RuntimeException("Formatted: [%s]".formatted(sql)));
return exception;
}

private static <T, U, R> Optional<R> combine(Optional<T> first, Optional<U> second, BiFunction<T, U, R> combiner)
{
return (first.isPresent() && second.isPresent())
? Optional.of(combiner.apply(first.get(), second.get()))
: Optional.empty();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.LanguageFunctionManager;
import io.trino.metadata.Metadata;
import io.trino.metadata.QualifiedObjectName;
import io.trino.security.AccessControl;
import io.trino.spi.connector.CatalogSchemaName;
import io.trino.sql.SqlEnvironmentConfig;
import io.trino.sql.tree.DropFunction;
import io.trino.sql.tree.Expression;

import java.util.List;
import java.util.Optional;

import static com.google.common.util.concurrent.Futures.immediateVoidFuture;
import static io.trino.execution.CreateFunctionTask.defaultFunctionSchema;
import static io.trino.execution.CreateFunctionTask.qualifiedFunctionName;
import static io.trino.spi.StandardErrorCode.NOT_FOUND;
import static io.trino.sql.analyzer.SemanticExceptions.semanticException;
import static java.util.Objects.requireNonNull;

public class DropFunctionTask
implements DataDefinitionTask<DropFunction>
{
private final Optional<CatalogSchemaName> functionSchema;
private final Metadata metadata;
private final AccessControl accessControl;
private final LanguageFunctionManager languageFunctionManager;

@Inject
public DropFunctionTask(
SqlEnvironmentConfig sqlEnvironmentConfig,
Metadata metadata,
AccessControl accessControl,
LanguageFunctionManager languageFunctionManager)
{
this.functionSchema = defaultFunctionSchema(sqlEnvironmentConfig);
this.metadata = requireNonNull(metadata, "metadata is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
this.languageFunctionManager = requireNonNull(languageFunctionManager, "languageFunctionManager is null");
}

@Override
public String getName()
{
return "DROP FUNCTION";
}

@Override
public ListenableFuture<Void> execute(DropFunction statement, QueryStateMachine stateMachine, List<Expression> parameters, WarningCollector warningCollector)
{
Session session = stateMachine.getSession();

QualifiedObjectName name = qualifiedFunctionName(functionSchema, statement, statement.getName());

accessControl.checkCanDropFunction(session.toSecurityContext(), name);

String signatureToken = languageFunctionManager.getSignatureToken(statement.getParameters());

if (!metadata.languageFunctionExists(session, name, signatureToken)) {
if (!statement.isExists()) {
Comment thread
electrum marked this conversation as resolved.
Outdated
throw semanticException(NOT_FOUND, statement, "Function not found");
}
return immediateVoidFuture();
}

metadata.dropLanguageFunction(session, name, signatureToken);

return immediateVoidFuture();
}
}
7 changes: 7 additions & 0 deletions core/trino-main/src/main/java/io/trino/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import io.trino.spi.function.FunctionDependencyDeclaration;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.LanguageFunction;
import io.trino.spi.function.OperatorType;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.security.GrantInfo;
Expand Down Expand Up @@ -693,6 +694,12 @@ default ResolvedFunction getCoercion(Type fromType, Type toType)

FunctionDependencyDeclaration getFunctionDependencies(Session session, CatalogHandle catalogHandle, FunctionId functionId, BoundSignature boundSignature);

boolean languageFunctionExists(Session session, QualifiedObjectName name, String signatureToken);

void createLanguageFunction(Session session, QualifiedObjectName name, LanguageFunction function, boolean replace);

void dropLanguageFunction(Session session, QualifiedObjectName name, String signatureToken);

/**
* Creates the specified materialized view with the specified view definition.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.trino.FeaturesConfig;
import io.trino.Session;
import io.trino.connector.system.GlobalSystemConnector;
import io.trino.metadata.LanguageFunctionManager.LanguageFunctionLoader;
import io.trino.metadata.LanguageFunctionManager.RunAsIdentityLoader;
import io.trino.metadata.ResolvedFunction.ResolvedFunctionDecoder;
import io.trino.spi.ErrorCode;
Expand Down Expand Up @@ -93,6 +92,7 @@
import io.trino.spi.function.FunctionDependencyDeclaration;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.LanguageFunction;
import io.trino.spi.function.OperatorType;
import io.trino.spi.function.SchemaFunctionName;
import io.trino.spi.function.Signature;
Expand Down Expand Up @@ -2314,6 +2314,7 @@ public Collection<FunctionMetadata> listFunctions(Session session, CatalogSchema
ConnectorSession connectorSession = session.toConnectorSession(catalogMetadata.getCatalogHandle());
ConnectorMetadata metadata = catalogMetadata.getMetadata(session);
functions.addAll(metadata.listFunctions(connectorSession, schema.getSchemaName()));
functions.addAll(languageFunctionManager.listFunctions(metadata.listLanguageFunctions(connectorSession, schema.getSchemaName())));
});
return functions.build();
}
Expand Down Expand Up @@ -2409,8 +2410,7 @@ private List<CatalogFunctionMetadata> getFunctions(Session session, ConnectorMet
.orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "No identity for SECURITY DEFINER function: " + functionName));
};

LanguageFunctionLoader emptyLoader = (ignoredSession, ignoredName) -> ImmutableList.of();
languageFunctionManager.getFunctions(session, catalogHandle, name, emptyLoader, identityLoader).stream()
languageFunctionManager.getFunctions(session, catalogHandle, name, metadata::getLanguageFunctions, identityLoader).stream()
.map(function -> new CatalogFunctionMetadata(catalogHandle, name.getSchemaName(), function))
.forEach(functions::add);

Expand Down Expand Up @@ -2448,6 +2448,38 @@ public AggregationFunctionMetadata getAggregationFunctionMetadata(Session sessio
return builder.build();
}

@Override
public boolean languageFunctionExists(Session session, QualifiedObjectName name, String signatureToken)
{
return getOptionalCatalogMetadata(session, name.getCatalogName())
.map(catalogMetadata -> {
ConnectorMetadata metadata = catalogMetadata.getMetadata(session);
ConnectorSession connectorSession = session.toConnectorSession(catalogMetadata.getCatalogHandle());
return metadata.languageFunctionExists(connectorSession, name.asSchemaFunctionName(), signatureToken);
})
.orElse(false);
}

@Override
public void createLanguageFunction(Session session, QualifiedObjectName name, LanguageFunction function, boolean replace)
{
CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, name.getCatalogName());
CatalogHandle catalogHandle = catalogMetadata.getCatalogHandle();
ConnectorMetadata metadata = catalogMetadata.getMetadata(session);

metadata.createLanguageFunction(session.toConnectorSession(catalogHandle), name.asSchemaFunctionName(), function, replace);
}

@Override
public void dropLanguageFunction(Session session, QualifiedObjectName name, String signatureToken)
{
CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, name.getCatalogName());
CatalogHandle catalogHandle = catalogMetadata.getCatalogHandle();
ConnectorMetadata metadata = catalogMetadata.getMetadata(session);

metadata.dropLanguageFunction(session.toConnectorSession(catalogHandle), name.asSchemaFunctionName(), signatureToken);
}

@VisibleForTesting
public static FunctionBinding toFunctionBinding(FunctionId functionId, BoundSignature boundSignature, Signature functionSignature)
{
Expand Down
14 changes: 14 additions & 0 deletions core/trino-main/src/main/java/io/trino/security/AccessControl.java
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,20 @@ void checkCanRevokeRoles(SecurityContext context,
*/
Set<SchemaFunctionName> filterFunctions(SecurityContext context, String catalogName, Set<SchemaFunctionName> functionNames);

/**
* Check if identity is allowed to create the specified function.
*
* @throws AccessDeniedException if not allowed
*/
void checkCanCreateFunction(SecurityContext context, QualifiedObjectName functionName);

/**
* Check if identity is allowed to drop the specified function.
*
* @throws AccessDeniedException if not allowed
*/
void checkCanDropFunction(SecurityContext context, QualifiedObjectName functionName);

default List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObjectName tableName)
{
return ImmutableList.of();
Expand Down
Loading