From 88c81dba2f8ac93842903de826089fc30d5212c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Breu=C3=9F?= Date: Sat, 20 Jun 2026 14:09:41 +0200 Subject: [PATCH] feat: expose hidden base-interface members via Mock.As() When a derived interface hides a base member with `new`, the hidden base slot was unreachable from the setup/verify surface. The mock now also implements the base interface's IMockSetupFor.../IMockVerifyFor... surfaces, so IDerived.CreateMock().Mock.As() can configure and verify the hidden slot without an explicit Implementing<>() call. Methods, properties and events get separate base/derived slots; indexers share storage by parameter signature. Bases with static members are out of scope. --- .../Entities/MockClass.cs | 113 +++++++++++++++++- .../MockGenerator.cs | 75 +++++++++++- .../Sources/Sources.MockClass.cs | 87 +++++++++++++- .../MockTests.HiddenInterfaceMemberTests.cs | 96 +++++++++++++++ 4 files changed, 360 insertions(+), 11 deletions(-) create mode 100644 Tests/Mockolate.Tests/MockTests.HiddenInterfaceMemberTests.cs diff --git a/Source/Mockolate.SourceGenerators/Entities/MockClass.cs b/Source/Mockolate.SourceGenerators/Entities/MockClass.cs index cc4f7afa..77df75ff 100644 --- a/Source/Mockolate.SourceGenerators/Entities/MockClass.cs +++ b/Source/Mockolate.SourceGenerators/Entities/MockClass.cs @@ -1,3 +1,4 @@ +using System.Collections.Immutable; using Microsoft.CodeAnalysis; namespace Mockolate.SourceGenerators.Entities; @@ -11,6 +12,11 @@ public MockClass(ITypeSymbol[] types, IAssemblySymbol sourceAssembly) : base(typ AdditionalImplementations = new EquatableArray( types.Skip(1).Select(x => new Class(x, sourceAssembly)).ToArray()); + HiddenBaseInterfaces = IsInterface + ? new EquatableArray(GetHiddenBaseInterfaces(types[0]) + .Select(x => new Class(x, sourceAssembly)).ToArray()) + : new EquatableArray([]); + if (!IsInterface && types[0] is INamedTypeSymbol namedTypeSymbol) { Constructors = @@ -34,13 +40,21 @@ public MockClass(ITypeSymbol[] types, IAssemblySymbol sourceAssembly) : base(typ public EquatableArray AdditionalImplementations { get; } + /// + /// Base interfaces whose members are hidden (via ) by the mocked + /// interface. Their setup/verify surfaces are generated and implemented so the hidden slots are + /// reachable through .Mock.As<TBase>(). Distinct from + /// (the user's explicit Implementing<T>()). + /// + public EquatableArray HiddenBaseInterfaces { get; } + /// /// MockClass equality is keyed on plus a content-derived /// hash that folds the base surface together with the mock-only fields - /// (, , - /// ). Two mocks of the same root with different additional - /// interfaces, different constructor surfaces, or different delegate signatures must hash - /// apart so Roslyn's incremental cache invalidates when any of those change. + /// (, , + /// , ). Two mocks of the same root with + /// different additional interfaces, different constructor surfaces, or different delegate + /// signatures must hash apart so Roslyn's incremental cache invalidates when any of those change. /// public bool Equals(MockClass? other) => ReferenceEquals(this, other) || @@ -55,6 +69,11 @@ public IEnumerable AllImplementations() { yield return additionalImplementation; } + + foreach (Class hiddenBaseInterface in HiddenBaseInterfaces) + { + yield return hiddenBaseInterface; + } } public override bool Equals(Class? other) => other is MockClass mc && Equals(mc); @@ -67,6 +86,7 @@ private int ComputeMockSurfaceHash() { int hash = base.GetHashCode(); hash = unchecked((hash * 17) + AdditionalImplementations.GetHashCode()); + hash = unchecked((hash * 17) + HiddenBaseInterfaces.GetHashCode()); if (Constructors is { } constructors) { hash = unchecked((hash * 17) + constructors.GetHashCode()); @@ -79,4 +99,89 @@ private int ComputeMockSurfaceHash() return hash; } + + /// + /// Base interfaces of that declare a member which a more-derived + /// interface in the hierarchy hides (a member with a matching signature). + /// The hidden base member is a separate interface slot, so its setup/verify surface must be + /// generated explicitly. Ordinary (non-hidden) inheritance returns nothing. + /// + private static IEnumerable GetHiddenBaseInterfaces(ITypeSymbol type) + { + ImmutableArray allInterfaces = type.AllInterfaces; + foreach (INamedTypeSymbol baseInterface in allInterfaces) + { + if (baseInterface.GetMembers().Any(member => member.IsStatic)) + { + continue; + } + + bool hasHiddenMember = false; + foreach (ISymbol baseMember in baseInterface.GetMembers()) + { + if (!IsHidableMember(baseMember)) + { + continue; + } + + if (HidesMember(type, baseMember) || + allInterfaces.Any(intermediate => + !SymbolEqualityComparer.Default.Equals(intermediate, baseInterface) && + intermediate.AllInterfaces.Contains(baseInterface, SymbolEqualityComparer.Default) && + HidesMember(intermediate, baseMember))) + { + hasHiddenMember = true; + break; + } + } + + if (hasHiddenMember) + { + yield return baseInterface; + } + } + } + + private static bool HidesMember(ITypeSymbol hidingType, ISymbol baseMember) + => hidingType.GetMembers(baseMember.Name) + .Any(candidate => !SymbolEqualityComparer.Default.Equals(candidate.ContainingType, baseMember.ContainingType) && + SignatureMatches(candidate, baseMember)); + + private static bool IsHidableMember(ISymbol member) + => member switch + { + IMethodSymbol { MethodKind: MethodKind.Ordinary, } => true, + IPropertySymbol => true, + IEventSymbol => true, + _ => false, + }; + + private static bool SignatureMatches(ISymbol a, ISymbol b) + => a.Kind == b.Kind && (a, b) switch + { + (IMethodSymbol ma, IMethodSymbol mb) => ma.TypeParameters.Length == mb.TypeParameters.Length && + ParametersMatch(ma.Parameters, mb.Parameters), + (IPropertySymbol pa, IPropertySymbol pb) => ParametersMatch(pa.Parameters, pb.Parameters), + (IEventSymbol, IEventSymbol) => true, + _ => false, + }; + + private static bool ParametersMatch(ImmutableArray a, ImmutableArray b) + { + if (a.Length != b.Length) + { + return false; + } + + for (int i = 0; i < a.Length; i++) + { + if (a[i].RefKind != b[i].RefKind || + !SymbolEqualityComparer.Default.Equals(a[i].Type, b[i].Type)) + { + return false; + } + } + + return true; + } } diff --git a/Source/Mockolate.SourceGenerators/MockGenerator.cs b/Source/Mockolate.SourceGenerators/MockGenerator.cs index 9c20aade..af9bf52d 100644 --- a/Source/Mockolate.SourceGenerators/MockGenerator.cs +++ b/Source/Mockolate.SourceGenerators/MockGenerator.cs @@ -356,8 +356,18 @@ private static EquatableArray CreateNamedMocks(EquatableArray result = new(arr.Length); HashSet seenBaseClasses = new(StringComparer.Ordinal); - // Pass 1: assign disambiguated names to every distinct base/additional class. The order - // here must be deterministic so the same input set always yields the same names. + Dictionary primaryMocks = new(StringComparer.Ordinal); + foreach (MockClass mc in arr) + { + if (IsValidMockDeclaration(mc)) + { + primaryMocks[mc.ClassFullName] = mc; + } + } + + // Pass 1a: assign disambiguated names to every distinct base/additional/hidden-base class. The + // order here must be deterministic so the same input set always yields the same names. + List orderedClasses = new(); foreach (MockClass mc in arr) { if (!IsValidMockDeclaration(mc)) @@ -381,8 +391,25 @@ private static EquatableArray CreateNamedMocks(EquatableArray? hiddenBases = null; + if (primaryMocks.TryGetValue(@class.ClassFullName, out MockClass? primaryMock) && + primaryMock.HiddenBaseInterfaces.Count > 0) + { + hiddenBases = new EquatableArray(primaryMock.HiddenBaseInterfaces + .Select(hiddenBase => new NamedClass(baseClassNames[hiddenBase.ClassFullName], hiddenBase)) + .ToArray()); } + + result.Add(new NamedMock(actualName, actualName, @class, null, hiddenBases)); } // Pass 2: combination mocks (additional implementations). @@ -420,6 +447,15 @@ private static EquatableArray CollectAsExtensionPairs(Equat List ordered = new(); foreach (NamedMock nm in arr) { + if (nm.HiddenBases is { } hiddenBases) + { + foreach (NamedClass hiddenBase in hiddenBases) + { + AddIfNew(seen, ordered, MockAsExtensionPair.Create( + nm.ParentName, nm.Mock.ClassFullName, hiddenBase.Name, hiddenBase.Class.ClassFullName)); + } + } + if (nm.AdditionalClasses is not { } additional || additional.Count == 0) { continue; @@ -471,8 +507,19 @@ private static void EmitMockFile(SourceProductionContext context, NamedMock name if (named.AdditionalClasses is not { } additional || additional.Count == 0) { + (string Name, Class Class)[] hiddenBaseArr = []; + if (named.HiddenBases is { } hiddenBases && hiddenBases.Count > 0) + { + NamedClass[] hiddenNamed = hiddenBases.AsArray(); + hiddenBaseArr = new (string Name, Class Class)[hiddenNamed.Length]; + for (int i = 0; i < hiddenNamed.Length; i++) + { + hiddenBaseArr[i] = (hiddenNamed[i].Name, hiddenNamed[i].Class); + } + } + context.AddSource($"Mock.{fileName}.g.cs", - ToSource(Sources.Sources.MockClass(named.ParentName, @class, hasOverloadResolutionPriority))); + ToSource(Sources.Sources.MockClass(named.ParentName, @class, hasOverloadResolutionPriority, hiddenBaseArr))); return; } @@ -535,18 +582,21 @@ public bool Equals(RefStructAggregate? other) internal sealed class NamedMock : IEquatable { - public NamedMock(string fileName, string parentName, Class mock, EquatableArray? additionalClasses) + public NamedMock(string fileName, string parentName, Class mock, EquatableArray? additionalClasses, + EquatableArray? hiddenBases = null) { FileName = fileName; ParentName = parentName; Mock = mock; AdditionalClasses = additionalClasses; + HiddenBases = hiddenBases; } public string FileName { get; } public string ParentName { get; } public Class Mock { get; } public EquatableArray? AdditionalClasses { get; } + public EquatableArray? HiddenBases { get; } public bool Equals(NamedMock? other) { @@ -565,6 +615,11 @@ public bool Equals(NamedMock? other) return false; } + if (!NullableEquals(HiddenBases, other.HiddenBases)) + { + return false; + } + if (AdditionalClasses is null) { return other.AdditionalClasses is null; @@ -576,6 +631,11 @@ public bool Equals(NamedMock? other) } return AdditionalClasses.Value.Equals(other.AdditionalClasses.Value); + + static bool NullableEquals(EquatableArray? a, EquatableArray? b) + { + return a is null ? b is null : b is not null && a.Value.Equals(b.Value); + } } public override bool Equals(object? obj) => Equals(obj as NamedMock); @@ -590,6 +650,11 @@ public override int GetHashCode() hash = unchecked((hash * 17) + additional.GetHashCode()); } + if (HiddenBases is { } hiddenBases) + { + hash = unchecked((hash * 17) + hiddenBases.GetHashCode()); + } + return hash; } } diff --git a/Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs b/Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs index 9f29ca8a..1ea8c81b 100644 --- a/Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs +++ b/Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs @@ -10,8 +10,13 @@ internal static partial class Sources { private const int MaxExplicitParameters = 4; - public static string MockClass(string name, Class @class, bool hasOverloadResolutionPriority = false) + public static string MockClass( + string name, + Class @class, + bool hasOverloadResolutionPriority = false, + (string Name, Class Class)[]? hiddenBaseInterfaces = null) { + hiddenBaseInterfaces ??= []; EquatableArray? constructors = (@class as MockClass)?.Constructors; bool hasParameterizedConstructor = !@class.IsInterface && constructors?.Any(m => m.Parameters.Count > 0) == true; @@ -23,7 +28,14 @@ public static string MockClass(string name, Class @class, bool hasOverloadResolu bool hasProtectedMembers = !@class.IsInterface && (@class.AllMethods().Any(method => method.IsProtected) || @class.AllProperties().Any(property => property.IsProtected)); string setupType = hasProtectedMembers ? $"IMockSetupInitializationFor{name}" : $"global::Mockolate.Mock.IMockSetupFor{name}"; string mockRegistryName = @class.GetUniqueName("MockRegistry", "MockolateMockRegistry"); - MemberIdTable memberIds = ComputeMemberIds(@class); + Class[] memberIdClasses = new Class[1 + hiddenBaseInterfaces.Length]; + memberIdClasses[0] = @class; + for (int i = 0; i < hiddenBaseInterfaces.Length; i++) + { + memberIdClasses[i + 1] = hiddenBaseInterfaces[i].Class; + } + + MemberIdTable memberIds = ComputeMemberIds(memberIdClasses); string memberIdPrefix = $"global::Mockolate.Mock.{name}."; StringBuilder sb = InitializeBuilder(); @@ -86,6 +98,37 @@ public static string MockClass(string name, Class @class, bool hasOverloadResolu sb.Append(',').AppendLine(); + foreach ((string Name, Class Class) hiddenBase in hiddenBaseInterfaces) + { + bool hbHasStaticMembers = hiddenBase.Class.AllMethods().Any(x => x.IsStatic) || + hiddenBase.Class.AllProperties().Any(x => x.IsStatic); + bool hbHasStaticEvents = hiddenBase.Class.AllEvents().Any(x => x.IsStatic); + bool hbHasInstanceEvents = hiddenBase.Class.AllEvents().Any(x => !x.IsStatic); + sb.Append("\t\tIMockFor").Append(hiddenBase.Name).Append(", IMockSetupFor").Append(hiddenBase.Name); + if (hbHasStaticMembers) + { + sb.Append(", IMockStaticSetupFor").Append(hiddenBase.Name); + } + + if (hbHasInstanceEvents) + { + sb.Append(", IMockRaiseOn").Append(hiddenBase.Name); + } + + if (hbHasStaticEvents) + { + sb.Append(", IMockStaticRaiseOn").Append(hiddenBase.Name); + } + + sb.Append(", IMockVerifyFor").Append(hiddenBase.Name); + if (hbHasStaticMembers || hbHasStaticEvents) + { + sb.Append(", IMockStaticVerifyFor").Append(hiddenBase.Name); + } + + sb.Append(",").AppendLine(); + } + sb.Append("\t\tglobal::Mockolate.IMock").AppendLine(); sb.Append("\t{").AppendLine(); @@ -130,6 +173,16 @@ public static string MockClass(string name, Class @class, bool hasOverloadResolu ImplementMockForInterface(sb, mockRegistryName, name, hasEvents, hasProtectedMembers, hasProtectedEvents, hasStaticMembers, hasStaticEvents); + foreach ((string Name, Class Class) hiddenBase in hiddenBaseInterfaces) + { + ImplementMockForInterface(sb, mockRegistryName, hiddenBase.Name, + hiddenBase.Class.AllEvents().Any(x => !x.IsStatic), + false /* interfaces have no protected members */, + false /* interfaces have no protected events */, + hiddenBase.Class.AllMethods().Any(x => x.IsStatic) || hiddenBase.Class.AllProperties().Any(x => x.IsStatic), + hiddenBase.Class.AllEvents().Any(x => x.IsStatic)); + } + sb.Append("\t\t/// ").AppendLine(); sb.Append("\t\tstring global::Mockolate.IMock.ToString()").AppendLine(); sb.Append("\t\t\t=> \"").Append(@class.DisplayString).Append(" mock\";").AppendLine(); @@ -189,6 +242,15 @@ public static string MockClass(string name, Class @class, bool hasOverloadResolu sb.Append("\t\t#endregion IMockStaticSetupFor").Append(name).AppendLine(); } + foreach ((string Name, Class Class) hiddenBase in hiddenBaseInterfaces) + { + sb.AppendLine(); + sb.Append("\t\t#region IMockSetupFor").Append(hiddenBase.Name).AppendLine(); + sb.AppendLine(); + ImplementSetupInterface(sb, hiddenBase.Class, mockRegistryName, $"IMockSetupFor{hiddenBase.Name}", MemberType.Public, memberIds, memberIdPrefix); + sb.Append("\t\t#endregion IMockSetupFor").Append(hiddenBase.Name).AppendLine(); + } + #endregion Mock.Setup #region Mock.Raise @@ -221,6 +283,18 @@ public static string MockClass(string name, Class @class, bool hasOverloadResolu sb.Append("\t\t#endregion IMockStaticRaiseOn").Append(name).AppendLine(); } + foreach ((string Name, Class Class) hiddenBase in hiddenBaseInterfaces) + { + if (hiddenBase.Class.AllEvents().Any(x => !x.IsStatic)) + { + sb.AppendLine(); + sb.Append("\t\t#region IMockRaiseOn").Append(hiddenBase.Name).AppendLine(); + sb.AppendLine(); + ImplementRaiseInterface(sb, hiddenBase.Class, mockRegistryName, $"IMockRaiseOn{hiddenBase.Name}", MemberType.Public); + sb.Append("\t\t#endregion IMockRaiseOn").Append(hiddenBase.Name).AppendLine(); + } + } + #endregion Mock.Raise #region Mock.Verify @@ -249,6 +323,15 @@ public static string MockClass(string name, Class @class, bool hasOverloadResolu sb.Append("\t\t#endregion IMockStaticVerifyFor").Append(name).AppendLine(); } + foreach ((string Name, Class Class) hiddenBase in hiddenBaseInterfaces) + { + sb.AppendLine(); + sb.Append("\t\t#region IMockVerifyFor").Append(hiddenBase.Name).AppendLine(); + sb.AppendLine(); + ImplementVerifyInterface(sb, hiddenBase.Class, mockRegistryName, $"IMockVerifyFor{hiddenBase.Name}", MemberType.Public, memberIds, memberIdPrefix, false); + sb.Append("\t\t#endregion IMockVerifyFor").Append(hiddenBase.Name).AppendLine(); + } + #endregion Mock.Verify sb.AppendLine("\t}"); diff --git a/Tests/Mockolate.Tests/MockTests.HiddenInterfaceMemberTests.cs b/Tests/Mockolate.Tests/MockTests.HiddenInterfaceMemberTests.cs new file mode 100644 index 00000000..80d3483d --- /dev/null +++ b/Tests/Mockolate.Tests/MockTests.HiddenInterfaceMemberTests.cs @@ -0,0 +1,96 @@ +namespace Mockolate.Tests; + +public sealed partial class MockTests +{ + public sealed class HiddenInterfaceMemberTests + { + [Fact] + public async Task HiddenEvent_AsBase_ShouldVerifyBaseSlotSeparately() + { + IDerivedEvents mock = IDerivedEvents.CreateMock(); + EventHandler handler = (_, _) => { }; + ((IBaseEvents)mock).Changed += handler; + + await That(mock.Mock.As().Verify.Changed.Subscribed()).Once(); + await That(mock.Mock.Verify.Changed.Subscribed()).Never(); + } + + [Fact] + public async Task HiddenIndexer_AsBase_ShouldVerifyBaseAccess() + { + IDerivedIndexer mock = IDerivedIndexer.CreateMock(); + + _ = ((IBaseIndexer)mock)[5]; + + // Indexers are keyed by their parameter signature (not the declaring interface), so base and + // derived access share storage; As still reaches the same recorded interaction. + await That(mock.Mock.As().Verify[It.IsAny()].Got()).Once(); + } + + [Fact] + public async Task HiddenMethod_AsBase_ShouldConfigureAndVerifyBaseSlotSeparately() + { + IDerivedService mock = IDerivedService.CreateMock(); + mock.Mock.Setup.GetValue().Returns(42); + mock.Mock.As().Setup.GetValue().Returns(43); + + await That(mock.GetValue()).IsEqualTo(42); + await That(((IBaseService)mock).GetValue()).IsEqualTo(43); + await That(mock.Mock.Verify.GetValue()).Once(); + await That(mock.Mock.As().Verify.GetValue()).Once(); + } + + [Fact] + public async Task HiddenProperty_AsBase_ShouldReadBaseSlot() + { + IDerivedProperty mock = IDerivedProperty.CreateMock(); + mock.Mock.Setup.SomeProperty.InitializeWith("derived"); + mock.Mock.As().Setup.SomeProperty.Returns("base"); + + mock.SomeProperty = "updated"; + + await That(mock.SomeProperty).IsEqualTo("updated"); + await That(((IBaseProperty)mock).SomeProperty).IsEqualTo("base"); + } + + internal interface IBaseService + { + int GetValue(); + } + + internal interface IDerivedService : IBaseService + { + new int GetValue(); + } + + internal interface IBaseProperty + { + string SomeProperty { get; } + } + + internal interface IDerivedProperty : IBaseProperty + { + new string SomeProperty { get; set; } + } + + internal interface IBaseEvents + { + event EventHandler? Changed; + } + + internal interface IDerivedEvents : IBaseEvents + { + new event EventHandler? Changed; + } + + internal interface IBaseIndexer + { + int this[int index] { get; } + } + + internal interface IDerivedIndexer : IBaseIndexer + { + new int this[int index] { get; set; } + } + } +}