diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java index a8d0b901f..ba69b3ad7 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java @@ -28,6 +28,8 @@ import dev.openfeature.sdk.exceptions.TypeMismatchError; import dev.openfeature.sdk.internal.TriConsumer; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; @@ -38,11 +40,15 @@ */ @Slf4j public class InProcessResolver implements Resolver { + + static final String STATE_WATCHER_THREAD_NAME = "InProcessResolver.stateWatcher"; private final Storage flagStore; private final TriConsumer onConnectionEvent; private final Operator operator; private final String scope; private final QueueSource queueSource; + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final AtomicReference stateWatcher = new AtomicReference<>(); /** * Resolves flag values using @@ -67,52 +73,54 @@ public InProcessResolver( */ public void init() throws Exception { flagStore.init(); - final Thread stateWatcher = new Thread(() -> { - try { - while (true) { - final StorageStateChange storageStateChange = - flagStore.getStateQueue().take(); - switch (storageStateChange.getStorageState()) { - case OK: - log.debug("onConnectionEvent.accept ProviderEvent.PROVIDER_CONFIGURATION_CHANGED"); - - var eventDetails = ProviderEventDetails.builder() - .flagsChanged(storageStateChange.getChangedFlagsKeys()) - .message("configuration changed") - .build(); - - onConnectionEvent.accept( - ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, - eventDetails, - storageStateChange.getSyncMetadata()); - - log.debug("post onConnectionEvent.accept ProviderEvent.PROVIDER_CONFIGURATION_CHANGED"); - break; - case STALE: - onConnectionEvent.accept(ProviderEvent.PROVIDER_ERROR, null, null); - break; - case ERROR: - onConnectionEvent.accept( - ProviderEvent.PROVIDER_ERROR, - ProviderEventDetails.builder() - .errorCode(ErrorCode.PROVIDER_FATAL) - .build(), - null); - break; - default: - log.warn(String.format( - "Storage emitted unhandled status: %s", storageStateChange.getStorageState())); - } - } - } catch (InterruptedException e) { - log.warn("Storage state watcher interrupted", e); - Thread.currentThread().interrupt(); - } - }); + final Thread stateWatcher = new Thread(this::stateWatcher, STATE_WATCHER_THREAD_NAME); stateWatcher.setDaemon(true); + this.stateWatcher.set(stateWatcher); stateWatcher.start(); } + private void stateWatcher() { + try { + while (!shutdown.get()) { + final StorageStateChange storageStateChange = + flagStore.getStateQueue().take(); + switch (storageStateChange.getStorageState()) { + case OK: + log.debug("onConnectionEvent.accept ProviderEvent.PROVIDER_CONFIGURATION_CHANGED"); + + var eventDetails = ProviderEventDetails.builder() + .flagsChanged(storageStateChange.getChangedFlagsKeys()) + .message("configuration changed") + .build(); + + onConnectionEvent.accept( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + eventDetails, + storageStateChange.getSyncMetadata()); + + log.debug("post onConnectionEvent.accept ProviderEvent.PROVIDER_CONFIGURATION_CHANGED"); + break; + case STALE: + onConnectionEvent.accept(ProviderEvent.PROVIDER_ERROR, null, null); + break; + case ERROR: + onConnectionEvent.accept( + ProviderEvent.PROVIDER_ERROR, + ProviderEventDetails.builder() + .errorCode(ErrorCode.PROVIDER_FATAL) + .build(), + null); + break; + default: + log.warn(String.format( + "Storage emitted unhandled status: %s", storageStateChange.getStorageState())); + } + } + } catch (InterruptedException e) { + log.debug("Storage state watcher interrupted, most likely shutdown was invoked", e); + } + } + /** * Called when the provider enters error state after grace period. * Attempts to reinitialize the sync connector if enabled. @@ -132,7 +140,17 @@ public void onError() { * @throws InterruptedException if stream can't be closed within deadline. */ public void shutdown() throws InterruptedException { + if (!shutdown.compareAndSet(false, true)) { + log.debug("Shutdown already in progress or completed"); + return; + } flagStore.shutdown(); + stateWatcher.getAndUpdate(existing -> { + if (existing != null && existing.isAlive()) { + existing.interrupt(); + } + return null; + }); } /** diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java index 04670b397..4af959861 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolverTest.java @@ -52,6 +52,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import org.awaitility.Awaitility; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -543,6 +544,41 @@ void flagSetMetadataIsOverwrittenByFlagMetadataToEvaluation() throws Exception { assertThat(providerEvaluation.getFlagMetadata().getString("key")).isEqualTo("expected"); } + @Test + void testStateWatcherThreadIsCleanedUpDuringShutdown() throws Exception { + // given + final Map flagMap = new HashMap<>(); + flagMap.put("booleanFlag", BOOLEAN_FLAG); + + var initialThreadCount = currentDaemonThreadCount(); + + var queue = new LinkedBlockingQueue(); + InProcessResolver inProcessResolver = + getInProcessResolverWith(new MockStorage(flagMap, queue), (event, details, metadata) -> {}); + + // when + inProcessResolver.init(); + Thread stateWatcher = Thread.getAllStackTraces().keySet().stream() + .filter(thread -> InProcessResolver.STATE_WATCHER_THREAD_NAME.equals(thread.getName())) + .findFirst() + .orElseThrow(); + var threadCountAfterInit = currentDaemonThreadCount(); + var stateWatcherWasStarted = stateWatcher.isAlive(); + inProcessResolver.shutdown(); + + // then + assertThat(stateWatcherWasStarted).isTrue(); + assertThat(threadCountAfterInit).isGreaterThan(initialThreadCount); + Awaitility.await().until(() -> !stateWatcher.isAlive()); + assertThat(currentDaemonThreadCount()).isEqualTo(initialThreadCount); + } + + private long currentDaemonThreadCount() { + return Thread.getAllStackTraces().keySet().stream() + .filter(Thread::isDaemon) + .count(); + } + private InProcessResolver getInProcessResolverWith(final FlagdOptions options, final MockStorage storage) throws NoSuchFieldException, IllegalAccessException {