diff --git a/src/TestFramework/TestFramework/Attributes/TestMethod/STATestMethodAttribute.cs b/src/TestFramework/TestFramework/Attributes/TestMethod/STATestMethodAttribute.cs
index 2adbb23931..dd423b5248 100644
--- a/src/TestFramework/TestFramework/Attributes/TestMethod/STATestMethodAttribute.cs
+++ b/src/TestFramework/TestFramework/Attributes/TestMethod/STATestMethodAttribute.cs
@@ -29,6 +29,13 @@ public STATestMethodAttribute(TestMethodAttribute testMethodAttribute)
: base(testMethodAttribute.DeclaringFilePath, testMethodAttribute.DeclaringLineNumber ?? -1)
=> _testMethodAttribute = testMethodAttribute;
+ ///
+ /// Gets or sets a value indicating whether the attribute will set a that preserves the same
+ /// STA thread for async continuations.
+ /// The default is .
+ ///
+ public bool UseSTASynchronizationContext { get; set; }
+
///
/// The core execution of STA test method, which happens on the STA thread.
///
@@ -38,18 +45,39 @@ protected virtual Task ExecuteCoreAsync(ITestMethod testMethod)
=> _testMethodAttribute is null ? base.ExecuteAsync(testMethod) : _testMethodAttribute.ExecuteAsync(testMethod);
///
- public override Task ExecuteAsync(ITestMethod testMethod)
+ public override async Task ExecuteAsync(ITestMethod testMethod)
{
+ if (UseSTASynchronizationContext)
+ {
+ SynchronizationContext? originalContext = SynchronizationContext.Current;
+ var syncContext = new SingleThreadedSTASynchronizationContext();
+ try
+ {
+ SynchronizationContext.SetSynchronizationContext(syncContext);
+
+ // The yield ensures that we switch to the STA thread created by SingleThreadedSTASynchronizationContext.
+ await Task.Yield();
+ TestResult[] testResults = await ExecuteCoreAsync(testMethod).ConfigureAwait(false);
+ return testResults;
+ }
+ finally
+ {
+ SynchronizationContext.SetSynchronizationContext(originalContext);
+ syncContext.Complete();
+ syncContext.Dispose();
+ }
+ }
+
if (Thread.CurrentThread.GetApartmentState() == ApartmentState.STA)
{
- return ExecuteCoreAsync(testMethod);
+ return await ExecuteCoreAsync(testMethod).ConfigureAwait(false);
}
#if !NETFRAMEWORK
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
// TODO: Throw?
- return ExecuteCoreAsync(testMethod);
+ return await ExecuteCoreAsync(testMethod).ConfigureAwait(false);
}
#endif
@@ -61,6 +89,6 @@ public override Task ExecuteAsync(ITestMethod testMethod)
t.SetApartmentState(ApartmentState.STA);
t.Start();
t.Join();
- return Task.FromResult(results!);
+ return results!;
}
}
diff --git a/src/TestFramework/TestFramework/Attributes/TestMethod/SingleThreadedSTASynchronizationContext.cs b/src/TestFramework/TestFramework/Attributes/TestMethod/SingleThreadedSTASynchronizationContext.cs
new file mode 100644
index 0000000000..047b133922
--- /dev/null
+++ b/src/TestFramework/TestFramework/Attributes/TestMethod/SingleThreadedSTASynchronizationContext.cs
@@ -0,0 +1,65 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace Microsoft.VisualStudio.TestTools.UnitTesting;
+
+internal sealed class SingleThreadedSTASynchronizationContext : SynchronizationContext, IDisposable
+{
+ private readonly BlockingCollection _queue = [];
+ private readonly Thread _thread;
+
+ public SingleThreadedSTASynchronizationContext()
+ {
+#if !NETFRAMEWORK
+ if (!OperatingSystem.IsWindows())
+ {
+ throw new NotSupportedException("SingleThreadedSTASynchronizationContext is only supported on Windows.");
+ }
+#endif
+
+ _thread = new Thread(() =>
+ {
+ SetSynchronizationContext(this);
+ foreach (Action callback in _queue.GetConsumingEnumerable())
+ {
+ callback();
+ }
+ })
+ {
+ IsBackground = true,
+ };
+ _thread.SetApartmentState(ApartmentState.STA);
+ _thread.Start();
+ }
+
+ public override void Post(SendOrPostCallback d, object? state)
+ => _queue.Add(() => d(state));
+
+ public override void Send(SendOrPostCallback d, object? state)
+ {
+ if (Environment.CurrentManagedThreadId == _thread.ManagedThreadId)
+ {
+ d(state);
+ }
+ else
+ {
+ using var done = new ManualResetEventSlim();
+ _queue.Add(() =>
+ {
+ try
+ {
+ d(state);
+ }
+ finally
+ {
+ done.Set();
+ }
+ });
+ done.Wait();
+ }
+ }
+
+ public void Complete() => _queue.CompleteAdding();
+
+ public void Dispose() => _queue.Dispose();
+}
diff --git a/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt b/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt
index c456e2b3e3..cfabfe29c5 100644
--- a/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt
+++ b/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt
@@ -43,6 +43,8 @@ Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertGenericIsExactInstance
Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertGenericIsNotExactInstanceOfTypeInterpolatedStringHandler
Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsExactInstanceOfTypeInterpolatedStringHandler
Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsNotExactInstanceOfTypeInterpolatedStringHandler
+Microsoft.VisualStudio.TestTools.UnitTesting.STATestMethodAttribute.UseSTASynchronizationContext.get -> bool
+Microsoft.VisualStudio.TestTools.UnitTesting.STATestMethodAttribute.UseSTASynchronizationContext.set -> void
static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.ContainsSingle(System.Func