diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index 0045c9e53dbf3d..b429a94e5160d3 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -267,17 +267,6 @@ private static void TransparentAwait(object o) AsyncSuspend(sentinelContinuation); } - private interface IRuntimeAsyncTaskOps - { - static abstract Action GetContinuationAction(T task); - static abstract Continuation MoveContinuationState(T task); - static abstract void SetContinuationState(T task, Continuation value); - static abstract bool SetCompleted(T task); - static abstract void PostToSyncContext(T task, SynchronizationContext syncCtx); - static abstract void ValueTaskSourceOnCompleted(T task, IValueTaskSourceNotifier vtsNotifier, ValueTaskSourceOnCompletedFlags configFlags); - static abstract ref byte GetResultStorage(T task); - } - // Represents execution of a chain of suspended and resuming runtime // async functions. private sealed class RuntimeAsyncTask : Task, ITaskCompletionAction @@ -288,195 +277,144 @@ public RuntimeAsyncTask() // Ensure that state object isn't published out for others to see. Debug.Assert((m_stateFlags & (int)InternalTaskOptions.PromiseTask) != 0, "Expected state flags to already be configured."); Debug.Assert(m_stateObject is null, "Expected to be able to use the state object field for Continuation."); - m_action = MoveNext; + m_action = DispatchContinuations; m_stateFlags |= (int)InternalTaskOptions.HiddenState; } internal override void ExecuteFromThreadPool(Thread threadPoolThread) { - MoveNext(); - } - - private void MoveNext() - { - RuntimeAsyncTaskCore.DispatchContinuations, Ops>(this); - } - - public void HandleSuspended() - { - RuntimeAsyncTaskCore.HandleSuspended, Ops>(this); + DispatchContinuations(); } void ITaskCompletionAction.Invoke(Task completingTask) { - MoveNext(); + DispatchContinuations(); } bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true; - private static readonly SendOrPostCallback s_postCallback = static state => - { - Debug.Assert(state is RuntimeAsyncTask); - ((RuntimeAsyncTask)state).MoveNext(); - }; + private Action GetContinuationAction() => (Action)m_action!; - public static readonly Action s_runContinuationAction = static state => + private Continuation MoveContinuationState() { - Debug.Assert(state is RuntimeAsyncTask); - ((RuntimeAsyncTask)state).MoveNext(); - }; - - private struct Ops : IRuntimeAsyncTaskOps> - { - public static Action GetContinuationAction(RuntimeAsyncTask task) => (Action)task.m_action!; - public static Continuation MoveContinuationState(RuntimeAsyncTask task) - { - Continuation continuation = (Continuation)task.m_stateObject!; - task.m_stateObject = null; - return continuation; - } - - public static void SetContinuationState(RuntimeAsyncTask task, Continuation value) - { - Debug.Assert(task.m_stateObject == null); - task.m_stateObject = value; - } - - public static bool SetCompleted(RuntimeAsyncTask task) - { - return task.TrySetResult(task.m_result); - } - - public static void PostToSyncContext(RuntimeAsyncTask task, SynchronizationContext syncContext) - { - syncContext.Post(s_postCallback, task); - } - - public static void ValueTaskSourceOnCompleted(RuntimeAsyncTask task, IValueTaskSourceNotifier vtsNotifier, ValueTaskSourceOnCompletedFlags configFlags) - { - vtsNotifier.OnCompleted(s_runContinuationAction, task, configFlags); - } - - public static ref byte GetResultStorage(RuntimeAsyncTask task) => ref Unsafe.As(ref task.m_result); + Continuation continuation = (Continuation)m_stateObject!; + m_stateObject = null; + return continuation; } - } - // Represents execution of a chain of suspended and resuming runtime - // async functions. - private sealed class RuntimeAsyncTask : Task, ITaskCompletionAction - { - public RuntimeAsyncTask() + private void SetContinuationState(Continuation value) { - // We use the base Task's state object field to store the Continuation while posting the task around. - // Ensure that state object isn't published out for others to see. - Debug.Assert((m_stateFlags & (int)InternalTaskOptions.PromiseTask) != 0, "Expected state flags to already be configured."); - Debug.Assert(m_stateObject is null, "Expected to be able to use the state object field for Continuation."); - m_action = MoveNext; - m_stateFlags |= (int)InternalTaskOptions.HiddenState; + Debug.Assert(m_stateObject == null); + m_stateObject = value; } - internal override void ExecuteFromThreadPool(Thread threadPoolThread) + internal void HandleSuspended() { - MoveNext(); - } + ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; - private void MoveNext() - { - RuntimeAsyncTaskCore.DispatchContinuations(this); - } + RestoreContextsOnSuspension(false, state.ExecutionContext, state.SynchronizationContext); - public void HandleSuspended() - { - RuntimeAsyncTaskCore.HandleSuspended(this); - } + ICriticalNotifyCompletion? critNotifier = state.CriticalNotifier; + INotifyCompletion? notifier = state.Notifier; + IValueTaskSourceNotifier? vtsNotifier = state.ValueTaskSourceNotifier; + Task? taskNotifier = state.TaskNotifier; - void ITaskCompletionAction.Invoke(Task completingTask) - { - MoveNext(); - } + state.CriticalNotifier = null; + state.Notifier = null; + state.ValueTaskSourceNotifier = null; + state.TaskNotifier = null; + state.ExecutionContext = null; + state.SynchronizationContext = null; - bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true; + Continuation sentinelContinuation = state.SentinelContinuation!; + Continuation headContinuation = sentinelContinuation.Next!; + sentinelContinuation.Next = null; - private static readonly SendOrPostCallback s_postCallback = static state => - { - Debug.Assert(state is RuntimeAsyncTask); - ((RuntimeAsyncTask)state).MoveNext(); - }; + // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter. + // These never have special continuation context handling. + const ContinuationFlags continueFlags = + ContinuationFlags.ContinueOnCapturedSynchronizationContext | + ContinuationFlags.ContinueOnThreadPool | + ContinuationFlags.ContinueOnCapturedTaskScheduler; - public static readonly Action s_runContinuationAction = static state => - { - Debug.Assert(state is RuntimeAsyncTask); - ((RuntimeAsyncTask)state).MoveNext(); - }; + Debug.Assert((headContinuation.Flags & continueFlags) == 0); - private struct Ops : IRuntimeAsyncTaskOps - { - public static Action GetContinuationAction(RuntimeAsyncTask task) => (Action)task.m_action!; - public static Continuation MoveContinuationState(RuntimeAsyncTask task) - { - Continuation continuation = (Continuation)task.m_stateObject!; - task.m_stateObject = null; - return continuation; - } + SetContinuationState(headContinuation); - public static void SetContinuationState(RuntimeAsyncTask task, Continuation value) + try { - Debug.Assert(task.m_stateObject == null); - task.m_stateObject = value; - } + if (critNotifier != null) + { + critNotifier.UnsafeOnCompleted(GetContinuationAction()); + } + else if (taskNotifier != null) + { + // Runtime async callable wrapper for task returning + // method. This implements the context transparent + // forwarding and makes these wrappers minimal cost. + if (!taskNotifier.TryAddCompletionAction(this)) + { + ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true); + } + } + else if (vtsNotifier != null) + { + // The awaiter must inform the ValueTaskSource on whether the continuation + // wants to run on a context, although the source may decide to ignore the suggestion. + // Since the behavior of the source takes precedence, we clear the context flags of + // the awaiting continuation (so it will run transparently on what the source decides) + // and then tell the source if the awaiting frame prefers to continue on a context. + // The reason why we do it here and not when the notifier is created is because + // the continuation chain builds from the innermost frame out and at the time when the + // notifier is created we do not know yet if the caller wants to continue on a context. + ValueTaskSourceOnCompletedFlags configFlags = ValueTaskSourceOnCompletedFlags.None; - public static bool SetCompleted(RuntimeAsyncTask task) - { - return task.TrySetResult(); - } + // Skip to a nontransparent/user continuation. Such continuaton must exist. + // Since we see a VTS notifier, something was directly or indirectly + // awaiting an async thunk for a ValueTask-returning method. + // That can only happen in nontransparent/user code. + Continuation nextUserContinuation = headContinuation.Next!; + while ((nextUserContinuation.Flags & continueFlags) == 0 && nextUserContinuation.Next != null) + { + nextUserContinuation = nextUserContinuation.Next; + } - public static void PostToSyncContext(RuntimeAsyncTask task, SynchronizationContext syncContext) - { - syncContext.Post(s_postCallback, task); - } + ContinuationFlags continuationFlags = nextUserContinuation.Flags; + const ContinuationFlags continueOnContextFlags = + ContinuationFlags.ContinueOnCapturedSynchronizationContext | + ContinuationFlags.ContinueOnCapturedTaskScheduler; - public static void ValueTaskSourceOnCompleted(RuntimeAsyncTask task, IValueTaskSourceNotifier vtsNotifier, ValueTaskSourceOnCompletedFlags configFlags) + if ((continuationFlags & continueOnContextFlags) != 0) + { + // if await has captured some context, inform the source + configFlags |= ValueTaskSourceOnCompletedFlags.UseSchedulingContext; + } + + // Clear continuation flags, so that continuation runs transparently + nextUserContinuation.Flags &= ~continueFlags; + vtsNotifier.OnCompleted(s_runContinuationAction, this, configFlags); + } + else + { + Debug.Assert(notifier != null); + notifier.OnCompleted(GetContinuationAction()); + } + } + catch (Exception ex) { - vtsNotifier.OnCompleted(s_runContinuationAction, task, configFlags); + Task.ThrowAsync(ex, targetContext: null); } - - public static ref byte GetResultStorage(RuntimeAsyncTask task) => ref Unsafe.NullRef(); } - } - private static class RuntimeAsyncTaskCore - { - [StructLayout(LayoutKind.Explicit)] - private unsafe ref struct DispatcherInfo - { - // Dispatcher info for next dispatcher present on stack, or - // null if none. - [FieldOffset(0)] - public DispatcherInfo* Next; - - // Next continuation the dispatcher will process. -#if TARGET_64BIT - [FieldOffset(8)] -#else - [FieldOffset(4)] -#endif - public Continuation? NextContinuation; - } - - // Information about current task dispatching, to be used for async - // stackwalking. - [ThreadStatic] - private static unsafe DispatcherInfo* t_dispatcherInfo; - - public static unsafe void DispatchContinuations(T task) where T : Task, ITaskCompletionAction where TOps : IRuntimeAsyncTaskOps + private unsafe void DispatchContinuations() { ExecutionAndSyncBlockStore contexts = default; contexts.Push(); - DispatcherInfo dispatcherInfo; - dispatcherInfo.Next = t_dispatcherInfo; - dispatcherInfo.NextContinuation = TOps.MoveContinuationState(task); - t_dispatcherInfo = &dispatcherInfo; + RuntimeAsyncTaskCore.DispatcherInfo dispatcherInfo; + dispatcherInfo.Next = RuntimeAsyncTaskCore.t_dispatcherInfo; + dispatcherInfo.NextContinuation = MoveContinuationState(); + RuntimeAsyncTaskCore.t_dispatcherInfo = &dispatcherInfo; while (true) { @@ -487,15 +425,15 @@ public static unsafe void DispatchContinuations(T task) where T : Task, Continuation? nextContinuation = curContinuation.Next; dispatcherInfo.NextContinuation = nextContinuation; - ref byte resultLoc = ref nextContinuation != null ? ref nextContinuation.GetResultStorageOrNull() : ref TOps.GetResultStorage(task); + ref byte resultLoc = ref nextContinuation != null ? ref nextContinuation.GetResultStorageOrNull() : ref GetResultStorage(); Continuation? newContinuation = curContinuation.ResumeInfo->Resume(curContinuation, ref resultLoc); if (newContinuation != null) { newContinuation.Next = nextContinuation; - HandleSuspended(task); + HandleSuspended(); contexts.Pop(); - t_dispatcherInfo = dispatcherInfo.Next; + RuntimeAsyncTaskCore.t_dispatcherInfo = dispatcherInfo.Next; return; } } @@ -506,12 +444,12 @@ public static unsafe void DispatchContinuations(T task) where T : Task, { // Tail of AsyncTaskMethodBuilderT.SetException bool successfullySet = ex is OperationCanceledException oce ? - task.TrySetCanceled(oce.CancellationToken, oce) : - task.TrySetException(ex); + TrySetCanceled(oce.CancellationToken, oce) : + TrySetException(ex); contexts.Pop(); - t_dispatcherInfo = dispatcherInfo.Next; + RuntimeAsyncTaskCore.t_dispatcherInfo = dispatcherInfo.Next; if (!successfullySet) { @@ -527,11 +465,11 @@ public static unsafe void DispatchContinuations(T task) where T : Task, if (dispatcherInfo.NextContinuation == null) { - bool successfullySet = TOps.SetCompleted(task); + bool successfullySet = TrySetResult(m_result); contexts.Pop(); - t_dispatcherInfo = dispatcherInfo.Next; + RuntimeAsyncTaskCore.t_dispatcherInfo = dispatcherInfo.Next; if (!successfullySet) { @@ -541,15 +479,17 @@ public static unsafe void DispatchContinuations(T task) where T : Task, return; } - if (QueueContinuationFollowUpActionIfNecessary(task, dispatcherInfo.NextContinuation)) + if (QueueContinuationFollowUpActionIfNecessary(dispatcherInfo.NextContinuation)) { contexts.Pop(); - t_dispatcherInfo = dispatcherInfo.Next; + RuntimeAsyncTaskCore.t_dispatcherInfo = dispatcherInfo.Next; return; } } } + private ref byte GetResultStorage() => ref Unsafe.As(ref m_result); + private static Continuation? UnwindToPossibleHandler(Continuation? continuation) { while (true) @@ -561,105 +501,7 @@ public static unsafe void DispatchContinuations(T task) where T : Task, } } - public static void HandleSuspended(T task) where T : Task, ITaskCompletionAction where TOps : IRuntimeAsyncTaskOps - { - ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; - - RestoreContextsOnSuspension(false, state.ExecutionContext, state.SynchronizationContext); - - ICriticalNotifyCompletion? critNotifier = state.CriticalNotifier; - INotifyCompletion? notifier = state.Notifier; - IValueTaskSourceNotifier? vtsNotifier = state.ValueTaskSourceNotifier; - Task? taskNotifier = state.TaskNotifier; - - state.CriticalNotifier = null; - state.Notifier = null; - state.ValueTaskSourceNotifier = null; - state.TaskNotifier = null; - state.ExecutionContext = null; - state.SynchronizationContext = null; - - Continuation sentinelContinuation = state.SentinelContinuation!; - Continuation headContinuation = sentinelContinuation.Next!; - sentinelContinuation.Next = null; - - // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter. - // These never have special continuation context handling. - const ContinuationFlags continueFlags = - ContinuationFlags.ContinueOnCapturedSynchronizationContext | - ContinuationFlags.ContinueOnThreadPool | - ContinuationFlags.ContinueOnCapturedTaskScheduler; - - Debug.Assert((headContinuation.Flags & continueFlags) == 0); - - TOps.SetContinuationState(task, headContinuation); - - try - { - if (critNotifier != null) - { - critNotifier.UnsafeOnCompleted(TOps.GetContinuationAction(task)); - } - else if (taskNotifier != null) - { - // Runtime async callable wrapper for task returning - // method. This implements the context transparent - // forwarding and makes these wrappers minimal cost. - if (!taskNotifier.TryAddCompletionAction(task)) - { - ThreadPool.UnsafeQueueUserWorkItemInternal(task, preferLocal: true); - } - } - else if (vtsNotifier != null) - { - // The awaiter must inform the ValueTaskSource on whether the continuation - // wants to run on a context, although the source may decide to ignore the suggestion. - // Since the behavior of the source takes precedence, we clear the context flags of - // the awaiting continuation (so it will run transparently on what the source decides) - // and then tell the source if the awaiting frame prefers to continue on a context. - // The reason why we do it here and not when the notifier is created is because - // the continuation chain builds from the innermost frame out and at the time when the - // notifier is created we do not know yet if the caller wants to continue on a context. - ValueTaskSourceOnCompletedFlags configFlags = ValueTaskSourceOnCompletedFlags.None; - - // Skip to a nontransparent/user continuation. Such continuaton must exist. - // Since we see a VTS notifier, something was directly or indirectly - // awaiting an async thunk for a ValueTask-returning method. - // That can only happen in nontransparent/user code. - Continuation nextUserContinuation = headContinuation.Next!; - while ((nextUserContinuation.Flags & continueFlags) == 0 && nextUserContinuation.Next != null) - { - nextUserContinuation = nextUserContinuation.Next; - } - - ContinuationFlags continuationFlags = nextUserContinuation.Flags; - const ContinuationFlags continueOnContextFlags = - ContinuationFlags.ContinueOnCapturedSynchronizationContext | - ContinuationFlags.ContinueOnCapturedTaskScheduler; - - if ((continuationFlags & continueOnContextFlags) != 0) - { - // if await has captured some context, inform the source - configFlags |= ValueTaskSourceOnCompletedFlags.UseSchedulingContext; - } - - // Clear continuation flags, so that continuation runs transparently - nextUserContinuation.Flags &= ~continueFlags; - TOps.ValueTaskSourceOnCompleted(task, vtsNotifier, configFlags); - } - else - { - Debug.Assert(notifier != null); - notifier.OnCompleted(TOps.GetContinuationAction(task)); - } - } - catch (Exception ex) - { - Task.ThrowAsync(ex, targetContext: null); - } - } - - private static bool QueueContinuationFollowUpActionIfNecessary(T task, Continuation continuation) where T : Task where TOps : IRuntimeAsyncTaskOps + private bool QueueContinuationFollowUpActionIfNecessary(Continuation continuation) { if ((continuation.Flags & ContinuationFlags.ContinueOnThreadPool) != 0) { @@ -674,8 +516,8 @@ private static bool QueueContinuationFollowUpActionIfNecessary(T task, } } - TOps.SetContinuationState(task, continuation); - ThreadPool.UnsafeQueueUserWorkItemInternal(task, preferLocal: true); + SetContinuationState(continuation); + ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true); return true; } @@ -691,11 +533,11 @@ private static bool QueueContinuationFollowUpActionIfNecessary(T task, return false; } - TOps.SetContinuationState(task, continuation); + SetContinuationState(continuation); try { - TOps.PostToSyncContext(task, continuationSyncCtx); + continuationSyncCtx.Post(s_postCallback, this); } catch (Exception ex) { @@ -711,9 +553,9 @@ private static bool QueueContinuationFollowUpActionIfNecessary(T task, Debug.Assert(continuationContext is TaskScheduler { }); TaskScheduler sched = (TaskScheduler)continuationContext; - TOps.SetContinuationState(task, continuation); + SetContinuationState(continuation); // TODO: We do not need TaskSchedulerAwaitTaskContinuation here, just need to refactor its Run method... - var taskSchedCont = new TaskSchedulerAwaitTaskContinuation(sched, TOps.GetContinuationAction(task), flowExecutionContext: false); + var taskSchedCont = new TaskSchedulerAwaitTaskContinuation(sched, GetContinuationAction(), flowExecutionContext: false); taskSchedCont.Run(Task.CompletedTask, canInlineContinuationTask: true); return true; @@ -721,6 +563,43 @@ private static bool QueueContinuationFollowUpActionIfNecessary(T task, return false; } + + private static readonly SendOrPostCallback s_postCallback = static state => + { + Debug.Assert(state is RuntimeAsyncTask); + ((RuntimeAsyncTask)state).DispatchContinuations(); + }; + + private static readonly Action s_runContinuationAction = static state => + { + Debug.Assert(state is RuntimeAsyncTask); + ((RuntimeAsyncTask)state).DispatchContinuations(); + }; + } + + internal static class RuntimeAsyncTaskCore + { + [StructLayout(LayoutKind.Explicit)] + internal unsafe ref struct DispatcherInfo + { + // Dispatcher info for next dispatcher present on stack, or + // null if none. + [FieldOffset(0)] + public DispatcherInfo* Next; + + // Next continuation the dispatcher will process. +#if TARGET_64BIT + [FieldOffset(8)] +#else + [FieldOffset(4)] +#endif + public Continuation? NextContinuation; + } + + // Information about current task dispatching, to be used for async + // stackwalking. + [ThreadStatic] + internal static unsafe DispatcherInfo* t_dispatcherInfo; } // Change return type to RuntimeAsyncTask -- no benefit since this is used for Task returning thunks only @@ -736,7 +615,7 @@ private static bool QueueContinuationFollowUpActionIfNecessary(T task, private static Task FinalizeTaskReturningThunk() { - RuntimeAsyncTask result = new(); + RuntimeAsyncTask result = new(); result.HandleSuspended(); return result; }