Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using Microsoft.CodeAnalysis.CSharp.Symbols;

namespace Microsoft.CodeAnalysis.CSharp;
Expand All @@ -21,16 +24,19 @@ public static BoundStatement Rewrite(
}

var rewriter = new RuntimeAsyncRewriter(compilationState.Compilation, new SyntheticBoundNodeFactory(method, node.Syntax, compilationState, diagnostics));
return (BoundStatement)rewriter.Visit(node);
var result = (BoundStatement)rewriter.Visit(node);
return SpillSequenceSpiller.Rewrite(result, method, compilationState, diagnostics);
}

private readonly CSharpCompilation _compilation;
private readonly SyntheticBoundNodeFactory _factory;
private readonly Dictionary<BoundAwaitableValuePlaceholder, BoundExpression> _placeholderMap;

private RuntimeAsyncRewriter(CSharpCompilation compilation, SyntheticBoundNodeFactory factory)
{
_compilation = compilation;
_factory = factory;
_placeholderMap = [];
}

private NamedTypeSymbol Task
Expand All @@ -53,10 +59,11 @@ private NamedTypeSymbol ValueTaskT
get => field ??= _compilation.GetWellKnownType(WellKnownType.System_Threading_Tasks_ValueTask_T);
} = null!;

public BoundExpression VisitExpression(BoundExpression node)
[return: NotNullIfNotNull(nameof(node))]
public BoundExpression? VisitExpression(BoundExpression? node)
{
var result = Visit(node);
return (BoundExpression)result;
return (BoundExpression?)result;
}

public override BoundNode? VisitAwaitExpression(BoundAwaitExpression node)
Expand Down Expand Up @@ -88,8 +95,7 @@ public BoundExpression VisitExpression(BoundExpression node)
}
else
{
// PROTOTYPE: when it's not a method with Task/TaskT/ValueTask/ValueTaskT returns, use the helpers
return base.VisitAwaitExpression(node);
return RewriteCustomAwaiterAwait(node);
}

// PROTOTYPE: Make sure that we report an error in initial binding if these are missing
Expand All @@ -112,4 +118,81 @@ public BoundExpression VisitExpression(BoundExpression node)
// System.Runtime.CompilerServices.RuntimeHelpers.Await(awaitedExpression)
return _factory.Call(receiver: null, awaitMethod, VisitExpression(node.Expression));
}

private BoundExpression RewriteCustomAwaiterAwait(BoundAwaitExpression node)
{
// await expr
// becomes
// var _tmp = expr.GetAwaiter();
// if (!_tmp.IsCompleted)
// UnsafeAwaitAwaiterFromRuntimeAsync(_tmp) OR AwaitAwaiterFromRuntimeAsync(_tmp);
// _tmp.GetResult()

// PROTOTYPE: await dynamic will need runtime checks, see AsyncMethodToStateMachine.GenerateAwaitOnCompletedDynamic

var expr = VisitExpression(node.Expression);

var awaitableInfo = node.AwaitableInfo;
var awaitablePlaceholder = awaitableInfo.AwaitableInstancePlaceholder;
if (awaitablePlaceholder is not null)
Copy link
Contributor

@cston cston Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is awaitablePlaceholder == null? (If it is null, it looks like we'll not execute expr.) Consider asserting instead of using if.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was basing it off the similar handling in the existing async rewriter, but I believe you're correct, this should never be null.

Copy link
Member Author

@333fred 333fred Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it can be null. See https://github.com/dotnet/roslyn/blob/main/src/Compilers/CSharp/Portable/Binder/UsingStatementBinder.cs#L153-L154. It appears to only be the case when we're dealing with dynamic, which is already covered under the prototype comment above.

{
_placeholderMap.Add(awaitablePlaceholder, expr);
}

// expr.GetAwaiter()
var getAwaiter = VisitExpression(awaitableInfo.GetAwaiter);
Debug.Assert(getAwaiter is not null);

if (awaitablePlaceholder is not null)
{
_placeholderMap.Remove(awaitablePlaceholder);
}

// var _tmp = expr.GetAwaiter();
var tmp = _factory.StoreToTemp(getAwaiter, out BoundAssignmentOperator store, kind: SynthesizedLocalKind.Awaiter);

// _tmp.IsCompleted
Debug.Assert(awaitableInfo.IsCompleted is not null);
var isCompletedMethod = awaitableInfo.IsCompleted.GetMethod;
Debug.Assert(isCompletedMethod is not null);
var isCompletedCall = _factory.Call(tmp, isCompletedMethod);

// UnsafeAwaitAwaiterFromRuntimeAsync(_tmp) OR AwaitAwaiterFromRuntimeAsync(_tmp)
var discardedUseSiteInfo = CompoundUseSiteInfo<AssemblySymbol>.Discarded;
var useUnsafeAwait = _factory.Compilation.Conversions.ClassifyImplicitConversionFromType(
tmp.Type,
_factory.Compilation.GetWellKnownType(WellKnownType.System_Runtime_CompilerServices_ICriticalNotifyCompletion),
ref discardedUseSiteInfo).IsImplicit;

// PROTOTYPE: Make sure that we report an error in initial binding if these are missing
var awaitMethod = (MethodSymbol?)_compilation.GetWellKnownTypeMember(useUnsafeAwait
? WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter
: WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter);

Debug.Assert(awaitMethod is { Arity: 1 });

var awaitCall = _factory.Call(
receiver: null,
awaitMethod.Construct(tmp.Type),
tmp);

// if (!_tmp.IsCompleted) awaitCall
var ifNotCompleted = _factory.If(_factory.Not(isCompletedCall), _factory.ExpressionStatement(awaitCall));

// _tmp.GetResult()
var getResultMethod = awaitableInfo.GetResult;
Debug.Assert(getResultMethod is not null);
var getResultCall = _factory.Call(tmp, getResultMethod);

// final sequence
return _factory.SpillSequence(
locals: [tmp.LocalSymbol],
sideEffects: [_factory.ExpressionStatement(store), ifNotCompleted],
result: getResultCall);
}

public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node)
{
return _placeholderMap[node];
}
}
Loading