diff --git a/core/trino-main/src/main/java/io/trino/Session.java b/core/trino-main/src/main/java/io/trino/Session.java index 830d50808a14..bacd9bf60bf1 100644 --- a/core/trino-main/src/main/java/io/trino/Session.java +++ b/core/trino-main/src/main/java/io/trino/Session.java @@ -14,7 +14,6 @@ package io.trino; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; @@ -24,7 +23,6 @@ import io.airlift.units.Duration; import io.opentelemetry.api.trace.Span; import io.trino.client.ProtocolHeaders; -import io.trino.connector.system.GlobalSystemConnector; import io.trino.metadata.SessionPropertyManager; import io.trino.security.AccessControl; import io.trino.security.SecurityContext; @@ -58,7 +56,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; -import static io.trino.metadata.GlobalFunctionCatalog.BUILTIN_SCHEMA; import static io.trino.spi.StandardErrorCode.NOT_FOUND; import static io.trino.sql.SqlPath.EMPTY_PATH; import static io.trino.util.Failures.checkCondition; @@ -316,6 +313,11 @@ public Optional getExchangeEncryptionKey() return exchangeEncryptionKey; } + public SessionPropertyManager getSessionPropertyManager() + { + return sessionPropertyManager; + } + public Session beginTransactionId(TransactionId transactionId, TransactionManager transactionManager, AccessControl accessControl) { requireNonNull(transactionId, "transactionId is null"); @@ -584,16 +586,13 @@ private void validateSystemProperties(AccessControl accessControl, Map catalog, Optional schema, Identity identity, List path) + public Session createViewSession(Optional catalog, Optional schema, Identity identity, List viewPath) + { + return createViewSession(catalog, schema, identity, path.forView(viewPath)); + } + + public Session createViewSession(Optional catalog, Optional schema, Identity identity, SqlPath sqlPath) { - // For a view, we prepend the global function schema to the path, which should not be in the path - // We do not change the raw path, as that is use for the current_path function - SqlPath sqlPath = new SqlPath( - ImmutableList.builder() - .add(new CatalogSchemaName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA)) - .addAll(path) - .build(), - getPath().getRawPath()); return builder(sessionPropertyManager) .setQueryId(getQueryId()) .setTransactionId(getTransactionId().orElse(null)) diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java index ec6c2b041105..685d8c2718b1 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java @@ -46,6 +46,7 @@ import io.trino.memory.LocalMemoryManager; import io.trino.memory.NodeMemoryConfig; import io.trino.memory.QueryContext; +import io.trino.metadata.LanguageFunctionProvider; import io.trino.operator.RetryPolicy; import io.trino.operator.scalar.JoniRegexpFunctions; import io.trino.operator.scalar.JoniRegexpReplaceLambdaFunction; @@ -136,12 +137,14 @@ public class SqlTaskManager private final CounterStat failedTasks = new CounterStat(); private final Optional stuckSplitTasksInterrupter; + private final LanguageFunctionProvider languageFunctionProvider; @Inject public SqlTaskManager( VersionEmbedder versionEmbedder, ConnectorServicesProvider connectorServicesProvider, LocalExecutionPlanner planner, + LanguageFunctionProvider languageFunctionProvider, LocationFactory locationFactory, TaskExecutor taskExecutor, SplitMonitor splitMonitor, @@ -159,6 +162,7 @@ public SqlTaskManager( this(versionEmbedder, connectorServicesProvider, planner, + languageFunctionProvider, locationFactory, taskExecutor, splitMonitor, @@ -180,6 +184,7 @@ public SqlTaskManager( VersionEmbedder versionEmbedder, ConnectorServicesProvider connectorServicesProvider, LocalExecutionPlanner planner, + LanguageFunctionProvider languageFunctionProvider, LocationFactory locationFactory, TaskExecutor taskExecutor, SplitMonitor splitMonitor, @@ -196,6 +201,7 @@ public SqlTaskManager( Predicate> stuckSplitStackTracePredicate) { this.connectorServicesProvider = requireNonNull(connectorServicesProvider, "connectorServicesProvider is null"); + this.languageFunctionProvider = languageFunctionProvider; requireNonNull(nodeInfo, "nodeInfo is null"); infoCacheTime = config.getInfoMaxAge(); @@ -230,7 +236,10 @@ public SqlTaskManager( tracer, sqlTaskExecutionFactory, taskNotificationExecutor, - sqlTask -> finishedTaskStats.merge(sqlTask.getIoStats()), + sqlTask -> { + languageFunctionProvider.unregisterTask(taskId); + finishedTaskStats.merge(sqlTask.getIoStats()); + }, maxBufferSize, maxBroadcastBufferSize, requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"), @@ -528,6 +537,9 @@ private TaskInfo doUpdateTask( } }); + fragment.map(PlanFragment::getLanguageFunctions) + .ifPresent(languageFunctions -> languageFunctionProvider.registerTask(taskId, languageFunctions)); + sqlTask.recordHeartbeat(); return sqlTask.updateTask(session, stageSpan, fragment, splitAssignments, outputBuffers, dynamicFilterDomains, speculative); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/DisabledSystemSecurityMetadata.java b/core/trino-main/src/main/java/io/trino/metadata/DisabledSystemSecurityMetadata.java index a41ebc776e38..3b86d92e84d0 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/DisabledSystemSecurityMetadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/DisabledSystemSecurityMetadata.java @@ -18,6 +18,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaTableName; +import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.security.GrantInfo; import io.trino.spi.security.Identity; import io.trino.spi.security.Privilege; @@ -166,6 +167,12 @@ public void setViewOwner(Session session, CatalogSchemaTableName view, TrinoPrin throw notSupportedException(view.getCatalogName()); } + @Override + public Optional getFunctionRunAsIdentity(Session session, CatalogSchemaFunctionName functionName) + { + return Optional.empty(); + } + @Override public void schemaCreated(Session session, CatalogSchemaName schema) {} diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionBinder.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionBinder.java index 1fef13d9b4e7..4017b8f22fa8 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionBinder.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionBinder.java @@ -71,6 +71,10 @@ CatalogFunctionBinding bindFunction(List parameterTypes, Optional tryBindFunction(List parameterTypes, Collection candidates) { + if (candidates.isEmpty()) { + return Optional.empty(); + } + List exactCandidates = candidates.stream() .filter(function -> function.functionMetadata().getSignature().getTypeVariableConstraints().isEmpty()) .collect(toImmutableList()); diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java index 92ca087ab96d..0c309d73e5f4 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java @@ -29,7 +29,6 @@ import io.trino.spi.function.AggregationImplementation; import io.trino.spi.function.BoundSignature; import io.trino.spi.function.FunctionDependencies; -import io.trino.spi.function.FunctionId; import io.trino.spi.function.FunctionProvider; import io.trino.spi.function.InOut; import io.trino.spi.function.InvocationConvention; @@ -44,16 +43,14 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodType; import java.util.List; -import java.util.Objects; -import java.util.Optional; -import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.primitives.Primitives.wrap; import static io.trino.cache.CacheUtils.uncheckedCacheGet; import static io.trino.cache.SafeCaches.buildNonEvictableCache; import static io.trino.client.NodeVersion.UNKNOWN; +import static io.trino.metadata.LanguageFunctionManager.isTrinoSqlLanguageFunction; import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; import static java.lang.String.format; @@ -63,14 +60,15 @@ public class FunctionManager { private final NonEvictableCache specializedScalarCache; - private final NonEvictableCache specializedAggregationCache; - private final NonEvictableCache specializedWindowCache; + private final NonEvictableCache specializedAggregationCache; + private final NonEvictableCache specializedWindowCache; private final CatalogServiceProvider functionProviders; private final GlobalFunctionCatalog globalFunctionCatalog; + private final LanguageFunctionProvider languageFunctionProvider; @Inject - public FunctionManager(CatalogServiceProvider functionProviders, GlobalFunctionCatalog globalFunctionCatalog) + public FunctionManager(CatalogServiceProvider functionProviders, GlobalFunctionCatalog globalFunctionCatalog, LanguageFunctionProvider languageFunctionProvider) { specializedScalarCache = buildNonEvictableCache(CacheBuilder.newBuilder() .maximumSize(1000) @@ -86,6 +84,7 @@ public FunctionManager(CatalogServiceProvider functionProvider this.functionProviders = requireNonNull(functionProviders, "functionProviders is null"); this.globalFunctionCatalog = requireNonNull(globalFunctionCatalog, "globalFunctionCatalog is null"); + this.languageFunctionProvider = requireNonNull(languageFunctionProvider, "functionProvider is null"); } public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) @@ -102,11 +101,19 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunc private ScalarFunctionImplementation getScalarFunctionImplementationInternal(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) { FunctionDependencies functionDependencies = getFunctionDependencies(resolvedFunction); - ScalarFunctionImplementation scalarFunctionImplementation = getFunctionProvider(resolvedFunction).getScalarFunctionImplementation( - resolvedFunction.getFunctionId(), - resolvedFunction.getSignature(), - functionDependencies, - invocationConvention); + + ScalarFunctionImplementation scalarFunctionImplementation; + if (isTrinoSqlLanguageFunction(resolvedFunction.getFunctionId())) { + scalarFunctionImplementation = languageFunctionProvider.specialize(this, resolvedFunction, functionDependencies, invocationConvention); + } + else { + scalarFunctionImplementation = getFunctionProvider(resolvedFunction).getScalarFunctionImplementation( + resolvedFunction.getFunctionId(), + resolvedFunction.getSignature(), + functionDependencies, + invocationConvention); + } + verifyMethodHandleSignature(resolvedFunction.getSignature(), scalarFunctionImplementation, invocationConvention); return scalarFunctionImplementation; } @@ -114,7 +121,7 @@ private ScalarFunctionImplementation getScalarFunctionImplementationInternal(Res public AggregationImplementation getAggregationImplementation(ResolvedFunction resolvedFunction) { try { - return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregationImplementationInternal(resolvedFunction)); + return uncheckedCacheGet(specializedAggregationCache, resolvedFunction, () -> getAggregationImplementationInternal(resolvedFunction)); } catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); @@ -134,7 +141,7 @@ private AggregationImplementation getAggregationImplementationInternal(ResolvedF public WindowFunctionSupplier getWindowFunctionSupplier(ResolvedFunction resolvedFunction) { try { - return uncheckedCacheGet(specializedWindowCache, new FunctionKey(resolvedFunction), () -> getWindowFunctionSupplierInternal(resolvedFunction)); + return uncheckedCacheGet(specializedWindowCache, resolvedFunction, () -> getWindowFunctionSupplierInternal(resolvedFunction)); } catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); @@ -303,58 +310,12 @@ private static void verifyFunctionSignature(boolean check, String message, Objec } } - private static class FunctionKey + private record FunctionKey(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) { - private final FunctionId functionId; - private final BoundSignature boundSignature; - private final Optional invocationConvention; - - public FunctionKey(ResolvedFunction resolvedFunction) - { - this(resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), Optional.empty()); - } - - public FunctionKey(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) - { - this(resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), Optional.of(invocationConvention)); - } - - public FunctionKey(FunctionId functionId, BoundSignature boundSignature, Optional invocationConvention) - { - this.functionId = requireNonNull(functionId, "functionId is null"); - this.boundSignature = requireNonNull(boundSignature, "boundSignature is null"); - this.invocationConvention = requireNonNull(invocationConvention, "invocationConvention is null"); - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - FunctionKey that = (FunctionKey) o; - return functionId.equals(that.functionId) && - boundSignature.equals(that.boundSignature) && - invocationConvention.equals(that.invocationConvention); - } - - @Override - public int hashCode() - { - return Objects.hash(functionId, boundSignature, invocationConvention); - } - - @Override - public String toString() + private FunctionKey { - return toStringHelper(this).omitNullValues() - .add("functionId", functionId) - .add("boundSignature", boundSignature) - .add("invocationConvention", invocationConvention.orElse(null)) - .toString(); + requireNonNull(resolvedFunction, "resolvedFunction is null"); + requireNonNull(invocationConvention, "invocationConvention is null"); } } @@ -364,6 +325,6 @@ public static FunctionManager createTestingFunctionManager() GlobalFunctionCatalog functionCatalog = new GlobalFunctionCatalog(); functionCatalog.addFunctions(SystemFunctionBundle.create(new FeaturesConfig(), typeOperators, new BlockTypeOperators(typeOperators), UNKNOWN)); functionCatalog.addFunctions(new InternalFunctionBundle(new LiteralFunction(new InternalBlockEncodingSerde(new BlockEncodingManager(), TESTING_TYPE_MANAGER)))); - return new FunctionManager(CatalogServiceProvider.fail(), functionCatalog); + return new FunctionManager(CatalogServiceProvider.fail(), functionCatalog, LanguageFunctionProvider.DISABLED); } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java index 489e60bf24d9..8ba0e9929cf2 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java @@ -50,6 +50,8 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.metadata.FunctionBinder.functionNotFound; import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName; +import static io.trino.metadata.LanguageFunctionManager.isInlineFunction; +import static io.trino.metadata.LanguageFunctionManager.isTrinoSqlLanguageFunction; import static io.trino.metadata.SignatureBinder.applyBoundVariables; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static io.trino.spi.StandardErrorCode.MISSING_CATALOG_NAME; @@ -64,14 +66,21 @@ public class FunctionResolver { private final Metadata metadata; private final TypeManager typeManager; + private final LanguageFunctionManager languageFunctionManager; private final WarningCollector warningCollector; private final ResolvedFunctionDecoder functionDecoder; private final FunctionBinder functionBinder; - public FunctionResolver(Metadata metadata, TypeManager typeManager, ResolvedFunctionDecoder functionDecoder, WarningCollector warningCollector) + public FunctionResolver( + Metadata metadata, + TypeManager typeManager, + LanguageFunctionManager languageFunctionManager, + ResolvedFunctionDecoder functionDecoder, + WarningCollector warningCollector) { this.metadata = requireNonNull(metadata, "metadata is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.languageFunctionManager = requireNonNull(languageFunctionManager, "languageFunctionManager is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.functionDecoder = requireNonNull(functionDecoder, "functionDecoder is null"); this.functionBinder = new FunctionBinder(metadata, typeManager); @@ -129,18 +138,24 @@ public ResolvedFunction resolveFunction(Session session, QualifiedName name, Lis warningCollector.add(new TrinoWarning(DEPRECATED_FUNCTION, "Use of deprecated function: %s: %s".formatted(name, functionMetadata.getDescription()))); } - return resolve(session, catalogFunctionBinding); + return resolve(session, catalogFunctionBinding, accessControl); } - private ResolvedFunction resolve(Session session, CatalogFunctionBinding functionBinding) + private ResolvedFunction resolve(Session session, CatalogFunctionBinding functionBinding, AccessControl accessControl) { - FunctionDependencyDeclaration dependencies = metadata.getFunctionDependencies( - session, - functionBinding.catalogHandle(), - functionBinding.functionBinding().getFunctionId(), - functionBinding.functionBinding().getBoundSignature()); + FunctionDependencyDeclaration dependencies; + if (isTrinoSqlLanguageFunction(functionBinding.functionBinding().getFunctionId())) { + dependencies = languageFunctionManager.getDependencies(session, functionBinding.functionBinding().getFunctionId(), accessControl); + } + else { + dependencies = metadata.getFunctionDependencies( + session, + functionBinding.catalogHandle(), + functionBinding.functionBinding().getFunctionId(), + functionBinding.functionBinding().getBoundSignature()); + } - return resolveFunctionBinding( + ResolvedFunction resolvedFunction = resolveFunctionBinding( metadata, typeManager, functionBinder, @@ -150,7 +165,15 @@ private ResolvedFunction resolve(Session session, CatalogFunctionBinding functio functionBinding.functionMetadata(), dependencies, catalogSchemaFunctionName -> metadata.getFunctions(session, catalogSchemaFunctionName), - catalogFunctionBinding -> resolve(session, catalogFunctionBinding)); + catalogFunctionBinding -> resolve(session, catalogFunctionBinding, accessControl)); + + // For SQL language functions, register the resolved function with the function manager, + // allowing the resolved function to be used later to retrieve the implementation. + if (isTrinoSqlLanguageFunction(resolvedFunction.getFunctionId())) { + languageFunctionManager.registerResolvedFunction(session, resolvedFunction); + } + + return resolvedFunction; } private CatalogFunctionBinding bindFunction( @@ -286,7 +309,7 @@ public static List toPath(Session session, QualifiedN private static boolean canExecuteFunction(Session session, AccessControl accessControl, CatalogSchemaFunctionName functionName) { - if (isBuiltinFunctionName(functionName)) { + if (isInlineFunction(functionName) || isBuiltinFunctionName(functionName)) { return true; } return accessControl.canExecuteFunction( diff --git a/core/trino-main/src/main/java/io/trino/metadata/LanguageFunctionManager.java b/core/trino-main/src/main/java/io/trino/metadata/LanguageFunctionManager.java new file mode 100644 index 000000000000..e4ae4ecfd77d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/LanguageFunctionManager.java @@ -0,0 +1,495 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import com.google.common.collect.ImmutableList; +import com.google.common.hash.Hashing; +import com.google.inject.Inject; +import io.trino.Session; +import io.trino.connector.system.GlobalSystemConnector; +import io.trino.execution.TaskId; +import io.trino.execution.warnings.WarningCollector; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.security.AccessControl; +import io.trino.security.ViewAccessControl; +import io.trino.spi.QueryId; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.CatalogHandle; +import io.trino.spi.connector.CatalogSchemaName; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.CatalogSchemaFunctionName; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.LanguageFunction; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.SchemaFunctionName; +import io.trino.spi.security.GroupProvider; +import io.trino.spi.security.Identity; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeId; +import io.trino.spi.type.TypeManager; +import io.trino.sql.PlannerContext; +import io.trino.sql.SqlPath; +import io.trino.sql.analyzer.TypeSignatureTranslator; +import io.trino.sql.parser.SqlParser; +import io.trino.sql.routine.SqlRoutineAnalysis; +import io.trino.sql.routine.SqlRoutineAnalyzer; +import io.trino.sql.routine.SqlRoutineCompiler; +import io.trino.sql.routine.SqlRoutinePlanner; +import io.trino.sql.routine.ir.IrRoutine; +import io.trino.sql.tree.FunctionSpecification; +import io.trino.sql.tree.ParameterDeclaration; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.sql.routine.SqlRoutineAnalyzer.extractFunctionMetadata; +import static io.trino.sql.routine.SqlRoutineAnalyzer.isRunAsInvoker; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class LanguageFunctionManager + implements LanguageFunctionProvider +{ + public static final String QUERY_LOCAL_SCHEMA = "$query"; + private static final String SQL_FUNCTION_PREFIX = "$trino_sql_"; + private final SqlParser parser; + private final TypeManager typeManager; + private final GroupProvider groupProvider; + private SqlRoutineAnalyzer analyzer; + private SqlRoutinePlanner planner; + private final Map queryFunctions = new ConcurrentHashMap<>(); + + @Inject + public LanguageFunctionManager(SqlParser parser, TypeManager typeManager, GroupProvider groupProvider) + { + this.parser = requireNonNull(parser, "parser is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.groupProvider = requireNonNull(groupProvider, "groupProvider is null"); + } + + // There is a circular dependency between LanguageFunctionManager and MetadataManager. + // To determine the dependencies of a language function, it must be analyzed, and that + // requires the metadata manager to resolve functions. The metadata manager needs the + // language function manager to resolve language functions. + public synchronized void setPlannerContext(PlannerContext plannerContext) + { + checkState(analyzer == null, "plannerContext already set"); + analyzer = new SqlRoutineAnalyzer(plannerContext, WarningCollector.NOOP); + planner = new SqlRoutinePlanner(plannerContext, WarningCollector.NOOP); + } + + public void tryRegisterQuery(Session session) + { + queryFunctions.putIfAbsent(session.getQueryId(), new QueryFunctions(session)); + } + + public void registerQuery(Session session) + { + boolean alreadyRegistered = queryFunctions.putIfAbsent(session.getQueryId(), new QueryFunctions(session)) != null; + if (alreadyRegistered) { + throw new IllegalStateException("Query already registered: " + session.getQueryId()); + } + } + + public void unregisterQuery(Session session) + { + queryFunctions.remove(session.getQueryId()); + } + + @Override + public void registerTask(TaskId taskId, List languageFunctions) + { + // the functions are already registered in the query, so we don't need to do anything here + } + + @Override + public void unregisterTask(TaskId taskId) {} + + private QueryFunctions getQueryFunctions(Session session) + { + QueryFunctions queryFunctions = this.queryFunctions.get(session.getQueryId()); + if (queryFunctions == null) { + throw new IllegalStateException("Query not registered: " + session.getQueryId()); + } + return queryFunctions; + } + + public List listFunctions(Collection languageFunctions) + { + return languageFunctions.stream() + .map(LanguageFunction::sql) + .map(sql -> extractFunctionMetadata(createSqlLanguageFunctionId(sql), parser.createFunctionSpecification(sql))) + .collect(toImmutableList()); + } + + public List getFunctions(Session session, CatalogHandle catalogHandle, SchemaFunctionName name, LanguageFunctionLoader languageFunctionLoader, RunAsIdentityLoader identityLoader) + { + return getQueryFunctions(session).getFunctions(catalogHandle, name, languageFunctionLoader, identityLoader); + } + + public FunctionMetadata getFunctionMetadata(Session session, FunctionId functionId) + { + return getQueryFunctions(session).getFunctionMetadata(functionId); + } + + public FunctionDependencyDeclaration getDependencies(Session session, FunctionId functionId, AccessControl accessControl) + { + return getQueryFunctions(session).getDependencies(functionId, accessControl); + } + + @Override + public ScalarFunctionImplementation specialize(FunctionManager functionManager, ResolvedFunction resolvedFunction, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) + { + // any resolved function in any query is guaranteed to have the same behavior, so we can use any query to get the implementation + return queryFunctions.values().stream() + .map(queryFunctions -> queryFunctions.specialize(resolvedFunction, functionManager, invocationConvention)) + .filter(Optional::isPresent) + .map(Optional::get) + .findFirst() + .orElseThrow(() -> new IllegalStateException("Unknown function implementation: " + resolvedFunction.getFunctionId())); + } + + public void registerResolvedFunction(Session session, ResolvedFunction resolvedFunction) + { + getQueryFunctions(session).registerResolvedFunction(resolvedFunction); + } + + public List serializeFunctionsForWorkers(Session session) + { + return getQueryFunctions(session).serializeFunctionsForWorkers(); + } + + public void verifyForCreate(Session session, String sql, FunctionManager functionManager, AccessControl accessControl) + { + getQueryFunctions(session).verifyForCreate(sql, functionManager, accessControl); + } + + public void addInlineFunction(Session session, String sql, AccessControl accessControl) + { + getQueryFunctions(session).addInlineFunction(sql, accessControl); + } + + public interface LanguageFunctionLoader + { + Collection getLanguageFunction(ConnectorSession session, SchemaFunctionName name); + } + + public interface RunAsIdentityLoader + { + Identity getFunctionRunAsIdentity(Optional owner); + } + + public static boolean isInlineFunction(CatalogSchemaFunctionName functionName) + { + return functionName.getCatalogName().equals(GlobalSystemConnector.NAME) && functionName.getSchemaName().equals(QUERY_LOCAL_SCHEMA); + } + + public static boolean isTrinoSqlLanguageFunction(FunctionId functionId) + { + return functionId.toString().startsWith(SQL_FUNCTION_PREFIX); + } + + private static FunctionId createSqlLanguageFunctionId(String sql) + { + String hash = Hashing.sha256().hashUnencodedChars(sql).toString(); + return new FunctionId(SQL_FUNCTION_PREFIX + hash); + } + + public String getSignatureToken(List parameters) + { + return parameters.stream() + .map(ParameterDeclaration::getType) + .map(TypeSignatureTranslator::toTypeSignature) + .map(typeManager::getType) + .map(Type::getTypeId) + .map(TypeId::getId) + .collect(joining(",", "(", ")")); + } + + private class QueryFunctions + { + private final Session session; + private final Map functionListing = new ConcurrentHashMap<>(); + private final Map implementationsById = new ConcurrentHashMap<>(); + private final Map implementationsByResolvedFunction = new ConcurrentHashMap<>(); + + public QueryFunctions(Session session) + { + this.session = session; + } + + public void verifyForCreate(String sql, FunctionManager functionManager, AccessControl accessControl) + { + implementationWithoutSecurity(sql).verifyForCreate(functionManager, accessControl); + } + + public void addInlineFunction(String sql, AccessControl accessControl) + { + LanguageFunctionImplementation implementation = implementationWithoutSecurity(sql); + FunctionMetadata metadata = implementation.getFunctionMetadata(); + implementationsById.put(metadata.getFunctionId(), implementation); + SchemaFunctionName name = new SchemaFunctionName(QUERY_LOCAL_SCHEMA, metadata.getCanonicalName()); + getFunctionListing(GlobalSystemConnector.CATALOG_HANDLE, name).addFunction(metadata); + + // enforce that functions may only call already registered functions and prevent recursive calls + implementation.analyzeAndPlan(accessControl); + } + + public synchronized List getFunctions(CatalogHandle catalogHandle, SchemaFunctionName name, LanguageFunctionLoader languageFunctionLoader, RunAsIdentityLoader identityLoader) + { + return getFunctionListing(catalogHandle, name).getFunctions(languageFunctionLoader, identityLoader); + } + + public FunctionDependencyDeclaration getDependencies(FunctionId functionId, AccessControl accessControl) + { + LanguageFunctionImplementation function = implementationsById.get(functionId); + checkArgument(function != null, "Unknown function implementation: " + functionId); + return function.getFunctionDependencies(accessControl); + } + + public Optional specialize(ResolvedFunction resolvedFunction, FunctionManager functionManager, InvocationConvention invocationConvention) + { + LanguageFunctionImplementation function = implementationsByResolvedFunction.get(resolvedFunction); + if (function == null) { + return Optional.empty(); + } + return Optional.of(function.specialize(functionManager, invocationConvention)); + } + + public FunctionMetadata getFunctionMetadata(FunctionId functionId) + { + LanguageFunctionImplementation function = implementationsById.get(functionId); + checkArgument(function != null, "Unknown function implementation: " + functionId); + return function.getFunctionMetadata(); + } + + public void registerResolvedFunction(ResolvedFunction resolvedFunction) + { + FunctionId functionId = resolvedFunction.getFunctionId(); + LanguageFunctionImplementation function = implementationsById.get(functionId); + checkArgument(function != null, "Unknown function implementation: " + functionId); + implementationsByResolvedFunction.put(resolvedFunction, function); + } + + public List serializeFunctionsForWorkers() + { + return implementationsByResolvedFunction.entrySet().stream() + .map(entry -> new LanguageScalarFunctionData( + entry.getKey(), + entry.getValue().getFunctionDependencies(), + entry.getValue().getRoutine())) + .collect(toImmutableList()); + } + + private FunctionListing getFunctionListing(CatalogHandle catalogHandle, SchemaFunctionName name) + { + return functionListing.computeIfAbsent(new FunctionKey(catalogHandle, name), FunctionListing::new); + } + + private record FunctionKey(CatalogHandle catalogHandle, SchemaFunctionName name) {} + + private class FunctionListing + { + private final CatalogHandle catalogHandle; + private final SchemaFunctionName name; + private final List functions = new ArrayList<>(); + private boolean loaded; + + public FunctionListing(FunctionKey key) + { + catalogHandle = key.catalogHandle(); + name = key.name(); + } + + public synchronized void addFunction(FunctionMetadata function) + { + functions.add(function); + loaded = true; + } + + public synchronized List getFunctions(LanguageFunctionLoader languageFunctionLoader, RunAsIdentityLoader identityLoader) + { + if (loaded) { + return ImmutableList.copyOf(functions); + } + loaded = true; + + List implementations = languageFunctionLoader.getLanguageFunction(session.toConnectorSession(), name).stream() + .map(function -> implementationWithSecurity(function.sql(), function.path(), function.owner(), identityLoader)) + .collect(toImmutableList()); + + // verify all names are correct + // Note: language functions don't have aliases + Set names = implementations.stream() + .map(function -> function.getFunctionMetadata().getCanonicalName()) + .collect(toImmutableSet()); + if (!names.isEmpty() && !names.equals(Set.of(name.getFunctionName()))) { + throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, "Catalog %s returned functions named %s when listing functions named %s".formatted(catalogHandle.getCatalogName(), names, name)); + } + + // add the functions to this listing + implementations.forEach(implementation -> functions.add(implementation.getFunctionMetadata())); + + // add the functions to the catalog index + implementations.forEach(processedFunction -> implementationsById.put(processedFunction.getFunctionMetadata().getFunctionId(), processedFunction)); + + return ImmutableList.copyOf(functions); + } + } + + private LanguageFunctionImplementation implementationWithoutSecurity(String sql) + { + // use the original path during function creation and for inline functions + return new LanguageFunctionImplementation(sql, session.getPath(), Optional.empty(), Optional.empty()); + } + + private LanguageFunctionImplementation implementationWithSecurity(String sql, List path, Optional owner, RunAsIdentityLoader identityLoader) + { + // stored functions cannot see inline functions, so we need to rebuild the path + return new LanguageFunctionImplementation(sql, session.getPath().forView(path), owner, Optional.of(identityLoader)); + } + + private class LanguageFunctionImplementation + { + private final FunctionMetadata functionMetadata; + private final FunctionSpecification functionSpecification; + private final SqlPath path; + private final Optional owner; + private final Optional identityLoader; + private SqlRoutineAnalysis analysis; + private FunctionDependencyDeclaration dependencies; + private IrRoutine routine; + private boolean analyzing; + + private LanguageFunctionImplementation(String sql, SqlPath path, Optional owner, Optional identityLoader) + { + this.functionSpecification = parser.createFunctionSpecification(sql); + this.functionMetadata = extractFunctionMetadata(createSqlLanguageFunctionId(sql), functionSpecification); + this.path = requireNonNull(path, "path is null"); + this.owner = requireNonNull(owner, "owner is null"); + this.identityLoader = requireNonNull(identityLoader, "identityLoader is null"); + } + + public FunctionMetadata getFunctionMetadata() + { + return functionMetadata; + } + + public void verifyForCreate(FunctionManager functionManager, AccessControl accessControl) + { + checkState(identityLoader.isEmpty(), "create should not enforce security"); + analyzeAndPlan(accessControl); + new SqlRoutineCompiler(functionManager).compile(getRoutine()); + } + + private synchronized void analyzeAndPlan(AccessControl accessControl) + { + if (analysis != null) { + return; + } + if (analyzing) { + throw new TrinoException(NOT_SUPPORTED, "Recursive language functions are not supported: %s%s".formatted(functionMetadata.getCanonicalName(), functionMetadata.getSignature())); + } + + analyzing = true; + FunctionContext context = functionContext(accessControl); + analysis = analyzer.analyze(context.session(), context.accessControl(), functionSpecification); + + FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder dependencies = FunctionDependencyDeclaration.builder(); + for (ResolvedFunction resolvedFunction : analysis.analysis().getResolvedFunctions()) { + dependencies.addFunction(resolvedFunction.toCatalogSchemaFunctionName(), resolvedFunction.getSignature().getArgumentTypes()); + } + this.dependencies = dependencies.build(); + + routine = planner.planSqlFunction(session, functionSpecification, analysis); + analyzing = false; + } + + public synchronized FunctionDependencyDeclaration getFunctionDependencies(AccessControl accessControl) + { + analyzeAndPlan(accessControl); + return dependencies; + } + + public synchronized FunctionDependencyDeclaration getFunctionDependencies() + { + if (dependencies == null) { + throw new IllegalStateException("Function not analyzed: " + functionMetadata.getSignature()); + } + return dependencies; + } + + public synchronized IrRoutine getRoutine() + { + if (routine == null) { + throw new IllegalStateException("Function not analyzed: " + functionMetadata.getSignature()); + } + return routine; + } + + public ScalarFunctionImplementation specialize(FunctionManager functionManager, InvocationConvention invocationConvention) + { + // Recompile everytime this function is called as the function dependencies may have changed. + // The caller caches, so this should not be a problem. + // TODO: compiler should use function dependencies instead of function manager + SpecializedSqlScalarFunction function = new SqlRoutineCompiler(functionManager).compile(getRoutine()); + return function.getScalarFunctionImplementation(invocationConvention); + } + + private FunctionContext functionContext(AccessControl accessControl) + { + if (identityLoader.isEmpty() || isRunAsInvoker(functionSpecification)) { + Session functionSession = createFunctionSession(session.getIdentity()); + return new FunctionContext(functionSession, accessControl); + } + + Identity identity = identityLoader.get().getFunctionRunAsIdentity(owner); + + Identity newIdentity = Identity.from(identity) + .withGroups(groupProvider.getGroups(identity.getUser())) + .build(); + + Session functionSession = createFunctionSession(newIdentity); + + if (!identity.getUser().equals(session.getUser())) { + accessControl = new ViewAccessControl(accessControl); + } + + return new FunctionContext(functionSession, accessControl); + } + + private Session createFunctionSession(Identity identity) + { + return session.createViewSession(Optional.empty(), Optional.empty(), identity, path); + } + + private record FunctionContext(Session session, AccessControl accessControl) {} + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/LanguageFunctionProvider.java b/core/trino-main/src/main/java/io/trino/metadata/LanguageFunctionProvider.java new file mode 100644 index 000000000000..126f454036b0 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/LanguageFunctionProvider.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import io.trino.execution.TaskId; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.ScalarFunctionImplementation; + +import java.util.List; + +public interface LanguageFunctionProvider +{ + LanguageFunctionProvider DISABLED = new LanguageFunctionProvider() + { + @Override + public ScalarFunctionImplementation specialize(FunctionManager functionManager, ResolvedFunction resolvedFunction, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) + { + throw new UnsupportedOperationException("SQL language functions are disabled"); + } + + @Override + public void registerTask(TaskId taskId, List languageFunctions) + { + if (!languageFunctions.isEmpty()) { + throw new UnsupportedOperationException("SQL language functions are disabled"); + } + } + + @Override + public void unregisterTask(TaskId taskId) {} + }; + + ScalarFunctionImplementation specialize( + FunctionManager functionManager, + ResolvedFunction resolvedFunction, + FunctionDependencies functionDependencies, + InvocationConvention invocationConvention); + + void registerTask(TaskId taskId, List languageFunctions); + + void unregisterTask(TaskId taskId); +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/LanguageScalarFunctionData.java b/core/trino-main/src/main/java/io/trino/metadata/LanguageScalarFunctionData.java new file mode 100644 index 000000000000..648d4e025e38 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/LanguageScalarFunctionData.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.sql.routine.ir.IrRoutine; + +import static java.util.Objects.requireNonNull; + +public record LanguageScalarFunctionData( + ResolvedFunction resolvedFunction, + FunctionDependencyDeclaration functionDependencies, + IrRoutine routine) +{ + public LanguageScalarFunctionData + { + requireNonNull(resolvedFunction, "resolvedFunction is null"); + requireNonNull(functionDependencies, "functionDependencies is null"); + requireNonNull(routine, "routine is null"); + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 0906f31f70f9..8e3e0213a546 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -27,6 +27,8 @@ import io.trino.FeaturesConfig; import io.trino.Session; import io.trino.connector.system.GlobalSystemConnector; +import io.trino.metadata.LanguageFunctionManager.LanguageFunctionLoader; +import io.trino.metadata.LanguageFunctionManager.RunAsIdentityLoader; import io.trino.metadata.ResolvedFunction.ResolvedFunctionDecoder; import io.trino.spi.ErrorCode; import io.trino.spi.QueryId; @@ -107,6 +109,7 @@ import io.trino.spi.type.TypeNotFoundException; import io.trino.spi.type.TypeOperators; import io.trino.sql.analyzer.TypeSignatureProvider; +import io.trino.sql.parser.SqlParser; import io.trino.sql.planner.ConnectorExpressions; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.tree.QualifiedName; @@ -148,6 +151,7 @@ import static io.trino.metadata.CatalogMetadata.SecurityManagement.SYSTEM; import static io.trino.metadata.GlobalFunctionCatalog.BUILTIN_SCHEMA; import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName; +import static io.trino.metadata.LanguageFunctionManager.isTrinoSqlLanguageFunction; import static io.trino.metadata.QualifiedObjectName.convertFromSchemaTableName; import static io.trino.metadata.RedirectionAwareTableHandle.noRedirection; import static io.trino.metadata.RedirectionAwareTableHandle.withRedirectionTo; @@ -182,6 +186,7 @@ public final class MetadataManager private final BuiltinFunctionResolver functionResolver; private final SystemSecurityMetadata systemSecurityMetadata; private final TransactionManager transactionManager; + private final LanguageFunctionManager languageFunctionManager; private final TypeManager typeManager; private final TypeCoercion typeCoercion; @@ -194,6 +199,7 @@ public MetadataManager( SystemSecurityMetadata systemSecurityMetadata, TransactionManager transactionManager, GlobalFunctionCatalog globalFunctionCatalog, + LanguageFunctionManager languageFunctionManager, TypeManager typeManager) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); @@ -204,6 +210,7 @@ public MetadataManager( this.systemSecurityMetadata = requireNonNull(systemSecurityMetadata, "systemSecurityMetadata is null"); this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); + this.languageFunctionManager = requireNonNull(languageFunctionManager, "languageFunctionManager is null"); } @Override @@ -1022,6 +1029,7 @@ public Optional getSupportedType(Session session, CatalogHandle catalogHan @Override public void beginQuery(Session session) { + languageFunctionManager.registerQuery(session); } @Override @@ -1031,6 +1039,7 @@ public void cleanupQuery(Session session) if (queryCatalogs != null) { queryCatalogs.finish(); } + languageFunctionManager.unregisterQuery(session); } @Override @@ -2348,6 +2357,9 @@ public ResolvedFunction getCoercion(CatalogSchemaFunctionName name, Type fromTyp @Override public FunctionDependencyDeclaration getFunctionDependencies(Session session, CatalogHandle catalogHandle, FunctionId functionId, BoundSignature boundSignature) { + if (isTrinoSqlLanguageFunction(functionId)) { + throw new IllegalArgumentException("Function dependencies for SQL functions must be fetched directly from the language manager"); + } if (catalogHandle.equals(GlobalSystemConnector.CATALOG_HANDLE)) { return functions.getFunctionDependencies(functionId, boundSignature); } @@ -2375,11 +2387,33 @@ private Collection getBuiltinFunctions(String functionN .collect(toImmutableList()); } - private static List getFunctions(Session session, ConnectorMetadata metadata, CatalogHandle catalogHandle, SchemaFunctionName name) + private List getFunctions(Session session, ConnectorMetadata metadata, CatalogHandle catalogHandle, SchemaFunctionName name) { - return metadata.getFunctions(session.toConnectorSession(catalogHandle), name).stream() + ConnectorSession connectorSession = session.toConnectorSession(catalogHandle); + ImmutableList.Builder functions = ImmutableList.builder(); + + metadata.getFunctions(connectorSession, name).stream() .map(function -> new CatalogFunctionMetadata(catalogHandle, name.getSchemaName(), function)) - .collect(toImmutableList()); + .forEach(functions::add); + + RunAsIdentityLoader identityLoader = owner -> { + CatalogSchemaFunctionName functionName = new CatalogSchemaFunctionName(catalogHandle.getCatalogName(), name); + + Optional systemIdentity = Optional.empty(); + if (getCatalogMetadata(session, catalogHandle).getSecurityManagement() == SYSTEM) { + systemIdentity = systemSecurityMetadata.getFunctionRunAsIdentity(session, functionName); + } + + return systemIdentity.or(() -> owner.map(Identity::ofUser)) + .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "No identity for SECURITY DEFINER function: " + functionName)); + }; + + LanguageFunctionLoader emptyLoader = (ignoredSession, ignoredName) -> ImmutableList.of(); + languageFunctionManager.getFunctions(session, catalogHandle, name, emptyLoader, identityLoader).stream() + .map(function -> new CatalogFunctionMetadata(catalogHandle, name.getSchemaName(), function)) + .forEach(functions::add); + + return functions.build(); } @Override @@ -2574,6 +2608,7 @@ public static class TestMetadataManagerBuilder private TransactionManager transactionManager; private TypeManager typeManager = TESTING_TYPE_MANAGER; private GlobalFunctionCatalog globalFunctionCatalog; + private LanguageFunctionManager languageFunctionManager; private TestMetadataManagerBuilder() {} @@ -2595,6 +2630,12 @@ public TestMetadataManagerBuilder withGlobalFunctionCatalog(GlobalFunctionCatalo return this; } + public TestMetadataManagerBuilder withLanguageFunctionManager(LanguageFunctionManager languageFunctionManager) + { + this.languageFunctionManager = languageFunctionManager; + return this; + } + public MetadataManager build() { TransactionManager transactionManager = this.transactionManager; @@ -2610,10 +2651,15 @@ public MetadataManager build() globalFunctionCatalog.addFunctions(new InternalFunctionBundle(new LiteralFunction(new InternalBlockEncodingSerde(new BlockEncodingManager(), typeManager)))); } + if (languageFunctionManager == null) { + languageFunctionManager = new LanguageFunctionManager(new SqlParser(), typeManager, user -> ImmutableSet.of()); + } + return new MetadataManager( new DisabledSystemSecurityMetadata(), transactionManager, globalFunctionCatalog, + languageFunctionManager, typeManager); } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/QualifiedObjectName.java b/core/trino-main/src/main/java/io/trino/metadata/QualifiedObjectName.java index 13651107dd69..e052373185c0 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/QualifiedObjectName.java +++ b/core/trino-main/src/main/java/io/trino/metadata/QualifiedObjectName.java @@ -22,6 +22,7 @@ import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.connector.SchemaRoutineName; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.SchemaFunctionName; import java.util.Objects; import java.util.function.Function; @@ -96,6 +97,11 @@ public QualifiedTablePrefix asQualifiedTablePrefix() return new QualifiedTablePrefix(catalogName, schemaName, objectName); } + public SchemaFunctionName asSchemaFunctionName() + { + return new SchemaFunctionName(schemaName, objectName); + } + @Override public boolean equals(Object obj) { diff --git a/core/trino-main/src/main/java/io/trino/metadata/SystemSecurityMetadata.java b/core/trino-main/src/main/java/io/trino/metadata/SystemSecurityMetadata.java index 7296a59e46e8..8318cd17b12f 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SystemSecurityMetadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SystemSecurityMetadata.java @@ -16,6 +16,7 @@ import io.trino.Session; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaTableName; +import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.security.GrantInfo; import io.trino.spi.security.Identity; import io.trino.spi.security.Privilege; @@ -138,6 +139,11 @@ public interface SystemSecurityMetadata */ void setViewOwner(Session session, CatalogSchemaTableName view, TrinoPrincipal principal); + /** + * Get the identity to run the function as + */ + Optional getFunctionRunAsIdentity(Session session, CatalogSchemaFunctionName functionName); + /** * A schema was created */ diff --git a/core/trino-main/src/main/java/io/trino/metadata/WorkerLanguageFunctionProvider.java b/core/trino-main/src/main/java/io/trino/metadata/WorkerLanguageFunctionProvider.java new file mode 100644 index 000000000000..bd46c6f79d23 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/WorkerLanguageFunctionProvider.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import io.trino.execution.TaskId; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.sql.routine.SqlRoutineCompiler; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; + +public class WorkerLanguageFunctionProvider + implements LanguageFunctionProvider +{ + private final Map> queryFunctions = new ConcurrentHashMap<>(); + + @Override + public void registerTask(TaskId taskId, List functions) + { + queryFunctions.computeIfAbsent(taskId, ignored -> functions.stream().collect(toImmutableMap(LanguageScalarFunctionData::resolvedFunction, Function.identity()))); + } + + @Override + public void unregisterTask(TaskId taskId) + { + queryFunctions.remove(taskId); + } + + @Override + public ScalarFunctionImplementation specialize(FunctionManager functionManager, ResolvedFunction resolvedFunction, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) + { + LanguageScalarFunctionData functionData = queryFunctions.values().stream() + .map(queryFunctions -> queryFunctions.get(resolvedFunction)) + .filter(Objects::nonNull) + .findFirst() + .orElseThrow(() -> new IllegalStateException("Unknown function implementation: " + resolvedFunction.getFunctionId())); + + // Recompile every time this function is called as the function dependencies may have changed. + // The caller caches, so this should not be a problem. + // TODO: compiler should use function dependencies instead of function manager + SpecializedSqlScalarFunction function = new SqlRoutineCompiler(functionManager).compile(functionData.routine()); + return function.getScalarFunctionImplementation(invocationConvention); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index 94b506e3ec46..613edeb3335e 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -95,6 +95,8 @@ import io.trino.memory.TotalReservationLowMemoryKiller; import io.trino.memory.TotalReservationOnBlockedNodesQueryLowMemoryKiller; import io.trino.memory.TotalReservationOnBlockedNodesTaskLowMemoryKiller; +import io.trino.metadata.LanguageFunctionManager; +import io.trino.metadata.LanguageFunctionProvider; import io.trino.metadata.Split; import io.trino.operator.ForScheduler; import io.trino.operator.OperatorStats; @@ -104,6 +106,7 @@ import io.trino.server.ui.WebUiModule; import io.trino.server.ui.WorkerResource; import io.trino.spi.memory.ClusterMemoryPoolManager; +import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.AnalyzerFactory; import io.trino.sql.analyzer.QueryExplainerFactory; import io.trino.sql.planner.OptimizerStatsMBeanExporter; @@ -274,6 +277,11 @@ List getCompositeOutputDataSizeEstimatorDelegate // dynamic filtering service binder.bind(DynamicFilterService.class).in(Scopes.SINGLETON); + // language functions + binder.bind(LanguageFunctionManager.class).in(Scopes.SINGLETON); + binder.bind(InitializeLanguageFunctionManager.class).asEagerSingleton(); + binder.bind(LanguageFunctionProvider.class).to(LanguageFunctionManager.class).in(Scopes.SINGLETON); + // analyzer binder.bind(AnalyzerFactory.class).in(Scopes.SINGLETON); @@ -360,6 +368,16 @@ List getCompositeOutputDataSizeEstimatorDelegate binder.bind(ExecutorCleanup.class).asEagerSingleton(); } + // working around circular dependency Metadata <-> PlannerContext + private static class InitializeLanguageFunctionManager + { + @Inject + public InitializeLanguageFunctionManager(LanguageFunctionManager languageFunctionManager, PlannerContext plannerContext) + { + languageFunctionManager.setPlannerContext(plannerContext); + } + } + @Provides @Singleton public static ResourceGroupManager getResourceGroupManager(@SuppressWarnings("rawtypes") ResourceGroupManager manager) diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index 5ce7a985d0ac..ad76199fe4b1 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -75,6 +75,7 @@ import io.trino.metadata.InternalBlockEncodingSerde; import io.trino.metadata.InternalFunctionBundle; import io.trino.metadata.InternalNodeManager; +import io.trino.metadata.LanguageFunctionManager; import io.trino.metadata.LiteralFunction; import io.trino.metadata.Metadata; import io.trino.metadata.MetadataManager; @@ -399,6 +400,7 @@ protected void setup(Binder binder) binder.bind(TableProceduresRegistry.class).in(Scopes.SINGLETON); binder.bind(TableFunctionRegistry.class).in(Scopes.SINGLETON); binder.bind(PlannerContext.class).in(Scopes.SINGLETON); + binder.bind(LanguageFunctionManager.class).in(Scopes.SINGLETON); // function binder.bind(FunctionManager.class).in(Scopes.SINGLETON); diff --git a/core/trino-main/src/main/java/io/trino/server/WorkerModule.java b/core/trino-main/src/main/java/io/trino/server/WorkerModule.java index 62f5ccfbc051..fbfb98632297 100644 --- a/core/trino-main/src/main/java/io/trino/server/WorkerModule.java +++ b/core/trino-main/src/main/java/io/trino/server/WorkerModule.java @@ -23,6 +23,8 @@ import io.trino.execution.resourcegroups.ResourceGroupManager; import io.trino.failuredetector.FailureDetector; import io.trino.failuredetector.NoOpFailureDetector; +import io.trino.metadata.LanguageFunctionProvider; +import io.trino.metadata.WorkerLanguageFunctionProvider; import io.trino.server.ui.NoWebUiAuthenticationFilter; import io.trino.server.ui.WebUiAuthenticationFilter; @@ -48,6 +50,10 @@ public void configure(Binder binder) throw new UnsupportedOperationException(); })); + // language functions + binder.bind(WorkerLanguageFunctionProvider.class).in(Scopes.SINGLETON); + binder.bind(LanguageFunctionProvider.class).to(WorkerLanguageFunctionProvider.class).in(Scopes.SINGLETON); + binder.bind(WebUiAuthenticationFilter.class).to(NoWebUiAuthenticationFilter.class).in(Scopes.SINGLETON); } diff --git a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java index 96f17c473b6a..9486c0181b97 100644 --- a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java +++ b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java @@ -62,6 +62,7 @@ import io.trino.metadata.FunctionManager; import io.trino.metadata.GlobalFunctionCatalog; import io.trino.metadata.InternalNodeManager; +import io.trino.metadata.LanguageFunctionManager; import io.trino.metadata.Metadata; import io.trino.metadata.ProcedureRegistry; import io.trino.metadata.SessionPropertyManager; @@ -164,6 +165,7 @@ public static Builder builder() private final QueryExplainer queryExplainer; private final SessionPropertyManager sessionPropertyManager; private final FunctionManager functionManager; + private final LanguageFunctionManager languageFunctionManager; private final GlobalFunctionCatalog globalFunctionCatalog; private final StatsCalculator statsCalculator; private final ProcedureRegistry procedureRegistry; @@ -357,6 +359,7 @@ private TestingTrinoServer( sessionPropertyDefaults = injector.getInstance(SessionPropertyDefaults.class); nodePartitioningManager = injector.getInstance(NodePartitioningManager.class); clusterMemoryManager = injector.getInstance(ClusterMemoryManager.class); + languageFunctionManager = injector.getInstance(LanguageFunctionManager.class); statsCalculator = injector.getInstance(StatsCalculator.class); procedureRegistry = injector.getInstance(ProcedureRegistry.class); injector.getInstance(CertificateAuthenticatorManager.class).useDefaultAuthenticator(); @@ -369,6 +372,7 @@ private TestingTrinoServer( sessionPropertyDefaults = null; nodePartitioningManager = null; clusterMemoryManager = null; + languageFunctionManager = null; statsCalculator = null; procedureRegistry = null; } @@ -545,6 +549,12 @@ public FunctionManager getFunctionManager() return functionManager; } + public LanguageFunctionManager getLanguageFunctionManager() + { + checkState(coordinator, "not a coordinator"); + return languageFunctionManager; + } + public void addFunctions(FunctionBundle functionBundle) { globalFunctionCatalog.addFunctions(functionBundle); diff --git a/core/trino-main/src/main/java/io/trino/sql/PlannerContext.java b/core/trino-main/src/main/java/io/trino/sql/PlannerContext.java index 9ec3c4dc8734..afb23f58186b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/PlannerContext.java +++ b/core/trino-main/src/main/java/io/trino/sql/PlannerContext.java @@ -18,6 +18,7 @@ import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionManager; import io.trino.metadata.FunctionResolver; +import io.trino.metadata.LanguageFunctionManager; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction.ResolvedFunctionDecoder; import io.trino.spi.block.BlockEncodingSerde; @@ -42,6 +43,7 @@ public class PlannerContext private final BlockEncodingSerde blockEncodingSerde; private final TypeManager typeManager; private final FunctionManager functionManager; + private final LanguageFunctionManager languageFunctionManager; private final Tracer tracer; private final ResolvedFunctionDecoder functionDecoder; @@ -51,6 +53,7 @@ public PlannerContext(Metadata metadata, BlockEncodingSerde blockEncodingSerde, TypeManager typeManager, FunctionManager functionManager, + LanguageFunctionManager languageFunctionManager, Tracer tracer) { this.metadata = requireNonNull(metadata, "metadata is null"); @@ -58,6 +61,7 @@ public PlannerContext(Metadata metadata, this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.functionManager = requireNonNull(functionManager, "functionManager is null"); + this.languageFunctionManager = requireNonNull(languageFunctionManager, "languageFunctionManager is null"); // the function decoder contains caches that are critical for planner performance so this must be shared this.functionDecoder = new ResolvedFunctionDecoder(typeManager::getType); this.tracer = requireNonNull(tracer, "tracer is null"); @@ -100,7 +104,12 @@ public FunctionResolver getFunctionResolver() public FunctionResolver getFunctionResolver(WarningCollector warningCollector) { - return new FunctionResolver(metadata, typeManager, functionDecoder, warningCollector); + return new FunctionResolver(metadata, typeManager, languageFunctionManager, functionDecoder, warningCollector); + } + + public LanguageFunctionManager getLanguageFunctionManager() + { + return languageFunctionManager; } public Tracer getTracer() diff --git a/core/trino-main/src/main/java/io/trino/sql/SqlPath.java b/core/trino-main/src/main/java/io/trino/sql/SqlPath.java index ad7cd25eb83b..711a3e3f3ff2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/SqlPath.java +++ b/core/trino-main/src/main/java/io/trino/sql/SqlPath.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableList; import io.trino.connector.system.GlobalSystemConnector; import io.trino.metadata.GlobalFunctionCatalog; +import io.trino.metadata.LanguageFunctionManager; import io.trino.spi.connector.CatalogSchemaName; import io.trino.sql.parser.SqlParser; import io.trino.sql.tree.Identifier; @@ -28,6 +29,7 @@ import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.metadata.GlobalFunctionCatalog.BUILTIN_SCHEMA; import static java.util.Objects.requireNonNull; public final class SqlPath @@ -40,6 +42,7 @@ public final class SqlPath public static SqlPath buildPath(String rawPath, Optional defaultCatalog) { ImmutableList.Builder path = ImmutableList.builder(); + path.add(new CatalogSchemaName(GlobalSystemConnector.NAME, LanguageFunctionManager.QUERY_LOCAL_SCHEMA)); path.add(new CatalogSchemaName(GlobalSystemConnector.NAME, GlobalFunctionCatalog.BUILTIN_SCHEMA)); for (SqlPathElement pathElement : parsePath(rawPath)) { pathElement.getCatalog() @@ -107,4 +110,16 @@ public String toString() { return rawPath; } + + public SqlPath forView(List storedPath) + { + // For a view, we prepend the global function schema to the path, as the + // global function schema should not be in the path that is stored for the view. + // We do not change the raw path, as that is used for the current_path function. + List viewPath = ImmutableList.builder() + .add(new CatalogSchemaName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA)) + .addAll(storedPath) + .build(); + return new SqlPath(viewPath, rawPath); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index e0753191590a..8ffd81278701 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -160,6 +160,7 @@ import io.trino.sql.tree.FieldReference; import io.trino.sql.tree.FrameBound; import io.trino.sql.tree.FunctionCall; +import io.trino.sql.tree.FunctionSpecification; import io.trino.sql.tree.Grant; import io.trino.sql.tree.GroupBy; import io.trino.sql.tree.GroupingElement; @@ -209,6 +210,7 @@ import io.trino.sql.tree.Row; import io.trino.sql.tree.RowPattern; import io.trino.sql.tree.SampledRelation; +import io.trino.sql.tree.SecurityCharacteristic; import io.trino.sql.tree.Select; import io.trino.sql.tree.SelectItem; import io.trino.sql.tree.SetColumnType; @@ -330,6 +332,7 @@ import static io.trino.spi.StandardErrorCode.NULL_TREATMENT_NOT_ALLOWED; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static io.trino.spi.StandardErrorCode.SCHEMA_NOT_FOUND; +import static io.trino.spi.StandardErrorCode.SYNTAX_ERROR; import static io.trino.spi.StandardErrorCode.TABLE_ALREADY_EXISTS; import static io.trino.spi.StandardErrorCode.TABLE_HAS_NO_COLUMNS; import static io.trino.spi.StandardErrorCode.TABLE_NOT_FOUND; @@ -355,6 +358,7 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static io.trino.sql.NodeUtils.getSortItemsFromOrderBy; +import static io.trino.sql.SqlFormatter.formatSql; import static io.trino.sql.analyzer.AggregationAnalyzer.verifyOrderByAggregations; import static io.trino.sql.analyzer.AggregationAnalyzer.verifySourceAggregations; import static io.trino.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions; @@ -1002,6 +1006,10 @@ protected Scope visitCreateView(CreateView node, Optional scope) { QualifiedObjectName viewName = createQualifiedObjectName(session, node, node.getName()); + node.getQuery().getFunctions().stream().findFirst().ifPresent(function -> { + throw semanticException(NOT_SUPPORTED, function, "Views cannot contain inline functions"); + }); + // analyze the query that creates the view StatementAnalyzer analyzer = statementAnalyzerFactory.createStatementAnalyzer(analysis, session, warningCollector, CorrelationSupport.ALLOWED); @@ -1517,7 +1525,19 @@ protected Scope visitExplainAnalyze(ExplainAnalyze node, Optional scope) @Override protected Scope visitQuery(Query node, Optional scope) { - verify(node.getFunctions().isEmpty(), "Inline functions not yet supported"); + verify(isTopLevel || node.getFunctions().isEmpty(), "Inline functions must be at the top level"); + for (FunctionSpecification function : node.getFunctions()) { + if (function.getName().getPrefix().isPresent()) { + throw semanticException(SYNTAX_ERROR, function, "Inline function names cannot be qualified: " + function.getName()); + } + function.getRoutineCharacteristics().stream() + .filter(SecurityCharacteristic.class::isInstance) + .findFirst() + .ifPresent(security -> { + throw semanticException(NOT_SUPPORTED, security, "Security mode not supported for inline functions"); + }); + plannerContext.getLanguageFunctionManager().addInlineFunction(session, formatSql(function), accessControl); + } Scope withScope = analyzeWith(node, scope); Scope queryBodyScope = process(node.getQueryBody(), withScope); @@ -2487,6 +2507,11 @@ private Scope createScopeForView( } Query query = parseView(originalSql, name, table); + + if (!query.getFunctions().isEmpty()) { + throw semanticException(NOT_SUPPORTED, table, "View contains inline function: " + name); + } + analysis.registerTableForView(table); RelationType descriptor = analyzeView(query, name, catalog, schema, owner, path, table); analysis.unregisterTableForView(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index a68c2d4c201f..3e2bb5093b91 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -79,6 +79,7 @@ import static io.trino.SystemSessionProperties.isComplexExpressionPushdown; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName; +import static io.trino.metadata.LanguageFunctionManager.isInlineFunction; import static io.trino.metadata.LiteralFunction.LITERAL_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; @@ -699,7 +700,10 @@ protected Optional visitFunctionCall(FunctionCall node, Voi } FunctionName name; - if (isBuiltinFunctionName(functionName)) { + if (isInlineFunction(functionName)) { + throw new IllegalArgumentException("Connector expressions cannot reference inline functions: " + functionName); + } + else if (isBuiltinFunctionName(functionName)) { name = new FunctionName(functionName.getFunctionName()); } else { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java index 2ee37ba886c9..4b51a8edff02 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java @@ -20,6 +20,7 @@ import com.google.errorprone.annotations.Immutable; import io.trino.connector.CatalogProperties; import io.trino.cost.StatsAndCosts; +import io.trino.metadata.LanguageScalarFunctionData; import io.trino.spi.type.Type; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNode; @@ -52,6 +53,7 @@ public class PlanFragment private final PartitioningScheme outputPartitioningScheme; private final StatsAndCosts statsAndCosts; private final List activeCatalogs; + private final List languageFunctions; private final Optional jsonRepresentation; // Only for creating instances without the JSON representation embedded @@ -68,7 +70,8 @@ private PlanFragment( List remoteSourceNodes, PartitioningScheme outputPartitioningScheme, StatsAndCosts statsAndCosts, - List activeCatalogs) + List activeCatalogs, + List languageFunctions) { this.id = requireNonNull(id, "id is null"); this.root = requireNonNull(root, "root is null"); @@ -83,6 +86,7 @@ private PlanFragment( this.outputPartitioningScheme = requireNonNull(outputPartitioningScheme, "outputPartitioningScheme is null"); this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.activeCatalogs = requireNonNull(activeCatalogs, "activeCatalogs is null"); + this.languageFunctions = requireNonNull(languageFunctions, "languageFunctions is null"); this.jsonRepresentation = Optional.empty(); } @@ -97,6 +101,7 @@ public PlanFragment( @JsonProperty("outputPartitioningScheme") PartitioningScheme outputPartitioningScheme, @JsonProperty("statsAndCosts") StatsAndCosts statsAndCosts, @JsonProperty("activeCatalogs") List activeCatalogs, + @JsonProperty("languageFunctions") List languageFunctions, @JsonProperty("jsonRepresentation") Optional jsonRepresentation) { this.id = requireNonNull(id, "id is null"); @@ -108,6 +113,7 @@ public PlanFragment( this.partitionedSourcesSet = ImmutableSet.copyOf(partitionedSources); this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.activeCatalogs = requireNonNull(activeCatalogs, "activeCatalogs is null"); + this.languageFunctions = requireNonNull(languageFunctions, "languageFunctions is null"); this.jsonRepresentation = requireNonNull(jsonRepresentation, "jsonRepresentation is null"); checkArgument( @@ -190,6 +196,12 @@ public List getActiveCatalogs() return activeCatalogs; } + @JsonProperty + public List getLanguageFunctions() + { + return languageFunctions; + } + @JsonProperty public Optional getJsonRepresentation() { @@ -216,7 +228,8 @@ public PlanFragment withoutEmbeddedJsonRepresentation() this.remoteSourceNodes, this.outputPartitioningScheme, this.statsAndCosts, - this.activeCatalogs); + this.activeCatalogs, + this.languageFunctions); } public List getTypes() @@ -270,7 +283,18 @@ private static void findRemoteSourceNodes(PlanNode node, ImmutableList.Builder bucketToPartition) { - return new PlanFragment(id, root, symbols, partitioning, partitionCount, partitionedSources, outputPartitioningScheme.withBucketToPartition(bucketToPartition), statsAndCosts, activeCatalogs, jsonRepresentation); + return new PlanFragment( + id, + root, + symbols, + partitioning, + partitionCount, + partitionedSources, + outputPartitioningScheme.withBucketToPartition(bucketToPartition), + statsAndCosts, + activeCatalogs, + languageFunctions, + jsonRepresentation); } @Override @@ -297,6 +321,7 @@ public PlanFragment withPartitionCount(Optional partitionCount) this.outputPartitioningScheme, this.statsAndCosts, this.activeCatalogs, + this.languageFunctions, this.jsonRepresentation); } @@ -312,6 +337,7 @@ public PlanFragment withOutputPartitioningScheme(PartitioningScheme outputPartit outputPartitioningScheme, this.statsAndCosts, this.activeCatalogs, + this.languageFunctions, this.jsonRepresentation); } @@ -327,6 +353,7 @@ public PlanFragment withRoot(PlanNode root) this.outputPartitioningScheme, this.statsAndCosts, this.activeCatalogs, + this.languageFunctions, this.jsonRepresentation); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index b64a65339f23..48ccb73292b3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -25,6 +25,8 @@ import io.trino.metadata.CatalogInfo; import io.trino.metadata.CatalogManager; import io.trino.metadata.FunctionManager; +import io.trino.metadata.LanguageFunctionManager; +import io.trino.metadata.LanguageScalarFunctionData; import io.trino.metadata.Metadata; import io.trino.metadata.TableHandle; import io.trino.metadata.TableProperties.TablePartitioning; @@ -96,6 +98,7 @@ public class PlanFragmenter private final FunctionManager functionManager; private final TransactionManager transactionManager; private final CatalogManager catalogManager; + private final LanguageFunctionManager languageFunctionManager; private final int stageCountWarningThreshold; @Inject @@ -104,6 +107,7 @@ public PlanFragmenter( FunctionManager functionManager, TransactionManager transactionManager, CatalogManager catalogManager, + LanguageFunctionManager languageFunctionManager, QueryManagerConfig queryManagerConfig) { this.metadata = requireNonNull(metadata, "metadata is null"); @@ -111,6 +115,7 @@ public PlanFragmenter( this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.catalogManager = requireNonNull(catalogManager, "catalogManager is null"); this.stageCountWarningThreshold = requireNonNull(queryManagerConfig, "queryManagerConfig is null").getStageCountWarningThreshold(); + this.languageFunctionManager = requireNonNull(languageFunctionManager, "languageFunctionManager is null"); } public SubPlan createSubPlans(Session session, Plan plan, boolean forceSingleNode, WarningCollector warningCollector) @@ -119,7 +124,8 @@ public SubPlan createSubPlans(Session session, Plan plan, boolean forceSingleNod .map(CatalogInfo::getCatalogHandle) .flatMap(catalogHandle -> catalogManager.getCatalogProperties(catalogHandle).stream()) .collect(toImmutableList()); - Fragmenter fragmenter = new Fragmenter(session, metadata, functionManager, plan.getTypes(), plan.getStatsAndCosts(), activeCatalogs); + List languageScalarFunctions = languageFunctionManager.serializeFunctionsForWorkers(session); + Fragmenter fragmenter = new Fragmenter(session, metadata, functionManager, plan.getTypes(), plan.getStatsAndCosts(), activeCatalogs, languageScalarFunctions); FragmentProperties properties = new FragmentProperties(new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getRoot().getOutputSymbols())); if (forceSingleNode || isForceSingleNodeOutput(session)) { @@ -195,6 +201,7 @@ private SubPlan reassignPartitioningHandleIfNecessaryHelper(Session session, Sub outputPartitioningScheme.getPartitionCount()), fragment.getStatsAndCosts(), fragment.getActiveCatalogs(), + fragment.getLanguageFunctions(), fragment.getJsonRepresentation()); ImmutableList.Builder childrenBuilder = ImmutableList.builder(); @@ -215,9 +222,17 @@ private static class Fragmenter private final TypeProvider types; private final StatsAndCosts statsAndCosts; private final List activeCatalogs; + private final List languageFunctions; private final PlanFragmentIdAllocator idAllocator = new PlanFragmentIdAllocator(ROOT_FRAGMENT_ID + 1); - public Fragmenter(Session session, Metadata metadata, FunctionManager functionManager, TypeProvider types, StatsAndCosts statsAndCosts, List activeCatalogs) + public Fragmenter( + Session session, + Metadata metadata, + FunctionManager functionManager, + TypeProvider types, + StatsAndCosts statsAndCosts, + List activeCatalogs, + List languageFunctions) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); @@ -225,6 +240,7 @@ public Fragmenter(Session session, Metadata metadata, FunctionManager functionMa this.types = requireNonNull(types, "types is null"); this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.activeCatalogs = requireNonNull(activeCatalogs, "activeCatalogs is null"); + this.languageFunctions = requireNonNull(languageFunctions, "languageFunctions is null"); } public SubPlan buildRootFragment(PlanNode root, FragmentProperties properties) @@ -252,6 +268,7 @@ private SubPlan buildFragment(PlanNode root, FragmentProperties properties, Plan properties.getPartitioningScheme(), statsAndCosts.getForSubplan(root), activeCatalogs, + languageFunctions, Optional.of(jsonFragmentPlan(root, symbols, metadata, functionManager, session))); return new SubPlan(fragment, properties.getChildren()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java index 099c825fb9a3..7a88e1b76e48 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java @@ -127,6 +127,7 @@ public static SubPlan overridePartitionCountRecursively( sourceFragment.getOutputPartitioningScheme().withPartitionCount(Optional.of(newPartitionCount)), sourceFragment.getStatsAndCosts(), sourceFragment.getActiveCatalogs(), + sourceFragment.getLanguageFunctions(), sourceFragment.getJsonRepresentation()); SubPlan newSource = new SubPlan( runtimeAdaptivePlanFragment, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index e619a284f892..f8a070b67c93 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -162,6 +162,7 @@ import static io.trino.execution.StageInfo.getAllStages; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName; +import static io.trino.metadata.LanguageFunctionManager.isInlineFunction; import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.server.DynamicFilterService.DynamicFilterDomainStats; import static io.trino.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; @@ -650,6 +651,7 @@ public static String graphvizLogicalPlan(PlanNode plan, TypeProvider types) new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getOutputSymbols()), StatsAndCosts.empty(), ImmutableList.of(), + ImmutableList.of(), Optional.empty()); return GraphvizPrinter.printLogical(ImmutableList.of(fragment)); } @@ -2215,7 +2217,7 @@ public static String formatAggregation(Anonymizer anonymizer, Aggregation aggreg private static String formatFunctionName(ResolvedFunction function) { CatalogSchemaFunctionName name = function.getSignature().getName(); - if (isBuiltinFunctionName(name)) { + if (isInlineFunction(name) || isBuiltinFunctionName(name)) { return name.getFunctionName(); } return name.toString(); @@ -2231,7 +2233,7 @@ public Expression rewriteFunctionCall(FunctionCall node, Void context, Expressio FunctionCall rewritten = treeRewriter.defaultRewrite(node, context); CatalogSchemaFunctionName name = extractFunctionName(node.getName()); QualifiedName qualifiedName; - if (isBuiltinFunctionName(name)) { + if (isInlineFunction(name) || isBuiltinFunctionName(name)) { qualifiedName = QualifiedName.of(name.getFunctionName()); } else { diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index c3947b7a9d6d..b7cdd40c0839 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -93,6 +93,7 @@ import io.trino.metadata.InternalBlockEncodingSerde; import io.trino.metadata.InternalFunctionBundle; import io.trino.metadata.InternalNodeManager; +import io.trino.metadata.LanguageFunctionManager; import io.trino.metadata.LiteralFunction; import io.trino.metadata.MaterializedViewPropertyManager; import io.trino.metadata.Metadata; @@ -276,6 +277,7 @@ public class LocalQueryRunner private final TypeRegistry typeRegistry; private final GlobalFunctionCatalog globalFunctionCatalog; private final FunctionManager functionManager; + private final LanguageFunctionManager languageFunctionManager; private final StatsCalculator statsCalculator; private final ScalarStatsCalculator scalarStatsCalculator; private final CostCalculator costCalculator; @@ -378,15 +380,17 @@ private LocalQueryRunner( this.globalFunctionCatalog = new GlobalFunctionCatalog(); globalFunctionCatalog.addFunctions(new InternalFunctionBundle(new LiteralFunction(blockEncodingSerde))); globalFunctionCatalog.addFunctions(SystemFunctionBundle.create(featuresConfig, typeOperators, blockTypeOperators, nodeManager.getCurrentNode().getNodeVersion())); + this.groupProvider = new TestingGroupProviderManager(); + this.languageFunctionManager = new LanguageFunctionManager(sqlParser, typeManager, groupProvider); Metadata metadata = metadataDecorator.apply(new MetadataManager( new DisabledSystemSecurityMetadata(), transactionManager, globalFunctionCatalog, + languageFunctionManager, typeManager)); typeRegistry.addType(new JsonPath2016Type(new TypeDeserializer(typeManager), blockEncodingSerde)); this.joinCompiler = new JoinCompiler(typeOperators); PageIndexerFactory pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler); - this.groupProvider = new TestingGroupProviderManager(); this.accessControl = new TestingAccessControlManager(transactionManager, eventListenerManager); accessControl.loadSystemAccessControl(AllowAllSystemAccessControl.NAME, ImmutableMap.of()); @@ -415,7 +419,7 @@ private LocalQueryRunner( this.sessionPropertyManager = createSessionPropertyManager(catalogManager, extraSessionProperties, taskManagerConfig, featuresConfig, optimizerConfig); this.nodePartitioningManager = new NodePartitioningManager(nodeScheduler, typeOperators, createNodePartitioningProvider(catalogManager)); TableProceduresRegistry tableProceduresRegistry = new TableProceduresRegistry(createTableProceduresProvider(catalogManager)); - this.functionManager = new FunctionManager(createFunctionProvider(catalogManager), globalFunctionCatalog); + this.functionManager = new FunctionManager(createFunctionProvider(catalogManager), globalFunctionCatalog, languageFunctionManager); TableFunctionRegistry tableFunctionRegistry = new TableFunctionRegistry(createTableFunctionProvider(catalogManager)); this.schemaPropertyManager = createSchemaPropertyManager(catalogManager); this.columnPropertyManager = createColumnPropertyManager(catalogManager); @@ -431,7 +435,7 @@ private LocalQueryRunner( new JsonValueFunction(functionManager, metadata, typeManager), new JsonQueryFunction(functionManager, metadata, typeManager))); - this.plannerContext = new PlannerContext(metadata, typeOperators, blockEncodingSerde, typeManager, functionManager, tracer); + this.plannerContext = new PlannerContext(metadata, typeOperators, blockEncodingSerde, typeManager, functionManager, languageFunctionManager, tracer); this.pageFunctionCompiler = new PageFunctionCompiler(functionManager, 0); this.expressionCompiler = new ExpressionCompiler(functionManager, pageFunctionCompiler); this.joinFilterFunctionCompiler = new JoinFilterFunctionCompiler(functionManager); @@ -455,7 +459,7 @@ private LocalQueryRunner( this.costCalculator = new CostCalculatorUsingExchanges(taskCountEstimator); this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, taskCountEstimator); - this.planFragmenter = new PlanFragmenter(metadata, functionManager, transactionManager, catalogManager, new QueryManagerConfig()); + this.planFragmenter = new PlanFragmenter(metadata, functionManager, transactionManager, catalogManager, languageFunctionManager, new QueryManagerConfig()); GlobalSystemConnector globalSystemConnector = new GlobalSystemConnector(ImmutableSet.of( new NodeSystemTable(nodeManager), @@ -490,6 +494,7 @@ private LocalQueryRunner( exchangeManagerRegistry); catalogManager.registerGlobalSystemConnector(globalSystemConnector); + languageFunctionManager.setPlannerContext(plannerContext); // rewrite session to use managed SessionPropertyMetadata this.defaultSession = new Session( @@ -656,6 +661,12 @@ public FunctionManager getFunctionManager() return functionManager; } + @Override + public LanguageFunctionManager getLanguageFunctionManager() + { + return languageFunctionManager; + } + public TypeOperators getTypeOperators() { return plannerContext.getTypeOperators(); @@ -865,6 +876,7 @@ private MaterializedResultWithPlan executeInternal(Session session, @Language("S lock.readLock().lock(); try (Closer closer = Closer.create()) { accessControl.checkCanExecuteQuery(session.getIdentity()); + AtomicReference builder = new AtomicReference<>(); PageConsumerOutputFactory outputFactory = new PageConsumerOutputFactory(types -> { builder.compareAndSet(null, MaterializedResult.resultBuilder(session, types)); @@ -951,6 +963,7 @@ public List createDrivers(Session session, @Language("SQL") String sql, public SubPlan createSubPlans(Session session, Plan plan, boolean forceSingleNode) { + languageFunctionManager.tryRegisterQuery(session); return planFragmenter.createSubPlans(session, plan, forceSingleNode, NOOP); } diff --git a/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java index fd399c087804..4d5b6edf9a52 100644 --- a/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java @@ -19,6 +19,7 @@ import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionBundle; import io.trino.metadata.FunctionManager; +import io.trino.metadata.LanguageFunctionManager; import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.SessionPropertyManager; @@ -62,6 +63,8 @@ public interface QueryRunner FunctionManager getFunctionManager(); + LanguageFunctionManager getLanguageFunctionManager(); + SplitManager getSplitManager(); ExchangeManager getExchangeManager(); diff --git a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java index 24d2bbf0507f..bdaa49f19b60 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java @@ -105,6 +105,7 @@ public void setUp() session = testSessionBuilder().setCatalog(TEST_CATALOG_NAME).build(); localQueryRunner = LocalQueryRunner.create(session); + localQueryRunner.getLanguageFunctionManager().registerQuery(session); localQueryRunner.createCatalog(TEST_CATALOG_NAME, new TpchConnectorFactory(), ImmutableMap.of()); planFragmenter = new PlanFragmenter( @@ -112,6 +113,7 @@ public void setUp() localQueryRunner.getFunctionManager(), localQueryRunner.getTransactionManager(), localQueryRunner.getCatalogManager(), + localQueryRunner.getLanguageFunctionManager(), new QueryManagerConfig()); } diff --git a/core/trino-main/src/test/java/io/trino/dispatcher/TestLocalDispatchQuery.java b/core/trino-main/src/test/java/io/trino/dispatcher/TestLocalDispatchQuery.java index 1a290477fa4e..77de01cfb286 100644 --- a/core/trino-main/src/test/java/io/trino/dispatcher/TestLocalDispatchQuery.java +++ b/core/trino-main/src/test/java/io/trino/dispatcher/TestLocalDispatchQuery.java @@ -47,6 +47,7 @@ import io.trino.metadata.GlobalFunctionCatalog; import io.trino.metadata.InMemoryNodeManager; import io.trino.metadata.InternalNodeManager; +import io.trino.metadata.LanguageFunctionProvider; import io.trino.metadata.Metadata; import io.trino.metadata.SessionPropertyManager; import io.trino.operator.OperatorStats; @@ -131,7 +132,8 @@ public void testSubmittedForDispatchedQuery() metadata, new FunctionManager( new ConnectorCatalogServiceProvider<>("function provider", new NoConnectorServicesProvider(), ConnectorServices::getFunctionProvider), - new GlobalFunctionCatalog()), + new GlobalFunctionCatalog(), + LanguageFunctionProvider.DISABLED), new QueryMonitorConfig()); CreateTable createTable = new CreateTable(QualifiedName.of("table"), ImmutableList.of(), FAIL, ImmutableList.of(), Optional.empty()); QueryPreparer.PreparedQuery preparedQuery = new QueryPreparer.PreparedQuery(createTable, ImmutableList.of(), Optional.empty()); diff --git a/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java b/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java index 58e468ea2a1f..95e7d392277d 100644 --- a/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java @@ -38,6 +38,7 @@ import io.trino.memory.QueryContext; import io.trino.memory.context.LocalMemoryContext; import io.trino.metadata.InternalNode; +import io.trino.metadata.WorkerLanguageFunctionProvider; import io.trino.operator.DirectExchangeClient; import io.trino.operator.DirectExchangeClientSupplier; import io.trino.operator.RetryPolicy; @@ -322,6 +323,7 @@ private SqlTaskManager createSqlTaskManager(TaskManagerConfig taskManagerConfig, new EmbedVersion("testversion"), new NoConnectorServicesProvider(), createTestingPlanner(), + new WorkerLanguageFunctionProvider(), new MockLocationFactory(), taskExecutor, createTestSplitMonitor(), diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index 147da449ae25..681afda39efe 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -121,6 +121,7 @@ public MockRemoteTask createTableScanTask(TaskId taskId, InternalNode newNode, L new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), ImmutableList.of(), + ImmutableList.of(), Optional.empty()); ImmutableMultimap.Builder initialSplits = ImmutableMultimap.builder(); diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index 6702a7e5bd2b..321a32999821 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -103,6 +103,7 @@ private TaskTestUtils() {} .withBucketToPartition(Optional.of(new int[1])), StatsAndCosts.empty(), ImmutableList.of(), + ImmutableList.of(), Optional.empty()); public static final DynamicFilterId DYNAMIC_FILTER_SOURCE_ID = new DynamicFilterId("filter"); @@ -127,6 +128,7 @@ private TaskTestUtils() {} .withBucketToPartition(Optional.of(new int[1])), StatsAndCosts.empty(), ImmutableList.of(), + ImmutableList.of(), Optional.empty()); public static LocalExecutionPlanner createTestingPlanner() diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java index 5431adcff169..44e4dab8b735 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java @@ -245,6 +245,7 @@ private static PlanFragment createExchangePlanFragment() new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()), StatsAndCosts.empty(), ImmutableList.of(), + ImmutableList.of(), Optional.empty()); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java index 8ddf8c6d6158..7f06895cb334 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java @@ -267,6 +267,7 @@ private static PlanFragment createValuesPlan() new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), ImmutableList.of(), + ImmutableList.of(), Optional.empty()); return planFragment; diff --git a/core/trino-main/src/test/java/io/trino/execution/TestTaskExecutorStuckSplits.java b/core/trino-main/src/test/java/io/trino/execution/TestTaskExecutorStuckSplits.java index baa4380f6440..7a36521c9303 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestTaskExecutorStuckSplits.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestTaskExecutorStuckSplits.java @@ -32,6 +32,7 @@ import io.trino.execution.executor.timesharing.TimeSharingTaskExecutor; import io.trino.memory.LocalMemoryManager; import io.trino.memory.NodeMemoryConfig; +import io.trino.metadata.WorkerLanguageFunctionProvider; import io.trino.spi.connector.CatalogHandle; import io.trino.spiller.LocalSpillManager; import io.trino.spiller.NodeSpillConfig; @@ -122,6 +123,7 @@ private SqlTaskManager createSqlTaskManager( new EmbedVersion("testversion"), new NoConnectorServicesProvider(), createTestingPlanner(), + new WorkerLanguageFunctionProvider(), new BaseTestSqlTaskManager.MockLocationFactory(), taskExecutor, createTestSplitMonitor(), diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java index a2d61d13a107..8c3a2bc28541 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java @@ -583,6 +583,7 @@ private PlanFragment createFragment(TableHandle firstTableHandle, TableHandle se new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), ImmutableList.of(), + ImmutableList.of(), Optional.empty()); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java index 17626637086c..441f3c2c4ebf 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java @@ -395,6 +395,7 @@ private static PlanFragment createFragment() new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), ImmutableList.of(), + ImmutableList.of(), Optional.empty()); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java index 75a183b50189..528e8b15d051 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java @@ -366,6 +366,7 @@ private static SubPlan createSubPlan(String fragmentId, PlanNode plan, Listbuilder() + .add(new CatalogSchemaName(GlobalSystemConnector.NAME, QUERY_LOCAL_SCHEMA)) .add(new CatalogSchemaName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA)) .add(new CatalogSchemaName("normal", "schema")) .add(new CatalogSchemaName("who.uses.periods", "in.schema.names")) diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index 1a3df6949e83..c0a8fc0f1180 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -73,6 +73,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeOperators; +import io.trino.spi.type.TypeSignature; import io.trino.sql.DynamicFilters; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; @@ -81,6 +82,8 @@ import io.trino.sql.tree.SymbolReference; import io.trino.testing.TestingSplit; import io.trino.type.TypeDeserializer; +import io.trino.type.TypeSignatureDeserializer; +import io.trino.type.TypeSignatureKeyDeserializer; import jakarta.ws.rs.Consumes; import jakarta.ws.rs.DELETE; import jakarta.ws.rs.DefaultValue; @@ -558,6 +561,8 @@ public void configure(Binder binder) binder.bind(JsonMapper.class).in(SINGLETON); binder.bind(Metadata.class).toInstance(createTestMetadataManager()); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + jsonBinder(binder).addDeserializerBinding(TypeSignature.class).to(TypeSignatureDeserializer.class); + jsonBinder(binder).addKeyDeserializerBinding(TypeSignature.class).to(TypeSignatureKeyDeserializer.class); jsonCodecBinder(binder).bindJsonCodec(TaskStatus.class); jsonCodecBinder(binder).bindJsonCodec(VersionedDynamicFilterDomains.class); jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class); diff --git a/core/trino-main/src/test/java/io/trino/sql/TestSqlPath.java b/core/trino-main/src/test/java/io/trino/sql/TestSqlPath.java index 4a591b3ca234..b6e0ea11cf5b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestSqlPath.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestSqlPath.java @@ -13,39 +13,45 @@ */ package io.trino.sql; +import io.trino.connector.system.GlobalSystemConnector; import io.trino.spi.connector.CatalogSchemaName; import io.trino.sql.parser.ParsingException; import org.junit.jupiter.api.Test; import java.util.Optional; +import static io.trino.metadata.GlobalFunctionCatalog.BUILTIN_SCHEMA; +import static io.trino.metadata.LanguageFunctionManager.QUERY_LOCAL_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; class TestSqlPath { + private static final CatalogSchemaName INLINE_SCHEMA_NAME = new CatalogSchemaName(GlobalSystemConnector.NAME, QUERY_LOCAL_SCHEMA); + private static final CatalogSchemaName BUILTIN_SCHEMA_NAME = new CatalogSchemaName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA); + @Test void empty() { assertThat(SqlPath.EMPTY_PATH.getRawPath()).isEmpty(); - assertThat(SqlPath.EMPTY_PATH.getPath()).containsExactly(new CatalogSchemaName("system", "builtin")); + assertThat(SqlPath.EMPTY_PATH.getPath()).containsExactly(INLINE_SCHEMA_NAME, BUILTIN_SCHEMA_NAME); } @Test void parsing() { assertThat(SqlPath.buildPath("a.b", Optional.empty()).getPath()) - .containsExactly(new CatalogSchemaName("system", "builtin"), new CatalogSchemaName("a", "b")); + .containsExactly(INLINE_SCHEMA_NAME, BUILTIN_SCHEMA_NAME, new CatalogSchemaName("a", "b")); assertThat(SqlPath.buildPath("a.b, c.d", Optional.empty()).getPath()) - .containsExactly(new CatalogSchemaName("system", "builtin"), new CatalogSchemaName("a", "b"), new CatalogSchemaName("c", "d")); + .containsExactly(INLINE_SCHEMA_NAME, BUILTIN_SCHEMA_NAME, new CatalogSchemaName("a", "b"), new CatalogSchemaName("c", "d")); assertThat(SqlPath.buildPath("y", Optional.of("x")).getPath()) - .containsExactly(new CatalogSchemaName("system", "builtin"), new CatalogSchemaName("x", "y")); + .containsExactly(INLINE_SCHEMA_NAME, BUILTIN_SCHEMA_NAME, new CatalogSchemaName("x", "y")); assertThat(SqlPath.buildPath("y, z", Optional.of("x")).getPath()) - .containsExactly(new CatalogSchemaName("system", "builtin"), new CatalogSchemaName("x", "y"), new CatalogSchemaName("x", "z")); + .containsExactly(INLINE_SCHEMA_NAME, BUILTIN_SCHEMA_NAME, new CatalogSchemaName("x", "y"), new CatalogSchemaName("x", "z")); assertThat(SqlPath.buildPath("a.b, c.d", Optional.of("x")).getPath()) - .containsExactly(new CatalogSchemaName("system", "builtin"), new CatalogSchemaName("a", "b"), new CatalogSchemaName("c", "d")); + .containsExactly(INLINE_SCHEMA_NAME, BUILTIN_SCHEMA_NAME, new CatalogSchemaName("a", "b"), new CatalogSchemaName("c", "d")); assertThat(SqlPath.buildPath("a.b, y", Optional.of("x")).getPath()) - .containsExactly(new CatalogSchemaName("system", "builtin"), new CatalogSchemaName("a", "b"), new CatalogSchemaName("x", "y")); + .containsExactly(INLINE_SCHEMA_NAME, BUILTIN_SCHEMA_NAME, new CatalogSchemaName("a", "b"), new CatalogSchemaName("x", "y")); assertThat(SqlPath.buildPath("a.b, c.d", Optional.empty()).getRawPath()).isEqualTo("a.b, c.d"); } diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java index 2d4329ffcca3..e3376c25803d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java @@ -3918,6 +3918,26 @@ public void testLambdaWithInvalidParameterCount() .hasMessageMatching("line 1:39: Expected a lambda that takes 2 argument\\(s\\) but got 3"); } + @Test + public void testInvalidInlineFunction() + { + assertFails("WITH FUNCTION test.abc() RETURNS int RETURN 42 SELECT 123") + .hasErrorCode(SYNTAX_ERROR) + .hasMessage("line 1:6: Inline function names cannot be qualified: test.abc"); + + assertFails("WITH function abc() RETURNS int SECURITY DEFINER RETURN 42 SELECT 123") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:33: Security mode not supported for inline functions"); + + assertFails(""" + CREATE VIEW test AS + WITH FUNCTION abc() RETURNS int RETURN 42 + SELECT 123 x + """) + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 2:6: Views cannot contain inline functions"); + } + @Test public void testInvalidDelete() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java index 636b1c532413..761c7bd41b64 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java @@ -64,6 +64,7 @@ public void setUp() localQueryRunner.getFunctionManager(), localQueryRunner.getTransactionManager(), localQueryRunner.getCatalogManager(), + localQueryRunner.getLanguageFunctionManager(), new QueryManagerConfig()); } @@ -143,6 +144,7 @@ public void testPartitionCountInPlanFragment() private SubPlan fragment(Plan plan) { + localQueryRunner.getLanguageFunctionManager().registerQuery(session); return inTransaction(session -> planFragmenter.createSubPlans(session, plan, false, WarningCollector.NOOP)); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java index 489a5b02316d..8cbc29e37399 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java @@ -197,6 +197,7 @@ private static SubPlan createSubPlan(String fragmentId, PlanNode plan, List ImmutableSet.of()); + Metadata metadata = this.metadata; if (metadata == null) { TestMetadataManagerBuilder builder = MetadataManager.testMetadataManagerBuilder() .withTypeManager(typeManager) + .withLanguageFunctionManager(languageFunctionManager) .withGlobalFunctionCatalog(globalFunctionCatalog); if (transactionManager != null) { builder.withTransactionManager(transactionManager); @@ -133,7 +140,7 @@ public PlannerContext build() metadata = builder.build(); } - FunctionManager functionManager = new FunctionManager(CatalogServiceProvider.fail(), globalFunctionCatalog); + FunctionManager functionManager = new FunctionManager(CatalogServiceProvider.fail(), globalFunctionCatalog, LanguageFunctionProvider.DISABLED); globalFunctionCatalog.addFunctions(new InternalFunctionBundle( new JsonExistsFunction(functionManager, metadata, typeManager), new JsonValueFunction(functionManager, metadata, typeManager), @@ -146,6 +153,7 @@ public PlannerContext build() blockEncodingSerde, typeManager, functionManager, + languageFunctionManager, noopTracer()); } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/LanguageFunction.java b/core/trino-spi/src/main/java/io/trino/spi/function/LanguageFunction.java new file mode 100644 index 000000000000..710a585ce947 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/function/LanguageFunction.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.function; + +import io.trino.spi.connector.CatalogSchemaName; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record LanguageFunction( + String signatureToken, + String sql, + List path, + Optional owner) +{ + public LanguageFunction + { + requireNonNull(signatureToken, "signatureToken is null"); + requireNonNull(sql, "sql is null"); + path = List.copyOf(requireNonNull(path, "path is null")); + requireNonNull(owner, "owner is null"); + } +} diff --git a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java index 3fa5c76f7a89..c1890f3cfb78 100644 --- a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java +++ b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java @@ -31,6 +31,7 @@ import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.metadata.FunctionBundle; import io.trino.metadata.FunctionManager; +import io.trino.metadata.LanguageFunctionManager; import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.SessionPropertyManager; @@ -246,6 +247,12 @@ public FunctionManager getFunctionManager() return source.getFunctionManager(); } + @Override + public LanguageFunctionManager getLanguageFunctionManager() + { + return source.getLanguageFunctionManager(); + } + @Override public SplitManager getSplitManager() { diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java index 982b9b6a7169..d5b1bb8a502e 100644 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java @@ -286,6 +286,7 @@ private SubPlan getSubPlan(Session session, @Language("SQL") String sql) queryRunner.getFunctionManager(), queryRunner.getTransactionManager(), new CoordinatorDynamicCatalogManager(new InMemoryCatalogStore(), new LazyCatalogFactory(), directExecutor()), + queryRunner.getLanguageFunctionManager(), new QueryManagerConfig()).createSubPlans(transactionSession, plan, false, WarningCollector.NOOP); }); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index 5cb4c66e35d3..ce8e53716e6c 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -77,6 +77,7 @@ import static java.util.stream.Collectors.toList; import static java.util.stream.IntStream.range; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; @@ -6612,6 +6613,99 @@ public void testColumnNames() assertEquals(showCreateTableResult.getColumnNames(), ImmutableList.of("Create Table")); } + @Test + public void testInlineSqlFunctions() + { + assertThat(query(""" + WITH FUNCTION abc(x integer) RETURNS integer RETURN x * 2 + SELECT abc(21) + """)) + .matches("VALUES 42"); + assertThat(query(""" + WITH FUNCTION abc(x integer) RETURNS integer RETURN abs(x) + SELECT abc(-21) + """)) + .matches("VALUES 21"); + + assertThat(query(""" + WITH + FUNCTION abc(x integer) RETURNS integer RETURN x * 2, + FUNCTION xyz(x integer) RETURNS integer RETURN abc(x) + 1 + SELECT xyz(21) + """)) + .matches("VALUES 43"); + + assertThat(query(""" + WITH + FUNCTION my_pow(n int, p int) + RETURNS int + BEGIN + DECLARE r int DEFAULT n; + top: LOOP + IF p <= 1 THEN + LEAVE top; + END IF; + SET r = r * n; + SET p = p - 1; + END LOOP; + RETURN r; + END + SELECT my_pow(2, 8) + """)) + .matches("VALUES 256"); + + // validations for inline functions + assertQueryFails("WITH FUNCTION a.b() RETURNS int RETURN 42 SELECT a.b()", + "line 1:6: Inline function names cannot be qualified: a.b"); + + assertQueryFails("WITH FUNCTION x() RETURNS int SECURITY INVOKER RETURN 42 SELECT x()", + "line 1:31: Security mode not supported for inline functions"); + + assertQueryFails("WITH FUNCTION x() RETURNS bigint SECURITY DEFINER RETURN 42 SELECT x()", + "line 1:34: Security mode not supported for inline functions"); + + // Verify the current restrictions on inline functions are enforced + + // inline function can mask a global function + assertThat(query(""" + WITH FUNCTION abs(x integer) RETURNS integer RETURN x * 2 + SELECT abs(-10) + """)) + .matches("VALUES -20"); + assertThat(query(""" + WITH + FUNCTION abs(x integer) RETURNS integer RETURN x * 2, + FUNCTION wrap_abs(x integer) RETURNS integer RETURN abs(x) + SELECT wrap_abs(-10) + """)) + .matches("VALUES -20"); + + // inline function can have the same name as a global function with a different signature + assertThat(query(""" + WITH FUNCTION abs(x varchar) RETURNS varchar RETURN reverse(x) + SELECT abs('abc') + """)) + .skippingTypesCheck() + .matches("VALUES 'cba'"); + + // inline functions must be declared before they are used + assertThatThrownBy(() -> query(""" + WITH + FUNCTION a(x integer) RETURNS integer RETURN b(x), + FUNCTION b(x integer) RETURNS integer RETURN x * 2 + SELECT a(10) + """)) + .hasMessage("line 3:8: Function 'b' not registered"); + + // inline function cannot be recursive + // note: mutual recursion is not supported either, but it is not tested due to the forward declaration limitation above + assertThatThrownBy(() -> query(""" + WITH FUNCTION a(x integer) RETURNS integer RETURN a(x) + SELECT a(10) + """)) + .hasMessage("line 3:8: Recursive language functions are not supported: a(integer):integer"); + } + private static ZonedDateTime zonedDateTime(String value) { return ZONED_DATE_TIME_FORMAT.parse(value, ZonedDateTime::from); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java index ae8fc0fa254e..acd427cda249 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java @@ -33,6 +33,7 @@ import io.trino.metadata.AllNodes; import io.trino.metadata.FunctionBundle; import io.trino.metadata.FunctionManager; +import io.trino.metadata.LanguageFunctionManager; import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.SessionPropertyManager; @@ -370,6 +371,12 @@ public FunctionManager getFunctionManager() return coordinator.getFunctionManager(); } + @Override + public LanguageFunctionManager getLanguageFunctionManager() + { + return coordinator.getLanguageFunctionManager(); + } + @Override public SplitManager getSplitManager() { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java index ae5e2534ec60..5a9824e46624 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java @@ -19,6 +19,7 @@ import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.metadata.FunctionBundle; import io.trino.metadata.FunctionManager; +import io.trino.metadata.LanguageFunctionManager; import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.SessionPropertyManager; @@ -147,6 +148,12 @@ public FunctionManager getFunctionManager() return server.getFunctionManager(); } + @Override + public LanguageFunctionManager getLanguageFunctionManager() + { + return server.getLanguageFunctionManager(); + } + @Override public SplitManager getSplitManager() { diff --git a/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java b/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java index 9cec74e401ba..195d43bae951 100644 --- a/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java +++ b/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java @@ -348,6 +348,8 @@ public void testAccessControl() assertAccessAllowed("SHOW CREATE TABLE lineitem", privilege("orders", SHOW_CREATE_TABLE)); assertAccessDenied("SELECT my_function(1)", "Cannot execute function my_function", privilege("mock.function.my_function", EXECUTE_FUNCTION)); assertAccessAllowed("SELECT my_function(1)", privilege("max", EXECUTE_FUNCTION)); + assertAccessAllowed("SELECT abs(-10)", privilege("abs", EXECUTE_FUNCTION)); + assertAccessAllowed("SELECT abs(-10)", privilege("system.builtin.abs", EXECUTE_FUNCTION)); assertAccessAllowed("SHOW STATS FOR lineitem"); assertAccessAllowed("SHOW STATS FOR lineitem", privilege("orders", SELECT_COLUMN)); assertAccessAllowed("SHOW STATS FOR (SELECT * FROM lineitem)"); @@ -608,9 +610,10 @@ public void testFunctionAccessControl() "Cannot execute function my_function", new TestingPrivilege(Optional.empty(), "mock.function.my_function", EXECUTE_FUNCTION)); - // builtin functions are always allowed, and there are no security checks + // inline and builtin functions are always allowed, and there are no security checks TestingPrivilege denyAllFunctionCalls = new TestingPrivilege(Optional.empty(), name -> true, EXECUTE_FUNCTION); assertAccessAllowed("SELECT abs(42)", denyAllFunctionCalls); + assertAccessAllowed("WITH FUNCTION foo() RETURNS int RETURN 42 SELECT foo()", denyAllFunctionCalls); assertAccessDenied("SELECT my_function(42)", "Cannot execute function my_function", denyAllFunctionCalls); TestingPrivilege denyNonMyFunctionCalls = new TestingPrivilege(Optional.empty(), name -> !name.equals("mock.function.my_function"), EXECUTE_FUNCTION); diff --git a/testing/trino-tests/src/test/java/io/trino/security/TestingSystemSecurityMetadata.java b/testing/trino-tests/src/test/java/io/trino/security/TestingSystemSecurityMetadata.java index c554c12e52a4..0870af6ef057 100644 --- a/testing/trino-tests/src/test/java/io/trino/security/TestingSystemSecurityMetadata.java +++ b/testing/trino-tests/src/test/java/io/trino/security/TestingSystemSecurityMetadata.java @@ -20,6 +20,7 @@ import io.trino.metadata.SystemSecurityMetadata; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaTableName; +import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.security.GrantInfo; import io.trino.spi.security.Identity; import io.trino.spi.security.Privilege; @@ -242,6 +243,12 @@ public void setViewOwner(Session session, CatalogSchemaTableName view, TrinoPrin viewOwners.put(view, Identity.ofUser(principal.getName())); } + @Override + public Optional getFunctionRunAsIdentity(Session session, CatalogSchemaFunctionName functionName) + { + return Optional.empty(); + } + @Override public void schemaCreated(Session session, CatalogSchemaName schema) {}