diff --git a/src/OpenFeature/Api.cs b/src/OpenFeature/Api.cs index e4a9826c5..33a7c79d9 100644 --- a/src/OpenFeature/Api.cs +++ b/src/OpenFeature/Api.cs @@ -43,7 +43,8 @@ internal Api() { } public async Task SetProviderAsync(FeatureProvider featureProvider) { this._eventExecutor.RegisterDefaultFeatureProvider(featureProvider); - await this._repository.SetProviderAsync(featureProvider, this.GetContext(), this.AfterInitialization, this.AfterError).ConfigureAwait(false); + await this._repository.SetProviderAsync(featureProvider, this.GetContext(), this.AfterInitializationAsync, this.AfterErrorAsync) + .ConfigureAwait(false); } @@ -62,7 +63,8 @@ public async Task SetProviderAsync(string domain, FeatureProvider featureProvide throw new ArgumentNullException(nameof(domain)); } this._eventExecutor.RegisterClientFeatureProvider(domain, featureProvider); - await this._repository.SetProviderAsync(domain, featureProvider, this.GetContext(), this.AfterInitialization, this.AfterError).ConfigureAwait(false); + await this._repository.SetProviderAsync(domain, featureProvider, this.GetContext(), this.AfterInitializationAsync, this.AfterErrorAsync) + .ConfigureAwait(false); } /// @@ -324,7 +326,7 @@ internal void RemoveClientHandler(string client, ProviderEventTypes eventType, E /// /// Update the provider state to READY and emit a READY event after successful init. /// - private async Task AfterInitialization(FeatureProvider provider) + private async Task AfterInitializationAsync(FeatureProvider provider, CancellationToken cancellationToken = default) { provider.Status = ProviderStatus.Ready; var eventPayload = new ProviderEventPayload @@ -334,13 +336,14 @@ private async Task AfterInitialization(FeatureProvider provider) ProviderName = provider.GetMetadata()?.Name, }; - await this._eventExecutor.EventChannel.Writer.WriteAsync(new Event { Provider = provider, EventPayload = eventPayload }).ConfigureAwait(false); + await this._eventExecutor.EventChannel.Writer.WriteAsync(new Event { Provider = provider, EventPayload = eventPayload }, cancellationToken) + .ConfigureAwait(false); } /// /// Update the provider state to ERROR and emit an ERROR after failed init. /// - private async Task AfterError(FeatureProvider provider, Exception? ex) + private async Task AfterErrorAsync(FeatureProvider provider, Exception? ex, CancellationToken cancellationToken = default) { provider.Status = typeof(ProviderFatalException) == ex?.GetType() ? ProviderStatus.Fatal : ProviderStatus.Error; var eventPayload = new ProviderEventPayload @@ -350,7 +353,8 @@ private async Task AfterError(FeatureProvider provider, Exception? ex) ProviderName = provider.GetMetadata()?.Name, }; - await this._eventExecutor.EventChannel.Writer.WriteAsync(new Event { Provider = provider, EventPayload = eventPayload }).ConfigureAwait(false); + await this._eventExecutor.EventChannel.Writer.WriteAsync(new Event { Provider = provider, EventPayload = eventPayload }, cancellationToken) + .ConfigureAwait(false); } /// diff --git a/src/OpenFeature/ProviderRepository.cs b/src/OpenFeature/ProviderRepository.cs index 4cea63b08..5539d0eab 100644 --- a/src/OpenFeature/ProviderRepository.cs +++ b/src/OpenFeature/ProviderRepository.cs @@ -55,8 +55,8 @@ public async ValueTask DisposeAsync() internal async Task SetProviderAsync( FeatureProvider? featureProvider, EvaluationContext context, - Func? afterInitSuccess = null, - Func? afterInitError = null, + Func? afterInitSuccess = null, + Func? afterInitError = null, CancellationToken cancellationToken = default) { // Cannot unset the feature provider. @@ -93,8 +93,8 @@ await InitProviderAsync(this._defaultProvider, context, afterInitSuccess, afterI private static async Task InitProviderAsync( FeatureProvider? newProvider, EvaluationContext context, - Func? afterInitialization, - Func? afterError, + Func? afterInitialization, + Func? afterError, CancellationToken cancellationToken = default) { if (newProvider == null) @@ -108,14 +108,14 @@ private static async Task InitProviderAsync( await newProvider.InitializeAsync(context, cancellationToken).ConfigureAwait(false); if (afterInitialization != null) { - await afterInitialization.Invoke(newProvider).ConfigureAwait(false); + await afterInitialization.Invoke(newProvider, cancellationToken).ConfigureAwait(false); } } catch (Exception ex) { if (afterError != null) { - await afterError.Invoke(newProvider, ex).ConfigureAwait(false); + await afterError.Invoke(newProvider, ex, cancellationToken).ConfigureAwait(false); } } } @@ -138,8 +138,8 @@ private static async Task InitProviderAsync( internal async Task SetProviderAsync(string domain, FeatureProvider? featureProvider, EvaluationContext context, - Func? afterInitSuccess = null, - Func? afterInitError = null, + Func? afterInitSuccess = null, + Func? afterInitError = null, CancellationToken cancellationToken = default) { // Cannot set a provider for a null domain. diff --git a/test/OpenFeature.Tests/ProviderRepositoryTests.cs b/test/OpenFeature.Tests/ProviderRepositoryTests.cs index 4284eaeeb..43fc71355 100644 --- a/test/OpenFeature.Tests/ProviderRepositoryTests.cs +++ b/test/OpenFeature.Tests/ProviderRepositoryTests.cs @@ -39,7 +39,7 @@ public async Task AfterInitialization_Is_Invoked_For_Setting_Default_Provider() providerMock.Status.Returns(ProviderStatus.NotReady); var context = new EvaluationContextBuilder().Build(); var callCount = 0; - await repository.SetProviderAsync(providerMock, context, afterInitSuccess: (theProvider) => + await repository.SetProviderAsync(providerMock, context, afterInitSuccess: (theProvider, ct) => { Assert.Equal(providerMock, theProvider); callCount++; @@ -48,6 +48,31 @@ await repository.SetProviderAsync(providerMock, context, afterInitSuccess: (theP Assert.Equal(1, callCount); } + [Fact] + public async Task AfterInitialization_Is_Invoked_With_CancellationToken() + { + var repository = new ProviderRepository(); + var providerMock = Substitute.For(); + providerMock.Status.Returns(ProviderStatus.NotReady); + + using var cancellationTokenSource = new CancellationTokenSource(); + var cancellationToken = cancellationTokenSource.Token; + + var context = new EvaluationContextBuilder().Build(); + + var initCancellationToken = CancellationToken.None; + await repository.SetProviderAsync(providerMock, context, afterInitSuccess: (theProvider, ct) => + { + Assert.Equal(providerMock, theProvider); + + initCancellationToken = ct; + + return Task.CompletedTask; + }, cancellationToken: cancellationToken); + + Assert.Equal(cancellationToken, initCancellationToken); + } + [Fact] public async Task AfterError_Is_Invoked_If_Initialization_Errors_Default_Provider() { @@ -55,10 +80,10 @@ public async Task AfterError_Is_Invoked_If_Initialization_Errors_Default_Provide var providerMock = Substitute.For(); providerMock.Status.Returns(ProviderStatus.NotReady); var context = new EvaluationContextBuilder().Build(); - providerMock.When(x => x.InitializeAsync(context)).Throw(new Exception("BAD THINGS")); + providerMock.When(x => x.InitializeAsync(context, Arg.Any())).Throw(new Exception("BAD THINGS")); var callCount = 0; Exception? receivedError = null; - await repository.SetProviderAsync(providerMock, context, afterInitError: (theProvider, error) => + await repository.SetProviderAsync(providerMock, context, afterInitError: (theProvider, error, ct) => { Assert.Equal(providerMock, theProvider); callCount++; @@ -69,6 +94,32 @@ await repository.SetProviderAsync(providerMock, context, afterInitError: (thePro Assert.Equal(1, callCount); } + [Fact] + public async Task AfterError_Is_Invoked_With_CancellationToken() + { + var repository = new ProviderRepository(); + var providerMock = Substitute.For(); + providerMock.Status.Returns(ProviderStatus.NotReady); + + using var cancellationTokenSource = new CancellationTokenSource(); + var cancellationToken = cancellationTokenSource.Token; + + var context = new EvaluationContextBuilder().Build(); + providerMock.When(x => x.InitializeAsync(context, cancellationToken)).Throw(new Exception("BAD THINGS")); + + var errorCancellationToken = CancellationToken.None; + await repository.SetProviderAsync(providerMock, context, afterInitError: (theProvider, error, ct) => + { + Assert.Equal(providerMock, theProvider); + + errorCancellationToken = ct; + + return Task.CompletedTask; + }, cancellationToken: cancellationToken); + + Assert.Equal(cancellationToken, errorCancellationToken); + } + [Theory] [InlineData(ProviderStatus.Ready)] [InlineData(ProviderStatus.Stale)] @@ -94,7 +145,7 @@ internal async Task AfterInitialize_Is_Not_Called_For_Ready_Provider(ProviderSta providerMock.Status.Returns(status); var context = new EvaluationContextBuilder().Build(); var callCount = 0; - await repository.SetProviderAsync(providerMock, context, afterInitSuccess: provider => + await repository.SetProviderAsync(providerMock, context, afterInitSuccess: (provider, ct) => { callCount++; return Task.CompletedTask; @@ -150,7 +201,7 @@ public async Task AfterInitialization_Is_Invoked_For_Setting_Named_Provider() providerMock.Status.Returns(ProviderStatus.NotReady); var context = new EvaluationContextBuilder().Build(); var callCount = 0; - await repository.SetProviderAsync("the-name", providerMock, context, afterInitSuccess: (theProvider) => + await repository.SetProviderAsync("the-name", providerMock, context, afterInitSuccess: (theProvider, ct) => { Assert.Equal(providerMock, theProvider); callCount++; @@ -159,6 +210,30 @@ await repository.SetProviderAsync("the-name", providerMock, context, afterInitSu Assert.Equal(1, callCount); } + [Fact] + public async Task AfterInitialization_WithNamedProvider_Is_Invoked_With_CancellationToken() + { + var repository = new ProviderRepository(); + var providerMock = Substitute.For(); + providerMock.Status.Returns(ProviderStatus.NotReady); + + var context = new EvaluationContextBuilder().Build(); + using var cancellationTokenSource = new CancellationTokenSource(); + var cancellationToken = cancellationTokenSource.Token; + + var initCancellationToken = CancellationToken.None; + await repository.SetProviderAsync("the-name", providerMock, context, afterInitSuccess: (theProvider, ct) => + { + Assert.Equal(providerMock, theProvider); + + initCancellationToken = ct; + + return Task.CompletedTask; + }, cancellationToken: cancellationToken); + + Assert.Equal(cancellationToken, initCancellationToken); + } + [Fact] public async Task AfterError_Is_Invoked_If_Initialization_Errors_Named_Provider() { @@ -166,10 +241,10 @@ public async Task AfterError_Is_Invoked_If_Initialization_Errors_Named_Provider( var providerMock = Substitute.For(); providerMock.Status.Returns(ProviderStatus.NotReady); var context = new EvaluationContextBuilder().Build(); - providerMock.When(x => x.InitializeAsync(context)).Throw(new Exception("BAD THINGS")); + providerMock.When(x => x.InitializeAsync(context, Arg.Any())).Throw(new Exception("BAD THINGS")); var callCount = 0; Exception? receivedError = null; - await repository.SetProviderAsync("the-provider", providerMock, context, afterInitError: (theProvider, error) => + await repository.SetProviderAsync("the-provider", providerMock, context, afterInitError: (theProvider, error, ct) => { Assert.Equal(providerMock, theProvider); callCount++; @@ -180,6 +255,32 @@ await repository.SetProviderAsync("the-provider", providerMock, context, afterIn Assert.Equal(1, callCount); } + [Fact] + public async Task AfterError_WithNamedProvider_Is_Invoked_With_CancellationToken() + { + var repository = new ProviderRepository(); + var providerMock = Substitute.For(); + providerMock.Status.Returns(ProviderStatus.NotReady); + + using var cancellationTokenSource = new CancellationTokenSource(); + var cancellationToken = cancellationTokenSource.Token; + + var context = new EvaluationContextBuilder().Build(); + providerMock.When(x => x.InitializeAsync(context, cancellationToken)).Throw(new Exception("BAD THINGS")); + + var errorCancellationToken = CancellationToken.None; + await repository.SetProviderAsync("the-provider", providerMock, context, afterInitError: (theProvider, error, ct) => + { + Assert.Equal(providerMock, theProvider); + + errorCancellationToken = ct; + + return Task.CompletedTask; + }, cancellationToken: cancellationToken); + + Assert.Equal(cancellationToken, errorCancellationToken); + } + [Theory] [InlineData(ProviderStatus.Ready)] [InlineData(ProviderStatus.Stale)] @@ -206,7 +307,7 @@ internal async Task AfterInitialize_Is_Not_Called_For_Ready_Named_Provider(Provi var context = new EvaluationContextBuilder().Build(); var callCount = 0; await repository.SetProviderAsync("the-name", providerMock, context, - afterInitSuccess: provider => + afterInitSuccess: (provider, ct) => { callCount++; return Task.CompletedTask;