diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index b9b6dbb87947c0..9f379245a309b2 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -61,6 +61,15 @@ internal enum CorInfoContinuationFlags // OSR method saved in the beginning of 'Data', or -1 if the continuation // belongs to a tier 0 method. CORINFO_CONTINUATION_OSR_IL_OFFSET_IN_DATA = 4, + // If this bit is set the continuation should continue on the thread + // pool. + CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL = 8, + // If this bit is set the continuation has a SynchronizationContext + // that we should continue on. + CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_SYNCHRONIZATION_CONTEXT = 16, + // If this bit is set the continuation has a TaskScheduler + // that we should continue on. + CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_TASK_SCHEDULER = 32, } internal sealed unsafe class Continuation @@ -93,6 +102,29 @@ internal sealed unsafe class Continuation // public byte[]? Data; public object?[]? GCData; + + public object GetContinuationContext() + { + int index = 0; + if ((Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_RESULT_IN_GCDATA) != 0) + index++; + if ((Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_NEEDS_EXCEPTION) != 0) + index++; + Debug.Assert(GCData != null && GCData.Length > index); + object? continuationContext = GCData[index]; + Debug.Assert(continuationContext != null); + return continuationContext; + } + + public void SetException(Exception ex) + { + int index = 0; + if ((Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_RESULT_IN_GCDATA) != 0) + index++; + + Debug.Assert(GCData != null && GCData.Length > index); + GCData[index] = ex; + } } public static partial class AsyncHelpers @@ -171,122 +203,338 @@ private static unsafe object AllocContinuationResultBox(void* ptr) return RuntimeTypeHandle.InternalAllocNoChecks((MethodTable*)pMT); } - // wrapper to await a notifier - private struct AwaitableProxy : ICriticalNotifyCompletion + private interface IThunkTaskOps { - private readonly INotifyCompletion _notifier; + static abstract Action GetContinuationAction(T task); + static abstract Continuation GetContinuationState(T task); + static abstract void SetContinuationState(T task, Continuation value); + static abstract bool SetCompleted(T task, Continuation continuation); + static abstract void PostToSyncContext(T task, SynchronizationContext syncCtx); + } - public AwaitableProxy(INotifyCompletion notifier) + private sealed class ThunkTask : Task + { + public ThunkTask() { - _notifier = notifier; + // We use the base Task's state object field to store the Continuation while posting the task around. + // Ensure that state object isn't published out for others to see. + Debug.Assert((m_stateFlags & (int)InternalTaskOptions.PromiseTask) != 0, "Expected state flags to already be configured."); + Debug.Assert(m_stateObject is null, "Expected to be able to use the state object field for Continuation."); + m_action = MoveNext; + m_stateFlags |= (int)InternalTaskOptions.HiddenState; } - public bool IsCompleted => false; + internal override void ExecuteFromThreadPool(Thread threadPoolThread) + { + MoveNext(); + } - public void OnCompleted(Action action) + private void MoveNext() { - _notifier!.OnCompleted(action); + ThunkTaskCore.MoveNext, Ops>(this); } - public AwaitableProxy GetAwaiter() { return this; } + public void HandleSuspended() + { + ThunkTaskCore.HandleSuspended, Ops>(this); + } - public void UnsafeOnCompleted(Action action) + private static readonly SendOrPostCallback s_postCallback = static state => { - if (_notifier is ICriticalNotifyCompletion criticalNotification) + Debug.Assert(state is ThunkTask); + ((ThunkTask)state).MoveNext(); + }; + + private struct Ops : IThunkTaskOps> + { + public static Action GetContinuationAction(ThunkTask task) => (Action)task.m_action!; + public static void MoveNext(ThunkTask task) => task.MoveNext(); + public static Continuation GetContinuationState(ThunkTask task) => (Continuation)task.m_stateObject!; + public static void SetContinuationState(ThunkTask task, Continuation value) { - criticalNotification.UnsafeOnCompleted(action); + task.m_stateObject = value; } - else + + public static bool SetCompleted(ThunkTask task, Continuation continuation) { - _notifier!.OnCompleted(action); + T result; + if (RuntimeHelpers.IsReferenceOrContainsReferences()) + { + if (typeof(T).IsValueType) + { + result = Unsafe.As(ref continuation.GCData![0]!.GetRawData()); + } + else + { + result = Unsafe.As(ref continuation.GCData![0]!); + } + } + else + { + result = Unsafe.As(ref continuation.Data![0]); + } + + return task.TrySetResult(result); } - } - public void GetResult() { } + public static void PostToSyncContext(ThunkTask task, SynchronizationContext syncContext) + { + syncContext.Post(s_postCallback, task); + } + } } - private static Continuation UnlinkHeadContinuation(out AwaitableProxy awaitableProxy) + private sealed class ThunkTask : Task { - ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; - awaitableProxy = new AwaitableProxy(state.Notifier!); - state.Notifier = null; - - Continuation sentinelContinuation = state.SentinelContinuation!; - Continuation head = sentinelContinuation.Next!; - sentinelContinuation.Next = null; - return head; - } + public ThunkTask() + { + // We use the base Task's state object field to store the Continuation while posting the task around. + // Ensure that state object isn't published out for others to see. + Debug.Assert((m_stateFlags & (int)InternalTaskOptions.PromiseTask) != 0, "Expected state flags to already be configured."); + Debug.Assert(m_stateObject is null, "Expected to be able to use the state object field for Continuation."); + m_action = MoveNext; + m_stateFlags |= (int)InternalTaskOptions.HiddenState; + } - // When a Task-returning thunk gets a continuation result - // it calls here to make a Task that awaits on the current async state. - // NOTE: This cannot be Runtime Async. Must use C# state machine or make one by hand. - private static async Task FinalizeTaskReturningThunk(Continuation continuation) - { - Continuation finalContinuation = new Continuation(); + internal override void ExecuteFromThreadPool(Thread threadPoolThread) + { + MoveNext(); + } - // Note that the exact location the return value is placed is tied - // into getAsyncResumptionStub in the VM, so do not change this - // without also changing that code (and the JIT). - if (RuntimeHelpers.IsReferenceOrContainsReferences()) + private void MoveNext() { - finalContinuation.Flags = CorInfoContinuationFlags.CORINFO_CONTINUATION_RESULT_IN_GCDATA | CorInfoContinuationFlags.CORINFO_CONTINUATION_NEEDS_EXCEPTION; - finalContinuation.GCData = new object[1]; + ThunkTaskCore.MoveNext(this); } - else + + public void HandleSuspended() { - finalContinuation.Flags = CorInfoContinuationFlags.CORINFO_CONTINUATION_NEEDS_EXCEPTION; - finalContinuation.Data = new byte[Unsafe.SizeOf()]; + ThunkTaskCore.HandleSuspended(this); } - continuation.Next = finalContinuation; + private static readonly SendOrPostCallback s_postCallback = static state => + { + Debug.Assert(state is ThunkTask); + ((ThunkTask)state).MoveNext(); + }; - while (true) + private struct Ops : IThunkTaskOps { - Continuation headContinuation = UnlinkHeadContinuation(out var awaitableProxy); - await awaitableProxy; - Continuation? finalResult = DispatchContinuations(headContinuation); - if (finalResult != null) + public static Action GetContinuationAction(ThunkTask task) => (Action)task.m_action!; + public static void MoveNext(ThunkTask task) => task.MoveNext(); + public static Continuation GetContinuationState(ThunkTask task) => (Continuation)task.m_stateObject!; + public static void SetContinuationState(ThunkTask task, Continuation value) { - Debug.Assert(finalResult == finalContinuation); - if (RuntimeHelpers.IsReferenceOrContainsReferences()) + task.m_stateObject = value; + } + + public static bool SetCompleted(ThunkTask task, Continuation continuation) + { + return task.TrySetResult(); + } + + public static void PostToSyncContext(ThunkTask task, SynchronizationContext syncContext) + { + syncContext.Post(s_postCallback, task); + } + } + } + + private static class ThunkTaskCore + { + public static unsafe void MoveNext(T task) where T : Task where TOps : IThunkTaskOps + { + ExecutionAndSyncBlockStore contexts = default; + contexts.Push(); + Continuation continuation = TOps.GetContinuationState(task); + + while (true) + { + try { - if (typeof(T).IsValueType) + Continuation? newContinuation = continuation.Resume(continuation); + + if (newContinuation != null) + { + newContinuation.Next = continuation.Next; + HandleSuspended(task); + contexts.Pop(); + return; + } + + Debug.Assert(continuation.Next != null); + continuation = continuation.Next; + } + catch (Exception ex) + { + Continuation nextContinuation = UnwindToPossibleHandler(continuation); + if (nextContinuation.Resume == null) + { + // Tail of AsyncTaskMethodBuilderT.SetException + bool successfullySet = ex is OperationCanceledException oce ? + task.TrySetCanceled(oce.CancellationToken, oce) : + task.TrySetException(ex); + + contexts.Pop(); + + if (!successfullySet) + { + ThrowHelper.ThrowInvalidOperationException(ExceptionResource.TaskT_TransitionToFinal_AlreadyCompleted); + } + + return; + } + + nextContinuation.SetException(ex); + + continuation = nextContinuation; + } + + if (continuation.Resume == null) + { + bool successfullySet = TOps.SetCompleted(task, continuation); + + contexts.Pop(); + + if (!successfullySet) { - return Unsafe.As(ref finalResult.GCData![0]!.GetRawData()); + ThrowHelper.ThrowInvalidOperationException(ExceptionResource.TaskT_TransitionToFinal_AlreadyCompleted); } - return Unsafe.As(ref finalResult.GCData![0]!); + return; + } + + if (QueueContinuationFollowUpActionIfNecessary(task, continuation)) + { + contexts.Pop(); + return; + } + } + } + + private static Continuation UnwindToPossibleHandler(Continuation continuation) + { + while (true) + { + Debug.Assert(continuation.Next != null); + continuation = continuation.Next; + if ((continuation.Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_NEEDS_EXCEPTION) != 0) + return continuation; + } + } + + public static void HandleSuspended(T task) where T : Task where TOps : IThunkTaskOps + { + Continuation headContinuation = UnlinkHeadContinuation(out INotifyCompletion? notifier); + + // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter. + // These never have special continuation handling. + const CorInfoContinuationFlags continueFlags = + CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_SYNCHRONIZATION_CONTEXT | + CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL | + CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_TASK_SCHEDULER; + Debug.Assert((headContinuation.Flags & continueFlags) == 0); + + TOps.SetContinuationState(task, headContinuation); + + try + { + if (notifier is ICriticalNotifyCompletion crit) + { + crit.UnsafeOnCompleted(TOps.GetContinuationAction(task)); } else { - return Unsafe.As(ref finalResult.Data![0]); + Debug.Assert(notifier != null); + notifier.OnCompleted(TOps.GetContinuationAction(task)); } } + catch (Exception ex) + { + Task.ThrowAsync(ex, targetContext: null); + } } - } - private static async Task FinalizeTaskReturningThunk(Continuation continuation) - { - Continuation finalContinuation = new Continuation + private static Continuation UnlinkHeadContinuation(out INotifyCompletion? notifier) { - Flags = CorInfoContinuationFlags.CORINFO_CONTINUATION_NEEDS_EXCEPTION, - }; - continuation.Next = finalContinuation; + ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; + notifier = state.Notifier; + state.Notifier = null; + + Continuation sentinelContinuation = state.SentinelContinuation!; + Continuation head = sentinelContinuation.Next!; + sentinelContinuation.Next = null; + return head; + } - while (true) + private static bool QueueContinuationFollowUpActionIfNecessary(T task, Continuation continuation) where T : Task where TOps : IThunkTaskOps { - Continuation headContinuation = UnlinkHeadContinuation(out var awaitableProxy); - await awaitableProxy; - Continuation? finalResult = DispatchContinuations(headContinuation); - if (finalResult != null) + if ((continuation.Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL) != 0) { - Debug.Assert(finalResult == finalContinuation); - return; + SynchronizationContext? ctx = Thread.CurrentThreadAssumedInitialized._synchronizationContext; + if (ctx == null || ctx.GetType() == typeof(SynchronizationContext)) + { + TaskScheduler? sched = TaskScheduler.InternalCurrent; + if (sched == null || sched == TaskScheduler.Default) + { + // Can inline + return false; + } + } + + TOps.SetContinuationState(task, continuation); + ThreadPool.UnsafeQueueUserWorkItemInternal(task, preferLocal: true); + return true; + } + + if ((continuation.Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_SYNCHRONIZATION_CONTEXT) != 0) + { + object continuationContext = continuation.GetContinuationContext(); + Debug.Assert(continuationContext is SynchronizationContext { }); + SynchronizationContext continuationSyncCtx = (SynchronizationContext)continuationContext; + + if (continuationSyncCtx == Thread.CurrentThreadAssumedInitialized._synchronizationContext) + { + // Inline + return false; + } + + TOps.SetContinuationState(task, continuation); + + try + { + TOps.PostToSyncContext(task, continuationSyncCtx); + } + catch (Exception ex) + { + Task.ThrowAsync(ex, targetContext: null); + } + + return true; + } + + if ((continuation.Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_TASK_SCHEDULER) != 0) + { + object continuationContext = continuation.GetContinuationContext(); + Debug.Assert(continuationContext is TaskScheduler { }); + TaskScheduler sched = (TaskScheduler)continuationContext; + + TOps.SetContinuationState(task, continuation); + // TODO: We do not need TaskSchedulerAwaitTaskContinuation here, just need to refactor its Run method... + var taskSchedCont = new TaskSchedulerAwaitTaskContinuation(sched, TOps.GetContinuationAction(task), flowExecutionContext: false); + taskSchedCont.Run(Task.CompletedTask, canInlineContinuationTask: true); + + return true; } + + return false; } } - private static async ValueTask FinalizeValueTaskReturningThunk(Continuation continuation) + // Change return type to ThunkTask -- no benefit since this is used for Task returning thunks only +#pragma warning disable CA1859 + // When a Task-returning thunk gets a continuation result + // it calls here to make a Task that awaits on the current async state. + private static Task FinalizeTaskReturningThunk(Continuation continuation) { Continuation finalContinuation = new Continuation(); @@ -306,32 +554,12 @@ private static async Task FinalizeTaskReturningThunk(Continuation continuation) continuation.Next = finalContinuation; - while (true) - { - Continuation headContinuation = UnlinkHeadContinuation(out var awaitableProxy); - await awaitableProxy; - Continuation? finalResult = DispatchContinuations(headContinuation); - if (finalResult != null) - { - Debug.Assert(finalResult == finalContinuation); - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - if (typeof(T).IsValueType) - { - return Unsafe.As(ref finalResult.GCData![0]!.GetRawData()); - } - - return Unsafe.As(ref finalResult.GCData![0]!); - } - else - { - return Unsafe.As(ref finalResult.Data![0]); - } - } - } + ThunkTask result = new(); + result.HandleSuspended(); + return result; } - private static async ValueTask FinalizeValueTaskReturningThunk(Continuation continuation) + private static Task FinalizeTaskReturningThunk(Continuation continuation) { Continuation finalContinuation = new Continuation { @@ -339,72 +567,21 @@ private static async ValueTask FinalizeValueTaskReturningThunk(Continuation cont }; continuation.Next = finalContinuation; - while (true) - { - Continuation headContinuation = UnlinkHeadContinuation(out var awaitableProxy); - await awaitableProxy; - Continuation? finalResult = DispatchContinuations(headContinuation); - if (finalResult != null) - { - Debug.Assert(finalResult == finalContinuation); - return; - } - } + ThunkTask result = new(); + result.HandleSuspended(); + return result; } - // Return a continuation object if that is the one which has the final - // result of the Task, if the real output of the series of continuations was - // an exception, it is allowed to propagate out. - // OR - // return NULL to indicate that this isn't yet done. - private static unsafe Continuation? DispatchContinuations(Continuation? continuation) + private static ValueTask FinalizeValueTaskReturningThunk(Continuation continuation) { - Debug.Assert(continuation != null); - - while (true) - { - Continuation? newContinuation; - try - { - newContinuation = continuation.Resume(continuation); - } - catch (Exception ex) - { - continuation = UnwindToPossibleHandler(continuation); - if (continuation.Resume == null) - { - throw; - } - - continuation.GCData![(continuation.Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_RESULT_IN_GCDATA) != 0 ? 1 : 0] = ex; - continue; - } - - if (newContinuation != null) - { - newContinuation.Next = continuation.Next; - return null; - } - - continuation = continuation.Next; - Debug.Assert(continuation != null); - - if (continuation.Resume == null) - { - return continuation; // Return the result containing Continuation - } - } + // We only come to these methods in the expensive case (already + // suspended), so ValueTask optimization here is not relevant. + return new ValueTask(FinalizeTaskReturningThunk(continuation)); } - private static Continuation UnwindToPossibleHandler(Continuation continuation) + private static ValueTask FinalizeValueTaskReturningThunk(Continuation continuation) { - while (true) - { - Debug.Assert(continuation.Next != null); - continuation = continuation.Next; - if ((continuation.Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_NEEDS_EXCEPTION) != 0) - return continuation; - } + return new ValueTask(FinalizeTaskReturningThunk(continuation)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -423,5 +600,26 @@ private static void RestoreExecutionContext(ExecutionContext? previousExecutionC ExecutionContext.RestoreChangedContextToThread(thread, previousExecutionCtx, currentExecutionCtx); } } + + private static void CaptureContinuationContext(ref object context, ref CorInfoContinuationFlags flags) + { + SynchronizationContext? syncCtx = Thread.CurrentThreadAssumedInitialized._synchronizationContext; + if (syncCtx != null && syncCtx.GetType() != typeof(SynchronizationContext)) + { + flags |= CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_SYNCHRONIZATION_CONTEXT; + context = syncCtx; + return; + } + + TaskScheduler? sched = TaskScheduler.InternalCurrent; + if (sched != null && sched != TaskScheduler.Default) + { + flags |= CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_TASK_SCHEDULER; + context = sched; + return; + } + + flags |= CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL; + } } } diff --git a/src/coreclr/inc/corinfo.h b/src/coreclr/inc/corinfo.h index 6182619dbbf77e..313d67a11ec6ee 100644 --- a/src/coreclr/inc/corinfo.h +++ b/src/coreclr/inc/corinfo.h @@ -1712,6 +1712,15 @@ enum CorInfoContinuationFlags // OSR method saved in the beginning of 'Data', or -1 if the continuation // belongs to a tier 0 method. CORINFO_CONTINUATION_OSR_IL_OFFSET_IN_DATA = 4, + // If this bit is set the continuation should continue on the thread + // pool. + CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL = 8, + // If this bit is set the continuation has a SynchronizationContext + // that we should continue on. + CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_SYNCHRONIZATION_CONTEXT = 16, + // If this bit is set the continuation has a TaskScheduler + // that we should continue on. + CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_TASK_SCHEDULER = 32, }; struct CORINFO_ASYNC_INFO @@ -1737,6 +1746,7 @@ struct CORINFO_ASYNC_INFO CORINFO_METHOD_HANDLE captureExecutionContextMethHnd; // Method handle for AsyncHelpers.RestoreExecutionContext CORINFO_METHOD_HANDLE restoreExecutionContextMethHnd; + CORINFO_METHOD_HANDLE captureContinuationContextMethHnd; }; // Flags passed from JIT to runtime. diff --git a/src/coreclr/jit/async.cpp b/src/coreclr/jit/async.cpp index e72f1fe7007afa..7a26ab4b004856 100644 --- a/src/coreclr/jit/async.cpp +++ b/src/coreclr/jit/async.cpp @@ -1040,6 +1040,13 @@ ContinuationLayout AsyncTransformation::LayOutContinuation(BasicBlock* block->getTryIndex(), layout.ExceptionGCDataIndex); } + if (call->GetAsyncInfo().ContinuationContextHandling == ContinuationContextHandling::ContinueOnCapturedContext) + { + layout.ContinuationContextGCDataIndex = layout.GCRefsCount++; + JITDUMP(" Continuation continues on captured context; context will be at GC@+%02u in GC data\n", + layout.ContinuationContextGCDataIndex); + } + if (call->GetAsyncInfo().ExecutionContextHandling == ExecutionContextHandling::AsyncSaveAndRestore) { layout.ExecContextGCDataIndex = layout.GCRefsCount++; @@ -1200,13 +1207,16 @@ BasicBlock* AsyncTransformation::CreateSuspension( LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_comp, storeState)); // Fill in 'flags' - unsigned continuationFlags = 0; + const AsyncCallInfo& callInfo = call->GetAsyncInfo(); + unsigned continuationFlags = 0; if (layout.ReturnInGCData) continuationFlags |= CORINFO_CONTINUATION_RESULT_IN_GCDATA; if (block->hasTryIndex()) continuationFlags |= CORINFO_CONTINUATION_NEEDS_EXCEPTION; if (m_comp->doesMethodHavePatchpoints() || m_comp->opts.IsOSR()) continuationFlags |= CORINFO_CONTINUATION_OSR_IL_OFFSET_IN_DATA; + if (callInfo.ContinuationContextHandling == ContinuationContextHandling::ContinueOnThreadPool) + continuationFlags |= CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL; newContinuation = m_comp->gtNewLclvNode(m_newContinuationVar, TYP_REF); unsigned flagsOffset = m_comp->info.compCompHnd->getFieldOffset(m_asyncInfo->continuationFlagsFldHnd); @@ -1386,6 +1396,51 @@ void AsyncTransformation::FillInGCPointersOnSuspension(const ContinuationLayout& } } + if (layout.ContinuationContextGCDataIndex != UINT_MAX) + { + // Insert call AsyncHelpers.CaptureContinuationContext(ref + // newContinuation.GCData[ContinuationContextGCDataIndex], ref newContinuation.Flags). + GenTree* contextElementPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF); + GenTree* flagsPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF); + GenTreeCall* captureCall = + m_comp->gtNewCallNode(CT_USER_FUNC, m_asyncInfo->captureContinuationContextMethHnd, TYP_VOID); + + captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(flagsPlaceholder)); + captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(contextElementPlaceholder)); + + m_comp->compCurBB = suspendBB; + m_comp->fgMorphTree(captureCall); + + LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_comp, captureCall)); + + // Now replace contextElementPlaceholder with actual address of the context element + LIR::Use use; + bool gotUse = LIR::AsRange(suspendBB).TryGetUse(contextElementPlaceholder, &use); + assert(gotUse); + + GenTree* objectArr = m_comp->gtNewLclvNode(objectArrLclNum, TYP_REF); + unsigned offset = OFFSETOF__CORINFO_Array__data + (layout.ContinuationContextGCDataIndex * TARGET_POINTER_SIZE); + GenTree* contextElementOffset = + m_comp->gtNewOperNode(GT_ADD, TYP_BYREF, objectArr, m_comp->gtNewIconNode((ssize_t)offset, TYP_I_IMPL)); + + LIR::AsRange(suspendBB).InsertBefore(contextElementPlaceholder, LIR::SeqTree(m_comp, contextElementOffset)); + use.ReplaceWith(contextElementOffset); + LIR::AsRange(suspendBB).Remove(contextElementPlaceholder); + + // And now replace flagsPlaceholder with actual address of the flags + gotUse = LIR::AsRange(suspendBB).TryGetUse(flagsPlaceholder, &use); + assert(gotUse); + + newContinuation = m_comp->gtNewLclvNode(m_newContinuationVar, TYP_REF); + unsigned flagsOffset = m_comp->info.compCompHnd->getFieldOffset(m_asyncInfo->continuationFlagsFldHnd); + GenTree* flagsOffsetNode = m_comp->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation, + m_comp->gtNewIconNode((ssize_t)flagsOffset, TYP_I_IMPL)); + + LIR::AsRange(suspendBB).InsertBefore(flagsPlaceholder, LIR::SeqTree(m_comp, flagsOffsetNode)); + use.ReplaceWith(flagsOffsetNode); + LIR::AsRange(suspendBB).Remove(flagsPlaceholder); + } + if (layout.ExecContextGCDataIndex != UINT_MAX) { GenTreeCall* captureExecContext = diff --git a/src/coreclr/jit/async.h b/src/coreclr/jit/async.h index 83732fd241187c..e75f2ac8d157f6 100644 --- a/src/coreclr/jit/async.h +++ b/src/coreclr/jit/async.h @@ -18,14 +18,15 @@ struct LiveLocalInfo struct ContinuationLayout { - unsigned DataSize = 0; - unsigned GCRefsCount = 0; - ClassLayout* ReturnStructLayout = nullptr; - unsigned ReturnSize = 0; - bool ReturnInGCData = false; - unsigned ReturnValDataOffset = UINT_MAX; - unsigned ExceptionGCDataIndex = UINT_MAX; - unsigned ExecContextGCDataIndex = UINT_MAX; + unsigned DataSize = 0; + unsigned GCRefsCount = 0; + ClassLayout* ReturnStructLayout = nullptr; + unsigned ReturnSize = 0; + bool ReturnInGCData = false; + unsigned ReturnValDataOffset = UINT_MAX; + unsigned ExceptionGCDataIndex = UINT_MAX; + unsigned ExecContextGCDataIndex = UINT_MAX; + unsigned ContinuationContextGCDataIndex = UINT_MAX; const jitstd::vector& Locals; explicit ContinuationLayout(const jitstd::vector& locals) diff --git a/src/coreclr/jit/compiler.h b/src/coreclr/jit/compiler.h index 36e35ba3ecc190..77e19733c10f15 100644 --- a/src/coreclr/jit/compiler.h +++ b/src/coreclr/jit/compiler.h @@ -4430,6 +4430,7 @@ class Compiler #endif // This call is a task await PREFIX_IS_TASK_AWAIT = 0x00000080, + PREFIX_TASK_AWAIT_CONTINUE_ON_CAPTURED_CONTEXT = 0x00000100, }; static void impValidateMemoryAccessOpcode(const BYTE* codeAddr, const BYTE* codeEndp, bool volatilePrefix); diff --git a/src/coreclr/jit/gentree.h b/src/coreclr/jit/gentree.h index 921a34fa057fbb..10d0b149e92b75 100644 --- a/src/coreclr/jit/gentree.h +++ b/src/coreclr/jit/gentree.h @@ -4316,10 +4316,21 @@ enum class ExecutionContextHandling AsyncSaveAndRestore, }; +enum class ContinuationContextHandling +{ + // No special handling of SynchronizationContext/TaskScheduler is required. + None, + // Continue on SynchronizationContext/TaskScheduler + ContinueOnCapturedContext, + // Continue on thread pool thread + ContinueOnThreadPool, +}; + // Additional async call info. struct AsyncCallInfo { - ExecutionContextHandling ExecutionContextHandling = ExecutionContextHandling::None; + ExecutionContextHandling ExecutionContextHandling = ExecutionContextHandling::None; + ContinuationContextHandling ContinuationContextHandling = ContinuationContextHandling::None; }; // Return type descriptor of a GT_CALL node. diff --git a/src/coreclr/jit/importer.cpp b/src/coreclr/jit/importer.cpp index 21f12c21e0dc59..b618e04c7f13bf 100644 --- a/src/coreclr/jit/importer.cpp +++ b/src/coreclr/jit/importer.cpp @@ -9131,17 +9131,18 @@ void Compiler::impImportBlockCode(BasicBlock* block) // many other places. We unfortunately embed that knowledge here. if (opcode != CEE_CALLI) { - bool isAwait = false; - // TODO: The configVal should be wired to the actual implementation - // that control the flow of sync context. - // We do not have that yet. - int configVal = -1; // -1 not configured, 0/1 configured to false/true + bool isAwait = false; + int configVal = -1; // -1 not configured, 0/1 configured to false/true if (compIsAsync() && JitConfig.JitOptimizeAwait()) { if (impMatchTaskAwaitPattern(codeAddr, codeEndp, &configVal)) { isAwait = true; prefixFlags |= PREFIX_IS_TASK_AWAIT; + if (configVal != 0) + { + prefixFlags |= PREFIX_TASK_AWAIT_CONTINUE_ON_CAPTURED_CONTEXT; + } } } diff --git a/src/coreclr/jit/importercalls.cpp b/src/coreclr/jit/importercalls.cpp index bc6c2da6dae770..b74ed83765e2a8 100644 --- a/src/coreclr/jit/importercalls.cpp +++ b/src/coreclr/jit/importercalls.cpp @@ -701,17 +701,26 @@ var_types Compiler::impImportCall(OPCODE opcode, { AsyncCallInfo asyncInfo; - JITDUMP("Call is an async "); - if ((prefixFlags & PREFIX_IS_TASK_AWAIT) != 0) { - JITDUMP("task await\n"); + JITDUMP("Call is an async task await\n"); asyncInfo.ExecutionContextHandling = ExecutionContextHandling::SaveAndRestore; + + if ((prefixFlags & PREFIX_TASK_AWAIT_CONTINUE_ON_CAPTURED_CONTEXT) != 0) + { + asyncInfo.ContinuationContextHandling = ContinuationContextHandling::ContinueOnCapturedContext; + JITDUMP(" Continuation continues on captured context\n"); + } + else + { + asyncInfo.ContinuationContextHandling = ContinuationContextHandling::ContinueOnThreadPool; + JITDUMP(" Continuation continues on thread pool\n"); + } } else { - JITDUMP("non-task await\n"); + JITDUMP("Call is an async non-task await\n"); // Only expected non-task await to see in IL is one of the AsyncHelpers.AwaitAwaiter variants. // These are awaits of custom awaitables, and they come with the behavior that the execution context // is captured and restored on suspension/resumption. @@ -7884,6 +7893,14 @@ void Compiler::impMarkInlineCandidateHelper(GenTreeCall* call, return; } + if (call->IsAsync() && (call->GetAsyncInfo().ContinuationContextHandling != ContinuationContextHandling::None)) + { + // Cannot currently handle moving to captured context/thread pool when logically returning from inlinee. + // + inlineResult->NoteFatal(InlineObservation::CALLSITE_CONTINUATION_HANDLING); + return; + } + // Ignore indirect calls, unless they are indirect virtual stub calls with profile info. // if (call->gtCallType == CT_INDIRECT) diff --git a/src/coreclr/jit/inline.def b/src/coreclr/jit/inline.def index b8bc674bce09a3..47a226e07d4116 100644 --- a/src/coreclr/jit/inline.def +++ b/src/coreclr/jit/inline.def @@ -163,6 +163,7 @@ INLINE_OBSERVATION(RETURN_TYPE_MISMATCH, bool, "return type mismatch", INLINE_OBSERVATION(STFLD_NEEDS_HELPER, bool, "stfld needs helper", FATAL, CALLSITE) INLINE_OBSERVATION(TOO_MANY_LOCALS, bool, "too many locals", FATAL, CALLSITE) INLINE_OBSERVATION(PINVOKE_EH, bool, "PInvoke call site with EH", FATAL, CALLSITE) +INLINE_OBSERVATION(CONTINUATION_HANDLING, bool, "Callsite needs continuation handling", FATAL, CALLSITE) // ------ Call Site Performance ------- diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h index 16cd41fd7e90f4..5b894a5312d97c 100644 --- a/src/coreclr/vm/corelib.h +++ b/src/coreclr/vm/corelib.h @@ -728,6 +728,7 @@ DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_VALUETASK_RETURNING_THUNK_1, Finalize DEFINE_METHOD(ASYNC_HELPERS, UNSAFE_AWAIT_AWAITER_1, UnsafeAwaitAwaiter, GM_T_RetVoid) DEFINE_METHOD(ASYNC_HELPERS, CAPTURE_EXECUTION_CONTEXT, CaptureExecutionContext, NoSig) DEFINE_METHOD(ASYNC_HELPERS, RESTORE_EXECUTION_CONTEXT, RestoreExecutionContext, NoSig) +DEFINE_METHOD(ASYNC_HELPERS, CAPTURE_CONTINUATION_CONTEXT, CaptureContinuationContext, NoSig) DEFINE_CLASS(SPAN_HELPERS, System, SpanHelpers) DEFINE_METHOD(SPAN_HELPERS, MEMSET, Fill, SM_RefByte_Byte_UIntPtr_RetVoid) diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index 7e818028899d9f..e52613a97f5ac6 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -10256,6 +10256,7 @@ void CEEInfo::getAsyncInfo(CORINFO_ASYNC_INFO* pAsyncInfoOut) pAsyncInfoOut->continuationsNeedMethodHandle = m_pMethodBeingCompiled->GetLoaderAllocator()->CanUnload(); pAsyncInfoOut->captureExecutionContextMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__CAPTURE_EXECUTION_CONTEXT)); pAsyncInfoOut->restoreExecutionContextMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__RESTORE_EXECUTION_CONTEXT)); + pAsyncInfoOut->captureContinuationContextMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__CAPTURE_CONTINUATION_CONTEXT)); EE_TO_JIT_TRANSITION(); } diff --git a/src/tests/async/synchronization-context/synchronization-context.cs b/src/tests/async/synchronization-context/synchronization-context.cs new file mode 100644 index 00000000000000..e028ee09c9b061 --- /dev/null +++ b/src/tests/async/synchronization-context/synchronization-context.cs @@ -0,0 +1,107 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +public class Async2SynchronizationContext +{ + [Fact] + public static void TestSyncContexts() + { + SynchronizationContext prevContext = SynchronizationContext.Current; + try + { + SynchronizationContext.SetSynchronizationContext(new MySyncContext()); + TestSyncContext().GetAwaiter().GetResult(); + } + finally + { + SynchronizationContext.SetSynchronizationContext(prevContext); + } + } + + private static async Task TestSyncContext() + { + MySyncContext context = (MySyncContext)SynchronizationContext.Current; + await WrappedYieldToThreadPool(suspend: false); + Assert.Same(context, SynchronizationContext.Current); + + await WrappedYieldToThreadPool(suspend: true); + Assert.Same(context, SynchronizationContext.Current); + + await WrappedYieldToThreadPool(suspend: true).ConfigureAwait(true); + Assert.Same(context, SynchronizationContext.Current); + + await WrappedYieldToThreadPool(suspend: false).ConfigureAwait(false); + Assert.Same(context, SynchronizationContext.Current); + + await WrappedYieldToThreadPool(suspend: true).ConfigureAwait(false); + Assert.Null(SynchronizationContext.Current); + + await WrappedYieldToThreadWithCustomSyncContext(); + Assert.Null(SynchronizationContext.Current); + } + + private static async Task WrappedYieldToThreadPool(bool suspend) + { + if (suspend) + { + await Task.Yield(); + } + } + + private static async Task WrappedYieldToThreadWithCustomSyncContext() + { + Assert.Null(SynchronizationContext.Current); + await new YieldToThreadWithCustomSyncContext(); + Assert.True(SynchronizationContext.Current is MySyncContext { }); + } + + private class MySyncContext : SynchronizationContext + { + public override void Post(SendOrPostCallback d, object state) + { + ThreadPool.UnsafeQueueUserWorkItem(_ => + { + SynchronizationContext prevContext = Current; + try + { + SetSynchronizationContext(this); + d(state); + } + finally + { + SetSynchronizationContext(prevContext); + } + }, null); + } + } + + private struct YieldToThreadWithCustomSyncContext : ICriticalNotifyCompletion + { + public YieldToThreadWithCustomSyncContext GetAwaiter() => this; + + public void UnsafeOnCompleted(Action continuation) + { + new Thread(state => + { + SynchronizationContext.SetSynchronizationContext(new MySyncContext()); + continuation(); + }).Start(); + } + + public void OnCompleted(Action continuation) + { + throw new NotImplementedException(); + } + + public bool IsCompleted => false; + + public void GetResult() { } + } +} diff --git a/src/tests/async/synchronization-context/synchronization-context.csproj b/src/tests/async/synchronization-context/synchronization-context.csproj new file mode 100644 index 00000000000000..1ae294349c376f --- /dev/null +++ b/src/tests/async/synchronization-context/synchronization-context.csproj @@ -0,0 +1,8 @@ + + + True + + + + +