Skip to content

Commit 2acef7e

Browse files
authored
Implement custom awaitable support (#78071)
This adds support for awaiting task-like types that are not natively supported by runtime async. Closes #77897.
1 parent 23aa61f commit 2acef7e

File tree

6 files changed

+627
-84
lines changed

6 files changed

+627
-84
lines changed

src/Compilers/CSharp/Portable/Lowering/AsyncRewriter/RuntimeAsyncRewriter.cs

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Collections.Generic;
56
using System.Diagnostics;
7+
using System.Diagnostics.CodeAnalysis;
8+
using System.Runtime.InteropServices;
69
using Microsoft.CodeAnalysis.CSharp.Symbols;
710

811
namespace Microsoft.CodeAnalysis.CSharp;
@@ -21,16 +24,19 @@ public static BoundStatement Rewrite(
2124
}
2225

2326
var rewriter = new RuntimeAsyncRewriter(compilationState.Compilation, new SyntheticBoundNodeFactory(method, node.Syntax, compilationState, diagnostics));
24-
return (BoundStatement)rewriter.Visit(node);
27+
var result = (BoundStatement)rewriter.Visit(node);
28+
return SpillSequenceSpiller.Rewrite(result, method, compilationState, diagnostics);
2529
}
2630

2731
private readonly CSharpCompilation _compilation;
2832
private readonly SyntheticBoundNodeFactory _factory;
33+
private readonly Dictionary<BoundAwaitableValuePlaceholder, BoundExpression> _placeholderMap;
2934

3035
private RuntimeAsyncRewriter(CSharpCompilation compilation, SyntheticBoundNodeFactory factory)
3136
{
3237
_compilation = compilation;
3338
_factory = factory;
39+
_placeholderMap = [];
3440
}
3541

3642
private NamedTypeSymbol Task
@@ -53,10 +59,11 @@ private NamedTypeSymbol ValueTaskT
5359
get => field ??= _compilation.GetWellKnownType(WellKnownType.System_Threading_Tasks_ValueTask_T);
5460
} = null!;
5561

56-
public BoundExpression VisitExpression(BoundExpression node)
62+
[return: NotNullIfNotNull(nameof(node))]
63+
public BoundExpression? VisitExpression(BoundExpression? node)
5764
{
5865
var result = Visit(node);
59-
return (BoundExpression)result;
66+
return (BoundExpression?)result;
6067
}
6168

6269
public override BoundNode? VisitAwaitExpression(BoundAwaitExpression node)
@@ -88,8 +95,7 @@ public BoundExpression VisitExpression(BoundExpression node)
8895
}
8996
else
9097
{
91-
// PROTOTYPE: when it's not a method with Task/TaskT/ValueTask/ValueTaskT returns, use the helpers
92-
return base.VisitAwaitExpression(node);
98+
return RewriteCustomAwaiterAwait(node);
9399
}
94100

95101
// PROTOTYPE: Make sure that we report an error in initial binding if these are missing
@@ -112,4 +118,81 @@ public BoundExpression VisitExpression(BoundExpression node)
112118
// System.Runtime.CompilerServices.RuntimeHelpers.Await(awaitedExpression)
113119
return _factory.Call(receiver: null, awaitMethod, VisitExpression(node.Expression));
114120
}
121+
122+
private BoundExpression RewriteCustomAwaiterAwait(BoundAwaitExpression node)
123+
{
124+
// await expr
125+
// becomes
126+
// var _tmp = expr.GetAwaiter();
127+
// if (!_tmp.IsCompleted)
128+
// UnsafeAwaitAwaiterFromRuntimeAsync(_tmp) OR AwaitAwaiterFromRuntimeAsync(_tmp);
129+
// _tmp.GetResult()
130+
131+
// PROTOTYPE: await dynamic will need runtime checks, see AsyncMethodToStateMachine.GenerateAwaitOnCompletedDynamic
132+
133+
var expr = VisitExpression(node.Expression);
134+
135+
var awaitableInfo = node.AwaitableInfo;
136+
var awaitablePlaceholder = awaitableInfo.AwaitableInstancePlaceholder;
137+
if (awaitablePlaceholder is not null)
138+
{
139+
_placeholderMap.Add(awaitablePlaceholder, expr);
140+
}
141+
142+
// expr.GetAwaiter()
143+
var getAwaiter = VisitExpression(awaitableInfo.GetAwaiter);
144+
Debug.Assert(getAwaiter is not null);
145+
146+
if (awaitablePlaceholder is not null)
147+
{
148+
_placeholderMap.Remove(awaitablePlaceholder);
149+
}
150+
151+
// var _tmp = expr.GetAwaiter();
152+
var tmp = _factory.StoreToTemp(getAwaiter, out BoundAssignmentOperator store, kind: SynthesizedLocalKind.Awaiter);
153+
154+
// _tmp.IsCompleted
155+
Debug.Assert(awaitableInfo.IsCompleted is not null);
156+
var isCompletedMethod = awaitableInfo.IsCompleted.GetMethod;
157+
Debug.Assert(isCompletedMethod is not null);
158+
var isCompletedCall = _factory.Call(tmp, isCompletedMethod);
159+
160+
// UnsafeAwaitAwaiterFromRuntimeAsync(_tmp) OR AwaitAwaiterFromRuntimeAsync(_tmp)
161+
var discardedUseSiteInfo = CompoundUseSiteInfo<AssemblySymbol>.Discarded;
162+
var useUnsafeAwait = _factory.Compilation.Conversions.ClassifyImplicitConversionFromType(
163+
tmp.Type,
164+
_factory.Compilation.GetWellKnownType(WellKnownType.System_Runtime_CompilerServices_ICriticalNotifyCompletion),
165+
ref discardedUseSiteInfo).IsImplicit;
166+
167+
// PROTOTYPE: Make sure that we report an error in initial binding if these are missing
168+
var awaitMethod = (MethodSymbol?)_compilation.GetWellKnownTypeMember(useUnsafeAwait
169+
? WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter
170+
: WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter);
171+
172+
Debug.Assert(awaitMethod is { Arity: 1 });
173+
174+
var awaitCall = _factory.Call(
175+
receiver: null,
176+
awaitMethod.Construct(tmp.Type),
177+
tmp);
178+
179+
// if (!_tmp.IsCompleted) awaitCall
180+
var ifNotCompleted = _factory.If(_factory.Not(isCompletedCall), _factory.ExpressionStatement(awaitCall));
181+
182+
// _tmp.GetResult()
183+
var getResultMethod = awaitableInfo.GetResult;
184+
Debug.Assert(getResultMethod is not null);
185+
var getResultCall = _factory.Call(tmp, getResultMethod);
186+
187+
// final sequence
188+
return _factory.SpillSequence(
189+
locals: [tmp.LocalSymbol],
190+
sideEffects: [_factory.ExpressionStatement(store), ifNotCompleted],
191+
result: getResultCall);
192+
}
193+
194+
public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node)
195+
{
196+
return _placeholderMap[node];
197+
}
115198
}

0 commit comments

Comments
 (0)