diff --git a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_Inheriting_Multiple_Interfaces.verified.txt b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_Inheriting_Multiple_Interfaces.verified.txt index ef90b8fdde..f28d3e5304 100644 --- a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_Inheriting_Multiple_Interfaces.verified.txt +++ b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_Inheriting_Multiple_Interfaces.verified.txt @@ -1,4 +1,4 @@ -// +// #nullable enable namespace TUnit.Mocks.Generated @@ -129,6 +129,8 @@ namespace TUnit.Mocks.Generated return new global::TUnit.Mocks.Setup.VoidMethodSetupBuilder(setup); })!; + /// + public IReadWriter_Write_M2_MockCall Returns() { EnsureSetup().Returns(); return this; } /// public IReadWriter_Write_M2_MockCall Throws() where TException : global::System.Exception, new() { EnsureSetup().Throws(); return this; } /// diff --git a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Async_Methods.verified.txt b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Async_Methods.verified.txt index 6dd209e2bc..3f08ae5dc2 100644 --- a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Async_Methods.verified.txt +++ b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Async_Methods.verified.txt @@ -1,4 +1,4 @@ -// +// #nullable enable namespace TUnit.Mocks.Generated @@ -47,6 +47,11 @@ namespace TUnit.Mocks.Generated try { var __result = _engine.HandleCallWithReturn(0, "GetValueAsync", new object?[] { key }, ""); + if (global::TUnit.Mocks.Setup.RawReturnContext.TryConsume(out var __rawAsync)) + { + if (__rawAsync is global::System.Threading.Tasks.Task __typedAsync) return __typedAsync; + throw new global::System.InvalidOperationException($"ReturnsAsync: expected global::System.Threading.Tasks.Task but got {__rawAsync?.GetType().Name ?? "null"}"); + } return global::System.Threading.Tasks.Task.FromResult(__result); } catch (global::System.Exception __ex) @@ -60,6 +65,11 @@ namespace TUnit.Mocks.Generated try { _engine.HandleCall(1, "DoWorkAsync", global::System.Array.Empty()); + if (global::TUnit.Mocks.Setup.RawReturnContext.TryConsume(out var __rawAsync)) + { + if (__rawAsync is global::System.Threading.Tasks.Task __typedAsync) return __typedAsync; + throw new global::System.InvalidOperationException($"ReturnsAsync: expected global::System.Threading.Tasks.Task but got {__rawAsync?.GetType().Name ?? "null"}"); + } return global::System.Threading.Tasks.Task.CompletedTask; } catch (global::System.Exception __ex) @@ -73,6 +83,11 @@ namespace TUnit.Mocks.Generated try { var __result = _engine.HandleCallWithReturn(2, "ComputeAsync", new object?[] { input }, default); + if (global::TUnit.Mocks.Setup.RawReturnContext.TryConsume(out var __rawAsync)) + { + if (__rawAsync is global::System.Threading.Tasks.ValueTask __typedAsync) return __typedAsync; + throw new global::System.InvalidOperationException($"ReturnsAsync: expected global::System.Threading.Tasks.ValueTask but got {__rawAsync?.GetType().Name ?? "null"}"); + } return new global::System.Threading.Tasks.ValueTask(__result); } catch (global::System.Exception __ex) @@ -86,6 +101,11 @@ namespace TUnit.Mocks.Generated try { _engine.HandleCall(3, "InitializeAsync", new object?[] { ct }); + if (global::TUnit.Mocks.Setup.RawReturnContext.TryConsume(out var __rawAsync)) + { + if (__rawAsync is global::System.Threading.Tasks.ValueTask __typedAsync) return __typedAsync; + throw new global::System.InvalidOperationException($"ReturnsAsync: expected global::System.Threading.Tasks.ValueTask but got {__rawAsync?.GetType().Name ?? "null"}"); + } return default(global::System.Threading.Tasks.ValueTask); } catch (global::System.Exception __ex) @@ -125,10 +145,10 @@ namespace TUnit.Mocks.Generated return new IAsyncService_GetValueAsync_M0_MockCall(global::TUnit.Mocks.Mock.GetEngine(mock), 0, "GetValueAsync", matchers); } - public static global::TUnit.Mocks.VoidMockMethodCall DoWorkAsync(this global::TUnit.Mocks.Mock mock) + public static IAsyncService_DoWorkAsync_M1_MockCall DoWorkAsync(this global::TUnit.Mocks.Mock mock) { var matchers = global::System.Array.Empty(); - return new global::TUnit.Mocks.VoidMockMethodCall(global::TUnit.Mocks.Mock.GetEngine(mock), 1, "DoWorkAsync", matchers); + return new IAsyncService_DoWorkAsync_M1_MockCall(global::TUnit.Mocks.Mock.GetEngine(mock), 1, "DoWorkAsync", matchers); } public static IAsyncService_ComputeAsync_M2_MockCall ComputeAsync(this global::TUnit.Mocks.Mock mock, global::TUnit.Mocks.Arguments.Arg input) @@ -212,6 +232,11 @@ namespace TUnit.Mocks.Generated /// public IAsyncService_GetValueAsync_M0_MockCall Then() { EnsureSetup().Then(); return this; } + /// Return a pre-built Task directly (e.g., from a TaskCompletionSource). + public IAsyncService_GetValueAsync_M0_MockCall ReturnsAsync(global::System.Threading.Tasks.Task task) { EnsureSetup().ReturnsRaw(task); return this; } + /// Return a pre-built Task from a factory, invoked on each call. + public IAsyncService_GetValueAsync_M0_MockCall ReturnsAsync(global::System.Func> taskFactory) { EnsureSetup().ReturnsRaw(() => (object?)taskFactory()); return this; } + /// Configure a typed computed return value using the actual method parameters. public IAsyncService_GetValueAsync_M0_MockCall Returns(global::System.Func factory) { @@ -248,6 +273,75 @@ namespace TUnit.Mocks.Generated public void WasNeverCalled(string? message) => _engine.CreateVerification(_memberId, _memberName, _matchers).WasNeverCalled(message); } + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + public sealed class IAsyncService_DoWorkAsync_M1_MockCall : global::TUnit.Mocks.Verification.ICallVerification + { + private readonly global::TUnit.Mocks.IMockEngineAccess _engine; + private readonly int _memberId; + private readonly string _memberName; + private readonly global::TUnit.Mocks.Arguments.IArgumentMatcher[] _matchers; + private global::TUnit.Mocks.Setup.VoidMethodSetupBuilder? _builder; + private bool _builderInitialized; + private object? _builderLock; + + internal IAsyncService_DoWorkAsync_M1_MockCall(global::TUnit.Mocks.IMockEngineAccess engine, int memberId, string memberName, global::TUnit.Mocks.Arguments.IArgumentMatcher[] matchers) + { + _engine = engine; + _memberId = memberId; + _memberName = memberName; + _matchers = matchers; + _ = EnsureSetup(); + } + + private global::TUnit.Mocks.Setup.VoidMethodSetupBuilder EnsureSetup() => + global::System.Threading.LazyInitializer.EnsureInitialized(ref _builder, ref _builderInitialized, ref _builderLock, () => + { + var setup = new global::TUnit.Mocks.Setup.MethodSetup(_memberId, _matchers, _memberName); + _engine.AddSetup(setup); + return new global::TUnit.Mocks.Setup.VoidMethodSetupBuilder(setup); + })!; + + /// + public IAsyncService_DoWorkAsync_M1_MockCall Returns() { EnsureSetup().Returns(); return this; } + /// + public IAsyncService_DoWorkAsync_M1_MockCall Throws() where TException : global::System.Exception, new() { EnsureSetup().Throws(); return this; } + /// + public IAsyncService_DoWorkAsync_M1_MockCall Throws(global::System.Exception exception) { EnsureSetup().Throws(exception); return this; } + /// + public IAsyncService_DoWorkAsync_M1_MockCall Callback(global::System.Action callback) { EnsureSetup().Callback(callback); return this; } + /// + public IAsyncService_DoWorkAsync_M1_MockCall Callback(global::System.Action callback) { EnsureSetup().Callback(callback); return this; } + /// + public IAsyncService_DoWorkAsync_M1_MockCall Throws(global::System.Func exceptionFactory) { EnsureSetup().Throws(exceptionFactory); return this; } + /// + public IAsyncService_DoWorkAsync_M1_MockCall Raises(string eventName, object? args = null) { EnsureSetup().Raises(eventName, args); return this; } + /// + public IAsyncService_DoWorkAsync_M1_MockCall SetsOutParameter(int paramIndex, object? value) { EnsureSetup().SetsOutParameter(paramIndex, value); return this; } + /// + public IAsyncService_DoWorkAsync_M1_MockCall TransitionsTo(string stateName) { EnsureSetup().TransitionsTo(stateName); return this; } + /// + public IAsyncService_DoWorkAsync_M1_MockCall Then() { EnsureSetup().Then(); return this; } + + /// Return a pre-built Task directly (e.g., from a TaskCompletionSource). + public IAsyncService_DoWorkAsync_M1_MockCall ReturnsAsync(global::System.Threading.Tasks.Task task) { EnsureSetup().ReturnsRaw(task); return this; } + /// Return a pre-built Task from a factory, invoked on each call. + public IAsyncService_DoWorkAsync_M1_MockCall ReturnsAsync(global::System.Func taskFactory) { EnsureSetup().ReturnsRaw(() => (object?)taskFactory()); return this; } + + // ICallVerification + /// + public void WasCalled() => _engine.CreateVerification(_memberId, _memberName, _matchers).WasCalled(); + /// + public void WasCalled(global::TUnit.Mocks.Times times) => _engine.CreateVerification(_memberId, _memberName, _matchers).WasCalled(times); + /// + public void WasCalled(global::TUnit.Mocks.Times times, string? message) => _engine.CreateVerification(_memberId, _memberName, _matchers).WasCalled(times, message); + /// + public void WasCalled(string? message) => _engine.CreateVerification(_memberId, _memberName, _matchers).WasCalled(message); + /// + public void WasNeverCalled() => _engine.CreateVerification(_memberId, _memberName, _matchers).WasNeverCalled(); + /// + public void WasNeverCalled(string? message) => _engine.CreateVerification(_memberId, _memberName, _matchers).WasNeverCalled(message); + } + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] public sealed class IAsyncService_ComputeAsync_M2_MockCall : global::TUnit.Mocks.Verification.ICallVerification { @@ -302,6 +396,13 @@ namespace TUnit.Mocks.Generated /// public IAsyncService_ComputeAsync_M2_MockCall Then() { EnsureSetup().Then(); return this; } + /// Return a pre-built ValueTask directly (e.g., from a TaskCompletionSource). + /// The same ValueTask instance is returned on every call. Since ValueTask may only be awaited once, + /// use the factory overload if the mock will be called multiple times, or ensure the ValueTask is backed by a Task. + public IAsyncService_ComputeAsync_M2_MockCall ReturnsAsync(global::System.Threading.Tasks.ValueTask task) { EnsureSetup().ReturnsRaw(task); return this; } + /// Return a pre-built ValueTask from a factory, invoked on each call. + public IAsyncService_ComputeAsync_M2_MockCall ReturnsAsync(global::System.Func> taskFactory) { EnsureSetup().ReturnsRaw(() => (object?)taskFactory()); return this; } + /// Configure a typed computed return value using the actual method parameters. public IAsyncService_ComputeAsync_M2_MockCall Returns(global::System.Func factory) { @@ -366,6 +467,8 @@ namespace TUnit.Mocks.Generated return new global::TUnit.Mocks.Setup.VoidMethodSetupBuilder(setup); })!; + /// + public IAsyncService_InitializeAsync_M3_MockCall Returns() { EnsureSetup().Returns(); return this; } /// public IAsyncService_InitializeAsync_M3_MockCall Throws() where TException : global::System.Exception, new() { EnsureSetup().Throws(); return this; } /// @@ -385,6 +488,13 @@ namespace TUnit.Mocks.Generated /// public IAsyncService_InitializeAsync_M3_MockCall Then() { EnsureSetup().Then(); return this; } + /// Return a pre-built ValueTask directly (e.g., from a TaskCompletionSource). + /// The same ValueTask instance is returned on every call. Since ValueTask may only be awaited once, + /// use the factory overload if the mock will be called multiple times, or ensure the ValueTask is backed by a Task. + public IAsyncService_InitializeAsync_M3_MockCall ReturnsAsync(global::System.Threading.Tasks.ValueTask task) { EnsureSetup().ReturnsRaw(task); return this; } + /// Return a pre-built ValueTask from a factory, invoked on each call. + public IAsyncService_InitializeAsync_M3_MockCall ReturnsAsync(global::System.Func taskFactory) { EnsureSetup().ReturnsRaw(() => (object?)taskFactory()); return this; } + /// Execute a typed callback using the actual method parameters. public IAsyncService_InitializeAsync_M3_MockCall Callback(global::System.Action callback) { diff --git a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Events.verified.txt b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Events.verified.txt index 9b5534c09c..46150bc827 100644 --- a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Events.verified.txt +++ b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Events.verified.txt @@ -1,4 +1,4 @@ -// +// #nullable enable namespace TUnit.Mocks.Generated @@ -164,6 +164,8 @@ namespace TUnit.Mocks.Generated return new global::TUnit.Mocks.Setup.VoidMethodSetupBuilder(setup); })!; + /// + public INotifier_Notify_M0_MockCall Returns() { EnsureSetup().Returns(); return this; } /// public INotifier_Notify_M0_MockCall Throws() where TException : global::System.Exception, new() { EnsureSetup().Throws(); return this; } /// diff --git a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Keyword_Parameter_Names.verified.txt b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Keyword_Parameter_Names.verified.txt index 8851539fb3..5a29ef3950 100644 --- a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Keyword_Parameter_Names.verified.txt +++ b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Keyword_Parameter_Names.verified.txt @@ -1,4 +1,4 @@ -// +// #nullable enable namespace TUnit.Mocks.Generated @@ -140,6 +140,8 @@ namespace TUnit.Mocks.Generated return new global::TUnit.Mocks.Setup.VoidMethodSetupBuilder(setup); })!; + /// + public ITest_Test_M0_MockCall Returns() { EnsureSetup().Returns(); return this; } /// public ITest_Test_M0_MockCall Throws() where TException : global::System.Exception, new() { EnsureSetup().Throws(); return this; } /// diff --git a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Mixed_Members.verified.txt b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Mixed_Members.verified.txt index 8e56e6d8d4..02d88599a6 100644 --- a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Mixed_Members.verified.txt +++ b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Mixed_Members.verified.txt @@ -1,4 +1,4 @@ -// +// #nullable enable namespace TUnit.Mocks.Generated @@ -78,6 +78,11 @@ namespace TUnit.Mocks.Generated try { var __result = _engine.HandleCallWithReturn(3, "GetAsync", new object?[] { id }, ""); + if (global::TUnit.Mocks.Setup.RawReturnContext.TryConsume(out var __rawAsync)) + { + if (__rawAsync is global::System.Threading.Tasks.Task __typedAsync) return __typedAsync; + throw new global::System.InvalidOperationException($"ReturnsAsync: expected global::System.Threading.Tasks.Task but got {__rawAsync?.GetType().Name ?? "null"}"); + } return global::System.Threading.Tasks.Task.FromResult(__result); } catch (global::System.Exception __ex) @@ -236,6 +241,11 @@ namespace TUnit.Mocks.Generated /// public IService_GetAsync_M3_MockCall Then() { EnsureSetup().Then(); return this; } + /// Return a pre-built Task directly (e.g., from a TaskCompletionSource). + public IService_GetAsync_M3_MockCall ReturnsAsync(global::System.Threading.Tasks.Task task) { EnsureSetup().ReturnsRaw(task); return this; } + /// Return a pre-built Task from a factory, invoked on each call. + public IService_GetAsync_M3_MockCall ReturnsAsync(global::System.Func> taskFactory) { EnsureSetup().ReturnsRaw(() => (object?)taskFactory()); return this; } + /// Configure a typed computed return value using the actual method parameters. public IService_GetAsync_M3_MockCall Returns(global::System.Func factory) { @@ -303,6 +313,8 @@ namespace TUnit.Mocks.Generated return new global::TUnit.Mocks.Setup.VoidMethodSetupBuilder(setup); })!; + /// + public IService_Process_M4_MockCall Returns() { EnsureSetup().Returns(); return this; } /// public IService_Process_M4_MockCall Throws() where TException : global::System.Exception, new() { EnsureSetup().Throws(); return this; } /// diff --git a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Nullable_Reference_Type_Parameters.verified.txt b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Nullable_Reference_Type_Parameters.verified.txt index c55ec140c1..cc500a2dcf 100644 --- a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Nullable_Reference_Type_Parameters.verified.txt +++ b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Nullable_Reference_Type_Parameters.verified.txt @@ -1,4 +1,4 @@ -// +// #nullable enable namespace TUnit.Mocks.Generated @@ -62,6 +62,11 @@ namespace TUnit.Mocks.Generated try { var __result = _engine.HandleCallWithReturn(3, "GetAsync", new object?[] { key }, default); + if (global::TUnit.Mocks.Setup.RawReturnContext.TryConsume(out var __rawAsync)) + { + if (__rawAsync is global::System.Threading.Tasks.Task __typedAsync) return __typedAsync; + throw new global::System.InvalidOperationException($"ReturnsAsync: expected global::System.Threading.Tasks.Task but got {__rawAsync?.GetType().Name ?? "null"}"); + } return global::System.Threading.Tasks.Task.FromResult(__result); } catch (global::System.Exception __ex) @@ -243,6 +248,8 @@ namespace TUnit.Mocks.Generated return new global::TUnit.Mocks.Setup.VoidMethodSetupBuilder(setup); })!; + /// + public IFoo_Bar_M0_MockCall Returns() { EnsureSetup().Returns(); return this; } /// public IFoo_Bar_M0_MockCall Throws() where TException : global::System.Exception, new() { EnsureSetup().Throws(); return this; } /// @@ -409,6 +416,8 @@ namespace TUnit.Mocks.Generated return new global::TUnit.Mocks.Setup.VoidMethodSetupBuilder(setup); })!; + /// + public IFoo_Process_M2_MockCall Returns() { EnsureSetup().Returns(); return this; } /// public IFoo_Process_M2_MockCall Throws() where TException : global::System.Exception, new() { EnsureSetup().Throws(); return this; } /// @@ -511,6 +520,11 @@ namespace TUnit.Mocks.Generated /// public IFoo_GetAsync_M3_MockCall Then() { EnsureSetup().Then(); return this; } + /// Return a pre-built Task directly (e.g., from a TaskCompletionSource). + public IFoo_GetAsync_M3_MockCall ReturnsAsync(global::System.Threading.Tasks.Task task) { EnsureSetup().ReturnsRaw(task); return this; } + /// Return a pre-built Task from a factory, invoked on each call. + public IFoo_GetAsync_M3_MockCall ReturnsAsync(global::System.Func> taskFactory) { EnsureSetup().ReturnsRaw(() => (object?)taskFactory()); return this; } + /// Configure a typed computed return value using the actual method parameters. public IFoo_GetAsync_M3_MockCall Returns(global::System.Func factory) { diff --git a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Out_Ref_Parameters.verified.txt b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Out_Ref_Parameters.verified.txt index c2d2bbb4fe..e0e5a758cf 100644 --- a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Out_Ref_Parameters.verified.txt +++ b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Out_Ref_Parameters.verified.txt @@ -1,4 +1,4 @@ -// +// #nullable enable namespace TUnit.Mocks.Generated @@ -225,6 +225,8 @@ namespace TUnit.Mocks.Generated return new global::TUnit.Mocks.Setup.VoidMethodSetupBuilder(setup); })!; + /// + public IDictionary_Swap_M1_MockCall Returns() { EnsureSetup().Returns(); return this; } /// public IDictionary_Swap_M1_MockCall Throws() where TException : global::System.Exception, new() { EnsureSetup().Throws(); return this; } /// diff --git a/TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs b/TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs index 99c83d7e4b..42398360c9 100644 --- a/TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs +++ b/TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs @@ -585,6 +585,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel writer.AppendLine("{"); writer.IncreaseIndent(); EmitOutRefReadback(writer, method); + EmitRawReturnCheck(writer, method); if (method.IsValueTask) { writer.AppendLine("return default(global::System.Threading.Tasks.ValueTask);"); @@ -614,6 +615,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel writer.IncreaseIndent(); } EmitOutRefReadback(writer, method); + EmitRawReturnCheck(writer, method); if (method.IsValueTask) { writer.AppendLine($"return new global::System.Threading.Tasks.ValueTask<{method.UnwrappedReturnType}>(__result);"); @@ -699,6 +701,7 @@ private static void GenerateEngineDispatchBody(CodeWriter writer, MockMemberMode { writer.AppendLine($"_engine.HandleCall({method.MemberId}, \"{method.Name}\", {argsArray});"); EmitOutRefReadback(writer, method); + EmitRawReturnCheck(writer, method); if (method.IsValueTask) { writer.AppendLine("return default(global::System.Threading.Tasks.ValueTask);"); @@ -736,6 +739,7 @@ private static void GenerateEngineDispatchBody(CodeWriter writer, MockMemberMode writer.AppendLine($"var __result = _engine.HandleCallWithReturn<{unwrappedArg}>({method.MemberId}, \"{method.Name}\", {argsArray}, {unwrappedDefault});"); } EmitOutRefReadback(writer, method); + EmitRawReturnCheck(writer, method); if (method.IsValueTask) { writer.AppendLine($"return new global::System.Threading.Tasks.ValueTask<{method.UnwrappedReturnType}>(__result);"); @@ -1142,6 +1146,23 @@ internal static void EmitOutRefReadback(CodeWriter writer, MockMemberModel metho } } + /// + /// For async methods: emits code to check + /// and return the raw Task/ValueTask directly if one was set by a ReturnsAsync setup. + /// + private static void EmitRawReturnCheck(CodeWriter writer, MockMemberModel method) + { + if (!method.IsAsync) return; + + // IMPORTANT: This check must appear synchronously (no await) after the engine + // dispatch call. The [ThreadStatic] RawReturnContext requires same-thread consumption. + writer.AppendLine($"if (global::TUnit.Mocks.Setup.RawReturnContext.TryConsume(out var __rawAsync))"); + writer.OpenBrace(); + writer.AppendLine($"if (__rawAsync is {method.ReturnType} __typedAsync) return __typedAsync;"); + writer.AppendLine($"throw new global::System.InvalidOperationException($\"ReturnsAsync: expected {method.ReturnType} but got {{__rawAsync?.GetType().Name ?? \"null\"}}\");"); + writer.CloseBrace(); + } + /// /// For ref struct return methods with span support: emits code to consume OutRefContext, /// read back out/ref params, extract span return value, and return. diff --git a/TUnit.Mocks.SourceGenerator/Builders/MockMembersBuilder.cs b/TUnit.Mocks.SourceGenerator/Builders/MockMembersBuilder.cs index c3339ed0a2..71a857d0bc 100644 --- a/TUnit.Mocks.SourceGenerator/Builders/MockMembersBuilder.cs +++ b/TUnit.Mocks.SourceGenerator/Builders/MockMembersBuilder.cs @@ -88,6 +88,9 @@ private static bool ShouldGenerateTypedWrapper(MockMemberModel method, bool hasE { if (method.IsGenericMethod) return false; + // Async methods need a typed wrapper for the generated ReturnsAsync() method + if (method.IsAsync) return true; + // Exclude out params and ref struct params (can't be boxed or used as type args) var matchableParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList(); if (matchableParams.Count == 0) @@ -124,7 +127,8 @@ private static void GenerateUnifiedSealedClass(CodeWriter writer, MockMemberMode // Ref struct returns use the void wrapper (can't use ref structs as generic type args) if (method.IsVoid || method.IsRefStructReturn) { - GenerateVoidUnifiedClass(writer, wrapperName, matchableParams, events, method.Parameters, hasRefStructParams, allNonOutParams, method.SpanReturnElementType, method.ReturnType); + GenerateVoidUnifiedClass(writer, wrapperName, matchableParams, events, method.Parameters, hasRefStructParams, allNonOutParams, method.SpanReturnElementType, method.ReturnType, + isAsync: method.IsAsync, isValueTask: method.IsValueTask); } else if (method.IsReturnTypeStaticAbstractInterface) { @@ -134,13 +138,15 @@ private static void GenerateUnifiedSealedClass(CodeWriter writer, MockMemberMode } else { - GenerateReturnUnifiedClass(writer, wrapperName, matchableParams, setupReturnType, events, method.Parameters, hasRefStructParams, allNonOutParams); + GenerateReturnUnifiedClass(writer, wrapperName, matchableParams, setupReturnType, events, method.Parameters, hasRefStructParams, allNonOutParams, + isAsync: method.IsAsync, isValueTask: method.IsValueTask, fullReturnType: method.ReturnType); } } private static void GenerateReturnUnifiedClass(CodeWriter writer, string wrapperName, List nonOutParams, string returnType, EquatableArray events, - EquatableArray allParameters, bool hasRefStructParams, List allNonOutParams) + EquatableArray allParameters, bool hasRefStructParams, List allNonOutParams, + bool isAsync = false, bool isValueTask = false, string? fullReturnType = null) { var builderType = $"global::TUnit.Mocks.Setup.MethodSetupBuilder<{returnType}>"; var hasOutRef = allParameters.Any(p => p.Direction == ParameterDirection.Out || p.Direction == ParameterDirection.Ref); @@ -207,6 +213,11 @@ private static void GenerateReturnUnifiedClass(CodeWriter writer, string wrapper writer.AppendLine($"/// "); writer.AppendLine($"public {wrapperName} Then() {{ EnsureSetup().Then(); return this; }}"); + if (isAsync && fullReturnType is not null) + { + EmitReturnsAsyncOverloads(writer, wrapperName, fullReturnType, isValueTask); + } + // Typed parameter overloads (only for methods with typed params) if (nonOutParams.Count >= 1) { @@ -272,7 +283,8 @@ private static void GenerateReturnUnifiedClass(CodeWriter writer, string wrapper private static void GenerateVoidUnifiedClass(CodeWriter writer, string wrapperName, List nonOutParams, EquatableArray events, EquatableArray allParameters, bool hasRefStructParams, List allNonOutParams, - string? spanReturnElementType = null, string? spanReturnType = null) + string? spanReturnElementType = null, string? spanReturnType = null, + bool isAsync = false, bool isValueTask = false) { var builderType = "global::TUnit.Mocks.Setup.VoidMethodSetupBuilder"; var hasOutRef = allParameters.Any(p => p.Direction == ParameterDirection.Out || p.Direction == ParameterDirection.Ref); @@ -311,6 +323,8 @@ private static void GenerateVoidUnifiedClass(CodeWriter writer, string wrapperNa // Public self-returning setup methods writer.AppendLine($"/// "); + writer.AppendLine($"public {wrapperName} Returns() {{ EnsureSetup().Returns(); return this; }}"); + writer.AppendLine($"/// "); writer.AppendLine($"public {wrapperName} Throws() where TException : global::System.Exception, new() {{ EnsureSetup().Throws(); return this; }}"); writer.AppendLine($"/// "); writer.AppendLine($"public {wrapperName} Throws(global::System.Exception exception) {{ EnsureSetup().Throws(exception); return this; }}"); @@ -341,6 +355,14 @@ private static void GenerateVoidUnifiedClass(CodeWriter writer, string wrapperNa writer.AppendLine($"public {wrapperName} Returns({spanReturnType} value) {{ EnsureSetup().SetsOutParameter(global::TUnit.Mocks.Setup.OutRefContext.SpanReturnValueIndex, value.ToArray()); return this; }}"); } + if (isAsync) + { + var taskType = isValueTask + ? "global::System.Threading.Tasks.ValueTask" + : "global::System.Threading.Tasks.Task"; + EmitReturnsAsyncOverloads(writer, wrapperName, taskType, isValueTask); + } + // Typed parameter overloads (only for methods with typed params) if (nonOutParams.Count >= 1) { @@ -859,6 +881,25 @@ private static string GetArgParameterList(MockMemberModel method, bool includeRe return string.Join(", ", parts); } + private static void EmitReturnsAsyncOverloads(CodeWriter writer, string wrapperName, string taskType, bool isValueTask) + { + var taskLabel = isValueTask ? "ValueTask" : "Task"; + writer.AppendLine(); + if (isValueTask) + { + writer.AppendLine($"/// Return a pre-built {taskLabel} directly (e.g., from a TaskCompletionSource)."); + writer.AppendLine($"/// The same {taskLabel} instance is returned on every call. Since {taskLabel} may only be awaited once,"); + writer.AppendLine($"/// use the factory overload if the mock will be called multiple times, or ensure the {taskLabel} is backed by a Task."); + } + else + { + writer.AppendLine($"/// Return a pre-built {taskLabel} directly (e.g., from a TaskCompletionSource)."); + } + writer.AppendLine($"public {wrapperName} ReturnsAsync({taskType} task) {{ EnsureSetup().ReturnsRaw(task); return this; }}"); + writer.AppendLine($"/// Return a pre-built {taskLabel} from a factory, invoked on each call."); + writer.AppendLine($"public {wrapperName} ReturnsAsync(global::System.Func<{taskType}> taskFactory) {{ EnsureSetup().ReturnsRaw(() => (object?)taskFactory()); return this; }}"); + } + private static void EmitEnsureSetup(CodeWriter writer, string builderType) { writer.AppendLine($"private {builderType} EnsureSetup() =>"); diff --git a/TUnit.Mocks.Tests/AsyncTests.cs b/TUnit.Mocks.Tests/AsyncTests.cs index 87393948e1..8c2f2c96ba 100644 --- a/TUnit.Mocks.Tests/AsyncTests.cs +++ b/TUnit.Mocks.Tests/AsyncTests.cs @@ -191,4 +191,144 @@ public async Task Async_Method_Verify_Called() mock.GetValueAsync().WasCalled(Times.Exactly(2)); await Assert.That(true).IsTrue(); } + + [Test] + public async Task ReturnsAsync_Task_With_TaskCompletionSource() + { + // Arrange + var tcs = new TaskCompletionSource(); + var mock = Mock.Of(); + mock.GetValueAsync().ReturnsAsync(tcs.Task); + + IAsyncService service = mock.Object; + + // Act — the task is not yet completed + var task = service.GetValueAsync(); + await Assert.That(task.IsCompleted).IsFalse(); + + // Complete the TCS + tcs.SetResult(42); + var result = await task; + + // Assert + await Assert.That(result).IsEqualTo(42); + } + + [Test] + public async Task ReturnsAsync_ValueTask_With_TaskCompletionSource() + { + // Arrange + var tcs = new TaskCompletionSource(); + var mock = Mock.Of(); + mock.GetValueValueTaskAsync().ReturnsAsync(new ValueTask(tcs.Task)); + + IAsyncService service = mock.Object; + + // Act — the ValueTask wraps the TCS + var vtask = service.GetValueValueTaskAsync(); + await Assert.That(vtask.IsCompleted).IsFalse(); + + // Complete the TCS + tcs.SetResult(99); + var result = await vtask; + + // Assert + await Assert.That(result).IsEqualTo(99); + } + + [Test] + public async Task ReturnsAsync_Void_Task_With_TaskCompletionSource() + { + // Arrange + var tcs = new TaskCompletionSource(); + var mock = Mock.Of(); + mock.DoWorkAsync().ReturnsAsync(tcs.Task); + + IAsyncService service = mock.Object; + + // Act — the task is not yet completed + var task = service.DoWorkAsync(); + await Assert.That(task.IsCompleted).IsFalse(); + + // Complete the TCS + tcs.SetResult(); + await task; + + // Assert — task completed successfully + await Assert.That(task.Status).IsEqualTo(TaskStatus.RanToCompletion); + } + + [Test] + public async Task ReturnsAsync_Factory_Returns_Different_Tasks() + { + // Arrange + var tcs1 = new TaskCompletionSource(); + var tcs2 = new TaskCompletionSource(); + var callCount = 0; + var mock = Mock.Of(); + mock.GetValueAsync().ReturnsAsync(() => (++callCount == 1) ? tcs1.Task : tcs2.Task); + + IAsyncService service = mock.Object; + + // Act — first call gets tcs1 + var task1 = service.GetValueAsync(); + tcs1.SetResult(10); + var result1 = await task1; + + // Second call gets tcs2 + var task2 = service.GetValueAsync(); + tcs2.SetResult(20); + var result2 = await task2; + + // Assert + await Assert.That(result1).IsEqualTo(10); + await Assert.That(result2).IsEqualTo(20); + } + + [Test] + public async Task ReturnsAsync_Then_Returns_Sequence() + { + // Arrange — mix ReturnsAsync and Returns in a sequence + var tcs = new TaskCompletionSource(); + var mock = Mock.Of(); + mock.GetValueAsync() + .Returns(1) + .Then() + .ReturnsAsync(tcs.Task) + .Then() + .Returns(3); + + IAsyncService service = mock.Object; + + // First call returns immediately + var result1 = await service.GetValueAsync(); + await Assert.That(result1).IsEqualTo(1); + + // Second call returns the TCS task (not yet completed) + var task2 = service.GetValueAsync(); + await Assert.That(task2.IsCompleted).IsFalse(); + tcs.SetResult(2); + var result2 = await task2; + await Assert.That(result2).IsEqualTo(2); + + // Third call returns immediately + var result3 = await service.GetValueAsync(); + await Assert.That(result3).IsEqualTo(3); + } + + [Test] + public async Task ReturnsAsync_Already_Completed_Task() + { + // Arrange — pass an already-completed task + var mock = Mock.Of(); + mock.GetValueAsync().ReturnsAsync(Task.FromResult(123)); + + IAsyncService service = mock.Object; + + // Act + var result = await service.GetValueAsync(); + + // Assert + await Assert.That(result).IsEqualTo(123); + } } diff --git a/TUnit.Mocks.Tests/SequentialBehaviorTests.cs b/TUnit.Mocks.Tests/SequentialBehaviorTests.cs index db6d2ffc81..14956c1f5f 100644 --- a/TUnit.Mocks.Tests/SequentialBehaviorTests.cs +++ b/TUnit.Mocks.Tests/SequentialBehaviorTests.cs @@ -104,6 +104,37 @@ public async Task Returns_Then_Throws_Sequence() await Assert.That(exception).IsNotNull(); } + [Test] + public async Task Void_Returns_Then_Throws_Then_Returns() + { + // Arrange + var mock = Mock.Of(); + mock.Log(Any()) + .Returns() + .Then() + .Throws() + .Then() + .Returns() + .Then() + .Throws(); + + ICalculator calc = mock.Object; + + // First call succeeds (Returns) + calc.Log("first"); + + // Second call throws InvalidOperationException + var ex1 = Assert.Throws(() => calc.Log("second")); + await Assert.That(ex1).IsNotNull(); + + // Third call succeeds (Returns) + calc.Log("third"); + + // Fourth call throws ArgumentException + var ex2 = Assert.Throws(() => calc.Log("fourth")); + await Assert.That(ex2).IsNotNull(); + } + [Test] public async Task Chained_Returns_With_Then() { diff --git a/TUnit.Mocks/MockEngine.cs b/TUnit.Mocks/MockEngine.cs index 2e69abad53..bb63ace818 100644 --- a/TUnit.Mocks/MockEngine.cs +++ b/TUnit.Mocks/MockEngine.cs @@ -139,6 +139,7 @@ ICallVerification IMockEngineAccess.CreateVerification(int memberId, string memb /// public void HandleCall(int memberId, string memberName, object?[] args) { + RawReturnContext.Clear(); var callRecord = RecordCall(memberId, memberName, args); // Auto-track property setters: store value keyed by property name @@ -151,13 +152,24 @@ public void HandleCall(int memberId, string memberName, object?[] args) if (behavior is not null) { - behavior.Execute(args); - // Set out/ref assignments after Execute to avoid reentrancy overwrite from callbacks - OutRefContext.Set(matchedSetup?.OutRefAssignments); - if (matchedSetup is not null) + var behaviorResult = behavior.Execute(args); + if (behaviorResult is RawReturn raw) { - - RaiseEventsForSetup(matchedSetup); + RawReturnContext.Set(raw); + } + try + { + // Set out/ref assignments after Execute to avoid reentrancy overwrite from callbacks + OutRefContext.Set(matchedSetup?.OutRefAssignments); + if (matchedSetup is not null) + { + RaiseEventsForSetup(matchedSetup); + } + } + catch + { + RawReturnContext.Clear(); + throw; } return; } @@ -170,7 +182,6 @@ public void HandleCall(int memberId, string memberName, object?[] args) { if (matchedSetup is not null) { - RaiseEventsForSetup(matchedSetup); } return; @@ -192,6 +203,7 @@ public void HandleCall(int memberId, string memberName, object?[] args) /// public TReturn HandleCallWithReturn(int memberId, string memberName, object?[] args, TReturn defaultValue) { + RawReturnContext.Clear(); var callRecord = RecordCall(memberId, memberName, args); var (setupFound, behavior, matchedSetup) = FindMatchingSetup(memberId, args); @@ -199,15 +211,30 @@ public TReturn HandleCallWithReturn(int memberId, string memberName, ob if (behavior is not null) { var result = behavior.Execute(args); - // Set out/ref assignments after Execute to avoid reentrancy overwrite from callbacks - OutRefContext.Set(matchedSetup?.OutRefAssignments); - if (matchedSetup is not null) + if (result is RawReturn raw) { - - RaiseEventsForSetup(matchedSetup); + RawReturnContext.Set(raw); + } + try + { + // Set out/ref assignments after Execute to avoid reentrancy overwrite from callbacks + OutRefContext.Set(matchedSetup?.OutRefAssignments); + if (matchedSetup is not null) + { + RaiseEventsForSetup(matchedSetup); + } + } + catch + { + RawReturnContext.Clear(); + throw; } if (result is TReturn typed) return typed; if (result is null) return default(TReturn)!; + if (result is RawReturn) + { + return defaultValue; + } throw new InvalidOperationException( $"Setup for method returning {typeof(TReturn).Name} returned incompatible type {result.GetType().Name}."); } @@ -284,19 +311,31 @@ public TReturn HandleCallWithReturn(int memberId, string memberName, ob [EditorBrowsable(EditorBrowsableState.Never)] public bool TryHandleCall(int memberId, string memberName, object?[] args) { + RawReturnContext.Clear(); var callRecord = RecordCall(memberId, memberName, args); var (setupFound, behavior, matchedSetup) = FindMatchingSetup(memberId, args); if (behavior is not null) { - behavior.Execute(args); - // Set out/ref assignments after Execute to avoid reentrancy overwrite from callbacks - OutRefContext.Set(matchedSetup?.OutRefAssignments); - if (matchedSetup is not null) + var behaviorResult = behavior.Execute(args); + if (behaviorResult is RawReturn raw) { - - RaiseEventsForSetup(matchedSetup); + RawReturnContext.Set(raw); + } + try + { + // Set out/ref assignments after Execute to avoid reentrancy overwrite from callbacks + OutRefContext.Set(matchedSetup?.OutRefAssignments); + if (matchedSetup is not null) + { + RaiseEventsForSetup(matchedSetup); + } + } + catch + { + RawReturnContext.Clear(); + throw; } return true; } @@ -306,7 +345,6 @@ public bool TryHandleCall(int memberId, string memberName, object?[] args) if (setupFound && matchedSetup is not null) { - RaiseEventsForSetup(matchedSetup); } @@ -333,6 +371,7 @@ public bool TryHandleCall(int memberId, string memberName, object?[] args) [EditorBrowsable(EditorBrowsableState.Never)] public bool TryHandleCallWithReturn(int memberId, string memberName, object?[] args, TReturn defaultValue, out TReturn result) { + RawReturnContext.Clear(); var callRecord = RecordCall(memberId, memberName, args); var (setupFound, behavior, matchedSetup) = FindMatchingSetup(memberId, args); @@ -340,15 +379,30 @@ public bool TryHandleCallWithReturn(int memberId, string memberName, ob if (behavior is not null) { var behaviorResult = behavior.Execute(args); - // Set out/ref assignments after Execute to avoid reentrancy overwrite from callbacks - OutRefContext.Set(matchedSetup?.OutRefAssignments); - if (matchedSetup is not null) + if (behaviorResult is RawReturn raw) { - - RaiseEventsForSetup(matchedSetup); + RawReturnContext.Set(raw); + } + try + { + // Set out/ref assignments after Execute to avoid reentrancy overwrite from callbacks + OutRefContext.Set(matchedSetup?.OutRefAssignments); + if (matchedSetup is not null) + { + RaiseEventsForSetup(matchedSetup); + } + } + catch + { + RawReturnContext.Clear(); + throw; } if (behaviorResult is TReturn typed) result = typed; else if (behaviorResult is null) result = default(TReturn)!; + else if (behaviorResult is RawReturn) + { + result = defaultValue; + } else throw new InvalidOperationException( $"Setup for method returning {typeof(TReturn).Name} returned incompatible type {behaviorResult.GetType().Name}."); return true; diff --git a/TUnit.Mocks/MockMethodCall.cs b/TUnit.Mocks/MockMethodCall.cs index 0faeac9627..4aae941b8e 100644 --- a/TUnit.Mocks/MockMethodCall.cs +++ b/TUnit.Mocks/MockMethodCall.cs @@ -118,6 +118,20 @@ public ISetupChain TransitionsTo(string stateName) return this; } + [EditorBrowsable(EditorBrowsableState.Never)] + public ISetupChain ReturnsRaw(object? rawValue) + { + EnsureSetup().ReturnsRaw(rawValue); + return this; + } + + [EditorBrowsable(EditorBrowsableState.Never)] + public ISetupChain ReturnsRaw(Func factory) + { + EnsureSetup().ReturnsRaw(factory); + return this; + } + // ISetupChain implementation public IMethodSetup Then() diff --git a/TUnit.Mocks/Setup/Behaviors/RawReturnBehavior.cs b/TUnit.Mocks/Setup/Behaviors/RawReturnBehavior.cs new file mode 100644 index 0000000000..bcf8e64f04 --- /dev/null +++ b/TUnit.Mocks/Setup/Behaviors/RawReturnBehavior.cs @@ -0,0 +1,19 @@ +namespace TUnit.Mocks.Setup.Behaviors; + +internal sealed class RawReturnBehavior : IBehavior +{ + private readonly RawReturn _wrapper; + + public RawReturnBehavior(object? rawValue) => _wrapper = new RawReturn(rawValue); + + public object? Execute(object?[] arguments) => _wrapper; +} + +internal sealed class ComputedRawReturnBehavior : IBehavior +{ + private readonly Func _factory; + + public ComputedRawReturnBehavior(Func factory) => _factory = factory; + + public object? Execute(object?[] arguments) => new RawReturn(_factory()); +} diff --git a/TUnit.Mocks/Setup/Behaviors/VoidReturnBehavior.cs b/TUnit.Mocks/Setup/Behaviors/VoidReturnBehavior.cs new file mode 100644 index 0000000000..f97398055d --- /dev/null +++ b/TUnit.Mocks/Setup/Behaviors/VoidReturnBehavior.cs @@ -0,0 +1,8 @@ +namespace TUnit.Mocks.Setup.Behaviors; + +internal sealed class VoidReturnBehavior : IBehavior +{ + public static VoidReturnBehavior Instance { get; } = new(); + + public object? Execute(object?[] arguments) => null; +} diff --git a/TUnit.Mocks/Setup/IVoidMethodSetup.cs b/TUnit.Mocks/Setup/IVoidMethodSetup.cs index 7913be3166..5554527fab 100644 --- a/TUnit.Mocks/Setup/IVoidMethodSetup.cs +++ b/TUnit.Mocks/Setup/IVoidMethodSetup.cs @@ -7,6 +7,9 @@ namespace TUnit.Mocks.Setup; /// public interface IVoidMethodSetup { + /// Configure the method to return normally (no-op). Useful for sequential behavior chains. + IVoidSetupChain Returns(); + /// Configure an exception to throw. IVoidSetupChain Throws() where TException : Exception, new(); diff --git a/TUnit.Mocks/Setup/MethodSetupBuilder.cs b/TUnit.Mocks/Setup/MethodSetupBuilder.cs index a53490646b..7b51ed052b 100644 --- a/TUnit.Mocks/Setup/MethodSetupBuilder.cs +++ b/TUnit.Mocks/Setup/MethodSetupBuilder.cs @@ -93,5 +93,19 @@ public ISetupChain TransitionsTo(string stateName) return this; } + [EditorBrowsable(EditorBrowsableState.Never)] + public ISetupChain ReturnsRaw(object? rawValue) + { + _setup.AddBehavior(new RawReturnBehavior(rawValue)); + return this; + } + + [EditorBrowsable(EditorBrowsableState.Never)] + public ISetupChain ReturnsRaw(Func factory) + { + _setup.AddBehavior(new ComputedRawReturnBehavior(factory)); + return this; + } + public IMethodSetup Then() => this; } diff --git a/TUnit.Mocks/Setup/PropertySetterSetupBuilder.cs b/TUnit.Mocks/Setup/PropertySetterSetupBuilder.cs index fb861f99f0..1406e775af 100644 --- a/TUnit.Mocks/Setup/PropertySetterSetupBuilder.cs +++ b/TUnit.Mocks/Setup/PropertySetterSetupBuilder.cs @@ -17,6 +17,7 @@ public PropertySetterSetupBuilder(MethodSetup setup) _inner = new VoidMethodSetupBuilder(setup); } + public IVoidSetupChain Returns() => _inner.Returns(); public IVoidSetupChain Throws() where TException : Exception, new() => _inner.Throws(); public IVoidSetupChain Throws(Exception exception) => _inner.Throws(exception); public IVoidSetupChain Callback(Action callback) => _inner.Callback(callback); diff --git a/TUnit.Mocks/Setup/RawReturn.cs b/TUnit.Mocks/Setup/RawReturn.cs new file mode 100644 index 0000000000..ed3094bceb --- /dev/null +++ b/TUnit.Mocks/Setup/RawReturn.cs @@ -0,0 +1,60 @@ +using System.ComponentModel; + +namespace TUnit.Mocks.Setup; + +/// +/// Marker type wrapping a raw return value (e.g., a Task or ValueTask) that should +/// bypass the engine's normal type coercion. When the engine encounters this as a +/// behavior result, it stores the inner value in +/// for the generated code to consume directly. +/// Public for generated code access. Not intended for direct use. +/// +[EditorBrowsable(EditorBrowsableState.Never)] +public sealed class RawReturn +{ + public object? Value { get; } + + public RawReturn(object? value) => Value = value; +} + +/// +/// Thread-local storage for raw return values from behaviors. +/// The generated mock implementation reads from this after calling the engine, +/// allowing async methods to return pre-built Task/ValueTask instances directly +/// (e.g., from a ). +/// Public for generated code access. Not intended for direct use. +/// +/// +/// IMPORTANT: RawReturnContext must be consumed synchronously in the same execution +/// context as HandleCall*/TryHandleCall*. No await may appear between the engine +/// dispatch call and TryConsume in the generated code. +/// +[EditorBrowsable(EditorBrowsableState.Never)] +public static class RawReturnContext +{ + [ThreadStatic] + private static RawReturn? _pending; + + /// Stores a for the generated code to consume. Accepts the marker directly to avoid re-wrapping. + [EditorBrowsable(EditorBrowsableState.Never)] + public static void Set(RawReturn raw) => _pending = raw; + + /// Consumes and returns the raw return value, if one was set. Clears the slot. + [EditorBrowsable(EditorBrowsableState.Never)] + public static bool TryConsume(out object? value) + { + if (_pending is { } raw) + { + value = raw.Value; + _pending = null; + return true; + } + + value = null; + return false; + } + + /// Clears any stale value from a previous call. Called at dispatch entry to prevent leaks. + [EditorBrowsable(EditorBrowsableState.Never)] + public static void Clear() => _pending = null; +} diff --git a/TUnit.Mocks/Setup/VoidMethodSetupBuilder.cs b/TUnit.Mocks/Setup/VoidMethodSetupBuilder.cs index 3e2c09e7f2..aad04447cf 100644 --- a/TUnit.Mocks/Setup/VoidMethodSetupBuilder.cs +++ b/TUnit.Mocks/Setup/VoidMethodSetupBuilder.cs @@ -17,6 +17,12 @@ public VoidMethodSetupBuilder(MethodSetup setup) _setup = setup; } + public IVoidSetupChain Returns() + { + _setup.AddBehavior(VoidReturnBehavior.Instance); + return this; + } + public IVoidSetupChain Throws() where TException : Exception, new() { _setup.AddBehavior(new ComputedThrowBehavior(_ => new TException())); @@ -65,5 +71,19 @@ public IVoidSetupChain TransitionsTo(string stateName) return this; } + [EditorBrowsable(EditorBrowsableState.Never)] + public IVoidSetupChain ReturnsRaw(object? rawValue) + { + _setup.AddBehavior(new RawReturnBehavior(rawValue)); + return this; + } + + [EditorBrowsable(EditorBrowsableState.Never)] + public IVoidSetupChain ReturnsRaw(Func factory) + { + _setup.AddBehavior(new ComputedRawReturnBehavior(factory)); + return this; + } + public IVoidMethodSetup Then() => this; } diff --git a/TUnit.Mocks/VoidMockMethodCall.cs b/TUnit.Mocks/VoidMockMethodCall.cs index c2bf80012c..b10f9a261c 100644 --- a/TUnit.Mocks/VoidMockMethodCall.cs +++ b/TUnit.Mocks/VoidMockMethodCall.cs @@ -55,6 +55,12 @@ private VoidMethodSetupBuilder EnsureSetup() => // IVoidMethodSetup implementation + public IVoidSetupChain Returns() + { + EnsureSetup().Returns(); + return this; + } + public IVoidSetupChain Throws() where TException : Exception, new() { EnsureSetup().Throws(); @@ -105,6 +111,20 @@ public IVoidSetupChain TransitionsTo(string stateName) return this; } + [EditorBrowsable(EditorBrowsableState.Never)] + public IVoidSetupChain ReturnsRaw(object? rawValue) + { + EnsureSetup().ReturnsRaw(rawValue); + return this; + } + + [EditorBrowsable(EditorBrowsableState.Never)] + public IVoidSetupChain ReturnsRaw(Func factory) + { + EnsureSetup().ReturnsRaw(factory); + return this; + } + // IVoidSetupChain implementation public IVoidMethodSetup Then()