diff --git a/runtime/service/src/main/java/org/apache/polaris/service/config/ServiceProducers.java b/runtime/service/src/main/java/org/apache/polaris/service/config/ServiceProducers.java index 080cbc5ba7..c9726326f8 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/config/ServiceProducers.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/config/ServiceProducers.java @@ -30,7 +30,6 @@ import jakarta.enterprise.inject.Instance; import jakarta.enterprise.inject.Produces; import jakarta.inject.Singleton; -import jakarta.ws.rs.container.ContainerRequestContext; import jakarta.ws.rs.core.Context; import jakarta.ws.rs.core.SecurityContext; import java.security.Principal; @@ -74,7 +73,6 @@ import org.apache.polaris.service.catalog.io.FileIOConfiguration; import org.apache.polaris.service.catalog.io.FileIOFactory; import org.apache.polaris.service.context.RealmContextConfiguration; -import org.apache.polaris.service.context.RealmContextFilter; import org.apache.polaris.service.context.RealmContextResolver; import org.apache.polaris.service.credentials.PolarisCredentialManagerConfiguration; import org.apache.polaris.service.events.PolarisEventListenerConfiguration; @@ -124,12 +122,6 @@ public PolarisDiagnostics polarisDiagnostics() { // Polaris core beans - request scope - @Produces - @RequestScoped - public RealmContext realmContext(@Context ContainerRequestContext request) { - return (RealmContext) request.getProperty(RealmContextFilter.REALM_CONTEXT_KEY); - } - @Produces @RequestScoped public CallContext polarisCallContext( @@ -417,6 +409,7 @@ public ManagedExecutor taskExecutor(TaskHandlerConfiguration config) { return SmallRyeManagedExecutor.builder() .injectionPointName("task-executor") .propagated(ThreadContext.ALL_REMAINING) + .cleared(ThreadContext.CDI) .maxAsync(config.maxConcurrentTasks()) .maxQueued(config.maxQueuedTasks()) .build(); diff --git a/runtime/service/src/main/java/org/apache/polaris/service/context/RealmContextFilter.java b/runtime/service/src/main/java/org/apache/polaris/service/context/RealmContextFilter.java index 78558bb5d5..9d2138f32f 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/context/RealmContextFilter.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/context/RealmContextFilter.java @@ -26,6 +26,7 @@ import jakarta.ws.rs.core.Response; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.polaris.service.config.FilterPriorities; +import org.apache.polaris.service.context.catalog.RealmContextHolder; import org.jboss.resteasy.reactive.server.ServerRequestFilter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,6 +38,7 @@ public class RealmContextFilter { private static final Logger LOGGER = LoggerFactory.getLogger(RealmContextFilter.class); @Inject RealmContextResolver realmContextResolver; + @Inject RealmContextHolder realmContextHolder; @ServerRequestFilter(preMatching = true, priority = FilterPriorities.REALM_CONTEXT_FILTER) public Uni resolveRealmContext(ContainerRequestContext rc) { @@ -49,6 +51,7 @@ public Uni resolveRealmContext(ContainerRequestContext rc) { rc.getUriInfo().getPath(), rc.getHeaders()::getFirst)) .onItem() + .invoke(realmContext -> realmContextHolder.set(realmContext)) .invoke(realmContext -> rc.setProperty(REALM_CONTEXT_KEY, realmContext)) // ContextLocals is used by RealmIdTagContributor to add the realm id to metrics .invoke(realmContext -> ContextLocals.put(REALM_CONTEXT_KEY, realmContext)) diff --git a/runtime/service/src/main/java/org/apache/polaris/service/context/catalog/RealmContextHolder.java b/runtime/service/src/main/java/org/apache/polaris/service/context/catalog/RealmContextHolder.java new file mode 100644 index 0000000000..6397119eaf --- /dev/null +++ b/runtime/service/src/main/java/org/apache/polaris/service/context/catalog/RealmContextHolder.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.polaris.service.context.catalog; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.inject.Produces; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.polaris.core.context.RealmContext; + +@RequestScoped +public class RealmContextHolder { + private final AtomicReference realmContext = new AtomicReference<>(); + + @Produces + @RequestScoped + public RealmContext get() { + // Note: if this producer is called before the context is set a CDI exception will occur. + return realmContext.get(); + } + + public void set(RealmContext rc) { + if (!realmContext.compareAndSet(null, rc)) { + throw new IllegalStateException("RealmContext already set"); + } + } +} diff --git a/runtime/service/src/main/java/org/apache/polaris/service/task/TaskExecutorImpl.java b/runtime/service/src/main/java/org/apache/polaris/service/task/TaskExecutorImpl.java index d6def1afc0..aeaa53260e 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/task/TaskExecutorImpl.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/task/TaskExecutorImpl.java @@ -27,6 +27,8 @@ import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; import java.time.Clock; import java.util.List; @@ -37,12 +39,14 @@ import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import org.apache.commons.lang3.function.TriConsumer; import org.apache.polaris.core.context.CallContext; import org.apache.polaris.core.entity.PolarisBaseEntity; import org.apache.polaris.core.entity.PolarisEntityType; import org.apache.polaris.core.entity.TaskEntity; import org.apache.polaris.core.persistence.MetaStoreManagerFactory; import org.apache.polaris.core.persistence.PolarisMetaStoreManager; +import org.apache.polaris.service.context.catalog.RealmContextHolder; import org.apache.polaris.service.events.AfterAttemptTaskEvent; import org.apache.polaris.service.events.BeforeAttemptTaskEvent; import org.apache.polaris.service.events.PolarisEventMetadata; @@ -65,22 +69,27 @@ public class TaskExecutorImpl implements TaskExecutor { private final Clock clock; private final MetaStoreManagerFactory metaStoreManagerFactory; private final TaskFileIOSupplier fileIOSupplier; + private final RealmContextHolder realmContextHolder; private final List taskHandlers = new CopyOnWriteArrayList<>(); + private final Optional> errorHandler; private final PolarisEventListener polarisEventListener; private final PolarisEventMetadataFactory eventMetadataFactory; @Nullable private final Tracer tracer; @SuppressWarnings("unused") // Required by CDI protected TaskExecutorImpl() { - this(null, null, null, null, null, null, null); + this(null, null, null, null, null, null, null, null, null); } @Inject public TaskExecutorImpl( @Identifier("task-executor") Executor executor, + @Identifier("task-error-handler") + Instance> errorHandler, Clock clock, MetaStoreManagerFactory metaStoreManagerFactory, TaskFileIOSupplier fileIOSupplier, + RealmContextHolder realmContextHolder, PolarisEventListener polarisEventListener, PolarisEventMetadataFactory eventMetadataFactory, @Nullable Tracer tracer) { @@ -88,9 +97,16 @@ public TaskExecutorImpl( this.clock = clock; this.metaStoreManagerFactory = metaStoreManagerFactory; this.fileIOSupplier = fileIOSupplier; + this.realmContextHolder = realmContextHolder; this.polarisEventListener = polarisEventListener; this.eventMetadataFactory = eventMetadataFactory; this.tracer = tracer; + + if (errorHandler != null && errorHandler.isResolvable()) { + this.errorHandler = Optional.of(errorHandler.get()); + } else { + this.errorHandler = Optional.empty(); + } } @Startup @@ -121,6 +137,7 @@ public void addTaskHandler(TaskHandler taskHandler) { @Override @SuppressWarnings("FutureReturnValueIgnored") // it _should_ be okay in this particular case public void addTaskHandlerContext(long taskEntityId, CallContext callContext) { + errorHandler.ifPresent(h -> h.accept(taskEntityId, true, null)); // Unfortunately CallContext is a request-scoped bean and must be cloned now, // because its usage inside the TaskExecutor thread pool will outlive its // lifespan, so the original CallContext will eventually be closed while @@ -142,12 +159,17 @@ public void addTaskHandlerContext(long taskEntityId, CallContext callContext) { if (attempt > 3) { return CompletableFuture.failedFuture(e); } + String realmId = callContext.getRealmContext().getRealmIdentifier(); return CompletableFuture.runAsync( - () -> handleTaskWithTracing(taskEntityId, callContext, eventMetadata, attempt), + () -> { + handleTaskWithTracing(realmId, taskEntityId, callContext, eventMetadata, attempt); + errorHandler.ifPresent(h -> h.accept(taskEntityId, false, null)); + }, executor) .exceptionallyComposeAsync( (t) -> { LOGGER.warn("Failed to handle task entity id {}", taskEntityId, t); + errorHandler.ifPresent(h -> h.accept(taskEntityId, false, e)); return tryHandleTask(taskEntityId, callContext, eventMetadata, t, attempt + 1); }, CompletableFuture.delayedExecutor( @@ -207,8 +229,16 @@ protected void handleTask( } } + @ActivateRequestContext protected void handleTaskWithTracing( - long taskEntityId, CallContext callContext, PolarisEventMetadata eventMetadata, int attempt) { + String realmId, + long taskEntityId, + CallContext callContext, + PolarisEventMetadata eventMetadata, + int attempt) { + // Note: each call to this method runs in a new CDI request context + realmContextHolder.set(() -> realmId); + if (tracer == null) { handleTask(taskEntityId, callContext, eventMetadata, attempt); } else { diff --git a/runtime/service/src/test/java/org/apache/polaris/service/admin/PolarisAuthzTestBase.java b/runtime/service/src/test/java/org/apache/polaris/service/admin/PolarisAuthzTestBase.java index c7b92bf80a..896b5fe600 100644 --- a/runtime/service/src/test/java/org/apache/polaris/service/admin/PolarisAuthzTestBase.java +++ b/runtime/service/src/test/java/org/apache/polaris/service/admin/PolarisAuthzTestBase.java @@ -29,7 +29,6 @@ import jakarta.enterprise.context.RequestScoped; import jakarta.enterprise.inject.Alternative; import jakarta.inject.Inject; -import jakarta.ws.rs.container.ContainerRequestContext; import java.io.IOException; import java.util.Date; import java.util.List; @@ -89,6 +88,7 @@ import org.apache.polaris.service.catalog.policy.PolicyCatalog; import org.apache.polaris.service.config.ReservedProperties; import org.apache.polaris.service.context.catalog.PolarisCallContextCatalogFactory; +import org.apache.polaris.service.context.catalog.RealmContextHolder; import org.apache.polaris.service.events.PolarisEventMetadataFactory; import org.apache.polaris.service.events.listeners.PolarisEventListener; import org.apache.polaris.service.storage.PolarisStorageIntegrationProviderImpl; @@ -202,6 +202,7 @@ public Map getConfigOverrides() { @Inject protected UserSecretsManager userSecretsManager; @Inject protected CallContext callContext; @Inject protected RealmConfig realmConfig; + @Inject protected RealmContextHolder realmContextHolder; protected IcebergCatalog baseCatalog; protected PolarisGenericTableCatalog genericTableCatalog; @@ -231,13 +232,9 @@ public void before(TestInfo testInfo) { RealmContext realmContext = testInfo::getDisplayName; QuarkusMock.installMockForType(realmContext, RealmContext.class); + realmContextHolder.set(realmContext); polarisContext = callContext.getPolarisCallContext(); - ContainerRequestContext containerRequestContext = Mockito.mock(ContainerRequestContext.class); - Mockito.when(containerRequestContext.getProperty(Mockito.anyString())) - .thenReturn("request-id-1"); - QuarkusMock.installMockForType(containerRequestContext, ContainerRequestContext.class); - polarisAuthorizer = new PolarisAuthorizerImpl(realmConfig); PrincipalEntity rootPrincipal = diff --git a/runtime/service/src/test/java/org/apache/polaris/service/it/RestCatalogFileIntegrationTest.java b/runtime/service/src/test/java/org/apache/polaris/service/it/RestCatalogFileIntegrationTest.java index 5aaf858dec..c08d81e2f6 100644 --- a/runtime/service/src/test/java/org/apache/polaris/service/it/RestCatalogFileIntegrationTest.java +++ b/runtime/service/src/test/java/org/apache/polaris/service/it/RestCatalogFileIntegrationTest.java @@ -21,8 +21,12 @@ import io.quarkus.test.junit.QuarkusTest; import io.quarkus.test.junit.QuarkusTestProfile; import io.quarkus.test.junit.TestProfile; +import io.smallrye.common.annotation.Identifier; +import jakarta.inject.Inject; import java.util.Map; import org.apache.polaris.service.it.test.PolarisRestCatalogFileIntegrationTest; +import org.apache.polaris.service.task.TaskErrorHandler; +import org.junit.jupiter.api.AfterEach; @QuarkusTest @TestProfile(RestCatalogFileIntegrationTest.Profile.class) @@ -47,4 +51,13 @@ public Map getConfigOverrides() { "true"); } } + + @Inject + @Identifier("task-error-handler") + TaskErrorHandler taskErrorHandler; + + @AfterEach + void checkTaskExceptions() { + taskErrorHandler.assertNoTaskExceptions(); + } } diff --git a/runtime/service/src/test/java/org/apache/polaris/service/task/TaskErrorHandler.java b/runtime/service/src/test/java/org/apache/polaris/service/task/TaskErrorHandler.java new file mode 100644 index 0000000000..88a8457ba1 --- /dev/null +++ b/runtime/service/src/test/java/org/apache/polaris/service/task/TaskErrorHandler.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.polaris.service.task; + +import io.smallrye.common.annotation.Identifier; +import jakarta.enterprise.context.ApplicationScoped; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.apache.commons.lang3.function.TriConsumer; +import org.assertj.core.api.Assertions; +import org.awaitility.Awaitility; + +@ApplicationScoped +@Identifier("task-error-handler") +public class TaskErrorHandler implements TriConsumer { + + private final ConcurrentMap tasks = new ConcurrentHashMap<>(); + + private record TaskStatus(boolean completed, Throwable error) {} + + @Override + public void accept(Long id, Boolean start, Throwable th) { + if (start) { + tasks.computeIfAbsent(id, x -> new TaskStatus(false, null)); + } else { + tasks.compute(id, (i, s) -> new TaskStatus(true, th != null || s == null ? th : s.error)); + } + } + + public void assertNoTaskExceptions() { + List ids = new ArrayList<>(tasks.keySet()); + TaskStatus incomplete = new TaskStatus(false, null); + Awaitility.await() + .atMost(Duration.ofSeconds(20)) + .until(() -> ids.stream().allMatch(id -> tasks.getOrDefault(id, incomplete).completed)); + + for (Long id : ids) { + TaskStatus s = tasks.remove(id); + Assertions.assertThatCode( + () -> { + if (s.error != null) { + throw s.error; + } + }) + .doesNotThrowAnyException(); + Assertions.assertThat(s.completed).isTrue(); + } + } +} diff --git a/runtime/service/src/test/java/org/apache/polaris/service/task/TaskExecutorImplTest.java b/runtime/service/src/test/java/org/apache/polaris/service/task/TaskExecutorImplTest.java index 15e3b4ed89..4151c9fce3 100644 --- a/runtime/service/src/test/java/org/apache/polaris/service/task/TaskExecutorImplTest.java +++ b/runtime/service/src/test/java/org/apache/polaris/service/task/TaskExecutorImplTest.java @@ -24,6 +24,7 @@ import org.apache.polaris.core.entity.TaskEntity; import org.apache.polaris.core.persistence.PolarisMetaStoreManager; import org.apache.polaris.service.TestServices; +import org.apache.polaris.service.context.catalog.RealmContextHolder; import org.apache.polaris.service.events.AfterAttemptTaskEvent; import org.apache.polaris.service.events.BeforeAttemptTaskEvent; import org.apache.polaris.service.events.PolarisEventMetadata; @@ -62,10 +63,12 @@ void testEventsAreEmitted() { TaskExecutorImpl executor = new TaskExecutorImpl( Runnable::run, + null, testServices.clock(), testServices.metaStoreManagerFactory(), new TaskFileIOSupplier( testServices.fileIOFactory(), testServices.storageAccessConfigProvider()), + new RealmContextHolder(), testServices.polarisEventListener(), testServices.eventMetadataFactory(), null);