Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Microsoft.CodeAnalysis;
using TUnit.Core.SourceGenerator.Extensions;
using TUnit.Core.SourceGenerator.Helpers;

namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers;

Expand All @@ -12,34 +12,30 @@ public static bool IsDataSourceAttribute(INamedTypeSymbol? attributeClass)
return false;
}

// Check if the attribute implements IDataSourceAttribute
return attributeClass.AllInterfaces.Any(i => i.GloballyQualified() == "global::TUnit.Core.IDataSourceAttribute");
// Check if the attribute implements IDataSourceAttribute (using cache)
return InterfaceCache.ImplementsInterface(attributeClass, "global::TUnit.Core.IDataSourceAttribute");
}

public static bool IsTypedDataSourceAttribute(INamedTypeSymbol? attributeClass)
{
if (attributeClass == null)
{
return false;
}

// Check if the attribute implements ITypedDataSourceAttribute<T>
return attributeClass.AllInterfaces.Any(i =>
i.IsGenericType &&
i.ConstructedFrom.GloballyQualified() == "global::TUnit.Core.ITypedDataSourceAttribute`1");
// Check if the attribute implements ITypedDataSourceAttribute<T> (using cache)
return InterfaceCache.ImplementsGenericInterface(attributeClass, "global::TUnit.Core.ITypedDataSourceAttribute`1");
}

public static ITypeSymbol? GetTypedDataSourceType(INamedTypeSymbol? attributeClass)
{
if (attributeClass == null)
{
return null;
}

var typedInterface = attributeClass.AllInterfaces
.FirstOrDefault(i => i.IsGenericType &&
i.ConstructedFrom.GloballyQualified() == "global::TUnit.Core.ITypedDataSourceAttribute`1");

var typedInterface = InterfaceCache.GetGenericInterface(attributeClass, "global::TUnit.Core.ITypedDataSourceAttribute`1");

return typedInterface?.TypeArguments.FirstOrDefault();
}
}
51 changes: 36 additions & 15 deletions TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,53 @@ public static string GetMetadataName(this Type type)

public static IEnumerable<ISymbol> GetMembersIncludingBase(this ITypeSymbol namedTypeSymbol, bool reverse = true)
{
var list = new List<ISymbol>();

var symbol = namedTypeSymbol;

while (symbol is not null)
if (!reverse)
{
if (symbol is IErrorTypeSymbol)
// Forward traversal - yield directly without allocations
var symbol = namedTypeSymbol;
while (symbol is not null && symbol.SpecialType != SpecialType.System_Object)
{
throw new Exception($"ErrorTypeSymbol for {symbol.Name} - Have you added any missing file sources to the compilation?");
if (symbol is IErrorTypeSymbol)
{
throw new Exception($"ErrorTypeSymbol for {symbol.Name} - Have you added any missing file sources to the compilation?");
}

foreach (var member in symbol.GetMembers())
{
yield return member;
}

symbol = symbol.BaseType;
}

if (symbol.SpecialType == SpecialType.System_Object)
yield break;
}

// Reverse traversal - collect hierarchy, then yield from base to derived
// Use stack to collect types (base to derived), then iterate members in forward order
var typeStack = new Stack<ITypeSymbol>();
var current = namedTypeSymbol;

while (current is not null && current.SpecialType != SpecialType.System_Object)
{
if (current is IErrorTypeSymbol)
{
break;
throw new Exception($"ErrorTypeSymbol for {current.Name} - Have you added any missing file sources to the compilation?");
}

list.AddRange(reverse ? symbol.GetMembers().Reverse() : symbol.GetMembers());
symbol = symbol.BaseType;
typeStack.Push(current);
current = current.BaseType;
}

if (reverse)
// Yield members from base to derived
while (typeStack.Count > 0)
{
list.Reverse();
var type = typeStack.Pop();
foreach (var member in type.GetMembers())
{
yield return member;
}
}

return list;
}

public static IEnumerable<INamedTypeSymbol> GetSelfAndBaseTypes(this INamedTypeSymbol namedTypeSymbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,20 @@ internal sealed class PropertyWithDataSourceAttribute
public required IPropertySymbol Property { get; init; }
public required AttributeData DataSourceAttribute { get; init; }
}

internal sealed class ClassWithDataSourcePropertiesComparer : IEqualityComparer<ClassWithDataSourceProperties>
{
public bool Equals(ClassWithDataSourceProperties? x, ClassWithDataSourceProperties? y)
{
if (ReferenceEquals(x, y)) return true;
if (x is null || y is null) return false;

// Compare based on the class symbol - this handles partial classes correctly
return SymbolEqualityComparer.Default.Equals(x.ClassSymbol, y.ClassSymbol);
}

public int GetHashCode(ClassWithDataSourceProperties obj)
{
return SymbolEqualityComparer.Default.GetHashCode(obj.ClassSymbol);
}
}
24 changes: 5 additions & 19 deletions TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
using TUnit.Core.SourceGenerator.CodeGenerators.Writers;
using TUnit.Core.SourceGenerator.Extensions;
using TUnit.Core.SourceGenerator.Helpers;
using TUnit.Core.SourceGenerator.Models;

namespace TUnit.Core.SourceGenerator.Generators;
Expand Down Expand Up @@ -1273,17 +1274,8 @@ private static void GeneratePropertyDataSourceFactory(CodeWriter writer, IProper

private static bool IsAsyncEnumerable(ITypeSymbol type)
{
// Check if the type itself is an IAsyncEnumerable<T>
if (type is INamedTypeSymbol { IsGenericType: true } namedType &&
namedType.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable<T>")
{
return true;
}

// Check if the type implements IAsyncEnumerable<T>
return type.AllInterfaces.Any(i =>
i.IsGenericType &&
i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable<T>");
// Use cached interface check
return InterfaceCache.IsAsyncEnumerable(type);
}

private static bool IsTask(ITypeSymbol type)
Expand All @@ -1295,14 +1287,8 @@ private static bool IsTask(ITypeSymbol type)

private static bool IsEnumerable(ITypeSymbol type)
{
if (type.SpecialType == SpecialType.System_String)
{
return false;
}

return type.AllInterfaces.Any(i =>
i.OriginalDefinition.ToDisplayString() == "System.Collections.IEnumerable" ||
(i.IsGenericType && i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IEnumerable<T>"));
// Use cached interface check (already handles string exclusion)
return InterfaceCache.IsEnumerable(type);
}

private static void WriteTypedConstant(CodeWriter writer, TypedConstant constant)
Expand Down
99 changes: 99 additions & 0 deletions TUnit.Core.SourceGenerator/Helpers/InterfaceCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using System.Collections.Concurrent;
using Microsoft.CodeAnalysis;
using TUnit.Core.SourceGenerator.Extensions;

namespace TUnit.Core.SourceGenerator.Helpers;

/// <summary>
/// Caches interface implementation checks to avoid repeated AllInterfaces traversals
/// </summary>
internal static class InterfaceCache
{
private static readonly ConcurrentDictionary<(ITypeSymbol Type, string InterfaceName), bool> _implementsCache = new(TypeStringTupleComparer.Default);
private static readonly ConcurrentDictionary<(ITypeSymbol Type, string GenericInterfacePattern), INamedTypeSymbol?> _genericInterfaceCache = new(TypeStringTupleComparer.Default);

/// <summary>
/// Checks if a type implements a specific interface
/// </summary>
public static bool ImplementsInterface(ITypeSymbol type, string fullyQualifiedInterfaceName)
{
return _implementsCache.GetOrAdd((type, fullyQualifiedInterfaceName), key =>
key.Type.AllInterfaces.Any(i => i.GloballyQualified() == key.InterfaceName));
}

/// <summary>
/// Checks if a type implements a generic interface and returns the matching interface symbol
/// </summary>
public static INamedTypeSymbol? GetGenericInterface(ITypeSymbol type, string fullyQualifiedGenericPattern)
{
return _genericInterfaceCache.GetOrAdd((type, fullyQualifiedGenericPattern), key =>
key.Type.AllInterfaces.FirstOrDefault(i =>
i.IsGenericType &&
i.ConstructedFrom.GloballyQualified() == key.GenericInterfacePattern));
}

/// <summary>
/// Checks if a type implements a generic interface
/// </summary>
public static bool ImplementsGenericInterface(ITypeSymbol type, string fullyQualifiedGenericPattern)
{
return GetGenericInterface(type, fullyQualifiedGenericPattern) != null;
}

/// <summary>
/// Checks if a type implements IAsyncEnumerable&lt;T&gt;
/// </summary>
public static bool IsAsyncEnumerable(ITypeSymbol type)
{
return _implementsCache.GetOrAdd((type, "System.Collections.Generic.IAsyncEnumerable<T>"), key =>
{
// Check if the type itself is an IAsyncEnumerable<T>
if (key.Type is INamedTypeSymbol { IsGenericType: true } namedType &&
namedType.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable<T>")
{
return true;
}

// Check if the type implements IAsyncEnumerable<T>
return key.Type.AllInterfaces.Any(i =>
i.IsGenericType &&
i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable<T>");
});
}

/// <summary>
/// Checks if a type implements IEnumerable (excluding string)
/// </summary>
public static bool IsEnumerable(ITypeSymbol type)
{
if (type.SpecialType == SpecialType.System_String)
{
return false;
}

return _implementsCache.GetOrAdd((type, "System.Collections.IEnumerable"), key =>
key.Type.AllInterfaces.Any(i =>
i.OriginalDefinition.ToDisplayString() == "System.Collections.IEnumerable" ||
(i.IsGenericType && i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IEnumerable<T>")));
}
}

internal sealed class TypeStringTupleComparer : IEqualityComparer<(ITypeSymbol Type, string Name)>
{
public static readonly TypeStringTupleComparer Default = new();

private TypeStringTupleComparer() { }

public bool Equals((ITypeSymbol Type, string Name) x, (ITypeSymbol Type, string Name) y)
{
return Microsoft.CodeAnalysis.SymbolEqualityComparer.Default.Equals(x.Type, y.Type) && x.Name == y.Name;
}

public int GetHashCode((ITypeSymbol Type, string Name) obj)
{
unchecked
{
return (Microsoft.CodeAnalysis.SymbolEqualityComparer.Default.GetHashCode(obj.Type) * 397) ^ obj.Name.GetHashCode();
}
}
}
14 changes: 14 additions & 0 deletions TUnit.Core.SourceGenerator/Models/TypeWithDataSourceProperties.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,17 @@ public struct TypeWithDataSourceProperties
public INamedTypeSymbol TypeSymbol { get; init; }
public List<PropertyWithDataSource> Properties { get; init; }
}

public sealed class TypeWithDataSourcePropertiesComparer : IEqualityComparer<TypeWithDataSourceProperties>
{
public bool Equals(TypeWithDataSourceProperties x, TypeWithDataSourceProperties y)
{
// Compare based on the type symbol - this handles partial classes correctly
return SymbolEqualityComparer.Default.Equals(x.TypeSymbol, y.TypeSymbol);
}

public int GetHashCode(TypeWithDataSourceProperties obj)
{
return SymbolEqualityComparer.Default.GetHashCode(obj.TypeSymbol);
}
}
10 changes: 9 additions & 1 deletion TUnit.Core/Models/AssemblyHookContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,31 @@ internal AssemblyHookContext(TestSessionContext testSessionContext) : base(testS
public required Assembly Assembly { get; init; }

private readonly List<ClassHookContext> _testClasses = [];
private TestContext[]? _cachedAllTests;

public void AddClass(ClassHookContext classHookContext)
{
_testClasses.Add(classHookContext);
InvalidateCache();
}

public IReadOnlyList<ClassHookContext> TestClasses => _testClasses;

public IReadOnlyList<TestContext> AllTests => TestClasses.SelectMany(x => x.Tests).ToArray();
public IReadOnlyList<TestContext> AllTests => _cachedAllTests ??= TestClasses.SelectMany(x => x.Tests).ToArray();

public int TestCount => AllTests.Count;

private void InvalidateCache()
{
_cachedAllTests = null;
}

internal bool FirstTestStarted { get; set; }

internal void RemoveClass(ClassHookContext classContext)
{
_testClasses.Remove(classContext);
InvalidateCache();

if (_testClasses.Count == 0)
{
Expand Down
14 changes: 12 additions & 2 deletions TUnit.Core/Models/TestSessionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,26 @@ internal TestSessionContext(TestDiscoveryContext beforeTestDiscoveryContext) : b
public required string? TestFilter { get; init; }

private readonly List<AssemblyHookContext> _assemblies = [];
private ClassHookContext[]? _cachedTestClasses;
private TestContext[]? _cachedAllTests;

public void AddAssembly(AssemblyHookContext assemblyHookContext)
{
_assemblies.Add(assemblyHookContext);
InvalidateCaches();
}

public IReadOnlyList<AssemblyHookContext> Assemblies => _assemblies;

public IReadOnlyList<ClassHookContext> TestClasses => Assemblies.SelectMany(x => x.TestClasses).ToArray();
public IReadOnlyList<ClassHookContext> TestClasses => _cachedTestClasses ??= Assemblies.SelectMany(x => x.TestClasses).ToArray();

public IReadOnlyList<TestContext> AllTests => TestClasses.SelectMany(x => x.Tests).ToArray();
public IReadOnlyList<TestContext> AllTests => _cachedAllTests ??= TestClasses.SelectMany(x => x.Tests).ToArray();

private void InvalidateCaches()
{
_cachedTestClasses = null;
_cachedAllTests = null;
}

internal bool FirstTestStarted { get; set; }

Expand All @@ -82,6 +91,7 @@ public void AddArtifact(Artifact artifact)
internal void RemoveAssembly(AssemblyHookContext assemblyContext)
{
_assemblies.Remove(assemblyContext);
InvalidateCaches();
}

internal override void SetAsyncLocalContext()
Expand Down
2 changes: 2 additions & 0 deletions TUnit.Core/TestContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ public void AddLinkedCancellationToken(CancellationToken cancellationToken)
else
{
var existingToken = LinkedCancellationTokens.Token;
var oldCts = LinkedCancellationTokens;
LinkedCancellationTokens = CancellationTokenSource.CreateLinkedTokenSource(existingToken, cancellationToken);
oldCts.Dispose();
}

CancellationToken = LinkedCancellationTokens.Token;
Expand Down
Loading
Loading