diff --git a/TUnit.Mocks.SourceGenerator.Tests/MockGeneratorTests.cs b/TUnit.Mocks.SourceGenerator.Tests/MockGeneratorTests.cs index 0b97591a19..a45c90a025 100644 --- a/TUnit.Mocks.SourceGenerator.Tests/MockGeneratorTests.cs +++ b/TUnit.Mocks.SourceGenerator.Tests/MockGeneratorTests.cs @@ -677,4 +677,30 @@ void M() return VerifyGeneratorOutput(source); } + + [Test] + public Task Interface_With_Unconstrained_Nullable_Generic() + { + var source = """ + using System.Threading.Tasks; + using TUnit.Mocks; + + public interface IFoo + { + Task DoSomethingAsync(); + T? GetValue(); + (T?, string) GetPair(); + } + + public class TestUsage + { + void M() + { + var mock = Mock.Of(); + } + } + """; + + return VerifyGeneratorOutput(source); + } } diff --git a/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Unconstrained_Nullable_Generic.verified.txt b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Unconstrained_Nullable_Generic.verified.txt new file mode 100644 index 0000000000..f555f82417 --- /dev/null +++ b/TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_Unconstrained_Nullable_Generic.verified.txt @@ -0,0 +1,160 @@ +// +#nullable enable + +namespace TUnit.Mocks.Generated +{ + public sealed class IFooMock : global::TUnit.Mocks.Mock, global::IFoo + { + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal IFooMock(global::IFoo mockObject, global::TUnit.Mocks.MockEngine engine) + : base(mockObject, engine) { } + + global::System.Threading.Tasks.Task global::IFoo.DoSomethingAsync() where T : default => Object.DoSomethingAsync(); + + T? global::IFoo.GetValue() where T : default => Object.GetValue(); + + (T?, string) global::IFoo.GetPair() where T : default => Object.GetPair(); + } +} + + +// ===== FILE SEPARATOR ===== + +// +#nullable enable + +namespace TUnit.Mocks.Generated +{ + internal static class IFooMockFactory + { + [global::System.Runtime.CompilerServices.ModuleInitializer] + internal static void Register() + { + global::TUnit.Mocks.MockRegistry.RegisterFactory(Create); + } + + internal static global::TUnit.Mocks.Mock Create(global::TUnit.Mocks.MockBehavior behavior, object[] constructorArgs) + { + if (constructorArgs.Length > 0) throw new global::System.ArgumentException($"Interface mock 'global::IFoo' does not support constructor arguments, but {constructorArgs.Length} were provided."); + var engine = new global::TUnit.Mocks.MockEngine(behavior); + var impl = new IFooMockImpl(engine); + engine.Raisable = impl; + var mock = new IFooMock(impl, engine); + return mock; + } + } +} + + +// ===== FILE SEPARATOR ===== + +// +#nullable enable + +namespace TUnit.Mocks.Generated +{ + internal sealed class IFooMockImpl : global::IFoo, global::TUnit.Mocks.IRaisable, global::TUnit.Mocks.IMockObject + { + private readonly global::TUnit.Mocks.MockEngine _engine; + + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + global::TUnit.Mocks.IMock? global::TUnit.Mocks.IMockObject.MockWrapper { get; set; } + + internal IFooMockImpl(global::TUnit.Mocks.MockEngine engine) + { + _engine = engine; + } + + public global::System.Threading.Tasks.Task DoSomethingAsync() + { + try + { + var __result = _engine.HandleCallWithReturn(0, "DoSomethingAsync", global::System.Array.Empty(), default); + if (global::TUnit.Mocks.Setup.RawReturnContext.TryConsume(out var __rawAsync)) + { + if (__rawAsync is global::System.Threading.Tasks.Task __typedAsync) return __typedAsync; + throw new global::System.InvalidOperationException($"ReturnsAsync: expected global::System.Threading.Tasks.Task but got {__rawAsync?.GetType().Name ?? "null"}"); + } + return global::System.Threading.Tasks.Task.FromResult(__result); + } + catch (global::System.Exception __ex) + { + return global::System.Threading.Tasks.Task.FromException(__ex); + } + } + + public T? GetValue() + { + return _engine.HandleCallWithReturn(1, "GetValue", global::System.Array.Empty(), default); + } + + public (T?, string) GetPair() + { + return _engine.HandleCallWithReturn<(T?, string)>(2, "GetPair", global::System.Array.Empty(), default); + } + + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + public void RaiseEvent(string eventName, object? args) + { + throw new global::System.InvalidOperationException($"No event named '{eventName}' exists on this mock."); + } + } +} + + +// ===== FILE SEPARATOR ===== + +// +#nullable enable + +namespace TUnit.Mocks.Generated +{ + public static class IFoo_MockMemberExtensions + { + public static global::TUnit.Mocks.MockMethodCall DoSomethingAsync(this global::TUnit.Mocks.Mock mock) + { + var matchers = global::System.Array.Empty(); + return new global::TUnit.Mocks.MockMethodCall(global::TUnit.Mocks.MockRegistry.GetEngine(mock), 0, "DoSomethingAsync", matchers); + } + + public static global::TUnit.Mocks.MockMethodCall GetValue(this global::TUnit.Mocks.Mock mock) + { + var matchers = global::System.Array.Empty(); + return new global::TUnit.Mocks.MockMethodCall(global::TUnit.Mocks.MockRegistry.GetEngine(mock), 1, "GetValue", matchers); + } + + public static global::TUnit.Mocks.MockMethodCall<(T?, string)> GetPair(this global::TUnit.Mocks.Mock mock) + { + var matchers = global::System.Array.Empty(); + return new global::TUnit.Mocks.MockMethodCall<(T?, string)>(global::TUnit.Mocks.MockRegistry.GetEngine(mock), 2, "GetPair", matchers); + } + } +} + + +// ===== FILE SEPARATOR ===== + +// +#nullable enable + +namespace TUnit.Mocks +{ + public static class IFoo_MockStaticExtension + { + extension(global::IFoo) + { + public static global::TUnit.Mocks.Generated.IFooMock Mock(global::TUnit.Mocks.MockBehavior behavior = global::TUnit.Mocks.MockBehavior.Loose) + { + return (global::TUnit.Mocks.Generated.IFooMock)global::TUnit.Mocks.Generated.IFooMockFactory.Create(behavior, []); + } + } + } +} + + +// ===== FILE SEPARATOR ===== + +// +#nullable enable + +namespace TUnit.Mocks.Generated; \ No newline at end of file diff --git a/TUnit.Mocks.SourceGenerator/Builders/MockBridgeBuilder.cs b/TUnit.Mocks.SourceGenerator/Builders/MockBridgeBuilder.cs index fb3cf4c934..7acbd90561 100644 --- a/TUnit.Mocks.SourceGenerator/Builders/MockBridgeBuilder.cs +++ b/TUnit.Mocks.SourceGenerator/Builders/MockBridgeBuilder.cs @@ -90,7 +90,7 @@ private static void GenerateStaticMethodDim(CodeWriter writer, MockMemberModel m var signatureReturnType = (method.IsVoid && !method.IsAsync) ? "void" : method.ReturnType; var paramList = MockImplBuilder.GetParameterList(method); var typeParams = MockImplBuilder.GetTypeParameterList(method); - var constraints = MockImplBuilder.GetConstraintClauses(method); + var constraints = MockImplBuilder.GetConstraintClauses(method, forExplicitImplementation: true); using (writer.Block($"static {signatureReturnType} {method.ExplicitInterfaceName}.{method.Name}{typeParams}({paramList}){constraints}")) { diff --git a/TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs b/TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs index 9300335c3b..ca2fef5dab 100644 --- a/TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs +++ b/TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs @@ -1110,10 +1110,10 @@ private static string FormatTypeParameterList(EquatableArray - FormatConstraintClauses(method.TypeParameters); + internal static string GetConstraintClauses(MockMemberModel method, bool forExplicitImplementation = false) => + FormatConstraintClauses(method.TypeParameters, forExplicitImplementation); - private static string FormatConstraintClauses(EquatableArray typeParameters) + private static string FormatConstraintClauses(EquatableArray typeParameters, bool forExplicitImplementation = false) { var clauses = new List(); foreach (var tp in typeParameters) @@ -1122,6 +1122,10 @@ private static string FormatConstraintClauses(EquatableArray 0 ? " " + string.Join(' ', clauses) : ""; } diff --git a/TUnit.Mocks.SourceGenerator/Builders/MockWrapperTypeBuilder.cs b/TUnit.Mocks.SourceGenerator/Builders/MockWrapperTypeBuilder.cs index 5ea4290aa4..d9e9b824d7 100644 --- a/TUnit.Mocks.SourceGenerator/Builders/MockWrapperTypeBuilder.cs +++ b/TUnit.Mocks.SourceGenerator/Builders/MockWrapperTypeBuilder.cs @@ -74,7 +74,7 @@ private static void GenerateMethodForwarding(CodeWriter writer, MockMemberModel var interfaceName = method.ExplicitInterfaceName ?? method.DeclaringInterfaceName ?? model.FullyQualifiedName; var paramList = MockImplBuilder.GetParameterList(method); var typeParams = MockImplBuilder.GetTypeParameterList(method); - var constraints = MockImplBuilder.GetConstraintClauses(method); + var constraints = MockImplBuilder.GetConstraintClauses(method, forExplicitImplementation: true); var argPassList = MockImplBuilder.GetArgPassList(method); var returnType = (method.IsVoid && !method.IsAsync) ? "void" : method.ReturnType; diff --git a/TUnit.Mocks.SourceGenerator/Discovery/MemberDiscovery.cs b/TUnit.Mocks.SourceGenerator/Discovery/MemberDiscovery.cs index a4bf2b2d8b..d07fa565bb 100644 --- a/TUnit.Mocks.SourceGenerator/Discovery/MemberDiscovery.cs +++ b/TUnit.Mocks.SourceGenerator/Discovery/MemberDiscovery.cs @@ -506,7 +506,8 @@ private static MockMemberModel CreateMethodModel(IMethodSymbol method, ref int m method.TypeParameters.Select(tp => new MockTypeParameterModel { Name = tp.Name, - Constraints = tp.GetGenericConstraints() + Constraints = tp.GetGenericConstraints(), + HasAnnotatedNullableUsage = tp.IsUnconstrainedWithNullableUsage(method) }).ToImmutableArray() ), ExplicitInterfaceName = explicitInterfaceName, diff --git a/TUnit.Mocks.SourceGenerator/Extensions/MethodSymbolExtensions.cs b/TUnit.Mocks.SourceGenerator/Extensions/MethodSymbolExtensions.cs index 7114cd0a62..3b39d75650 100644 --- a/TUnit.Mocks.SourceGenerator/Extensions/MethodSymbolExtensions.cs +++ b/TUnit.Mocks.SourceGenerator/Extensions/MethodSymbolExtensions.cs @@ -19,8 +19,19 @@ public static ParameterDirection GetParameterDirection(this IParameterSymbol par }; } + private static bool IsUnconstrained(this ITypeParameterSymbol typeParam) => + !typeParam.HasReferenceTypeConstraint && + !typeParam.HasValueTypeConstraint && + !typeParam.HasUnmanagedTypeConstraint && + !typeParam.HasNotNullConstraint && + typeParam.ConstraintTypes.Length == 0 && + !typeParam.HasConstructorConstraint; + public static string GetGenericConstraints(this ITypeParameterSymbol typeParam) { + if (typeParam.IsUnconstrained()) + return ""; + var constraints = new List(); if (typeParam.HasReferenceTypeConstraint) @@ -40,7 +51,53 @@ public static string GetGenericConstraints(this ITypeParameterSymbol typeParam) if (typeParam.HasConstructorConstraint) constraints.Add("new()"); - return constraints.Count > 0 ? string.Join(", ", constraints) : ""; + return string.Join(", ", constraints); + } + + public static bool IsUnconstrainedWithNullableUsage(this ITypeParameterSymbol typeParam, IMethodSymbol method) + { + if (!typeParam.IsUnconstrained()) + { + return false; + } + + if (HasNullableTypeParameter(method.ReturnType, typeParam)) + return true; + + foreach (var param in method.Parameters) + { + if (HasNullableTypeParameter(param.Type, typeParam)) + return true; + } + + return false; + } + + private static bool HasNullableTypeParameter(ITypeSymbol type, ITypeParameterSymbol typeParam) + { + if (type is ITypeParameterSymbol tp && + SymbolEqualityComparer.Default.Equals(tp.OriginalDefinition, typeParam.OriginalDefinition) && + tp.NullableAnnotation == NullableAnnotation.Annotated) + { + return true; + } + + if (type is INamedTypeSymbol named) + { + foreach (var arg in named.TypeArguments) + { + if (HasNullableTypeParameter(arg, typeParam)) + return true; + } + } + + if (type is IArrayTypeSymbol array) + { + if (HasNullableTypeParameter(array.ElementType, typeParam)) + return true; + } + + return false; } public static string GetParameterList(this IMethodSymbol method) diff --git a/TUnit.Mocks.SourceGenerator/Models/MockTypeParameterModel.cs b/TUnit.Mocks.SourceGenerator/Models/MockTypeParameterModel.cs index 0ff354c16f..c8c7a2cbb5 100644 --- a/TUnit.Mocks.SourceGenerator/Models/MockTypeParameterModel.cs +++ b/TUnit.Mocks.SourceGenerator/Models/MockTypeParameterModel.cs @@ -6,11 +6,12 @@ internal sealed record MockTypeParameterModel : IEquatable