diff --git a/Source/Mockolate.SourceGenerators/Sources/Sources.ForMock.cs b/Source/Mockolate.SourceGenerators/Sources/Sources.ForMock.cs index 4e3e22f4..0141c33a 100644 --- a/Source/Mockolate.SourceGenerators/Sources/Sources.ForMock.cs +++ b/Source/Mockolate.SourceGenerators/Sources/Sources.ForMock.cs @@ -150,6 +150,10 @@ namespace Mockolate.Generated; .Append(mockClass.ClassFullName).Append(">.Mock => _mock;").AppendLine(); sb.Append("\t[DebuggerBrowsable(DebuggerBrowsableState.Never)]").AppendLine(); sb.Append("\tprivate readonly Mock<").Append(mockClass.ClassFullName).Append("> _mock;").AppendLine(); + if (mockClass.IsInterface) + { + sb.Append("\tprivate readonly ").Append(mockClass.ClassFullName).Append("? _wrapped;").AppendLine(); + } sb.AppendLine(); if (mockClass.Constructors?.Count > 0) { @@ -182,11 +186,12 @@ namespace Mockolate.Generated; if (mockClass.IsInterface) { sb.Append("\t/// ").AppendLine(); - sb.Append("\tpublic MockFor").Append(name).Append("(MockBehavior mockBehavior)").AppendLine(); + sb.Append("\tpublic MockFor").Append(name).Append("(MockBehavior mockBehavior, ").Append(mockClass.ClassFullName).Append("? wrapped = null)").AppendLine(); sb.Append("\t{").AppendLine(); sb.Append("\t\t_mock = new Mock<").Append(mockClass.ClassFullName) .Append(">(this, new MockRegistration(mockBehavior, \"").Append(mockClass.DisplayString) .Append("\"));").AppendLine(); + sb.Append("\t\tthis._wrapped = wrapped;").AppendLine(); sb.Append("\t}").AppendLine(); sb.AppendLine(); } @@ -335,10 +340,34 @@ private static void AppendMockSubject_ImplementClass_AddEvent(StringBuilder sb, } sb.AppendLine("\t{"); - sb.Append("\t\tadd => MockRegistrations.AddEvent(").Append(@event.GetUniqueNameString()) - .Append(", value?.Target, value?.Method);").AppendLine(); - sb.Append("\t\tremove => MockRegistrations.RemoveEvent(") - .Append(@event.GetUniqueNameString()).Append(", value?.Target, value?.Method);").AppendLine(); + if (isClassInterface && !explicitInterfaceImplementation && @event.ExplicitImplementation is null) + { + sb.Append("\t\tadd").AppendLine(); + sb.Append("\t\t{").AppendLine(); + sb.Append("\t\t\tMockRegistrations.AddEvent(").Append(@event.GetUniqueNameString()) + .Append(", value?.Target, value?.Method);").AppendLine(); + sb.Append("\t\t\tif (this._wrapped is not null)").AppendLine(); + sb.Append("\t\t\t{").AppendLine(); + sb.Append("\t\t\t\tthis._wrapped.").Append(@event.Name).Append(" += value;").AppendLine(); + sb.Append("\t\t\t}").AppendLine(); + sb.Append("\t\t}").AppendLine(); + sb.Append("\t\tremove").AppendLine(); + sb.Append("\t\t{").AppendLine(); + sb.Append("\t\t\tMockRegistrations.RemoveEvent(").Append(@event.GetUniqueNameString()) + .Append(", value?.Target, value?.Method);").AppendLine(); + sb.Append("\t\t\tif (this._wrapped is not null)").AppendLine(); + sb.Append("\t\t\t{").AppendLine(); + sb.Append("\t\t\t\tthis._wrapped.").Append(@event.Name).Append(" -= value;").AppendLine(); + sb.Append("\t\t\t}").AppendLine(); + sb.Append("\t\t}").AppendLine(); + } + else + { + sb.Append("\t\tadd => MockRegistrations.AddEvent(").Append(@event.GetUniqueNameString()) + .Append(", value?.Target, value?.Method);").AppendLine(); + sb.Append("\t\tremove => MockRegistrations.RemoveEvent(") + .Append(@event.GetUniqueNameString()).Append(", value?.Target, value?.Method);").AppendLine(); + } sb.AppendLine("\t}"); } @@ -396,7 +425,48 @@ property.IndexerParameters is not null sb.AppendLine("get"); sb.AppendLine("\t\t{"); - if (!isClassInterface && !property.IsAbstract) + if (isClassInterface && !explicitInterfaceImplementation && property.ExplicitImplementation is null) + { + if (property is { IsIndexer: true, IndexerParameters: not null, }) + { + string indexerResultVarName = + Helpers.GetUniqueLocalVariableName("indexerResult", property.IndexerParameters.Value); + string baseResultVarName = + Helpers.GetUniqueLocalVariableName("baseResult", property.IndexerParameters.Value); + + sb.Append("\t\t\tif (this._wrapped is null)").AppendLine(); + sb.Append("\t\t\t{").AppendLine(); + sb.Append("\t\t\t\treturn MockRegistrations.GetIndexer<") + .AppendTypeOrWrapper(property.Type).Append(">(") + .Append(string.Join(", ", property.IndexerParameters.Value.Select(p => p.ToNameOrWrapper()))) + .Append(").GetResult(() => ") + .AppendDefaultValueGeneratorFor(property.Type, "MockRegistrations.Behavior.DefaultValue") + .Append(");").AppendLine(); + sb.Append("\t\t\t}").AppendLine(); + sb.Append("\t\t\tvar ").Append(indexerResultVarName).Append(" = MockRegistrations.GetIndexer<") + .AppendTypeOrWrapper(property.Type).Append(">(") + .Append(string.Join(", ", property.IndexerParameters.Value.Select(p => p.ToNameOrWrapper()))) + .AppendLine(");"); + sb.Append("\t\t\tvar ").Append(baseResultVarName).Append(" = this._wrapped[") + .Append(string.Join(", ", property.IndexerParameters.Value.Select(p => p.Name))) + .Append("];").AppendLine(); + sb.Append("\t\t\treturn ").Append(indexerResultVarName).Append(".GetResult(") + .Append(baseResultVarName) + .Append(", () => ") + .AppendDefaultValueGeneratorFor(property.Type, "MockRegistrations.Behavior.DefaultValue") + .Append(");").AppendLine(); + } + else + { + sb.Append( + "\t\t\treturn MockRegistrations.GetProperty<") + .AppendTypeOrWrapper(property.Type).Append(">(") + .Append(property.GetUniqueNameString()).Append(", () => ") + .AppendDefaultValueGeneratorFor(property.Type, "MockRegistrations.Behavior.DefaultValue") + .Append(", this._wrapped is null ? null : () => this._wrapped.").Append(property.Name).Append(");").AppendLine(); + } + } + else if (!isClassInterface && !property.IsAbstract) { if (property is { IsIndexer: true, IndexerParameters: not null, }) { @@ -468,7 +538,36 @@ property.IndexerParameters is not null sb.AppendLine("set"); sb.AppendLine("\t\t{"); - if (property is { IsIndexer: true, IndexerParameters: not null, }) + if (isClassInterface && !explicitInterfaceImplementation && property.ExplicitImplementation is null) + { + if (property is { IsIndexer: true, IndexerParameters: not null, }) + { + sb.Append( + "\t\t\tMockRegistrations.SetIndexer<") + .Append(property.Type.Fullname) + .Append(">(value, ") + .Append(string.Join(", ", property.IndexerParameters.Value.Select(p => p.ToNameOrWrapper()))) + .Append(");").AppendLine(); + + sb.Append("\t\t\tif (this._wrapped is not null)").AppendLine(); + sb.Append("\t\t\t{").AppendLine(); + sb.Append("\t\t\t\tthis._wrapped[") + .Append(string.Join(", ", property.IndexerParameters.Value.Select(p => p.Name))) + .AppendLine("] = value;"); + sb.Append("\t\t\t}").AppendLine(); + } + else + { + sb.Append( + "\t\t\tMockRegistrations.SetProperty(").Append(property.GetUniqueNameString()) + .Append(", value);").AppendLine(); + sb.Append("\t\t\tif (this._wrapped is not null)").AppendLine(); + sb.Append("\t\t\t{").AppendLine(); + sb.Append("\t\t\t\tthis._wrapped.").Append(property.Name).Append(" = value;").AppendLine(); + sb.Append("\t\t\t}").AppendLine(); + } + } + else if (property is { IsIndexer: true, IndexerParameters: not null, }) { if (!isClassInterface && !property.IsAbstract) { @@ -609,6 +708,80 @@ private static void AppendMockSubject_ImplementClass_AddMethod(StringBuilder sb, if (isClassInterface || method.IsAbstract) { + if (!explicitInterfaceImplementation && isClassInterface) + { + string baseResultVarName = Helpers.GetUniqueLocalVariableName("baseResult", method.Parameters); + if (method.ReturnType != Type.Void) + { + sb.Append( + "\t\tif (this._wrapped is not null)") + .AppendLine(); + sb.Append("\t\t{").AppendLine(); + sb.Append("\t\t\tvar ").Append(baseResultVarName).Append(" = this._wrapped").Append(".") + .Append(method.Name).Append('(') + .Append(string.Join(", ", method.Parameters.Select(p => $"{p.RefKind.GetString()}{p.Name}"))) + .Append(");").AppendLine(); + } + else + { + sb.Append( + "\t\tif (this._wrapped is not null)") + .AppendLine(); + sb.Append("\t\t{").AppendLine(); + sb.Append("\t\t\tthis._wrapped").Append(".") + .Append(method.Name).Append('(') + .Append(string.Join(", ", method.Parameters.Select(p => $"{p.RefKind.GetString()}{p.Name}"))) + .Append(");").AppendLine(); + } + + foreach (MethodParameter parameter in method.Parameters) + { + if (parameter.RefKind == RefKind.Out) + { + sb.Append( + "\t\t\tif (").Append(methodExecutionVarName).Append(".HasSetupResult == true)") + .AppendLine(); + sb.Append("\t\t\t{").AppendLine(); + sb.Append("\t\t\t\t").Append(parameter.Name).Append(" = ").Append(methodExecutionVarName) + .Append(".SetOutParameter<") + .Append(parameter.Type.Fullname).Append(">(\"").Append(parameter.Name) + .Append("\", () => ") + .AppendDefaultValueGeneratorFor(parameter.Type, + "MockRegistrations.Behavior.DefaultValue") + .Append(");").AppendLine(); + sb.Append("\t\t\t}").AppendLine().AppendLine(); + } + else if (parameter.RefKind == RefKind.Ref) + { + sb.Append( + "\t\t\tif (").Append(methodExecutionVarName).Append(".HasSetupResult == true)") + .AppendLine(); + sb.Append("\t\t\t{").AppendLine(); + sb.Append("\t\t\t\t").Append(parameter.Name).Append(" = ").Append(methodExecutionVarName) + .Append(".SetRefParameter<") + .Append(parameter.Type.Fullname).Append(">(\"").Append(parameter.Name).Append("\", ") + .Append(parameter.Name).Append(");").AppendLine(); + sb.Append("\t\t\t}").AppendLine().AppendLine(); + } + } + + if (method.ReturnType != Type.Void) + { + sb.Append( + "\t\t\tif (!").Append(methodExecutionVarName).Append(".HasSetupResult)") + .AppendLine(); + sb.Append("\t\t\t{").AppendLine(); + sb.Append("\t\t\t\t").Append(methodExecutionVarName).Append(".TriggerCallbacks(") + .Append( + string.Join(", ", method.Parameters.Select(p => p.ToNameOrNull()))) + .Append(");").AppendLine(); + sb.Append("\t\t\t\treturn ").Append(baseResultVarName).Append(";").AppendLine(); + sb.Append("\t\t\t}").AppendLine(); + } + + sb.Append("\t\t}").AppendLine(); + } + foreach (MethodParameter parameter in method.Parameters) { if (parameter.RefKind == RefKind.Out) diff --git a/Source/Mockolate.SourceGenerators/Sources/Sources.MockRegistration.cs b/Source/Mockolate.SourceGenerators/Sources/Sources.MockRegistration.cs index 93e9badc..2f78c1e6 100644 --- a/Source/Mockolate.SourceGenerators/Sources/Sources.MockRegistration.cs +++ b/Source/Mockolate.SourceGenerators/Sources/Sources.MockRegistration.cs @@ -240,6 +240,49 @@ namespace Mockolate; } sb.AppendLine("\t\t}"); + + sb.AppendLine(); + sb.AppendLine( + "\t\tpartial void GenerateWrapped(T instance, MockBehavior mockBehavior, Action>[] setups)"); + sb.Append("\t\t{").AppendLine(); + index = 0; + foreach ((string Name, MockClass MockClass) mock in mocks.Where(m => m.MockClass.AdditionalImplementations.Count == 0)) + { + if (index++ > 0) + { + sb.Append("\t\t\telse "); + } + else + { + sb.Append("\t\t\t"); + } + + sb.Append("if (typeof(T) == typeof(").Append(mock.MockClass.ClassFullName).Append("))").AppendLine(); + sb.Append("\t\t\t{").AppendLine(); + + sb.Append("\t\t\t\tMockRegistration mockRegistration = new MockRegistration(mockBehavior, \"") + .Append(mock.MockClass.DisplayString).Append("\");").AppendLine(); + + if (mock.MockClass.IsInterface) + { + sb.Append("\t\t\t\t_value = new MockFor").Append(mock.Name).Append("(mockBehavior, instance as ").Append(mock.MockClass.ClassFullName).Append(");").AppendLine(); + sb.Append("\t\t\t\tif (setups.Length > 0)").AppendLine(); + sb.Append("\t\t\t\t{").AppendLine(); + sb.Append("\t\t\t\t\tIMockSetup<").Append(mock.MockClass.ClassFullName) + .Append("> setupTarget = ((IMockSubject<").Append(mock.MockClass.ClassFullName) + .Append(">)_value).Mock;").AppendLine(); + sb.Append("\t\t\t\t\tforeach (Action> setup in setups)").AppendLine(); + sb.Append("\t\t\t\t\t{").AppendLine(); + sb.Append("\t\t\t\t\t\tsetup.Invoke(setupTarget);").AppendLine(); + sb.Append("\t\t\t\t\t}").AppendLine(); + sb.Append("\t\t\t\t}").AppendLine(); + } + + sb.Append("\t\t\t}").AppendLine(); + } + sb.AppendLine("\t\t}"); + sb.AppendLine("\t}"); sb.AppendLine(); } diff --git a/Source/Mockolate.SourceGenerators/Sources/Sources.cs b/Source/Mockolate.SourceGenerators/Sources/Sources.cs index 77bcd692..7bc9df14 100644 --- a/Source/Mockolate.SourceGenerators/Sources/Sources.cs +++ b/Source/Mockolate.SourceGenerators/Sources/Sources.cs @@ -395,6 +395,65 @@ public T Create(BaseClass.ConstructorParameters constructorParameters, params sb.AppendLine("\t}"); sb.AppendLine(); + sb.AppendLine(""" + /// + /// Wraps a concrete instance with a mock proxy that intercepts and delegates method calls, + /// supporting setup and verification on the wrapped instance. + /// + /// Type to wrap, which can be an interface or a class. + /// The concrete instance to wrap. + /// Optional setup actions to configure the mock. + /// + /// When no setup is specified for a method, the call is delegated to the wrapped instance. + /// Setup and verification work the same as with regular mocks. + /// + [MockGenerator] + public static T Wrap(T instance, params Action>[] setups) + where T : class + { + if (instance == null) + { + throw new ArgumentNullException(nameof(instance)); + } + + ThrowIfNotMockable(typeof(T)); + + return new MockGenerator().GetWrapped(instance, MockBehavior.Default, setups) + ?? throw new MockException("Could not generate wrapped Mock. Did the source generator run correctly?"); + } + """); + sb.AppendLine(); + + sb.AppendLine(""" + /// + /// Wraps a concrete instance with a mock proxy that intercepts and delegates method calls, + /// supporting setup and verification on the wrapped instance. + /// + /// Type to wrap, which can be an interface or a class. + /// The concrete instance to wrap. + /// The behavior settings for the mock. + /// Optional setup actions to configure the mock. + /// + /// When no setup is specified for a method, the call is delegated to the wrapped instance. + /// Setup and verification work the same as with regular mocks. + /// + [MockGenerator] + public static T Wrap(T instance, MockBehavior mockBehavior, params Action>[] setups) + where T : class + { + if (instance == null) + { + throw new ArgumentNullException(nameof(instance)); + } + + ThrowIfNotMockable(typeof(T)); + + return new MockGenerator().GetWrapped(instance, mockBehavior, setups) + ?? throw new MockException("Could not generate wrapped Mock. Did the source generator run correctly?"); + } + """); + sb.AppendLine(); + sb.AppendLine(""" private static void ThrowIfNotMockable(Type type) { @@ -414,6 +473,7 @@ private partial class MockGenerator #pragma warning restore CS0649 partial void Generate(BaseClass.ConstructorParameters? constructorParameters, MockBehavior mockBehavior, Action>[] setups, params Type[] types); + partial void GenerateWrapped(T instance, MockBehavior mockBehavior, Action>[] setups); public object? Get(BaseClass.ConstructorParameters? constructorParameters, MockBehavior mockBehavior, Type type) { @@ -442,6 +502,16 @@ private partial class MockGenerator """); } + sb.AppendLine(); + sb.AppendLine(""" + public T? GetWrapped(T instance, MockBehavior mockBehavior, Action>[] setups) + where T : class + { + GenerateWrapped(instance, mockBehavior, setups); + return _value as T; + } + """); + sb.AppendLine(""" } } diff --git a/Tests/Mockolate.SourceGenerators.Tests/GeneralTests.cs b/Tests/Mockolate.SourceGenerators.Tests/GeneralTests.cs index b3b1bddc..6dd49e37 100644 --- a/Tests/Mockolate.SourceGenerators.Tests/GeneralTests.cs +++ b/Tests/Mockolate.SourceGenerators.Tests/GeneralTests.cs @@ -459,11 +459,15 @@ public string SomeProperty { get { - return MockRegistrations.GetProperty("MyCode.IMyService.SomeProperty", () => MockRegistrations.Behavior.DefaultValue.Generate(default(string)!), null); + return MockRegistrations.GetProperty("MyCode.IMyService.SomeProperty", () => MockRegistrations.Behavior.DefaultValue.Generate(default(string)!), this._wrapped is null ? null : () => this._wrapped.SomeProperty); } set { MockRegistrations.SetProperty("MyCode.IMyService.SomeProperty", value); + if (this._wrapped is not null) + { + this._wrapped.SomeProperty = value; + } } } """).IgnoringNewlineStyle().And @@ -473,6 +477,15 @@ public string SomeProperty public string MyMethod(string message) { MethodSetupResult methodExecution = MockRegistrations.InvokeMethod("MyCode.IMyService.MyMethod", p => MockRegistrations.Behavior.DefaultValue.Generate(default(string)!, p), message); + if (this._wrapped is not null) + { + var baseResult = this._wrapped.MyMethod(message); + if (!methodExecution.HasSetupResult) + { + methodExecution.TriggerCallbacks(message); + return baseResult; + } + } methodExecution.TriggerCallbacks(message); return methodExecution.Result; } @@ -482,8 +495,22 @@ public string MyMethod(string message) [MyCode.Custom(true, (byte)42, 'X', 3.14, 2.71F, 100, 999L, (sbyte)-10, (short)500, "test", 123u, 456uL, (ushort)789, typeof(string), (MyCode.MyEnum)2, new int[]{1, 2, 3}, BoolParam = false, ByteParam = (byte)99, CharParam = 'Y', DoubleParam = 1.23, FloatParam = 4.56F, IntParam = 200, LongParam = 888L, SByteParam = (sbyte)-5, ShortParam = (short)300, StringParam = "named", UIntParam = 111u, ULongParam = 222uL, UShortParam = (ushort)333, ObjectParam = 42, TypeParam = typeof(int), EnumParam = (MyCode.MyFlagEnum)3, ArrayParam = new string[]{"a", "b"})] public event System.EventHandler? MyEvent { - add => MockRegistrations.AddEvent("MyCode.IMyService.MyEvent", value?.Target, value?.Method); - remove => MockRegistrations.RemoveEvent("MyCode.IMyService.MyEvent", value?.Target, value?.Method); + add + { + MockRegistrations.AddEvent("MyCode.IMyService.MyEvent", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyEvent += value; + } + } + remove + { + MockRegistrations.RemoveEvent("MyCode.IMyService.MyEvent", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyEvent -= value; + } + } } """).IgnoringNewlineStyle(); } diff --git a/Tests/Mockolate.SourceGenerators.Tests/Sources/ForMockTests.ImplementClassTests.cs b/Tests/Mockolate.SourceGenerators.Tests/Sources/ForMockTests.ImplementClassTests.cs index 06a780a1..2ffb283a 100644 --- a/Tests/Mockolate.SourceGenerators.Tests/Sources/ForMockTests.ImplementClassTests.cs +++ b/Tests/Mockolate.SourceGenerators.Tests/Sources/ForMockTests.ImplementClassTests.cs @@ -72,16 +72,44 @@ await That(result.Sources).ContainsKey("MockForIMyService.g.cs").WhoseValue /// public event System.EventHandler? SomeEvent { - add => MockRegistrations.AddEvent("MyCode.IMyService.SomeEvent", value?.Target, value?.Method); - remove => MockRegistrations.RemoveEvent("MyCode.IMyService.SomeEvent", value?.Target, value?.Method); + add + { + MockRegistrations.AddEvent("MyCode.IMyService.SomeEvent", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.SomeEvent += value; + } + } + remove + { + MockRegistrations.RemoveEvent("MyCode.IMyService.SomeEvent", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.SomeEvent -= value; + } + } } """).IgnoringNewlineStyle().And .Contains(""" /// public event System.EventHandler? SomeOtherEvent { - add => MockRegistrations.AddEvent("MyCode.IMyService.SomeOtherEvent", value?.Target, value?.Method); - remove => MockRegistrations.RemoveEvent("MyCode.IMyService.SomeOtherEvent", value?.Target, value?.Method); + add + { + MockRegistrations.AddEvent("MyCode.IMyService.SomeOtherEvent", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.SomeOtherEvent += value; + } + } + remove + { + MockRegistrations.RemoveEvent("MyCode.IMyService.SomeOtherEvent", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.SomeOtherEvent -= value; + } + } } """).IgnoringNewlineStyle(); } @@ -129,32 +157,88 @@ await That(result.Sources).ContainsKey("MockForIMyService.g.cs").WhoseValue /// public event System.EventHandler? MyDirectEvent { - add => MockRegistrations.AddEvent("MyCode.IMyService.MyDirectEvent", value?.Target, value?.Method); - remove => MockRegistrations.RemoveEvent("MyCode.IMyService.MyDirectEvent", value?.Target, value?.Method); + add + { + MockRegistrations.AddEvent("MyCode.IMyService.MyDirectEvent", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyDirectEvent += value; + } + } + remove + { + MockRegistrations.RemoveEvent("MyCode.IMyService.MyDirectEvent", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyDirectEvent -= value; + } + } } """).IgnoringNewlineStyle().And .Contains(""" /// public event System.EventHandler? MyBaseEvent1 { - add => MockRegistrations.AddEvent("MyCode.IMyServiceBase1.MyBaseEvent1", value?.Target, value?.Method); - remove => MockRegistrations.RemoveEvent("MyCode.IMyServiceBase1.MyBaseEvent1", value?.Target, value?.Method); + add + { + MockRegistrations.AddEvent("MyCode.IMyServiceBase1.MyBaseEvent1", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyBaseEvent1 += value; + } + } + remove + { + MockRegistrations.RemoveEvent("MyCode.IMyServiceBase1.MyBaseEvent1", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyBaseEvent1 -= value; + } + } } """).IgnoringNewlineStyle().And .Contains(""" /// public event System.EventHandler? MyBaseEvent2 { - add => MockRegistrations.AddEvent("MyCode.IMyServiceBase2.MyBaseEvent2", value?.Target, value?.Method); - remove => MockRegistrations.RemoveEvent("MyCode.IMyServiceBase2.MyBaseEvent2", value?.Target, value?.Method); + add + { + MockRegistrations.AddEvent("MyCode.IMyServiceBase2.MyBaseEvent2", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyBaseEvent2 += value; + } + } + remove + { + MockRegistrations.RemoveEvent("MyCode.IMyServiceBase2.MyBaseEvent2", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyBaseEvent2 -= value; + } + } } """).IgnoringNewlineStyle().And .Contains(""" /// public event System.EventHandler? MyBaseEvent3 { - add => MockRegistrations.AddEvent("MyCode.IMyServiceBase3.MyBaseEvent3", value?.Target, value?.Method); - remove => MockRegistrations.RemoveEvent("MyCode.IMyServiceBase3.MyBaseEvent3", value?.Target, value?.Method); + add + { + MockRegistrations.AddEvent("MyCode.IMyServiceBase3.MyBaseEvent3", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyBaseEvent3 += value; + } + } + remove + { + MockRegistrations.RemoveEvent("MyCode.IMyServiceBase3.MyBaseEvent3", value?.Target, value?.Method); + if (this._wrapped is not null) + { + this._wrapped.MyBaseEvent3 -= value; + } + } } """).IgnoringNewlineStyle(); } @@ -240,11 +324,21 @@ public int this[int index] { get { - return MockRegistrations.GetIndexer(index).GetResult(() => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!)); + if (this._wrapped is null) + { + return MockRegistrations.GetIndexer(index).GetResult(() => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!)); + } + var indexerResult = MockRegistrations.GetIndexer(index); + var baseResult = this._wrapped[index]; + return indexerResult.GetResult(baseResult, () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!)); } set { MockRegistrations.SetIndexer(value, index); + if (this._wrapped is not null) + { + this._wrapped[index] = value; + } } } """).IgnoringNewlineStyle().And @@ -254,7 +348,13 @@ public int this[int index] { get { - return MockRegistrations.GetIndexer(index, isReadOnly).GetResult(() => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!)); + if (this._wrapped is null) + { + return MockRegistrations.GetIndexer(index, isReadOnly).GetResult(() => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!)); + } + var indexerResult = MockRegistrations.GetIndexer(index, isReadOnly); + var baseResult = this._wrapped[index, isReadOnly]; + return indexerResult.GetResult(baseResult, () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!)); } } """).IgnoringNewlineStyle().And @@ -265,6 +365,10 @@ public int this[int index] set { MockRegistrations.SetIndexer(value, index, isWriteOnly); + if (this._wrapped is not null) + { + this._wrapped[index, isWriteOnly] = value; + } } } """).IgnoringNewlineStyle(); @@ -539,6 +643,15 @@ await That(result.Sources).ContainsKey("MockForIMyService.g.cs").WhoseValue public bool MyMethod1(int index) { MethodSetupResult methodExecution = MockRegistrations.InvokeMethod("MyCode.IMyService.MyMethod1", p => MockRegistrations.Behavior.DefaultValue.Generate(default(bool)!, p), index); + if (this._wrapped is not null) + { + var baseResult = this._wrapped.MyMethod1(index); + if (!methodExecution.HasSetupResult) + { + methodExecution.TriggerCallbacks(index); + return baseResult; + } + } methodExecution.TriggerCallbacks(index); return methodExecution.Result; } @@ -548,6 +661,10 @@ public bool MyMethod1(int index) public void MyMethod2(int index, bool isReadOnly) { MethodSetupResult methodExecution = MockRegistrations.InvokeMethod("MyCode.IMyService.MyMethod2", index, isReadOnly); + if (this._wrapped is not null) + { + this._wrapped.MyMethod2(index, isReadOnly); + } methodExecution.TriggerCallbacks(index, isReadOnly); } """).IgnoringNewlineStyle(); @@ -597,6 +714,15 @@ await That(result.Sources).ContainsKey("MockForIMyService.g.cs").WhoseValue public int MyDirectMethod(int value) { MethodSetupResult methodExecution = MockRegistrations.InvokeMethod("MyCode.IMyService.MyDirectMethod", p => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!, p), value); + if (this._wrapped is not null) + { + var baseResult = this._wrapped.MyDirectMethod(value); + if (!methodExecution.HasSetupResult) + { + methodExecution.TriggerCallbacks(value); + return baseResult; + } + } methodExecution.TriggerCallbacks(value); return methodExecution.Result; } @@ -606,6 +732,15 @@ public int MyDirectMethod(int value) public int MyBaseMethod1(int value) { MethodSetupResult methodExecution = MockRegistrations.InvokeMethod("MyCode.IMyServiceBase1.MyBaseMethod1", p => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!, p), value); + if (this._wrapped is not null) + { + var baseResult = this._wrapped.MyBaseMethod1(value); + if (!methodExecution.HasSetupResult) + { + methodExecution.TriggerCallbacks(value); + return baseResult; + } + } methodExecution.TriggerCallbacks(value); return methodExecution.Result; } @@ -615,6 +750,15 @@ public int MyBaseMethod1(int value) public int MyBaseMethod2(int value) { MethodSetupResult methodExecution = MockRegistrations.InvokeMethod("MyCode.IMyServiceBase2.MyBaseMethod2", p => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!, p), value); + if (this._wrapped is not null) + { + var baseResult = this._wrapped.MyBaseMethod2(value); + if (!methodExecution.HasSetupResult) + { + methodExecution.TriggerCallbacks(value); + return baseResult; + } + } methodExecution.TriggerCallbacks(value); return methodExecution.Result; } @@ -624,6 +768,15 @@ public int MyBaseMethod2(int value) public int MyBaseMethod3(int value) { MethodSetupResult methodExecution = MockRegistrations.InvokeMethod("MyCode.IMyServiceBase3.MyBaseMethod3", p => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!, p), value); + if (this._wrapped is not null) + { + var baseResult = this._wrapped.MyBaseMethod3(value); + if (!methodExecution.HasSetupResult) + { + methodExecution.TriggerCallbacks(value); + return baseResult; + } + } methodExecution.TriggerCallbacks(value); return methodExecution.Result; } @@ -733,6 +886,15 @@ await That(result.Sources).ContainsKey("MockForIMyService.g.cs").WhoseValue public void MyMethod1(ref int index) { MethodSetupResult methodExecution = MockRegistrations.InvokeMethod("MyCode.IMyService.MyMethod1", index); + if (this._wrapped is not null) + { + this._wrapped.MyMethod1(ref index); + if (methodExecution.HasSetupResult == true) + { + index = methodExecution.SetRefParameter("index", index); + } + + } index = methodExecution.SetRefParameter("index", index); methodExecution.TriggerCallbacks(index); } @@ -742,6 +904,20 @@ public void MyMethod1(ref int index) public bool MyMethod2(int index, out bool isReadOnly) { MethodSetupResult methodExecution = MockRegistrations.InvokeMethod("MyCode.IMyService.MyMethod2", p => MockRegistrations.Behavior.DefaultValue.Generate(default(bool)!, p), index, null); + if (this._wrapped is not null) + { + var baseResult = this._wrapped.MyMethod2(index, out isReadOnly); + if (methodExecution.HasSetupResult == true) + { + isReadOnly = methodExecution.SetOutParameter("isReadOnly", () => MockRegistrations.Behavior.DefaultValue.Generate(default(bool)!)); + } + + if (!methodExecution.HasSetupResult) + { + methodExecution.TriggerCallbacks(index, isReadOnly); + return baseResult; + } + } isReadOnly = methodExecution.SetOutParameter("isReadOnly", () => MockRegistrations.Behavior.DefaultValue.Generate(default(bool)!)); methodExecution.TriggerCallbacks(index, isReadOnly); return methodExecution.Result; @@ -820,11 +996,15 @@ public int SomeProperty { get { - return MockRegistrations.GetProperty("MyCode.IMyService.SomeProperty", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), null); + return MockRegistrations.GetProperty("MyCode.IMyService.SomeProperty", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), this._wrapped is null ? null : () => this._wrapped.SomeProperty); } set { MockRegistrations.SetProperty("MyCode.IMyService.SomeProperty", value); + if (this._wrapped is not null) + { + this._wrapped.SomeProperty = value; + } } } """).IgnoringNewlineStyle().And @@ -834,7 +1014,7 @@ public bool? SomeReadOnlyProperty { get { - return MockRegistrations.GetProperty("MyCode.IMyService.SomeReadOnlyProperty", () => MockRegistrations.Behavior.DefaultValue.Generate(default(bool?)!), null); + return MockRegistrations.GetProperty("MyCode.IMyService.SomeReadOnlyProperty", () => MockRegistrations.Behavior.DefaultValue.Generate(default(bool?)!), this._wrapped is null ? null : () => this._wrapped.SomeReadOnlyProperty); } } """).IgnoringNewlineStyle().And @@ -845,6 +1025,10 @@ public bool? SomeWriteOnlyProperty set { MockRegistrations.SetProperty("MyCode.IMyService.SomeWriteOnlyProperty", value); + if (this._wrapped is not null) + { + this._wrapped.SomeWriteOnlyProperty = value; + } } } """).IgnoringNewlineStyle(); @@ -895,11 +1079,15 @@ public int MyDirectProperty { get { - return MockRegistrations.GetProperty("MyCode.IMyService.MyDirectProperty", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), null); + return MockRegistrations.GetProperty("MyCode.IMyService.MyDirectProperty", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), this._wrapped is null ? null : () => this._wrapped.MyDirectProperty); } set { MockRegistrations.SetProperty("MyCode.IMyService.MyDirectProperty", value); + if (this._wrapped is not null) + { + this._wrapped.MyDirectProperty = value; + } } } """).IgnoringNewlineStyle().And @@ -909,11 +1097,15 @@ public int MyBaseProperty1 { get { - return MockRegistrations.GetProperty("MyCode.IMyServiceBase1.MyBaseProperty1", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), null); + return MockRegistrations.GetProperty("MyCode.IMyServiceBase1.MyBaseProperty1", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), this._wrapped is null ? null : () => this._wrapped.MyBaseProperty1); } set { MockRegistrations.SetProperty("MyCode.IMyServiceBase1.MyBaseProperty1", value); + if (this._wrapped is not null) + { + this._wrapped.MyBaseProperty1 = value; + } } } """).IgnoringNewlineStyle().And @@ -923,11 +1115,15 @@ public int MyBaseProperty2 { get { - return MockRegistrations.GetProperty("MyCode.IMyServiceBase2.MyBaseProperty2", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), null); + return MockRegistrations.GetProperty("MyCode.IMyServiceBase2.MyBaseProperty2", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), this._wrapped is null ? null : () => this._wrapped.MyBaseProperty2); } set { MockRegistrations.SetProperty("MyCode.IMyServiceBase2.MyBaseProperty2", value); + if (this._wrapped is not null) + { + this._wrapped.MyBaseProperty2 = value; + } } } """).IgnoringNewlineStyle().And @@ -937,11 +1133,15 @@ public int MyBaseProperty3 { get { - return MockRegistrations.GetProperty("MyCode.IMyServiceBase3.MyBaseProperty3", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), null); + return MockRegistrations.GetProperty("MyCode.IMyServiceBase3.MyBaseProperty3", () => MockRegistrations.Behavior.DefaultValue.Generate(default(int)!), this._wrapped is null ? null : () => this._wrapped.MyBaseProperty3); } set { MockRegistrations.SetProperty("MyCode.IMyServiceBase3.MyBaseProperty3", value); + if (this._wrapped is not null) + { + this._wrapped.MyBaseProperty3 = value; + } } } """).IgnoringNewlineStyle(); diff --git a/Tests/Mockolate.Tests/MockTests.WrapTests.cs b/Tests/Mockolate.Tests/MockTests.WrapTests.cs new file mode 100644 index 00000000..6d2e8d27 --- /dev/null +++ b/Tests/Mockolate.Tests/MockTests.WrapTests.cs @@ -0,0 +1,174 @@ +using System.Collections.Generic; +using Mockolate.Tests.TestHelpers; + +namespace Mockolate.Tests; + +public sealed partial class MockTests +{ + public sealed class WrapTests + { + [Fact] + public async Task Wrap_Events_ForwardEventsFromWrappedInstance() + { + MyChocolateDispenser myDispenser = new(); + IChocolateDispenser wrappedDispenser = Mock.Wrap(myDispenser); + + string? eventType = null; + int eventAmount = 0; + + wrappedDispenser.ChocolateDispensed += (type, amt) => + { + eventType = type; + eventAmount = amt; + }; + + myDispenser.Dispense("Milk", 3); + + await That(eventType).IsEqualTo("Milk"); + await That(eventAmount).IsEqualTo(3); + } + + [Fact] + public async Task Wrap_Events_ForwardsFromWrapper() + { + MyChocolateDispenser myDispenser = new(); + IChocolateDispenser wrappedDispenser = Mock.Wrap(myDispenser); + + string? eventType = null; + int eventAmount = 0; + + myDispenser.ChocolateDispensed += (type, amt) => + { + eventType = type; + eventAmount = amt; + }; + + wrappedDispenser.Dispense("Dark", 1); + + await That(eventType).IsEqualTo("Dark"); + await That(eventAmount).IsEqualTo(1); + } + + [Fact] + public async Task Wrap_Events_Unsubscribe_ShouldRemoveSubscription() + { + MyChocolateDispenser myDispenser = new(); + IChocolateDispenser wrappedDispenser = Mock.Wrap(myDispenser); + + string? eventType = null; + int eventAmount = -1; + + wrappedDispenser.ChocolateDispensed += Handler; + + myDispenser.Dispense("Milk", 3); + + await That(eventType).IsEqualTo("Milk"); + await That(eventAmount).IsEqualTo(3); + + wrappedDispenser.ChocolateDispensed -= Handler; + + myDispenser.Dispense("Dark", 6); + + await That(eventType).IsEqualTo("Milk"); + await That(eventAmount).IsEqualTo(3); + + void Handler(string type, int amount) + { + eventType = type; + eventAmount = amount; + } + } + + [Fact] + public async Task Wrap_Indexer_ShouldDelegateToWrappedInstance() + { + MyChocolateDispenser myDispenser = new(); + IChocolateDispenser wrappedDispenser = Mock.Wrap(myDispenser); + + wrappedDispenser["Dark"] = 12; + + await That(wrappedDispenser["Dark"]).IsEqualTo(12); + await That(myDispenser["Dark"]).IsEqualTo(12); + await That(wrappedDispenser["White"]).IsEqualTo(8); + await That(myDispenser["White"]).IsEqualTo(8); + } + + [Fact] + public async Task Wrap_Method_ShouldDelegateToWrappedInstance() + { + MyChocolateDispenser myDispenser = new(); + IChocolateDispenser wrappedDispenser = Mock.Wrap(myDispenser); + + bool result = wrappedDispenser.Dispense("Dark", 4); + + await That(result).IsTrue(); + await That(wrappedDispenser["Dark"]).IsEqualTo(1); + await That(myDispenser["Dark"]).IsEqualTo(1); + } + + [Fact] + public async Task Wrap_Property_ShouldDelegateToWrappedInstance() + { + MyChocolateDispenser myDispenser = new(); + IChocolateDispenser wrappedDispenser = Mock.Wrap(myDispenser); + + wrappedDispenser.TotalDispensed = 12; + + await That(wrappedDispenser.TotalDispensed).IsEqualTo(12); + await That(myDispenser.TotalDispensed).IsEqualTo(12); + } + + [Fact] + public async Task Wrap_WithSetup_ShouldOverrideMethod() + { + MyChocolateDispenser myDispenser = new(); + IChocolateDispenser wrappedDispenser = Mock.Wrap(myDispenser); + wrappedDispenser.SetupMock.Method.Dispense(It.IsAny(), It.IsAny()).Returns(false); + + bool result = wrappedDispenser.Dispense("Dark", 4); + + await That(result).IsFalse(); + await That(wrappedDispenser["Dark"]).IsEqualTo(1); + await That(myDispenser.TotalDispensed).IsEqualTo(4); + } + + private class MyChocolateDispenser : IChocolateDispenser + { + private readonly Dictionary _inventory = new() + { + { + "Milk", 10 + }, + { + "Dark", 5 + }, + { + "White", 8 + }, + }; + + public int this[string type] + { + get => _inventory[type]; + set => _inventory[type] = value; + } + + public int TotalDispensed { get; set; } + + public bool Dispense(string type, int amount) + { + if (_inventory[type] >= amount) + { + TotalDispensed += amount; + _inventory[type] -= amount; + ChocolateDispensed?.Invoke(type, amount); + return true; + } + + return false; + } + + public event ChocolateDispensedDelegate? ChocolateDispensed; + } + } +}