diff --git a/TUnit.Core/Discovery/ObjectGraph.cs b/TUnit.Core/Discovery/ObjectGraph.cs index 8ae2ab5016..28fbb908b0 100644 --- a/TUnit.Core/Discovery/ObjectGraph.cs +++ b/TUnit.Core/Discovery/ObjectGraph.cs @@ -13,7 +13,7 @@ namespace TUnit.Core.Discovery; /// internal readonly struct ObjectGraph { - private readonly ConcurrentDictionary> _objectsByDepth; + private readonly Dictionary> _objectsByDepth; // Cached sorted depths (computed once in constructor) private readonly int[] _sortedDepthsDescending; @@ -22,8 +22,7 @@ internal readonly struct ObjectGraph /// Creates a new object graph from the discovered objects. /// /// Objects organized by depth level. - /// All unique objects in the graph. - public ObjectGraph(ConcurrentDictionary> objectsByDepth) + public ObjectGraph(Dictionary> objectsByDepth) { _objectsByDepth = objectsByDepth; diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs index fe0bf8b8c6..de94eebb9b 100644 --- a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -2,7 +2,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Reflection; -using TUnit.Core.Helpers; +using TUnit.Core.Extensions; using TUnit.Core.Interfaces; using TUnit.Core.Interfaces.SourceGenerator; using TUnit.Core.PropertyInjection; @@ -99,24 +99,18 @@ public static void ClearDiscoveryErrors() /// public ObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default) { - var objectsByDepth = new ConcurrentDictionary>(); - var allObjects = new HashSet(ReferenceComparer); - var allObjectsLock = new object(); // Thread-safety for allObjects HashSet - var visitedObjects = new ConcurrentDictionary(ReferenceComparer); + var objectsByDepth = new Dictionary>(); + var visitedObjects = new HashSet(ReferenceComparer); // Standard mode add callback (thread-safe) bool TryAddStandard(object obj, int depth) { - if (!visitedObjects.TryAdd(obj, 0)) + if (!visitedObjects.Add(obj)) { return false; } AddToDepth(objectsByDepth, depth, obj); - lock (allObjectsLock) - { - allObjects.Add(obj); - } return true; } @@ -125,7 +119,7 @@ bool TryAddStandard(object obj, int depth) CollectRootObjects( testContext.Metadata.TestDetails, TryAddStandard, - obj => DiscoverNestedObjects(obj, objectsByDepth, visitedObjects, allObjects, allObjectsLock, currentDepth: 1, cancellationToken), + obj => DiscoverNestedObjects(obj, objectsByDepth, visitedObjects, currentDepth: 1, cancellationToken), cancellationToken); return new ObjectGraph(objectsByDepth); @@ -134,20 +128,14 @@ bool TryAddStandard(object obj, int depth) /// public ObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default) { - var objectsByDepth = new ConcurrentDictionary>(); - var allObjects = new HashSet(ReferenceComparer); - var allObjectsLock = new object(); // Thread-safety for allObjects HashSet - var visitedObjects = new ConcurrentDictionary(ReferenceComparer); + var objectsByDepth = new Dictionary>(); + var visitedObjects = new HashSet(ReferenceComparer); - if (visitedObjects.TryAdd(rootObject, 0)) + if (visitedObjects.Add(rootObject)) { AddToDepth(objectsByDepth, 0, rootObject); - lock (allObjectsLock) - { - allObjects.Add(rootObject); - } - DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, allObjectsLock, currentDepth: 1, cancellationToken); + DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, currentDepth: 1, cancellationToken); } return new ObjectGraph(objectsByDepth); @@ -180,10 +168,8 @@ public ConcurrentDictionary> DiscoverAndTrackObjects(TestCo /// private void DiscoverNestedObjects( object obj, - ConcurrentDictionary> objectsByDepth, - ConcurrentDictionary visitedObjects, - HashSet allObjects, - object allObjectsLock, + Dictionary> objectsByDepth, + HashSet visitedObjects, int currentDepth, CancellationToken cancellationToken) { @@ -197,16 +183,12 @@ private void DiscoverNestedObjects( // Standard mode add callback: visitedObjects + objectsByDepth + allObjects (thread-safe) bool TryAddStandard(object value, int depth) { - if (!visitedObjects.TryAdd(value, 0)) + if (!visitedObjects.Add(value)) { return false; } AddToDepth(objectsByDepth, depth, value); - lock (allObjectsLock) - { - allObjects.Add(value); - } return true; } @@ -214,7 +196,7 @@ bool TryAddStandard(object value, int depth) // Recursive callback void Recurse(object value, int depth) { - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, allObjectsLock, depth, cancellationToken); + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, depth, cancellationToken); } // Traverse injectable properties (useSourceRegistrarCheck = false) @@ -281,15 +263,11 @@ private static bool ShouldSkipType(Type type) /// /// Adds an object to the specified depth level. - /// Thread-safe: uses lock to protect HashSet modifications. - /// - private static void AddToDepth(ConcurrentDictionary> objectsByDepth, int depth, object obj) + /// + private static void AddToDepth(Dictionary> objectsByDepth, int depth, object obj) { var hashSet = objectsByDepth.GetOrAdd(depth, _ => new HashSet(ReferenceComparer)); - lock (hashSet) - { - hashSet.Add(obj); - } + hashSet.Add(obj); } /// diff --git a/TUnit.Core/Extensions/DictionaryExtensions.cs b/TUnit.Core/Extensions/DictionaryExtensions.cs new file mode 100644 index 0000000000..1486cf39df --- /dev/null +++ b/TUnit.Core/Extensions/DictionaryExtensions.cs @@ -0,0 +1,18 @@ +namespace TUnit.Core.Extensions; + +internal static class DictionaryExtensions +{ + public static TValue GetOrAdd( + this IDictionary dictionary, + TKey key, + Func valueFactory) + { + if (!dictionary.TryGetValue(key, out TValue? value)) + { + value = valueFactory(key); + dictionary.Add(key, value); + } + + return value; + } +}