Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,23 @@ private static Task ForEachAsync<TSource>(IEnumerable<TSource> source, int dop,
// Continue to loop while there are more elements to be processed.
while (!state.Cancellation.IsCancellationRequested)
{
// Get the next element from the enumerator. This requires asynchronously locking around MoveNextAsync/Current.
// Get the next element from the enumerator. This requires asynchronously locking around MoveNext/Current.
TSource element;
lock (state)
await state.AcquireLock();
try
{
if (!state.Enumerator.MoveNext())
if (state.Cancellation.IsCancellationRequested || // check now that the lock has been acquired
!state.Enumerator.MoveNext())
{
break;
}

element = state.Enumerator.Current;
}
finally
{
state.ReleaseLock();
}

// If the remaining dop allows it and we've not yet queued the next worker, do so now. We wait
// until after we've grabbed an item from the enumerator to a) avoid unnecessary contention on the
Expand Down Expand Up @@ -249,20 +255,11 @@ private static Task ForEachAsync<TSource>(IAsyncEnumerable<TSource> source, int
{
// Get the next element from the enumerator. This requires asynchronously locking around MoveNextAsync/Current.
TSource element;
await state.AcquireLock();
try
{
// TODO https://github.com/dotnet/runtime/issues/22144:
// Use a no-throwing await if/when one is available built-in.
await state.Lock.WaitAsync(state.Cancellation.Token);
}
catch (OperationCanceledException)
{
break;
}

try
{
if (!await state.Enumerator.MoveNextAsync())
if (state.Cancellation.IsCancellationRequested || // check now that the lock has been acquired
!await state.Enumerator.MoveNextAsync())
{
break;
}
Expand All @@ -271,7 +268,7 @@ private static Task ForEachAsync<TSource>(IAsyncEnumerable<TSource> source, int
}
finally
{
state.Lock.Release();
state.ReleaseLock();
}

// If the remaining dop allows it and we've not yet queued the next worker, do so now. We wait
Expand Down Expand Up @@ -354,6 +351,8 @@ private abstract class ForEachAsyncState<TSource> : TaskCompletionSource, IThrea
private readonly TaskScheduler _scheduler;
/// <summary>The <see cref="ExecutionContext"/> present at the time of the ForEachAsync invocation. This is only used if on the default scheduler.</summary>
private readonly ExecutionContext? _executionContext;
/// <summary>Semaphore used to provide exclusive access to the enumerator.</summary>
private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1);

/// <summary>The number of outstanding workers. When this hits 0, the operation has completed.</summary>
private int _completionRefCount;
Expand Down Expand Up @@ -417,6 +416,21 @@ public void QueueWorkerIfDopAvailable()
/// <returns>true if this is the last worker to complete iterating; otherwise, false.</returns>
public bool SignalWorkerCompletedIterating() => Interlocked.Decrement(ref _completionRefCount) == 0;

/// <summary>Asynchronously acquires exclusive access to the enumerator.</summary>
public Task AcquireLock() =>
// We explicitly don't pass this.Cancellation to WaitAsync. Doing so adds overhead, and it isn't actually
// necessary. All of the operations that monitor the lock are part of the same ForEachAsync operation, and the Task
// returned from ForEachAsync can't complete until all of the constituent operations have completed, including whoever
// holds the lock while this worker is waiting on the lock. Thus, the lock will need to be released for the overall
// operation to complete. Passing the token would allow the overall operation to potentially complete a bit faster in
// the face of cancellation, in exchange for making it a bit slower / more overhead in the common case of cancellation
// not being requested. We want to optimize for the latter. This also then avoids an exception throw / catch when
// cancellation is requested.
_lock.WaitAsync(CancellationToken.None);

/// <summary>Relinquishes exclusive access to the enumerator.</summary>
public void ReleaseLock() => _lock.Release();

/// <summary>Stores an exception and triggers cancellation in order to alert all workers to stop as soon as possible.</summary>
/// <param name="e">The exception.</param>
public void RecordException(Exception e)
Expand Down Expand Up @@ -444,6 +458,7 @@ public void Complete()
else if (_exceptions is null)
{
// Everything completed successfully.
Debug.Assert(!Cancellation.IsCancellationRequested);
taskSet = TrySetResult();
}
else
Expand Down Expand Up @@ -500,7 +515,6 @@ public void Dispose()
/// <typeparam name="TSource">Specifies the type of data being enumerated.</typeparam>
private sealed class AsyncForEachAsyncState<TSource> : ForEachAsyncState<TSource>, IAsyncDisposable
{
public readonly SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public readonly IAsyncEnumerator<TSource> Enumerator;

public AsyncForEachAsyncState(
Expand Down