22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
5+ using System . Collections . Generic ;
56using System . Diagnostics ;
7+ using System . Diagnostics . CodeAnalysis ;
8+ using System . Runtime . InteropServices ;
69using Microsoft . CodeAnalysis . CSharp . Symbols ;
710
811namespace Microsoft . CodeAnalysis . CSharp ;
@@ -21,16 +24,19 @@ public static BoundStatement Rewrite(
2124 }
2225
2326 var rewriter = new RuntimeAsyncRewriter ( compilationState . Compilation , new SyntheticBoundNodeFactory ( method , node . Syntax , compilationState , diagnostics ) ) ;
24- return ( BoundStatement ) rewriter . Visit ( node ) ;
27+ var result = ( BoundStatement ) rewriter . Visit ( node ) ;
28+ return SpillSequenceSpiller . Rewrite ( result , method , compilationState , diagnostics ) ;
2529 }
2630
2731 private readonly CSharpCompilation _compilation ;
2832 private readonly SyntheticBoundNodeFactory _factory ;
33+ private readonly Dictionary < BoundAwaitableValuePlaceholder , BoundExpression > _placeholderMap ;
2934
3035 private RuntimeAsyncRewriter ( CSharpCompilation compilation , SyntheticBoundNodeFactory factory )
3136 {
3237 _compilation = compilation ;
3338 _factory = factory ;
39+ _placeholderMap = [ ] ;
3440 }
3541
3642 private NamedTypeSymbol Task
@@ -53,10 +59,11 @@ private NamedTypeSymbol ValueTaskT
5359 get => field ??= _compilation . GetWellKnownType ( WellKnownType . System_Threading_Tasks_ValueTask_T ) ;
5460 } = null ! ;
5561
56- public BoundExpression VisitExpression ( BoundExpression node )
62+ [ return : NotNullIfNotNull ( nameof ( node ) ) ]
63+ public BoundExpression ? VisitExpression ( BoundExpression ? node )
5764 {
5865 var result = Visit ( node ) ;
59- return ( BoundExpression ) result ;
66+ return ( BoundExpression ? ) result ;
6067 }
6168
6269 public override BoundNode ? VisitAwaitExpression ( BoundAwaitExpression node )
@@ -88,8 +95,7 @@ public BoundExpression VisitExpression(BoundExpression node)
8895 }
8996 else
9097 {
91- // PROTOTYPE: when it's not a method with Task/TaskT/ValueTask/ValueTaskT returns, use the helpers
92- return base . VisitAwaitExpression ( node ) ;
98+ return RewriteCustomAwaiterAwait ( node ) ;
9399 }
94100
95101 // PROTOTYPE: Make sure that we report an error in initial binding if these are missing
@@ -112,4 +118,81 @@ public BoundExpression VisitExpression(BoundExpression node)
112118 // System.Runtime.CompilerServices.RuntimeHelpers.Await(awaitedExpression)
113119 return _factory . Call ( receiver : null , awaitMethod , VisitExpression ( node . Expression ) ) ;
114120 }
121+
122+ private BoundExpression RewriteCustomAwaiterAwait ( BoundAwaitExpression node )
123+ {
124+ // await expr
125+ // becomes
126+ // var _tmp = expr.GetAwaiter();
127+ // if (!_tmp.IsCompleted)
128+ // UnsafeAwaitAwaiterFromRuntimeAsync(_tmp) OR AwaitAwaiterFromRuntimeAsync(_tmp);
129+ // _tmp.GetResult()
130+
131+ // PROTOTYPE: await dynamic will need runtime checks, see AsyncMethodToStateMachine.GenerateAwaitOnCompletedDynamic
132+
133+ var expr = VisitExpression ( node . Expression ) ;
134+
135+ var awaitableInfo = node . AwaitableInfo ;
136+ var awaitablePlaceholder = awaitableInfo . AwaitableInstancePlaceholder ;
137+ if ( awaitablePlaceholder is not null )
138+ {
139+ _placeholderMap . Add ( awaitablePlaceholder , expr ) ;
140+ }
141+
142+ // expr.GetAwaiter()
143+ var getAwaiter = VisitExpression ( awaitableInfo . GetAwaiter ) ;
144+ Debug . Assert ( getAwaiter is not null ) ;
145+
146+ if ( awaitablePlaceholder is not null )
147+ {
148+ _placeholderMap . Remove ( awaitablePlaceholder ) ;
149+ }
150+
151+ // var _tmp = expr.GetAwaiter();
152+ var tmp = _factory . StoreToTemp ( getAwaiter , out BoundAssignmentOperator store , kind : SynthesizedLocalKind . Awaiter ) ;
153+
154+ // _tmp.IsCompleted
155+ Debug . Assert ( awaitableInfo . IsCompleted is not null ) ;
156+ var isCompletedMethod = awaitableInfo . IsCompleted . GetMethod ;
157+ Debug . Assert ( isCompletedMethod is not null ) ;
158+ var isCompletedCall = _factory . Call ( tmp , isCompletedMethod ) ;
159+
160+ // UnsafeAwaitAwaiterFromRuntimeAsync(_tmp) OR AwaitAwaiterFromRuntimeAsync(_tmp)
161+ var discardedUseSiteInfo = CompoundUseSiteInfo < AssemblySymbol > . Discarded ;
162+ var useUnsafeAwait = _factory . Compilation . Conversions . ClassifyImplicitConversionFromType (
163+ tmp . Type ,
164+ _factory . Compilation . GetWellKnownType ( WellKnownType . System_Runtime_CompilerServices_ICriticalNotifyCompletion ) ,
165+ ref discardedUseSiteInfo ) . IsImplicit ;
166+
167+ // PROTOTYPE: Make sure that we report an error in initial binding if these are missing
168+ var awaitMethod = ( MethodSymbol ? ) _compilation . GetWellKnownTypeMember ( useUnsafeAwait
169+ ? WellKnownMember . System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter
170+ : WellKnownMember . System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter ) ;
171+
172+ Debug . Assert ( awaitMethod is { Arity : 1 } ) ;
173+
174+ var awaitCall = _factory . Call (
175+ receiver : null ,
176+ awaitMethod . Construct ( tmp . Type ) ,
177+ tmp ) ;
178+
179+ // if (!_tmp.IsCompleted) awaitCall
180+ var ifNotCompleted = _factory . If ( _factory . Not ( isCompletedCall ) , _factory . ExpressionStatement ( awaitCall ) ) ;
181+
182+ // _tmp.GetResult()
183+ var getResultMethod = awaitableInfo . GetResult ;
184+ Debug . Assert ( getResultMethod is not null ) ;
185+ var getResultCall = _factory . Call ( tmp , getResultMethod ) ;
186+
187+ // final sequence
188+ return _factory . SpillSequence (
189+ locals : [ tmp . LocalSymbol ] ,
190+ sideEffects : [ _factory . ExpressionStatement ( store ) , ifNotCompleted ] ,
191+ result : getResultCall ) ;
192+ }
193+
194+ public override BoundNode VisitAwaitableValuePlaceholder ( BoundAwaitableValuePlaceholder node )
195+ {
196+ return _placeholderMap [ node ] ;
197+ }
115198}
0 commit comments