diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index 8f5c29cf2857d6..e4d4d92e2b1e45 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -556,7 +556,7 @@ public static void HandleSuspended(T task) where T : Task, ITaskComplet sentinelContinuation.Next = null; // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter. - // These never have special continuation handling. + // These never have special continuation context handling. const ContinuationFlags continueFlags = ContinuationFlags.ContinueOnCapturedSynchronizationContext | ContinuationFlags.ContinueOnThreadPool | @@ -774,28 +774,33 @@ private static void CaptureContexts(out ExecutionContext? execCtx, out Synchroni syncCtx = thread._synchronizationContext; } + // Restore contexts onto current Thread. If "resumed" then this is not the first starting call for the async method. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void RestoreContexts(bool suspended, ExecutionContext? previousExecCtx, SynchronizationContext? previousSyncCtx) + private static void RestoreContexts(bool resumed, ExecutionContext? previousExecCtx, SynchronizationContext? previousSyncCtx) { - Thread thread = Thread.CurrentThreadAssumedInitialized; - if (!suspended && previousSyncCtx != thread._synchronizationContext) + if (!resumed) { - thread._synchronizationContext = previousSyncCtx; - } + Thread thread = Thread.CurrentThreadAssumedInitialized; + if (previousSyncCtx != thread._synchronizationContext) + { + thread._synchronizationContext = previousSyncCtx; + } - ExecutionContext? currentExecCtx = thread._executionContext; - if (previousExecCtx != currentExecCtx) - { - ExecutionContext.RestoreChangedContextToThread(thread, previousExecCtx, currentExecCtx); + ExecutionContext? currentExecCtx = thread._executionContext; + if (previousExecCtx != currentExecCtx) + { + ExecutionContext.RestoreChangedContextToThread(thread, previousExecCtx, currentExecCtx); + } } } - private static void CaptureContinuationContext(SynchronizationContext syncCtx, ref object context, ref ContinuationFlags flags) + private static void CaptureContinuationContext(ref object continuationContext, ref ContinuationFlags flags) { + SynchronizationContext? syncCtx = Thread.CurrentThreadAssumedInitialized._synchronizationContext; if (syncCtx != null && syncCtx.GetType() != typeof(SynchronizationContext)) { flags |= ContinuationFlags.ContinueOnCapturedSynchronizationContext; - context = syncCtx; + continuationContext = syncCtx; return; } @@ -803,7 +808,7 @@ private static void CaptureContinuationContext(SynchronizationContext syncCtx, r if (sched != null && sched != TaskScheduler.Default) { flags |= ContinuationFlags.ContinueOnCapturedTaskScheduler; - context = sched; + continuationContext = sched; return; } diff --git a/src/coreclr/inc/corinfo.h b/src/coreclr/inc/corinfo.h index d2b3422ab35d7a..010a817cbe7201 100644 --- a/src/coreclr/inc/corinfo.h +++ b/src/coreclr/inc/corinfo.h @@ -722,6 +722,7 @@ enum CorInfoOptions CORINFO_GENERICS_CTXT_FROM_METHODDESC | CORINFO_GENERICS_CTXT_FROM_METHODTABLE), CORINFO_GENERICS_CTXT_KEEP_ALIVE = 0x00000100, // Keep the generics context alive throughout the method even if there is no explicit use, and report its location to the CLR + CORINFO_ASYNC_SAVE_CONTEXTS = 0x00000200, // Runtime async method must save and restore contexts }; // diff --git a/src/coreclr/inc/jiteeversionguid.h b/src/coreclr/inc/jiteeversionguid.h index ab8004ce707608..dd019944d810eb 100644 --- a/src/coreclr/inc/jiteeversionguid.h +++ b/src/coreclr/inc/jiteeversionguid.h @@ -37,11 +37,11 @@ #include -constexpr GUID JITEEVersionIdentifier = { /* a802fbbf-3e14-4b34-a348-5fba9fd756d4 */ - 0xa802fbbf, - 0x3e14, - 0x4b34, - {0xa3, 0x48, 0x5f, 0xba, 0x9f, 0xd7, 0x56, 0xd4} +constexpr GUID JITEEVersionIdentifier = { /* f0752445-e116-444d-98ea-aaa7cbc30baa */ + 0xf0752445, + 0xe116, + 0x444d, + {0x98, 0xea, 0xaa, 0xa7, 0xcb, 0xc3, 0x0b, 0xaa} }; #endif // JIT_EE_VERSIONING_GUID_H diff --git a/src/coreclr/jit/async.cpp b/src/coreclr/jit/async.cpp index 851bca6e80d436..f208bd4a50003a 100644 --- a/src/coreclr/jit/async.cpp +++ b/src/coreclr/jit/async.cpp @@ -47,194 +47,215 @@ //------------------------------------------------------------------------ // Compiler::SaveAsyncContexts: -// Insert code to save and restore contexts around async call sites. +// Insert code in async methods that saves and restores contexts. // // Returns: // Suitable phase status. // // Remarks: -// Runs early, after import but before inlining. Thus RET_EXPRs may be -// present, and async calls may later be inlined. +// This inserts code to save the current ExecutionContext and +// SynchronizationContext at the beginning of async functions, and code that +// restores these contexts at the end. Additionally inserts uses of each of +// these context at async calls to model the fact that on suspension, these +// locals will be used there. // PhaseStatus Compiler::SaveAsyncContexts() { - if (!compMustSaveAsyncContexts) + if ((info.compMethodInfo->options & CORINFO_ASYNC_SAVE_CONTEXTS) == 0) { - JITDUMP("No async calls where execution context capture/restore is necessary\n"); - ValidateNoAsyncSavesNecessary(); return PhaseStatus::MODIFIED_NOTHING; } - PhaseStatus result = PhaseStatus::MODIFIED_NOTHING; + // Create locals for ExecutionContext and SynchronizationContext + lvaAsyncExecutionContextVar = lvaGrabTemp(false DEBUGARG("Async ExecutionContext")); + lvaGetDesc(lvaAsyncExecutionContextVar)->lvType = TYP_REF; - BasicBlock* curBB = fgFirstBB; - while (curBB != nullptr) - { - BasicBlock* nextBB = curBB->Next(); + lvaAsyncSynchronizationContextVar = lvaGrabTemp(false DEBUGARG("Async SynchronizationContext")); + lvaGetDesc(lvaAsyncSynchronizationContextVar)->lvType = TYP_REF; - for (Statement* stmt : curBB->Statements()) - { - GenTree* tree = stmt->GetRootNode(); - if (tree->OperIs(GT_STORE_LCL_VAR)) - { - tree = tree->AsLclVarCommon()->Data(); - } + // Create try-fault structure. This is actually a try-finally, but we + // manually insert the restore code in a (merged) return block, so EH wise + // we only need to restore on fault. + BasicBlock* const tryBegBB = fgSplitBlockAtBeginning(fgFirstBB); + BasicBlock* const tryLastBB = fgLastBB; - if (!tree->IsCall()) - { - ValidateNoAsyncSavesNecessaryInStatement(stmt); - continue; - } + // Create fault handler block + BasicBlock* faultBB = fgNewBBafter(BBJ_EHFAULTRET, tryLastBB, false); + faultBB->bbRefs = 1; // Artificial ref count + faultBB->inheritWeightPercentage(tryBegBB, 0); - GenTreeCall* call = tree->AsCall(); - if (!call->IsAsync()) - { - ValidateNoAsyncSavesNecessaryInStatement(stmt); - continue; - } + // Add a new EH table entry. It encloses all others, so placing it at the + // end is the right thing to do. + unsigned XTnew = compHndBBtabCount; + EHblkDsc* newEntry = fgTryAddEHTableEntries(XTnew); - const AsyncCallInfo& asyncCallInfo = call->GetAsyncInfo(); + if (newEntry == nullptr) + { + IMPL_LIMITATION("too many exception clauses"); + } - // Currently we always expect that ExecutionContext and - // SynchronizationContext correlate about their save/restore - // behavior. - assert((asyncCallInfo.ExecutionContextHandling == ExecutionContextHandling::SaveAndRestore) == - asyncCallInfo.SaveAndRestoreSynchronizationContextField); + // Initialize the new entry + asyncContextRestoreEHID = impInlineRoot()->compEHID++; + newEntry->ebdID = asyncContextRestoreEHID; + newEntry->ebdHandlerType = EH_HANDLER_FAULT; - if (asyncCallInfo.ExecutionContextHandling != ExecutionContextHandling::SaveAndRestore) - { - continue; - } + newEntry->ebdTryBeg = tryBegBB; + newEntry->ebdTryLast = tryLastBB; - unsigned suspendedLclNum = - lvaGrabTemp(false DEBUGARG(printfAlloc("Suspended indicator for [%06u]", dspTreeID(call)))); - unsigned execCtxLclNum = - lvaGrabTemp(false DEBUGARG(printfAlloc("ExecutionContext for [%06u]", dspTreeID(call)))); - unsigned syncCtxLclNum = - lvaGrabTemp(false DEBUGARG(printfAlloc("SynchronizationContext for [%06u]", dspTreeID(call)))); + newEntry->ebdHndBeg = faultBB; + newEntry->ebdHndLast = faultBB; - LclVarDsc* suspendedLclDsc = lvaGetDesc(suspendedLclNum); - suspendedLclDsc->lvType = TYP_UBYTE; - suspendedLclDsc->lvHasLdAddrOp = true; + newEntry->ebdTyp = 0; // unused for fault - LclVarDsc* execCtxLclDsc = lvaGetDesc(execCtxLclNum); - execCtxLclDsc->lvType = TYP_REF; - execCtxLclDsc->lvHasLdAddrOp = true; + newEntry->ebdEnclosingTryIndex = EHblkDsc::NO_ENCLOSING_INDEX; + newEntry->ebdEnclosingHndIndex = EHblkDsc::NO_ENCLOSING_INDEX; - LclVarDsc* syncCtxLclDsc = lvaGetDesc(syncCtxLclNum); - syncCtxLclDsc->lvType = TYP_REF; - syncCtxLclDsc->lvHasLdAddrOp = true; + newEntry->ebdTryBegOffset = tryBegBB->bbCodeOffs; + newEntry->ebdTryEndOffset = tryLastBB->bbCodeOffsEnd; + newEntry->ebdFilterBegOffset = 0; + newEntry->ebdHndBegOffset = 0; + newEntry->ebdHndEndOffset = 0; - call->asyncInfo->SynchronizationContextLclNum = syncCtxLclNum; + // Set flags on new region + tryBegBB->SetFlags(BBF_DONT_REMOVE | BBF_IMPORTED); + faultBB->SetFlags(BBF_DONT_REMOVE | BBF_IMPORTED); + faultBB->bbCatchTyp = BBCT_FAULT; - call->gtArgs.PushBack(this, NewCallArg::Primitive(gtNewLclAddrNode(suspendedLclNum, 0)) - .WellKnown(WellKnownArg::AsyncSuspendedIndicator)); + tryBegBB->setTryIndex(XTnew); + tryBegBB->clearHndIndex(); - JITDUMP("Saving contexts around [%06u], ExecutionContext = V%02u, SynchronizationContext = V%02u\n", - call->gtTreeID, execCtxLclNum, syncCtxLclNum); + faultBB->clearTryIndex(); + faultBB->setHndIndex(XTnew); - CORINFO_ASYNC_INFO* asyncInfo = eeGetAsyncInfo(); + // Walk user code blocks and set try index + for (BasicBlock* tmpBB = tryBegBB->Next(); tmpBB != faultBB; tmpBB = tmpBB->Next()) + { + if (!tmpBB->hasTryIndex()) + { + tmpBB->setTryIndex(XTnew); + } + } - GenTreeCall* capture = gtNewCallNode(CT_USER_FUNC, asyncInfo->captureContextsMethHnd, TYP_VOID); - capture->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclAddrNode(syncCtxLclNum, 0))); - capture->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclAddrNode(execCtxLclNum, 0))); + // Walk EH table and update enclosing try indices + for (unsigned XTnum = 0; XTnum < XTnew; XTnum++) + { + EHblkDsc* HBtab = &compHndBBtab[XTnum]; + if (HBtab->ebdEnclosingTryIndex == EHblkDsc::NO_ENCLOSING_INDEX) + { + HBtab->ebdEnclosingTryIndex = (unsigned short)XTnew; + } + } - CORINFO_CALL_INFO callInfo = {}; - callInfo.hMethod = capture->gtCallMethHnd; - callInfo.methodFlags = info.compCompHnd->getMethodAttribs(callInfo.hMethod); - impMarkInlineCandidate(capture, MAKE_METHODCONTEXT(callInfo.hMethod), false, &callInfo, compInlineContext); + JITDUMP("Created EH descriptor EH#%u for try/fault wrapping body to save/restore async contexts\n", XTnew); + INDEBUG(fgVerifyHandlerTab()); - Statement* captureStmt = fgNewStmtFromTree(capture); - fgInsertStmtBefore(curBB, stmt, captureStmt); + // Get async helper methods + CORINFO_ASYNC_INFO* asyncInfo = eeGetAsyncInfo(); - JITDUMP("Inserted capture:\n"); - DISPSTMT(captureStmt); + // Insert CaptureContexts call before the try (keep it before so the + // try/finally can be removed if there is no exception side effects) + GenTreeCall* captureCall = gtNewCallNode(CT_USER_FUNC, asyncInfo->captureContextsMethHnd, TYP_VOID); + captureCall->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclAddrNode(lvaAsyncSynchronizationContextVar, 0))); + captureCall->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclAddrNode(lvaAsyncExecutionContextVar, 0))); + lvaGetDesc(lvaAsyncSynchronizationContextVar)->lvHasLdAddrOp = true; + lvaGetDesc(lvaAsyncExecutionContextVar)->lvHasLdAddrOp = true; - BasicBlock* restoreBB = curBB; - Statement* restoreAfterStmt = stmt; + CORINFO_CALL_INFO callInfo = {}; + callInfo.hMethod = captureCall->gtCallMethHnd; + callInfo.methodFlags = info.compCompHnd->getMethodAttribs(callInfo.hMethod); + impMarkInlineCandidate(captureCall, MAKE_METHODCONTEXT(callInfo.hMethod), false, &callInfo, compInlineContext); - if (call->IsInlineCandidate() && (call->gtReturnType != TYP_VOID)) - { - restoreAfterStmt = stmt->GetNextStmt(); - assert(restoreAfterStmt->GetRootNode()->OperIs(GT_RET_EXPR) || - (restoreAfterStmt->GetRootNode()->OperIs(GT_STORE_LCL_VAR) && - restoreAfterStmt->GetRootNode()->AsLclVarCommon()->Data()->OperIs(GT_RET_EXPR))); - } + Statement* captureStmt = fgNewStmtFromTree(captureCall); + fgInsertStmtAtBeg(fgFirstBB, captureStmt); - if (curBB->hasTryIndex()) - { -#ifdef FEATURE_EH_WINDOWS_X86 - IMPL_LIMITATION("Cannot handle insertion of try-finally without funclets"); -#else - // Await is inside a try, need to insert try-finally around it. - restoreBB = InsertTryFinallyForContextRestore(curBB, stmt, restoreAfterStmt); - restoreAfterStmt = nullptr; - // we have split the block that could have another await. - nextBB = restoreBB->Next(); -#endif - } + JITDUMP("Inserted capture\n"); + DISPSTMT(captureStmt); - GenTreeCall* restore = gtNewCallNode(CT_USER_FUNC, asyncInfo->restoreContextsMethHnd, TYP_VOID); - restore->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclVarNode(syncCtxLclNum))); - restore->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclVarNode(execCtxLclNum))); - restore->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclVarNode(suspendedLclNum))); + // Insert RestoreContexts call in fault (exceptional case) + // First argument: started = (continuation == null) + GenTree* continuation = gtNewLclvNode(lvaAsyncContinuationArg, TYP_REF); + GenTree* null = gtNewNull(); + GenTree* resumed = gtNewOperNode(GT_NE, TYP_INT, continuation, null); - callInfo = {}; - callInfo.hMethod = restore->gtCallMethHnd; - callInfo.methodFlags = info.compCompHnd->getMethodAttribs(callInfo.hMethod); - impMarkInlineCandidate(restore, MAKE_METHODCONTEXT(callInfo.hMethod), false, &callInfo, compInlineContext); + GenTreeCall* restoreCall = gtNewCallNode(CT_USER_FUNC, asyncInfo->restoreContextsMethHnd, TYP_VOID); + restoreCall->gtArgs.PushFront(this, + NewCallArg::Primitive(gtNewLclVarNode(lvaAsyncSynchronizationContextVar, TYP_REF))); + restoreCall->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclVarNode(lvaAsyncExecutionContextVar, TYP_REF))); + restoreCall->gtArgs.PushFront(this, NewCallArg::Primitive(resumed)); - Statement* restoreStmt = fgNewStmtFromTree(restore); - if (restoreAfterStmt == nullptr) - { - fgInsertStmtNearEnd(restoreBB, restoreStmt); - } - else - { - fgInsertStmtAfter(restoreBB, restoreAfterStmt, restoreStmt); - } + Statement* restoreStmt = fgNewStmtFromTree(restoreCall); + fgInsertStmtAtEnd(faultBB, restoreStmt); - JITDUMP("Inserted restore:\n"); - DISPSTMT(restoreStmt); + // Now insert uses of the new contexts to all async calls (modelling the + // fact that on suspension, we restore the context from those values). Also + // convert BBJ_RETURNs into an exit to a block outside the region. + BasicBlock* newReturnBB = nullptr; + unsigned mergedReturnLcl = BAD_VAR_NUM; + + for (BasicBlock* block : Blocks()) + { + AddContextArgsToAsyncCalls(block); - result = PhaseStatus::MODIFIED_EVERYTHING; + if (!block->KindIs(BBJ_RETURN) || (block == newReturnBB)) + { + continue; } - curBB = nextBB; - } + JITDUMP("Merging BBJ_RETURN block " FMT_BB "\n", block->bbNum); - return result; -} + if (newReturnBB == nullptr) + { + newReturnBB = CreateReturnBB(&mergedReturnLcl); + newReturnBB->inheritWeightPercentage(block, 0); + } -//------------------------------------------------------------------------ -// Compiler::ValidateNoAsyncSavesNecessary: -// Check that there are no async calls requiring saving of ExecutionContext -// in the method. -// -void Compiler::ValidateNoAsyncSavesNecessary() -{ -#ifdef DEBUG - for (BasicBlock* block : Blocks()) - { - for (Statement* stmt : block->Statements()) + // Store return value to common local + Statement* retStmt = block->lastStmt(); + assert((retStmt != nullptr) && retStmt->GetRootNode()->OperIs(GT_RETURN)); + + if (mergedReturnLcl != BAD_VAR_NUM) { - ValidateNoAsyncSavesNecessaryInStatement(stmt); + GenTree* retVal = retStmt->GetRootNode()->AsOp()->GetReturnValue(); + Statement* insertAfter = retStmt; + GenTree* storeRetVal = + gtNewTempStore(mergedReturnLcl, retVal, CHECK_SPILL_NONE, &insertAfter, retStmt->GetDebugInfo(), block); + Statement* storeStmt = fgNewStmtFromTree(storeRetVal); + fgInsertStmtAtEnd(block, storeStmt); + JITDUMP("Inserted store to common return local\n"); + DISPSTMT(storeStmt); } + + retStmt->GetRootNode()->gtBashToNOP(); + + // Jump to new shared restore + return block + block->SetKindAndTargetEdge(BBJ_ALWAYS, fgAddRefPred(newReturnBB, block)); + fgReturnCount--; } -#endif + + if (newReturnBB != nullptr) + { + newReturnBB->bbWeight = newReturnBB->computeIncomingWeight(); + } + + // After merging of returns we have at most 1 return (and we may have 0, if + // there were no returns before due to infinite loops or exceptions). + assert(fgReturnCount <= 1); + + return PhaseStatus::MODIFIED_EVERYTHING; } //------------------------------------------------------------------------ -// Compiler::ValidateNoAsyncSavesNecessaryInStatement: -// Check that there are no async calls requiring saving of ExecutionContext -// in the statement. +// Compiler::AddContextArgsToAsyncCalls: +// Add uses of the saved ExecutionContext and SynchronizationContext to all +// async calls. // -// Parameters: -// stmt - The statement +// Remarks: +// This models the fact that calls have uses of the saved contexts on +// suspension. The async transformation will later move the uses into the +// suspension code path. // -void Compiler::ValidateNoAsyncSavesNecessaryInStatement(Statement* stmt) +void Compiler::AddContextArgsToAsyncCalls(BasicBlock* block) { -#ifdef DEBUG struct Visitor : GenTreeVisitor { enum @@ -249,123 +270,117 @@ void Compiler::ValidateNoAsyncSavesNecessaryInStatement(Statement* stmt) fgWalkResult PreOrderVisit(GenTree** use, GenTree* user) { - if (((*use)->gtFlags & GTF_CALL) == 0) + GenTree* tree = *use; + if ((tree->gtFlags & GTF_CALL) == 0) { return WALK_SKIP_SUBTREES; } - if ((*use)->IsCall()) + if (!tree->IsCall() || !tree->AsCall()->IsAsync()) { - assert(!(*use)->AsCall()->IsAsyncAndAlwaysSavesAndRestoresExecutionContext()); + return WALK_CONTINUE; } + GenTreeCall* call = tree->AsCall(); + GenTree* execCtx = m_compiler->gtNewLclVarNode(m_compiler->lvaAsyncExecutionContextVar, TYP_REF); + GenTree* syncCtx = m_compiler->gtNewLclVarNode(m_compiler->lvaAsyncSynchronizationContextVar, TYP_REF); + JITDUMP("Adding exec context [%06u], sync context [%06u] to async call [%06u]\n", dspTreeID(execCtx), + dspTreeID(syncCtx), dspTreeID(call)); + call->gtArgs.PushFront(m_compiler, + NewCallArg::Primitive(syncCtx).WellKnown(WellKnownArg::AsyncSynchronizationContext)); + call->gtArgs.PushFront(m_compiler, + NewCallArg::Primitive(execCtx).WellKnown(WellKnownArg::AsyncExecutionContext)); return WALK_CONTINUE; } }; Visitor visitor(this); - visitor.WalkTree(stmt->GetRootNodePointer(), nullptr); -#endif + for (Statement* stmt : block->Statements()) + { + visitor.WalkTree(stmt->GetRootNodePointer(), nullptr); + } } //------------------------------------------------------------------------ -// Compiler::InsertTryFinallyForContextRestore: -// Insert a try-finally around the specified statements in the specified -// block. +// Compiler::CreateReturnBB: +// Create a new return block to exit the async method. +// +// Parameters: +// mergedReturnLcl - [out] The local created to hold the merged return value. +// BAD_VAR_NUM if the async method does not return a result. // // Returns: -// Finally block of inserted try-finally. +// A new basic block that restores contexts and returns a merged result. // -BasicBlock* Compiler::InsertTryFinallyForContextRestore(BasicBlock* block, Statement* firstStmt, Statement* lastStmt) +BasicBlock* Compiler::CreateReturnBB(unsigned* mergedReturnLcl) { - assert(!block->hasHndIndex()); - EHblkDsc* ebd = fgTryAddEHTableEntries(block->bbTryIndex - 1, 1); - if (ebd == nullptr) - { - IMPL_LIMITATION("Awaits require insertion of too many EH clauses"); - } + BasicBlock* newReturnBB = fgNewBBafter(BBJ_RETURN, fgLastBB, /* extendRegion */ false); + newReturnBB->bbTryIndex = 0; // EH region + newReturnBB->bbHndIndex = 0; + fgReturnCount++; + JITDUMP("Created new BBJ_RETURN block " FMT_BB "\n", newReturnBB->bbNum); - if (firstStmt == block->firstStmt()) - { - block = fgSplitBlockAtBeginning(block); - } - else - { - block = fgSplitBlockAfterStatement(block, firstStmt->GetPrevStmt()); - } - - BasicBlock* tailBB = fgSplitBlockAfterStatement(block, lastStmt); + // Insert "restore" call + CORINFO_ASYNC_INFO* asyncInfo = eeGetAsyncInfo(); - BasicBlock* callFinally = fgNewBBafter(BBJ_CALLFINALLY, block, false); - BasicBlock* callFinallyRet = fgNewBBafter(BBJ_CALLFINALLYRET, callFinally, false); - BasicBlock* finallyRet = fgNewBBafter(BBJ_EHFINALLYRET, callFinallyRet, false); - BasicBlock* goToTailBlock = fgNewBBafter(BBJ_ALWAYS, finallyRet, false); - - callFinally->inheritWeight(block); - callFinallyRet->inheritWeight(block); - finallyRet->inheritWeight(block); - goToTailBlock->inheritWeight(block); - - // Set some info the starting blocks like fgFindBasicBlocks does - block->SetFlags(BBF_DONT_REMOVE); - finallyRet->SetFlags(BBF_DONT_REMOVE); - finallyRet->bbRefs++; // Artificial ref count on handler begins - - fgRemoveRefPred(block->GetTargetEdge()); - // Wire up the control flow for the new blocks - block->SetTargetEdge(fgAddRefPred(callFinally, block)); - callFinally->SetTargetEdge(fgAddRefPred(finallyRet, callFinally)); - - FlowEdge** succs = new (this, CMK_BasicBlock) FlowEdge* [1] { - fgAddRefPred(callFinallyRet, finallyRet) - }; - succs[0]->setLikelihood(1.0); - BBJumpTable* ehfDesc = new (this, CMK_BasicBlock) BBJumpTable(succs, 1); - finallyRet->SetEhfTargets(ehfDesc); + GenTree* continuation = gtNewLclvNode(lvaAsyncContinuationArg, TYP_REF); + GenTree* null = gtNewNull(); + GenTree* resumed = gtNewOperNode(GT_NE, TYP_INT, continuation, null); - callFinallyRet->SetTargetEdge(fgAddRefPred(goToTailBlock, callFinallyRet)); - goToTailBlock->SetTargetEdge(fgAddRefPred(tailBB, goToTailBlock)); + GenTreeCall* restoreCall = gtNewCallNode(CT_USER_FUNC, asyncInfo->restoreContextsMethHnd, TYP_VOID); + restoreCall->gtArgs.PushFront(this, + NewCallArg::Primitive(gtNewLclVarNode(lvaAsyncSynchronizationContextVar, TYP_REF))); + restoreCall->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclVarNode(lvaAsyncExecutionContextVar, TYP_REF))); + restoreCall->gtArgs.PushFront(this, NewCallArg::Primitive(resumed)); - // Most of these blocks go in the old EH region - callFinally->bbTryIndex = block->bbTryIndex; - callFinallyRet->bbTryIndex = block->bbTryIndex; - finallyRet->bbTryIndex = block->bbTryIndex; - goToTailBlock->bbTryIndex = block->bbTryIndex; + // This restore is an inline candidate (unlike the fault one) + CORINFO_CALL_INFO callInfo = {}; + callInfo.hMethod = restoreCall->gtCallMethHnd; + callInfo.methodFlags = info.compCompHnd->getMethodAttribs(callInfo.hMethod); + impMarkInlineCandidate(restoreCall, MAKE_METHODCONTEXT(callInfo.hMethod), false, &callInfo, compInlineContext); - callFinally->bbHndIndex = block->bbHndIndex; - callFinallyRet->bbHndIndex = block->bbHndIndex; - finallyRet->bbHndIndex = block->bbHndIndex; - goToTailBlock->bbHndIndex = block->bbHndIndex; + Statement* restoreStmt = fgNewStmtFromTree(restoreCall); + fgInsertStmtAtEnd(newReturnBB, restoreStmt); + JITDUMP("Inserted restore statement in return block\n"); + DISPSTMT(restoreStmt); - // block goes into the inserted EH clause and the finally becomes the handler - block->bbTryIndex--; - finallyRet->bbHndIndex = block->bbTryIndex; + *mergedReturnLcl = BAD_VAR_NUM; - ebd->ebdID = impInlineRoot()->compEHID++; - ebd->ebdHandlerType = EH_HANDLER_FINALLY; + GenTree* ret; + if (compMethodHasRetVal()) + { + *mergedReturnLcl = lvaGrabTemp(false DEBUGARG("Async merged return local")); - ebd->ebdTryBeg = block; - ebd->ebdTryLast = block; + var_types retLclType = compMethodReturnsRetBufAddr() ? TYP_BYREF : genActualType(info.compRetType); - ebd->ebdHndBeg = finallyRet; - ebd->ebdHndLast = finallyRet; + if (varTypeIsStruct(retLclType)) + { + lvaSetStruct(*mergedReturnLcl, info.compMethodInfo->args.retTypeClass, false); - ebd->ebdTyp = 0; - ebd->ebdEnclosingTryIndex = (unsigned short)goToTailBlock->getTryIndex(); - ebd->ebdEnclosingHndIndex = EHblkDsc::NO_ENCLOSING_INDEX; + if (compMethodReturnsMultiRegRetType()) + { + lvaGetDesc(*mergedReturnLcl)->lvIsMultiRegRet = true; + } + } + else + { + lvaGetDesc(*mergedReturnLcl)->lvType = retLclType; + } - ebd->ebdTryBegOffset = block->bbCodeOffs; - ebd->ebdTryEndOffset = block->bbCodeOffsEnd; - ebd->ebdFilterBegOffset = 0; - ebd->ebdHndBegOffset = 0; - ebd->ebdHndEndOffset = 0; + GenTree* retTemp = gtNewLclVarNode(*mergedReturnLcl); + ret = gtNewOperNode(GT_RETURN, retTemp->TypeGet(), retTemp); + } + else + { + ret = new (this, GT_RETURN) GenTreeOp(GT_RETURN, TYP_VOID); + } - finallyRet->bbCatchTyp = BBCT_FINALLY; - GenTree* retFilt = gtNewOperNode(GT_RETFILT, TYP_VOID, nullptr); - Statement* retFiltStmt = fgNewStmtFromTree(retFilt); - fgInsertStmtAtEnd(finallyRet, retFiltStmt); + Statement* retStmt = fgNewStmtFromTree(ret); - return finallyRet; + fgInsertStmtAtEnd(newReturnBB, retStmt); + JITDUMP("Inserted return statement in return block\n"); + DISPSTMT(retStmt); + return newReturnBB; } class AsyncLiveness @@ -808,8 +823,6 @@ void AsyncTransformation::Transform( ContinuationLayout layout = LayOutContinuation(block, call, ContinuationNeedsKeepAlive(life), liveLocals); - ClearSuspendedIndicator(block, call); - CallDefinitionInfo callDefInfo = CanonicalizeCallDefinition(block, call, life); unsigned stateNum = (unsigned)m_resumptionBBs.size(); @@ -857,12 +870,14 @@ void AsyncTransformation::CreateLiveSetForSuspension(BasicBlock* call->VisitLocalDefs(m_comp, visitDef); - const AsyncCallInfo& asyncInfo = call->GetAsyncInfo(); - - if (asyncInfo.SynchronizationContextLclNum != BAD_VAR_NUM) + // Exclude method-level context locals (only live on synchronous path) + if (m_comp->lvaAsyncSynchronizationContextVar != BAD_VAR_NUM) { - // This one is only live on the synchronous path, which liveness cannot prove - excludedLocals.AddOrUpdate(asyncInfo.SynchronizationContextLclNum, true); + excludedLocals.AddOrUpdate(m_comp->lvaAsyncSynchronizationContextVar, true); + } + if (m_comp->lvaAsyncExecutionContextVar != BAD_VAR_NUM) + { + excludedLocals.AddOrUpdate(m_comp->lvaAsyncExecutionContextVar, true); } life.GetLiveLocals(liveLocals, [&](unsigned lclNum) { @@ -1129,9 +1144,18 @@ ContinuationLayout AsyncTransformation::LayOutContinuation(BasicBlock* if (block->hasTryIndex()) { - layout.ExceptionOffset = allocLayout(TARGET_POINTER_SIZE, TARGET_POINTER_SIZE); - JITDUMP(" " FMT_BB " is in try region %u; exception will be at offset %u\n", block->bbNum, - block->getTryIndex(), layout.ExceptionOffset); + // If we are enclosed in any try region that isn't our special "context + // restore" try region then we need to rethrow an exception. For our + // special "context restore" try region we know that it is a no-op on + // the resumption path. + EHblkDsc* ehDsc = m_comp->ehGetDsc(block->getTryIndex()); + if ((ehDsc->ebdID != m_comp->asyncContextRestoreEHID) || + (ehDsc->ebdEnclosingTryIndex != EHblkDsc::NO_ENCLOSING_INDEX)) + { + layout.ExceptionOffset = allocLayout(TARGET_POINTER_SIZE, TARGET_POINTER_SIZE); + JITDUMP(" " FMT_BB " is in try region %u; exception will be at offset %u\n", block->bbNum, + block->getTryIndex(), layout.ExceptionOffset); + } } if (call->GetAsyncInfo().ContinuationContextHandling == ContinuationContextHandling::ContinueOnCapturedContext) @@ -1157,12 +1181,9 @@ ContinuationLayout AsyncTransformation::LayOutContinuation(BasicBlock* JITDUMP(" Continuation needs keep alive object; will be at offset %u\n", layout.KeepAliveOffset); } - if (call->GetAsyncInfo().ExecutionContextHandling == ExecutionContextHandling::AsyncSaveAndRestore) - { - layout.ExecContextOffset = allocLayout(TARGET_POINTER_SIZE, TARGET_POINTER_SIZE); - JITDUMP(" Call has async-only save and restore of ExecutionContext; ExecutionContext will be at offset %u\n", - layout.ExecContextOffset); - } + layout.ExecutionContextOffset = allocLayout(TARGET_POINTER_SIZE, TARGET_POINTER_SIZE); + JITDUMP(" Call has async-only save and restore of ExecutionContext; ExecutionContext will be at offset %u\n", + layout.ExecutionContextOffset); for (LiveLocalInfo& inf : liveLocals) { @@ -1190,7 +1211,7 @@ ContinuationLayout AsyncTransformation::LayOutContinuation(BasicBlock* bitmapBuilder.SetIfNotMax(layout.ExceptionOffset); bitmapBuilder.SetIfNotMax(layout.ContinuationContextOffset); bitmapBuilder.SetIfNotMax(layout.KeepAliveOffset); - bitmapBuilder.SetIfNotMax(layout.ExecContextOffset); + bitmapBuilder.SetIfNotMax(layout.ExecutionContextOffset); if (layout.ReturnSize > 0) { @@ -1251,80 +1272,6 @@ ContinuationLayout AsyncTransformation::LayOutContinuation(BasicBlock* return layout; } -//------------------------------------------------------------------------ -// AsyncTransformation::ClearSuspendedIndicator: -// Generate IR to clear the value of the suspended indicator local. -// -// Parameters: -// block - Block to generate IR into -// call - The async call (not contained in "block") -// -void AsyncTransformation::ClearSuspendedIndicator(BasicBlock* block, GenTreeCall* call) -{ - CallArg* suspendedArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSuspendedIndicator); - if (suspendedArg == nullptr) - { - return; - } - - GenTree* suspended = suspendedArg->GetNode(); - if (!suspended->IsLclVarAddr() && - (!suspended->OperIs(GT_LCL_VAR) || m_comp->lvaVarAddrExposed(suspended->AsLclVarCommon()->GetLclNum()))) - { - // We will need a second use of this, so spill to a local - LIR::Use use(LIR::AsRange(block), &suspendedArg->NodeRef(), call); - use.ReplaceWithLclVar(m_comp); - suspended = use.Def(); - } - - GenTree* value = m_comp->gtNewIconNode(0); - GenTree* storeSuspended = - m_comp->gtNewStoreValueNode(TYP_UBYTE, m_comp->gtCloneExpr(suspended), value, GTF_IND_NONFAULTING); - - LIR::AsRange(block).InsertBefore(call, LIR::SeqTree(m_comp, storeSuspended)); -} - -//------------------------------------------------------------------------ -// AsyncTransformation::SetSuspendedIndicator: -// Generate IR to set the value of the suspended indicator local, and remove -// the argument from the call. -// -// Parameters: -// block - Block to generate IR into -// callBlock - Block containing the call -// call - The async call -// -void AsyncTransformation::SetSuspendedIndicator(BasicBlock* block, BasicBlock* callBlock, GenTreeCall* call) -{ - CallArg* suspendedArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSuspendedIndicator); - if (suspendedArg == nullptr) - { - return; - } - - GenTree* suspended = suspendedArg->GetNode(); - assert(suspended->IsLclVarAddr() || suspended->OperIs(GT_LCL_VAR)); // Ensured by ClearSuspendedIndicator - - GenTree* value = m_comp->gtNewIconNode(1); - GenTree* storeSuspended = - m_comp->gtNewStoreValueNode(TYP_UBYTE, m_comp->gtCloneExpr(suspended), value, GTF_IND_NONFAULTING); - - LIR::AsRange(block).InsertAtEnd(LIR::SeqTree(m_comp, storeSuspended)); - - call->gtArgs.RemoveUnsafe(suspendedArg); - call->asyncInfo->HasSuspensionIndicatorDef = false; - - // Avoid leaving LCL_ADDR around which will DNER the local. - if (suspended->IsLclVarAddr()) - { - LIR::AsRange(callBlock).Remove(suspended); - } - else - { - suspended->SetUnusedValue(); - } -} - //------------------------------------------------------------------------ // AsyncTransformation::CanonicalizeCallDefinition: // Put the call definition in a canonical form. This ensures that either the @@ -1483,6 +1430,8 @@ BasicBlock* AsyncTransformation::CreateSuspension( FillInDataOnSuspension(call, layout, suspendBB); } + RestoreContexts(block, call, suspendBB); + if (suspendBB->KindIs(BBJ_RETURN)) { newContinuation = m_comp->gtNewLclvNode(m_newContinuationVar, TYP_REF); @@ -1600,52 +1549,38 @@ void AsyncTransformation::FillInDataOnSuspension(GenTreeCall* call, if (layout.ContinuationContextOffset != UINT_MAX) { - const AsyncCallInfo& callInfo = call->GetAsyncInfo(); - assert(callInfo.SaveAndRestoreSynchronizationContextField); - assert(callInfo.ExecutionContextHandling == ExecutionContextHandling::SaveAndRestore); - assert(callInfo.SynchronizationContextLclNum != BAD_VAR_NUM); - // Insert call // AsyncHelpers.CaptureContinuationContext( - // syncContextFromBeforeCall, // ref newContinuation.ContinuationContext, // ref newContinuation.Flags). - GenTree* syncContextPlaceholder = m_comp->gtNewNull(); - GenTree* contextElementPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF); - GenTree* flagsPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF); + GenTree* contContextElementPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF); + GenTree* flagsPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF); GenTreeCall* captureCall = m_comp->gtNewCallNode(CT_USER_FUNC, m_asyncInfo->captureContinuationContextMethHnd, TYP_VOID); captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(flagsPlaceholder)); - captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(contextElementPlaceholder)); - captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(syncContextPlaceholder)); + captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(contContextElementPlaceholder)); m_comp->compCurBB = suspendBB; m_comp->fgMorphTree(captureCall); LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_comp, captureCall)); - // Replace sync context placeholder with actual sync context from before call + // Replace contContextElementPlaceholder with actual address of the continuation context element LIR::Use use; - bool gotUse = LIR::AsRange(suspendBB).TryGetUse(syncContextPlaceholder, &use); - assert(gotUse); - GenTree* syncContextLcl = m_comp->gtNewLclvNode(callInfo.SynchronizationContextLclNum, TYP_REF); - LIR::AsRange(suspendBB).InsertBefore(syncContextPlaceholder, syncContextLcl); - use.ReplaceWith(syncContextLcl); - LIR::AsRange(suspendBB).Remove(syncContextPlaceholder); - - // Replace contextElementPlaceholder with actual address of the context element - gotUse = LIR::AsRange(suspendBB).TryGetUse(contextElementPlaceholder, &use); + bool gotUse = LIR::AsRange(suspendBB).TryGetUse(contContextElementPlaceholder, &use); assert(gotUse); - GenTree* newContinuation = m_comp->gtNewLclvNode(m_newContinuationVar, TYP_REF); - unsigned offset = OFFSETOF__CORINFO_Continuation__data + layout.ContinuationContextOffset; - GenTree* contextElementOffset = m_comp->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation, - m_comp->gtNewIconNode((ssize_t)offset, TYP_I_IMPL)); + GenTree* newContinuation = m_comp->gtNewLclvNode(m_newContinuationVar, TYP_REF); + unsigned contContextOffset = OFFSETOF__CORINFO_Continuation__data + layout.ContinuationContextOffset; + GenTree* contContextElementOffset = + m_comp->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation, + m_comp->gtNewIconNode((ssize_t)contContextOffset, TYP_I_IMPL)); - LIR::AsRange(suspendBB).InsertBefore(contextElementPlaceholder, LIR::SeqTree(m_comp, contextElementOffset)); - use.ReplaceWith(contextElementOffset); - LIR::AsRange(suspendBB).Remove(contextElementPlaceholder); + LIR::AsRange(suspendBB).InsertBefore(contContextElementPlaceholder, + LIR::SeqTree(m_comp, contContextElementOffset)); + use.ReplaceWith(contContextElementOffset); + LIR::AsRange(suspendBB).Remove(contContextElementPlaceholder); // Replace flagsPlaceholder with actual address of the flags gotUse = LIR::AsRange(suspendBB).TryGetUse(flagsPlaceholder, &use); @@ -1661,7 +1596,7 @@ void AsyncTransformation::FillInDataOnSuspension(GenTreeCall* call, LIR::AsRange(suspendBB).Remove(flagsPlaceholder); } - if (layout.ExecContextOffset != UINT_MAX) + if (layout.ExecutionContextOffset != UINT_MAX) { GenTreeCall* captureExecContext = m_comp->gtNewCallNode(CT_USER_FUNC, m_asyncInfo->captureExecutionContextMethHnd, TYP_REF); @@ -1670,12 +1605,109 @@ void AsyncTransformation::FillInDataOnSuspension(GenTreeCall* call, m_comp->fgMorphTree(captureExecContext); GenTree* newContinuation = m_comp->gtNewLclvNode(m_newContinuationVar, TYP_REF); - unsigned offset = OFFSETOF__CORINFO_Continuation__data + layout.ExecContextOffset; + unsigned offset = OFFSETOF__CORINFO_Continuation__data + layout.ExecutionContextOffset; GenTree* store = StoreAtOffset(newContinuation, offset, captureExecContext, TYP_REF); LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_comp, store)); } } +//------------------------------------------------------------------------ +// AsyncTransformation::RestoreContexts: +// Create IR to restore contexts on suspension. +// +// Parameters: +// block - Block that contains the async call +// call - The async call +// suspendBB - The basic block to add IR to. +// +void AsyncTransformation::RestoreContexts(BasicBlock* block, GenTreeCall* call, BasicBlock* suspendBB) +{ + CallArg* execContextArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncExecutionContext); + CallArg* syncContextArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSynchronizationContext); + assert((execContextArg != nullptr) == (syncContextArg != nullptr)); + if (execContextArg == nullptr) + { + JITDUMP(" Call [%06u] does not have async contexts; skipping restore on suspension\n", + Compiler::dspTreeID(call)); + return; + } + + JITDUMP(" Call [%06u] has async contexts; will restore on suspension\n", Compiler::dspTreeID(call)); + + // Insert call + // AsyncHelpers.RestoreContexts(resumed, execContext, syncContext); + + GenTree* resumedPlaceholder = m_comp->gtNewIconNode(0); + GenTree* execContextPlaceholder = m_comp->gtNewNull(); + GenTree* syncContextPlaceholder = m_comp->gtNewNull(); + GenTreeCall* restoreCall = m_comp->gtNewCallNode(CT_USER_FUNC, m_asyncInfo->restoreContextsMethHnd, TYP_VOID); + + restoreCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(syncContextPlaceholder)); + restoreCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(execContextPlaceholder)); + restoreCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(resumedPlaceholder)); + + m_comp->compCurBB = suspendBB; + m_comp->fgMorphTree(restoreCall); + + LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_comp, restoreCall)); + + // Replace resumedPlaceholder with actual "continuationParameter != null" arg + LIR::Use use; + bool gotUse = LIR::AsRange(suspendBB).TryGetUse(resumedPlaceholder, &use); + assert(gotUse); + + GenTree* continuation = m_comp->gtNewLclvNode(m_comp->lvaAsyncContinuationArg, TYP_REF); + GenTree* null = m_comp->gtNewNull(); + GenTree* started = m_comp->gtNewOperNode(GT_NE, TYP_INT, continuation, null); + + LIR::AsRange(suspendBB).InsertBefore(resumedPlaceholder, LIR::SeqTree(m_comp, started)); + use.ReplaceWith(started); + LIR::AsRange(suspendBB).Remove(resumedPlaceholder); + + // Replace execContextPlaceholder with actual value + GenTree* execContext = execContextArg->GetNode(); + if (!execContext->OperIs(GT_LCL_VAR)) + { + // We are moving execContext into a different BB so create a temp for it. + LIR::Use use(LIR::AsRange(block), &execContextArg->NodeRef(), call); + use.ReplaceWithLclVar(m_comp); + execContext = use.Def(); + } + + gotUse = LIR::AsRange(suspendBB).TryGetUse(execContextPlaceholder, &use); + assert(gotUse); + + LIR::AsRange(block).Remove(execContext); + LIR::AsRange(suspendBB).InsertBefore(execContextPlaceholder, execContext); + use.ReplaceWith(execContext); + LIR::AsRange(suspendBB).Remove(execContextPlaceholder); + + call->gtArgs.RemoveUnsafe(execContextArg); + + // Replace syncContextPlaceholder with actual value + GenTree* syncContext = syncContextArg->GetNode(); + if (!syncContext->OperIs(GT_LCL_VAR)) + { + // We are moving syncContext into a different BB so create a temp for it. + LIR::Use use(LIR::AsRange(block), &syncContextArg->NodeRef(), call); + use.ReplaceWithLclVar(m_comp); + syncContext = use.Def(); + } + + gotUse = LIR::AsRange(suspendBB).TryGetUse(syncContextPlaceholder, &use); + assert(gotUse); + + LIR::AsRange(block).Remove(syncContext); + LIR::AsRange(suspendBB).InsertBefore(syncContextPlaceholder, syncContext); + use.ReplaceWith(syncContext); + LIR::AsRange(suspendBB).Remove(syncContextPlaceholder); + + call->gtArgs.RemoveUnsafe(syncContextArg); + + JITDUMP(" Created RestoreContexts call on suspension:\n"); + DISPTREERANGE(LIR::AsRange(suspendBB), restoreCall); +} + //------------------------------------------------------------------------ // AsyncTransformation::CreateCheckAndSuspendAfterCall: // Split the block containing the specified async call, and create the IR @@ -1778,8 +1810,6 @@ BasicBlock* AsyncTransformation::CreateResumption(BasicBlock* bloc LIR::AsRange(resumeBB).InsertAtEnd(LIR::SeqTree(m_comp, ilOffsetNode)); - SetSuspendedIndicator(resumeBB, block, call); - if (layout.Size > 0) { RestoreFromDataOnResumption(layout, resumeBB); @@ -1811,7 +1841,7 @@ BasicBlock* AsyncTransformation::CreateResumption(BasicBlock* bloc // void AsyncTransformation::RestoreFromDataOnResumption(const ContinuationLayout& layout, BasicBlock* resumeBB) { - if (layout.ExecContextOffset != BAD_VAR_NUM) + if (layout.ExecutionContextOffset != BAD_VAR_NUM) { GenTree* valuePlaceholder = m_comp->gtNewZeroConNode(TYP_REF); GenTreeCall* restoreCall = @@ -1828,7 +1858,7 @@ void AsyncTransformation::RestoreFromDataOnResumption(const ContinuationLayout& assert(gotUse); GenTree* continuation = m_comp->gtNewLclvNode(m_comp->lvaAsyncContinuationArg, TYP_REF); - unsigned execContextOffset = OFFSETOF__CORINFO_Continuation__data + layout.ExecContextOffset; + unsigned execContextOffset = OFFSETOF__CORINFO_Continuation__data + layout.ExecutionContextOffset; GenTree* execContextValue = LoadFromOffset(continuation, execContextOffset, TYP_REF); LIR::AsRange(resumeBB).InsertBefore(valuePlaceholder, LIR::SeqTree(m_comp, execContextValue)); diff --git a/src/coreclr/jit/async.h b/src/coreclr/jit/async.h index 13621666a0163b..96afec2773e198 100644 --- a/src/coreclr/jit/async.h +++ b/src/coreclr/jit/async.h @@ -30,7 +30,7 @@ struct ContinuationLayout unsigned ReturnAlignment = 0; unsigned ReturnSize = 0; unsigned ReturnValOffset = UINT_MAX; - unsigned ExecContextOffset = UINT_MAX; + unsigned ExecutionContextOffset = UINT_MAX; const jitstd::vector& Locals; CORINFO_CLASS_HANDLE ClassHnd = NO_CLASS_HANDLE; @@ -106,6 +106,7 @@ class AsyncTransformation GenTree* prevContinuation, const ContinuationLayout& layout); void FillInDataOnSuspension(GenTreeCall* call, const ContinuationLayout& layout, BasicBlock* suspendBB); + void RestoreContexts(BasicBlock* block, GenTreeCall* call, BasicBlock* suspendBB); void CreateCheckAndSuspendAfterCall(BasicBlock* block, GenTreeCall* call, const CallDefinitionInfo& callDefInfo, diff --git a/src/coreclr/jit/compiler.cpp b/src/coreclr/jit/compiler.cpp index f5c618fccea289..fb56de4755510a 100644 --- a/src/coreclr/jit/compiler.cpp +++ b/src/coreclr/jit/compiler.cpp @@ -4368,7 +4368,7 @@ void Compiler::compCompile(void** methodCodePtr, uint32_t* methodCodeSize, JitFl // DoPhase(this, PHASE_POST_IMPORT, &Compiler::fgPostImportationCleanup); - // Capture and restore contexts around awaited calls, if needed. + // Capture and restore contexts around the body, if needed. // DoPhase(this, PHASE_ASYNC_SAVE_CONTEXTS, &Compiler::SaveAsyncContexts); diff --git a/src/coreclr/jit/compiler.h b/src/coreclr/jit/compiler.h index 5f51bd68e138db..e8c4cc28b30db2 100644 --- a/src/coreclr/jit/compiler.h +++ b/src/coreclr/jit/compiler.h @@ -3744,7 +3744,6 @@ class Compiler bool gtIsTypeof(GenTree* tree, CORINFO_CLASS_HANDLE* handle = nullptr); GenTreeLclVarCommon* gtCallGetDefinedRetBufLclAddr(GenTreeCall* call); - GenTreeLclVarCommon* gtCallGetDefinedAsyncSuspendedIndicatorLclAddr(GenTreeCall* call); //------------------------------------------------------------------------- // Functions to display the trees @@ -3966,6 +3965,11 @@ class Compiler unsigned lvaMonAcquired = BAD_VAR_NUM; // boolean variable introduced into in synchronized methods // that tracks whether the lock has been taken + unsigned lvaAsyncExecutionContextVar = BAD_VAR_NUM; // ExecutionContext local for async methods + unsigned lvaAsyncSynchronizationContextVar = BAD_VAR_NUM; // SynchronizationContext local for async methods + + unsigned short asyncContextRestoreEHID = USHRT_MAX; + unsigned lvaArg0Var = BAD_VAR_NUM; // The lclNum of arg0. Normally this will be info.compThisArg. // However, if there is a "ldarga 0" or "starg 0" in the IL, // we will redirect all "ldarg(a) 0" and "starg 0" to this temp. @@ -5553,9 +5557,8 @@ class Compiler #endif PhaseStatus SaveAsyncContexts(); - BasicBlock* InsertTryFinallyForContextRestore(BasicBlock* block, Statement* firstStmt, Statement* lastStmt); - void ValidateNoAsyncSavesNecessary(); - void ValidateNoAsyncSavesNecessaryInStatement(Statement* stmt); + void AddContextArgsToAsyncCalls(BasicBlock* block); + BasicBlock* CreateReturnBB(unsigned* mergedReturnLcl); PhaseStatus TransformAsync(); // This field keep the R2R helper call that would be inserted to trigger the constructor @@ -9837,7 +9840,6 @@ class Compiler bool compSuppressedZeroInit = false; // There are vars with lvSuppressedZeroInit set bool compMaskConvertUsed = false; // Does the method have Convert Mask To Vector nodes. bool compUsesThrowHelper = false; // There is a call to a THROW_HELPER for the compiled method. - bool compMustSaveAsyncContexts = false; // There is an async call that needs capture/restore of async contexts. // NOTE: These values are only reliable after // the importing is completely finished. diff --git a/src/coreclr/jit/compiler.hpp b/src/coreclr/jit/compiler.hpp index 36406e9fa8ef15..1ae7cab070e018 100644 --- a/src/coreclr/jit/compiler.hpp +++ b/src/coreclr/jit/compiler.hpp @@ -4705,20 +4705,7 @@ GenTree::VisitResult GenTree::VisitLocalDefs(Compiler* comp, TVisitor visitor) } if (OperIs(GT_CALL)) { - GenTreeCall* call = AsCall(); - if (call->IsAsync()) - { - GenTreeLclVarCommon* suspendedArg = comp->gtCallGetDefinedAsyncSuspendedIndicatorLclAddr(call); - if (suspendedArg != nullptr) - { - bool isEntire = comp->lvaLclExactSize(suspendedArg->GetLclNum()) == 1; - if (visitor(LocalDef(suspendedArg, isEntire, suspendedArg->GetLclOffs(), 1)) == VisitResult::Abort) - { - return VisitResult::Abort; - } - } - } - + GenTreeCall* call = AsCall(); GenTreeLclVarCommon* lclAddr = comp->gtCallGetDefinedRetBufLclAddr(call); if (lclAddr != nullptr) { @@ -4761,16 +4748,7 @@ GenTree::VisitResult GenTree::VisitLocalDefNodes(Compiler* comp, TVisitor visito } if (OperIs(GT_CALL)) { - GenTreeCall* call = AsCall(); - if (call->IsAsync()) - { - GenTreeLclVarCommon* suspendedArg = comp->gtCallGetDefinedAsyncSuspendedIndicatorLclAddr(call); - if ((suspendedArg != nullptr) && (visitor(suspendedArg) == VisitResult::Abort)) - { - return VisitResult::Abort; - } - } - + GenTreeCall* call = AsCall(); GenTreeLclVarCommon* lclAddr = comp->gtCallGetDefinedRetBufLclAddr(call); if (lclAddr != nullptr) { diff --git a/src/coreclr/jit/fginline.cpp b/src/coreclr/jit/fginline.cpp index 08f2f7402dd41e..ec41df22af9671 100644 --- a/src/coreclr/jit/fginline.cpp +++ b/src/coreclr/jit/fginline.cpp @@ -1361,7 +1361,7 @@ void Compiler::fgInvokeInlineeCompiler(GenTreeCall* call, InlineResult* inlineRe struct Param { Compiler* pThis; - GenTree* call; + GenTreeCall* call; CORINFO_METHOD_HANDLE fncHandle; InlineCandidateInfo* inlineCandidateInfo; InlineInfo* inlineInfo; @@ -1420,6 +1420,10 @@ void Compiler::fgInvokeInlineeCompiler(GenTreeCall* call, InlineResult* inlineRe compileFlagsForInlinee.Clear(JitFlags::JIT_FLAG_DEBUG_EnC); compileFlagsForInlinee.Clear(JitFlags::JIT_FLAG_REVERSE_PINVOKE); compileFlagsForInlinee.Clear(JitFlags::JIT_FLAG_TRACK_TRANSITIONS); + if (!pParam->call->IsAsync()) + { + compileFlagsForInlinee.Clear(JitFlags::JIT_FLAG_ASYNC); + } #ifdef DEBUG if (pParam->pThis->verbose) @@ -2316,6 +2320,8 @@ Statement* Compiler::fgInlinePrependStatements(InlineInfo* inlineInfo) { case WellKnownArg::RetBuffer: case WellKnownArg::AsyncContinuation: + case WellKnownArg::AsyncExecutionContext: + case WellKnownArg::AsyncSynchronizationContext: continue; case WellKnownArg::InstParam: argInfo = inlineInfo->inlInstParamArgInfo; diff --git a/src/coreclr/jit/fgopt.cpp b/src/coreclr/jit/fgopt.cpp index 9120da5926735e..a574d1633dd390 100644 --- a/src/coreclr/jit/fgopt.cpp +++ b/src/coreclr/jit/fgopt.cpp @@ -361,6 +361,11 @@ PhaseStatus Compiler::fgPostImportationCleanup() cur->SetFlags(BBF_REMOVED); removedBlks++; + if (cur->KindIs(BBJ_RETURN)) + { + fgReturnCount--; + } + // Drop the block from the list. // // We rely on the fact that this does not clear out diff --git a/src/coreclr/jit/flowgraph.cpp b/src/coreclr/jit/flowgraph.cpp index 262277983af084..32eabc31cb4da3 100644 --- a/src/coreclr/jit/flowgraph.cpp +++ b/src/coreclr/jit/flowgraph.cpp @@ -2410,7 +2410,7 @@ PhaseStatus Compiler::fgAddInternal() // when we are asked to generate enter/leave callbacks // or for methods with PInvoke // or for methods calling into unmanaged code - // or for synchronized methods. + // or for synchronized methods // BasicBlock* lastBlockBeforeGenReturns = fgLastBB; if (compIsProfilerHookNeeded() || compMethodRequiresPInvokeFrame() || opts.IsReversePInvoke() || diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index b9d49d5655f500..790a6869084258 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -1513,10 +1513,13 @@ void CallArgs::RemovedWellKnownArg(WellKnownArg arg) // comp - The compiler. // cc - The calling convention. // arg - The kind of argument. +// reg - [out] The custom register assigned to the argument. Can be REG_NA for a non-ABI argument. // // Returns: -// The custom register assignment, or REG_NA if this is a normally treated -// argument. +// True if this is a specially passed argument. If so "reg" is set to the +// register to use for passing the argument, or REG_NA if the argument is not +// actually passed (used to represent arbitrary uses that are later expanded +// out for some cases). // // Remarks: // Many JIT helpers have custom calling conventions in order to improve @@ -1525,14 +1528,15 @@ void CallArgs::RemovedWellKnownArg(WellKnownArg arg) // them. Note that we only support passing such arguments in custom registers // and generally never on stack. // -regNumber CallArgs::GetCustomRegister(Compiler* comp, CorInfoCallConvExtension cc, WellKnownArg arg) +bool CallArgs::GetCustomRegister(Compiler* comp, CorInfoCallConvExtension cc, WellKnownArg arg, regNumber* reg) { switch (arg) { #if defined(TARGET_X86) || defined(TARGET_ARM) // The x86 and arm32 CORINFO_HELP_INIT_PINVOKE_FRAME helpers have a custom calling convention. case WellKnownArg::PInvokeFrame: - return REG_PINVOKE_FRAME; + *reg = REG_PINVOKE_FRAME; + return true; #endif #if defined(TARGET_ARM) // A non-standard calling convention using wrapper delegate invoke is used @@ -1545,64 +1549,85 @@ regNumber CallArgs::GetCustomRegister(Compiler* comp, CorInfoCallConvExtension c // delegate IL stub) to achieve its goal for delegate VSD call. See // COMDelegate::NeedsWrapperDelegate() in the VM for details. case WellKnownArg::WrapperDelegateCell: - return comp->virtualStubParamInfo->GetReg(); + *reg = comp->virtualStubParamInfo->GetReg(); + return true; #endif #if defined(TARGET_X86) // The x86 shift helpers have custom calling conventions and expect the lo // part of the long to be in EAX and the hi part to be in EDX. case WellKnownArg::ShiftLow: - return REG_LNGARG_LO; + *reg = REG_LNGARG_LO; + return true; case WellKnownArg::ShiftHigh: - return REG_LNGARG_HI; + *reg = REG_LNGARG_HI; + return true; #endif case WellKnownArg::RetBuffer: if (hasFixedRetBuffReg(cc)) { - return theFixedRetBuffReg(cc); + *reg = theFixedRetBuffReg(cc); + return true; } break; case WellKnownArg::VirtualStubCell: - return comp->virtualStubParamInfo->GetReg(); + *reg = comp->virtualStubParamInfo->GetReg(); + return true; case WellKnownArg::PInvokeCookie: - return REG_PINVOKE_COOKIE_PARAM; + *reg = REG_PINVOKE_COOKIE_PARAM; + return true; case WellKnownArg::PInvokeTarget: - return REG_PINVOKE_TARGET_PARAM; + *reg = REG_PINVOKE_TARGET_PARAM; + return true; case WellKnownArg::R2RIndirectionCell: - return REG_R2R_INDIRECT_PARAM; + *reg = REG_R2R_INDIRECT_PARAM; + return true; case WellKnownArg::ValidateIndirectCallTarget: if (REG_VALIDATE_INDIRECT_CALL_ADDR != REG_ARG_0) { - return REG_VALIDATE_INDIRECT_CALL_ADDR; + *reg = REG_VALIDATE_INDIRECT_CALL_ADDR; + return true; } break; #ifdef REG_DISPATCH_INDIRECT_CELL_ADDR case WellKnownArg::DispatchIndirectCallTarget: - return REG_DISPATCH_INDIRECT_CALL_ADDR; + *reg = REG_DISPATCH_INDIRECT_CALL_ADDR; + return true; #endif #ifdef SWIFT_SUPPORT case WellKnownArg::SwiftError: assert(cc == CorInfoCallConvExtension::Swift); - return REG_SWIFT_ERROR; + *reg = REG_SWIFT_ERROR; + return true; case WellKnownArg::SwiftSelf: assert(cc == CorInfoCallConvExtension::Swift); - return REG_SWIFT_SELF; + *reg = REG_SWIFT_SELF; + return true; #endif // SWIFT_SUPPORT + case WellKnownArg::StackArrayLocal: + case WellKnownArg::AsyncExecutionContext: + case WellKnownArg::AsyncSynchronizationContext: + // These are pseudo-args; they are not actual arguments, but we + // reuse the argument mechanism to represent them as arbitrary uses + // that are later expanded out. + *reg = REG_NA; + return true; + default: break; } - return REG_NA; + return false; } //--------------------------------------------------------------- @@ -1619,7 +1644,8 @@ regNumber CallArgs::GetCustomRegister(Compiler* comp, CorInfoCallConvExtension c // bool CallArgs::IsNonStandard(Compiler* comp, GenTreeCall* call, CallArg* arg) { - return GetCustomRegister(comp, call->GetUnmanagedCallConv(), arg->GetWellKnownArg()) != REG_NA; + regNumber reg; + return GetCustomRegister(comp, call->GetUnmanagedCallConv(), arg->GetWellKnownArg(), ®) && (reg != REG_NA); } //--------------------------------------------------------------- @@ -2276,24 +2302,6 @@ bool GenTreeCall::IsAsync() const return (gtCallMoreFlags & GTF_CALL_M_ASYNC) != 0; } -//------------------------------------------------------------------------- -// IsAsyncAndAlwaysSavesAndRestoresExecutionContext: -// Check if this is an async call that always saves and restores the -// ExecutionContext around it. -// -// Return Value: -// True if so. -// -// Remarks: -// Normal user await calls have this behavior, while custom awaiters (via -// AsyncHelpers.AwaitAwaiter) only saves and restores the ExecutionContext if -// actual suspension happens. -// -bool GenTreeCall::IsAsyncAndAlwaysSavesAndRestoresExecutionContext() const -{ - return IsAsync() && (GetAsyncInfo().ExecutionContextHandling == ExecutionContextHandling::SaveAndRestore); -} - //------------------------------------------------------------------------- // HasNonStandardAddedArgs: Return true if the method has non-standard args added to the call // argument list during argument morphing (fgMorphArgs), e.g., passed in R10 or R11 on AMD64. @@ -13236,8 +13244,10 @@ const char* Compiler::gtGetWellKnownArgNameForArgMsg(WellKnownArg arg) return "&lcl arr"; case WellKnownArg::RuntimeMethodHandle: return "meth hnd"; - case WellKnownArg::AsyncSuspendedIndicator: - return "async susp"; + case WellKnownArg::AsyncExecutionContext: + return "exec ctx"; + case WellKnownArg::AsyncSynchronizationContext: + return "sync ctx"; default: return nullptr; } @@ -19558,32 +19568,6 @@ GenTreeLclVarCommon* Compiler::gtCallGetDefinedRetBufLclAddr(GenTreeCall* call) return node->AsLclVarCommon(); } -//------------------------------------------------------------------------ -// gtCallGetDefinedAsyncSuspendedIndicatorLclAddr: -// Get the tree corresponding to the address of the indicator local that this call defines. -// -// Parameters: -// call - the Call node -// -// Returns: -// A tree representing the address of a local. -// -GenTreeLclVarCommon* Compiler::gtCallGetDefinedAsyncSuspendedIndicatorLclAddr(GenTreeCall* call) -{ - if (!call->IsAsync() || !call->GetAsyncInfo().HasSuspensionIndicatorDef) - { - return nullptr; - } - - CallArg* asyncSuspensionIndicatorArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSuspendedIndicator); - assert(asyncSuspensionIndicatorArg != nullptr); - GenTree* node = asyncSuspensionIndicatorArg->GetNode(); - - assert(node->OperIs(GT_LCL_ADDR) && lvaGetDesc(node->AsLclVarCommon())->IsDefinedViaAddress()); - - return node->AsLclVarCommon(); -} - //------------------------------------------------------------------------ // ParseArrayAddress: Rehydrate the array and index expression from ARR_ADDR. // diff --git a/src/coreclr/jit/gentree.h b/src/coreclr/jit/gentree.h index 794b6f66b09203..b76b6f1bc5049e 100644 --- a/src/coreclr/jit/gentree.h +++ b/src/coreclr/jit/gentree.h @@ -4341,18 +4341,6 @@ inline GenTreeCallDebugFlags& operator &=(GenTreeCallDebugFlags& a, GenTreeCallD // clang-format on -enum class ExecutionContextHandling -{ - // No special handling of execution context is required. - None, - // Always save and restore ExecutionContext around this await. - // Used for task awaits. - SaveAndRestore, - // Save and restore execution context on suspension/resumption only. - // Used for custom awaitables. - AsyncSaveAndRestore, -}; - enum class ContinuationContextHandling { // No special handling of SynchronizationContext/TaskScheduler is required. @@ -4369,42 +4357,10 @@ struct AsyncCallInfo // DebugInfo with SOURCE_TYPE_ASYNC pointing at the await call IL instruction DebugInfo CallAsyncDebugInfo; - // The following information is used to implement the proper observable handling of `ExecutionContext`, - // `SynchronizationContext` and `TaskScheduler` in async methods. - // - // The breakdown of the handling is as follows: - // - // - For custom awaitables there is no special handling of `SynchronizationContext` or `TaskScheduler`. All the - // handling that exists is custom implemented by the user. In this case "ContinuationContextHandling == None" and - // "SaveAndRestoreSynchronizationContextField == false". - // - // - For custom awaitables there _is_ special handling of `ExecutionContext`: when the custom awaitable suspends, - // the JIT ensures that the `ExecutionContext` will be captured on suspension and restored when the continuation is - // running. This is represented by "ExecutionContextHandling == AsyncSaveAndRestore". - // - // - For task awaits there is special handling of `SynchronizationContext` and `TaskScheduler` in multiple ways: - // - // * The JIT ensures that `Thread.CurrentThread._synchronizationContext` is saved and restored around - // synchronously finishing calls. This is represented by "SaveAndRestoreSynchronizationContextField == true". - // - // * The JIT/runtime/BCL ensure that when the callee suspends, the caller will eventually be resumed on the - // `SynchronizationContext`/`TaskScheduler` present before the call started, depending on the configuration of the - // task await by the user. This resumption can be inlined if the `SynchronizationContext` is current when the - // continuation is about to run, and otherwise will be posted to it. This is represented by - // "ContinuationContextHandling == ContinueOnCapturedContext/ContinueOnThreadPool". - // - // * When the callee suspends restoration of `Thread.CurrentThread._synchronizationContext` is left up to the - // custom implementation of the `SynchronizationContext`, it must not be done by the JIT. - // - // - For task awaits the runtime/BCL ensure that `Thread.CurrentThread._executionContext` is captured before the - // call and restored after it. This happens consistently regardless of whether the callee finishes synchronously or - // not. This is represented by "ExecutionContextHandling == SaveAndRestore". - // - ExecutionContextHandling ExecutionContextHandling = ExecutionContextHandling::None; - ContinuationContextHandling ContinuationContextHandling = ContinuationContextHandling::None; - bool SaveAndRestoreSynchronizationContextField = false; - bool HasSuspensionIndicatorDef = false; - unsigned SynchronizationContextLclNum = BAD_VAR_NUM; + // Behavior where we continue for each call depends on how it was + // configured and whether it is a task await or custom await. This field + // records that behavior. + ContinuationContextHandling ContinuationContextHandling = ContinuationContextHandling::None; }; // Return type descriptor of a GT_CALL node. @@ -4677,7 +4633,8 @@ enum class WellKnownArg : unsigned X86TailCallSpecialArg, StackArrayLocal, RuntimeMethodHandle, - AsyncSuspendedIndicator, + AsyncExecutionContext, + AsyncSynchronizationContext, }; #ifdef DEBUG @@ -4857,10 +4814,10 @@ class CallArgs bool m_alignmentDone : 1; #endif - void AddedWellKnownArg(WellKnownArg arg); - void RemovedWellKnownArg(WellKnownArg arg); - regNumber GetCustomRegister(Compiler* comp, CorInfoCallConvExtension cc, WellKnownArg arg); - void SortArgs(Compiler* comp, GenTreeCall* call, CallArg** sortedArgs); + void AddedWellKnownArg(WellKnownArg arg); + void RemovedWellKnownArg(WellKnownArg arg); + bool GetCustomRegister(Compiler* comp, CorInfoCallConvExtension cc, WellKnownArg arg, regNumber* reg); + void SortArgs(Compiler* comp, GenTreeCall* call, CallArg** sortedArgs); public: CallArgs(); @@ -5139,8 +5096,6 @@ struct GenTreeCall final : public GenTree return *asyncInfo; } - bool IsAsyncAndAlwaysSavesAndRestoresExecutionContext() const; - //--------------------------------------------------------------------------- // GetRegNumByIdx: get i'th return register allocated to this call node. // diff --git a/src/coreclr/jit/importer.cpp b/src/coreclr/jit/importer.cpp index 5d67351639e840..9bc7aa13eb59d3 100644 --- a/src/coreclr/jit/importer.cpp +++ b/src/coreclr/jit/importer.cpp @@ -4441,15 +4441,6 @@ bool Compiler::impIsImplicitTailCallCandidate( return false; } - // We cannot tailcall ValueTask returning methods as we need to preserve - // the Continuation instance for ValueTaskSource handling (the BCL needs - // to look at continuation.Next). We cannot easily differentiate between - // ValueTask and Task here, so we just disable it more generally. - if ((prefixFlags & PREFIX_IS_TASK_AWAIT) != 0) - { - return false; - } - #if !FEATURE_TAILCALL_OPT_SHARED_RETURN // the block containing call is marked as BBJ_RETURN // We allow shared ret tail call optimization on recursive calls even under @@ -13527,6 +13518,8 @@ void Compiler::impInlineInitVars(InlineInfo* pInlineInfo) { case WellKnownArg::RetBuffer: case WellKnownArg::AsyncContinuation: + case WellKnownArg::AsyncExecutionContext: + case WellKnownArg::AsyncSynchronizationContext: // These do not appear in the table of inline arg info; do not include them continue; case WellKnownArg::InstParam: diff --git a/src/coreclr/jit/importercalls.cpp b/src/coreclr/jit/importercalls.cpp index f7915a7b98ef61..51e0424fe58d8d 100644 --- a/src/coreclr/jit/importercalls.cpp +++ b/src/coreclr/jit/importercalls.cpp @@ -66,6 +66,7 @@ var_types Compiler::impImportCall(OPCODE opcode, // to see any imperative security. // Reverse P/Invokes need a call to CORINFO_HELP_JIT_REVERSE_PINVOKE_EXIT // at the end, so tailcalls should be disabled. + // Async methods need to restore contexts, so tailcalls should be disabled. if (info.compFlags & CORINFO_FLG_SYNCH) { canTailCall = false; @@ -76,6 +77,11 @@ var_types Compiler::impImportCall(OPCODE opcode, canTailCall = false; szCanTailCallFailReason = "Caller is Reverse P/Invoke"; } + else if (compIsAsync()) + { + canTailCall = false; + szCanTailCallFailReason = "Caller is async method"; + } #if !FEATURE_FIXED_OUT_ARGS else if (info.compIsVarArgs) { @@ -1433,35 +1439,10 @@ var_types Compiler::impImportCall(OPCODE opcode, // Propagate retExpr as the placeholder for the call. call = retExpr; - - if (origCall->IsAsyncAndAlwaysSavesAndRestoresExecutionContext()) - { - // Async calls that require save/restore of - // ExecutionContext need to be top most so that we can - // insert try-finally around them. We can inline these, so - // we need to ensure that the RET_EXPR is findable when we - // later expand this. - - unsigned resultLcl = lvaGrabTemp(true DEBUGARG("async")); - LclVarDsc* varDsc = lvaGetDesc(resultLcl); - // Keep the information about small typedness to avoid - // inserting unnecessary casts around normalization. - if (varTypeIsSmall(origCall->gtReturnType)) - { - assert(origCall->NormalizesSmallTypesOnReturn()); - varDsc->lvType = origCall->gtReturnType; - } - - impStoreToTemp(resultLcl, call, CHECK_SPILL_ALL); - // impStoreToTemp can change src arg list and return type for call that returns struct. - var_types type = genActualType(lvaGetDesc(resultLcl)->TypeGet()); - call = gtNewLclvNode(resultLcl, type); - } } else { - if (call->IsCall() && - (isFatPointerCandidate || call->AsCall()->IsAsyncAndAlwaysSavesAndRestoresExecutionContext())) + if (call->IsCall() && isFatPointerCandidate) { // these calls should be in statements of the form call() or var = call(). // Such form allows to find statements with fat calls without walking through whole trees @@ -6839,9 +6820,6 @@ void Compiler::impSetupAndSpillForAsyncCall(GenTreeCall* call, { JITDUMP("Call is an async task await\n"); - asyncInfo.ExecutionContextHandling = ExecutionContextHandling::SaveAndRestore; - asyncInfo.SaveAndRestoreSynchronizationContextField = true; - if ((prefixFlags & PREFIX_TASK_AWAIT_CONTINUE_ON_CAPTURED_CONTEXT) != 0) { asyncInfo.ContinuationContextHandling = ContinuationContextHandling::ContinueOnCapturedContext; @@ -6861,40 +6839,9 @@ void Compiler::impSetupAndSpillForAsyncCall(GenTreeCall* call, else { JITDUMP("Call is an async non-task await\n"); - // Only expected non-task await to see in IL is one of the AsyncHelpers.AwaitAwaiter variants. - // These are awaits of custom awaitables, and they come with the behavior that the execution context - // is captured and restored on suspension/resumption. - // We could perhaps skip this for AwaitAwaiter (but not for UnsafeAwaitAwaiter) since it is expected - // that the safe INotifyCompletion will take care of flowing ExecutionContext. - asyncInfo.ExecutionContextHandling = ExecutionContextHandling::AsyncSaveAndRestore; - } - - // For tailcalls the contexts does not need saving/restoring: they will be - // overwritten by the caller anyway. - // - // More specifically, if we can show that - // Thread.CurrentThread._executionContext is not accessed between the - // call and returning then we can omit save/restore of the execution - // context. We do not do that optimization yet. - if ((prefixFlags & PREFIX_TAILCALL) != 0) - { - asyncInfo.ExecutionContextHandling = ExecutionContextHandling::None; - asyncInfo.ContinuationContextHandling = ContinuationContextHandling::None; - asyncInfo.SaveAndRestoreSynchronizationContextField = false; } call->AsCall()->SetIsAsync(new (this, CMK_Async) AsyncCallInfo(asyncInfo)); - - if (asyncInfo.ExecutionContextHandling == ExecutionContextHandling::SaveAndRestore) - { - compMustSaveAsyncContexts = true; - - // In this case we will need to save the context after the arguments are evaluated. - // Spill the arguments to accomplish that. - // (We could do this via splitting in SaveAsyncContexts, but since we need to - // handle inline candidates we won't gain much.) - impSpillSideEffects(true, CHECK_SPILL_ALL DEBUGARG("Async await with execution context save and restore")); - } } //------------------------------------------------------------------------ diff --git a/src/coreclr/jit/lclmorph.cpp b/src/coreclr/jit/lclmorph.cpp index 4bda0cfcb4c827..d071f4f45212b2 100644 --- a/src/coreclr/jit/lclmorph.cpp +++ b/src/coreclr/jit/lclmorph.cpp @@ -1512,24 +1512,6 @@ class LocalAddressVisitor final : public GenTreeVisitor } } - if ((callUser != nullptr) && callUser->IsAsync() && m_compiler->IsValidLclAddr(lclNum, val.Offset())) - { - CallArg* suspendedArg = callUser->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSuspendedIndicator); - if ((suspendedArg != nullptr) && (val.Node() == suspendedArg->GetNode())) - { - INDEBUG(varDsc->SetDefinedViaAddress(true)); - escapeAddr = false; - defFlag = GTF_VAR_DEF; - - if ((val.Offset() != 0) || (varDsc->lvExactSize() != 1)) - { - defFlag |= GTF_VAR_USEASG; - } - - callUser->asyncInfo->HasSuspensionIndicatorDef = true; - } - } - if (escapeAddr) { unsigned exposedLclNum = varDsc->lvIsStructField ? varDsc->lvParentLcl : lclNum; diff --git a/src/coreclr/jit/morph.cpp b/src/coreclr/jit/morph.cpp index 48493499ff22b3..ee05a6b47aa72e 100644 --- a/src/coreclr/jit/morph.cpp +++ b/src/coreclr/jit/morph.cpp @@ -678,8 +678,10 @@ const char* getWellKnownArgName(WellKnownArg arg) return "StackArrayLocal"; case WellKnownArg::RuntimeMethodHandle: return "RuntimeMethodHandle"; - case WellKnownArg::AsyncSuspendedIndicator: - return "AsyncSuspendedIndicator"; + case WellKnownArg::AsyncExecutionContext: + return "AsyncExecutionContext"; + case WellKnownArg::AsyncSynchronizationContext: + return "AsyncSynchronizationContext"; } return "N/A"; @@ -1853,24 +1855,22 @@ void CallArgs::AddFinalArgsAndDetermineABIInfo(Compiler* comp, GenTreeCall* call // These should not affect the placement of any other args or stack space required. // Example: on AMD64 R10 and R11 are used for indirect VSD (generic interface) and cookie calls. // TODO-Cleanup: Integrate this into the new style ABI classifiers. - regNumber nonStdRegNum = GetCustomRegister(comp, call->GetUnmanagedCallConv(), arg.GetWellKnownArg()); - - if (nonStdRegNum == REG_NA) + regNumber nonStdRegNum; + if (GetCustomRegister(comp, call->GetUnmanagedCallConv(), arg.GetWellKnownArg(), &nonStdRegNum)) { - if (arg.GetWellKnownArg() == WellKnownArg::AsyncSuspendedIndicator) + if (nonStdRegNum != REG_NA) { - // Represents definition of a local. Expanded out by async transformation. - arg.AbiInfo = ABIPassingInformation(comp, 0); + ABIPassingSegment segment = ABIPassingSegment::InRegister(nonStdRegNum, 0, TARGET_POINTER_SIZE); + arg.AbiInfo = ABIPassingInformation::FromSegmentByValue(comp, segment); } else { - arg.AbiInfo = classifier.Classify(comp, argSigType, argLayout, arg.GetWellKnownArg()); + arg.AbiInfo = ABIPassingInformation(comp, 0); } } else { - ABIPassingSegment segment = ABIPassingSegment::InRegister(nonStdRegNum, 0, TARGET_POINTER_SIZE); - arg.AbiInfo = ABIPassingInformation::FromSegmentByValue(comp, segment); + arg.AbiInfo = classifier.Classify(comp, argSigType, argLayout, arg.GetWellKnownArg()); } JITDUMP("Argument %u ABI info: ", GetIndex(&arg)); @@ -1930,13 +1930,6 @@ void CallArgs::DetermineABIInfo(Compiler* comp, GenTreeCall* call) for (CallArg& arg : Args()) { - if (arg.GetWellKnownArg() == WellKnownArg::AsyncSuspendedIndicator) - { - // Represents definition of a local. Expanded out by async transformation. - arg.AbiInfo = ABIPassingInformation(comp, 0); - continue; - } - const var_types argSigType = arg.GetSignatureType(); const CORINFO_CLASS_HANDLE argSigClass = arg.GetSignatureClassHandle(); ClassLayout* argLayout = argSigClass == NO_CLASS_HANDLE ? nullptr : comp->typGetObjLayout(argSigClass); @@ -1945,16 +1938,22 @@ void CallArgs::DetermineABIInfo(Compiler* comp, GenTreeCall* call) // These should not affect the placement of any other args or stack space required. // Example: on AMD64 R10 and R11 are used for indirect VSD (generic interface) and cookie calls. // TODO-Cleanup: Integrate this into the new style ABI classifiers. - regNumber nonStdRegNum = GetCustomRegister(comp, call->GetUnmanagedCallConv(), arg.GetWellKnownArg()); - - if (nonStdRegNum == REG_NA) + regNumber nonStdRegNum; + if (GetCustomRegister(comp, call->GetUnmanagedCallConv(), arg.GetWellKnownArg(), &nonStdRegNum)) { - arg.AbiInfo = classifier.Classify(comp, argSigType, argLayout, arg.GetWellKnownArg()); + if (nonStdRegNum != REG_NA) + { + ABIPassingSegment segment = ABIPassingSegment::InRegister(nonStdRegNum, 0, TARGET_POINTER_SIZE); + arg.AbiInfo = ABIPassingInformation::FromSegmentByValue(comp, segment); + } + else + { + arg.AbiInfo = ABIPassingInformation(comp, 0); + } } else { - ABIPassingSegment segment = ABIPassingSegment::InRegister(nonStdRegNum, 0, TARGET_POINTER_SIZE); - arg.AbiInfo = ABIPassingInformation::FromSegmentByValue(comp, segment); + arg.AbiInfo = classifier.Classify(comp, argSigType, argLayout, arg.GetWellKnownArg()); } } diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index f6902fbc24f91e..321bb257561fc5 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -7577,9 +7577,10 @@ COR_ILMETHOD_DECODER* CEEInfo::getMethodInfoWorker( _ASSERTE(scopeHnd != NULL); methInfo->scope = scopeHnd; methInfo->options = (CorInfoOptions)(((UINT32)methInfo->options) | - ((ftn->AcquiresInstMethodTableFromThis() ? CORINFO_GENERICS_CTXT_FROM_THIS : 0) | - (ftn->RequiresInstMethodTableArg() ? CORINFO_GENERICS_CTXT_FROM_METHODTABLE : 0) | - (ftn->RequiresInstMethodDescArg() ? CORINFO_GENERICS_CTXT_FROM_METHODDESC : 0))); + ((ftn->AcquiresInstMethodTableFromThis() ? CORINFO_GENERICS_CTXT_FROM_THIS : 0) | + (ftn->RequiresInstMethodTableArg() ? CORINFO_GENERICS_CTXT_FROM_METHODTABLE : 0) | + (ftn->RequiresInstMethodDescArg() ? CORINFO_GENERICS_CTXT_FROM_METHODDESC : 0) | + (ftn->RequiresAsyncContextSaveAndRestore() ? CORINFO_ASYNC_SAVE_CONTEXTS : 0))); if (methInfo->options & CORINFO_GENERICS_CTXT_MASK) { diff --git a/src/coreclr/vm/method.hpp b/src/coreclr/vm/method.hpp index 2807fb4beaf7e4..c0fd9e972a172a 100644 --- a/src/coreclr/vm/method.hpp +++ b/src/coreclr/vm/method.hpp @@ -69,7 +69,7 @@ enum class AsyncMethodKind // Regular methods that return Task/ValueTask // Such method has its actual IL body and there also a synthetic variant that is an - // Async-callable think. (AsyncVariantThunk) + // Async-callable thunk. (AsyncVariantThunk) TaskReturning, // Task-returning methods marked as MethodImpl::Async in metadata. @@ -1979,6 +1979,17 @@ class MethodDesc m_wFlags |= mdfHasAsyncMethodData; } + // Returns true if this is an async method that requires save and restore + // of async contexts. Regular user implemented runtime async methods + // require this behavior, but thunks should be transparent and should not + // come with this behavior. + inline bool RequiresAsyncContextSaveAndRestore() const + { + if (!HasAsyncMethodData()) + return false; + return GetAddrOfAsyncMethodData()->kind == AsyncMethodKind::AsyncVariantImpl; + } + #ifdef FEATURE_METADATA_UPDATER inline BOOL IsEnCAddedMethod() { diff --git a/src/tests/async/synchronization-context/synchronization-context.cs b/src/tests/async/synchronization-context/synchronization-context.cs index 33280981cebee7..9ddda0d0cf9fb4 100644 --- a/src/tests/async/synchronization-context/synchronization-context.cs +++ b/src/tests/async/synchronization-context/synchronization-context.cs @@ -241,4 +241,25 @@ private static async Task TestNoSyncContextInRuntimeCallableThunkAsync() await Task.Delay(100).ConfigureAwait(false); Assert.Equal(0, syncCtx.NumPosts); } + + [Fact] + public static void TestAsyncToSyncContextSwitch() + { + TestAsyncToSyncContextSwitchAsync().GetAwaiter().GetResult(); + } + + private static async Task TestAsyncToSyncContextSwitchAsync() + { + MySyncContext context1 = new MySyncContext(); + MySyncContext context2 = new MySyncContext(); + SynchronizationContext.SetSynchronizationContext(context1); + await SyncMethodThatSwitchesContext(context2); + Assert.Same(context2, SynchronizationContext.Current); + } + + private static Task SyncMethodThatSwitchesContext(MySyncContext context) + { + SynchronizationContext.SetSynchronizationContext(context); + return Task.CompletedTask; + } }