Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
<ItemGroup>
<PackageVersion Include="Aspire.Hosting.AppHost" Version="13.2.1" />
<PackageVersion Include="Aspire.Hosting.Testing" Version="13.2.1" />
<PackageVersion Include="Azure.Data.Tables" Version="12.11.0" />
<PackageVersion Include="Azure.Storage.Blobs" Version="12.27.0" />
<PackageVersion Include="AutoFixture" Version="4.18.1" />
<PackageVersion Include="BenchmarkDotNet" Version="0.15.8" />
<PackageVersion Include="BenchmarkDotNet.Annotations" Version="0.15.8" />
Expand Down
12 changes: 6 additions & 6 deletions TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
writer.AppendLine("return;");
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"base.{method.Name}({argPassList});");
writer.AppendLine($"base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else if (method.IsVoid && method.IsAsync)
{
Expand All @@ -589,7 +589,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
}
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else if (method.IsAsync)
{
Expand Down Expand Up @@ -619,7 +619,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
}
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else if (method.IsRefStructReturn)
{
Expand All @@ -638,7 +638,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
}
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else if (method.IsReturnTypeStaticAbstractInterface)
{
Expand All @@ -650,7 +650,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
writer.AppendLine("return __result;");
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else
{
Expand All @@ -662,7 +662,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
writer.AppendLine("return __result;");
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
}

Expand Down
107 changes: 91 additions & 16 deletions TUnit.Mocks.SourceGenerator/Builders/MockMembersBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,19 @@ public static string Build(MockTypeModel model)
{
bool firstMember = true;

// Pre-compute which methods need their `out` parameters kept in the extension
// signature to avoid CS0111 collisions. A method needs disambiguation when
// some other method on the model shares the same name AND the same
// matchable-parameter signature (i.e. parameters excluding out).
var needsOutDisambiguation = ComputeOutDisambiguationSet(model.Methods);

// Methods
foreach (var method in model.Methods)
{
if (!firstMember) writer.AppendLine();
firstMember = false;
GenerateMemberMethod(writer, method, model, safeName);
GenerateMemberMethod(writer, method, model, safeName,
keepOutParams: needsOutDisambiguation.Contains(method.MemberId));
}

// Properties -- extension properties via C# 14 extension blocks
Expand Down Expand Up @@ -84,6 +91,15 @@ public static string Build(MockTypeModel model)
return writer.ToString();
}

private static void EmitOutParamDefaults(CodeWriter writer, MockMemberModel method, bool keepOutParams)
{
if (!keepOutParams) return;
foreach (var op in method.Parameters.Where(p => p.Direction == ParameterDirection.Out))
{
writer.AppendLine($"{op.Name} = default!;");
}
}

private static bool ShouldGenerateTypedWrapper(MockMemberModel method, bool hasEvents)
{
if (method.IsGenericMethod) return false;
Expand Down Expand Up @@ -558,25 +574,60 @@ private static string CastArg(MockParameterModel p, int index)
return $"({p.FullyQualifiedType})args[{index}]{bang}";
}

private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName)
private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName, bool keepOutParams)
{
if (method.HasRefStructParams)
{
writer.AppendLine("#if NET9_0_OR_GREATER");
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: true);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: true);
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: true, keepOutParams);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: true, keepOutParams);
writer.AppendLine("#else");
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: false);
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false, keepOutParams);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: false, keepOutParams);
writer.AppendLine("#endif");
}
else
{
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: false);
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false, keepOutParams);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: false, keepOutParams);
}
}

private static HashSet<int> ComputeOutDisambiguationSet(EquatableArray<MockMemberModel> methods)
{
// Group methods by (name, matchable-parameter signature). Any group with >1 entry
// contains methods that would otherwise emit colliding extension overloads — flag
// every member of such a group whose original method has out parameters.
var result = new HashSet<int>();
var byKey = new Dictionary<string, List<MockMemberModel>>(System.StringComparer.Ordinal);
foreach (var m in methods)
{
var matchable = string.Join(",", m.Parameters
.Where(p => p.Direction != ParameterDirection.Out)
.Select(p => p.FullyQualifiedType));
var typeArity = m.TypeParameters.Length;
var key = $"{m.Name}`{typeArity}({matchable})";
if (!byKey.TryGetValue(key, out var list))
{
list = new List<MockMemberModel>();
byKey[key] = list;
}
list.Add(m);
}
foreach (var group in byKey.Values)
{
if (group.Count < 2) continue;
foreach (var m in group)
{
if (m.Parameters.Any(p => p.Direction == ParameterDirection.Out))
{
result.Add(m.MemberId);
}
}
}
return result;
}

private static (bool UseTypedWrapper, string ReturnType, string SetupReturnType) GetReturnTypeInfo(
MockMemberModel method, MockTypeModel model, string safeName)
{
Expand Down Expand Up @@ -608,11 +659,11 @@ private static (bool UseTypedWrapper, string ReturnType, string SetupReturnType)
return (useTypedWrapper, returnType, setupReturnType);
}

private static void EmitMemberMethodBody(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName, bool includeRefStructArgs)
private static void EmitMemberMethodBody(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName, bool includeRefStructArgs, bool keepOutParams)
{
var (useTypedWrapper, returnType, setupReturnType) = GetReturnTypeInfo(method, model, safeName);

var paramList = GetArgParameterList(method, includeRefStructArgs);
var paramList = GetArgParameterList(method, includeRefStructArgs, keepOutParams);
var typeParams = MockImplBuilder.GetTypeParameterList(method);
var constraints = MockImplBuilder.GetConstraintClauses(method);

Expand All @@ -633,6 +684,8 @@ private static void EmitMemberMethodBody(CodeWriter writer, MockMemberModel meth

using (writer.Block($"public static {returnType} {safeMemberName}{typeParams}({fullParamList}){constraints}"))
{
EmitOutParamDefaults(writer, method, keepOutParams);

// Build matchers array
var matchableParams = includeRefStructArgs
? method.Parameters.Where(p => p.Direction != ParameterDirection.Out).ToList()
Expand Down Expand Up @@ -684,7 +737,7 @@ private static List<int> GetFuncEligibleParamIndices(MockMemberModel method)
}

private static void EmitFuncOverloads(CodeWriter writer, MockMemberModel method, MockTypeModel model,
string safeName, bool includeRefStructArgs)
string safeName, bool includeRefStructArgs, bool keepOutParams)
{
var eligible = GetFuncEligibleParamIndices(method);
if (eligible.Count == 0 || eligible.Count > MaxFuncOverloadParams) return;
Expand All @@ -693,12 +746,12 @@ private static void EmitFuncOverloads(CodeWriter writer, MockMemberModel method,
for (int mask = 1; mask <= totalMasks; mask++)
{
writer.AppendLine();
EmitSingleFuncOverload(writer, method, model, safeName, eligible, mask, includeRefStructArgs);
EmitSingleFuncOverload(writer, method, model, safeName, eligible, mask, includeRefStructArgs, keepOutParams);
}
}

private static void EmitSingleFuncOverload(CodeWriter writer, MockMemberModel method, MockTypeModel model,
string safeName, List<int> eligibleIndices, int funcMask, bool includeRefStructArgs)
string safeName, List<int> eligibleIndices, int funcMask, bool includeRefStructArgs, bool keepOutParams)
{
// Determine which parameter indices use Func<T, bool>
var funcIndices = new HashSet<int>();
Expand All @@ -717,7 +770,15 @@ private static void EmitSingleFuncOverload(CodeWriter writer, MockMemberModel me
for (int i = 0; i < method.Parameters.Length; i++)
{
var p = method.Parameters[i];
if (p.Direction == ParameterDirection.Out) continue;
if (p.Direction == ParameterDirection.Out)
{
// Keep out params only when needed to disambiguate colliding overloads.
if (keepOutParams)
{
paramParts.Add($"out {p.FullyQualifiedType} {p.Name}");
}
continue;
}

if (funcIndices.Contains(i))
{
Expand Down Expand Up @@ -757,6 +818,8 @@ private static void EmitSingleFuncOverload(CodeWriter writer, MockMemberModel me

using (writer.Block($"public static {returnType} {safeMemberName}{typeParams}({fullParamList}){constraints}"))
{
EmitOutParamDefaults(writer, method, keepOutParams);

// Convert Func params to Arg<T> via implicit conversion
foreach (var idx in funcIndices.OrderBy(i => i))
{
Expand Down Expand Up @@ -871,12 +934,24 @@ private static void GenerateRaiseExtensionMethods(CodeWriter writer, MockTypeMod
}
}

private static string GetArgParameterList(MockMemberModel method, bool includeRefStructArgs)
private static string GetArgParameterList(MockMemberModel method, bool includeRefStructArgs, bool keepOutParams)
{
var parts = new List<string>();
foreach (var p in method.Parameters)
{
if (p.Direction == ParameterDirection.Out) continue;
if (p.Direction == ParameterDirection.Out)
{
// Normally out params are omitted from the extension signature so callers
// don't have to write `out _`. But when another overload of this method has
// the same matchable-parameter signature (e.g. GenerateSasUri(perms, expires)
// vs GenerateSasUri(perms, expires, out string)), we MUST keep the out param
// in the signature, otherwise CS0111 fires on the generated extensions.
if (keepOutParams)
{
parts.Add($"out {p.FullyQualifiedType} {p.Name}");
}
continue;
}
if (p.IsRefStruct)
{
if (includeRefStructArgs)
Expand Down
25 changes: 25 additions & 0 deletions TUnit.Mocks.Tests/Issue5434Tests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Azure.Data.Tables;
using Azure.Storage.Blobs;

namespace TUnit.Mocks.Tests;

// Reproduction for https://github.com/thomhurst/TUnit/issues/5434
// BlobClient: CS0111 duplicate GenerateSasUri / GenerateUserDelegationSasUri members in generated extensions.
// TableClient: CS0411 type inference failures for generic methods (GetEntity<T>, GetEntityAsync<T>,
// GetEntityIfExists<T>, GetEntityIfExistsAsync<T>, Query<T>, QueryAsync<T>) in generated impl factory.
public class Issue5434Tests
{
[Test]
public void Can_Mock_BlobClient()
{
var mock = Mock.Of<BlobClient>(MockBehavior.Strict);
_ = mock.Object;
}

[Test]
public void Can_Mock_TableClient()
{
var mock = Mock.Of<TableClient>(MockBehavior.Strict);
_ = mock.Object;
}
}
2 changes: 2 additions & 0 deletions TUnit.Mocks.Tests/TUnit.Mocks.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Data.Tables" />
<PackageReference Include="Azure.Storage.Blobs" />
<ProjectReference Include="..\TUnit.Mocks\TUnit.Mocks.csproj" />
<ProjectReference Include="..\TUnit.Mocks.Assertions\TUnit.Mocks.Assertions.csproj" />

Expand Down
Loading