diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index dfd8663b4..387894689 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -82,5 +82,6 @@ dotnet run --no-build -c Release --framework net9.0 -- --list-tests ## Coding style * Honor StyleCop rules and fix any reported build warnings *after* getting tests to pass. -* In C# files, use namespace *statements* instead of namespace *blocks* for all new files. +* In C# files, use namespace *statements* instead of namespace *blocks* for all new files that define namespaces. +* Test files are *not* expected to declare namespaces. * Add API doc comments to all new public and internal members. diff --git a/samples/DisableProcessing.cs b/samples/DisableProcessing.cs new file mode 100644 index 000000000..0deddb1d1 --- /dev/null +++ b/samples/DisableProcessing.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +#pragma warning disable VSTHRD103 // Call async methods when in an async method + +using System.IO; +using Microsoft.VisualStudio.Threading; + +internal class DisableProcessing +{ + private readonly JoinableTaskFactory joinableTaskFactory = null!; + + private void Simple() + { + #region Simple + this.joinableTaskFactory.Run(async delegate + { + this.joinableTaskFactory.DisableProcessing(); + + // Synchronous I/O and lock contentions will NOT result in any reentrancy within this JoinableTask. + }); + #endregion + } + + private void Exhaustive() + { + #region Exhaustive + this.joinableTaskFactory.Run(async delegate + { + // Async I/O isn't expected to synchronously block, and thus would never allow unwanted reentrancy. + string content = await File.ReadAllTextAsync(@"somefile.txt"); + + // Here, synchronous I/O and lock contentions MAY allow certain reentrancy (e.g. COM RPC messages). + content = File.ReadAllText(@"somefile.txt"); + + using (this.joinableTaskFactory.DisableProcessing()) + { + // Within this block, synchronous I/O and lock contentions will NOT result in any reentrancy. + content = File.ReadAllText(@"somefile.txt"); + } + + // Just disable the synchronous wait message pump for the rest of this JoinableTask. + this.joinableTaskFactory.DisableProcessing(); + + // Sync I/O and lock contentions will NOT result in any reentrancy here. + content = File.ReadAllText(@"somefile.txt"); + }); + #endregion + } +} diff --git a/samples/Polyfill.cs b/samples/Polyfill.cs new file mode 100644 index 000000000..5853f02ea --- /dev/null +++ b/samples/Polyfill.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +#if NETFRAMEWORK + +using System; +using System.IO; +using System.Threading.Tasks; + +internal static class PolyfillExtensions +{ + extension(File) + { + internal static Task ReadAllTextAsync(string path) => throw new NotImplementedException(); + } +} + +#endif diff --git a/samples/ApiSamples.cs b/samples/SuppressRelevance.cs similarity index 92% rename from samples/ApiSamples.cs rename to samples/SuppressRelevance.cs index 6fa40c624..348eb2ce6 100644 --- a/samples/ApiSamples.cs +++ b/samples/SuppressRelevance.cs @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.Threading; -public class SuppressRelevanceSample +public class SuppressRelevance { private readonly ReentrantSemaphore semaphore = ReentrantSemaphore.Create(1, null, ReentrantSemaphore.ReentrancyMode.NotAllowed); @@ -19,7 +18,7 @@ await this.semaphore.ExecuteAsync(async delegate await Task.Yield(); // represents some async work // Fire and forget code that uses the semaphore, but should *not* - // inherit our own posession of the semaphore. + // inherit our own possession of the semaphore. using (this.semaphore.SuppressRelevance()) { this.DoSomethingLaterAsync().Forget(); // Don't await this, or a deadlock will occur. diff --git a/src/Microsoft.VisualStudio.Threading/DispatcherExtensions.cs b/src/Microsoft.VisualStudio.Threading/DispatcherExtensions.cs index 900977100..1b93c4bef 100644 --- a/src/Microsoft.VisualStudio.Threading/DispatcherExtensions.cs +++ b/src/Microsoft.VisualStudio.Threading/DispatcherExtensions.cs @@ -27,6 +27,13 @@ public static class DispatcherExtensions /// and for each asynchronous return to the main thread after an . /// /// A that may be used for scheduling async work with the specified priority. + /// + /// In addition to scheduling work on the UI thread with the specified priority, + /// this also ensures that any synchronous waits + /// on the main thread within objects created with the returned factory + /// will honor calls to , producing similar behavior + /// to . + /// public static JoinableTaskFactory WithPriority(this JoinableTaskFactory joinableTaskFactory, Dispatcher dispatcher, DispatcherPriority priority) { Requires.NotNull(joinableTaskFactory, nameof(joinableTaskFactory)); @@ -61,6 +68,9 @@ internal DispatcherJoinableTaskFactory(JoinableTaskFactory innerFactory, Dispatc : base(innerFactory) { this.dispatcher = dispatcher ?? throw new ArgumentNullException(nameof(dispatcher)); +#if NETFRAMEWORK // Avoid trim warnings on .NET, and only .NET Framework calls CoWait anyway. + this.DefaultWaitPolicy = new DispatcherSynchronizationContext(dispatcher); +#endif this.priority = priority; } diff --git a/src/Microsoft.VisualStudio.Threading/JoinableTask+JoinableTaskSynchronizationContext.cs b/src/Microsoft.VisualStudio.Threading/JoinableTask+JoinableTaskSynchronizationContext.cs index 776f67679..00432f5ae 100644 --- a/src/Microsoft.VisualStudio.Threading/JoinableTask+JoinableTaskSynchronizationContext.cs +++ b/src/Microsoft.VisualStudio.Threading/JoinableTask+JoinableTaskSynchronizationContext.cs @@ -2,11 +2,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; +using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; +using Windows.Win32; +using Windows.Win32.Foundation; namespace Microsoft.VisualStudio.Threading { @@ -44,6 +44,11 @@ internal JoinableTaskSynchronizationContext(JoinableTaskFactory owner) this.jobFactory = owner; this.mainThreadAffinitized = true; + + if (owner.DefaultWaitPolicy is not null) + { + this.SetWaitNotificationRequired(); + } } /// @@ -56,6 +61,11 @@ internal JoinableTaskSynchronizationContext(JoinableTask joinableTask, bool main { this.job = joinableTask; this.mainThreadAffinitized = mainThreadAffinitized; + + if (joinableTask.DisableProcessing > 0) + { + this.SetWaitNotificationRequired(); + } } /// @@ -134,6 +144,50 @@ public override void Send(SendOrPostCallback d, object? state) } } + /// + /// Synchronously blocks without a message pump. + /// + /// An array of type that contains the native operating system handles. + /// true to wait for all handles; false to wait for any handle. + /// The number of milliseconds to wait, or (-1) to wait indefinitely. + /// + /// The array index of the object that satisfied the wait. + /// + public override unsafe int Wait(IntPtr[] waitHandles, bool waitAll, int millisecondsTimeout) + { + Requires.NotNull(waitHandles, nameof(waitHandles)); + + if (this.job?.DisableProcessing > 0) + { + // On .NET Framework we must take special care to NOT end up in a call to CoWait (which lets in RPC calls). + // Off Windows, we can't p/invoke to kernel32, but it appears that .NET never calls CoWait, so we can rely on default behavior. + // We're just going to use the OS as the switch instead of the runtime so that (one day) if we drop our .NET Framework specific target, + // and if .NET ever adds CoWait support on Windows, we'll still behave properly. +#if NET + if (OperatingSystem.IsWindowsVersionAtLeast(5, 1, 2600)) +#else + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) +#endif + { + fixed (IntPtr* pHandles = waitHandles) + { + return (int)PInvoke.WaitForMultipleObjects((uint)waitHandles.Length, (HANDLE*)pHandles, waitAll, (uint)millisecondsTimeout); + } + } + } + + // Use a surrogate default policy if provided. + if (this.jobFactory.DefaultWaitPolicy is { } waitPolicy) + { + return waitPolicy.Wait(waitHandles, waitAll, millisecondsTimeout); + } + + // Fallback to sync blocking such that CoWait might be called. + return WaitHelper(waitHandles, waitAll, millisecondsTimeout); + } + + internal void ConsiderDisableProcessing() => this.SetWaitNotificationRequired(); + /// /// Called by the joinable task when it has completed. /// diff --git a/src/Microsoft.VisualStudio.Threading/JoinableTask.cs b/src/Microsoft.VisualStudio.Threading/JoinableTask.cs index 7490f3672..2cb35d17c 100644 --- a/src/Microsoft.VisualStudio.Threading/JoinableTask.cs +++ b/src/Microsoft.VisualStudio.Threading/JoinableTask.cs @@ -372,6 +372,23 @@ internal SynchronizationContext? ApplicableJobSyncContext } } + /// + /// Gets or sets a value indicating whether CoWait will be prohibited + /// during synchronously blocking waits from code actively running within this . + /// + internal int DisableProcessing + { + get => field; + set + { + field = value; + if (this.mainThreadJobSyncContext is { } syncContext) + { + syncContext.ConsiderDisableProcessing(); + } + } + } + /// /// Gets a weak reference to this object. /// diff --git a/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs b/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs index 3417adb3a..b6bc0acfd 100644 --- a/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs +++ b/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs @@ -98,6 +98,17 @@ internal JoinableTaskCollection? Collection get { return this.jobCollection; } } + /// + /// Gets a on which + /// should be called from + /// when has not been called. + /// + /// + /// This allows a WPF-aware -derived class within this assembly + /// to match Dispatcher.DisableProcessing() behavior. + /// + internal SynchronizationContext? DefaultWaitPolicy { get; init; } + /// /// Gets or sets the timeout after which no activity while synchronously blocking /// suggests a hang has occurred. @@ -294,6 +305,55 @@ public JoinableTask RunAsync(Func> asyncMethod, string? parentToke return this.RunAsync(asyncMethod, synchronouslyBlocking: false, parentToken, creationOptions: creationOptions); } +#pragma warning disable SA1629 // Documentation text should end with a period + /// + /// Prevents filtered message pumps from running during synchronous waits for the ambient . + /// + /// + /// A value that may be disposed of when the need to suppress synchronous wait message pumps is ended. + /// Alternatively it may be discarded if the rest of the is intended to have processing disabled. + /// + /// Thrown when called outside the context of a . + /// + /// + /// During a yielding within a , no message pump ever runs + /// regardless of whether this method is called, except for the internal one that lets in only relevant work. + /// When user code runs within the delegate or its callees that ends up requiring + /// a synchronous block of the main thread (e.g. synchronous I/O or lock contention), this wait is typically + /// implemented by calling on . + /// The default implementation of this method allows for certain interruptions (e.g. COM RPC calls), which + /// may avoid deadlocks in certain situations. + /// + /// + /// Calling this method will replace the default implementation of + /// with one that will not allow such interruptions while that is active and in control of + /// . + /// As this method may be called multiple times, this effect remains on the target + /// until all invocations' return values are disposed (in any order). + /// The effect only applies to the direct . It does not affect any of its children or parents. + /// + /// + /// Disabling processing has no effect on non-Windows operating systems. + /// + /// + /// Disposing the resulting value will revert to the default behavior. + /// Callers need not ever dispose of this value if the intent is to disable processing for the remainder of that + /// 's execution. + /// + /// + /// + /// + /// Here is a simple, common usage of this method: + /// + /// + /// + /// Following are more examples of how it might be used: + /// + /// + /// + public ProcessingDisabledOperation DisableProcessing() => new(this.Context.AmbientTask ?? throw new InvalidOperationException(Strings.NoAmbientTask)); +#pragma warning restore SA1629 // Documentation text should end with a period + /// /// Responds to calls to /// by scheduling a continuation to execute on the Main thread. @@ -682,6 +742,34 @@ static bool FailFast(Exception ex) } } + /// + /// A struct whose disposal will revert the effect of an earlier call to . + /// + public struct ProcessingDisabledOperation : IDisposable + { + private JoinableTask? owner; + + /// + /// Initializes a new instance of the struct. + /// + /// The owner of this struct. + internal ProcessingDisabledOperation(JoinableTask owner) + { + owner.DisableProcessing++; + this.owner = owner; + } + + /// + public void Dispose() + { + if (this.owner is { } owner) + { + owner.DisableProcessing--; + this.owner = null; + } + } + } + /// /// An awaitable struct that facilitates an asynchronous transition to the Main thread. /// diff --git a/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs b/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs index feea9a501..fa92b754c 100644 --- a/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs +++ b/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; -using System.Buffers; using System.Runtime.InteropServices; using System.Threading; using global::Windows.Win32; @@ -52,10 +51,10 @@ public override unsafe int Wait(IntPtr[] waitHandles, bool waitAll, int millisec Requires.NotNull(waitHandles, nameof(waitHandles)); // On .NET Framework we must take special care to NOT end up in a call to CoWait (which lets in RPC calls). - // Off Windows, we can't p/invoke to kernel32, but it appears that .NET Core never calls CoWait, so we can rely on default behavior. - // We're just going to use the OS as the switch instead of the framework so that (one day) if we drop our .NET Framework specific target, - // and if .NET Core ever adds CoWait support on Windows, we'll still behave properly. -#if NET5_0_OR_GREATER + // Off Windows, we can't p/invoke to kernel32, but it appears that .NET never calls CoWait, so we can rely on default behavior. + // We're just going to use the OS as the switch instead of the runtime so that (one day) if we drop our .NET Framework specific target, + // and if .NET ever adds CoWait support on Windows, we'll still behave properly. +#if NET if (OperatingSystem.IsWindowsVersionAtLeast(5, 1, 2600)) #else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) diff --git a/src/Microsoft.VisualStudio.Threading/ReentrantSemaphore.cs b/src/Microsoft.VisualStudio.Threading/ReentrantSemaphore.cs index 5ee34752a..fbe46fe2d 100644 --- a/src/Microsoft.VisualStudio.Threading/ReentrantSemaphore.cs +++ b/src/Microsoft.VisualStudio.Threading/ReentrantSemaphore.cs @@ -191,7 +191,7 @@ public static ReentrantSemaphore Create(int initialCount = 1, JoinableTaskContex /// /// The following snippet demonstrates a way to use this method. /// - /// + /// /// public virtual RevertRelevance SuppressRelevance() => default; diff --git a/src/Microsoft.VisualStudio.Threading/Strings.resx b/src/Microsoft.VisualStudio.Threading/Strings.resx index 2537d49ff..4983aee85 100644 --- a/src/Microsoft.VisualStudio.Threading/Strings.resx +++ b/src/Microsoft.VisualStudio.Threading/Strings.resx @@ -193,4 +193,7 @@ No SynchronizationContext to reach the main thread has been set. + + No JoinableTask is active. + \ No newline at end of file diff --git a/test/Microsoft.VisualStudio.Threading.Tests/AsyncCrossProcessMutexTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/AsyncCrossProcessMutexTests.cs index 409769123..1223a8491 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/AsyncCrossProcessMutexTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/AsyncCrossProcessMutexTests.cs @@ -157,14 +157,14 @@ public async Task TryEnterAsync_AbandonedMutex() using AsyncCrossProcessMutex mutex2 = new(this.mutex.Name); AsyncCrossProcessMutex.LockReleaser? abandonedReleaser = await this.mutex.TryEnterAsync(Timeout.InfiniteTimeSpan); - Assert.False(abandonedReleaser.Value.IsAbandoned); + Assert.False(abandonedReleaser?.IsAbandoned); // Dispose the mutex WITHOUT first releasing it. this.mutex.Dispose(); using (AsyncCrossProcessMutex.LockReleaser? releaser2 = await mutex2.TryEnterAsync(Timeout.InfiniteTimeSpan)) { - Assert.True(releaser2.Value.IsAbandoned); + Assert.True(releaser2?.IsAbandoned); } } diff --git a/test/Microsoft.VisualStudio.Threading.Tests/CoWaitMainThreadTransition.cs b/test/Microsoft.VisualStudio.Threading.Tests/CoWaitMainThreadTransition.cs new file mode 100644 index 000000000..f10cc753d --- /dev/null +++ b/test/Microsoft.VisualStudio.Threading.Tests/CoWaitMainThreadTransition.cs @@ -0,0 +1,199 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +#if NETFRAMEWORK + +using System; +using System.Reflection; +using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; +using System.Threading; + +/// +/// Probes whether the calling STA thread's current synchronous wait allows COM RPC calls to +/// penetrate. Spawns an MTA thread that marshals a COM call back to the STA thread; the call +/// succeeds if and only if the thread is performing a CoWait (message-pumping wait) rather than +/// a plain WaitForMultipleObjects wait. +/// +/// +/// Create the probe before blocking the STA thread, then call after unblocking +/// to learn whether the COM call was delivered. Dispose when done to cancel any pending call and +/// release resources. +/// +public sealed class CoWaitMainThreadTransition : IDisposable +{ + /// Best-effort delay in milliseconds used when cancelling or joining the caller thread. + private const int CallCancellationDelayMs = 500; + + private const int RpcECallCanceled = unchecked((int)0x80010002); + + private static readonly Guid IDispatchGuid = new("00020400-0000-0000-C000-000000000046"); + + private readonly ManualResetEventSlim signalReceived = new(); + private readonly ManualResetEventSlim callerReady = new(); + private readonly Thread callerThread; + private Exception? backgroundFailure; + private uint callerThreadId; + + /// + /// Initializes a new instance of the class and + /// immediately starts the background MTA thread that will attempt the COM call. + /// + internal CoWaitMainThreadTransition() + { + IMainThreadSignaler signaler = new MainThreadSignaler(this.signalReceived); + IntPtr signalerInterface = Marshal.GetIDispatchForObject(signaler); + try + { + Marshal.ThrowExceptionForHR(NativeMethods.CoMarshalInterThreadInterfaceInStream(in IDispatchGuid, signalerInterface, out IntPtr stream)); + + this.callerThread = new Thread(() => this.InvokeSignalOnBackgroundThread(stream)) + { + IsBackground = true, + }; + } + finally + { + Marshal.Release(signalerInterface); + } + +#pragma warning disable CA1416 // Apartment state is only relevant on Windows, and the probe is not used elsewhere. + this.callerThread.SetApartmentState(ApartmentState.MTA); +#pragma warning restore CA1416 + this.callerThread.Start(); + } + + /// + /// A COM-visible IDispatch interface used to signal the main STA thread from an MTA background thread. + /// + [ComVisible(true)] + [Guid("A1D1F0E7-564F-4B9F-8DB2-D40185F115FB")] + [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] + public interface IMainThreadSignaler + { + /// Signals the main thread that it has received a COM RPC call. + [DispId(1)] + void Signal(); + } + + /// + public void Dispose() + { + if (this.callerThread.IsAlive) + { + this.CancelPendingCall(); + _ = this.callerThread.Join(CallCancellationDelayMs); + } + + this.callerReady.Dispose(); + this.signalReceived.Dispose(); + } + + /// + /// Blocks until the COM call completes or elapses. + /// + /// The maximum time to wait. + /// + /// if the COM call was delivered within ; + /// if the call did not penetrate the wait (i.e., no CoWait was used). + /// + internal bool Wait(TimeSpan timeout) + { + bool interruptedWait = this.signalReceived.Wait(timeout); + if (!interruptedWait) + { + this.CancelPendingCall(); + } + + if (interruptedWait) + { + Assert.True(this.callerThread.Join(timeout), "Timed out waiting for the COM call to finish."); + } + else + { + _ = this.callerThread.Join(CallCancellationDelayMs); + } + + if (this.backgroundFailure is object && (interruptedWait || this.backgroundFailure.HResult != RpcECallCanceled)) + { + ExceptionDispatchInfo.Capture(this.backgroundFailure).Throw(); + } + + return interruptedWait; + } + + private void CancelPendingCall() + { + Assert.True(this.callerReady.Wait(CallCancellationDelayMs), "Timed out waiting for the COM caller thread to initialize."); + if (this.callerThread.IsAlive && this.callerThreadId != 0) + { + _ = NativeMethods.CoCancelCall(this.callerThreadId, 0); + } + } + + private void InvokeSignalOnBackgroundThread(IntPtr stream) + { + try + { + this.callerThreadId = NativeMethods.GetCurrentThreadId(); + Marshal.ThrowExceptionForHR(NativeMethods.CoEnableCallCancellation(IntPtr.Zero)); + this.callerReady.Set(); + + Thread.Sleep(50); + Marshal.ThrowExceptionForHR(NativeMethods.CoGetInterfaceAndReleaseStream(stream, in IDispatchGuid, out object signaler)); + signaler.GetType().InvokeMember(nameof(IMainThreadSignaler.Signal), BindingFlags.InvokeMethod, binder: null, target: signaler, args: Array.Empty()); + } + catch (Exception ex) + { + this.backgroundFailure = ex; + this.callerReady.Set(); + } + finally + { + _ = NativeMethods.CoDisableCallCancellation(IntPtr.Zero); + } + } + + /// + /// COM-visible implementation of that uses the free-threaded + /// marshaler so the COM proxy routes calls back to whichever STA thread holds the object. + /// + [ComVisible(true)] + [ClassInterface(ClassInterfaceType.None)] + public sealed class MainThreadSignaler : StandardOleMarshalObject, IMainThreadSignaler + { + private readonly ManualResetEventSlim signalReceived; + + /// Initializes a new instance of the class. + /// The event to set when is called. + internal MainThreadSignaler(ManualResetEventSlim signalReceived) + { + this.signalReceived = signalReceived; + } + + /// + public void Signal() => this.signalReceived.Set(); + } + + private static class NativeMethods + { + [DllImport("ole32.dll")] + internal static extern int CoMarshalInterThreadInterfaceInStream(in Guid riid, IntPtr pUnk, out IntPtr ppStm); + + [DllImport("ole32.dll")] + internal static extern int CoGetInterfaceAndReleaseStream(IntPtr pStm, in Guid iid, [MarshalAs(UnmanagedType.IDispatch)] out object ppv); + + [DllImport("ole32.dll")] + internal static extern int CoEnableCallCancellation(IntPtr pReserved); + + [DllImport("ole32.dll")] + internal static extern int CoDisableCallCancellation(IntPtr pReserved); + + [DllImport("ole32.dll")] + internal static extern int CoCancelCall(uint dwThreadId, uint ulTimeout); + + [DllImport("kernel32.dll")] + internal static extern uint GetCurrentThreadId(); + } +} +#endif diff --git a/test/Microsoft.VisualStudio.Threading.Tests/DispatcherExtensionsTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/DispatcherExtensionsTests.cs index 65d324840..44b5cf011 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/DispatcherExtensionsTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/DispatcherExtensionsTests.cs @@ -66,6 +66,50 @@ public void WithPriority_LowPriorityCanBlockOnHighPriorityWork() await Task.WhenAll(idleTask.Task, normalTask.Task).WithCancellation(this.TimeoutToken); }); } + +#if NETFRAMEWORK + [StaFact] + public void WithPriority_MatchesDisableProcessingWithinDelegate() + { + this.SimulateUIThread(delegate + { + JoinableTaskFactory? normalPriorityJtf = this.asyncPump.WithPriority(Dispatcher.CurrentDispatcher, DispatcherPriority.Normal); + normalPriorityJtf.Run(delegate + { + this.AssertProcessingAllowed(); + + using (Dispatcher.CurrentDispatcher.DisableProcessing()) + { + this.AssertProcessingDisabled(); + } + + return Task.CompletedTask; + }); + + return Task.CompletedTask; + }); + } + + [StaFact] + public void WithPriority_MatchesDisableProcessingOutsideDelegate() + { + this.SimulateUIThread(delegate + { + JoinableTaskFactory? normalPriorityJtf = this.asyncPump.WithPriority(Dispatcher.CurrentDispatcher, DispatcherPriority.Normal); + using (Dispatcher.CurrentDispatcher.DisableProcessing()) + { + normalPriorityJtf.Run(delegate + { + this.AssertProcessingDisabled(); + + return Task.CompletedTask; + }); + } + + return Task.CompletedTask; + }); + } +#endif } diff --git a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskFactoryTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskFactoryTests.cs index 7f05a3d5b..b7b8832e0 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskFactoryTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskFactoryTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; @@ -138,6 +138,110 @@ public void SwitchToMainThreadAlwaysYield() }); } + [Fact] + public void DisableProcessing_ThrowsOutsideJoinableTask() + { + Assert.Throws(() => this.asyncPump.DisableProcessing()); + } + + [Fact] + public void DisableProcessing_InsideJoinableTask() + { + this.asyncPump.Run(delegate + { + using (this.asyncPump.DisableProcessing()) + { + } + + return Task.CompletedTask; + }); + } + + [Fact] + public void ProcessingDisabledOperation_Dispose_DoesNotThrowFromDefaultValue() + { + default(JoinableTaskFactory.ProcessingDisabledOperation).Dispose(); + } + +#if NETFRAMEWORK + [StaFact] + public void DisableProcessing() + { + this.asyncPump.Run(() => + { + this.AssertProcessingAllowed(); + + using (this.asyncPump.DisableProcessing()) + { + this.AssertProcessingDisabled(); + } + + this.AssertProcessingAllowed(); + return Task.CompletedTask; + }); + } + + [StaFact] + public void DisableProcessing_NestedProcessingDisabled() + { + this.asyncPump.Run(() => + { + using (this.asyncPump.DisableProcessing()) + { + using (this.asyncPump.DisableProcessing()) + { + this.AssertProcessingDisabled(); + } + + this.AssertProcessingDisabled(); + } + + this.AssertProcessingAllowed(); + return Task.CompletedTask; + }); + } + + [StaFact] + public void DisableProcessing_NestedTasks() + { + this.asyncPump.Run(() => + { + using (this.asyncPump.DisableProcessing()) + { + this.asyncPump.Run(() => + { + // Child JoinableTasks do not inherit the processing-disabled state of their parents. + this.AssertProcessingAllowed(); + + return Task.CompletedTask; + }); + } + + return Task.CompletedTask; + }); + } + + [StaFact] + public void DisableProcessing_RefCounted() + { + this.asyncPump.Run(() => + { + JoinableTaskFactory.ProcessingDisabledOperation first = this.asyncPump.DisableProcessing(); + JoinableTaskFactory.ProcessingDisabledOperation second = this.asyncPump.DisableProcessing(); + + // Dispose things in a FIFO order instead of a nested LIFO order. + // Processing should only be re-enabled after the last reference is disposed. + first.Dispose(); + this.AssertProcessingDisabled(); + second.Dispose(); + this.AssertProcessingAllowed(); + + return Task.CompletedTask; + }); + } + +#endif + /// /// A that allows a test to inject code /// in the main thread transition events. diff --git a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskTestBase.cs b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskTestBase.cs index f415f93ba..bd64031a9 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskTestBase.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskTestBase.cs @@ -96,4 +96,30 @@ protected void PushFrameTillQueueIsEmpty() this.dispatcherContext.Post(s => this.testFrame.Continue = false, null); this.PushFrame(); } + +#if NETFRAMEWORK + protected void AssertProcessingDisabled() + { + Assert.SkipUnless(MightCoWaitBeUsed, "DisableProcessing has no effect in this environment."); + + // For this check to work, we need to be on the main thread. + Assert.True(this.asyncPump.Context.IsOnMainThread); + Assert.Equal(ApartmentState.STA, Thread.CurrentThread.GetApartmentState()); + + using CoWaitMainThreadTransition transition = new(); + Assert.False(transition.Wait(ExpectedTimeout)); + } + + protected void AssertProcessingAllowed() + { + Assert.SkipUnless(MightCoWaitBeUsed, "DisableProcessing has no effect in this environment."); + + // For this check to work, we need to be on the main thread. + Assert.True(this.asyncPump.Context.IsOnMainThread); + Assert.Equal(ApartmentState.STA, Thread.CurrentThread.GetApartmentState()); + + using CoWaitMainThreadTransition transition = new(); + Assert.True(transition.Wait(UnexpectedTimeout)); + } +#endif } diff --git a/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs new file mode 100644 index 000000000..1c95be9d3 --- /dev/null +++ b/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Threading; + +/// +/// Tests for . +/// +public class NoMessagePumpSyncContextTests : TestBase +{ + /// + /// Initializes a new instance of the class. + /// + /// The logger to use for test output. + public NoMessagePumpSyncContextTests(ITestOutputHelper logger) + : base(logger) + { + } + + /// + /// Verifies that returns a usable singleton instance. + /// + [Fact] + public void Default_IsNonNull() + { + Assert.NotNull(NoMessagePumpSyncContext.Default); + } + + /// + /// Verifies that the default singleton is itself a . + /// + [Fact] + public void Default_IsNoMessagePumpSyncContext() + { + Assert.IsType(NoMessagePumpSyncContext.Default); + } + +#if NETFRAMEWORK + /// + /// Establishes the baseline: on a plain STA thread without a special synchronization context, + /// uses CoWaitForMultipleHandles, + /// which allows COM RPC calls to be dispatched to the thread while it waits. + /// + [StaFact] + public void Wait_ComRpcPenetratesDefaultStaWait() + { + using CoWaitMainThreadTransition probe = new(); + using ManualResetEvent mre = new(false); + + // Block the STA thread; the default CoWait allows the COM call to execute Signal(). + mre.WaitOne((int)ExpectedTimeout.TotalMilliseconds); + + // The COM call should have been delivered while the thread was blocked. + Assert.True(probe.Wait(TimeSpan.FromMilliseconds(AsyncDelay))); + } + + /// + /// Verifies that uses + /// WaitForMultipleObjects rather than CoWaitForMultipleHandles, preventing + /// COM RPC calls from being dispatched to the thread while it is synchronously waiting. + /// + [StaFact] + public void Wait_BlocksComRpcCalls() + { + using (NoMessagePumpSyncContext.Default.Apply()) + { + using CoWaitMainThreadTransition probe = new(); + using ManualResetEvent mre = new(false); + + // Block the STA thread; NoMessagePumpSyncContext uses WaitForMultipleObjects, + // so the COM call cannot be dispatched while the thread is waiting. + mre.WaitOne((int)ExpectedTimeout.TotalMilliseconds); + + // The COM call should NOT have been delivered. + Assert.False(probe.Wait(TimeSpan.FromMilliseconds(AsyncDelay))); + } + } +#endif +} diff --git a/test/Microsoft.VisualStudio.Threading.Tests/NonConcurrentSynchronizationContextTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/NonConcurrentSynchronizationContextTests.cs index 650a23f3b..4ac2e9b75 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/NonConcurrentSynchronizationContextTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/NonConcurrentSynchronizationContextTests.cs @@ -48,7 +48,7 @@ public async Task UnhandledException_WithHandler() var eventArgs = new TaskCompletionSource<(object?, Exception)>(); this.nonSticky.UnhandledException += (s, e) => eventArgs.SetResult((s, e)); this.nonSticky.Post(s => throw new InvalidOperationException(), null); - (object sender, Exception ex) = await eventArgs.Task.WithCancellation(this.TimeoutToken); + (object? sender, Exception? ex) = await eventArgs.Task.WithCancellation(this.TimeoutToken); Assert.Same(this.nonSticky, sender); Assert.IsType(ex); } diff --git a/test/Microsoft.VisualStudio.Threading.Tests/TestBase.cs b/test/Microsoft.VisualStudio.Threading.Tests/TestBase.cs index d3a837504..704e1440b 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/TestBase.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/TestBase.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Microsoft; @@ -35,6 +36,13 @@ protected TestBase(ITestOutputHelper logger) this.Logger = logger; } + protected static bool MightCoWaitBeUsed +#if NETFRAMEWORK + => RuntimeInformation.IsOSPlatform(OSPlatform.Windows); +#else + => false; +#endif + /// /// Gets or sets the source of that influences /// when tests consider themselves to be timed out. diff --git a/test/NativeAOTCompatibility.Test/NativeAOTCompatibility.Test.csproj b/test/NativeAOTCompatibility.Test/NativeAOTCompatibility.Test.csproj index 87b9e251d..1c29d7486 100644 --- a/test/NativeAOTCompatibility.Test/NativeAOTCompatibility.Test.csproj +++ b/test/NativeAOTCompatibility.Test/NativeAOTCompatibility.Test.csproj @@ -3,6 +3,7 @@ net10.0 false + false $(TargetFrameworks);net10.0-windows diff --git a/version.json b/version.json index 2f478c761..e53c28ed4 100644 --- a/version.json +++ b/version.json @@ -1,6 +1,6 @@ { "$schema": "https://raw.githubusercontent.com/dotnet/Nerdbank.GitVersioning/main/src/NerdBank.GitVersioning/version.schema.json", - "version": "18.0", + "version": "18.7", "publicReleaseRefSpec": [ "^refs/heads/main$", "^refs/heads/v\\d+(?:\\.\\d+)?$"