diff --git a/TUnit.Core/ObjectInitializer.cs b/TUnit.Core/ObjectInitializer.cs index de2e963e6f..362445816e 100644 --- a/TUnit.Core/ObjectInitializer.cs +++ b/TUnit.Core/ObjectInitializer.cs @@ -25,24 +25,28 @@ public static async ValueTask InitializeAsync(object? obj, CancellationToken can { if (obj is IAsyncInitializer asyncInitializer) { - await GetInitializationTask(obj, asyncInitializer); + await GetInitializationTask(obj, asyncInitializer, cancellationToken); } } - private static Task GetInitializationTask(object obj, IAsyncInitializer asyncInitializer) + private static async Task GetInitializationTask(object obj, IAsyncInitializer asyncInitializer, CancellationToken cancellationToken) { + Task initializationTask; + lock (_lock) { - if (_initializationTasks.TryGetValue(obj, out var task)) + if (_initializationTasks.TryGetValue(obj, out var existingTask)) { - return task; + initializationTask = existingTask; + } + else + { + initializationTask = asyncInitializer.InitializeAsync(); + _initializationTasks.Add(obj, initializationTask); } - - var initializationTask = asyncInitializer.InitializeAsync(); - - _initializationTasks.Add(obj, initializationTask); - - return initializationTask; } + + // Wait for initialization with cancellation support + await initializationTask.WaitAsync(cancellationToken); } } diff --git a/TUnit.Engine/Services/DataSourceInitializer.cs b/TUnit.Engine/Services/DataSourceInitializer.cs index 4f91a1070e..5fa59eba0e 100644 --- a/TUnit.Engine/Services/DataSourceInitializer.cs +++ b/TUnit.Engine/Services/DataSourceInitializer.cs @@ -33,7 +33,8 @@ public async Task EnsureInitializedAsync( T dataSource, ConcurrentDictionary? objectBag = null, MethodMetadata? methodMetadata = null, - TestContextEvents? events = null) where T : notnull + TestContextEvents? events = null, + CancellationToken cancellationToken = default) where T : notnull { if (dataSource == null) { @@ -51,12 +52,22 @@ public async Task EnsureInitializedAsync( else { // Start initialization - existingTask = InitializeDataSourceAsync(dataSource, objectBag, methodMetadata, events); + existingTask = InitializeDataSourceAsync(dataSource, objectBag, methodMetadata, events, cancellationToken); _initializationTasks[dataSource] = existingTask; } } - await existingTask; + // Wait for initialization with cancellation support + if (cancellationToken.CanBeCanceled) + { + await existingTask.ConfigureAwait(false); + cancellationToken.ThrowIfCancellationRequested(); + } + else + { + await existingTask.ConfigureAwait(false); + } + return dataSource; } @@ -67,7 +78,8 @@ private async Task InitializeDataSourceAsync( object dataSource, ConcurrentDictionary? objectBag, MethodMetadata? methodMetadata, - TestContextEvents? events) + TestContextEvents? events, + CancellationToken cancellationToken) { try { @@ -85,12 +97,12 @@ await _propertyInjectionService.InjectPropertiesIntoObjectAsync( // Step 2: Initialize nested property-injected objects (deepest first) // This ensures that when the parent's IAsyncInitializer runs, all nested objects are already initialized - await InitializeNestedObjectsAsync(dataSource); + await InitializeNestedObjectsAsync(dataSource, cancellationToken); // Step 3: IAsyncInitializer on the data source itself if (dataSource is IAsyncInitializer asyncInitializer) { - await ObjectInitializer.InitializeAsync(asyncInitializer); + await ObjectInitializer.InitializeAsync(asyncInitializer, cancellationToken); } } catch (Exception ex) @@ -104,7 +116,7 @@ await _propertyInjectionService.InjectPropertiesIntoObjectAsync( /// Initializes all nested property-injected objects in depth-first order. /// This ensures that when the parent's IAsyncInitializer runs, all nested dependencies are already initialized. /// - private async Task InitializeNestedObjectsAsync(object rootObject) + private async Task InitializeNestedObjectsAsync(object rootObject, CancellationToken cancellationToken) { var objectsByDepth = new Dictionary>(capacity: 4); var visitedObjects = new HashSet(); @@ -120,7 +132,7 @@ private async Task InitializeNestedObjectsAsync(object rootObject) var objectsAtDepth = objectsByDepth[depth]; // Initialize all objects at this depth in parallel - await Task.WhenAll(objectsAtDepth.Select(obj => ObjectInitializer.InitializeAsync(obj).AsTask())); + await Task.WhenAll(objectsAtDepth.Select(obj => ObjectInitializer.InitializeAsync(obj, cancellationToken).AsTask())); } } diff --git a/TUnit.Engine/Services/TestExecution/TestCoordinator.cs b/TUnit.Engine/Services/TestExecution/TestCoordinator.cs index ea413cd0a0..4ae9652904 100644 --- a/TUnit.Engine/Services/TestExecution/TestCoordinator.cs +++ b/TUnit.Engine/Services/TestExecution/TestCoordinator.cs @@ -3,6 +3,7 @@ using TUnit.Core.Exceptions; using TUnit.Core.Logging; using TUnit.Core.Tracking; +using TUnit.Engine.Helpers; using TUnit.Engine.Interfaces; using TUnit.Engine.Logging; @@ -92,60 +93,75 @@ private async Task ExecuteTestInternalAsync(AbstractExecutableTest test, Cancell await _testExecutor.EnsureTestSessionHooksExecutedAsync(); // Execute test with retry logic - each retry gets a fresh instance + // Timeout is applied per retry attempt, not across all retries await RetryHelper.ExecuteWithRetry(test.Context, async () => { - test.Context.Metadata.TestDetails.ClassInstance = await test.CreateInstanceAsync(); + // Get timeout configuration for this attempt + var testTimeout = test.Context.Metadata.TestDetails.Timeout; + var timeoutMessage = testTimeout.HasValue + ? $"Test '{test.Context.Metadata.TestDetails.TestName}' timed out after {testTimeout.Value}" + : null; + + // Wrap entire lifecycle (instance creation, initialization, execution) with timeout + await TimeoutHelper.ExecuteWithTimeoutAsync( + async ct => + { + test.Context.Metadata.TestDetails.ClassInstance = await test.CreateInstanceAsync(); - // Invalidate cached eligible event objects since ClassInstance changed - test.Context.CachedEligibleEventObjects = null; + // Invalidate cached eligible event objects since ClassInstance changed + test.Context.CachedEligibleEventObjects = null; - // Check if this test should be skipped (after creating instance) - if (test.Context.Metadata.TestDetails.ClassInstance is SkippedTestInstance || - !string.IsNullOrEmpty(test.Context.SkipReason)) - { - await _stateManager.MarkSkippedAsync(test, test.Context.SkipReason ?? "Test was skipped"); + // Check if this test should be skipped (after creating instance) + if (test.Context.Metadata.TestDetails.ClassInstance is SkippedTestInstance || + !string.IsNullOrEmpty(test.Context.SkipReason)) + { + await _stateManager.MarkSkippedAsync(test, test.Context.SkipReason ?? "Test was skipped"); - await _eventReceiverOrchestrator.InvokeTestSkippedEventReceiversAsync(test.Context, cancellationToken); + await _eventReceiverOrchestrator.InvokeTestSkippedEventReceiversAsync(test.Context, ct); - await _eventReceiverOrchestrator.InvokeTestEndEventReceiversAsync(test.Context, cancellationToken); + await _eventReceiverOrchestrator.InvokeTestEndEventReceiversAsync(test.Context, ct); - return; - } + return; + } - try - { - await _testInitializer.InitializeTest(test, cancellationToken); - test.Context.RestoreExecutionContext(); - await _testExecutor.ExecuteAsync(test, cancellationToken); - } - finally - { - // Dispose test instance and fire OnDispose after each attempt - // This ensures each retry gets a fresh instance - if (test.Context.Events.OnDispose?.InvocationList != null) - { - foreach (var invocation in test.Context.Events.OnDispose.InvocationList) + try { + await _testInitializer.InitializeTest(test, ct); + test.Context.RestoreExecutionContext(); + await _testExecutor.ExecuteAsync(test, ct); + } + finally + { + // Dispose test instance and fire OnDispose after each attempt + // This ensures each retry gets a fresh instance + if (test.Context.Events.OnDispose?.InvocationList != null) + { + foreach (var invocation in test.Context.Events.OnDispose.InvocationList) + { + try + { + await invocation.InvokeAsync(test.Context, test.Context); + } + catch (Exception disposeEx) + { + await _logger.LogErrorAsync($"Error during OnDispose for {test.TestId}: {disposeEx}"); + } + } + } + try { - await invocation.InvokeAsync(test.Context, test.Context); + await TestExecutor.DisposeTestInstance(test); } catch (Exception disposeEx) { - await _logger.LogErrorAsync($"Error during OnDispose for {test.TestId}: {disposeEx}"); + await _logger.LogErrorAsync($"Error disposing test instance for {test.TestId}: {disposeEx}"); } } - } - - try - { - await TestExecutor.DisposeTestInstance(test); - } - catch (Exception disposeEx) - { - await _logger.LogErrorAsync($"Error disposing test instance for {test.TestId}: {disposeEx}"); - } - } + }, + testTimeout, + cancellationToken, + timeoutMessage); }); await _stateManager.MarkCompletedAsync(test); diff --git a/TUnit.Engine/TestExecutor.cs b/TUnit.Engine/TestExecutor.cs index 741710c761..dc5562796d 100644 --- a/TUnit.Engine/TestExecutor.cs +++ b/TUnit.Engine/TestExecutor.cs @@ -5,7 +5,6 @@ using TUnit.Core.Exceptions; using TUnit.Core.Interfaces; using TUnit.Core.Services; -using TUnit.Engine.Helpers; using TUnit.Engine.Services; namespace TUnit.Engine; @@ -98,16 +97,8 @@ await _eventReceiverOrchestrator.InvokeFirstTestInClassEventReceiversAsync( executableTest.Context.RestoreExecutionContext(); - var testTimeout = executableTest.Context.Metadata.TestDetails.Timeout; - var timeoutMessage = testTimeout.HasValue - ? $"Test '{executableTest.Context.Metadata.TestDetails.TestName}' execution timed out after {testTimeout.Value}" - : null; - - await TimeoutHelper.ExecuteWithTimeoutAsync( - ct => ExecuteTestAsync(executableTest, ct), - testTimeout, - cancellationToken, - timeoutMessage).ConfigureAwait(false); + // Timeout is now enforced at TestCoordinator level (wrapping entire lifecycle) + await ExecuteTestAsync(executableTest, cancellationToken).ConfigureAwait(false); executableTest.SetResult(TestState.Passed); } diff --git a/TUnit.Engine/TestInitializer.cs b/TUnit.Engine/TestInitializer.cs index 914456581c..00297f354f 100644 --- a/TUnit.Engine/TestInitializer.cs +++ b/TUnit.Engine/TestInitializer.cs @@ -36,10 +36,10 @@ await _propertyInjectionService.InjectPropertiesIntoObjectAsync( // Shouldn't retrack already tracked objects, but will track any new ones created during retries / initialization _objectTracker.TrackObjects(test.Context); - await InitializeTrackedObjects(test.Context); + await InitializeTrackedObjects(test.Context, cancellationToken); } - private async Task InitializeTrackedObjects(TestContext testContext) + private async Task InitializeTrackedObjects(TestContext testContext, CancellationToken cancellationToken) { // Initialize by level (deepest first), with objects at the same level in parallel var levels = testContext.TrackedObjects.Keys.OrderByDescending(level => level); @@ -47,10 +47,10 @@ private async Task InitializeTrackedObjects(TestContext testContext) foreach (var level in levels) { var objectsAtLevel = testContext.TrackedObjects[level]; - await Task.WhenAll(objectsAtLevel.Select(obj => ObjectInitializer.InitializeAsync(obj).AsTask())); + await Task.WhenAll(objectsAtLevel.Select(obj => ObjectInitializer.InitializeAsync(obj, cancellationToken).AsTask())); } // Finally, ensure the test class itself is initialized - await ObjectInitializer.InitializeAsync(testContext.Metadata.TestDetails.ClassInstance); + await ObjectInitializer.InitializeAsync(testContext.Metadata.TestDetails.ClassInstance, cancellationToken); } }