Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Task.WhenEach to process tasks as they complete #100316

Merged
merged 6 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot()
}

[Fact]
public async void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement()
public async Task GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement()
{
// Arrange
var serviceCollection = new ServiceCollection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ private async Task ValidateOnStart_AddEagerValidation_DoesValidationWhenHostStar
}

[Fact]
private async void CanValidateOptionsEagerly_AddOptionsWithValidateOnStart_IValidateOptions()
private async Task CanValidateOptionsEagerly_AddOptionsWithValidateOnStart_IValidateOptions()
{
var hostBuilder = CreateHostBuilder(services =>
services.AddOptionsWithValidateOnStart<ComplexOptions, ComplexOptionsValidator>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ private void AssertCounters(TestCountingLogger testLogger, int requestCount, boo
[InlineData(false, true)]
[InlineData(true, false)]
[InlineData(true, true)]
public async void CustomLogger_LogsCorrectEvents_Sync(bool requestSuccessful, bool asyncSecondCall)
public async Task CustomLogger_LogsCorrectEvents_Sync(bool requestSuccessful, bool asyncSecondCall)
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddTransient(_ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ public static IEnumerable<object[]> GetConverter_ByMultithread_ReturnsExpected_T

[Theory]
[MemberData(nameof(GetConverter_ByMultithread_ReturnsExpected_TestData))]
public async void GetConverter_ByMultithread_ReturnsExpected(Type typeForGetConverter, Type expectedConverterType)
public async Task GetConverter_ByMultithread_ReturnsExpected(Type typeForGetConverter, Type expectedConverterType)
{
TypeConverter[] actualConverters = await Task.WhenAll(
Enumerable.Range(0, 100).Select(_ =>
Expand All @@ -1415,7 +1415,7 @@ public static IEnumerable<object[]> GetConverterWithAddProvider_ByMultithread_Su

[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsReflectionEmitSupported))] // Mock will try to JIT
[MemberData(nameof(GetConverterWithAddProvider_ByMultithread_Success_TestData))]
public async void GetConverterWithAddProvider_ByMultithread_Success(Type typeForGetConverter, Type expectedConverterType)
public async Task GetConverterWithAddProvider_ByMultithread_Success(Type typeForGetConverter, Type expectedConverterType)
{
TypeConverter[] actualConverters = await Task.WhenAll(
Enumerable.Range(0, 200).Select(_ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public async Task SendAsync_SlowServerAndCancel_ThrowsTaskCanceledException()

[OuterLoop]
[Fact]
public async void SendAsync_SlowServerRespondsAfterDefaultReceiveTimeout_ThrowsHttpRequestException()
public async Task SendAsync_SlowServerRespondsAfterDefaultReceiveTimeout_ThrowsHttpRequestException()
{
var handler = new WinHttpHandler();
using (var client = new HttpClient(handler))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
using System.Threading.Tasks.Sources;

namespace System.Threading.Tasks
{
Expand Down Expand Up @@ -6659,6 +6660,191 @@ public static Task<Task<TResult>> WhenAny<TResult>(IEnumerable<Task<TResult>> ta
WhenAny<Task<TResult>>(tasks);
#endregion

#region WhenEach
/// <summary>Creates an <see cref="IAsyncEnumerable{T}"/> that will yield the supplied tasks as those tasks complete.</summary>
/// <param name="tasks">The task to iterate through when completed.</param>
/// <returns>An <see cref="IAsyncEnumerable{T}"/> for iterating through the supplied tasks.</returns>
/// <remarks>
/// The supplied tasks will become available to be output via the enumerable once they've completed. The exact order
/// in which the tasks will become available is not defined.
/// </remarks>
/// <exception cref="ArgumentNullException"><paramref name="tasks"/> is null.</exception>
/// <exception cref="ArgumentException"><paramref name="tasks"/> contains a null.</exception>
public static IAsyncEnumerable<Task> WhenEach(params Task[] tasks)
{
ArgumentNullException.ThrowIfNull(tasks);
return WhenEach((ReadOnlySpan<Task>)tasks);
}

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task> WhenEach(ReadOnlySpan<Task> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params
WhenEachState.Iterate<Task>(WhenEachState.Create(tasks));

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task> WhenEach(IEnumerable<Task> tasks) =>
WhenEachState.Iterate<Task>(WhenEachState.Create(tasks));

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(params Task<TResult>[] tasks)
{
ArgumentNullException.ThrowIfNull(tasks);
return WhenEach((ReadOnlySpan<Task<TResult>>)tasks);
}

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(ReadOnlySpan<Task<TResult>> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params
WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(ReadOnlySpan<Task>.CastUp(tasks)));

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(IEnumerable<Task<TResult>> tasks) =>
WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(tasks));

/// <summary>Object used by <see cref="Iterate"/> to store its state.</summary>
private sealed class WhenEachState : Queue<Task>, IValueTaskSource, ITaskCompletionAction
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
/// <summary>Implementation backing the ValueTask used to wait for the next task to be available.</summary>
/// <remarks>This is a mutable struct. Do not make it readonly.</remarks>
private ManualResetValueTaskSourceCore<bool> _waitForNextCompletedTask = new() { RunContinuationsAsynchronously = true }; // _waitForNextCompletedTask.Set is called while holding a lock
/// <summary>0 if this has never been used in an iteration; 1 if it has.</summary>
/// <remarks>This is used to ensure we only ever iterate through the tasks once.</remarks>
private int _enumerated;

/// <summary>Called at the beginning of the iterator to assume ownership of the state.</summary>
/// <returns>true if the caller owns the state; false if the caller should end immediately.</returns>
public bool TryStart() => Interlocked.Exchange(ref _enumerated, 1) == 0;

/// <summary>Gets or sets the number of tasks that haven't yet been yielded.</summary>
public int Remaining { get; set; }

void ITaskCompletionAction.Invoke(Task completingTask)
{
lock (this)
{
// Enqueue the task into the queue. If the Count is now 1, we transitioned from
// empty to non-empty, which means we need to signal the MRVTSC, as the consumer
// could be waiting on a ValueTask representing a completed task being available.
Enqueue(completingTask);
if (Count == 1)
{
Debug.Assert(_waitForNextCompletedTask.GetStatus(_waitForNextCompletedTask.Version) == ValueTaskSourceStatus.Pending);
_waitForNextCompletedTask.SetResult(default);
}
}
}
bool ITaskCompletionAction.InvokeMayRunArbitraryCode => false;

// Delegate to _waitForNextCompletedTask for IValueTaskSource implementation.
void IValueTaskSource.GetResult(short token) => _waitForNextCompletedTask.GetResult(token);
ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitForNextCompletedTask.GetStatus(token);
void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) =>
_waitForNextCompletedTask.OnCompleted(continuation, state, token, flags);

/// <summary>Creates a <see cref="WhenEachState"/> from the specified tasks.</summary>
public static WhenEachState? Create(ReadOnlySpan<Task> tasks)
{
WhenEachState? waiter = null;

if (tasks.Length != 0)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
waiter = new();
foreach (Task task in tasks)
{
if (task is null)
{
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
}

waiter.Remaining++;
task.AddCompletionAction(waiter);
}
}

return waiter;
}

/// <inheritdoc cref="Create(ReadOnlySpan{Task})"/>
public static WhenEachState? Create(IEnumerable<Task> tasks)
{
ArgumentNullException.ThrowIfNull(tasks);

WhenEachState? waiter = null;

IEnumerator<Task> e = tasks.GetEnumerator();
if (e.MoveNext())
{
waiter = new();
do
{
Task task = e.Current;
if (task is null)
{
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
}

waiter.Remaining++;
task.AddCompletionAction(waiter);
}
while (e.MoveNext());
}

return waiter;
}

/// <summary>Iterates through the tasks represented by the provided waiter.</summary>
public static async IAsyncEnumerable<T> Iterate<T>(WhenEachState? waiter, [EnumeratorCancellation] CancellationToken cancellationToken = default) where T : Task
{
// The enumerable could have GetAsyncEnumerator called on it multiple times. As we're dealing with Tasks that
// only ever transition from non-completed to completed, re-enumeration doesn't have much benefit, so we take
// advantage of the optimizations possible by not supporting that and simply have the semantics that, no matter
// how many times the enumerable is enumerated, every task is yielded only once. The original GetAsyncEnumerator
// call will give back all the tasks, and all subsequent iterations will be empty.
if (waiter?.TryStart() is not true)
{
yield break;
}

// Loop until we've yielded all tasks.
while (waiter.Remaining > 0)
{
// Either get the next completed task from the queue, or get a
// ValueTask with which to wait for the next task to complete.
Task? next;
ValueTask waitTask = default;
lock (waiter)
{
// Reset the MRVTSC if it was signaled, then try to dequeue a task and
// either return one we got or return a ValueTask that will be signaled
// when the next completed task is available.
waiter._waitForNextCompletedTask.Reset();
if (!waiter.TryDequeue(out next))
{
waitTask = new(waiter, waiter._waitForNextCompletedTask.Version);
}
}

// If we got a completed Task, yield it.
if (next is not null)
{
cancellationToken.ThrowIfCancellationRequested();
waiter.Remaining--;
yield return (T)next;
continue;
}

// If we have a cancellation token and the ValueTask isn't already completed,
// get a Task from the ValueTask so we can use WaitAsync to make the wait cancelable.
// Otherwise, just await the ValueTask directly. We don't need to be concerned
// about suppressing exceptions, as the ValueTask is only ever completed successfully.
if (cancellationToken.CanBeCanceled && !waitTask.IsCompleted)
{
waitTask = new ValueTask(waitTask.AsTask().WaitAsync(cancellationToken));
}
await waitTask.ConfigureAwait(false);
}
}
}
#endregion

internal static Task<TResult> CreateUnwrapPromise<TResult>(Task outerTask, bool lookForOce)
{
Debug.Assert(outerTask != null);
Expand Down
6 changes: 6 additions & 0 deletions src/libraries/System.Runtime/ref/System.Runtime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15335,6 +15335,12 @@ public static void WaitAll(System.Threading.Tasks.Task[] tasks, System.Threading
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(System.Threading.Tasks.Task<TResult> task1, System.Threading.Tasks.Task<TResult> task2) { throw null; }
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(params System.Threading.Tasks.Task<TResult>[] tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task> tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(params System.Threading.Tasks.Task[] tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(System.ReadOnlySpan<System.Threading.Tasks.Task> tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(params System.Threading.Tasks.Task<TResult>[] tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(System.ReadOnlySpan<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
public static System.Runtime.CompilerServices.YieldAwaitable Yield() { throw null; }
}
public static partial class TaskAsyncEnumerableExtensions
Expand Down
Loading
Loading