diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/DataSourceAttributeHelper.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/DataSourceAttributeHelper.cs index 40deca4e17..935fec6b69 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/DataSourceAttributeHelper.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/DataSourceAttributeHelper.cs @@ -1,5 +1,5 @@ using Microsoft.CodeAnalysis; -using TUnit.Core.SourceGenerator.Extensions; +using TUnit.Core.SourceGenerator.Helpers; namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers; @@ -12,10 +12,10 @@ public static bool IsDataSourceAttribute(INamedTypeSymbol? attributeClass) return false; } - // Check if the attribute implements IDataSourceAttribute - return attributeClass.AllInterfaces.Any(i => i.GloballyQualified() == "global::TUnit.Core.IDataSourceAttribute"); + // Check if the attribute implements IDataSourceAttribute (using cache) + return InterfaceCache.ImplementsInterface(attributeClass, "global::TUnit.Core.IDataSourceAttribute"); } - + public static bool IsTypedDataSourceAttribute(INamedTypeSymbol? attributeClass) { if (attributeClass == null) @@ -23,12 +23,10 @@ public static bool IsTypedDataSourceAttribute(INamedTypeSymbol? attributeClass) return false; } - // Check if the attribute implements ITypedDataSourceAttribute - return attributeClass.AllInterfaces.Any(i => - i.IsGenericType && - i.ConstructedFrom.GloballyQualified() == "global::TUnit.Core.ITypedDataSourceAttribute`1"); + // Check if the attribute implements ITypedDataSourceAttribute (using cache) + return InterfaceCache.ImplementsGenericInterface(attributeClass, "global::TUnit.Core.ITypedDataSourceAttribute`1"); } - + public static ITypeSymbol? GetTypedDataSourceType(INamedTypeSymbol? attributeClass) { if (attributeClass == null) @@ -36,10 +34,8 @@ public static bool IsTypedDataSourceAttribute(INamedTypeSymbol? attributeClass) return null; } - var typedInterface = attributeClass.AllInterfaces - .FirstOrDefault(i => i.IsGenericType && - i.ConstructedFrom.GloballyQualified() == "global::TUnit.Core.ITypedDataSourceAttribute`1"); - + var typedInterface = InterfaceCache.GetGenericInterface(attributeClass, "global::TUnit.Core.ITypedDataSourceAttribute`1"); + return typedInterface?.TypeArguments.FirstOrDefault(); } } \ No newline at end of file diff --git a/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs index c6fd4fd2cf..d4b2266807 100644 --- a/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs +++ b/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs @@ -35,32 +35,53 @@ public static string GetMetadataName(this Type type) public static IEnumerable GetMembersIncludingBase(this ITypeSymbol namedTypeSymbol, bool reverse = true) { - var list = new List(); - - var symbol = namedTypeSymbol; - - while (symbol is not null) + if (!reverse) { - if (symbol is IErrorTypeSymbol) + // Forward traversal - yield directly without allocations + var symbol = namedTypeSymbol; + while (symbol is not null && symbol.SpecialType != SpecialType.System_Object) { - throw new Exception($"ErrorTypeSymbol for {symbol.Name} - Have you added any missing file sources to the compilation?"); + if (symbol is IErrorTypeSymbol) + { + throw new Exception($"ErrorTypeSymbol for {symbol.Name} - Have you added any missing file sources to the compilation?"); + } + + foreach (var member in symbol.GetMembers()) + { + yield return member; + } + + symbol = symbol.BaseType; } - if (symbol.SpecialType == SpecialType.System_Object) + yield break; + } + + // Reverse traversal - collect hierarchy, then yield from base to derived + // Use stack to collect types (base to derived), then iterate members in forward order + var typeStack = new Stack(); + var current = namedTypeSymbol; + + while (current is not null && current.SpecialType != SpecialType.System_Object) + { + if (current is IErrorTypeSymbol) { - break; + throw new Exception($"ErrorTypeSymbol for {current.Name} - Have you added any missing file sources to the compilation?"); } - list.AddRange(reverse ? symbol.GetMembers().Reverse() : symbol.GetMembers()); - symbol = symbol.BaseType; + typeStack.Push(current); + current = current.BaseType; } - if (reverse) + // Yield members from base to derived + while (typeStack.Count > 0) { - list.Reverse(); + var type = typeStack.Pop(); + foreach (var member in type.GetMembers()) + { + yield return member; + } } - - return list; } public static IEnumerable GetSelfAndBaseTypes(this INamedTypeSymbol namedTypeSymbol) diff --git a/TUnit.Core.SourceGenerator/Generators/PropertyInjectionSourceGenerator.cs b/TUnit.Core.SourceGenerator/Generators/PropertyInjectionSourceGenerator.cs index 23084bee35..c25c168c84 100644 --- a/TUnit.Core.SourceGenerator/Generators/PropertyInjectionSourceGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/PropertyInjectionSourceGenerator.cs @@ -457,3 +457,20 @@ internal sealed class PropertyWithDataSourceAttribute public required IPropertySymbol Property { get; init; } public required AttributeData DataSourceAttribute { get; init; } } + +internal sealed class ClassWithDataSourcePropertiesComparer : IEqualityComparer +{ + public bool Equals(ClassWithDataSourceProperties? x, ClassWithDataSourceProperties? y) + { + if (ReferenceEquals(x, y)) return true; + if (x is null || y is null) return false; + + // Compare based on the class symbol - this handles partial classes correctly + return SymbolEqualityComparer.Default.Equals(x.ClassSymbol, y.ClassSymbol); + } + + public int GetHashCode(ClassWithDataSourceProperties obj) + { + return SymbolEqualityComparer.Default.GetHashCode(obj.ClassSymbol); + } +} diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs index 225bb2cff6..46bd3dd1ec 100644 --- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs @@ -9,6 +9,7 @@ using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; using TUnit.Core.SourceGenerator.CodeGenerators.Writers; using TUnit.Core.SourceGenerator.Extensions; +using TUnit.Core.SourceGenerator.Helpers; using TUnit.Core.SourceGenerator.Models; namespace TUnit.Core.SourceGenerator.Generators; @@ -1273,17 +1274,8 @@ private static void GeneratePropertyDataSourceFactory(CodeWriter writer, IProper private static bool IsAsyncEnumerable(ITypeSymbol type) { - // Check if the type itself is an IAsyncEnumerable - if (type is INamedTypeSymbol { IsGenericType: true } namedType && - namedType.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable") - { - return true; - } - - // Check if the type implements IAsyncEnumerable - return type.AllInterfaces.Any(i => - i.IsGenericType && - i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable"); + // Use cached interface check + return InterfaceCache.IsAsyncEnumerable(type); } private static bool IsTask(ITypeSymbol type) @@ -1295,14 +1287,8 @@ private static bool IsTask(ITypeSymbol type) private static bool IsEnumerable(ITypeSymbol type) { - if (type.SpecialType == SpecialType.System_String) - { - return false; - } - - return type.AllInterfaces.Any(i => - i.OriginalDefinition.ToDisplayString() == "System.Collections.IEnumerable" || - (i.IsGenericType && i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IEnumerable")); + // Use cached interface check (already handles string exclusion) + return InterfaceCache.IsEnumerable(type); } private static void WriteTypedConstant(CodeWriter writer, TypedConstant constant) diff --git a/TUnit.Core.SourceGenerator/Helpers/InterfaceCache.cs b/TUnit.Core.SourceGenerator/Helpers/InterfaceCache.cs new file mode 100644 index 0000000000..0064250efa --- /dev/null +++ b/TUnit.Core.SourceGenerator/Helpers/InterfaceCache.cs @@ -0,0 +1,99 @@ +using System.Collections.Concurrent; +using Microsoft.CodeAnalysis; +using TUnit.Core.SourceGenerator.Extensions; + +namespace TUnit.Core.SourceGenerator.Helpers; + +/// +/// Caches interface implementation checks to avoid repeated AllInterfaces traversals +/// +internal static class InterfaceCache +{ + private static readonly ConcurrentDictionary<(ITypeSymbol Type, string InterfaceName), bool> _implementsCache = new(TypeStringTupleComparer.Default); + private static readonly ConcurrentDictionary<(ITypeSymbol Type, string GenericInterfacePattern), INamedTypeSymbol?> _genericInterfaceCache = new(TypeStringTupleComparer.Default); + + /// + /// Checks if a type implements a specific interface + /// + public static bool ImplementsInterface(ITypeSymbol type, string fullyQualifiedInterfaceName) + { + return _implementsCache.GetOrAdd((type, fullyQualifiedInterfaceName), key => + key.Type.AllInterfaces.Any(i => i.GloballyQualified() == key.InterfaceName)); + } + + /// + /// Checks if a type implements a generic interface and returns the matching interface symbol + /// + public static INamedTypeSymbol? GetGenericInterface(ITypeSymbol type, string fullyQualifiedGenericPattern) + { + return _genericInterfaceCache.GetOrAdd((type, fullyQualifiedGenericPattern), key => + key.Type.AllInterfaces.FirstOrDefault(i => + i.IsGenericType && + i.ConstructedFrom.GloballyQualified() == key.GenericInterfacePattern)); + } + + /// + /// Checks if a type implements a generic interface + /// + public static bool ImplementsGenericInterface(ITypeSymbol type, string fullyQualifiedGenericPattern) + { + return GetGenericInterface(type, fullyQualifiedGenericPattern) != null; + } + + /// + /// Checks if a type implements IAsyncEnumerable<T> + /// + public static bool IsAsyncEnumerable(ITypeSymbol type) + { + return _implementsCache.GetOrAdd((type, "System.Collections.Generic.IAsyncEnumerable"), key => + { + // Check if the type itself is an IAsyncEnumerable + if (key.Type is INamedTypeSymbol { IsGenericType: true } namedType && + namedType.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable") + { + return true; + } + + // Check if the type implements IAsyncEnumerable + return key.Type.AllInterfaces.Any(i => + i.IsGenericType && + i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IAsyncEnumerable"); + }); + } + + /// + /// Checks if a type implements IEnumerable (excluding string) + /// + public static bool IsEnumerable(ITypeSymbol type) + { + if (type.SpecialType == SpecialType.System_String) + { + return false; + } + + return _implementsCache.GetOrAdd((type, "System.Collections.IEnumerable"), key => + key.Type.AllInterfaces.Any(i => + i.OriginalDefinition.ToDisplayString() == "System.Collections.IEnumerable" || + (i.IsGenericType && i.OriginalDefinition.ToDisplayString() == "System.Collections.Generic.IEnumerable"))); + } +} + +internal sealed class TypeStringTupleComparer : IEqualityComparer<(ITypeSymbol Type, string Name)> +{ + public static readonly TypeStringTupleComparer Default = new(); + + private TypeStringTupleComparer() { } + + public bool Equals((ITypeSymbol Type, string Name) x, (ITypeSymbol Type, string Name) y) + { + return Microsoft.CodeAnalysis.SymbolEqualityComparer.Default.Equals(x.Type, y.Type) && x.Name == y.Name; + } + + public int GetHashCode((ITypeSymbol Type, string Name) obj) + { + unchecked + { + return (Microsoft.CodeAnalysis.SymbolEqualityComparer.Default.GetHashCode(obj.Type) * 397) ^ obj.Name.GetHashCode(); + } + } +} diff --git a/TUnit.Core.SourceGenerator/Models/TypeWithDataSourceProperties.cs b/TUnit.Core.SourceGenerator/Models/TypeWithDataSourceProperties.cs index b95834a9f4..9b52e5e691 100644 --- a/TUnit.Core.SourceGenerator/Models/TypeWithDataSourceProperties.cs +++ b/TUnit.Core.SourceGenerator/Models/TypeWithDataSourceProperties.cs @@ -7,3 +7,17 @@ public struct TypeWithDataSourceProperties public INamedTypeSymbol TypeSymbol { get; init; } public List Properties { get; init; } } + +public sealed class TypeWithDataSourcePropertiesComparer : IEqualityComparer +{ + public bool Equals(TypeWithDataSourceProperties x, TypeWithDataSourceProperties y) + { + // Compare based on the type symbol - this handles partial classes correctly + return SymbolEqualityComparer.Default.Equals(x.TypeSymbol, y.TypeSymbol); + } + + public int GetHashCode(TypeWithDataSourceProperties obj) + { + return SymbolEqualityComparer.Default.GetHashCode(obj.TypeSymbol); + } +} diff --git a/TUnit.Core/Models/AssemblyHookContext.cs b/TUnit.Core/Models/AssemblyHookContext.cs index 9387037a49..a7f43367b3 100644 --- a/TUnit.Core/Models/AssemblyHookContext.cs +++ b/TUnit.Core/Models/AssemblyHookContext.cs @@ -27,23 +27,31 @@ internal AssemblyHookContext(TestSessionContext testSessionContext) : base(testS public required Assembly Assembly { get; init; } private readonly List _testClasses = []; + private TestContext[]? _cachedAllTests; public void AddClass(ClassHookContext classHookContext) { _testClasses.Add(classHookContext); + InvalidateCache(); } public IReadOnlyList TestClasses => _testClasses; - public IReadOnlyList AllTests => TestClasses.SelectMany(x => x.Tests).ToArray(); + public IReadOnlyList AllTests => _cachedAllTests ??= TestClasses.SelectMany(x => x.Tests).ToArray(); public int TestCount => AllTests.Count; + private void InvalidateCache() + { + _cachedAllTests = null; + } + internal bool FirstTestStarted { get; set; } internal void RemoveClass(ClassHookContext classContext) { _testClasses.Remove(classContext); + InvalidateCache(); if (_testClasses.Count == 0) { diff --git a/TUnit.Core/Models/TestSessionContext.cs b/TUnit.Core/Models/TestSessionContext.cs index 09598aa39d..01d838115a 100644 --- a/TUnit.Core/Models/TestSessionContext.cs +++ b/TUnit.Core/Models/TestSessionContext.cs @@ -58,17 +58,26 @@ internal TestSessionContext(TestDiscoveryContext beforeTestDiscoveryContext) : b public required string? TestFilter { get; init; } private readonly List _assemblies = []; + private ClassHookContext[]? _cachedTestClasses; + private TestContext[]? _cachedAllTests; public void AddAssembly(AssemblyHookContext assemblyHookContext) { _assemblies.Add(assemblyHookContext); + InvalidateCaches(); } public IReadOnlyList Assemblies => _assemblies; - public IReadOnlyList TestClasses => Assemblies.SelectMany(x => x.TestClasses).ToArray(); + public IReadOnlyList TestClasses => _cachedTestClasses ??= Assemblies.SelectMany(x => x.TestClasses).ToArray(); - public IReadOnlyList AllTests => TestClasses.SelectMany(x => x.Tests).ToArray(); + public IReadOnlyList AllTests => _cachedAllTests ??= TestClasses.SelectMany(x => x.Tests).ToArray(); + + private void InvalidateCaches() + { + _cachedTestClasses = null; + _cachedAllTests = null; + } internal bool FirstTestStarted { get; set; } @@ -82,6 +91,7 @@ public void AddArtifact(Artifact artifact) internal void RemoveAssembly(AssemblyHookContext assemblyContext) { _assemblies.Remove(assemblyContext); + InvalidateCaches(); } internal override void SetAsyncLocalContext() diff --git a/TUnit.Core/TestContext.cs b/TUnit.Core/TestContext.cs index e2b69e526e..e2edf48255 100644 --- a/TUnit.Core/TestContext.cs +++ b/TUnit.Core/TestContext.cs @@ -255,7 +255,9 @@ public void AddLinkedCancellationToken(CancellationToken cancellationToken) else { var existingToken = LinkedCancellationTokens.Token; + var oldCts = LinkedCancellationTokens; LinkedCancellationTokens = CancellationTokenSource.CreateLinkedTokenSource(existingToken, cancellationToken); + oldCts.Dispose(); } CancellationToken = LinkedCancellationTokens.Token; diff --git a/TUnit.Engine/Discovery/ReflectionAttributeExtractor.cs b/TUnit.Engine/Discovery/ReflectionAttributeExtractor.cs index 01b78e1d4e..09b3a4aa5f 100644 --- a/TUnit.Engine/Discovery/ReflectionAttributeExtractor.cs +++ b/TUnit.Engine/Discovery/ReflectionAttributeExtractor.cs @@ -16,6 +16,11 @@ internal static class ReflectionAttributeExtractor /// private static readonly ConcurrentDictionary _attributeCache = new(); + /// + /// Cache for multiple attributes lookups to avoid repeated reflection calls + /// + private static readonly ConcurrentDictionary _attributesCache = new(); + /// /// Composite cache key combining type, method, and attribute type information /// @@ -103,17 +108,24 @@ public static IEnumerable GetAttributes(Type testClass, MethodInfo? testMe } #endif - var attributes = new List(); - - attributes.AddRange(testClass.Assembly.GetCustomAttributes()); - attributes.AddRange(testClass.GetCustomAttributes()); + var cacheKey = new AttributeCacheKey(testClass, testMethod, typeof(T)); - if (testMethod != null) + var cachedAttributes = _attributesCache.GetOrAdd(cacheKey, key => { - attributes.AddRange(testMethod.GetCustomAttributes()); - } + var attributes = new List(); + + attributes.AddRange(key.TestClass.Assembly.GetCustomAttributes(key.AttributeType)); + attributes.AddRange(key.TestClass.GetCustomAttributes(key.AttributeType)); + + if (key.TestMethod != null) + { + attributes.AddRange(key.TestMethod.GetCustomAttributes(key.AttributeType)); + } + + return attributes.ToArray(); + }); - return attributes; + return cachedAttributes.Cast(); } public static string[] ExtractCategories(Type testClass, MethodInfo testMethod) diff --git a/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs b/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs index 890a6e4b8b..b61e758d35 100644 --- a/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs +++ b/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs @@ -22,8 +22,8 @@ namespace TUnit.Engine.Discovery; internal sealed class ReflectionTestDataCollector : ITestDataCollector { private static readonly ConcurrentDictionary _scannedAssemblies = new(); - private static readonly ConcurrentBag _discoveredTests = new(); - private static readonly Lock _resultsLock = new(); // Only for final results aggregation + private static readonly List _discoveredTests = new(capacity: 1000); // Pre-sized for typical test suites + private static readonly Lock _discoveredTestsLock = new(); // Lock for thread-safe access to _discoveredTests private static readonly ConcurrentDictionary _assemblyTypesCache = new(); private static readonly ConcurrentDictionary _typeMethodsCache = new(); @@ -41,7 +41,10 @@ private static Assembly[] GetCachedAssemblies() public static void ClearCaches() { _scannedAssemblies.Clear(); - while (_discoveredTests.TryTake(out _)) { } + lock (_discoveredTestsLock) + { + _discoveredTests.Clear(); + } _assemblyTypesCache.Clear(); _typeMethodsCache.Clear(); lock (_assemblyCacheLock) @@ -132,16 +135,11 @@ public async Task> CollectTestsAsync(string testSessio var dynamicTests = await DiscoverDynamicTests(testSessionId).ConfigureAwait(false); newTests.AddRange(dynamicTests); - // Add to concurrent collection without locking - foreach (var test in newTests) - { - _discoveredTests.Add(test); - } - - // Only lock when creating the final result list - lock (_resultsLock) + // Add to discovered tests with lock (better enumeration performance than ConcurrentBag) + lock (_discoveredTestsLock) { - return _discoveredTests.ToList(); + _discoveredTests.AddRange(newTests); + return new List(_discoveredTests); } } @@ -179,8 +177,10 @@ public async IAsyncEnumerable CollectTestsStreamingAsync( // Stream tests from this assembly await foreach (var test in DiscoverTestsInAssemblyStreamingAsync(assembly, cancellationToken)) { - // Use lock-free ConcurrentBag - _discoveredTests.Add(test); + lock (_discoveredTestsLock) + { + _discoveredTests.Add(test); + } yield return test; } } @@ -188,7 +188,10 @@ public async IAsyncEnumerable CollectTestsStreamingAsync( // Stream dynamic tests await foreach (var dynamicTest in DiscoverDynamicTestsStreamingAsync(testSessionId, cancellationToken)) { - _discoveredTests.Add(dynamicTest); + lock (_discoveredTestsLock) + { + _discoveredTests.Add(dynamicTest); + } yield return dynamicTest; } } diff --git a/TUnit.Engine/Services/CircularDependencyDetector.cs b/TUnit.Engine/Services/CircularDependencyDetector.cs index aa37fe6dd4..70bd5cf353 100644 --- a/TUnit.Engine/Services/CircularDependencyDetector.cs +++ b/TUnit.Engine/Services/CircularDependencyDetector.cs @@ -18,7 +18,7 @@ internal sealed class CircularDependencyDetector { var testList = tests.ToList(); var circularDependencies = new List<(AbstractExecutableTest Test, List DependencyChain)>(); - var visitedStates = new Dictionary(); + var visitedStates = new Dictionary(capacity: testList.Count); foreach (var test in testList) { diff --git a/TUnit.Engine/Services/DataSourceInitializer.cs b/TUnit.Engine/Services/DataSourceInitializer.cs index f0dd58ec96..06093c07b7 100644 --- a/TUnit.Engine/Services/DataSourceInitializer.cs +++ b/TUnit.Engine/Services/DataSourceInitializer.cs @@ -74,7 +74,7 @@ private async Task InitializeDataSourceAsync( try { // Ensure we have required context - objectBag ??= new Dictionary(); + objectBag ??= new Dictionary(capacity: 8); events ??= new TestContextEvents(); // Initialize the data source directly here @@ -111,7 +111,7 @@ await _propertyInjectionService.InjectPropertiesIntoObjectAsync( #endif private async Task InitializeNestedObjectsAsync(object rootObject) { - var objectsByDepth = new Dictionary>(); + var objectsByDepth = new Dictionary>(capacity: 4); var visitedObjects = new HashSet(); // Collect all nested property-injected objects grouped by depth diff --git a/TUnit.Engine/Services/TestDependencyResolver.cs b/TUnit.Engine/Services/TestDependencyResolver.cs index 0c99d8c8b3..42c97a1e2c 100644 --- a/TUnit.Engine/Services/TestDependencyResolver.cs +++ b/TUnit.Engine/Services/TestDependencyResolver.cs @@ -103,7 +103,7 @@ private bool ResolveDependenciesForTest(AbstractExecutableTest test) if (allResolved) { - var uniqueDependencies = new Dictionary(); + var uniqueDependencies = new Dictionary(capacity: 8); foreach (var dep in resolvedDependencies) { if (dep.Test == test) diff --git a/TUnit.Engine/Services/TestFinder.cs b/TUnit.Engine/Services/TestFinder.cs index aa573513fe..bf1e782825 100644 --- a/TUnit.Engine/Services/TestFinder.cs +++ b/TUnit.Engine/Services/TestFinder.cs @@ -20,8 +20,14 @@ public TestFinder(TestDiscoveryService discoveryService) /// public IEnumerable GetTests(Type classType) { - return _discoveryService.GetCachedTestContexts() - .Where(t => t.TestDetails?.ClassType == classType); + var allTests = _discoveryService.GetCachedTestContexts(); + foreach (var test in allTests) + { + if (test.TestDetails?.ClassType == classType) + { + yield return test; + } + } } /// @@ -30,26 +36,48 @@ public IEnumerable GetTests(Type classType) public TestContext[] GetTestsByNameAndParameters(string testName, IEnumerable methodParameterTypes, Type classType, IEnumerable classParameterTypes, IEnumerable classArguments) { - var paramTypes = methodParameterTypes?.ToArray() ?? [ - ]; - var classParamTypes = classParameterTypes?.ToArray() ?? [ - ]; + var paramTypes = methodParameterTypes?.ToArray() ?? []; + var classParamTypes = classParameterTypes?.ToArray() ?? []; var allTests = _discoveryService.GetCachedTestContexts(); + var results = new List(); // If no parameter types are specified, match by name and class type only if (paramTypes.Length == 0 && classParamTypes.Length == 0) { - return allTests.Where(t => - t.TestName == testName && - t.TestDetails?.ClassType == classType).ToArray(); + foreach (var test in allTests) + { + if (test.TestName == testName && test.TestDetails?.ClassType == classType) + { + results.Add(test); + } + } + return results.ToArray(); + } + + // Match with parameter types + foreach (var test in allTests) + { + if (test.TestName != testName || test.TestDetails?.ClassType != classType) + { + continue; + } + + var testParams = test.TestDetails.MethodMetadata.Parameters.ToArray(); + var testParamTypes = new Type[testParams.Length]; + for (int i = 0; i < testParams.Length; i++) + { + testParamTypes[i] = testParams[i].Type; + } + + if (ParameterTypesMatch(testParamTypes, paramTypes) && + ClassParametersMatch(test, classParamTypes, classArguments)) + { + results.Add(test); + } } - return allTests.Where(t => - t.TestName == testName && - t.TestDetails?.ClassType == classType && - ParameterTypesMatch(t.TestDetails.MethodMetadata.Parameters.Select(p => p.Type).ToArray(), paramTypes) && - ClassParametersMatch(t, classParamTypes, classArguments)).ToArray(); + return results.ToArray(); } private bool ParameterTypesMatch(Type[]? testParamTypes, Type[] expectedParamTypes) diff --git a/TUnit.Engine/Services/TestGroupingService.cs b/TUnit.Engine/Services/TestGroupingService.cs index 4ac3655d5b..786ec9131f 100644 --- a/TUnit.Engine/Services/TestGroupingService.cs +++ b/TUnit.Engine/Services/TestGroupingService.cs @@ -17,7 +17,7 @@ internal sealed class TestGroupingService : ITestGroupingService private struct TestSortKey { public int ExecutionPriority { get; init; } - public string? ClassFullName { get; init; } + public string ClassFullName { get; init; } // Cached to avoid repeated property access public int NotInParallelOrder { get; init; } public NotInParallelConstraint? NotInParallelConstraint { get; init; } } @@ -40,7 +40,7 @@ public ValueTask GroupTestsByConstraintsAsync(IEnumerable GroupTestsByConstraintsAsync(IEnumerable(); - var keyedNotInParallelList = new List<(AbstractExecutableTest Test, IReadOnlyList ConstraintKeys, TestPriority Priority)>(); + var notInParallelList = new List<(AbstractExecutableTest Test, string ClassName, TestPriority Priority)>(); + var keyedNotInParallelList = new List<(AbstractExecutableTest Test, string ClassName, IReadOnlyList ConstraintKeys, TestPriority Priority)>(); var parallelTests = new List(); - var parallelGroups = new Dictionary>>(); - var constrainedParallelGroups = new Dictionary Unconstrained, List<(AbstractExecutableTest, IReadOnlyList, TestPriority)> Keyed)>(); + var parallelGroups = new Dictionary>>(capacity: 16); + var constrainedParallelGroups = new Dictionary Unconstrained, List<(AbstractExecutableTest, string, IReadOnlyList, TestPriority)> Keyed)>(capacity: 16); foreach (var (test, sortKey) in testsWithKeys) { @@ -77,11 +77,11 @@ public ValueTask GroupTestsByConstraintsAsync(IEnumerable GroupTestsByConstraintsAsync(IEnumerable GroupTestsByConstraintsAsync(IEnumerable { - var classA = a.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty; - var classB = b.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty; - var classCompare = string.CompareOrdinal(classA, classB); + var classCompare = string.CompareOrdinal(a.ClassName, b.ClassName); if (classCompare != 0) return classCompare; - + var priorityCompare = b.Priority.Priority.CompareTo(a.Priority.Priority); if (priorityCompare != 0) return priorityCompare; - + return a.Priority.Order.CompareTo(b.Priority.Order); }); - + var sortedNotInParallel = new AbstractExecutableTest[notInParallelList.Count]; for (int i = 0; i < notInParallelList.Count; i++) { @@ -121,17 +119,15 @@ public ValueTask GroupTestsByConstraintsAsync(IEnumerable { - var classA = a.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty; - var classB = b.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty; - var classCompare = string.CompareOrdinal(classA, classB); + var classCompare = string.CompareOrdinal(a.ClassName, b.ClassName); if (classCompare != 0) return classCompare; - + var priorityCompare = b.Priority.Priority.CompareTo(a.Priority.Priority); if (priorityCompare != 0) return priorityCompare; - + return a.Priority.Order.CompareTo(b.Priority.Order); }); - + var keyedArrays = new (AbstractExecutableTest, IReadOnlyList, int)[keyedNotInParallelList.Count]; for (int i = 0; i < keyedNotInParallelList.Count; i++) { @@ -140,7 +136,7 @@ public ValueTask GroupTestsByConstraintsAsync(IEnumerable(); + var finalConstrainedGroups = new Dictionary(capacity: constrainedParallelGroups.Count); foreach (var kvp in constrainedParallelGroups) { var groupName = kvp.Key; @@ -149,22 +145,20 @@ public ValueTask GroupTestsByConstraintsAsync(IEnumerable { - var classA = a.Item1.Context.ClassContext?.ClassType?.FullName ?? string.Empty; - var classB = b.Item1.Context.ClassContext?.ClassType?.FullName ?? string.Empty; - var classCompare = string.CompareOrdinal(classA, classB); + var classCompare = string.CompareOrdinal(a.Item2, b.Item2); if (classCompare != 0) return classCompare; - - var priorityCompare = b.Item3.Priority.CompareTo(a.Item3.Priority); + + var priorityCompare = b.Item4.Priority.CompareTo(a.Item4.Priority); if (priorityCompare != 0) return priorityCompare; - - return a.Item3.Order.CompareTo(b.Item3.Order); + + return a.Item4.Order.CompareTo(b.Item4.Order); }); - + var sortedKeyed = new (AbstractExecutableTest, IReadOnlyList, int)[keyed.Count]; for (int i = 0; i < keyed.Count; i++) { var item = keyed[i]; - sortedKeyed[i] = (item.Item1, item.Item2, item.Item3.GetHashCode()); + sortedKeyed[i] = (item.Item1, item.Item3, item.Item4.GetHashCode()); } finalConstrainedGroups[groupName] = new GroupedConstrainedTests @@ -188,9 +182,10 @@ public ValueTask GroupTestsByConstraintsAsync(IEnumerable notInParallelList, - List<(AbstractExecutableTest Test, IReadOnlyList ConstraintKeys, TestPriority Priority)> keyedNotInParallelList) + List<(AbstractExecutableTest Test, string ClassName, TestPriority Priority)> notInParallelList, + List<(AbstractExecutableTest Test, string ClassName, IReadOnlyList ConstraintKeys, TestPriority Priority)> keyedNotInParallelList) { var order = constraint.Order; var priority = test.Context.ExecutionPriority; @@ -198,12 +193,12 @@ private static void ProcessNotInParallelConstraint( if (constraint.NotInParallelConstraintKeys.Count == 0) { - notInParallelList.Add((test, testPriority)); + notInParallelList.Add((test, className, testPriority)); } else { // Add test only once with all its constraint keys - keyedNotInParallelList.Add((test, constraint.NotInParallelConstraintKeys, testPriority)); + keyedNotInParallelList.Add((test, className, constraint.NotInParallelConstraintKeys, testPriority)); } } @@ -229,29 +224,30 @@ private static void ProcessParallelGroupConstraint( private static void ProcessCombinedConstraints( AbstractExecutableTest test, + string className, ParallelGroupConstraint parallelGroup, NotInParallelConstraint notInParallel, - Dictionary Unconstrained, List<(AbstractExecutableTest, IReadOnlyList, TestPriority)> Keyed)> constrainedGroups) + Dictionary Unconstrained, List<(AbstractExecutableTest, string, IReadOnlyList, TestPriority)> Keyed)> constrainedGroups) { if (!constrainedGroups.TryGetValue(parallelGroup.Group, out var group)) { - group = (new List(), new List<(AbstractExecutableTest, IReadOnlyList, TestPriority)>()); + group = (new List(), new List<(AbstractExecutableTest, string, IReadOnlyList, TestPriority)>()); constrainedGroups[parallelGroup.Group] = group; } - + // Add to keyed tests within the parallel group var order = notInParallel.Order; var priority = test.Context.ExecutionPriority; var testPriority = new TestPriority(priority, order); - + if (notInParallel.NotInParallelConstraintKeys.Count > 0) { - group.Keyed.Add((test, notInParallel.NotInParallelConstraintKeys, testPriority)); + group.Keyed.Add((test, className, notInParallel.NotInParallelConstraintKeys, testPriority)); } else { // NotInParallel without keys means sequential within the group - group.Keyed.Add((test, new List { "__global__" }, testPriority)); + group.Keyed.Add((test, className, new List { "__global__" }, testPriority)); } } }