diff --git a/src/Compilers/CSharp/Portable/Lowering/AsyncRewriter/RuntimeAsyncRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/AsyncRewriter/RuntimeAsyncRewriter.cs index 58d347662435b..bd3a09a78f228 100644 --- a/src/Compilers/CSharp/Portable/Lowering/AsyncRewriter/RuntimeAsyncRewriter.cs +++ b/src/Compilers/CSharp/Portable/Lowering/AsyncRewriter/RuntimeAsyncRewriter.cs @@ -2,7 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; using Microsoft.CodeAnalysis.CSharp.Symbols; namespace Microsoft.CodeAnalysis.CSharp; @@ -21,16 +24,19 @@ public static BoundStatement Rewrite( } var rewriter = new RuntimeAsyncRewriter(compilationState.Compilation, new SyntheticBoundNodeFactory(method, node.Syntax, compilationState, diagnostics)); - return (BoundStatement)rewriter.Visit(node); + var result = (BoundStatement)rewriter.Visit(node); + return SpillSequenceSpiller.Rewrite(result, method, compilationState, diagnostics); } private readonly CSharpCompilation _compilation; private readonly SyntheticBoundNodeFactory _factory; + private readonly Dictionary _placeholderMap; private RuntimeAsyncRewriter(CSharpCompilation compilation, SyntheticBoundNodeFactory factory) { _compilation = compilation; _factory = factory; + _placeholderMap = []; } private NamedTypeSymbol Task @@ -53,10 +59,11 @@ private NamedTypeSymbol ValueTaskT get => field ??= _compilation.GetWellKnownType(WellKnownType.System_Threading_Tasks_ValueTask_T); } = null!; - public BoundExpression VisitExpression(BoundExpression node) + [return: NotNullIfNotNull(nameof(node))] + public BoundExpression? VisitExpression(BoundExpression? node) { var result = Visit(node); - return (BoundExpression)result; + return (BoundExpression?)result; } public override BoundNode? VisitAwaitExpression(BoundAwaitExpression node) @@ -88,8 +95,7 @@ public BoundExpression VisitExpression(BoundExpression node) } else { - // PROTOTYPE: when it's not a method with Task/TaskT/ValueTask/ValueTaskT returns, use the helpers - return base.VisitAwaitExpression(node); + return RewriteCustomAwaiterAwait(node); } // PROTOTYPE: Make sure that we report an error in initial binding if these are missing @@ -112,4 +118,81 @@ public BoundExpression VisitExpression(BoundExpression node) // System.Runtime.CompilerServices.RuntimeHelpers.Await(awaitedExpression) return _factory.Call(receiver: null, awaitMethod, VisitExpression(node.Expression)); } + + private BoundExpression RewriteCustomAwaiterAwait(BoundAwaitExpression node) + { + // await expr + // becomes + // var _tmp = expr.GetAwaiter(); + // if (!_tmp.IsCompleted) + // UnsafeAwaitAwaiterFromRuntimeAsync(_tmp) OR AwaitAwaiterFromRuntimeAsync(_tmp); + // _tmp.GetResult() + + // PROTOTYPE: await dynamic will need runtime checks, see AsyncMethodToStateMachine.GenerateAwaitOnCompletedDynamic + + var expr = VisitExpression(node.Expression); + + var awaitableInfo = node.AwaitableInfo; + var awaitablePlaceholder = awaitableInfo.AwaitableInstancePlaceholder; + if (awaitablePlaceholder is not null) + { + _placeholderMap.Add(awaitablePlaceholder, expr); + } + + // expr.GetAwaiter() + var getAwaiter = VisitExpression(awaitableInfo.GetAwaiter); + Debug.Assert(getAwaiter is not null); + + if (awaitablePlaceholder is not null) + { + _placeholderMap.Remove(awaitablePlaceholder); + } + + // var _tmp = expr.GetAwaiter(); + var tmp = _factory.StoreToTemp(getAwaiter, out BoundAssignmentOperator store, kind: SynthesizedLocalKind.Awaiter); + + // _tmp.IsCompleted + Debug.Assert(awaitableInfo.IsCompleted is not null); + var isCompletedMethod = awaitableInfo.IsCompleted.GetMethod; + Debug.Assert(isCompletedMethod is not null); + var isCompletedCall = _factory.Call(tmp, isCompletedMethod); + + // UnsafeAwaitAwaiterFromRuntimeAsync(_tmp) OR AwaitAwaiterFromRuntimeAsync(_tmp) + var discardedUseSiteInfo = CompoundUseSiteInfo.Discarded; + var useUnsafeAwait = _factory.Compilation.Conversions.ClassifyImplicitConversionFromType( + tmp.Type, + _factory.Compilation.GetWellKnownType(WellKnownType.System_Runtime_CompilerServices_ICriticalNotifyCompletion), + ref discardedUseSiteInfo).IsImplicit; + + // PROTOTYPE: Make sure that we report an error in initial binding if these are missing + var awaitMethod = (MethodSymbol?)_compilation.GetWellKnownTypeMember(useUnsafeAwait + ? WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter + : WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter); + + Debug.Assert(awaitMethod is { Arity: 1 }); + + var awaitCall = _factory.Call( + receiver: null, + awaitMethod.Construct(tmp.Type), + tmp); + + // if (!_tmp.IsCompleted) awaitCall + var ifNotCompleted = _factory.If(_factory.Not(isCompletedCall), _factory.ExpressionStatement(awaitCall)); + + // _tmp.GetResult() + var getResultMethod = awaitableInfo.GetResult; + Debug.Assert(getResultMethod is not null); + var getResultCall = _factory.Call(tmp, getResultMethod); + + // final sequence + return _factory.SpillSequence( + locals: [tmp.LocalSymbol], + sideEffects: [_factory.ExpressionStatement(store), ifNotCompleted], + result: getResultCall); + } + + public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node) + { + return _placeholderMap[node]; + } } diff --git a/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncTests.cs b/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncTests.cs index 2807e211f1817..9db5670663bf0 100644 --- a/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncTests.cs +++ b/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncTests.cs @@ -6,15 +6,15 @@ using System.Collections.Generic; using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using Basic.Reference.Assemblies; using Microsoft.CodeAnalysis.CSharp.Symbols; using Microsoft.CodeAnalysis.CSharp.Test.Utilities; using Microsoft.CodeAnalysis.Test.Utilities; using Roslyn.Test.Utilities; using Roslyn.Utilities; using Xunit; -using Basic.Reference.Assemblies; -using System.Runtime.CompilerServices; -using System.Reflection; // PROTOTYPE: Verify execution of runtime async methods // PROTOTYPE: ILVerify for runtime async? @@ -27,9 +27,14 @@ public class CodeGenAsyncTests : EmitMetadataTestBase private const MethodImplAttributes MethodImplOptionsAsync = (MethodImplAttributes)1024; private static CSharpParseOptions WithRuntimeAsync(CSharpParseOptions options) => options.WithFeature("runtime-async", "on"); - internal static string ExpectedOutput(string output) + internal static string ExpectedOutput(string output, bool isRuntimeAsync = false) { - return ExecutionConditionUtil.IsMonoOrCoreClr ? output : null; + return ExecutionConditionUtil.IsMonoOrCoreClr + ? isRuntimeAsync + // PROTOTYPE: Verify runtime async output + ? null + : output + : null; } private static CSharpCompilation CreateCompilation(string source, IEnumerable references = null, CSharpCompilationOptions options = null) @@ -177,7 +182,7 @@ public static void Main() var comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], targetFramework: TargetFramework.Net90, parseOptions: WithRuntimeAsync(TestOptions.RegularPreview)); comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput(expected, isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = ReturnValueMissing("F", "0x2e"), }, symbolValidator: verify); @@ -240,11 +245,11 @@ public static async Task Main() } }"; - //var expected = "42"; + var expected = "42"; var comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], targetFramework: TargetFramework.Net90, parseOptions: WithRuntimeAsync(TestOptions.RegularPreview)); comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput(expected, isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = $""" {ReturnValueMissing("F", "0xa")} @@ -306,7 +311,7 @@ public static void Main() var comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], targetFramework: TargetFramework.Net90, parseOptions: WithRuntimeAsync(TestOptions.RegularPreview)); comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput(expected, isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = "[F]: Unexpected type on the stack. { Offset = 0x2e, Found = ref 'string', Expected = ref '[System.Runtime]System.Threading.Tasks.Task`1' }", }, symbolValidator: verify); @@ -363,11 +368,11 @@ public static async Task Main() Console.WriteLine(s); } }"; - //var expected = @"O brave new world..."; + var expected = @"O brave new world..."; var comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], targetFramework: TargetFramework.Net90, parseOptions: WithRuntimeAsync(TestOptions.RegularPreview)); comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput(expected, isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = $$""" [F]: Unexpected type on the stack. { Offset = 0xa, Found = ref 'string', Expected = value '[System.Runtime]System.Threading.Tasks.ValueTask`1' } @@ -445,7 +450,7 @@ public static async Task Main() """, }; - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("42", isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = ilVerifyMessage }, symbolValidator: verify); @@ -585,7 +590,7 @@ public static async Task Main() """, }; - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("42", isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = ilVerifyMessage, }, symbolValidator: verify); @@ -710,7 +715,7 @@ public static async Task Main() """, }; - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("42", isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = ilVerifyMessage, }, symbolValidator: verify); @@ -824,7 +829,7 @@ public static async Task Main() """, }; - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("42", isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = ilVerifyMessage, }, symbolValidator: verify); @@ -956,7 +961,7 @@ public static async Task Main() """, }; - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("42", isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = ilVerifyMessage, }, symbolValidator: verify); @@ -1070,7 +1075,7 @@ public static async Task Main() """, }; - var verifier = CompileAndVerify(comp, verify: Verification.Fails with + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("42", isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = ilVerifyMessage, }, symbolValidator: verify); @@ -1127,75 +1132,135 @@ void verify(ModuleSymbol module) } } - [Fact] - public void Conformance_Awaiting_Methods_Generic01() + [Theory] + [CombinatorialData] + public void Conformance_Awaiting_Methods_Generic01(bool useCritical) { - var source = @" -using System; -using System.Runtime.CompilerServices; -using System.Threading; + var source = $$""" + using System; + using System.Runtime.CompilerServices; + using System.Threading; -//Implementation of you own async pattern -public class MyTask -{ - public MyTaskAwaiter GetAwaiter() - { - return new MyTaskAwaiter(); - } + //Implementation of you own async pattern + public class MyTask + { + public MyTaskAwaiter GetAwaiter() + { + return new MyTaskAwaiter(); + } - public async void Run(U u) where U : MyTask, new() - { - try - { - int tests = 0; + public async void Run(U u) where U : MyTask, new() + { + try + { + int tests = 0; - tests++; - var rez = await u; - if (rez == 0) - Driver.Count++; + tests++; + var rez = await u; + if (rez == 0) + Driver.Count++; - Driver.Result = Driver.Count - tests; - } - finally - { - //When test complete, set the flag. - Driver.CompletedSignal.Set(); - } - } -} -public class MyTaskAwaiter : INotifyCompletion -{ - public void OnCompleted(Action continuationAction) - { - } + Driver.Result = Driver.Count - tests; + } + finally + { + //When test complete, set the flag. + Driver.CompletedSignal.Set(); + } + } + } + public class MyTaskAwaiter : {{(useCritical ? "ICriticalNotifyCompletion" : "INotifyCompletion")}} + { + public void OnCompleted(Action continuationAction) + { + } - public T GetResult() - { - return default(T); - } + public void UnsafeOnCompleted(Action continuationAction) + { + } - public bool IsCompleted { get { return true; } } -} -//------------------------------------- + public T GetResult() + { + return default(T); + } -class Driver -{ - public static int Result = -1; - public static int Count = 0; - public static AutoResetEvent CompletedSignal = new AutoResetEvent(false); - static void Main() - { - new MyTask().Run>(new MyTask()); + public bool IsCompleted { get { return true; } } + } + //------------------------------------- - CompletedSignal.WaitOne(); + class Driver + { + public static int Result = -1; + public static int Count = 0; + public static AutoResetEvent CompletedSignal = new AutoResetEvent(false); + static void Main() + { + new MyTask().Run>(new MyTask()); + + CompletedSignal.WaitOne(); + + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result); + } + } + """; - // 0 - success - // 1 - failed (test completed) - // -1 - failed (test incomplete - deadlock, etc) - Console.WriteLine(Driver.Result); - } -}"; CompileAndVerify(source, "0"); + + var comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], targetFramework: TargetFramework.Net90, parseOptions: WithRuntimeAsync(TestOptions.RegularPreview)); + comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); + + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("0", isRuntimeAsync: true), verify: Verification.FailsPEVerify); + verifier.VerifyDiagnostics(); + + verifier.VerifyIL("MyTask.Run", $$""" + { + // Code size 79 (0x4f) + .maxstack 2 + .locals init (int V_0, //tests + MyTaskAwaiter V_1) + .try + { + IL_0000: ldc.i4.0 + IL_0001: stloc.0 + IL_0002: ldloc.0 + IL_0003: ldc.i4.1 + IL_0004: add + IL_0005: stloc.0 + IL_0006: ldarg.1 + IL_0007: box "U" + IL_000c: callvirt "MyTaskAwaiter MyTask.GetAwaiter()" + IL_0011: stloc.1 + IL_0012: ldloc.1 + IL_0013: callvirt "bool MyTaskAwaiter.IsCompleted.get" + IL_0018: brtrue.s IL_0020 + IL_001a: ldloc.1 + IL_001b: call "void System.Runtime.CompilerServices.RuntimeHelpers.{{(useCritical ? "Unsafe" : "")}}AwaitAwaiterFromRuntimeAsync>(MyTaskAwaiter)" + IL_0020: ldloc.1 + IL_0021: callvirt "int MyTaskAwaiter.GetResult()" + IL_0026: brtrue.s IL_0034 + IL_0028: ldsfld "int Driver.Count" + IL_002d: ldc.i4.1 + IL_002e: add + IL_002f: stsfld "int Driver.Count" + IL_0034: ldsfld "int Driver.Count" + IL_0039: ldloc.0 + IL_003a: sub + IL_003b: stsfld "int Driver.Result" + IL_0040: leave.s IL_004e + } + finally + { + IL_0042: ldsfld "System.Threading.AutoResetEvent Driver.CompletedSignal" + IL_0047: callvirt "bool System.Threading.EventWaitHandle.Set()" + IL_004c: pop + IL_004d: endfinally + } + IL_004e: ret + } + """); } [Fact] @@ -2984,6 +3049,57 @@ static void Main() } }"; CompileAndVerify(source, "0"); + + var comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], parseOptions: WithRuntimeAsync(TestOptions.RegularPreview), targetFramework: TargetFramework.NetCoreApp); + comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); + + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("0", isRuntimeAsync: true), verify: Verification.FailsPEVerify); + verifier.VerifyIL("MyTask.Run", """ + { + // Code size 80 (0x50) + .maxstack 2 + .locals init (int V_0, //tests + MyTaskAwaiter V_1) + IL_0000: ldc.i4.0 + IL_0001: stloc.0 + .try + { + IL_0002: ldloc.0 + IL_0003: ldc.i4.1 + IL_0004: add + IL_0005: stloc.0 + IL_0006: newobj "MyTask..ctor()" + IL_000b: call "MyTaskAwaiter Extension.GetAwaiter(MyTask)" + IL_0010: stloc.1 + IL_0011: ldloc.1 + IL_0012: callvirt "bool MyTaskAwaiter.IsCompleted.get" + IL_0017: brtrue.s IL_001f + IL_0019: ldloc.1 + IL_001a: call "void System.Runtime.CompilerServices.RuntimeHelpers.AwaitAwaiterFromRuntimeAsync(MyTaskAwaiter)" + IL_001f: ldloc.1 + IL_0020: callvirt "int MyTaskAwaiter.GetResult()" + IL_0025: ldc.i4.s 123 + IL_0027: bne.un.s IL_0035 + IL_0029: ldsfld "int Driver.Count" + IL_002e: ldc.i4.1 + IL_002f: add + IL_0030: stsfld "int Driver.Count" + IL_0035: leave.s IL_004f + } + finally + { + IL_0037: ldsfld "int Driver.Count" + IL_003c: ldloc.0 + IL_003d: sub + IL_003e: stsfld "int Driver.Result" + IL_0043: ldsfld "System.Threading.AutoResetEvent Driver.CompletedSignal" + IL_0048: callvirt "bool System.Threading.EventWaitHandle.Set()" + IL_004d: pop + IL_004e: endfinally + } + IL_004f: ret + } + """); } [Fact] @@ -3060,6 +3176,57 @@ static void Main() } }"; CompileAndVerify(source, "0"); + + var comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], parseOptions: WithRuntimeAsync(TestOptions.RegularPreview), targetFramework: TargetFramework.NetCoreApp); + comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); + + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("0", isRuntimeAsync: true), verify: Verification.FailsPEVerify); + verifier.VerifyIL("MyTask.Run", """ + { + // Code size 80 (0x50) + .maxstack 2 + .locals init (int V_0, //tests + MyTaskAwaiter V_1) + IL_0000: ldc.i4.0 + IL_0001: stloc.0 + .try + { + IL_0002: ldloc.0 + IL_0003: ldc.i4.1 + IL_0004: add + IL_0005: stloc.0 + IL_0006: newobj "MyTask..ctor()" + IL_000b: callvirt "MyTaskAwaiter MyTask.GetAwaiter()" + IL_0010: stloc.1 + IL_0011: ldloc.1 + IL_0012: callvirt "bool MyTaskBaseAwaiter.IsCompleted.get" + IL_0017: brtrue.s IL_001f + IL_0019: ldloc.1 + IL_001a: call "void System.Runtime.CompilerServices.RuntimeHelpers.AwaitAwaiterFromRuntimeAsync(MyTaskAwaiter)" + IL_001f: ldloc.1 + IL_0020: callvirt "int MyTaskBaseAwaiter.GetResult()" + IL_0025: ldc.i4.s 123 + IL_0027: bne.un.s IL_0035 + IL_0029: ldsfld "int Driver.Count" + IL_002e: ldc.i4.1 + IL_002f: add + IL_0030: stsfld "int Driver.Count" + IL_0035: leave.s IL_004f + } + finally + { + IL_0037: ldsfld "int Driver.Count" + IL_003c: ldloc.0 + IL_003d: sub + IL_003e: stsfld "int Driver.Result" + IL_0043: ldsfld "System.Threading.AutoResetEvent Driver.CompletedSignal" + IL_0048: callvirt "bool System.Threading.EventWaitHandle.Set()" + IL_004d: pop + IL_004e: endfinally + } + IL_004f: ret + } + """); } [Fact] @@ -6722,7 +6889,38 @@ static async Task Main() } }"; var comp = CSharpTestBase.CreateCompilation(source, options: TestOptions.ReleaseExe); - CompileAndVerify(comp, expectedOutput: "StructAwaitable"); + var expected = "StructAwaitable"; + CompileAndVerify(comp, expectedOutput: expected); + + comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], parseOptions: WithRuntimeAsync(TestOptions.RegularPreview), targetFramework: TargetFramework.NetCoreApp); + comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput(expected, isRuntimeAsync: true), verify: Verification.Fails with + { + ILVerifyMessage = ReturnValueMissing("Main", "0x2a") + }); + + verifier.VerifyIL("Program.Main()", """ + { + // Code size 43 (0x2b) + .maxstack 1 + .locals init (System.Runtime.CompilerServices.TaskAwaiter V_0, + StructAwaitable V_1) + IL_0000: ldloca.s V_1 + IL_0002: initobj "StructAwaitable" + IL_0008: ldloc.1 + IL_0009: box "StructAwaitable" + IL_000e: call "System.Runtime.CompilerServices.TaskAwaiter Extensions.GetAwaiter(IAwaitable)" + IL_0013: stloc.0 + IL_0014: ldloca.s V_0 + IL_0016: call "bool System.Runtime.CompilerServices.TaskAwaiter.IsCompleted.get" + IL_001b: brtrue.s IL_0023 + IL_001d: ldloc.0 + IL_001e: call "void System.Runtime.CompilerServices.RuntimeHelpers.UnsafeAwaitAwaiterFromRuntimeAsync(System.Runtime.CompilerServices.TaskAwaiter)" + IL_0023: ldloca.s V_0 + IL_0025: call "void System.Runtime.CompilerServices.TaskAwaiter.GetResult()" + IL_002a: ret + } + """); } [Fact] @@ -6755,7 +6953,39 @@ static async Task Main() } }"; var comp = CSharpTestBase.CreateCompilation(source, options: TestOptions.ReleaseExe); - CompileAndVerify(comp, expectedOutput: "StructAwaitable"); + var expected = "StructAwaitable"; + CompileAndVerify(comp, expectedOutput: expected); + + comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], parseOptions: WithRuntimeAsync(TestOptions.RegularPreview), targetFramework: TargetFramework.NetCoreApp); + comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput(expected, isRuntimeAsync: true), verify: Verification.Fails with + { + ILVerifyMessage = ReturnValueMissing("Main", "0x2f") + }); + + verifier.VerifyIL("Program.Main()", """ + { + // Code size 48 (0x30) + .maxstack 1 + .locals init (StructAwaitable V_0, + System.Runtime.CompilerServices.TaskAwaiter V_1) + IL_0000: ldloca.s V_0 + IL_0002: initobj "StructAwaitable" + IL_0008: ldloc.0 + IL_0009: newobj "StructAwaitable?..ctor(StructAwaitable)" + IL_000e: box "StructAwaitable?" + IL_0013: call "System.Runtime.CompilerServices.TaskAwaiter Extensions.GetAwaiter(object)" + IL_0018: stloc.1 + IL_0019: ldloca.s V_1 + IL_001b: call "bool System.Runtime.CompilerServices.TaskAwaiter.IsCompleted.get" + IL_0020: brtrue.s IL_0028 + IL_0022: ldloc.1 + IL_0023: call "void System.Runtime.CompilerServices.RuntimeHelpers.UnsafeAwaitAwaiterFromRuntimeAsync(System.Runtime.CompilerServices.TaskAwaiter)" + IL_0028: ldloca.s V_1 + IL_002a: call "void System.Runtime.CompilerServices.TaskAwaiter.GetResult()" + IL_002f: ret + } + """); } [Fact, WorkItem(40251, "https://github.com/dotnet/roslyn/issues/40251")] @@ -7715,5 +7945,209 @@ .maxstack 1 } """); } + + [Theory] + [InlineData("INotifyCompletion")] + [InlineData("ICriticalNotifyCompletion")] + public void CustomAwaitable_NonGeneric(string notifyType) + { + var source = $$""" + var c = new C(); + await c; + + class C + { + public class Awaiter : System.Runtime.CompilerServices.{{notifyType}} + { + private bool isCompleted = false; + public void OnCompleted(System.Action continuation) {} + public void UnsafeOnCompleted(System.Action continuation) {} + public bool IsCompleted + { + get + { + var isCompleted = this.isCompleted; + this.isCompleted = true; + return isCompleted; + } + } + public void GetResult() => System.Console.WriteLine("42"); + } + + public Awaiter GetAwaiter() => new Awaiter(); + } + """; + + var comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], parseOptions: WithRuntimeAsync(TestOptions.RegularPreview), targetFramework: TargetFramework.NetCoreApp); + comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("42", isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = ReturnValueMissing("
$", "0x1f") }); + + var expectedAwait = notifyType == "INotifyCompletion" ? "AwaitAwaiterFromRuntimeAsync" : "UnsafeAwaitAwaiterFromRuntimeAsync"; + verifier.VerifyIL("", $$""" + { + // Code size 32 (0x20) + .maxstack 1 + .locals init (C.Awaiter V_0) + IL_0000: newobj "C..ctor()" + IL_0005: callvirt "C.Awaiter C.GetAwaiter()" + IL_000a: stloc.0 + IL_000b: ldloc.0 + IL_000c: callvirt "bool C.Awaiter.IsCompleted.get" + IL_0011: brtrue.s IL_0019 + IL_0013: ldloc.0 + IL_0014: call "void System.Runtime.CompilerServices.RuntimeHelpers.{{expectedAwait}}(C.Awaiter)" + IL_0019: ldloc.0 + IL_001a: callvirt "void C.Awaiter.GetResult()" + IL_001f: ret + } + """); + } + + [Theory] + [InlineData("System.Runtime.CompilerServices.INotifyCompletion")] + [InlineData("System.Runtime.CompilerServices.ICriticalNotifyCompletion")] + [InlineData("System.Runtime.CompilerServices.ICriticalNotifyCompletion, System.Runtime.CompilerServices.INotifyCompletion")] + public void CustomAwaitable_WithNonVoidAwait(string notifyType) + { + var source = $$""" + var c = new C(); + System.Console.WriteLine(await c); + + class C + { + public class Awaiter : {{notifyType}} + { + private bool isCompleted = false; + public void OnCompleted(System.Action continuation) {} + public void UnsafeOnCompleted(System.Action continuation) {} + public bool IsCompleted + { + get + { + var isCompleted = this.isCompleted; + this.isCompleted = true; + return isCompleted; + } + } + public int GetResult() => 42; + } + + public Awaiter GetAwaiter() => new Awaiter(); + } + """; + + var comp = CreateCompilation([source, RuntimeAsyncAwaitHelpers], parseOptions: WithRuntimeAsync(TestOptions.RegularPreview), targetFramework: TargetFramework.NetCoreApp); + comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("42", isRuntimeAsync: true), verify: Verification.Fails with { ILVerifyMessage = ReturnValueMissing("
$", "0x24") }); + + var expectedAwait = notifyType.Contains("Critical") ? "UnsafeAwaitAwaiterFromRuntimeAsync" : "AwaitAwaiterFromRuntimeAsync"; + verifier.VerifyIL("", $$""" + { + // Code size 37 (0x25) + .maxstack 1 + .locals init (C.Awaiter V_0) + IL_0000: newobj "C..ctor()" + IL_0005: callvirt "C.Awaiter C.GetAwaiter()" + IL_000a: stloc.0 + IL_000b: ldloc.0 + IL_000c: callvirt "bool C.Awaiter.IsCompleted.get" + IL_0011: brtrue.s IL_0019 + IL_0013: ldloc.0 + IL_0014: call "void System.Runtime.CompilerServices.RuntimeHelpers.{{expectedAwait}}(C.Awaiter)" + IL_0019: ldloc.0 + IL_001a: callvirt "int C.Awaiter.GetResult()" + IL_001f: call "void System.Console.WriteLine(int)" + IL_0024: ret + } + """); + } + + [Fact, WorkItem("https://github.com/dotnet/roslyn/issues/77897")] + public void AwaitYield() + { + var code = """ + using System.Threading.Tasks; + class C + { + static bool doYields = true; + + static async Task Main() + { + System.Console.WriteLine(await Fib(10)); + } + + static async Task Fib(int i) + { + if (i <= 1) + { + if (doYields) + { + await Task.Yield(); + } + + return 1; + } + + int i1 = await Fib(i - 1); + int i2 = await Fib(i - 2); + + return i1 + i2; + } + } + """; + + var comp = CreateCompilation([code, RuntimeAsyncAwaitHelpers], parseOptions: WithRuntimeAsync(TestOptions.RegularPreview), targetFramework: TargetFramework.NetCoreApp); + comp.Assembly.SetOverrideRuntimeSupportsAsyncMethods(); + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("55", isRuntimeAsync: true), verify: Verification.Fails with + { + ILVerifyMessage = $$""" + {{ReturnValueMissing("Main", "0x11")}} + [Fib]: Unexpected type on the stack. { Offset = 0x30, Found = Int32, Expected = ref '[System.Runtime]System.Threading.Tasks.Task`1' } + [Fib]: Unexpected type on the stack. { Offset = 0x4e, Found = Int32, Expected = ref '[System.Runtime]System.Threading.Tasks.Task`1' } + """ + }); + verifier.VerifyIL("C.Fib(int)", """ + { + // Code size 79 (0x4f) + .maxstack 3 + .locals init (int V_0, //i2 + System.Runtime.CompilerServices.YieldAwaitable.YieldAwaiter V_1, + System.Runtime.CompilerServices.YieldAwaitable V_2) + IL_0000: ldarg.0 + IL_0001: ldc.i4.1 + IL_0002: bgt.s IL_0031 + IL_0004: ldsfld "bool C.doYields" + IL_0009: brfalse.s IL_002f + IL_000b: call "System.Runtime.CompilerServices.YieldAwaitable System.Threading.Tasks.Task.Yield()" + IL_0010: stloc.2 + IL_0011: ldloca.s V_2 + IL_0013: call "System.Runtime.CompilerServices.YieldAwaitable.YieldAwaiter System.Runtime.CompilerServices.YieldAwaitable.GetAwaiter()" + IL_0018: stloc.1 + IL_0019: ldloca.s V_1 + IL_001b: call "bool System.Runtime.CompilerServices.YieldAwaitable.YieldAwaiter.IsCompleted.get" + IL_0020: brtrue.s IL_0028 + IL_0022: ldloc.1 + IL_0023: call "void System.Runtime.CompilerServices.RuntimeHelpers.UnsafeAwaitAwaiterFromRuntimeAsync(System.Runtime.CompilerServices.YieldAwaitable.YieldAwaiter)" + IL_0028: ldloca.s V_1 + IL_002a: call "void System.Runtime.CompilerServices.YieldAwaitable.YieldAwaiter.GetResult()" + IL_002f: ldc.i4.1 + IL_0030: ret + IL_0031: ldarg.0 + IL_0032: ldc.i4.1 + IL_0033: sub + IL_0034: call "System.Threading.Tasks.Task C.Fib(int)" + IL_0039: call "int System.Runtime.CompilerServices.RuntimeHelpers.Await(System.Threading.Tasks.Task)" + IL_003e: ldarg.0 + IL_003f: ldc.i4.2 + IL_0040: sub + IL_0041: call "System.Threading.Tasks.Task C.Fib(int)" + IL_0046: call "int System.Runtime.CompilerServices.RuntimeHelpers.Await(System.Threading.Tasks.Task)" + IL_004b: stloc.0 + IL_004c: ldloc.0 + IL_004d: add + IL_004e: ret + } + """); + } } } diff --git a/src/Compilers/CSharp/Test/Symbol/Symbols/MissingSpecialMember.cs b/src/Compilers/CSharp/Test/Symbol/Symbols/MissingSpecialMember.cs index d2e76bf4b05e6..9846f57e88268 100644 --- a/src/Compilers/CSharp/Test/Symbol/Symbols/MissingSpecialMember.cs +++ b/src/Compilers/CSharp/Test/Symbol/Symbols/MissingSpecialMember.cs @@ -1027,6 +1027,8 @@ public void AllWellKnownTypeMembers() case WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitTaskT_T: case WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTask: case WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTaskT_T: + case WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter: + case WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter: // Not yet in the platform. continue; case WellKnownMember.Microsoft_CodeAnalysis_Runtime_Instrumentation__CreatePayloadForMethodsSpanningSingleFile: diff --git a/src/Compilers/Core/Portable/WellKnownMember.cs b/src/Compilers/Core/Portable/WellKnownMember.cs index 8704713648526..fc768dcca4fb8 100644 --- a/src/Compilers/Core/Portable/WellKnownMember.cs +++ b/src/Compilers/Core/Portable/WellKnownMember.cs @@ -154,6 +154,8 @@ internal enum WellKnownMember System_Runtime_CompilerServices_RuntimeHelpers__AwaitTaskT_T, System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTask, System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTaskT_T, + System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter, + System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter, System_Runtime_CompilerServices_Unsafe__Add_T, System_Runtime_CompilerServices_Unsafe__As_T, diff --git a/src/Compilers/Core/Portable/WellKnownMembers.cs b/src/Compilers/Core/Portable/WellKnownMembers.cs index 30488c80ff189..0cb65cceb46e4 100644 --- a/src/Compilers/Core/Portable/WellKnownMembers.cs +++ b/src/Compilers/Core/Portable/WellKnownMembers.cs @@ -1067,6 +1067,22 @@ static WellKnownMembers() (byte)SignatureTypeCode.GenericTypeInstance, (byte)SignatureTypeCode.TypeHandle, (byte)WellKnownType.ExtSentinel, (byte)(WellKnownType.System_Threading_Tasks_ValueTask_T - WellKnownType.ExtSentinel), 1, (byte)SignatureTypeCode.GenericMethodParameter, 0, + // System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter + (byte)(MemberFlags.Method | MemberFlags.Static), // Flags + (byte)WellKnownType.System_Runtime_CompilerServices_RuntimeHelpers, // DeclaringTypeId + 1, // Arity + 1, // Method Signature + (byte)SignatureTypeCode.TypeHandle, (byte)SpecialType.System_Void, // Return Type + (byte)SignatureTypeCode.GenericMethodParameter, 0, + + // System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter + (byte)(MemberFlags.Method | MemberFlags.Static), // Flags + (byte)WellKnownType.System_Runtime_CompilerServices_RuntimeHelpers, // DeclaringTypeId + 1, // Arity + 1, // Method Signature + (byte)SignatureTypeCode.TypeHandle, (byte)SpecialType.System_Void, // Return Type + (byte)SignatureTypeCode.GenericMethodParameter, 0, + // System_Runtime_CompilerServices_Unsafe__Add_T (byte)(MemberFlags.Method | MemberFlags.Static), // Flags (byte)WellKnownType.ExtSentinel, (byte)(WellKnownType.System_Runtime_CompilerServices_Unsafe - WellKnownType.ExtSentinel), // DeclaringTypeId @@ -5363,6 +5379,8 @@ static WellKnownMembers() "Await", // System_Runtime_CompilerServices_RuntimeHelpers__AwaitTaskT_T "Await", // System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTask "Await", // System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTaskT_T + "AwaitAwaiterFromRuntimeAsync", // System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter + "UnsafeAwaitAwaiterFromRuntimeAsync", // System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter "Add", // System_Runtime_CompilerServices_Unsafe__Add_T "As", // System_Runtime_CompilerServices_Unsafe__As_T, "AsRef", // System_Runtime_CompilerServices_Unsafe__AsRef_T, diff --git a/src/Compilers/VisualBasic/Test/Symbol/SymbolsTests/WellKnownTypeValidationTests.vb b/src/Compilers/VisualBasic/Test/Symbol/SymbolsTests/WellKnownTypeValidationTests.vb index a07d9fe406eeb..c1c6cea34fcb9 100644 --- a/src/Compilers/VisualBasic/Test/Symbol/SymbolsTests/WellKnownTypeValidationTests.vb +++ b/src/Compilers/VisualBasic/Test/Symbol/SymbolsTests/WellKnownTypeValidationTests.vb @@ -763,7 +763,9 @@ End Namespace WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitTask, WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitTaskT_T, WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTask, - WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTaskT_T + WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTaskT_T, + WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter, + WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter ' Not available yet, but will be in upcoming release. Continue For Case WellKnownMember.Microsoft_CodeAnalysis_Runtime_Instrumentation__CreatePayloadForMethodsSpanningSingleFile, @@ -978,7 +980,9 @@ End Namespace WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitTask, WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitTaskT_T, WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTask, - WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTaskT_T + WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitValueTaskT_T, + WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__AwaitAwaiterFromRuntimeAsync_TAwaiter, + WellKnownMember.System_Runtime_CompilerServices_RuntimeHelpers__UnsafeAwaitAwaiterFromRuntimeAsync_TAwaiter ' Not available yet, but will be in upcoming release. Continue For Case WellKnownMember.Microsoft_CodeAnalysis_Runtime_Instrumentation__CreatePayloadForMethodsSpanningSingleFile,