Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion TUnit.Core.SourceGenerator/CodeWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,25 @@ public ICodeWriter SetIndentLevel(int level)
private string GetIndentation(int level)
{
var key = (_indentString, level);
return _indentCache.GetOrAdd(key, static k => string.Concat(Enumerable.Repeat(k.Item1, k.Item2)));
return _indentCache.GetOrAdd(key, static k =>
{
var (indent, count) = k;

// Fast path: a single-space indent (or any whitespace-only indent) can be built
// directly without allocating an intermediate sequence.
if (indent == " ")
{
return new string(' ', count);
}

var builder = new StringBuilder(indent.Length * count);
for (var i = 0; i < count; i++)
{
builder.Append(indent);
}

return builder.ToString();
});
}

/// <summary>
Expand Down
14 changes: 11 additions & 3 deletions TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@ public static AttributeData GetRequiredTestAttribute(this IMethodSymbol methodSy
return null;
}

return attributes
.FirstOrDefault(x => x.AttributeClass?.BaseType?.GloballyQualified()
== WellKnownFullyQualifiedClassNames.BaseTestAttribute.WithGlobalPrefix);
var baseTestAttribute = WellKnownFullyQualifiedClassNames.BaseTestAttribute.WithGlobalPrefix;

foreach (var attribute in attributes)
{
if (attribute.AttributeClass?.BaseType?.GloballyQualified() == baseTestAttribute)
{
return attribute;
}
}

return null;
}
}
58 changes: 47 additions & 11 deletions TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,46 @@ void TryAddInstantiation(ITypeSymbol[] typeArguments, AttributeData? specificAtt
});
}

var methodArgumentsAttributes = testMethod.MethodAttributes
.Where(a => a.AttributeClass?.Name == "ArgumentsAttribute")
.ToArray();
// Single-pass classification of method attributes into named buckets, replacing the
// repeated `.Where(...).ToArray()` scans that previously walked MethodAttributes 6-8 times.
List<AttributeData>? methodArgumentsBucket = null;
List<AttributeData>? methodDataSourceBucket = null;
List<AttributeData>? typedDataSourceBucket = null;
List<AttributeData>? generateGenericTestBucket = null;

foreach (var attribute in testMethod.MethodAttributes)
{
var attributeClass = attribute.AttributeClass;
if (attributeClass is null)
{
continue;
}

// Buckets are independent: an attribute may match more than one (e.g.
// MethodDataSourceAttribute is also an IDataSourceAttribute), matching the original
// behaviour where each `.Where(...)` scan was evaluated separately.
switch (attributeClass.Name)
{
case "ArgumentsAttribute":
(methodArgumentsBucket ??= []).Add(attribute);
break;
case "MethodDataSourceAttribute":
(methodDataSourceBucket ??= []).Add(attribute);
break;
}

if (DataSourceAttributeHelper.IsDataSourceAttribute(attributeClass))
{
(typedDataSourceBucket ??= []).Add(attribute);
}

if (attributeClass.IsOrInherits("global::TUnit.Core.GenerateGenericTestAttribute"))
{
(generateGenericTestBucket ??= []).Add(attribute);
}
}

var methodArgumentsAttributes = methodArgumentsBucket ?? [];

var classArgumentsAttributes = testMethod.IsGenericType
? testMethod.TypeSymbol.GetAttributes()
Expand Down Expand Up @@ -679,7 +716,7 @@ void TryAddInstantiation(ITypeSymbol[] typeArguments, AttributeData? specificAtt
}

// Handle generic classes with non-generic methods that have method-level Arguments
if (testMethod is { IsGenericType: true, IsGenericMethod: false } && methodArgumentsAttributes.Length > 0)
if (testMethod is { IsGenericType: true, IsGenericMethod: false } && methodArgumentsAttributes.Count > 0)
{
foreach (var methodArgAttr in methodArgumentsAttributes)
{
Expand All @@ -692,7 +729,7 @@ void TryAddInstantiation(ITypeSymbol[] typeArguments, AttributeData? specificAtt
}

// Process typed data source attributes
foreach (var dataSourceAttr in testMethod.MethodAttributes.Where(a => DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass)))
foreach (var dataSourceAttr in typedDataSourceBucket ?? Enumerable.Empty<AttributeData>())
{
var inferredTypes = InferTypesFromDataSourceAttribute(testMethod.MethodSymbol, dataSourceAttr);
if (inferredTypes is { Length: > 0 })
Expand All @@ -714,7 +751,7 @@ void TryAddInstantiation(ITypeSymbol[] typeArguments, AttributeData? specificAtt
// Process MethodDataSource attributes for generic classes (non-generic methods)
if (testMethod is { IsGenericType: true, IsGenericMethod: false })
{
foreach (var mdsAttr in testMethod.MethodAttributes.Where(a => a.AttributeClass?.Name == "MethodDataSourceAttribute"))
foreach (var mdsAttr in methodDataSourceBucket ?? Enumerable.Empty<AttributeData>())
{
var inferredTypes = InferClassTypesFromMethodDataSource(testMethod, mdsAttr);
if (inferredTypes is { Length: > 0 })
Expand All @@ -734,7 +771,7 @@ void TryAddInstantiation(ITypeSymbol[] typeArguments, AttributeData? specificAtt
// Process MethodDataSource attributes for generic methods
if (testMethod.IsGenericMethod)
{
foreach (var mdsAttr in testMethod.MethodAttributes.Where(a => a.AttributeClass?.Name == "MethodDataSourceAttribute"))
foreach (var mdsAttr in methodDataSourceBucket ?? Enumerable.Empty<AttributeData>())
{
var inferredTypes = InferTypesFromMethodDataSource(testMethod, mdsAttr);
if (inferredTypes is { Length: > 0 })
Expand All @@ -754,7 +791,7 @@ void TryAddInstantiation(ITypeSymbol[] typeArguments, AttributeData? specificAtt
{
if (testMethod.IsGenericMethod)
{
foreach (var methodArgAttr in testMethod.MethodAttributes.Where(a => a.AttributeClass?.Name == "ArgumentsAttribute"))
foreach (var methodArgAttr in methodArgumentsAttributes)
{
var methodInferredTypes = InferTypesFromArgumentsAttribute(testMethod.MethodSymbol, methodArgAttr, compilation);
if (methodInferredTypes is { Length: > 0 })
Expand All @@ -774,9 +811,8 @@ void TryAddInstantiation(ITypeSymbol[] typeArguments, AttributeData? specificAtt
// Process GenerateGenericTest attributes
// GenerateGenericTestAttribute takes params Type[] in its constructor, so extract from constructor args
{
var methodGenericTestAttrs = testMethod.IsGenericMethod
? testMethod.MethodAttributes
.Where(a => a.AttributeClass?.IsOrInherits("global::TUnit.Core.GenerateGenericTestAttribute") is true)
var methodGenericTestAttrs = testMethod.IsGenericMethod && generateGenericTestBucket is not null
? generateGenericTestBucket
.Select(ExtractTypeArgsFromGenerateGenericTestAttribute)
.Where(t => t is { Length: > 0 })
.ToList()
Expand Down
36 changes: 28 additions & 8 deletions TUnit.Core.SourceGenerator/Helpers/InterfaceCache.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,47 @@
using System.Collections.Immutable;
using System.Runtime.CompilerServices;
using Microsoft.CodeAnalysis;
using TUnit.Core.SourceGenerator.Extensions;

namespace TUnit.Core.SourceGenerator.Helpers;

/// <summary>
/// Caches interface implementation checks to avoid repeated AllInterfaces traversals
/// Caches interface implementation checks to avoid repeated AllInterfaces traversals.
/// </summary>
/// <remarks>
/// Cache entries are keyed by the <see cref="ITypeSymbol"/> itself and held in a
/// <see cref="ConditionalWeakTable{TKey,TValue}"/>. This ties each entry's lifetime to the
/// symbol's lifetime, so the cache is reclaimed automatically when a <see cref="Compilation"/>
/// is collected. This avoids the cross-compilation symbol leak a long-lived static dictionary
/// would cause in extended IDE sessions.
/// </remarks>
public static class InterfaceCache
{
/// <summary>
/// Checks if a type implements a specific interface
/// Per-type cache of the globally-qualified names of every interface the type implements.
/// </summary>
public static bool ImplementsInterface(ITypeSymbol type, string fullyQualifiedInterfaceName)
private static readonly ConditionalWeakTable<ITypeSymbol, ImmutableHashSet<string>> InterfaceNames = new();

private static ImmutableHashSet<string> GetInterfaceNames(ITypeSymbol type)
{
foreach (var i in type.AllInterfaces)
return InterfaceNames.GetValue(type, static t =>
{
if (i.GloballyQualified() == fullyQualifiedInterfaceName)
var builder = ImmutableHashSet.CreateBuilder<string>(StringComparer.Ordinal);
foreach (var i in t.AllInterfaces)
{
return true;
builder.Add(i.GloballyQualified());
}
}

return false;
return builder.ToImmutable();
});
}

/// <summary>
/// Checks if a type implements a specific interface
/// </summary>
public static bool ImplementsInterface(ITypeSymbol type, string fullyQualifiedInterfaceName)
{
return GetInterfaceNames(type).Contains(fullyQualifiedInterfaceName);
}

/// <summary>
Expand Down
Loading