Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions src/coreclr/inc/corinfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1712,6 +1712,15 @@ enum CorInfoContinuationFlags
// OSR method saved in the beginning of 'Data', or -1 if the continuation
// belongs to a tier 0 method.
CORINFO_CONTINUATION_OSR_IL_OFFSET_IN_DATA = 4,
// If this bit is set the continuation should continue on the thread
// pool.
CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL = 8,
// If this bit is set the continuation has a SynchronizationContext
// that we should continue on.
CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_SYNCHRONIZATION_CONTEXT = 16,
// If this bit is set the continuation has a TaskScheduler
// that we should continue on.
CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_TASK_SCHEDULER = 32,
};

struct CORINFO_ASYNC_INFO
Expand All @@ -1737,6 +1746,7 @@ struct CORINFO_ASYNC_INFO
CORINFO_METHOD_HANDLE captureExecutionContextMethHnd;
// Method handle for AsyncHelpers.RestoreExecutionContext
CORINFO_METHOD_HANDLE restoreExecutionContextMethHnd;
CORINFO_METHOD_HANDLE captureContinuationContextMethHnd;
};

// Flags passed from JIT to runtime.
Expand Down
57 changes: 56 additions & 1 deletion src/coreclr/jit/async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,13 @@ ContinuationLayout AsyncTransformation::LayOutContinuation(BasicBlock*
block->getTryIndex(), layout.ExceptionGCDataIndex);
}

if (call->GetAsyncInfo().ContinuationContextHandling == ContinuationContextHandling::ContinueOnCapturedContext)
{
layout.ContinuationContextGCDataIndex = layout.GCRefsCount++;
JITDUMP(" Continuation continues on captured context; context will be at GC@+%02u in GC data\n",
layout.ContinuationContextGCDataIndex);
}

if (call->GetAsyncInfo().ExecutionContextHandling == ExecutionContextHandling::AsyncSaveAndRestore)
{
layout.ExecContextGCDataIndex = layout.GCRefsCount++;
Expand Down Expand Up @@ -1200,13 +1207,16 @@ BasicBlock* AsyncTransformation::CreateSuspension(
LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_comp, storeState));

// Fill in 'flags'
unsigned continuationFlags = 0;
const AsyncCallInfo& callInfo = call->GetAsyncInfo();
unsigned continuationFlags = 0;
if (layout.ReturnInGCData)
continuationFlags |= CORINFO_CONTINUATION_RESULT_IN_GCDATA;
if (block->hasTryIndex())
continuationFlags |= CORINFO_CONTINUATION_NEEDS_EXCEPTION;
if (m_comp->doesMethodHavePatchpoints() || m_comp->opts.IsOSR())
continuationFlags |= CORINFO_CONTINUATION_OSR_IL_OFFSET_IN_DATA;
if (callInfo.ContinuationContextHandling == ContinuationContextHandling::ContinueOnThreadPool)
continuationFlags |= CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL;

newContinuation = m_comp->gtNewLclvNode(m_newContinuationVar, TYP_REF);
unsigned flagsOffset = m_comp->info.compCompHnd->getFieldOffset(m_asyncInfo->continuationFlagsFldHnd);
Expand Down Expand Up @@ -1386,6 +1396,51 @@ void AsyncTransformation::FillInGCPointersOnSuspension(const ContinuationLayout&
}
}

if (layout.ContinuationContextGCDataIndex != UINT_MAX)
{
// Insert call AsyncHelpers.CaptureContinuationContext(ref
// newContinuation.GCData[ContinuationContextGCDataIndex], ref newContinuation.Flags).
GenTree* contextElementPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF);
GenTree* flagsPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF);
GenTreeCall* captureCall =
m_comp->gtNewCallNode(CT_USER_FUNC, m_asyncInfo->captureContinuationContextMethHnd, TYP_VOID);

captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(flagsPlaceholder));
captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(contextElementPlaceholder));

m_comp->compCurBB = suspendBB;
m_comp->fgMorphTree(captureCall);

LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_comp, captureCall));

// Now replace contextElementPlaceholder with actual address of the context element
LIR::Use use;
bool gotUse = LIR::AsRange(suspendBB).TryGetUse(contextElementPlaceholder, &use);
assert(gotUse);

GenTree* objectArr = m_comp->gtNewLclvNode(objectArrLclNum, TYP_REF);
unsigned offset = OFFSETOF__CORINFO_Array__data + (layout.ContinuationContextGCDataIndex * TARGET_POINTER_SIZE);
GenTree* contextElementOffset =
m_comp->gtNewOperNode(GT_ADD, TYP_BYREF, objectArr, m_comp->gtNewIconNode((ssize_t)offset, TYP_I_IMPL));

LIR::AsRange(suspendBB).InsertBefore(contextElementPlaceholder, LIR::SeqTree(m_comp, contextElementOffset));
use.ReplaceWith(contextElementOffset);
LIR::AsRange(suspendBB).Remove(contextElementPlaceholder);

// And now replace flagsPlaceholder with actual address of the flags
gotUse = LIR::AsRange(suspendBB).TryGetUse(flagsPlaceholder, &use);
assert(gotUse);

newContinuation = m_comp->gtNewLclvNode(m_newContinuationVar, TYP_REF);
unsigned flagsOffset = m_comp->info.compCompHnd->getFieldOffset(m_asyncInfo->continuationFlagsFldHnd);
GenTree* flagsOffsetNode = m_comp->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation,
m_comp->gtNewIconNode((ssize_t)flagsOffset, TYP_I_IMPL));

LIR::AsRange(suspendBB).InsertBefore(flagsPlaceholder, LIR::SeqTree(m_comp, flagsOffsetNode));
use.ReplaceWith(flagsOffsetNode);
LIR::AsRange(suspendBB).Remove(flagsPlaceholder);
}

if (layout.ExecContextGCDataIndex != UINT_MAX)
{
GenTreeCall* captureExecContext =
Expand Down
17 changes: 9 additions & 8 deletions src/coreclr/jit/async.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ struct LiveLocalInfo

struct ContinuationLayout
{
unsigned DataSize = 0;
unsigned GCRefsCount = 0;
ClassLayout* ReturnStructLayout = nullptr;
unsigned ReturnSize = 0;
bool ReturnInGCData = false;
unsigned ReturnValDataOffset = UINT_MAX;
unsigned ExceptionGCDataIndex = UINT_MAX;
unsigned ExecContextGCDataIndex = UINT_MAX;
unsigned DataSize = 0;
unsigned GCRefsCount = 0;
ClassLayout* ReturnStructLayout = nullptr;
unsigned ReturnSize = 0;
bool ReturnInGCData = false;
unsigned ReturnValDataOffset = UINT_MAX;
unsigned ExceptionGCDataIndex = UINT_MAX;
unsigned ExecContextGCDataIndex = UINT_MAX;
unsigned ContinuationContextGCDataIndex = UINT_MAX;
const jitstd::vector<LiveLocalInfo>& Locals;

explicit ContinuationLayout(const jitstd::vector<LiveLocalInfo>& locals)
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4430,6 +4430,7 @@ class Compiler
#endif
// This call is a task await
PREFIX_IS_TASK_AWAIT = 0x00000080,
PREFIX_TASK_AWAIT_CONTINUE_ON_CAPTURED_CONTEXT = 0x00000100,
};

static void impValidateMemoryAccessOpcode(const BYTE* codeAddr, const BYTE* codeEndp, bool volatilePrefix);
Expand Down
13 changes: 12 additions & 1 deletion src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -4316,10 +4316,21 @@ enum class ExecutionContextHandling
AsyncSaveAndRestore,
};

enum class ContinuationContextHandling
{
// No special handling of SynchronizationContext/TaskScheduler is required.
None,
// Continue on SynchronizationContext/TaskScheduler
ContinueOnCapturedContext,
// Continue on thread pool thread
ContinueOnThreadPool,
};

// Additional async call info.
struct AsyncCallInfo
{
ExecutionContextHandling ExecutionContextHandling = ExecutionContextHandling::None;
ExecutionContextHandling ExecutionContextHandling = ExecutionContextHandling::None;
ContinuationContextHandling ContinuationContextHandling = ContinuationContextHandling::None;
};

// Return type descriptor of a GT_CALL node.
Expand Down
11 changes: 6 additions & 5 deletions src/coreclr/jit/importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9131,17 +9131,18 @@ void Compiler::impImportBlockCode(BasicBlock* block)
// many other places. We unfortunately embed that knowledge here.
if (opcode != CEE_CALLI)
{
bool isAwait = false;
// TODO: The configVal should be wired to the actual implementation
// that control the flow of sync context.
// We do not have that yet.
int configVal = -1; // -1 not configured, 0/1 configured to false/true
bool isAwait = false;
int configVal = -1; // -1 not configured, 0/1 configured to false/true
if (compIsAsync() && JitConfig.JitOptimizeAwait())
{
if (impMatchTaskAwaitPattern(codeAddr, codeEndp, &configVal))
{
isAwait = true;
prefixFlags |= PREFIX_IS_TASK_AWAIT;
if (configVal != 0)
{
prefixFlags |= PREFIX_TASK_AWAIT_CONTINUE_ON_CAPTURED_CONTEXT;
}
}
}

Expand Down
25 changes: 21 additions & 4 deletions src/coreclr/jit/importercalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,17 +701,26 @@ var_types Compiler::impImportCall(OPCODE opcode,
{
AsyncCallInfo asyncInfo;

JITDUMP("Call is an async ");

if ((prefixFlags & PREFIX_IS_TASK_AWAIT) != 0)
{
JITDUMP("task await\n");
JITDUMP("Call is an async task await\n");

asyncInfo.ExecutionContextHandling = ExecutionContextHandling::SaveAndRestore;

if ((prefixFlags & PREFIX_TASK_AWAIT_CONTINUE_ON_CAPTURED_CONTEXT) != 0)
{
asyncInfo.ContinuationContextHandling = ContinuationContextHandling::ContinueOnCapturedContext;
JITDUMP(" Continuation continues on captured context\n");
}
else
{
asyncInfo.ContinuationContextHandling = ContinuationContextHandling::ContinueOnThreadPool;
JITDUMP(" Continuation continues on thread pool\n");
}
}
else
{
JITDUMP("non-task await\n");
JITDUMP("Call is an async non-task await\n");
// Only expected non-task await to see in IL is one of the AsyncHelpers.AwaitAwaiter variants.
// These are awaits of custom awaitables, and they come with the behavior that the execution context
// is captured and restored on suspension/resumption.
Expand Down Expand Up @@ -7884,6 +7893,14 @@ void Compiler::impMarkInlineCandidateHelper(GenTreeCall* call,
return;
}

if (call->IsAsync() && (call->GetAsyncInfo().ContinuationContextHandling != ContinuationContextHandling::None))
{
// Cannot currently handle moving to captured context/thread pool when logically returning from inlinee.
//
inlineResult->NoteFatal(InlineObservation::CALLSITE_CONTINUATION_HANDLING);
return;
}

// Ignore indirect calls, unless they are indirect virtual stub calls with profile info.
//
if (call->gtCallType == CT_INDIRECT)
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/inline.def
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ INLINE_OBSERVATION(RETURN_TYPE_MISMATCH, bool, "return type mismatch",
INLINE_OBSERVATION(STFLD_NEEDS_HELPER, bool, "stfld needs helper", FATAL, CALLSITE)
INLINE_OBSERVATION(TOO_MANY_LOCALS, bool, "too many locals", FATAL, CALLSITE)
INLINE_OBSERVATION(PINVOKE_EH, bool, "PInvoke call site with EH", FATAL, CALLSITE)
INLINE_OBSERVATION(CONTINUATION_HANDLING, bool, "Callsite needs continuation handling", FATAL, CALLSITE)

// ------ Call Site Performance -------

Expand Down
1 change: 1 addition & 0 deletions src/coreclr/vm/corelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_VALUETASK_RETURNING_THUNK_1, Finalize
DEFINE_METHOD(ASYNC_HELPERS, UNSAFE_AWAIT_AWAITER_1, UnsafeAwaitAwaiter, GM_T_RetVoid)
DEFINE_METHOD(ASYNC_HELPERS, CAPTURE_EXECUTION_CONTEXT, CaptureExecutionContext, NoSig)
DEFINE_METHOD(ASYNC_HELPERS, RESTORE_EXECUTION_CONTEXT, RestoreExecutionContext, NoSig)
DEFINE_METHOD(ASYNC_HELPERS, CAPTURE_CONTINUATION_CONTEXT, CaptureContinuationContext, NoSig)

DEFINE_CLASS(SPAN_HELPERS, System, SpanHelpers)
DEFINE_METHOD(SPAN_HELPERS, MEMSET, Fill, SM_RefByte_Byte_UIntPtr_RetVoid)
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/vm/jitinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10256,6 +10256,7 @@ void CEEInfo::getAsyncInfo(CORINFO_ASYNC_INFO* pAsyncInfoOut)
pAsyncInfoOut->continuationsNeedMethodHandle = m_pMethodBeingCompiled->GetLoaderAllocator()->CanUnload();
pAsyncInfoOut->captureExecutionContextMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__CAPTURE_EXECUTION_CONTEXT));
pAsyncInfoOut->restoreExecutionContextMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__RESTORE_EXECUTION_CONTEXT));
pAsyncInfoOut->captureContinuationContextMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__CAPTURE_CONTINUATION_CONTEXT));

EE_TO_JIT_TRANSITION();
}
Expand Down
107 changes: 107 additions & 0 deletions src/tests/async/synchronization-context/synchronization-context.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

public class Async2SynchronizationContext
{
[Fact]
public static void TestSyncContexts()
{
SynchronizationContext prevContext = SynchronizationContext.Current;
try
{
SynchronizationContext.SetSynchronizationContext(new MySyncContext());
TestSyncContext().GetAwaiter().GetResult();
}
finally
{
SynchronizationContext.SetSynchronizationContext(prevContext);
}
}

private static async Task TestSyncContext()
{
MySyncContext context = (MySyncContext)SynchronizationContext.Current;
await WrappedYieldToThreadPool(suspend: false);
Assert.Same(context, SynchronizationContext.Current);

await WrappedYieldToThreadPool(suspend: true);
Assert.Same(context, SynchronizationContext.Current);

await WrappedYieldToThreadPool(suspend: true).ConfigureAwait(true);
Assert.Same(context, SynchronizationContext.Current);

await WrappedYieldToThreadPool(suspend: false).ConfigureAwait(false);
Assert.Same(context, SynchronizationContext.Current);

await WrappedYieldToThreadPool(suspend: true).ConfigureAwait(false);
Assert.Null(SynchronizationContext.Current);

await WrappedYieldToThreadWithCustomSyncContext();
Assert.Null(SynchronizationContext.Current);
}

private static async Task WrappedYieldToThreadPool(bool suspend)
{
if (suspend)
{
await Task.Yield();
}
}

private static async Task WrappedYieldToThreadWithCustomSyncContext()
{
Assert.Null(SynchronizationContext.Current);
await new YieldToThreadWithCustomSyncContext();
Assert.True(SynchronizationContext.Current is MySyncContext { });
}

private class MySyncContext : SynchronizationContext
{
public override void Post(SendOrPostCallback d, object state)
{
ThreadPool.UnsafeQueueUserWorkItem(_ =>
{
SynchronizationContext prevContext = Current;
try
{
SetSynchronizationContext(this);
d(state);
}
finally
{
SetSynchronizationContext(prevContext);
}
}, null);
}
}

private struct YieldToThreadWithCustomSyncContext : ICriticalNotifyCompletion
{
public YieldToThreadWithCustomSyncContext GetAwaiter() => this;

public void UnsafeOnCompleted(Action continuation)
{
new Thread(state =>
{
SynchronizationContext.SetSynchronizationContext(new MySyncContext());
continuation();
}).Start();
}

public void OnCompleted(Action continuation)
{
throw new NotImplementedException();
}

public bool IsCompleted => false;

public void GetResult() { }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk.IL">
<PropertyGroup>
<Optimize>True</Optimize>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(MSBuildProjectName).cs" />
</ItemGroup>
</Project>
Loading