Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions .idea/.idea.TUnit/.idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 0 additions & 6 deletions TUnit.Core/StaticPropertyReflectionInitializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,6 @@ private static async Task InitializeStaticProperty(Type type, PropertyInfo prope
// Set the property value
property.SetValue(null, value);

// Initialize the value if it's an object
if (value != null)
{
await ObjectInitializer.InitializeAsync(value);
}

// Only use the first value for static properties
break;
}
Expand Down
2 changes: 1 addition & 1 deletion TUnit.Core/TestContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ public async Task ReregisterTestWithArguments(object?[]? methodArguments = null,

internal AbstractExecutableTest InternalExecutableTest { get; set; } = null!;

internal HashSet<object> TrackedObjects { get; } = [];
internal ConcurrentDictionary<int, HashSet<object>> TrackedObjects { get; } = [];

public DateTimeOffset? TestEnd { get; set; }

Expand Down
13 changes: 10 additions & 3 deletions TUnit.Core/Tracking/ObjectTracker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,24 @@ internal class ObjectTracker(TrackableObjectGraphProvider trackableObjectGraphPr

public void TrackObjects(TestContext testContext)
{
var objects = trackableObjectGraphProvider.GetTrackableObjects(testContext);
var alreadyTracked = testContext.TrackedObjects.SelectMany(x => x.Value).ToHashSet();

foreach (var obj in objects)
var newTrackableObjects = trackableObjectGraphProvider.GetTrackableObjects(testContext)
.SelectMany(x => x.Value)
.Except(alreadyTracked)
.ToHashSet();

foreach (var obj in newTrackableObjects)
{
TrackObject(obj);
}
}

public async ValueTask UntrackObjects(TestContext testContext, List<Exception> cleanupExceptions)
{
foreach (var obj in testContext.TrackedObjects)
foreach (var obj in testContext.TrackedObjects
.SelectMany(x => x.Value)
.ToHashSet())
{
try
{
Expand Down
68 changes: 25 additions & 43 deletions TUnit.Core/Tracking/TrackableObjectGraphProvider.cs
Original file line number Diff line number Diff line change
@@ -1,54 +1,52 @@
using System.Collections.Concurrent;
using TUnit.Core.PropertyInjection;
using TUnit.Core.StaticProperties;

namespace TUnit.Core.Tracking;

internal class TrackableObjectGraphProvider
{
public IEnumerable<object> GetTrackableObjects(TestContext testContext)
public ConcurrentDictionary<int, HashSet<object>> GetTrackableObjects(TestContext testContext)
{
var visitedObjects = testContext.TrackedObjects;

var testDetails = testContext.TestDetails;

foreach (var classArgument in testDetails.TestClassArguments)
{
if (classArgument != null && visitedObjects.Add(classArgument))
if (classArgument != null && visitedObjects.GetOrAdd(0, []).Add(classArgument))
{
yield return classArgument;

foreach (var nested in GetNestedTrackableObjects(classArgument, visitedObjects))
{
yield return nested;
}
AddNestedTrackableObjects(classArgument, visitedObjects, 1);
}
}

foreach (var methodArgument in testDetails.TestMethodArguments)
{
if (methodArgument != null && visitedObjects.Add(methodArgument))
if (methodArgument != null && visitedObjects.GetOrAdd(0, []).Add(methodArgument))
{
yield return methodArgument;

foreach (var nested in GetNestedTrackableObjects(methodArgument, visitedObjects))
{
yield return nested;
}
AddNestedTrackableObjects(methodArgument, visitedObjects, 1);
}
}

foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values)
{
if (property != null && visitedObjects.Add(property))
if (property != null && visitedObjects.GetOrAdd(0, []).Add(property))
{
yield return property;

foreach (var nested in GetNestedTrackableObjects(property, visitedObjects))
{
yield return nested;
}
AddNestedTrackableObjects(property, visitedObjects, 1);
}
}

return visitedObjects;
}

private static void AddToLevel(Dictionary<int, List<object>> objectsByLevel, int level, object obj)
{
if (!objectsByLevel.TryGetValue(level, out var list))
{
list = [];
objectsByLevel[level] = list;
}
list.Add(obj);
}

/// <summary>
Expand All @@ -65,14 +63,8 @@ public IEnumerable<object> GetStaticPropertyTrackableObjects()
}
}

private IEnumerable<object> GetNestedTrackableObjects(object obj, HashSet<object> visitedObjects)
private void AddNestedTrackableObjects(object obj, ConcurrentDictionary<int, HashSet<object>> visitedObjects, int currentDepth)
{
// Prevent infinite recursion on circular references
if (!visitedObjects.Add(obj))
{
yield break;
}

var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType());

if(!SourceRegistrar.IsEnabled)
Expand All @@ -87,22 +79,17 @@ private IEnumerable<object> GetNestedTrackableObjects(object obj, HashSet<object
}

// Check if already visited before yielding to prevent duplicates
if (!visitedObjects.Add(value))
if (!visitedObjects.GetOrAdd(currentDepth, []).Add(value))
{
continue;
}

yield return value;

if (!PropertyInjectionCache.HasInjectableProperties(value.GetType()))
{
continue;
}

foreach (var nested in GetNestedTrackableObjects(value, visitedObjects))
{
yield return nested;
}
AddNestedTrackableObjects(value, visitedObjects, currentDepth + 1);
}
}
else
Expand All @@ -124,22 +111,17 @@ private IEnumerable<object> GetNestedTrackableObjects(object obj, HashSet<object
}

// Check if already visited before yielding to prevent duplicates
if (!visitedObjects.Add(value))
if (!visitedObjects.GetOrAdd(currentDepth, []).Add(value))
{
continue;
}

yield return value;

if (!PropertyInjectionCache.HasInjectableProperties(value.GetType()))
{
continue;
}

foreach (var nested in GetNestedTrackableObjects(value, visitedObjects))
{
yield return nested;
}
AddNestedTrackableObjects(value, visitedObjects, currentDepth + 1);
}
}
}
Expand Down
6 changes: 2 additions & 4 deletions TUnit.Engine/Framework/TUnitServiceProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ public ITestExecutionFilter? Filter
public PropertyInjectionService PropertyInjectionService { get; }
public DataSourceInitializer DataSourceInitializer { get; }
public ObjectRegistrationService ObjectRegistrationService { get; }
public ObjectInitializationService ObjectInitializationService { get; }
public bool AfterSessionHooksFailed { get; set; }

public TUnitServiceProvider(IExtension extension,
Expand Down Expand Up @@ -89,7 +88,6 @@ public TUnitServiceProvider(IExtension extension,

// NEW: Separate registration and execution services (replaces TestObjectInitializer)
ObjectRegistrationService = Register(new ObjectRegistrationService(PropertyInjectionService));
ObjectInitializationService = Register(new ObjectInitializationService());

// Initialize the circular dependencies
PropertyInjectionService.Initialize(ObjectRegistrationService);
Expand All @@ -115,7 +113,7 @@ public TUnitServiceProvider(IExtension extension,

CancellationToken = Register(new EngineCancellationToken());

EventReceiverOrchestrator = Register(new EventReceiverOrchestrator(Logger));
EventReceiverOrchestrator = Register(new EventReceiverOrchestrator(Logger, trackableObjectGraphProvider));
HookCollectionService = Register<IHookCollectionService>(new HookCollectionService(EventReceiverOrchestrator));

ParallelLimitLockProvider = Register(new ParallelLimitLockProvider());
Expand Down Expand Up @@ -157,7 +155,7 @@ public TUnitServiceProvider(IExtension extension,
// Create test finder service after discovery service so it can use its cache
TestFinder = Register<ITestFinder>(new TestFinder(DiscoveryService));

var testInitializer = new TestInitializer(EventReceiverOrchestrator, ObjectInitializationService, PropertyInjectionService, objectTracker);
var testInitializer = new TestInitializer(EventReceiverOrchestrator, PropertyInjectionService, objectTracker);

// Create the new TestCoordinator that orchestrates the granular services
var testCoordinator = Register<ITestCoordinator>(
Expand Down
10 changes: 4 additions & 6 deletions TUnit.Engine/Services/DataSourceInitializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ internal sealed class DataSourceInitializer
private readonly Dictionary<object, Task> _initializationTasks = new();
private readonly object _lock = new();
private PropertyInjectionService? _propertyInjectionService;

public void Initialize(PropertyInjectionService propertyInjectionService)
{
_propertyInjectionService = propertyInjectionService;
Expand Down Expand Up @@ -77,14 +77,12 @@ private async Task InitializeDataSourceAsync(
await _propertyInjectionService.InjectPropertiesIntoObjectAsync(
dataSource, objectBag, methodMetadata, events);
}

// Step 2: IAsyncInitializer
if (dataSource is IAsyncInitializer asyncInitializer)
{
await asyncInitializer.InitializeAsync();
await ObjectInitializer.InitializeAsync(asyncInitializer);
}

// Note: Tracking is now handled by ObjectRegistrationService during registration phase
}
catch (Exception ex)
{
Expand All @@ -103,4 +101,4 @@ public void ClearCache()
_initializationTasks.Clear();
}
}
}
}
20 changes: 5 additions & 15 deletions TUnit.Engine/Services/EventReceiverOrchestrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using TUnit.Core.Data;
using TUnit.Core.Helpers;
using TUnit.Core.Interfaces;
using TUnit.Core.Tracking;
using TUnit.Engine.Events;
using TUnit.Engine.Extensions;
using TUnit.Engine.Logging;
Expand All @@ -16,6 +17,7 @@ internal sealed class EventReceiverOrchestrator : IDisposable
{
private readonly EventReceiverRegistry _registry = new();
private readonly TUnitFrameworkLogger _logger;
private readonly TrackableObjectGraphProvider _trackableObjectGraphProvider;

// Track which assemblies/classes/sessions have had their "first" event invoked
private ThreadSafeDictionary<string, Task> _firstTestInAssemblyTasks = new();
Expand All @@ -33,25 +35,22 @@ internal sealed class EventReceiverOrchestrator : IDisposable
// Track registered First event receiver types to avoid duplicate registrations
private readonly ConcurrentHashSet<Type> _registeredFirstEventReceiverTypes = new();

public EventReceiverOrchestrator(TUnitFrameworkLogger logger)
public EventReceiverOrchestrator(TUnitFrameworkLogger logger, TrackableObjectGraphProvider trackableObjectGraphProvider)
{
_logger = logger;
_trackableObjectGraphProvider = trackableObjectGraphProvider;
}

public async ValueTask InitializeAllEligibleObjectsAsync(TestContext context, CancellationToken cancellationToken)
public void RegisterReceivers(TestContext context, CancellationToken cancellationToken)
{
var eligibleObjects = context.GetEligibleEventObjects().ToArray();

// Only initialize and register objects that haven't been processed yet
var newObjects = new List<object>();
var objectsToRegister = new List<object>();

foreach (var obj in eligibleObjects)
{
if (_initializedObjects.Add(obj)) // Add returns false if already present
{
newObjects.Add(obj);

// For First event receivers, only register one instance per type
var objType = obj.GetType();
bool isFirstEventReceiver = obj is IFirstTestInTestSessionEventReceiver ||
Expand Down Expand Up @@ -80,15 +79,6 @@ obj is IFirstTestInAssemblyEventReceiver ||
// Register only the objects that should be registered
_registry.RegisterReceivers(objectsToRegister);
}

if (newObjects.Count > 0)
{
// Initialize all new objects (even if not registered)
foreach (var obj in newObjects)
{
await ObjectInitializer.InitializeAsync(obj, cancellationToken);
}
}
}


Expand Down
Loading
Loading