Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,22 @@ namespace TUnit.Mocks.Generated

public void Process(global::System.ReadOnlySpan<byte> data)
{
_engine.HandleCall(0, "Process", global::System.Array.Empty<object?>());
#if NET9_0_OR_GREATER
var __args = new object?[] { null };
#else
var __args = global::System.Array.Empty<object?>();
#endif
_engine.HandleCall(0, "Process", __args);
}

public int Parse(global::System.ReadOnlySpan<char> text)
{
return _engine.HandleCallWithReturn<int>(1, "Parse", global::System.Array.Empty<object?>(), default);
#if NET9_0_OR_GREATER
var __args = new object?[] { null };
#else
var __args = global::System.Array.Empty<object?>();
#endif
return _engine.HandleCallWithReturn<int>(1, "Parse", __args, default);
}

public string GetName()
Expand All @@ -72,17 +82,33 @@ namespace TUnit.Mocks.Generated
{
public static class IBufferProcessor_MockMemberExtensions
{
#if NET9_0_OR_GREATER
public static global::TUnit.Mocks.VoidMockMethodCall Process(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock, global::TUnit.Mocks.Arguments.RefStructArg<global::System.ReadOnlySpan<byte>> data)
{
var matchers = new global::TUnit.Mocks.Arguments.IArgumentMatcher[] { data.Matcher };
return new global::TUnit.Mocks.VoidMockMethodCall(mock.Engine, 0, "Process", matchers);
}
#else
public static global::TUnit.Mocks.VoidMockMethodCall Process(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock)
{
var matchers = global::System.Array.Empty<global::TUnit.Mocks.Arguments.IArgumentMatcher>();
return new global::TUnit.Mocks.VoidMockMethodCall(mock.Engine, 0, "Process", matchers);
}
#endif

#if NET9_0_OR_GREATER
public static global::TUnit.Mocks.MockMethodCall<int> Parse(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock, global::TUnit.Mocks.Arguments.RefStructArg<global::System.ReadOnlySpan<char>> text)
{
var matchers = new global::TUnit.Mocks.Arguments.IArgumentMatcher[] { text.Matcher };
return new global::TUnit.Mocks.MockMethodCall<int>(mock.Engine, 1, "Parse", matchers);
}
#else
public static global::TUnit.Mocks.MockMethodCall<int> Parse(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock)
{
var matchers = global::System.Array.Empty<global::TUnit.Mocks.Arguments.IArgumentMatcher>();
return new global::TUnit.Mocks.MockMethodCall<int>(mock.Engine, 1, "Parse", matchers);
}
#endif

public static global::TUnit.Mocks.MockMethodCall<string> GetName(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock)
{
Expand Down
36 changes: 27 additions & 9 deletions TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ private static void GenerateWrapMethodBody(CodeWriter writer, MockMemberModel me
}
}

var argsArray = GetArgsArrayExpression(method);
var argsArray = EmitArgsArrayVariable(writer, method);
var argPassList = GetArgPassList(method);

if (method.IsVoid && !method.IsAsync)
Expand Down Expand Up @@ -461,7 +461,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
}
}

var argsArray = GetArgsArrayExpression(method);
var argsArray = EmitArgsArrayVariable(writer, method);
var argPassList = GetArgPassList(method);

if (method.IsVoid && !method.IsAsync)
Expand Down Expand Up @@ -551,7 +551,7 @@ private static void GenerateEngineDispatchBody(CodeWriter writer, MockMemberMode
}
}

var argsArray = GetArgsArrayExpression(method);
var argsArray = EmitArgsArrayVariable(writer, method);

var hasOutRef = HasOutRefParams(method);

Expand Down Expand Up @@ -955,14 +955,32 @@ private static void EmitOutRefReadback(CodeWriter writer, MockMemberModel method
}
}

private static string GetArgsArrayExpression(MockMemberModel method)
private static string EmitArgsArrayVariable(CodeWriter writer, MockMemberModel method)
{
// Only include non-out, non-ref-struct parameters in args array
// (ref structs cannot be boxed into object?[])
var matchableParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList();
if (!method.HasRefStructParams)
return GetArgsArrayExpression(method, false);

writer.AppendLine("#if NET9_0_OR_GREATER");
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, true)};");
writer.AppendLine("#else");
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, false)};");
writer.AppendLine("#endif");
return "__args";
}

private static string GetArgsArrayExpression(MockMemberModel method, bool includeRefStructSentinels)
{
var nonOutParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out).ToList();
if (includeRefStructSentinels)
{
if (nonOutParams.Count == 0) return "global::System.Array.Empty<object?>()";
var args = string.Join(", ", nonOutParams.Select(p => p.IsRefStruct ? "null" : p.Name));
return $"new object?[] {{ {args} }}";
}
var matchableParams = nonOutParams.Where(p => !p.IsRefStruct).ToList();
if (matchableParams.Count == 0) return "global::System.Array.Empty<object?>()";
var args = string.Join(", ", matchableParams.Select(p => p.Name));
return $"new object?[] {{ {args} }}";
var argsStr = string.Join(", ", matchableParams.Select(p => p.Name));
return $"new object?[] {{ {argsStr} }}";
}

/// <summary>
Expand Down
128 changes: 97 additions & 31 deletions TUnit.Mocks.SourceGenerator/Builders/MockMembersBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,23 @@ private static void GenerateUnifiedSealedClass(CodeWriter writer, MockMemberMode

var wrapperName = GetWrapperName(safeName, method);
var matchableParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList();
var hasRefStructParams = method.HasRefStructParams;
var allNonOutParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out).ToList();

// Ref struct returns use the void wrapper (can't use generic type args with ref structs)
if (method.IsVoid || method.IsRefStructReturn)
{
GenerateVoidUnifiedClass(writer, wrapperName, matchableParams, events, method.Parameters);
GenerateVoidUnifiedClass(writer, wrapperName, matchableParams, events, method.Parameters, hasRefStructParams, allNonOutParams);
}
else
{
GenerateReturnUnifiedClass(writer, wrapperName, matchableParams, setupReturnType, events, method.Parameters);
GenerateReturnUnifiedClass(writer, wrapperName, matchableParams, setupReturnType, events, method.Parameters, hasRefStructParams, allNonOutParams);
}
}

private static void GenerateReturnUnifiedClass(CodeWriter writer, string wrapperName,
List<MockParameterModel> nonOutParams, string returnType, EquatableArray<MockEventModel> events,
EquatableArray<MockParameterModel> allParameters)
EquatableArray<MockParameterModel> allParameters, bool hasRefStructParams, List<MockParameterModel> allNonOutParams)
{
var builderType = $"global::TUnit.Mocks.Setup.MethodSetupBuilder<{returnType}>";
var hasOutRef = allParameters.Any(p => p.Direction == ParameterDirection.Out || p.Direction == ParameterDirection.Ref);
Expand Down Expand Up @@ -198,11 +200,30 @@ private static void GenerateReturnUnifiedClass(CodeWriter writer, string wrapper
if (nonOutParams.Count >= 1)
{
writer.AppendLine();
GenerateTypedReturnsOverload(writer, nonOutParams, returnType, wrapperName);
writer.AppendLine();
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
writer.AppendLine();
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
if (hasRefStructParams)
{
writer.AppendLine("#if NET9_0_OR_GREATER");
GenerateTypedReturnsOverload(writer, nonOutParams, returnType, wrapperName, allNonOutParams);
writer.AppendLine();
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName, allNonOutParams);
writer.AppendLine();
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName, allNonOutParams);
writer.AppendLine("#else");
GenerateTypedReturnsOverload(writer, nonOutParams, returnType, wrapperName);
writer.AppendLine();
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
writer.AppendLine();
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
writer.AppendLine("#endif");
}
else
{
GenerateTypedReturnsOverload(writer, nonOutParams, returnType, wrapperName);
writer.AppendLine();
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
writer.AppendLine();
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
}
}

// Typed out/ref parameter setters
Expand Down Expand Up @@ -239,7 +260,7 @@ private static void GenerateReturnUnifiedClass(CodeWriter writer, string wrapper

private static void GenerateVoidUnifiedClass(CodeWriter writer, string wrapperName,
List<MockParameterModel> nonOutParams, EquatableArray<MockEventModel> events,
EquatableArray<MockParameterModel> allParameters)
EquatableArray<MockParameterModel> allParameters, bool hasRefStructParams, List<MockParameterModel> allNonOutParams)
{
var builderType = "global::TUnit.Mocks.Setup.VoidMethodSetupBuilder";
var hasOutRef = allParameters.Any(p => p.Direction == ParameterDirection.Out || p.Direction == ParameterDirection.Ref);
Expand Down Expand Up @@ -307,9 +328,24 @@ private static void GenerateVoidUnifiedClass(CodeWriter writer, string wrapperNa
if (nonOutParams.Count >= 1)
{
writer.AppendLine();
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
writer.AppendLine();
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
if (hasRefStructParams)
{
writer.AppendLine("#if NET9_0_OR_GREATER");
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName, allNonOutParams);
writer.AppendLine();
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName, allNonOutParams);
writer.AppendLine("#else");
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
writer.AppendLine();
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
writer.AppendLine("#endif");
}
else
{
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
writer.AppendLine();
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
}
}

// Typed out/ref parameter setters
Expand Down Expand Up @@ -345,11 +381,11 @@ private static void GenerateVoidUnifiedClass(CodeWriter writer, string wrapperNa
}

private static void GenerateTypedReturnsOverload(CodeWriter writer, List<MockParameterModel> nonOutParams,
string returnType, string wrapperName)
string returnType, string wrapperName, List<MockParameterModel>? allNonOutParams = null)
{
var typeList = string.Join(", ", nonOutParams.Select(p => p.FullyQualifiedType));
var funcType = $"global::System.Func<{typeList}, {returnType}>";
var castArgs = BuildCastArgs(nonOutParams);
var castArgs = BuildCastArgs(nonOutParams, allNonOutParams);

writer.AppendLine("/// <summary>Configure a typed computed return value using the actual method parameters.</summary>");
using (writer.Block($"public {wrapperName} Returns({funcType} factory)"))
Expand All @@ -360,11 +396,11 @@ private static void GenerateTypedReturnsOverload(CodeWriter writer, List<MockPar
}

private static void GenerateTypedCallbackOverload(CodeWriter writer, List<MockParameterModel> nonOutParams,
string wrapperName)
string wrapperName, List<MockParameterModel>? allNonOutParams = null)
{
var typeList = string.Join(", ", nonOutParams.Select(p => p.FullyQualifiedType));
var actionType = $"global::System.Action<{typeList}>";
var castArgs = BuildCastArgs(nonOutParams);
var castArgs = BuildCastArgs(nonOutParams, allNonOutParams);

writer.AppendLine("/// <summary>Execute a typed callback using the actual method parameters.</summary>");
using (writer.Block($"public {wrapperName} Callback({actionType} callback)"))
Expand All @@ -375,11 +411,11 @@ private static void GenerateTypedCallbackOverload(CodeWriter writer, List<MockPa
}

private static void GenerateTypedThrowsOverload(CodeWriter writer, List<MockParameterModel> nonOutParams,
string wrapperName)
string wrapperName, List<MockParameterModel>? allNonOutParams = null)
{
var typeList = string.Join(", ", nonOutParams.Select(p => p.FullyQualifiedType));
var funcType = $"global::System.Func<{typeList}, global::System.Exception>";
var castArgs = BuildCastArgs(nonOutParams);
var castArgs = BuildCastArgs(nonOutParams, allNonOutParams);

writer.AppendLine("/// <summary>Configure a typed computed exception using the actual method parameters.</summary>");
using (writer.Block($"public {wrapperName} Throws({funcType} exceptionFactory)"))
Expand Down Expand Up @@ -442,13 +478,32 @@ private static void GenerateTypedOutRefMethods(CodeWriter writer, EquatableArray
private static string ToPascalCase(string name)
=> string.IsNullOrEmpty(name) ? name : char.ToUpperInvariant(name[0]) + name[1..];

private static string BuildCastArgs(List<MockParameterModel> nonOutParams)
private static string BuildCastArgs(List<MockParameterModel> nonOutParams, List<MockParameterModel>? allNonOutParams = null)
{
return string.Join(", ", nonOutParams.Select((p, i) =>
$"({p.FullyQualifiedType})args[{i}]!"));
if (allNonOutParams is null)
return string.Join(", ", nonOutParams.Select((p, i) => $"({p.FullyQualifiedType})args[{i}]!"));

var indexMap = allNonOutParams.Select((p, i) => (p, i)).ToDictionary(x => x.p, x => x.i);
return string.Join(", ", nonOutParams.Select(p => $"({p.FullyQualifiedType})args[{indexMap[p]}]!"));
}

private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName)
{
if (method.HasRefStructParams)
{
writer.AppendLine("#if NET9_0_OR_GREATER");
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: true);
writer.AppendLine("#else");
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false);
writer.AppendLine("#endif");
}
else
{
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false);
}
}

private static void EmitMemberMethodBody(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName, bool includeRefStructArgs)
{
// For async methods (Task<T>/ValueTask<T>), unwrap the return type so users write .Returns(5) not .Returns(Task.FromResult(5))
// For void-async methods (Task/ValueTask), IsVoid is already true
Expand All @@ -474,7 +529,7 @@ private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel meth
returnType = $"global::TUnit.Mocks.MockMethodCall<{setupReturnType}>";
}

var paramList = GetArgParameterList(method);
var paramList = GetArgParameterList(method, includeRefStructArgs);
var typeParams = GetTypeParameterList(method);
var constraints = GetConstraintClauses(method);

Expand All @@ -484,9 +539,10 @@ private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel meth

using (writer.Block($"public static {returnType} {safeMemberName}{typeParams}({fullParamList}){constraints}"))
{
// Build matchers array (exclude out and ref struct params)
var matchableParams = method.Parameters
.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList();
// Build matchers array
var matchableParams = includeRefStructArgs
? method.Parameters.Where(p => p.Direction != ParameterDirection.Out).ToList()
: method.Parameters.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList();

if (matchableParams.Count == 0)
{
Expand Down Expand Up @@ -576,13 +632,23 @@ private static void GenerateRaiseExtensionMethods(CodeWriter writer, MockTypeMod
}
}

private static string GetArgParameterList(MockMemberModel method)
private static string GetArgParameterList(MockMemberModel method, bool includeRefStructArgs)
{
// Only include non-out, non-ref-struct parameters as Arg<T> in setup
// (ref structs cannot be used as generic type arguments)
return string.Join(", ", method.Parameters
.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct)
.Select(p => $"global::TUnit.Mocks.Arguments.Arg<{p.FullyQualifiedType}> {p.Name}"));
var parts = new List<string>();
foreach (var p in method.Parameters)
{
if (p.Direction == ParameterDirection.Out) continue;
if (p.IsRefStruct)
{
if (includeRefStructArgs)
parts.Add($"global::TUnit.Mocks.Arguments.RefStructArg<{p.FullyQualifiedType}> {p.Name}");
}
else
{
parts.Add($"global::TUnit.Mocks.Arguments.Arg<{p.FullyQualifiedType}> {p.Name}");
}
}
return string.Join(", ", parts);
}

private static string GetTypeParameterList(MockMemberModel method)
Expand Down
7 changes: 7 additions & 0 deletions TUnit.Mocks.SourceGenerator/Models/MockMemberModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Linq;

namespace TUnit.Mocks.SourceGenerator.Models;

Expand Down Expand Up @@ -31,6 +32,12 @@ internal sealed record MockMemberModel : IEquatable<MockMemberModel>
public bool IsProtected { get; init; }
public bool IsRefStructReturn { get; init; }

/// <summary>
/// Returns true if the method has any non-out ref struct parameters.
/// Computed from <see cref="Parameters"/> — does not participate in equality.
/// </summary>
public bool HasRefStructParams => Parameters.Any(p => p.IsRefStruct && p.Direction != ParameterDirection.Out);

public bool Equals(MockMemberModel? other)
{
if (other is null) return false;
Expand Down
Loading
Loading