Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions TUnit.Assertions.Tests/WaitsForAssertionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,110 @@ public async Task WaitsFor_WithIsNotNull_ReturnsResolvedValue_Issue3623()
await Assert.That(entity.IsReady).IsEqualTo(true);
}

[Test]
public async Task WaitsFor_PropagatesExternalCancellation_Before_Internal_Timeout()
{
using var cts = new CancellationTokenSource();
cts.CancelAfter(TimeSpan.FromMilliseconds(100));

var stopwatch = Stopwatch.StartNew();

var act = async () => await Assert.That(() => false)
.WaitsFor(
assert => assert.IsTrue(),
timeout: TimeSpan.FromSeconds(30),
cancellationToken: cts.Token);

await Assert.That(act).Throws<OperationCanceledException>();

stopwatch.Stop();

await Assert.That(stopwatch.Elapsed).IsLessThan(TimeSpan.FromSeconds(1));
}

[Test]
public async Task WaitsFor_Throws_AssertionException_On_Internal_Timeout_When_Token_Not_Cancelled()
{
using var cts = new CancellationTokenSource();

var act = async () => await Assert.That(() => false)
.WaitsFor(
assert => assert.IsTrue(),
timeout: TimeSpan.FromMilliseconds(200),
cancellationToken: cts.Token);

await Assert.That(act).Throws<AssertionException>();
}

[Test]
public async Task WaitsFor_Honours_PreCancelled_Token_Before_First_Poll()
{
using var cts = new CancellationTokenSource();
cts.Cancel();

var stopwatch = Stopwatch.StartNew();

var act = async () => await Assert.That(() => false)
.WaitsFor(
assert => assert.IsTrue(),
timeout: TimeSpan.FromSeconds(5),
cancellationToken: cts.Token);

await Assert.That(act).Throws<OperationCanceledException>();

stopwatch.Stop();

await Assert.That(stopwatch.Elapsed).IsLessThan(TimeSpan.FromMilliseconds(500));
}

[Test]
public async Task Eventually_PropagatesExternalCancellation_Before_Internal_Timeout()
{
using var cts = new CancellationTokenSource();
cts.CancelAfter(TimeSpan.FromMilliseconds(100));

var stopwatch = Stopwatch.StartNew();

var act = async () => await Assert.That(() => false)
.Eventually(
assert => assert.IsTrue(),
timeout: TimeSpan.FromSeconds(30),
cancellationToken: cts.Token);

await Assert.That(act).Throws<OperationCanceledException>();

stopwatch.Stop();

await Assert.That(stopwatch.Elapsed).IsLessThan(TimeSpan.FromSeconds(1));
}

[Test]
public async Task WaitsFor_Propagates_OCE_When_Predicate_Observes_Supplied_Token()
{
using var cts = new CancellationTokenSource();
cts.CancelAfter(TimeSpan.FromMilliseconds(50));

Func<bool> tokenAwarePredicate = () =>
{
cts.Token.ThrowIfCancellationRequested();
return false;
};

var stopwatch = Stopwatch.StartNew();

var act = async () => await Assert.That(tokenAwarePredicate)
.WaitsFor(
assert => assert.IsTrue(),
timeout: TimeSpan.FromSeconds(30),
cancellationToken: cts.Token);

await Assert.That(act).Throws<OperationCanceledException>();

stopwatch.Stop();

await Assert.That(stopwatch.Elapsed).IsLessThan(TimeSpan.FromMilliseconds(500));
}

// Helper class for testing complex objects
private class TestEntity
{
Expand Down
23 changes: 20 additions & 3 deletions TUnit.Assertions/Conditions/WaitsForAssertion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@ public class WaitsForAssertion<TValue> : Assertion<TValue>
private readonly Func<IAssertionSource<TValue>, Assertion<TValue>> _assertionBuilder;
private readonly TimeSpan _timeout;
private readonly TimeSpan _pollingInterval;
private readonly CancellationToken _cancellationToken;

public WaitsForAssertion(
AssertionContext<TValue> context,
Func<IAssertionSource<TValue>, Assertion<TValue>> assertionBuilder,
TimeSpan timeout,
TimeSpan? pollingInterval = null)
TimeSpan? pollingInterval = null,
CancellationToken cancellationToken = default)
: base(context)
{
_assertionBuilder = assertionBuilder ?? throw new ArgumentNullException(nameof(assertionBuilder));
_timeout = timeout;
_pollingInterval = pollingInterval ?? TimeSpan.FromMilliseconds(10);
_cancellationToken = cancellationToken;

if (_timeout <= TimeSpan.Zero)
{
Expand All @@ -44,10 +47,20 @@ protected override async Task<AssertionResult> CheckAsync(EvaluationMetadata<TVa
Exception? lastException = null;
var attemptCount = 0;

using var cts = new CancellationTokenSource(_timeout);
// Link the supplied cancellation token with an internal timeout source so the polling
// loop honours both: external cancellation propagates as OperationCanceledException,
// and the internal timeout still produces the standard AssertionResult.Failed path.
// When the caller did not supply a cancellable token, the linking step is skipped to
// avoid the registration overhead.
using var linkedCts = _cancellationToken.CanBeCanceled
? CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken)
: new CancellationTokenSource();
linkedCts.CancelAfter(_timeout);

while (stopwatch.Elapsed < _timeout)
{
_cancellationToken.ThrowIfCancellationRequested();

attemptCount++;

try
Expand All @@ -71,7 +84,11 @@ protected override async Task<AssertionResult> CheckAsync(EvaluationMetadata<TVa

try
{
await Task.Delay(_pollingInterval, cts.Token);
await Task.Delay(_pollingInterval, linkedCts.Token);
}
catch (OperationCanceledException) when (_cancellationToken.IsCancellationRequested)
{
throw;
}
catch (OperationCanceledException)
{
Expand Down
8 changes: 6 additions & 2 deletions TUnit.Assertions/Extensions/AssertionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,7 @@ public static CompletesWithinAsyncAssertion CompletesWithin(
/// <param name="assertionBuilder">A function that builds the assertion to be evaluated on each poll</param>
/// <param name="timeout">The maximum time to wait for the assertion to pass</param>
/// <param name="pollingInterval">The interval between polling attempts (defaults to 10ms if not specified)</param>
/// <param name="cancellationToken">A token to cancel the polling loop. External cancellation throws <see cref="OperationCanceledException"/>; the internal timeout still produces the standard <see cref="Exceptions.AssertionException"/>.</param>
/// <param name="timeoutExpression">Captured expression for the timeout parameter</param>
/// <param name="pollingIntervalExpression">Captured expression for the polling interval parameter</param>
/// <returns>An assertion that can be awaited or chained with And/Or</returns>
Expand All @@ -1727,12 +1728,13 @@ public static WaitsForAssertion<TValue> WaitsFor<TValue>(
Func<IAssertionSource<TValue>, Assertion<TValue>> assertionBuilder,
TimeSpan timeout,
TimeSpan? pollingInterval = null,
CancellationToken cancellationToken = default,
[CallerArgumentExpression(nameof(timeout))] string? timeoutExpression = null,
[CallerArgumentExpression(nameof(pollingInterval))] string? pollingIntervalExpression = null)
{
var intervalExpr = pollingInterval.HasValue ? $", pollingInterval: {pollingIntervalExpression}" : "";
source.Context.ExpressionBuilder.Append($".WaitsFor(..., timeout: {timeoutExpression}{intervalExpr})");
return new WaitsForAssertion<TValue>(source.Context, assertionBuilder, timeout, pollingInterval);
return new WaitsForAssertion<TValue>(source.Context, assertionBuilder, timeout, pollingInterval, cancellationToken);
}

/// <summary>
Expand All @@ -1745,6 +1747,7 @@ public static WaitsForAssertion<TValue> WaitsFor<TValue>(
/// <param name="assertionBuilder">A function that builds the assertion to be evaluated on each poll</param>
/// <param name="timeout">The maximum time to wait for the assertion to pass</param>
/// <param name="pollingInterval">The interval between polling attempts (defaults to 10ms if not specified)</param>
/// <param name="cancellationToken">A token to cancel the polling loop. External cancellation throws <see cref="OperationCanceledException"/>; the internal timeout still produces the standard <see cref="Exceptions.AssertionException"/>.</param>
/// <param name="timeoutExpression">Captured expression for the timeout parameter</param>
/// <param name="pollingIntervalExpression">Captured expression for the polling interval parameter</param>
/// <returns>An assertion that can be awaited or chained with And/Or</returns>
Expand All @@ -1753,10 +1756,11 @@ public static WaitsForAssertion<TValue> Eventually<TValue>(
Func<IAssertionSource<TValue>, Assertion<TValue>> assertionBuilder,
TimeSpan timeout,
TimeSpan? pollingInterval = null,
CancellationToken cancellationToken = default,
[CallerArgumentExpression(nameof(timeout))] string? timeoutExpression = null,
[CallerArgumentExpression(nameof(pollingInterval))] string? pollingIntervalExpression = null)
{
return source.WaitsFor(assertionBuilder, timeout, pollingInterval, timeoutExpression, pollingIntervalExpression);
return source.WaitsFor(assertionBuilder, timeout, pollingInterval, cancellationToken, timeoutExpression, pollingIntervalExpression);
}

private static Action GetActionFromDelegate(DelegateAssertion source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2279,7 +2279,7 @@ namespace .Conditions
public static class ValueTaskAssertionExtensions { }
public class WaitsForAssertion<TValue> : .<TValue>
{
public WaitsForAssertion(.<TValue> context, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default) { }
public WaitsForAssertion(.<TValue> context, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, .CancellationToken cancellationToken = default) { }
public override .<TValue?> AssertAsync() { }
protected override .<.> CheckAsync(.<TValue> metadata) { }
protected override string GetExpectation() { }
Expand Down Expand Up @@ -2674,7 +2674,7 @@ namespace .Extensions
where TCollection : .<.<TKey, TValue>>
where TKey : notnull { }
public static .<TValue> EqualTo<TValue>(this .<TValue> source, TValue? expected, [.("expected")] string? expression = null) { }
public static .<TValue> Eventually<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
public static .<TValue> Eventually<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, .CancellationToken cancellationToken = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
[("Use Length() instead, which provides all numeric assertion methods. Example: Asse" +
"(str).Length().IsGreaterThan(5)")]
public static ..LengthWrapper HasLength(this .<string> source) { }
Expand Down Expand Up @@ -2803,7 +2803,7 @@ namespace .Extensions
public static .<TException> ThrowsException<TException, TValue>(this .<TValue> source)
where TException : { }
public static .<TValue> ThrowsNothing<TValue>(this .<TValue> source) { }
public static .<TValue> WaitsFor<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
public static .<TValue> WaitsFor<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, .CancellationToken cancellationToken = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
public static ..WhenParsedIntoAssertion<T> WhenParsedInto<[.(..None | ..PublicMethods | ..Interfaces)] T>(this .<string> source) { }
public static .<TException, TInnerException> WithInnerException<TException, TInnerException>(this .<TException> source)
where TException :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2262,7 +2262,7 @@ namespace .Conditions
public static class ValueTaskAssertionExtensions { }
public class WaitsForAssertion<TValue> : .<TValue>
{
public WaitsForAssertion(.<TValue> context, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default) { }
public WaitsForAssertion(.<TValue> context, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, .CancellationToken cancellationToken = default) { }
public override .<TValue?> AssertAsync() { }
protected override .<.> CheckAsync(.<TValue> metadata) { }
protected override string GetExpectation() { }
Expand Down Expand Up @@ -2653,7 +2653,7 @@ namespace .Extensions
where TCollection : .<.<TKey, TValue>>
where TKey : notnull { }
public static .<TValue> EqualTo<TValue>(this .<TValue> source, TValue? expected, [.("expected")] string? expression = null) { }
public static .<TValue> Eventually<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
public static .<TValue> Eventually<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, .CancellationToken cancellationToken = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
[("Use Length() instead, which provides all numeric assertion methods. Example: Asse" +
"(str).Length().IsGreaterThan(5)")]
public static ..LengthWrapper HasLength(this .<string> source) { }
Expand Down Expand Up @@ -2768,7 +2768,7 @@ namespace .Extensions
public static .<TException> ThrowsException<TException, TValue>(this .<TValue> source)
where TException : { }
public static .<TValue> ThrowsNothing<TValue>(this .<TValue> source) { }
public static .<TValue> WaitsFor<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
public static .<TValue> WaitsFor<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, .CancellationToken cancellationToken = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
public static ..WhenParsedIntoAssertion<T> WhenParsedInto<[.(..None | ..PublicMethods | ..Interfaces)] T>(this .<string> source) { }
public static .<TException, TInnerException> WithInnerException<TException, TInnerException>(this .<TException> source)
where TException :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2279,7 +2279,7 @@ namespace .Conditions
public static class ValueTaskAssertionExtensions { }
public class WaitsForAssertion<TValue> : .<TValue>
{
public WaitsForAssertion(.<TValue> context, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default) { }
public WaitsForAssertion(.<TValue> context, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, .CancellationToken cancellationToken = default) { }
public override .<TValue?> AssertAsync() { }
protected override .<.> CheckAsync(.<TValue> metadata) { }
protected override string GetExpectation() { }
Expand Down Expand Up @@ -2674,7 +2674,7 @@ namespace .Extensions
where TCollection : .<.<TKey, TValue>>
where TKey : notnull { }
public static .<TValue> EqualTo<TValue>(this .<TValue> source, TValue? expected, [.("expected")] string? expression = null) { }
public static .<TValue> Eventually<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
public static .<TValue> Eventually<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, .CancellationToken cancellationToken = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
[("Use Length() instead, which provides all numeric assertion methods. Example: Asse" +
"(str).Length().IsGreaterThan(5)")]
public static ..LengthWrapper HasLength(this .<string> source) { }
Expand Down Expand Up @@ -2803,7 +2803,7 @@ namespace .Extensions
public static .<TException> ThrowsException<TException, TValue>(this .<TValue> source)
where TException : { }
public static .<TValue> ThrowsNothing<TValue>(this .<TValue> source) { }
public static .<TValue> WaitsFor<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
public static .<TValue> WaitsFor<TValue>(this .<TValue> source, <.<TValue>, .<TValue>> assertionBuilder, timeout, ? pollingInterval = default, .CancellationToken cancellationToken = default, [.("timeout")] string? timeoutExpression = null, [.("pollingInterval")] string? pollingIntervalExpression = null) { }
public static ..WhenParsedIntoAssertion<T> WhenParsedInto<[.(..None | ..PublicMethods | ..Interfaces)] T>(this .<string> source) { }
public static .<TException, TInnerException> WithInnerException<TException, TInnerException>(this .<TException> source)
where TException :
Expand Down
Loading
Loading