Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions TUnit.Core/Discovery/ObjectGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace TUnit.Core.Discovery;
/// </remarks>
internal readonly struct ObjectGraph
{
private readonly ConcurrentDictionary<int, HashSet<object>> _objectsByDepth;
private readonly Dictionary<int, HashSet<object>> _objectsByDepth;

// Cached sorted depths (computed once in constructor)
private readonly int[] _sortedDepthsDescending;
Expand All @@ -22,8 +22,7 @@ internal readonly struct ObjectGraph
/// Creates a new object graph from the discovered objects.
/// </summary>
/// <param name="objectsByDepth">Objects organized by depth level.</param>
/// <param name="allObjects">All unique objects in the graph.</param>
public ObjectGraph(ConcurrentDictionary<int, HashSet<object>> objectsByDepth)
public ObjectGraph(Dictionary<int, HashSet<object>> objectsByDepth)
{
_objectsByDepth = objectsByDepth;

Expand Down
54 changes: 16 additions & 38 deletions TUnit.Core/Discovery/ObjectGraphDiscoverer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using TUnit.Core.Helpers;
using TUnit.Core.Extensions;
using TUnit.Core.Interfaces;
using TUnit.Core.Interfaces.SourceGenerator;
using TUnit.Core.PropertyInjection;
Expand Down Expand Up @@ -99,24 +99,18 @@ public static void ClearDiscoveryErrors()
/// <inheritdoc />
public ObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default)
{
var objectsByDepth = new ConcurrentDictionary<int, HashSet<object>>();
var allObjects = new HashSet<object>(ReferenceComparer);
var allObjectsLock = new object(); // Thread-safety for allObjects HashSet
var visitedObjects = new ConcurrentDictionary<object, byte>(ReferenceComparer);
var objectsByDepth = new Dictionary<int, HashSet<object>>();
var visitedObjects = new HashSet<object>(ReferenceComparer);

// Standard mode add callback (thread-safe)
bool TryAddStandard(object obj, int depth)
{
if (!visitedObjects.TryAdd(obj, 0))
if (!visitedObjects.Add(obj))
{
return false;
}

AddToDepth(objectsByDepth, depth, obj);
lock (allObjectsLock)
{
allObjects.Add(obj);
}

return true;
}
Expand All @@ -125,7 +119,7 @@ bool TryAddStandard(object obj, int depth)
CollectRootObjects(
testContext.Metadata.TestDetails,
TryAddStandard,
obj => DiscoverNestedObjects(obj, objectsByDepth, visitedObjects, allObjects, allObjectsLock, currentDepth: 1, cancellationToken),
obj => DiscoverNestedObjects(obj, objectsByDepth, visitedObjects, currentDepth: 1, cancellationToken),
cancellationToken);

return new ObjectGraph(objectsByDepth);
Expand All @@ -134,20 +128,14 @@ bool TryAddStandard(object obj, int depth)
/// <inheritdoc />
public ObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default)
{
var objectsByDepth = new ConcurrentDictionary<int, HashSet<object>>();
var allObjects = new HashSet<object>(ReferenceComparer);
var allObjectsLock = new object(); // Thread-safety for allObjects HashSet
var visitedObjects = new ConcurrentDictionary<object, byte>(ReferenceComparer);
var objectsByDepth = new Dictionary<int, HashSet<object>>();
var visitedObjects = new HashSet<object>(ReferenceComparer);

if (visitedObjects.TryAdd(rootObject, 0))
if (visitedObjects.Add(rootObject))
{
AddToDepth(objectsByDepth, 0, rootObject);
lock (allObjectsLock)
{
allObjects.Add(rootObject);
}

DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, allObjectsLock, currentDepth: 1, cancellationToken);
DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, currentDepth: 1, cancellationToken);
}

return new ObjectGraph(objectsByDepth);
Expand Down Expand Up @@ -180,10 +168,8 @@ public ConcurrentDictionary<int, HashSet<object>> DiscoverAndTrackObjects(TestCo
/// </summary>
private void DiscoverNestedObjects(
object obj,
ConcurrentDictionary<int, HashSet<object>> objectsByDepth,
ConcurrentDictionary<object, byte> visitedObjects,
HashSet<object> allObjects,
object allObjectsLock,
Dictionary<int, HashSet<object>> objectsByDepth,
HashSet<object> visitedObjects,
int currentDepth,
CancellationToken cancellationToken)
{
Expand All @@ -197,24 +183,20 @@ private void DiscoverNestedObjects(
// Standard mode add callback: visitedObjects + objectsByDepth + allObjects (thread-safe)
bool TryAddStandard(object value, int depth)
{
if (!visitedObjects.TryAdd(value, 0))
if (!visitedObjects.Add(value))
{
return false;
}

AddToDepth(objectsByDepth, depth, value);
lock (allObjectsLock)
{
allObjects.Add(value);
}

return true;
}

// Recursive callback
void Recurse(object value, int depth)
{
DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, allObjectsLock, depth, cancellationToken);
DiscoverNestedObjects(value, objectsByDepth, visitedObjects, depth, cancellationToken);
}

// Traverse injectable properties (useSourceRegistrarCheck = false)
Expand Down Expand Up @@ -281,15 +263,11 @@ private static bool ShouldSkipType(Type type)

/// <summary>
/// Adds an object to the specified depth level.
/// Thread-safe: uses lock to protect HashSet modifications.
/// </summary>
private static void AddToDepth(ConcurrentDictionary<int, HashSet<object>> objectsByDepth, int depth, object obj)
/// </summary>
private static void AddToDepth(Dictionary<int, HashSet<object>> objectsByDepth, int depth, object obj)
{
var hashSet = objectsByDepth.GetOrAdd(depth, _ => new HashSet<object>(ReferenceComparer));
lock (hashSet)
{
hashSet.Add(obj);
}
hashSet.Add(obj);
}

/// <summary>
Expand Down
18 changes: 18 additions & 0 deletions TUnit.Core/Extensions/DictionaryExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace TUnit.Core.Extensions;

internal static class DictionaryExtensions
{
public static TValue GetOrAdd<TKey, TValue>(
this IDictionary<TKey, TValue> dictionary,
TKey key,
Func<TKey, TValue> valueFactory)
{
if (!dictionary.TryGetValue(key, out TValue? value))
{
value = valueFactory(key);
dictionary.Add(key, value);
}

return value;
}
}
Loading