diff --git a/TUnit.Assertions.SourceGenerator.IncrementalTests/AssertionMethodGeneratorIncrementalTests.cs b/TUnit.Assertions.SourceGenerator.IncrementalTests/AssertionMethodGeneratorIncrementalTests.cs new file mode 100644 index 0000000000..f91b8f9199 --- /dev/null +++ b/TUnit.Assertions.SourceGenerator.IncrementalTests/AssertionMethodGeneratorIncrementalTests.cs @@ -0,0 +1,114 @@ +using Microsoft.CodeAnalysis.CSharp; + +namespace TUnit.Assertions.SourceGenerator.IncrementalTests; + +public class AssertionMethodGeneratorIncrementalTests +{ + private const string DefaultAssertion = + """ + #nullable enabled + using System.ComponentModel; + using TUnit.Assertions.Attributes; + + public static partial class IntAssertionExtensions + { + [GenerateAssertion(ExpectationMessage = "to be positive")] + public static bool IsPositive(this int value) + { + return value > 0; + } + + public static bool IsNegative(this int value) + { + return value < 0; + } + } + """; + + [Fact] + public void AddUnrelatedMethodShouldNotRegenerate() + { + var syntaxTree = CSharpSyntaxTree.ParseText(DefaultAssertion, CSharpParseOptions.Default); + var compilation1 = Fixture.CreateLibrary(syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + TestHelper.AssertRunReasons(driver1, IncrementalGeneratorRunReasons.New); + + var compilation2 = compilation1.AddSyntaxTrees(CSharpSyntaxTree.ParseText("struct MyValue {}")); + var driver2 = driver1.RunGenerators(compilation2); + TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Cached); + } + + [Fact] + public void AddNewTypeAssertionShouldRegenerate() + { + var syntaxTree = CSharpSyntaxTree.ParseText(DefaultAssertion, CSharpParseOptions.Default); + var compilation1 = Fixture.CreateLibrary(syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + TestHelper.AssertRunReasons(driver1, IncrementalGeneratorRunReasons.New); + + var compilation2 = compilation1.AddSyntaxTrees(CSharpSyntaxTree.ParseText( + """ + using TUnit.Assertions.Attributes; + + public static partial class LongAssertionExtensions + { + [GenerateAssertion(ExpectationMessage = "to be positive")] + public static bool IsPositive(this long value) + { + return value > 0; + } + } + """)); + var driver2 = driver1.RunGenerators(compilation2); + TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Cached, 0); + TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.New, 1); + } + + [Fact] + public void AddNewSameTypeAssertionShouldRegenerate() + { + var syntaxTree = CSharpSyntaxTree.ParseText(DefaultAssertion, CSharpParseOptions.Default); + var compilation1 = Fixture.CreateLibrary(syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + TestHelper.AssertRunReasons(driver1, IncrementalGeneratorRunReasons.New); + + var compilation2 = TestHelper.ReplaceMethodDeclaration(compilation1, "IsNegative", + """ + [GenerateAssertion(ExpectationMessage = "to be less than zero")] + public static bool IsNegative(this int value) + { + return value < 0; + } + """ + ); + var driver2 = driver1.RunGenerators(compilation2); + TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Cached, 0); + TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.New, 1); + } + + [Fact] + public void ModifyMessageShouldRegenerate() + { + var syntaxTree = CSharpSyntaxTree.ParseText(DefaultAssertion, CSharpParseOptions.Default); + var compilation1 = Fixture.CreateLibrary(syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + TestHelper.AssertRunReasons(driver1, IncrementalGeneratorRunReasons.New); + + var compilation2 = TestHelper.ReplaceMethodDeclaration(compilation1, "IsPositive", + """ + [GenerateAssertion(ExpectationMessage = "to be more than zero")] + public static bool IsPositive(this int value) + { + return value > 0; + } + """ + ); + var driver2 = driver1.RunGenerators(compilation2); + TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Modified); + } + +} diff --git a/TUnit.Assertions.SourceGenerator.IncrementalTests/Fixture.cs b/TUnit.Assertions.SourceGenerator.IncrementalTests/Fixture.cs new file mode 100644 index 0000000000..bafda094b1 --- /dev/null +++ b/TUnit.Assertions.SourceGenerator.IncrementalTests/Fixture.cs @@ -0,0 +1,77 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using TUnit.Assertions.Attributes; + +namespace TUnit.Assertions.SourceGenerator.IncrementalTests; + +public static class Fixture +{ + public static readonly Assembly[] ImportantAssemblies = new[] + { + typeof(object).Assembly, + typeof(Console).Assembly, + typeof(GenerateAssertionAttribute).Assembly, + typeof(MulticastDelegate).Assembly, + typeof(IServiceProvider).Assembly, + }; + + public static Assembly[] AssemblyReferencesForCodegen => + AppDomain + .CurrentDomain.GetAssemblies() + .Concat(ImportantAssemblies) + .Distinct() + .Where(a => !a.IsDynamic) + .ToArray(); + + public static DirectoryInfo GetSolutionDirectoryInfo() + { + var slnDir = SolutionDir(); + var directory = new DirectoryInfo(slnDir); + // Assert.True(directory.Exists); + return directory; + } + + private static string SolutionDir([CallerFilePath] string thisFilePath = "") => + Path.GetFullPath(Path.Join(thisFilePath, "../../../")); + + public static CSharpCompilation CreateLibrary(params string[] source) => + CreateLibrary(source.Select(s => CSharpSyntaxTree.ParseText(s)).ToArray()); + + public static CSharpCompilation CreateLibrary(params SyntaxTree[] source) + { + var references = new List(); + var assemblies = AssemblyReferencesForCodegen; + foreach (Assembly assembly in assemblies) + { + if (!assembly.IsDynamic) + { + references.Add(MetadataReference.CreateFromFile(assembly.Location)); + } + } + + var compilation = CSharpCompilation.Create( + "Library", + source, + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) + ); + + return compilation; + } + + public static async Task SourceFromResourceFile(string file) + { + var currentDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); + Assert.NotNull(currentDir); + var resourcesDir = Path.Combine(currentDir, "resources"); + + return await File.ReadAllTextAsync(Path.Combine(resourcesDir, file)); + } +} diff --git a/TUnit.Assertions.SourceGenerator.IncrementalTests/TUnit.Assertions.SourceGenerator.IncrementalTests.csproj b/TUnit.Assertions.SourceGenerator.IncrementalTests/TUnit.Assertions.SourceGenerator.IncrementalTests.csproj new file mode 100644 index 0000000000..24329cbc95 --- /dev/null +++ b/TUnit.Assertions.SourceGenerator.IncrementalTests/TUnit.Assertions.SourceGenerator.IncrementalTests.csproj @@ -0,0 +1,31 @@ + + + + net10.0 + enable + enable + false + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/TUnit.Assertions.SourceGenerator.IncrementalTests/TestHelper.cs b/TUnit.Assertions.SourceGenerator.IncrementalTests/TestHelper.cs new file mode 100644 index 0000000000..8475138ed3 --- /dev/null +++ b/TUnit.Assertions.SourceGenerator.IncrementalTests/TestHelper.cs @@ -0,0 +1,149 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using TUnit.Assertions.SourceGenerator.Generators; + +namespace TUnit.Assertions.SourceGenerator.IncrementalTests; + +internal static class TestHelper +{ + private static readonly GeneratorDriverOptions EnableIncrementalTrackingDriverOptions = new( + IncrementalGeneratorOutputKind.None, + trackIncrementalGeneratorSteps: true + ); + + internal static GeneratorDriver GenerateTracked(Compilation compilation) + { + var generator = new MethodAssertionGenerator(); + + var driver = CSharpGeneratorDriver.Create( + new[] { generator.AsSourceGenerator() }, + driverOptions: EnableIncrementalTrackingDriverOptions + ); + return driver.RunGenerators(compilation); + } + + internal static CSharpCompilation ReplaceMemberDeclaration( + CSharpCompilation compilation, + string memberName, + string newMember + ) + { + var syntaxTree = compilation.SyntaxTrees.Single(); + var memberDeclaration = syntaxTree + .GetCompilationUnitRoot() + .DescendantNodes() + .OfType() + .Single(x => x.Identifier.Text == memberName); + var updatedMemberDeclaration = SyntaxFactory.ParseMemberDeclaration(newMember)!; + + var newRoot = syntaxTree.GetCompilationUnitRoot().ReplaceNode(memberDeclaration, updatedMemberDeclaration); + var newTree = syntaxTree.WithRootAndOptions(newRoot, syntaxTree.Options); + + return compilation.ReplaceSyntaxTree(compilation.SyntaxTrees.First(), newTree); + } + + internal static CSharpCompilation ReplaceLocalDeclaration( + CSharpCompilation compilation, + string variableName, + string newDeclaration + ) + { + var syntaxTree = compilation.SyntaxTrees.Single(); + + var memberDeclaration = syntaxTree + .GetCompilationUnitRoot() + .DescendantNodes() + .OfType() + .Single(x => x.Declaration.Variables.Any(x => x.Identifier.ToString() == variableName)); + var updatedMemberDeclaration = SyntaxFactory.ParseStatement(newDeclaration)!; + + var newRoot = syntaxTree.GetCompilationUnitRoot().ReplaceNode(memberDeclaration, updatedMemberDeclaration); + var newTree = syntaxTree.WithRootAndOptions(newRoot, syntaxTree.Options); + + return compilation.ReplaceSyntaxTree(compilation.SyntaxTrees.First(), newTree); + } + + internal static CSharpCompilation ReplaceMethodDeclaration( + CSharpCompilation compilation, + string methodName, + string newDeclaration + ) + { + var syntaxTree = compilation.SyntaxTrees.Single(); + + var memberDeclaration = syntaxTree + .GetCompilationUnitRoot() + .DescendantNodes() + .OfType() + .First(x => x.Identifier.Text == methodName); + var updatedMemberDeclaration = SyntaxFactory.ParseMemberDeclaration(newDeclaration)!; + + var newRoot = syntaxTree.GetCompilationUnitRoot().ReplaceNode(memberDeclaration, updatedMemberDeclaration); + var newTree = syntaxTree.WithRootAndOptions(newRoot, syntaxTree.Options); + + return compilation.ReplaceSyntaxTree(compilation.SyntaxTrees.First(), newTree); + } + + + internal static void AssertRunReasons( + GeneratorDriver driver, + IncrementalGeneratorRunReasons reasons, + int outputIndex = 0 + ) + { + var runResult = driver.GetRunResult().Results[0]; + + AssertRunReason(runResult, MethodAssertionGenerator.BuildAssertion, reasons.BuildMethodAssertionStep, outputIndex); + } + + private static void AssertRunReason( + GeneratorRunResult runResult, + string stepName, + IncrementalStepRunReason expectedStepReason, + int outputIndex + ) + { + var actualStepReason = runResult + .TrackedSteps[stepName] + .SelectMany(x => x.Outputs) + .ElementAt(outputIndex) + .Reason; + + if (actualStepReason != expectedStepReason) + { + throw new Exception($"Incremental generator step {stepName} at index {outputIndex} failed " + + $"with the expected reason: {expectedStepReason}, with the actual reason: {actualStepReason}."); + } + } +} + +internal record IncrementalGeneratorRunReasons( + IncrementalStepRunReason BuildMethodAssertionStep, + IncrementalStepRunReason ReportDiagnosticsStep +) +{ + public static readonly IncrementalGeneratorRunReasons New = new( + IncrementalStepRunReason.New, + IncrementalStepRunReason.New + ); + + public static readonly IncrementalGeneratorRunReasons Cached = new( + // compilation step should always be modified as each time a new compilation is passed + IncrementalStepRunReason.Cached, + IncrementalStepRunReason.Cached + ); + + public static readonly IncrementalGeneratorRunReasons Modified = Cached with + { + ReportDiagnosticsStep = IncrementalStepRunReason.Modified, + BuildMethodAssertionStep = IncrementalStepRunReason.Modified, + }; + + public static readonly IncrementalGeneratorRunReasons ModifiedSource = Cached with + { + ReportDiagnosticsStep = IncrementalStepRunReason.Unchanged, + BuildMethodAssertionStep = IncrementalStepRunReason.Modified, + }; +} + diff --git a/TUnit.Assertions.SourceGenerator/Generators/MethodAssertionGenerator.cs b/TUnit.Assertions.SourceGenerator/Generators/MethodAssertionGenerator.cs index 95bbfcae85..ba3106a611 100644 --- a/TUnit.Assertions.SourceGenerator/Generators/MethodAssertionGenerator.cs +++ b/TUnit.Assertions.SourceGenerator/Generators/MethodAssertionGenerator.cs @@ -1,13 +1,10 @@ -using System; -using System.Collections.Generic; using System.Collections.Immutable; -using System.Linq; using System.Text; using System.Text.RegularExpressions; -using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using TUnit.Assertions.SourceGenerator.Models; namespace TUnit.Assertions.SourceGenerator.Generators; @@ -20,6 +17,8 @@ namespace TUnit.Assertions.SourceGenerator.Generators; [Generator] public sealed class MethodAssertionGenerator : IIncrementalGenerator { + public static string BuildAssertion = "MethodAssertionData"; + private static readonly DiagnosticDescriptor MethodMustBeStaticRule = new DiagnosticDescriptor( id: "TUNITGEN001", title: "Method must be static", @@ -68,7 +67,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Split into methods and diagnostics var methods = assertionMethodsOrDiagnostics .Where(x => x.Data != null) - .Select((x, _) => x.Data!); + .Select((x, _) => x.Data!) + .WithTrackingName(BuildAssertion); var diagnostics = assertionMethodsOrDiagnostics .Where(x => x.Diagnostic != null) @@ -132,11 +132,15 @@ private static (AssertionMethodData? Data, Diagnostic? Diagnostic) GetAssertionM // First parameter is the target type (what becomes IAssertionSource) var targetType = methodSymbol.Parameters[0].Type; - var additionalParameters = methodSymbol.Parameters.Skip(1).ToImmutableArray(); - - // Check if it's an extension method - var isExtensionMethod = methodSymbol.IsExtensionMethod || - (methodSymbol.Parameters.Length > 0 && methodSymbol.Parameters[0].IsThis); + var additionalParameters = methodSymbol.Parameters.Skip(1).Select(p => new ParameterData() + { + Name = p.Name, + Type = p.Type.ToDisplayString(), + IsRefStruct = IsRefStruct(p.Type), + IsParams = p.IsParams, + IsInterpolatedStringHandler = IsInterpolatedStringHandler(p.Type), + SimpleTypeName = GetSimpleTypeName(p.Type), + }).ToImmutableEquatableArray(); // Extract custom expectation message and inlining preference if provided string? customExpectation = null; @@ -272,34 +276,67 @@ private static (AssertionMethodData? Data, Diagnostic? Diagnostic) GetAssertionM // Ref structs cannot be stored as class fields, so we need to inline the method body foreach (var param in additionalParameters) { - if (IsRefStruct(param.Type) && string.IsNullOrEmpty(methodBody)) + if (param.IsRefStruct && string.IsNullOrEmpty(methodBody)) { var diagnostic = Diagnostic.Create( RefStructRequiresInliningRule, location, methodSymbol.Name, param.Name, - param.Type.ToDisplayString()); + param.Type); return (null, diagnostic); } } + + ContainingTypeData? containingTypeData = null; + + if (methodSymbol.ContainingSymbol != null) + { + containingTypeData = new ContainingTypeData( + methodSymbol.ContainingType.Name, + methodSymbol.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + methodSymbol.ContainingType.ContainingNamespace.ToDisplayString() + ); + } + + var methodData = new MethodData() + { + Name = methodSymbol.Name, + FirstParamName = methodSymbol.Parameters[0].Name, + GenericTypeParameters = GetGenericTypeParameters(targetType, methodSymbol), + GenericConstraints = CollectGenericConstraints(methodSymbol), + MethodCallExpression = BuildMethodCallExpression(methodSymbol, additionalParameters), + ContainingType = containingTypeData, + }; + + var targetTypeData = new TargetTypeData() + { + TypeName = targetType.ToDisplayString(), + SimpleTypeName = GetSimpleTypeName(targetType), + IsNullable = targetType.IsReferenceType || + targetType.NullableAnnotation == NullableAnnotation.Annotated, + }; + var data = new AssertionMethodData( - methodSymbol, - targetType, + methodData, + targetTypeData, additionalParameters, returnTypeInfo.Value, - isExtensionMethod, customExpectation, isFileScoped, methodBody, - suppressionAttributesForCheckAsync.ToImmutableArray(), - diagnosticAttributesForExtensionMethod.ToImmutableArray() + suppressionAttributesForCheckAsync.ToImmutableEquatableArray(), + diagnosticAttributesForExtensionMethod.ToImmutableEquatableArray() ); return (data, null); } + private static bool IsExtensionMethod(IMethodSymbol methodSymbol) => + methodSymbol.IsExtensionMethod || + (methodSymbol.Parameters.Length > 0 && methodSymbol.Parameters[0].IsThis); + /// /// Checks if a type is file-scoped (has 'file' accessibility) /// File-scoped types have specific metadata that we can check. @@ -423,14 +460,14 @@ public TypeQualifyingRewriter(SemanticModel semanticModel) // Task if (innerType.SpecialType == SpecialType.System_Boolean) { - return new ReturnTypeInfo(ReturnTypeKind.TaskBool, innerType); + return new ReturnTypeInfo(ReturnTypeKind.TaskBool); } // Task if (innerType.Name == "AssertionResult" && innerType.ContainingNamespace?.ToDisplayString() == "TUnit.Assertions.Core") { - return new ReturnTypeInfo(ReturnTypeKind.TaskAssertionResult, innerType); + return new ReturnTypeInfo(ReturnTypeKind.TaskAssertionResult); } } } @@ -439,14 +476,14 @@ public TypeQualifyingRewriter(SemanticModel semanticModel) if (namedType.Name == "AssertionResult" && namedType.ContainingNamespace?.ToDisplayString() == "TUnit.Assertions.Core") { - return new ReturnTypeInfo(ReturnTypeKind.AssertionResult, namedType); + return new ReturnTypeInfo(ReturnTypeKind.AssertionResult); } } // bool if (returnType.SpecialType == SpecialType.System_Boolean) { - return new ReturnTypeInfo(ReturnTypeKind.Bool, returnType); + return new ReturnTypeInfo(ReturnTypeKind.Bool); } return null; @@ -462,9 +499,9 @@ private static void GenerateAssertions( } // Group by containing class to generate one file per class - foreach (var methodGroup in methods.GroupBy(m => m.Method.ContainingType, SymbolEqualityComparer.Default)) + foreach (var methodGroup in methods.GroupBy(m => m.Method.ContainingType?.FullContainingType)) { - var containingType = methodGroup.Key as INamedTypeSymbol; + var containingType = methodGroup.First().Method.ContainingType; if (containingType == null) { continue; @@ -476,7 +513,7 @@ private static void GenerateAssertions( var namespaceName = "TUnit.Assertions.Extensions"; // Get the original namespace where the helper methods are defined - var originalNamespace = containingType.ContainingNamespace?.ToDisplayString(); + var originalNamespace = containingType.ContainingNamespace; // File header sourceBuilder.AppendLine("#nullable enable"); @@ -537,13 +574,13 @@ private static void GenerateAssertions( private static void GenerateAssertionClass(StringBuilder sb, AssertionMethodData data) { var className = GenerateClassName(data); - var targetTypeName = data.TargetType.ToDisplayString(); - var genericParams = GetGenericTypeParameters(data.TargetType, data.Method); - var genericDeclaration = genericParams.Length > 0 ? $"<{string.Join(", ", genericParams)}>" : ""; - var isNullable = data.TargetType.IsReferenceType || data.TargetType.NullableAnnotation == NullableAnnotation.Annotated; + var targetTypeName = data.TargetType.TypeName; + var genericParams = data.Method.GenericTypeParameters; + var genericDeclaration = genericParams.Count > 0 ? $"<{string.Join(", ", genericParams)}>" : ""; + var isNullable = data.TargetType.IsNullable; // Collect generic constraints from the method - var genericConstraints = CollectGenericConstraints(data.Method); + var genericConstraints = data.Method.GenericConstraints; // Class declaration sb.AppendLine($"/// "); @@ -551,7 +588,7 @@ private static void GenerateAssertionClass(StringBuilder sb, AssertionMethodData sb.AppendLine($"/// "); // Add suppression for generic types to avoid trimming warnings - if (genericParams.Length > 0) + if (genericParams.Count > 0) { sb.AppendLine($"[System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage(\"Trimming\", \"IL2091\", Justification = \"Generic type parameter is only used for property access, not instantiation\")]"); } @@ -573,11 +610,11 @@ private static void GenerateAssertionClass(StringBuilder sb, AssertionMethodData // Note: Ref struct types (like DefaultInterpolatedStringHandler) are stored as string foreach (var param in data.AdditionalParameters) { - var fieldType = IsRefStruct(param.Type) ? "string" : param.Type.ToDisplayString(); + var fieldType = param.IsRefStruct ? "string" : param.Type; sb.AppendLine($" private readonly {fieldType} _{param.Name};"); } - if (data.AdditionalParameters.Length > 0) + if (data.AdditionalParameters.Count > 0) { sb.AppendLine(); } @@ -587,7 +624,7 @@ private static void GenerateAssertionClass(StringBuilder sb, AssertionMethodData sb.Append($" public {className}(AssertionContext<{targetTypeName}> context"); foreach (var param in data.AdditionalParameters) { - var paramType = IsRefStruct(param.Type) ? "string" : param.Type.ToDisplayString(); + var paramType = param.IsRefStruct ? "string" : param.Type; sb.Append($", {paramType} {param.Name}"); } sb.AppendLine(")"); @@ -601,12 +638,11 @@ private static void GenerateAssertionClass(StringBuilder sb, AssertionMethodData sb.AppendLine(); // CheckAsync method - only async if we need await - var needsAsync = data.ReturnTypeInfo.Kind == ReturnTypeKind.TaskBool || - data.ReturnTypeInfo.Kind == ReturnTypeKind.TaskAssertionResult; + var needsAsync = data.ReturnTypeInfo.Kind is ReturnTypeKind.TaskBool or ReturnTypeKind.TaskAssertionResult; var asyncKeyword = needsAsync ? "async " : ""; // Add suppression attributes to CheckAsync method when method body is inlined - if (!string.IsNullOrEmpty(data.MethodBody) && data.SuppressionAttributesForCheckAsync.Length > 0) + if (!string.IsNullOrEmpty(data.MethodBody) && data.SuppressionAttributesForCheckAsync.Count > 0) { foreach (var suppressionAttr in data.SuppressionAttributesForCheckAsync) { @@ -650,7 +686,7 @@ private static void GenerateAssertionClass(StringBuilder sb, AssertionMethodData // Use custom expectation message // Replace parameter placeholders like {param} with {_param} (field references) var expectation = data.CustomExpectation; - if (data.AdditionalParameters.Length > 0) + if (data.AdditionalParameters.Count > 0) { // Replace each parameter placeholder {paramName} with {_paramName} foreach (var param in data.AdditionalParameters) @@ -673,7 +709,7 @@ private static void GenerateAssertionClass(StringBuilder sb, AssertionMethodData else { // Use default expectation message - if (data.AdditionalParameters.Length > 0) + if (data.AdditionalParameters.Count > 0) { var paramList = string.Join(", ", data.AdditionalParameters.Select(p => $"{{_{p.Name}}}")); sb.AppendLine($" return $\"to satisfy {data.Method.Name}({paramList})\";"); @@ -696,7 +732,7 @@ private static void GenerateMethodCall(StringBuilder sb, AssertionMethodData dat var shouldInline = !string.IsNullOrEmpty(data.MethodBody); var methodCall = shouldInline ? BuildInlinedExpression(data) - : BuildMethodCallExpression(data); + : data.Method.MethodCallExpression; switch (data.ReturnTypeInfo.Kind) { @@ -733,13 +769,13 @@ private static string BuildInlinedExpression(AssertionMethodData data) if (string.IsNullOrEmpty(data.MethodBody)) { // Fallback to method call if body is not available - return BuildMethodCallExpression(data); + return data.Method.MethodCallExpression; } var inlinedBody = data.MethodBody; // Replace first parameter name with "value" (already named value in our context) - var firstParamName = data.Method.Parameters[0].Name; + var firstParamName = data.Method.FirstParamName; if (firstParamName != "value") { // Use word boundary replacement to avoid partial matches @@ -763,7 +799,7 @@ private static string BuildInlinedExpression(AssertionMethodData data) // remove calls to .ToStringAndClear() and .ToString() since the value is already a string foreach (var param in data.AdditionalParameters) { - if (IsRefStruct(param.Type)) + if (param.IsRefStruct) { var fieldName = $"_{param.Name}"; // Remove .ToStringAndClear() - the value is already a string @@ -781,7 +817,7 @@ private static string BuildInlinedExpression(AssertionMethodData data) // Add null-forgiving operator for reference types if not already present // This is safe because we've already checked for null above - var isNullable = data.TargetType.IsReferenceType || data.TargetType.NullableAnnotation == NullableAnnotation.Annotated; + var isNullable = data.TargetType.IsNullable; if (isNullable && !string.IsNullOrEmpty(inlinedBody) && !inlinedBody.StartsWith("value!")) { // Replace null-conditional operators with null-forgiving + regular operators @@ -813,31 +849,31 @@ private static string BuildInlinedExpression(AssertionMethodData data) return inlinedBody ?? string.Empty; } - private static string BuildMethodCallExpression(AssertionMethodData data) + private static string BuildMethodCallExpression(IMethodSymbol method, ImmutableEquatableArray additionalParameters) { - var containingType = data.Method.ContainingType.ToDisplayString(); - var methodName = data.Method.Name; + var containingType = method.ContainingType.ToDisplayString(); + var methodName = method.Name; // Build type arguments if the method is generic var typeArguments = ""; - if (data.Method.IsGenericMethod && data.Method.TypeParameters.Length > 0) + if (method is { IsGenericMethod: true, TypeParameters.Length: > 0 }) { - var typeParams = string.Join(", ", data.Method.TypeParameters.Select(tp => tp.Name)); + var typeParams = string.Join(", ", method.TypeParameters.Select(tp => tp.Name)); typeArguments = $"<{typeParams}>"; } - if (data.IsExtensionMethod) + if (IsExtensionMethod(method)) { // Extension method syntax: value!.MethodName(params) // Use null-forgiving operator since we've already checked for null above - var paramList = string.Join(", ", data.AdditionalParameters.Select(p => $"_{p.Name}")); + var paramList = string.Join(", ", additionalParameters.Select(p => $"_{p.Name}")); return $"value!.{methodName}{typeArguments}({paramList})"; } else { // Static method syntax: ContainingType.MethodName(value, params) var allParams = new List { "value" }; - allParams.AddRange(data.AdditionalParameters.Select(p => $"_{p.Name}")); + allParams.AddRange(additionalParameters.Select(p => $"_{p.Name}")); var paramList = string.Join(", ", allParams); return $"{containingType}.{methodName}{typeArguments}({paramList})"; } @@ -846,13 +882,13 @@ private static string BuildMethodCallExpression(AssertionMethodData data) private static void GenerateExtensionMethod(StringBuilder sb, AssertionMethodData data) { var className = GenerateClassName(data); - var targetTypeName = data.TargetType.ToDisplayString(); + var targetTypeName = data.TargetType.TypeName; var methodName = data.Method.Name; - var genericParams = GetGenericTypeParameters(data.TargetType, data.Method); - var genericDeclaration = genericParams.Length > 0 ? $"<{string.Join(", ", genericParams)}>" : ""; + var genericParams = data.Method.GenericTypeParameters; + var genericDeclaration = genericParams.Count > 0 ? $"<{string.Join(", ", genericParams)}>" : ""; // Collect generic constraints from the method - var genericConstraints = CollectGenericConstraints(data.Method); + var genericConstraints = data.Method.GenericConstraints; // XML documentation sb.AppendLine(" /// "); @@ -860,13 +896,13 @@ private static void GenerateExtensionMethod(StringBuilder sb, AssertionMethodDat sb.AppendLine(" /// "); // Add suppression for generic types to avoid trimming warnings - if (genericParams.Length > 0) + if (genericParams.Count > 0) { sb.AppendLine($" [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage(\"Trimming\", \"IL2091\", Justification = \"Generic type parameter is only used for property access, not instantiation\")]"); } // Add diagnostic attributes (RequiresUnreferencedCode, RequiresDynamicCode) to extension method - if (data.DiagnosticAttributesForExtensionMethod.Length > 0) + if (data.DiagnosticAttributesForExtensionMethod.Count > 0) { foreach (var diagnosticAttr in data.DiagnosticAttributesForExtensionMethod) { @@ -882,11 +918,11 @@ private static void GenerateExtensionMethod(StringBuilder sb, AssertionMethodDat foreach (var param in data.AdditionalParameters) { var paramsModifier = param.IsParams ? "params " : ""; - sb.Append($", {paramsModifier}{param.Type.ToDisplayString()} {param.Name}"); + sb.Append($", {paramsModifier}{param.Type} {param.Name}"); } // CallerArgumentExpression parameters (skip for params since params must be last) - for (int i = 0; i < data.AdditionalParameters.Length; i++) + for (int i = 0; i < data.AdditionalParameters.Count; i++) { var param = data.AdditionalParameters[i]; if (!param.IsParams) @@ -909,7 +945,7 @@ private static void GenerateExtensionMethod(StringBuilder sb, AssertionMethodDat sb.AppendLine(" {"); // Build expression string - if (data.AdditionalParameters.Length > 0) + if (data.AdditionalParameters.Count > 0) { // For params parameters, use parameter name directly (no Expression suffix since we didn't generate it) var exprList = string.Join(", ", data.AdditionalParameters.Select(p => @@ -926,11 +962,11 @@ private static void GenerateExtensionMethod(StringBuilder sb, AssertionMethodDat sb.Append($" return new {className}{genericDeclaration}(source.Context"); foreach (var param in data.AdditionalParameters) { - if (IsRefStruct(param.Type)) + if (param.IsRefStruct) { // Convert ref struct to string - use ToStringAndClear for interpolated string handlers // or ToString() for other ref structs - var conversion = IsInterpolatedStringHandler(param.Type) + var conversion = param.IsInterpolatedStringHandler ? $"{param.Name}.ToStringAndClear()" : $"{param.Name}.ToString()"; sb.Append($", {conversion}"); @@ -948,25 +984,25 @@ private static void GenerateExtensionMethod(StringBuilder sb, AssertionMethodDat private static string GenerateClassName(AssertionMethodData data) { var methodName = data.Method.Name; - var targetTypeName = GetSimpleTypeName(data.TargetType); + var targetTypeName =data.TargetType.SimpleTypeName; - if (data.AdditionalParameters.Length == 0) + if (data.AdditionalParameters.Count == 0) { return $"{targetTypeName}_{methodName}_Assertion"; } // Include parameter types to distinguish overloads - var paramTypes = string.Join("_", data.AdditionalParameters.Select(p => GetSimpleTypeName(p.Type))); + var paramTypes = string.Join("_", data.AdditionalParameters.Select(p => p.SimpleTypeName)); return $"{targetTypeName}_{methodName}_{paramTypes}_Assertion"; } - private static string[] GetGenericTypeParameters(ITypeSymbol type, IMethodSymbol method) + private static ImmutableEquatableArray GetGenericTypeParameters(ITypeSymbol type, IMethodSymbol method) { // For extension methods, if the method has generic parameters, those define ALL the type parameters // (including any used in the target type like Lazy or T[]) if (method != null && method.IsGenericMethod) { - return method.TypeParameters.Select(t => t.Name).ToArray(); + return method.TypeParameters.Select(t => t.Name).ToImmutableEquatableArray(); } // If the method is not generic, check if the type itself has unbound generic parameters @@ -975,10 +1011,10 @@ private static string[] GetGenericTypeParameters(ITypeSymbol type, IMethodSymbol return namedType.TypeArguments .OfType() .Select(t => t.Name) - .ToArray(); + .ToImmutableEquatableArray(); } - return Array.Empty(); + return ImmutableEquatableArray.Empty(); } private static string GetSimpleTypeName(ITypeSymbol type) @@ -1020,13 +1056,13 @@ private static string GetSimpleTypeName(ITypeSymbol type) /// Collects generic constraints from method type parameters. /// Returns a list of constraint strings in the format "where T : constraint1, constraint2" /// - private static List CollectGenericConstraints(IMethodSymbol method) + private static ImmutableEquatableArray CollectGenericConstraints(IMethodSymbol method) { var constraints = new List(); if (!method.IsGenericMethod || method.TypeParameters.Length == 0) { - return constraints; + return constraints.ToImmutableEquatableArray(); } foreach (var typeParameter in method.TypeParameters) @@ -1060,7 +1096,7 @@ private static List CollectGenericConstraints(IMethodSymbol method) } } - return constraints; + return constraints.ToImmutableEquatableArray(); } /// @@ -1126,18 +1162,43 @@ private enum ReturnTypeKind TaskAssertionResult } - private readonly record struct ReturnTypeInfo(ReturnTypeKind Kind, ITypeSymbol Type); + private readonly record struct ReturnTypeInfo(ReturnTypeKind Kind); private record AssertionMethodData( - IMethodSymbol Method, - ITypeSymbol TargetType, - ImmutableArray AdditionalParameters, + MethodData Method, + TargetTypeData TargetType, + ImmutableEquatableArray AdditionalParameters, ReturnTypeInfo ReturnTypeInfo, - bool IsExtensionMethod, string? CustomExpectation, bool IsFileScoped, string? MethodBody, - ImmutableArray SuppressionAttributesForCheckAsync, - ImmutableArray DiagnosticAttributesForExtensionMethod + ImmutableEquatableArray SuppressionAttributesForCheckAsync, + ImmutableEquatableArray DiagnosticAttributesForExtensionMethod + ); + + private record ContainingTypeData( + string Name, + string FullContainingType, + string ContainingNamespace + ); + + private record struct TargetTypeData(string TypeName, string SimpleTypeName, bool IsNullable); + + private record struct MethodData( + string Name, + ContainingTypeData? ContainingType, + string FirstParamName, + string MethodCallExpression, + ImmutableEquatableArray GenericConstraints, + ImmutableEquatableArray GenericTypeParameters + ); + + private record struct ParameterData( + string Name, + string Type, + bool IsRefStruct, + bool IsParams, + bool IsInterpolatedStringHandler, + string SimpleTypeName ); } diff --git a/TUnit.Assertions.SourceGenerator/Models/ImmutableEquatableArray.cs b/TUnit.Assertions.SourceGenerator/Models/ImmutableEquatableArray.cs new file mode 100644 index 0000000000..6e3fa10e38 --- /dev/null +++ b/TUnit.Assertions.SourceGenerator/Models/ImmutableEquatableArray.cs @@ -0,0 +1,85 @@ +using System.Collections; + +namespace TUnit.Assertions.SourceGenerator.Models; + +// From https://github.com/dotnet/runtime/blob/6316c17c26c7ad25bd3449ce477a0882a48916dd/src/libraries/Common/src/SourceGenerators/ImmutableEquatableArray.cs#L15 +/// +/// Provides an immutable list implementation which implements sequence equality. +/// +public sealed class ImmutableEquatableArray : IEquatable>, IReadOnlyList + where T : IEquatable +{ + public static ImmutableEquatableArray Empty { get; } = new ImmutableEquatableArray(Array.Empty()); + + private readonly T[] _values; + public T this[int index] => _values[index]; + public int Count => _values.Length; + + public ImmutableEquatableArray(IEnumerable values) + => _values = values.ToArray(); + + public bool Equals(ImmutableEquatableArray? other) + => other != null && ((ReadOnlySpan)_values).SequenceEqual(other._values); + + public override bool Equals(object? obj) + => obj is ImmutableEquatableArray other && Equals(other); + + public override int GetHashCode() + { + var hash = 0; + foreach (T value in _values) + { + hash = Combine(hash, value.GetHashCode()); + } + + static int Combine(int h1, int h2) + { + // RyuJIT optimizes this to use the ROL instruction + // Related GitHub pull request: https://github.com/dotnet/coreclr/pull/1830 + uint rol5 = ((uint)h1 << 5) | ((uint)h1 >> 27); + return ((int)rol5 + h1) ^ h2; + } + + return hash; + } + + public Enumerator GetEnumerator() => new Enumerator(_values); + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_values).GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => _values.GetEnumerator(); + + public struct Enumerator + { + private readonly T[] _values; + private int _index; + + internal Enumerator(T[] values) + { + _values = values; + _index = -1; + } + + public bool MoveNext() + { + int newIndex = _index + 1; + + if ((uint)newIndex < (uint)_values.Length) + { + _index = newIndex; + return true; + } + + return false; + } + + public readonly T Current => _values[_index]; + } +} + +internal static class ImmutableEquatableArray +{ + public static ImmutableEquatableArray Empty() + where T : IEquatable => ImmutableEquatableArray.Empty; + + public static ImmutableEquatableArray ToImmutableEquatableArray(this IEnumerable values) where T : IEquatable + => new(values); +} diff --git a/TUnit.sln b/TUnit.sln index 3350695215..61e6f38db2 100644 --- a/TUnit.sln +++ b/TUnit.sln @@ -155,6 +155,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TUnit.FsCheck", "TUnit.FsCh EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TUnit.Example.FsCheck.TestProject", "TUnit.Example.FsCheck.TestProject\TUnit.Example.FsCheck.TestProject.csproj", "{3428D7AD-B362-4647-B1B0-72674CF3BC7C}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TUnit.Assertions.SourceGenerator.IncrementalTests", "TUnit.Assertions.SourceGenerator.IncrementalTests\TUnit.Assertions.SourceGenerator.IncrementalTests.csproj", "{93A728CE-CC78-4F9B-897B-AA6F72E870F2}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -897,6 +899,18 @@ Global {3428D7AD-B362-4647-B1B0-72674CF3BC7C}.Release|x64.Build.0 = Release|Any CPU {3428D7AD-B362-4647-B1B0-72674CF3BC7C}.Release|x86.ActiveCfg = Release|Any CPU {3428D7AD-B362-4647-B1B0-72674CF3BC7C}.Release|x86.Build.0 = Release|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Debug|x64.ActiveCfg = Debug|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Debug|x64.Build.0 = Debug|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Debug|x86.ActiveCfg = Debug|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Debug|x86.Build.0 = Debug|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Release|Any CPU.Build.0 = Release|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Release|x64.ActiveCfg = Release|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Release|x64.Build.0 = Release|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Release|x86.ActiveCfg = Release|Any CPU + {93A728CE-CC78-4F9B-897B-AA6F72E870F2}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -965,6 +979,7 @@ Global {6134813B-F928-443F-A629-F6726A1112F9} = {503DA9FA-045D-4910-8AF6-905E6048B1F1} {3428D7AD-B362-4647-B1B0-72674CF3BC7C} = {0BA988BF-ADCE-4343-9098-B4EF65C43709} {6846A70E-2232-4BEF-9CE5-03F28A221335} = {1B56B580-4D59-4E83-9F80-467D58DADAC1} + {93A728CE-CC78-4F9B-897B-AA6F72E870F2} = {62AD1EAF-43C4-4AC0-B9FA-CD59739B3850} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {109D285A-36B3-4503-BCDF-8E26FB0E2C5B}