Skip to content

Commit

Permalink
Add Task.WhenEach to process tasks as they complete (dotnet#100316)
Browse files Browse the repository at this point in the history
* Add Task.WhenEach to process tasks as they complete

* Address PR feedback

* Fix some async void tests to be async Task across libs

* Remove extra awaiter field from state machine

Also clean up an extra level of indentation.
  • Loading branch information
stephentoub authored and matouskozak committed Apr 30, 2024
1 parent 2d425a9 commit 96cc4d3
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot()
}

[Fact]
public async void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement()
public async Task GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement()
{
// Arrange
var serviceCollection = new ServiceCollection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ private async Task ValidateOnStart_AddEagerValidation_DoesValidationWhenHostStar
}

[Fact]
private async void CanValidateOptionsEagerly_AddOptionsWithValidateOnStart_IValidateOptions()
private async Task CanValidateOptionsEagerly_AddOptionsWithValidateOnStart_IValidateOptions()
{
var hostBuilder = CreateHostBuilder(services =>
services.AddOptionsWithValidateOnStart<ComplexOptions, ComplexOptionsValidator>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ private void AssertCounters(TestCountingLogger testLogger, int requestCount, boo
[InlineData(false, true)]
[InlineData(true, false)]
[InlineData(true, true)]
public async void CustomLogger_LogsCorrectEvents_Sync(bool requestSuccessful, bool asyncSecondCall)
public async Task CustomLogger_LogsCorrectEvents_Sync(bool requestSuccessful, bool asyncSecondCall)
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddTransient(_ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ public static IEnumerable<object[]> GetConverter_ByMultithread_ReturnsExpected_T

[Theory]
[MemberData(nameof(GetConverter_ByMultithread_ReturnsExpected_TestData))]
public async void GetConverter_ByMultithread_ReturnsExpected(Type typeForGetConverter, Type expectedConverterType)
public async Task GetConverter_ByMultithread_ReturnsExpected(Type typeForGetConverter, Type expectedConverterType)
{
TypeConverter[] actualConverters = await Task.WhenAll(
Enumerable.Range(0, 100).Select(_ =>
Expand All @@ -1415,7 +1415,7 @@ public static IEnumerable<object[]> GetConverterWithAddProvider_ByMultithread_Su

[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsReflectionEmitSupported))] // Mock will try to JIT
[MemberData(nameof(GetConverterWithAddProvider_ByMultithread_Success_TestData))]
public async void GetConverterWithAddProvider_ByMultithread_Success(Type typeForGetConverter, Type expectedConverterType)
public async Task GetConverterWithAddProvider_ByMultithread_Success(Type typeForGetConverter, Type expectedConverterType)
{
TypeConverter[] actualConverters = await Task.WhenAll(
Enumerable.Range(0, 200).Select(_ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public async Task SendAsync_SlowServerAndCancel_ThrowsTaskCanceledException()

[OuterLoop]
[Fact]
public async void SendAsync_SlowServerRespondsAfterDefaultReceiveTimeout_ThrowsHttpRequestException()
public async Task SendAsync_SlowServerRespondsAfterDefaultReceiveTimeout_ThrowsHttpRequestException()
{
var handler = new WinHttpHandler();
using (var client = new HttpClient(handler))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
using System.Threading.Tasks.Sources;

namespace System.Threading.Tasks
{
Expand Down Expand Up @@ -6659,6 +6660,191 @@ public static Task<Task<TResult>> WhenAny<TResult>(IEnumerable<Task<TResult>> ta
WhenAny<Task<TResult>>(tasks);
#endregion

#region WhenEach
/// <summary>Creates an <see cref="IAsyncEnumerable{T}"/> that will yield the supplied tasks as those tasks complete.</summary>
/// <param name="tasks">The task to iterate through when completed.</param>
/// <returns>An <see cref="IAsyncEnumerable{T}"/> for iterating through the supplied tasks.</returns>
/// <remarks>
/// The supplied tasks will become available to be output via the enumerable once they've completed. The exact order
/// in which the tasks will become available is not defined.
/// </remarks>
/// <exception cref="ArgumentNullException"><paramref name="tasks"/> is null.</exception>
/// <exception cref="ArgumentException"><paramref name="tasks"/> contains a null.</exception>
public static IAsyncEnumerable<Task> WhenEach(params Task[] tasks)
{
ArgumentNullException.ThrowIfNull(tasks);
return WhenEach((ReadOnlySpan<Task>)tasks);
}

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task> WhenEach(ReadOnlySpan<Task> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params
WhenEachState.Iterate<Task>(WhenEachState.Create(tasks));

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task> WhenEach(IEnumerable<Task> tasks) =>
WhenEachState.Iterate<Task>(WhenEachState.Create(tasks));

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(params Task<TResult>[] tasks)
{
ArgumentNullException.ThrowIfNull(tasks);
return WhenEach((ReadOnlySpan<Task<TResult>>)tasks);
}

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(ReadOnlySpan<Task<TResult>> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params
WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(ReadOnlySpan<Task>.CastUp(tasks)));

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(IEnumerable<Task<TResult>> tasks) =>
WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(tasks));

/// <summary>Object used by <see cref="Iterate"/> to store its state.</summary>
private sealed class WhenEachState : Queue<Task>, IValueTaskSource, ITaskCompletionAction
{
/// <summary>Implementation backing the ValueTask used to wait for the next task to be available.</summary>
/// <remarks>This is a mutable struct. Do not make it readonly.</remarks>
private ManualResetValueTaskSourceCore<bool> _waitForNextCompletedTask = new() { RunContinuationsAsynchronously = true }; // _waitForNextCompletedTask.Set is called while holding a lock
/// <summary>0 if this has never been used in an iteration; 1 if it has.</summary>
/// <remarks>This is used to ensure we only ever iterate through the tasks once.</remarks>
private int _enumerated;

/// <summary>Called at the beginning of the iterator to assume ownership of the state.</summary>
/// <returns>true if the caller owns the state; false if the caller should end immediately.</returns>
public bool TryStart() => Interlocked.Exchange(ref _enumerated, 1) == 0;

/// <summary>Gets or sets the number of tasks that haven't yet been yielded.</summary>
public int Remaining { get; set; }

void ITaskCompletionAction.Invoke(Task completingTask)
{
lock (this)
{
// Enqueue the task into the queue. If the Count is now 1, we transitioned from
// empty to non-empty, which means we need to signal the MRVTSC, as the consumer
// could be waiting on a ValueTask representing a completed task being available.
Enqueue(completingTask);
if (Count == 1)
{
Debug.Assert(_waitForNextCompletedTask.GetStatus(_waitForNextCompletedTask.Version) == ValueTaskSourceStatus.Pending);
_waitForNextCompletedTask.SetResult(default);
}
}
}
bool ITaskCompletionAction.InvokeMayRunArbitraryCode => false;

// Delegate to _waitForNextCompletedTask for IValueTaskSource implementation.
void IValueTaskSource.GetResult(short token) => _waitForNextCompletedTask.GetResult(token);
ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitForNextCompletedTask.GetStatus(token);
void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) =>
_waitForNextCompletedTask.OnCompleted(continuation, state, token, flags);

/// <summary>Creates a <see cref="WhenEachState"/> from the specified tasks.</summary>
public static WhenEachState? Create(ReadOnlySpan<Task> tasks)
{
WhenEachState? waiter = null;

if (tasks.Length != 0)
{
waiter = new();
foreach (Task task in tasks)
{
if (task is null)
{
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
}

waiter.Remaining++;
task.AddCompletionAction(waiter);
}
}

return waiter;
}

/// <inheritdoc cref="Create(ReadOnlySpan{Task})"/>
public static WhenEachState? Create(IEnumerable<Task> tasks)
{
ArgumentNullException.ThrowIfNull(tasks);

WhenEachState? waiter = null;

IEnumerator<Task> e = tasks.GetEnumerator();
if (e.MoveNext())
{
waiter = new();
do
{
Task task = e.Current;
if (task is null)
{
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
}

waiter.Remaining++;
task.AddCompletionAction(waiter);
}
while (e.MoveNext());
}

return waiter;
}

/// <summary>Iterates through the tasks represented by the provided waiter.</summary>
public static async IAsyncEnumerable<T> Iterate<T>(WhenEachState? waiter, [EnumeratorCancellation] CancellationToken cancellationToken = default) where T : Task
{
// The enumerable could have GetAsyncEnumerator called on it multiple times. As we're dealing with Tasks that
// only ever transition from non-completed to completed, re-enumeration doesn't have much benefit, so we take
// advantage of the optimizations possible by not supporting that and simply have the semantics that, no matter
// how many times the enumerable is enumerated, every task is yielded only once. The original GetAsyncEnumerator
// call will give back all the tasks, and all subsequent iterations will be empty.
if (waiter?.TryStart() is not true)
{
yield break;
}

// Loop until we've yielded all tasks.
while (waiter.Remaining > 0)
{
// Either get the next completed task from the queue, or get a
// ValueTask with which to wait for the next task to complete.
Task? next;
ValueTask waitTask = default;
lock (waiter)
{
// Reset the MRVTSC if it was signaled, then try to dequeue a task and
// either return one we got or return a ValueTask that will be signaled
// when the next completed task is available.
waiter._waitForNextCompletedTask.Reset();
if (!waiter.TryDequeue(out next))
{
waitTask = new(waiter, waiter._waitForNextCompletedTask.Version);
}
}

// If we got a completed Task, yield it.
if (next is not null)
{
cancellationToken.ThrowIfCancellationRequested();
waiter.Remaining--;
yield return (T)next;
continue;
}

// If we have a cancellation token and the ValueTask isn't already completed,
// get a Task from the ValueTask so we can use WaitAsync to make the wait cancelable.
// Otherwise, just await the ValueTask directly. We don't need to be concerned
// about suppressing exceptions, as the ValueTask is only ever completed successfully.
if (cancellationToken.CanBeCanceled && !waitTask.IsCompleted)
{
waitTask = new ValueTask(waitTask.AsTask().WaitAsync(cancellationToken));
}
await waitTask.ConfigureAwait(false);
}
}
}
#endregion

internal static Task<TResult> CreateUnwrapPromise<TResult>(Task outerTask, bool lookForOce)
{
Debug.Assert(outerTask != null);
Expand Down
6 changes: 6 additions & 0 deletions src/libraries/System.Runtime/ref/System.Runtime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15335,6 +15335,12 @@ public static void WaitAll(System.Threading.Tasks.Task[] tasks, System.Threading
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(System.Threading.Tasks.Task<TResult> task1, System.Threading.Tasks.Task<TResult> task2) { throw null; }
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(params System.Threading.Tasks.Task<TResult>[] tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task> tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(params System.Threading.Tasks.Task[] tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(System.ReadOnlySpan<System.Threading.Tasks.Task> tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(params System.Threading.Tasks.Task<TResult>[] tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(System.ReadOnlySpan<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
public static System.Runtime.CompilerServices.YieldAwaitable Yield() { throw null; }
}
public static partial class TaskAsyncEnumerableExtensions
Expand Down
Loading

0 comments on commit 96cc4d3

Please sign in to comment.