Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ public MetadataType GetContinuationType(GCPointerMap pointerMap)
private sealed class ContinuationTypeHashtable : LockFreeReaderHashtable<GCPointerMap, AsyncContinuationType>
{
private readonly CompilerTypeSystemContext _parent;
private MetadataType _continuationType;

public ContinuationTypeHashtable(CompilerTypeSystemContext parent)
=> _parent = parent;
Expand All @@ -214,10 +213,22 @@ protected override bool CompareValueToValue(AsyncContinuationType value1, AsyncC
=> value1.PointerMap.Equals(value2.PointerMap);
protected override AsyncContinuationType CreateValueFromKey(GCPointerMap key)
{
_continuationType ??= _parent.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "Continuation"u8);
return new AsyncContinuationType(_continuationType, key);
return new AsyncContinuationType(_parent.ContinuationType, key);
}
}
private ContinuationTypeHashtable _continuationTypeHashtable;

private MetadataType _continuationType;

/// <summary>
/// Gets the base type for async continuations.
/// </summary>
public MetadataType ContinuationType
{
get
{
return _continuationType ??= SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "Continuation"u8);
}
}
}
}
4 changes: 2 additions & 2 deletions src/coreclr/tools/Common/JitInterface/CorInfoImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,7 @@ private void resolveToken(ref CORINFO_RESOLVED_TOKEN pResolvedToken)
// in rare cases a method that returns Task is not actually TaskReturning (i.e. returns T).
// we cannot resolve to an Async variant in such case.
// return NULL, so that caller would re-resolve as a regular method call
method = method.IsAsync && method.GetMethodDefinition().Signature.ReturnsTaskOrValueTask()
method = method.GetTypicalMethodDefinition().Signature.ReturnsTaskOrValueTask()
? _compilation.TypeSystemContext.GetAsyncVariantMethod(method)
: null;
}
Expand Down Expand Up @@ -3391,7 +3391,7 @@ private void getEEInfo(ref CORINFO_EE_INFO pEEInfoOut)

private void getAsyncInfo(ref CORINFO_ASYNC_INFO pAsyncInfoOut)
{
DefType continuation = _compilation.TypeSystemContext.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "Continuation"u8);
DefType continuation = _compilation.TypeSystemContext.ContinuationType;
pAsyncInfoOut.continuationClsHnd = ObjectToHandle(continuation);
pAsyncInfoOut.continuationNextFldHnd = ObjectToHandle(continuation.GetKnownField("Next"u8));
pAsyncInfoOut.continuationResumeInfoFldHnd = ObjectToHandle(continuation.GetKnownField("ResumeInfo"u8));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,14 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
{
_context = factory.TypeSystemContext;

// Do not try to optimize around continuation types, we don't keep good track of them.
// Allow CoreLib not to have this type.
if (_context.SystemModule.GetType("System.Runtime.CompilerServices"u8, "Continuation"u8, throwIfNotFound: false) is MetadataType continuationType)
{
_unsealedTypes.Add(continuationType);
_disqualifiedTypes.Add(continuationType);
}

var vtables = new Dictionary<TypeDesc, List<MethodDesc>>();
var dynamicInterfaceCastableImplementationTargets = new HashSet<TypeDesc>();

Expand Down
173 changes: 173 additions & 0 deletions src/coreclr/tools/aot/ILCompiler.Compiler/IL/ILImporter.Scanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ public enum ImportState : byte
private DependencyList _dependencies;
private BasicBlock _lateBasicBlocks;

private bool _asyncDependenciesReported;

private sealed class ExceptionRegion
{
public ILExceptionRegion ILRegion;
Expand Down Expand Up @@ -177,6 +179,14 @@ public ILImporter(ILScanner compilation, MethodDesc method, MethodIL methodIL =
}
}

if (_canonMethod.IsAsyncCall())
{
const string reason = "Async state machine";
DefType asyncHelpers = _compilation.TypeSystemContext.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8);
_dependencies.Add(_factory.MethodEntrypoint(asyncHelpers.GetKnownMethod("CaptureContexts"u8, null)), reason);
_dependencies.Add(_factory.MethodEntrypoint(asyncHelpers.GetKnownMethod("RestoreContexts"u8, null)), reason);
}

FindBasicBlocks();
ImportBasicBlocks();

Expand Down Expand Up @@ -310,6 +320,99 @@ private IMethodNode GetMethodEntrypoint(MethodDesc method)
return _factory.MethodEntrypointOrTentativeMethod(method);
}

// Check if a method call starts a task await pattern that can be
// optimized for runtime async.
// Roughly corresponds to impMatchTaskAwaitPattern in RyuJIT codebase
private bool MatchTaskAwaitPattern()
{
// We look for the following code patterns in runtime async methods:
//
// call[virt] <Method>
// [ OPTIONAL ]
// {
// [ OPTIONAL ]
// {
// stloc X;
// ldloca X
// }
// ldc.i4.0 / ldc.i4.1
// call[virt] <ConfigureAwait>
// }
// call <Await>

// Find where this basic block ends
int nextBBOffset = _currentOffset;
while (nextBBOffset < _basicBlocks.Length && _basicBlocks[nextBBOffset] == null)
nextBBOffset++;

// Create ILReader for what's left in the basic block
var reader = new ILReader(new ReadOnlySpan<byte>(_ilBytes, _currentOffset, nextBBOffset - _currentOffset));

if (!reader.HasNext)
return false;

ILOpcode opcode;

// If we can read at least two call tokens + an ldc, this could be ConfigureAwait
// so check for that.
if (reader.Size > 2 * (1 + sizeof(int)))
{
opcode = reader.ReadILOpcode();

// ConfigureAwait on a ValueTask will start with stloc/ldloca.
int stlocNum = opcode switch
{
>= ILOpcode.stloc_0 and <= ILOpcode.stloc_3 => opcode - ILOpcode.stloc_0,
ILOpcode.stloc => reader.ReadILUInt16(),
ILOpcode.stloc_s => reader.ReadILByte(),
_ => -1,
};

// if it was a stloc, check for matching ldloca
if (stlocNum != -1)
{
opcode = reader.ReadILOpcode();
int ldlocaNum = opcode switch
{
ILOpcode.ldloca_s => reader.ReadILByte(),
ILOpcode.ldloca => reader.ReadILUInt16(),
_ => -1,
};

if (stlocNum != ldlocaNum)
return false;

opcode = reader.ReadILOpcode();
}

if (opcode is (not ILOpcode.ldc_i4_0) and (not ILOpcode.ldc_i4_1))
{
if (stlocNum != -1)
{
// we had stloc/ldloca, we must see ConfigAwait
return false;
}

goto checkForAwait;
}

opcode = reader.ReadILOpcode();
if (opcode is (not ILOpcode.call) and (not ILOpcode.callvirt)
|| !IsTaskConfigureAwait((MethodDesc)_methodIL.GetObject(reader.ReadILToken()))
|| !reader.HasNext)
{
return false;
}
}

opcode = reader.ReadILOpcode();

checkForAwait:

return opcode == ILOpcode.call
&& IsAsyncHelpersAwait((MethodDesc)_methodIL.GetObject(reader.ReadILToken()));
}

private void ImportCall(ILOpcode opcode, int token)
{
// We get both the canonical and runtime determined form - JitInterface mostly operates
Expand Down Expand Up @@ -346,6 +449,40 @@ private void ImportCall(ILOpcode opcode, int token)
Debug.Assert(false); break;
}

// Are we scanning a call within a state machine?
if (opcode is ILOpcode.call or ILOpcode.callvirt
&& _canonMethod.IsAsyncCall())
{
// Add dependencies on infra to do suspend/resume. We only need to do this once per method scanned.
if (!_asyncDependenciesReported && method.IsAsync)
{
_asyncDependenciesReported = true;

const string asyncReason = "Async state machine";

var resumptionStub = new AsyncResumptionStub(_canonMethod, _compilation.TypeSystemContext.GeneratedAssembly.GetGlobalModuleType());
_dependencies.Add(_compilation.NodeFactory.MethodEntrypoint(resumptionStub), asyncReason);

_dependencies.Add(_factory.ConstructedTypeSymbol(_compilation.TypeSystemContext.ContinuationType), asyncReason);

DefType asyncHelpers = _compilation.TypeSystemContext.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8);

_dependencies.Add(_factory.MethodEntrypoint(asyncHelpers.GetKnownMethod("AllocContinuation"u8, null)), asyncReason);
_dependencies.Add(_factory.MethodEntrypoint(asyncHelpers.GetKnownMethod("CaptureExecutionContext"u8, null)), asyncReason);
_dependencies.Add(_factory.MethodEntrypoint(asyncHelpers.GetKnownMethod("RestoreExecutionContext"u8, null)), asyncReason);
_dependencies.Add(_factory.MethodEntrypoint(asyncHelpers.GetKnownMethod("CaptureContinuationContext"u8, null)), asyncReason);
}

// If this is the task await pattern, we're actually going to call the variant
// so switch our focus to the variant.
if (method.GetTypicalMethodDefinition().Signature.ReturnsTaskOrValueTask()
&& MatchTaskAwaitPattern())
{
runtimeDeterminedMethod = _factory.TypeSystemContext.GetAsyncVariantMethod(runtimeDeterminedMethod);
method = _factory.TypeSystemContext.GetAsyncVariantMethod(method);
}
}

if (opcode == ILOpcode.newobj)
{
TypeDesc owningType = runtimeDeterminedMethod.OwningType;
Expand Down Expand Up @@ -1550,6 +1687,42 @@ private static bool IsMemoryMarshalGetArrayDataReference(MethodDesc method)
return false;
}

private static bool IsAsyncHelpersAwait(MethodDesc method)
{
if (method.IsIntrinsic && method.Name.SequenceEqual("Await"u8))
{
MetadataType owningType = method.OwningType as MetadataType;
if (owningType != null)
{
return owningType.Module == method.Context.SystemModule
&& owningType.Name.SequenceEqual("AsyncHelpers"u8)
&& owningType.Namespace.SequenceEqual("System.Runtime.CompilerServices"u8);
}
}

return false;
}

private static bool IsTaskConfigureAwait(MethodDesc method)
{
if (method.IsIntrinsic && method.Name.SequenceEqual("ConfigureAwait"u8))
{
MetadataType owningType = method.OwningType as MetadataType;
if (owningType != null)
{
ReadOnlySpan<byte> typeName = owningType.Name;
return owningType.Module == method.Context.SystemModule
&& owningType.Namespace.SequenceEqual("System.Threading.Tasks"u8)
&& (typeName.SequenceEqual("Task"u8)
|| typeName.SequenceEqual("Task`1"u8)
|| typeName.SequenceEqual("ValueTask"u8)
|| typeName.SequenceEqual("ValueTask`1"u8));
}
}

return false;
}

private DefType GetWellKnownType(WellKnownType wellKnownType)
{
return _compilation.TypeSystemContext.GetWellKnownType(wellKnownType);
Expand Down
Loading