Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers;

public class AttributeWriter
public class AttributeWriter(Compilation compilation)
{
public static void WriteAttributes(ICodeWriter sourceCodeWriter, Compilation compilation,
private readonly Dictionary<AttributeData, string> _attributeObjectInitializerCache = new();

public void WriteAttributes(ICodeWriter sourceCodeWriter,
ImmutableArray<AttributeData> attributeDatas)
{
var attributesToWrite = new List<AttributeData>();
Expand Down Expand Up @@ -49,7 +51,7 @@ public static void WriteAttributes(ICodeWriter sourceCodeWriter, Compilation com
{
var attributeData = attributesToWrite[index];

WriteAttribute(sourceCodeWriter, compilation, attributeData);
WriteAttribute(sourceCodeWriter, attributeData);

if (index != attributesToWrite.Count - 1)
{
Expand All @@ -58,8 +60,7 @@ public static void WriteAttributes(ICodeWriter sourceCodeWriter, Compilation com
}
}

public static void WriteAttribute(ICodeWriter sourceCodeWriter, Compilation compilation,
AttributeData attributeData)
public void WriteAttribute(ICodeWriter sourceCodeWriter, AttributeData attributeData)
{
if (attributeData.ApplicationSyntaxReference is null)
{
Expand All @@ -70,12 +71,23 @@ public static void WriteAttribute(ICodeWriter sourceCodeWriter, Compilation comp
else
{
// For attributes from the current compilation, use the syntax-based approach
sourceCodeWriter.Append(GetAttributeObjectInitializer(compilation, attributeData));
sourceCodeWriter.Append(GetAttributeObjectInitializer(attributeData));
}
}

public string GetAttributeObjectInitializer(AttributeData attributeData)
{
if (_attributeObjectInitializerCache.TryGetValue(attributeData, out var initializer))
{
return initializer;
}

initializer = GetAttributeObjectInitializerInner(compilation, attributeData);
_attributeObjectInitializerCache.Add(attributeData, initializer);
return initializer;
}

public static string GetAttributeObjectInitializer(Compilation compilation,
AttributeData attributeData)
private static string GetAttributeObjectInitializerInner(Compilation compilation, AttributeData attributeData)
{
var sourceCodeWriter = new CodeWriter("", includeHeader: false);

Expand Down Expand Up @@ -123,7 +135,6 @@ public static string GetAttributeObjectInitializer(Compilation compilation,
return sourceCodeWriter.ToString();
}


private static string FormatConstructorArgument(Compilation compilation, AttributeArgumentSyntax attributeArgumentSyntax)
{
if (attributeArgumentSyntax.NameColon is not null)
Expand Down
56 changes: 36 additions & 20 deletions TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,31 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return !string.Equals(value, "false", StringComparison.OrdinalIgnoreCase);
});

var compilationContext = context
.CompilationProvider
.Select(static (c, _) =>
new CompilationContext(
(CSharpCompilation)c,
new AttributeWriter(c)
));

var testMethodsProvider = context.SyntaxProvider
.ForAttributeWithMetadataName(
"TUnit.Core.TestAttribute",
predicate: static (node, _) => node is MethodDeclarationSyntax,
transform: static (ctx, _) => GetTestMethodMetadata(ctx))
transform: static (ctx, _) => ctx)
.Combine(compilationContext)
.Select(static (ctx, _) => GetTestMethodMetadata(ctx.Left, ctx.Right))
.Where(static m => m is not null)
.Combine(enabledProvider);

var inheritsTestsClassesProvider = context.SyntaxProvider
.ForAttributeWithMetadataName(
"TUnit.Core.InheritsTestsAttribute",
predicate: static (node, _) => node is ClassDeclarationSyntax,
transform: static (ctx, _) => GetInheritsTestsClassMetadata(ctx))
transform: static (ctx, _) => ctx)
.Combine(compilationContext)
.Select(static (ctx, _) => GetInheritsTestsClassMetadata(ctx.Left, ctx.Right))
.Where(static m => m is not null)
.Combine(enabledProvider);

Expand All @@ -67,7 +79,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
});
}

private static InheritsTestsClassMetadata? GetInheritsTestsClassMetadata(GeneratorAttributeSyntaxContext context)
private static InheritsTestsClassMetadata? GetInheritsTestsClassMetadata(GeneratorAttributeSyntaxContext context, CompilationContext compilationContext)
{
var classSyntax = (ClassDeclarationSyntax)context.TargetNode;

Expand All @@ -85,11 +97,12 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
TypeSymbol = classSymbol,
ClassSyntax = classSyntax,
Context = context
Context = context,
CompilationContext = compilationContext
};
}

private static TestMethodMetadata? GetTestMethodMetadata(GeneratorAttributeSyntaxContext context)
private static TestMethodMetadata? GetTestMethodMetadata(GeneratorAttributeSyntaxContext context, CompilationContext compilationContext)
{
var methodSyntax = (MethodDeclarationSyntax)context.TargetNode;
var methodSymbol = context.TargetSymbol as IMethodSymbol;
Expand Down Expand Up @@ -121,6 +134,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
LineNumber = lineNumber,
TestAttribute = context.Attributes.First(),
Context = context,
CompilationContext = compilationContext,
MethodSyntax = methodSyntax,
IsGenericType = isGenericType,
IsGenericMethod = isGenericMethod,
Expand Down Expand Up @@ -186,6 +200,7 @@ private static void GenerateInheritedTestSources(SourceProductionContext context
LineNumber = lineNumber,
TestAttribute = testAttribute,
Context = classInfo.Context, // Use class context to access Compilation
CompilationContext = classInfo.CompilationContext,
MethodSyntax = null, // No syntax for inherited methods
IsGenericType = typeForMetadata.IsGenericType,
IsGenericMethod = (concreteMethod ?? method).IsGenericMethod,
Expand Down Expand Up @@ -458,7 +473,7 @@ private static void GenerateMetadata(CodeWriter writer, TestMethodMetadata testM
.Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes())
.ToImmutableArray();

AttributeWriter.WriteAttributes(writer, compilation, attributes);
testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, attributes);

writer.Unindent();
writer.AppendLine("],");
Expand Down Expand Up @@ -504,7 +519,7 @@ private static void GenerateMetadataForConcreteInstantiation(CodeWriter writer,
.Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes())
.ToImmutableArray();

AttributeWriter.WriteAttributes(writer, compilation, attributes);
testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, attributes);

writer.Unindent();
writer.AppendLine("],");
Expand Down Expand Up @@ -564,7 +579,7 @@ private static void GenerateDataSources(CodeWriter writer, TestMethodMetadata te

foreach (var attr in methodDataSources)
{
GenerateDataSourceAttribute(writer, compilation, attr, methodSymbol, typeSymbol);
GenerateDataSourceAttribute(writer, testMethod.CompilationContext, attr, methodSymbol, typeSymbol);
}

writer.Unindent();
Expand All @@ -584,7 +599,7 @@ private static void GenerateDataSources(CodeWriter writer, TestMethodMetadata te

foreach (var attr in classDataSources)
{
GenerateDataSourceAttribute(writer, compilation, attr, methodSymbol, typeSymbol);
GenerateDataSourceAttribute(writer, testMethod.CompilationContext, attr, methodSymbol, typeSymbol);
}

writer.Unindent();
Expand All @@ -595,7 +610,7 @@ private static void GenerateDataSources(CodeWriter writer, TestMethodMetadata te
GeneratePropertyDataSources(writer, testMethod);
}

private static void GenerateDataSourceAttribute(CodeWriter writer, Compilation compilation, AttributeData attr, IMethodSymbol methodSymbol, INamedTypeSymbol typeSymbol)
private static void GenerateDataSourceAttribute(CodeWriter writer, CompilationContext compilationContext, AttributeData attr, IMethodSymbol methodSymbol, INamedTypeSymbol typeSymbol)
{
var attrClass = attr.AttributeClass;
if (attrClass == null)
Expand All @@ -613,18 +628,18 @@ private static void GenerateDataSourceAttribute(CodeWriter writer, Compilation c
{
try
{
GenerateArgumentsAttributeWithParameterTypes(writer, compilation, attr, methodSymbol);
GenerateArgumentsAttributeWithParameterTypes(writer, compilationContext.Compilation, attr, methodSymbol);
}
catch
{
// Fall back to default behavior if parameter type matching fails
AttributeWriter.WriteAttribute(writer, compilation, attr);
compilationContext.AttributeWriter.WriteAttribute(writer, attr);
writer.AppendLine(",");
}
}
else
{
AttributeWriter.WriteAttribute(writer, compilation, attr);
compilationContext.AttributeWriter.WriteAttribute(writer, attr);
writer.AppendLine(",");
}
}
Expand Down Expand Up @@ -1535,7 +1550,7 @@ private static void GeneratePropertyDataSources(CodeWriter writer, TestMethodMet
writer.AppendLine($"PropertyName = \"{property.Name}\",");
writer.AppendLine($"PropertyType = typeof({property.Type.GloballyQualified()}),");
writer.Append("DataSource = ");
GenerateDataSourceAttribute(writer, compilation, dataSourceAttr, testMethod.MethodSymbol, typeSymbol);
GenerateDataSourceAttribute(writer, testMethod.CompilationContext, dataSourceAttr, testMethod.MethodSymbol, typeSymbol);
writer.Unindent();
writer.AppendLine("},");
}
Expand Down Expand Up @@ -4805,7 +4820,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources(
writer.AppendLine("AttributeFactory = static () =>");
writer.AppendLine("[");
writer.Indent();
AttributeWriter.WriteAttributes(writer, compilation, filteredAttributes.ToImmutableArray());
testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, filteredAttributes.ToImmutableArray());
writer.Unindent();
writer.AppendLine("],");

Expand Down Expand Up @@ -4866,7 +4881,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources(

foreach (var attr in methodDataSources)
{
GenerateDataSourceAttribute(writer, compilation, attr, methodSymbol, concreteTypeSymbol);
GenerateDataSourceAttribute(writer, testMethod.CompilationContext, attr, methodSymbol, concreteTypeSymbol);
}

writer.Unindent();
Expand All @@ -4886,7 +4901,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources(

foreach (var attr in classDataSources)
{
GenerateDataSourceAttribute(writer, compilation, attr, methodSymbol, concreteTypeSymbol);
GenerateDataSourceAttribute(writer, testMethod.CompilationContext, attr, methodSymbol, concreteTypeSymbol);
}

writer.Unindent();
Expand Down Expand Up @@ -5133,7 +5148,7 @@ private static void GenerateConcreteTestMetadataForNonGeneric(
.Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes())
.ToImmutableArray();

AttributeWriter.WriteAttributes(writer, compilation, attributes);
testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, attributes);

writer.Unindent();
writer.AppendLine("],");
Expand All @@ -5154,7 +5169,7 @@ private static void GenerateConcreteTestMetadataForNonGeneric(
writer.AppendLine("DataSources = new global::TUnit.Core.IDataSourceAttribute[]");
writer.AppendLine("{");
writer.Indent();
GenerateDataSourceAttribute(writer, compilation, methodDataSourceAttribute, testMethod.MethodSymbol, testMethod.TypeSymbol);
GenerateDataSourceAttribute(writer, testMethod.CompilationContext, methodDataSourceAttribute, testMethod.MethodSymbol, testMethod.TypeSymbol);
writer.Unindent();
writer.AppendLine("},");
}
Expand All @@ -5169,7 +5184,7 @@ private static void GenerateConcreteTestMetadataForNonGeneric(
writer.AppendLine("ClassDataSources = new global::TUnit.Core.IDataSourceAttribute[]");
writer.AppendLine("{");
writer.Indent();
GenerateDataSourceAttribute(writer, compilation, classDataSourceAttribute, testMethod.MethodSymbol, testMethod.TypeSymbol);
GenerateDataSourceAttribute(writer, testMethod.CompilationContext, classDataSourceAttribute, testMethod.MethodSymbol, testMethod.TypeSymbol);
writer.Unindent();
writer.AppendLine("},");
}
Expand Down Expand Up @@ -5282,5 +5297,6 @@ public class InheritsTestsClassMetadata
public required INamedTypeSymbol TypeSymbol { get; init; }
public required ClassDeclarationSyntax ClassSyntax { get; init; }
public GeneratorAttributeSyntaxContext Context { get; init; }
public required CompilationContext CompilationContext { get; init; }
}

7 changes: 6 additions & 1 deletion TUnit.Core.SourceGenerator/Models/TestMethodMetadata.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using TUnit.Core.SourceGenerator.CodeGenerators.Writers;

namespace TUnit.Core.SourceGenerator.Models;

public record CompilationContext(CSharpCompilation Compilation, AttributeWriter AttributeWriter);

/// <summary>
/// Contains all the metadata about a test method discovered by the source generator.
/// </summary>
Expand All @@ -15,6 +19,7 @@ public class TestMethodMetadata : IEquatable<TestMethodMetadata>
public required int LineNumber { get; init; }
public required AttributeData TestAttribute { get; init; }
public GeneratorAttributeSyntaxContext? Context { get; init; }
public required CompilationContext CompilationContext { get; init; }
public required MethodDeclarationSyntax? MethodSyntax { get; init; }
public bool IsGenericType { get; init; }
public bool IsGenericMethod { get; init; }
Expand All @@ -23,7 +28,7 @@ public class TestMethodMetadata : IEquatable<TestMethodMetadata>
/// All attributes on the method, stored for later use during data combination generation
/// </summary>
public ImmutableArray<AttributeData> MethodAttributes { get; init; } = ImmutableArray<AttributeData>.Empty;

/// <summary>
/// The inheritance depth of this test method.
/// 0 = method is declared directly in the test class
Expand Down
Loading