diff --git a/src/Polly/Timeout/AsyncTimeoutPolicy.cs b/src/Polly/Timeout/AsyncTimeoutPolicy.cs index e8a29b12508..03036818aa5 100644 --- a/src/Polly/Timeout/AsyncTimeoutPolicy.cs +++ b/src/Polly/Timeout/AsyncTimeoutPolicy.cs @@ -3,7 +3,6 @@ /// /// A timeout policy which can be applied to async delegates. /// -#pragma warning disable CA1062 // Validate arguments of public methods public class AsyncTimeoutPolicy : AsyncPolicy, ITimeoutPolicy { private readonly Func _timeoutProvider; @@ -26,8 +25,14 @@ protected override Task ImplementationAsync( Func> action, Context context, CancellationToken cancellationToken, - bool continueOnCapturedContext) => - AsyncTimeoutEngine.ImplementationAsync( + bool continueOnCapturedContext) + { + if (action is null) + { + throw new ArgumentNullException(nameof(action)); + } + + return AsyncTimeoutEngine.ImplementationAsync( action, context, _timeoutProvider, @@ -35,6 +40,7 @@ protected override Task ImplementationAsync( _onTimeoutAsync, continueOnCapturedContext, cancellationToken); + } } /// @@ -63,8 +69,14 @@ protected override Task ImplementationAsync( Func> action, Context context, CancellationToken cancellationToken, - bool continueOnCapturedContext) => - AsyncTimeoutEngine.ImplementationAsync( + bool continueOnCapturedContext) + { + if (action is null) + { + throw new ArgumentNullException(nameof(action)); + } + + return AsyncTimeoutEngine.ImplementationAsync( action, context, _timeoutProvider, @@ -72,4 +84,5 @@ protected override Task ImplementationAsync( _onTimeoutAsync, continueOnCapturedContext, cancellationToken); + } } diff --git a/test/Polly.Specs/Timeout/TimeoutAsyncSpecs.cs b/test/Polly.Specs/Timeout/TimeoutAsyncSpecs.cs index 5851f421d3e..2bc95235431 100644 --- a/test/Polly.Specs/Timeout/TimeoutAsyncSpecs.cs +++ b/test/Polly.Specs/Timeout/TimeoutAsyncSpecs.cs @@ -7,6 +7,34 @@ public class TimeoutAsyncSpecs : TimeoutSpecsBase { #region Configuration + [Fact] + public void Should_throw_when_action_is_null() + { + var flags = BindingFlags.NonPublic | BindingFlags.Instance; + Func> action = null!; + Func timeoutProvider = (_) => TimeSpan.Zero; + TimeoutStrategy timeoutStrategy = TimeoutStrategy.Optimistic; + Func onTimeoutAsync = (_, _, _, _) => Task.CompletedTask; + + var instance = Activator.CreateInstance( + typeof(AsyncTimeoutPolicy), + flags, + null, + [timeoutProvider, timeoutStrategy, onTimeoutAsync], + null)!; + var instanceType = instance.GetType(); + var methods = instanceType.GetMethods(flags); + var methodInfo = methods.First(method => method is { Name: "ImplementationAsync", ReturnType.Name: "Task`1" }); + var generic = methodInfo.MakeGenericMethod(typeof(EmptyStruct)); + + var func = () => generic.Invoke(instance, [action, new Context(), CancellationToken.None, false]); + + var exceptionAssertions = func.Should().Throw(); + exceptionAssertions.And.Message.Should().Be("Exception has been thrown by the target of an invocation."); + exceptionAssertions.And.InnerException.Should().BeOfType() + .Which.ParamName.Should().Be("action"); + } + [Fact] public void Should_throw_when_timeout_is_zero_by_timespan() { diff --git a/test/Polly.Specs/Timeout/TimeoutTResultAsyncSpecs.cs b/test/Polly.Specs/Timeout/TimeoutTResultAsyncSpecs.cs index e3380f1059e..40ffb703dfe 100644 --- a/test/Polly.Specs/Timeout/TimeoutTResultAsyncSpecs.cs +++ b/test/Polly.Specs/Timeout/TimeoutTResultAsyncSpecs.cs @@ -5,6 +5,33 @@ public class TimeoutTResultAsyncSpecs : TimeoutSpecsBase { #region Configuration + [Fact] + public void Should_throw_when_action_is_null() + { + var flags = BindingFlags.NonPublic | BindingFlags.Instance; + Func> action = null!; + Func timeoutProvider = (_) => TimeSpan.Zero; + TimeoutStrategy timeoutStrategy = TimeoutStrategy.Optimistic; + Func onTimeoutAsync = (_, _, _, _) => Task.CompletedTask; + + var instance = Activator.CreateInstance( + typeof(AsyncTimeoutPolicy), + flags, + null, + [timeoutProvider, timeoutStrategy, onTimeoutAsync], + null)!; + var instanceType = instance.GetType(); + var methods = instanceType.GetMethods(flags); + var methodInfo = methods.First(method => method is { Name: "ImplementationAsync", ReturnType.Name: "Task`1" }); + + var func = () => methodInfo.Invoke(instance, [action, new Context(), CancellationToken.None, false]); + + var exceptionAssertions = func.Should().Throw(); + exceptionAssertions.And.Message.Should().Be("Exception has been thrown by the target of an invocation."); + exceptionAssertions.And.InnerException.Should().BeOfType() + .Which.ParamName.Should().Be("action"); + } + [Fact] public void Should_throw_when_timeout_is_zero_by_timespan() {