Skip to content

Commit adb1091

Browse files
authored
Make runtime async callable thunks transparent (#120386)
Runtime async callable thunks were using `TaskAwaiter` directly, but that has the normal await semantics which will either post continuations to the captured synchronization context or to the thread pool. This introduces an observable behavior change with async1 where sometimes even configured awaits will end up posting back to a captured synchronization context. For example, consider an example like: ```csharp private static async Task Foo() { SynchronizationContext.SetSynchronizationContext(new TrackingSynchronizationContext()); await Task.Delay(1000).ConfigureAwait(false); } ```csharp Before this change the runtime async call to `Task.Delay` creates a runtime async callback thunk that roughly looks like ```csharp Task DelayThunk(int time) { TaskAwaiter awaiter = Task.Delay(time).GetAwaiter(); if (!await.IsCompleted) AsyncHelpers.UnsafeAwaiterAwaiter(awaiter); awaiter.GetResult(); } ``` however, when this thunk is called we end up posting back to `TrackingSynchronizationContext`, before the continuation for `Foo` then must move its continuation back to the thread pool. This PR fixes this and makes the thunks transparent in context behavior. At the same time it also optimizes the runtime async -> async1 path to be more efficient in the suspension case: the async1 task now directly invokes the runtime async infrastructure as its continuion, instead of going through multiple layers of indirection. I also fixed a bug where we could end up invoking the wrong variant of OnCompleted for non-unsafe notifiers if they implemented both INotifyCompletion and ICriticalNotifyCompletion. I have also renamed ThunkTask -> RuntimeAsyncTask.
1 parent 8d3c401 commit adb1091

File tree

9 files changed

+294
-198
lines changed

9 files changed

+294
-198
lines changed

src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs

Lines changed: 104 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ public static partial class AsyncHelpers
140140
private struct RuntimeAsyncAwaitState
141141
{
142142
public Continuation? SentinelContinuation;
143+
public ICriticalNotifyCompletion? CriticalNotifier;
143144
public INotifyCompletion? Notifier;
145+
public Task? CalledTask;
144146
}
145147

146148
[ThreadStatic]
@@ -203,7 +205,21 @@ private static unsafe object AllocContinuationResultBox(void* ptr)
203205
return RuntimeTypeHandle.InternalAllocNoChecks((MethodTable*)pMT);
204206
}
205207

206-
private interface IThunkTaskOps<T>
208+
[BypassReadyToRun]
209+
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
210+
[RequiresPreviewFeatures]
211+
private static void TransparentAwaitTask(Task t)
212+
{
213+
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
214+
Continuation? sentinelContinuation = state.SentinelContinuation;
215+
if (sentinelContinuation == null)
216+
state.SentinelContinuation = sentinelContinuation = new Continuation();
217+
218+
state.CalledTask = t;
219+
AsyncSuspend(sentinelContinuation);
220+
}
221+
222+
private interface IRuntimeAsyncTaskOps<T>
207223
{
208224
static abstract Action GetContinuationAction(T task);
209225
static abstract Continuation GetContinuationState(T task);
@@ -212,9 +228,12 @@ private interface IThunkTaskOps<T>
212228
static abstract void PostToSyncContext(T task, SynchronizationContext syncCtx);
213229
}
214230

215-
private sealed class ThunkTask<T> : Task<T>
231+
/// <summary>
232+
/// Represents a wrapped runtime async operation.
233+
/// </summary>
234+
private sealed class RuntimeAsyncTask<T> : Task<T>, ITaskCompletionAction
216235
{
217-
public ThunkTask()
236+
public RuntimeAsyncTask()
218237
{
219238
// We use the base Task's state object field to store the Continuation while posting the task around.
220239
// Ensure that state object isn't published out for others to see.
@@ -231,31 +250,38 @@ internal override void ExecuteFromThreadPool(Thread threadPoolThread)
231250

232251
private void MoveNext()
233252
{
234-
ThunkTaskCore.MoveNext<ThunkTask<T>, Ops>(this);
253+
RuntimeAsyncTaskCore.DispatchContinuations<RuntimeAsyncTask<T>, Ops>(this);
235254
}
236255

237256
public void HandleSuspended()
238257
{
239-
ThunkTaskCore.HandleSuspended<ThunkTask<T>, Ops>(this);
258+
RuntimeAsyncTaskCore.HandleSuspended<RuntimeAsyncTask<T>, Ops>(this);
259+
}
260+
261+
void ITaskCompletionAction.Invoke(Task completingTask)
262+
{
263+
MoveNext();
240264
}
241265

266+
bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true;
267+
242268
private static readonly SendOrPostCallback s_postCallback = static state =>
243269
{
244-
Debug.Assert(state is ThunkTask<T>);
245-
((ThunkTask<T>)state).MoveNext();
270+
Debug.Assert(state is RuntimeAsyncTask<T>);
271+
((RuntimeAsyncTask<T>)state).MoveNext();
246272
};
247273

248-
private struct Ops : IThunkTaskOps<ThunkTask<T>>
274+
private struct Ops : IRuntimeAsyncTaskOps<RuntimeAsyncTask<T>>
249275
{
250-
public static Action GetContinuationAction(ThunkTask<T> task) => (Action)task.m_action!;
251-
public static void MoveNext(ThunkTask<T> task) => task.MoveNext();
252-
public static Continuation GetContinuationState(ThunkTask<T> task) => (Continuation)task.m_stateObject!;
253-
public static void SetContinuationState(ThunkTask<T> task, Continuation value)
276+
public static Action GetContinuationAction(RuntimeAsyncTask<T> task) => (Action)task.m_action!;
277+
public static void MoveNext(RuntimeAsyncTask<T> task) => task.MoveNext();
278+
public static Continuation GetContinuationState(RuntimeAsyncTask<T> task) => (Continuation)task.m_stateObject!;
279+
public static void SetContinuationState(RuntimeAsyncTask<T> task, Continuation value)
254280
{
255281
task.m_stateObject = value;
256282
}
257283

258-
public static bool SetCompleted(ThunkTask<T> task, Continuation continuation)
284+
public static bool SetCompleted(RuntimeAsyncTask<T> task, Continuation continuation)
259285
{
260286
T result;
261287
if (RuntimeHelpers.IsReferenceOrContainsReferences<T>())
@@ -277,16 +303,19 @@ public static bool SetCompleted(ThunkTask<T> task, Continuation continuation)
277303
return task.TrySetResult(result);
278304
}
279305

280-
public static void PostToSyncContext(ThunkTask<T> task, SynchronizationContext syncContext)
306+
public static void PostToSyncContext(RuntimeAsyncTask<T> task, SynchronizationContext syncContext)
281307
{
282308
syncContext.Post(s_postCallback, task);
283309
}
284310
}
285311
}
286312

287-
private sealed class ThunkTask : Task
313+
/// <summary>
314+
/// Represents a wrapped runtime async operation.
315+
/// </summary>
316+
private sealed class RuntimeAsyncTask : Task, ITaskCompletionAction
288317
{
289-
public ThunkTask()
318+
public RuntimeAsyncTask()
290319
{
291320
// We use the base Task's state object field to store the Continuation while posting the task around.
292321
// Ensure that state object isn't published out for others to see.
@@ -303,45 +332,52 @@ internal override void ExecuteFromThreadPool(Thread threadPoolThread)
303332

304333
private void MoveNext()
305334
{
306-
ThunkTaskCore.MoveNext<ThunkTask, Ops>(this);
335+
RuntimeAsyncTaskCore.DispatchContinuations<RuntimeAsyncTask, Ops>(this);
307336
}
308337

309338
public void HandleSuspended()
310339
{
311-
ThunkTaskCore.HandleSuspended<ThunkTask, Ops>(this);
340+
RuntimeAsyncTaskCore.HandleSuspended<RuntimeAsyncTask, Ops>(this);
312341
}
313342

343+
void ITaskCompletionAction.Invoke(Task completingTask)
344+
{
345+
MoveNext();
346+
}
347+
348+
bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true;
349+
314350
private static readonly SendOrPostCallback s_postCallback = static state =>
315351
{
316-
Debug.Assert(state is ThunkTask);
317-
((ThunkTask)state).MoveNext();
352+
Debug.Assert(state is RuntimeAsyncTask);
353+
((RuntimeAsyncTask)state).MoveNext();
318354
};
319355

320-
private struct Ops : IThunkTaskOps<ThunkTask>
356+
private struct Ops : IRuntimeAsyncTaskOps<RuntimeAsyncTask>
321357
{
322-
public static Action GetContinuationAction(ThunkTask task) => (Action)task.m_action!;
323-
public static void MoveNext(ThunkTask task) => task.MoveNext();
324-
public static Continuation GetContinuationState(ThunkTask task) => (Continuation)task.m_stateObject!;
325-
public static void SetContinuationState(ThunkTask task, Continuation value)
358+
public static Action GetContinuationAction(RuntimeAsyncTask task) => (Action)task.m_action!;
359+
public static void MoveNext(RuntimeAsyncTask task) => task.MoveNext();
360+
public static Continuation GetContinuationState(RuntimeAsyncTask task) => (Continuation)task.m_stateObject!;
361+
public static void SetContinuationState(RuntimeAsyncTask task, Continuation value)
326362
{
327363
task.m_stateObject = value;
328364
}
329365

330-
public static bool SetCompleted(ThunkTask task, Continuation continuation)
366+
public static bool SetCompleted(RuntimeAsyncTask task, Continuation continuation)
331367
{
332368
return task.TrySetResult();
333369
}
334370

335-
public static void PostToSyncContext(ThunkTask task, SynchronizationContext syncContext)
371+
public static void PostToSyncContext(RuntimeAsyncTask task, SynchronizationContext syncContext)
336372
{
337373
syncContext.Post(s_postCallback, task);
338374
}
339375
}
340376
}
341377

342-
private static class ThunkTaskCore
378+
private static class RuntimeAsyncTaskCore
343379
{
344-
public static unsafe void MoveNext<T, TOps>(T task) where T : Task where TOps : IThunkTaskOps<T>
380+
public static unsafe void DispatchContinuations<T, TOps>(T task) where T : Task, ITaskCompletionAction where TOps : IRuntimeAsyncTaskOps<T>
345381
{
346382
ExecutionAndSyncBlockStore contexts = default;
347383
contexts.Push();
@@ -422,9 +458,20 @@ private static Continuation UnwindToPossibleHandler(Continuation continuation)
422458
}
423459
}
424460

425-
public static void HandleSuspended<T, TOps>(T task) where T : Task where TOps : IThunkTaskOps<T>
461+
public static void HandleSuspended<T, TOps>(T task) where T : Task, ITaskCompletionAction where TOps : IRuntimeAsyncTaskOps<T>
426462
{
427-
Continuation headContinuation = UnlinkHeadContinuation(out INotifyCompletion? notifier);
463+
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
464+
ICriticalNotifyCompletion? critNotifier = state.CriticalNotifier;
465+
INotifyCompletion? notifier = state.Notifier;
466+
Task? calledTask = state.CalledTask;
467+
468+
state.CriticalNotifier = null;
469+
state.Notifier = null;
470+
state.CalledTask = null;
471+
472+
Continuation sentinelContinuation = state.SentinelContinuation!;
473+
Continuation headContinuation = sentinelContinuation.Next!;
474+
sentinelContinuation.Next = null;
428475

429476
// Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter.
430477
// These never have special continuation handling.
@@ -438,9 +485,19 @@ public static void HandleSuspended<T, TOps>(T task) where T : Task where TOps :
438485

439486
try
440487
{
441-
if (notifier is ICriticalNotifyCompletion crit)
488+
if (critNotifier != null)
489+
{
490+
critNotifier.UnsafeOnCompleted(TOps.GetContinuationAction(task));
491+
}
492+
else if (calledTask != null)
442493
{
443-
crit.UnsafeOnCompleted(TOps.GetContinuationAction(task));
494+
// Runtime async callable wrapper for task returning
495+
// method. This implements the context transparent
496+
// forwarding and makes these wrappers minimal cost.
497+
if (!calledTask.TryAddCompletionAction(task))
498+
{
499+
ThreadPool.UnsafeQueueUserWorkItemInternal(task, preferLocal: true);
500+
}
444501
}
445502
else
446503
{
@@ -454,19 +511,7 @@ public static void HandleSuspended<T, TOps>(T task) where T : Task where TOps :
454511
}
455512
}
456513

457-
private static Continuation UnlinkHeadContinuation(out INotifyCompletion? notifier)
458-
{
459-
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
460-
notifier = state.Notifier;
461-
state.Notifier = null;
462-
463-
Continuation sentinelContinuation = state.SentinelContinuation!;
464-
Continuation head = sentinelContinuation.Next!;
465-
sentinelContinuation.Next = null;
466-
return head;
467-
}
468-
469-
private static bool QueueContinuationFollowUpActionIfNecessary<T, TOps>(T task, Continuation continuation) where T : Task where TOps : IThunkTaskOps<T>
514+
private static bool QueueContinuationFollowUpActionIfNecessary<T, TOps>(T task, Continuation continuation) where T : Task where TOps : IRuntimeAsyncTaskOps<T>
470515
{
471516
if ((continuation.Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL) != 0)
472517
{
@@ -554,7 +599,7 @@ private static bool QueueContinuationFollowUpActionIfNecessary<T, TOps>(T task,
554599

555600
continuation.Next = finalContinuation;
556601

557-
ThunkTask<T?> result = new();
602+
RuntimeAsyncTask<T?> result = new();
558603
result.HandleSuspended();
559604
return result;
560605
}
@@ -567,7 +612,7 @@ private static Task FinalizeTaskReturningThunk(Continuation continuation)
567612
};
568613
continuation.Next = finalContinuation;
569614

570-
ThunkTask result = new();
615+
RuntimeAsyncTask result = new();
571616
result.HandleSuspended();
572617
return result;
573618
}
@@ -679,5 +724,16 @@ private static void CaptureContinuationContext(SynchronizationContext syncCtx, r
679724

680725
flags |= CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL;
681726
}
727+
728+
internal static T CompletedTaskResult<T>(Task<T> task)
729+
{
730+
TaskAwaiter.ValidateEnd(task);
731+
return task.ResultOnSuccess;
732+
}
733+
734+
internal static void CompletedTask(Task task)
735+
{
736+
TaskAwaiter.ValidateEnd(task);
737+
}
682738
}
683739
}

0 commit comments

Comments
 (0)