Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 32 additions & 44 deletions TUnit.Core.SourceGenerator/Helpers/InterfaceCache.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System.Collections.Concurrent;
using Microsoft.CodeAnalysis;
using TUnit.Core.SourceGenerator.Extensions;

Expand All @@ -7,29 +6,38 @@ namespace TUnit.Core.SourceGenerator.Helpers;
/// <summary>
/// Caches interface implementation checks to avoid repeated AllInterfaces traversals
/// </summary>
internal static class InterfaceCache
public static class InterfaceCache
{
private static readonly ConcurrentDictionary<(ITypeSymbol Type, string InterfaceName), bool> _implementsCache = new(TypeStringTupleComparer.Default);
private static readonly ConcurrentDictionary<(ITypeSymbol Type, string GenericInterfacePattern), INamedTypeSymbol?> _genericInterfaceCache = new(TypeStringTupleComparer.Default);

/// <summary>
/// Checks if a type implements a specific interface
/// </summary>
public static bool ImplementsInterface(ITypeSymbol type, string fullyQualifiedInterfaceName)
{
return _implementsCache.GetOrAdd((type, fullyQualifiedInterfaceName), key =>
key.Type.AllInterfaces.Any(i => i.GloballyQualified() == key.InterfaceName));
foreach (var i in type.AllInterfaces)
{
if (i.GloballyQualified() == fullyQualifiedInterfaceName)
{
return true;
}
}

return false;
}

/// <summary>
/// Checks if a type implements a generic interface and returns the matching interface symbol
/// </summary>
public static INamedTypeSymbol? GetGenericInterface(ITypeSymbol type, string fullyQualifiedGenericPattern)
{
return _genericInterfaceCache.GetOrAdd((type, fullyQualifiedGenericPattern), key =>
key.Type.AllInterfaces.FirstOrDefault(i =>
i.IsGenericType &&
i.ConstructedFrom.GloballyQualified() == key.GenericInterfacePattern));
foreach (var i in type.AllInterfaces)
{
if (i.IsGenericType && i.ConstructedFrom.GloballyQualified() == fullyQualifiedGenericPattern)
{
return i;
}
}

return null;
}

/// <summary>
Expand All @@ -45,18 +53,15 @@ public static bool ImplementsGenericInterface(ITypeSymbol type, string fullyQual
/// </summary>
public static bool IsAsyncEnumerable(ITypeSymbol type)
{
return _implementsCache.GetOrAdd((type, "System.Collections.Generic.IAsyncEnumerable<T>"), key =>
if (type is INamedTypeSymbol { IsGenericType: true } namedType &&
namedType.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable<T>")
{
if (key.Type is INamedTypeSymbol { IsGenericType: true } namedType &&
namedType.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable<T>")
{
return true;
}
return true;
}

return key.Type.AllInterfaces.Any(i =>
i.IsGenericType &&
i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable<T>");
});
return type.AllInterfaces.Any(i =>
i.IsGenericType &&
i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable<T>");
}

/// <summary>
Expand All @@ -69,29 +74,12 @@ public static bool IsEnumerable(ITypeSymbol type)
return false;
}

return _implementsCache.GetOrAdd((type, "System.Collections.IEnumerable"), key =>
key.Type.AllInterfaces.Any(i =>
i.OriginalDefinition.ToDisplayString() == "System.Collections.IEnumerable" ||
(i.IsGenericType && i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IEnumerable<T>")));
}
}

internal sealed class TypeStringTupleComparer : IEqualityComparer<(ITypeSymbol Type, string Name)>
{
public static readonly TypeStringTupleComparer Default = new();

private TypeStringTupleComparer() { }

public bool Equals((ITypeSymbol Type, string Name) x, (ITypeSymbol Type, string Name) y)
{
return Microsoft.CodeAnalysis.SymbolEqualityComparer.Default.Equals(x.Type, y.Type) && x.Name == y.Name;
}

public int GetHashCode((ITypeSymbol Type, string Name) obj)
{
unchecked
return type.AllInterfaces.Any(i =>
{
return (Microsoft.CodeAnalysis.SymbolEqualityComparer.Default.GetHashCode(obj.Type) * 397) ^ obj.Name.GetHashCode();
}
var originalDefintion = i.OriginalDefinition.ToDisplayString();
return originalDefintion == "System.Collections.IEnumerable" ||
(i.IsGenericType && originalDefintion ==
"System.Collections.Generic.IEnumerable<T>");
});
}
}
Loading