diff --git a/src/TestFramework/TestFramework/Attributes/TestMethod/STATestMethodAttribute.cs b/src/TestFramework/TestFramework/Attributes/TestMethod/STATestMethodAttribute.cs index 2adbb23931..dd423b5248 100644 --- a/src/TestFramework/TestFramework/Attributes/TestMethod/STATestMethodAttribute.cs +++ b/src/TestFramework/TestFramework/Attributes/TestMethod/STATestMethodAttribute.cs @@ -29,6 +29,13 @@ public STATestMethodAttribute(TestMethodAttribute testMethodAttribute) : base(testMethodAttribute.DeclaringFilePath, testMethodAttribute.DeclaringLineNumber ?? -1) => _testMethodAttribute = testMethodAttribute; + /// + /// Gets or sets a value indicating whether the attribute will set a that preserves the same + /// STA thread for async continuations. + /// The default is . + /// + public bool UseSTASynchronizationContext { get; set; } + /// /// The core execution of STA test method, which happens on the STA thread. /// @@ -38,18 +45,39 @@ protected virtual Task ExecuteCoreAsync(ITestMethod testMethod) => _testMethodAttribute is null ? base.ExecuteAsync(testMethod) : _testMethodAttribute.ExecuteAsync(testMethod); /// - public override Task ExecuteAsync(ITestMethod testMethod) + public override async Task ExecuteAsync(ITestMethod testMethod) { + if (UseSTASynchronizationContext) + { + SynchronizationContext? originalContext = SynchronizationContext.Current; + var syncContext = new SingleThreadedSTASynchronizationContext(); + try + { + SynchronizationContext.SetSynchronizationContext(syncContext); + + // The yield ensures that we switch to the STA thread created by SingleThreadedSTASynchronizationContext. + await Task.Yield(); + TestResult[] testResults = await ExecuteCoreAsync(testMethod).ConfigureAwait(false); + return testResults; + } + finally + { + SynchronizationContext.SetSynchronizationContext(originalContext); + syncContext.Complete(); + syncContext.Dispose(); + } + } + if (Thread.CurrentThread.GetApartmentState() == ApartmentState.STA) { - return ExecuteCoreAsync(testMethod); + return await ExecuteCoreAsync(testMethod).ConfigureAwait(false); } #if !NETFRAMEWORK if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { // TODO: Throw? - return ExecuteCoreAsync(testMethod); + return await ExecuteCoreAsync(testMethod).ConfigureAwait(false); } #endif @@ -61,6 +89,6 @@ public override Task ExecuteAsync(ITestMethod testMethod) t.SetApartmentState(ApartmentState.STA); t.Start(); t.Join(); - return Task.FromResult(results!); + return results!; } } diff --git a/src/TestFramework/TestFramework/Attributes/TestMethod/SingleThreadedSTASynchronizationContext.cs b/src/TestFramework/TestFramework/Attributes/TestMethod/SingleThreadedSTASynchronizationContext.cs new file mode 100644 index 0000000000..047b133922 --- /dev/null +++ b/src/TestFramework/TestFramework/Attributes/TestMethod/SingleThreadedSTASynchronizationContext.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.VisualStudio.TestTools.UnitTesting; + +internal sealed class SingleThreadedSTASynchronizationContext : SynchronizationContext, IDisposable +{ + private readonly BlockingCollection _queue = []; + private readonly Thread _thread; + + public SingleThreadedSTASynchronizationContext() + { +#if !NETFRAMEWORK + if (!OperatingSystem.IsWindows()) + { + throw new NotSupportedException("SingleThreadedSTASynchronizationContext is only supported on Windows."); + } +#endif + + _thread = new Thread(() => + { + SetSynchronizationContext(this); + foreach (Action callback in _queue.GetConsumingEnumerable()) + { + callback(); + } + }) + { + IsBackground = true, + }; + _thread.SetApartmentState(ApartmentState.STA); + _thread.Start(); + } + + public override void Post(SendOrPostCallback d, object? state) + => _queue.Add(() => d(state)); + + public override void Send(SendOrPostCallback d, object? state) + { + if (Environment.CurrentManagedThreadId == _thread.ManagedThreadId) + { + d(state); + } + else + { + using var done = new ManualResetEventSlim(); + _queue.Add(() => + { + try + { + d(state); + } + finally + { + done.Set(); + } + }); + done.Wait(); + } + } + + public void Complete() => _queue.CompleteAdding(); + + public void Dispose() => _queue.Dispose(); +} diff --git a/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt b/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt index c456e2b3e3..cfabfe29c5 100644 --- a/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt @@ -43,6 +43,8 @@ Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertGenericIsExactInstance Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertGenericIsNotExactInstanceOfTypeInterpolatedStringHandler Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsExactInstanceOfTypeInterpolatedStringHandler Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsNotExactInstanceOfTypeInterpolatedStringHandler +Microsoft.VisualStudio.TestTools.UnitTesting.STATestMethodAttribute.UseSTASynchronizationContext.get -> bool +Microsoft.VisualStudio.TestTools.UnitTesting.STATestMethodAttribute.UseSTASynchronizationContext.set -> void static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.ContainsSingle(System.Func! predicate, System.Collections.IEnumerable! collection, string? message = "", string! predicateExpression = "", string! collectionExpression = "") -> object? static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.IsExactInstanceOfType(object? value, System.Type? expectedType, string? message = "", string! valueExpression = "") -> void static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.IsExactInstanceOfType(object? value, System.Type? expectedType, ref Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsExactInstanceOfTypeInterpolatedStringHandler message, string! valueExpression = "") -> void diff --git a/test/UnitTests/MSTest.SelfRealExamples.UnitTests/STATestMethodSyncContext.cs b/test/UnitTests/MSTest.SelfRealExamples.UnitTests/STATestMethodSyncContext.cs new file mode 100644 index 0000000000..b895effd33 --- /dev/null +++ b/test/UnitTests/MSTest.SelfRealExamples.UnitTests/STATestMethodSyncContext.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace MSTest.SelfRealExamples.UnitTests; + +[TestClass] +public class STATestMethodSyncContext +{ + [STATestMethod] + [OSCondition(OperatingSystems.Windows)] + public void STAByDefaultDoesNotUseSynchronizationContext() + { + Assert.IsNull(SynchronizationContext.Current); + Assert.AreEqual(ApartmentState.STA, Thread.CurrentThread.GetApartmentState()); + } + + [STATestMethod(UseSTASynchronizationContext = true)] + [OSCondition(OperatingSystems.Windows)] + public async Task STAWithSynchronizationContextIsCorrect() + { + Assert.IsNotNull(SynchronizationContext.Current); + Assert.AreEqual(ApartmentState.STA, Thread.CurrentThread.GetApartmentState()); + + await Task.Delay(100); + + Assert.AreEqual(ApartmentState.STA, Thread.CurrentThread.GetApartmentState()); + } +}