Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions core/trino-main/src/main/java/io/trino/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
Expand All @@ -24,7 +23,6 @@
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Span;
import io.trino.client.ProtocolHeaders;
import io.trino.connector.system.GlobalSystemConnector;
import io.trino.metadata.SessionPropertyManager;
import io.trino.security.AccessControl;
import io.trino.security.SecurityContext;
Expand Down Expand Up @@ -58,7 +56,6 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static io.trino.client.ProtocolHeaders.TRINO_HEADERS;
import static io.trino.metadata.GlobalFunctionCatalog.BUILTIN_SCHEMA;
import static io.trino.spi.StandardErrorCode.NOT_FOUND;
import static io.trino.sql.SqlPath.EMPTY_PATH;
import static io.trino.util.Failures.checkCondition;
Expand Down Expand Up @@ -316,6 +313,11 @@ public Optional<Slice> getExchangeEncryptionKey()
return exchangeEncryptionKey;
}

public SessionPropertyManager getSessionPropertyManager()
{
return sessionPropertyManager;
}

public Session beginTransactionId(TransactionId transactionId, TransactionManager transactionManager, AccessControl accessControl)
{
requireNonNull(transactionId, "transactionId is null");
Expand Down Expand Up @@ -584,16 +586,13 @@ private void validateSystemProperties(AccessControl accessControl, Map<String, S
}
}

public Session createViewSession(Optional<String> catalog, Optional<String> schema, Identity identity, List<CatalogSchemaName> path)
public Session createViewSession(Optional<String> catalog, Optional<String> schema, Identity identity, List<CatalogSchemaName> viewPath)
{
return createViewSession(catalog, schema, identity, path.forView(viewPath));
}

public Session createViewSession(Optional<String> catalog, Optional<String> schema, Identity identity, SqlPath sqlPath)
{
// For a view, we prepend the global function schema to the path, which should not be in the path
// We do not change the raw path, as that is use for the current_path function
SqlPath sqlPath = new SqlPath(
ImmutableList.<CatalogSchemaName>builder()
.add(new CatalogSchemaName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA))
.addAll(path)
.build(),
getPath().getRawPath());
return builder(sessionPropertyManager)
.setQueryId(getQueryId())
.setTransactionId(getTransactionId().orElse(null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import io.trino.memory.LocalMemoryManager;
import io.trino.memory.NodeMemoryConfig;
import io.trino.memory.QueryContext;
import io.trino.metadata.LanguageFunctionProvider;
import io.trino.operator.RetryPolicy;
import io.trino.operator.scalar.JoniRegexpFunctions;
import io.trino.operator.scalar.JoniRegexpReplaceLambdaFunction;
Expand Down Expand Up @@ -136,12 +137,14 @@ public class SqlTaskManager

private final CounterStat failedTasks = new CounterStat();
private final Optional<StuckSplitTasksInterrupter> stuckSplitTasksInterrupter;
private final LanguageFunctionProvider languageFunctionProvider;

@Inject
public SqlTaskManager(
VersionEmbedder versionEmbedder,
ConnectorServicesProvider connectorServicesProvider,
LocalExecutionPlanner planner,
LanguageFunctionProvider languageFunctionProvider,
LocationFactory locationFactory,
TaskExecutor taskExecutor,
SplitMonitor splitMonitor,
Expand All @@ -159,6 +162,7 @@ public SqlTaskManager(
this(versionEmbedder,
connectorServicesProvider,
planner,
languageFunctionProvider,
locationFactory,
taskExecutor,
splitMonitor,
Expand All @@ -180,6 +184,7 @@ public SqlTaskManager(
VersionEmbedder versionEmbedder,
ConnectorServicesProvider connectorServicesProvider,
LocalExecutionPlanner planner,
LanguageFunctionProvider languageFunctionProvider,
LocationFactory locationFactory,
TaskExecutor taskExecutor,
SplitMonitor splitMonitor,
Expand All @@ -196,6 +201,7 @@ public SqlTaskManager(
Predicate<List<StackTraceElement>> stuckSplitStackTracePredicate)
{
this.connectorServicesProvider = requireNonNull(connectorServicesProvider, "connectorServicesProvider is null");
this.languageFunctionProvider = languageFunctionProvider;

requireNonNull(nodeInfo, "nodeInfo is null");
infoCacheTime = config.getInfoMaxAge();
Expand Down Expand Up @@ -230,7 +236,10 @@ public SqlTaskManager(
tracer,
sqlTaskExecutionFactory,
taskNotificationExecutor,
sqlTask -> finishedTaskStats.merge(sqlTask.getIoStats()),
sqlTask -> {
languageFunctionProvider.unregisterTask(taskId);
finishedTaskStats.merge(sqlTask.getIoStats());
},
maxBufferSize,
maxBroadcastBufferSize,
requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"),
Expand Down Expand Up @@ -528,6 +537,9 @@ private TaskInfo doUpdateTask(
}
});

fragment.map(PlanFragment::getLanguageFunctions)
.ifPresent(languageFunctions -> languageFunctionProvider.registerTask(taskId, languageFunctions));

sqlTask.recordHeartbeat();
return sqlTask.updateTask(session, stageSpan, fragment, splitAssignments, outputBuffers, dynamicFilterDomains, speculative);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.spi.TrinoException;
import io.trino.spi.connector.CatalogSchemaName;
import io.trino.spi.connector.CatalogSchemaTableName;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.security.GrantInfo;
import io.trino.spi.security.Identity;
import io.trino.spi.security.Privilege;
Expand Down Expand Up @@ -166,6 +167,12 @@ public void setViewOwner(Session session, CatalogSchemaTableName view, TrinoPrin
throw notSupportedException(view.getCatalogName());
}

@Override
public Optional<Identity> getFunctionRunAsIdentity(Session session, CatalogSchemaFunctionName functionName)
{
return Optional.empty();
}

@Override
public void schemaCreated(Session session, CatalogSchemaName schema) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ CatalogFunctionBinding bindFunction(List<TypeSignatureProvider> parameterTypes,

Optional<CatalogFunctionBinding> tryBindFunction(List<TypeSignatureProvider> parameterTypes, Collection<CatalogFunctionMetadata> candidates)
{
if (candidates.isEmpty()) {
return Optional.empty();
}

List<CatalogFunctionMetadata> exactCandidates = candidates.stream()
.filter(function -> function.functionMetadata().getSignature().getTypeVariableConstraints().isEmpty())
.collect(toImmutableList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionProvider;
import io.trino.spi.function.InOut;
import io.trino.spi.function.InvocationConvention;
Expand All @@ -44,16 +43,14 @@
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.primitives.Primitives.wrap;
import static io.trino.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.client.NodeVersion.UNKNOWN;
import static io.trino.metadata.LanguageFunctionManager.isTrinoSqlLanguageFunction;
import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
import static java.lang.String.format;
Expand All @@ -63,14 +60,15 @@
public class FunctionManager
{
private final NonEvictableCache<FunctionKey, ScalarFunctionImplementation> specializedScalarCache;
private final NonEvictableCache<FunctionKey, AggregationImplementation> specializedAggregationCache;
private final NonEvictableCache<FunctionKey, WindowFunctionSupplier> specializedWindowCache;
private final NonEvictableCache<ResolvedFunction, AggregationImplementation> specializedAggregationCache;
private final NonEvictableCache<ResolvedFunction, WindowFunctionSupplier> specializedWindowCache;

private final CatalogServiceProvider<FunctionProvider> functionProviders;
private final GlobalFunctionCatalog globalFunctionCatalog;
private final LanguageFunctionProvider languageFunctionProvider;

@Inject
public FunctionManager(CatalogServiceProvider<FunctionProvider> functionProviders, GlobalFunctionCatalog globalFunctionCatalog)
public FunctionManager(CatalogServiceProvider<FunctionProvider> functionProviders, GlobalFunctionCatalog globalFunctionCatalog, LanguageFunctionProvider languageFunctionProvider)
{
specializedScalarCache = buildNonEvictableCache(CacheBuilder.newBuilder()
.maximumSize(1000)
Expand All @@ -86,6 +84,7 @@ public FunctionManager(CatalogServiceProvider<FunctionProvider> functionProvider

this.functionProviders = requireNonNull(functionProviders, "functionProviders is null");
this.globalFunctionCatalog = requireNonNull(globalFunctionCatalog, "globalFunctionCatalog is null");
this.languageFunctionProvider = requireNonNull(languageFunctionProvider, "functionProvider is null");
}

public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
Expand All @@ -102,19 +101,27 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunc
private ScalarFunctionImplementation getScalarFunctionImplementationInternal(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
{
FunctionDependencies functionDependencies = getFunctionDependencies(resolvedFunction);
ScalarFunctionImplementation scalarFunctionImplementation = getFunctionProvider(resolvedFunction).getScalarFunctionImplementation(
resolvedFunction.getFunctionId(),
resolvedFunction.getSignature(),
functionDependencies,
invocationConvention);

ScalarFunctionImplementation scalarFunctionImplementation;
if (isTrinoSqlLanguageFunction(resolvedFunction.getFunctionId())) {
scalarFunctionImplementation = languageFunctionProvider.specialize(this, resolvedFunction, functionDependencies, invocationConvention);
}
else {
scalarFunctionImplementation = getFunctionProvider(resolvedFunction).getScalarFunctionImplementation(
resolvedFunction.getFunctionId(),
resolvedFunction.getSignature(),
functionDependencies,
invocationConvention);
}

verifyMethodHandleSignature(resolvedFunction.getSignature(), scalarFunctionImplementation, invocationConvention);
return scalarFunctionImplementation;
}

public AggregationImplementation getAggregationImplementation(ResolvedFunction resolvedFunction)
{
try {
return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregationImplementationInternal(resolvedFunction));
return uncheckedCacheGet(specializedAggregationCache, resolvedFunction, () -> getAggregationImplementationInternal(resolvedFunction));
}
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
Expand All @@ -134,7 +141,7 @@ private AggregationImplementation getAggregationImplementationInternal(ResolvedF
public WindowFunctionSupplier getWindowFunctionSupplier(ResolvedFunction resolvedFunction)
{
try {
return uncheckedCacheGet(specializedWindowCache, new FunctionKey(resolvedFunction), () -> getWindowFunctionSupplierInternal(resolvedFunction));
return uncheckedCacheGet(specializedWindowCache, resolvedFunction, () -> getWindowFunctionSupplierInternal(resolvedFunction));
}
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
Expand Down Expand Up @@ -303,58 +310,12 @@ private static void verifyFunctionSignature(boolean check, String message, Objec
}
}

private static class FunctionKey
private record FunctionKey(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
{
private final FunctionId functionId;
private final BoundSignature boundSignature;
private final Optional<InvocationConvention> invocationConvention;

public FunctionKey(ResolvedFunction resolvedFunction)
{
this(resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), Optional.empty());
}

public FunctionKey(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
{
this(resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), Optional.of(invocationConvention));
}

public FunctionKey(FunctionId functionId, BoundSignature boundSignature, Optional<InvocationConvention> invocationConvention)
{
this.functionId = requireNonNull(functionId, "functionId is null");
this.boundSignature = requireNonNull(boundSignature, "boundSignature is null");
this.invocationConvention = requireNonNull(invocationConvention, "invocationConvention is null");
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FunctionKey that = (FunctionKey) o;
return functionId.equals(that.functionId) &&
boundSignature.equals(that.boundSignature) &&
invocationConvention.equals(that.invocationConvention);
}

@Override
public int hashCode()
{
return Objects.hash(functionId, boundSignature, invocationConvention);
}

@Override
public String toString()
private FunctionKey
{
return toStringHelper(this).omitNullValues()
.add("functionId", functionId)
.add("boundSignature", boundSignature)
.add("invocationConvention", invocationConvention.orElse(null))
.toString();
requireNonNull(resolvedFunction, "resolvedFunction is null");
requireNonNull(invocationConvention, "invocationConvention is null");
}
}

Expand All @@ -364,6 +325,6 @@ public static FunctionManager createTestingFunctionManager()
GlobalFunctionCatalog functionCatalog = new GlobalFunctionCatalog();
functionCatalog.addFunctions(SystemFunctionBundle.create(new FeaturesConfig(), typeOperators, new BlockTypeOperators(typeOperators), UNKNOWN));
functionCatalog.addFunctions(new InternalFunctionBundle(new LiteralFunction(new InternalBlockEncodingSerde(new BlockEncodingManager(), TESTING_TYPE_MANAGER))));
return new FunctionManager(CatalogServiceProvider.fail(), functionCatalog);
return new FunctionManager(CatalogServiceProvider.fail(), functionCatalog, LanguageFunctionProvider.DISABLED);
}
}
Loading