diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/RequestExecutionQueueTests.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/RequestExecutionQueueTests.cs index 220532b62b9b..2a0349580a4a 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/RequestExecutionQueueTests.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/RequestExecutionQueueTests.cs @@ -4,6 +4,8 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis.Elfie.Diagnostics; @@ -13,6 +15,7 @@ using StreamJsonRpc; using Xunit; using static Microsoft.CommonLanguageServerProtocol.Framework.UnitTests.HandlerProviderTests; +using static Microsoft.CommonLanguageServerProtocol.Framework.UnitTests.RequestExecutionQueueTests; namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; @@ -31,14 +34,29 @@ protected override ILspServices ConstructLspServices() } private const string MethodName = "SomeMethod"; + private const string CancellingMethod = "CancellingMethod"; + private const string CompletingMethod = "CompletingMethod"; + private const string MutatingMethod = "MutatingMethod"; - private static RequestExecutionQueue GetRequestExecutionQueue(IMethodHandler? methodHandler = null) + private static RequestExecutionQueue GetRequestExecutionQueue(bool cancelInProgressWorkUponMutatingRequest, params IMethodHandler[] methodHandlers) { var handlerProvider = new Mock(MockBehavior.Strict); - var handler = methodHandler ?? GetTestMethodHandler(); - handlerProvider.Setup(h => h.GetMethodHandler(MethodName, TestMethodHandler.RequestType, TestMethodHandler.ResponseType)).Returns(handler); + if (methodHandlers.Length == 0) + { + var handler = GetTestMethodHandler(); + handlerProvider.Setup(h => h.GetMethodHandler(MethodName, TestMethodHandler.RequestType, TestMethodHandler.ResponseType)).Returns(handler); + } + + foreach (var methodHandler in methodHandlers) + { + var methodType = methodHandler.GetType(); + var methodAttribute = methodType.GetCustomAttribute(); + var method = methodAttribute.Method; - var executionQueue = new RequestExecutionQueue(new MockServer(), NoOpLspLogger.Instance, handlerProvider.Object); + handlerProvider.Setup(h => h.GetMethodHandler(method, typeof(int), typeof(string))).Returns(methodHandler); + } + + var executionQueue = new TestRequestExecutionQueue(new MockServer(), NoOpLspLogger.Instance, handlerProvider.Object, cancelInProgressWorkUponMutatingRequest); executionQueue.Start(); return executionQueue; @@ -65,19 +83,45 @@ private static TestMethodHandler GetTestMethodHandler() [Fact] public async Task ExecuteAsync_ThrowCompletes() { + // Arrange var throwingHandler = new ThrowingHandler(); - var requestExecutionQueue = GetRequestExecutionQueue(throwingHandler); - var request = 1; + var requestExecutionQueue = GetRequestExecutionQueue(false, throwingHandler); var lspServices = GetLspServices(); - await Assert.ThrowsAsync(() => requestExecutionQueue.ExecuteAsync(request, MethodName, lspServices, CancellationToken.None)); + // Act & Assert + await Assert.ThrowsAsync(() => requestExecutionQueue.ExecuteAsync(1, MethodName, lspServices, CancellationToken.None)); + } + + [Fact] + public async Task ExecuteAsync_WithCancelInProgressWork_CancelsInProgressWorkWhenMutatingRequestArrives() + { + // Let's try it a bunch of times to try to find timing issues. + for (var i = 0; i < 20; i++) + { + // Arrange + var mutatingHandler = new MutatingHandler(); + var cancellingHandler = new CancellingHandler(); + var completingHandler = new CompletingHandler(); + var requestExecutionQueue = GetRequestExecutionQueue(cancelInProgressWorkUponMutatingRequest: true, methodHandlers: new IMethodHandler[] { cancellingHandler, completingHandler, mutatingHandler }); + var lspServices = GetLspServices(); + + var cancellingRequestCancellationToken = new CancellationToken(); + var completingRequestCancellationToken = new CancellationToken(); + + var _ = requestExecutionQueue.ExecuteAsync(1, CancellingMethod, lspServices, cancellingRequestCancellationToken); + var _1 = requestExecutionQueue.ExecuteAsync(1, CompletingMethod, lspServices, completingRequestCancellationToken); + + // Act & Assert + // A Debug.Assert would throw if the tasks hadn't completed when the mutating request is called. + await requestExecutionQueue.ExecuteAsync(1, MutatingMethod, lspServices, CancellationToken.None); + } } [Fact] public async Task Dispose_MultipleTimes_Succeeds() { // Arrange - var requestExecutionQueue = GetRequestExecutionQueue(); + var requestExecutionQueue = GetRequestExecutionQueue(false); // Act await requestExecutionQueue.DisposeAsync(); @@ -86,20 +130,10 @@ public async Task Dispose_MultipleTimes_Succeeds() // Assert, it didn't fail } - public class ThrowingHandler : IRequestHandler - { - public bool MutatesSolutionState => false; - - public Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } - } - [Fact] public async Task ExecuteAsync_CompletesTask() { - var requestExecutionQueue = GetRequestExecutionQueue(); + var requestExecutionQueue = GetRequestExecutionQueue(false); var request = 1; var lspServices = GetLspServices(); @@ -111,7 +145,7 @@ public async Task ExecuteAsync_CompletesTask() [Fact] public async Task Queue_DrainsOnShutdown() { - var requestExecutionQueue = GetRequestExecutionQueue(); + var requestExecutionQueue = GetRequestExecutionQueue(false); var request = 1; var lspServices = GetLspServices(); @@ -124,7 +158,75 @@ public async Task Queue_DrainsOnShutdown() Assert.True(task2.IsCompleted); } - private class TestResponse + private class TestRequestExecutionQueue : RequestExecutionQueue + { + private readonly bool _cancelInProgressWorkUponMutatingRequest; + + public TestRequestExecutionQueue(AbstractLanguageServer languageServer, ILspLogger logger, IHandlerProvider handlerProvider, bool cancelInProgressWorkUponMutatingRequest) + : base(languageServer, logger, handlerProvider) + { + _cancelInProgressWorkUponMutatingRequest = cancelInProgressWorkUponMutatingRequest; + } + + protected override bool CancelInProgressWorkUponMutatingRequest => _cancelInProgressWorkUponMutatingRequest; + } + + [LanguageServerEndpoint(MutatingMethod)] + public class MutatingHandler : IRequestHandler { + public MutatingHandler() + { + } + + public bool MutatesSolutionState => true; + + public Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + { + return Task.FromResult(string.Empty); + } + } + + [LanguageServerEndpoint(CompletingMethod)] + public class CompletingHandler : IRequestHandler + { + public bool MutatesSolutionState => false; + + public async Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + { + while (true) + { + if (cancellationToken.IsCancellationRequested) + { + return "I completed!"; + } + await Task.Delay(100); + } + } + } + + [LanguageServerEndpoint(CancellingMethod)] + public class CancellingHandler : IRequestHandler + { + public bool MutatesSolutionState => false; + + public async Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + { + while (true) + { + cancellationToken.ThrowIfCancellationRequested(); + await Task.Delay(100); + } + } + } + + [LanguageServerEndpoint(MethodName)] + public class ThrowingHandler : IRequestHandler + { + public bool MutatesSolutionState => false; + + public Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } } } diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs index 6c5fe52ab713..d37e859ee6f8 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs @@ -3,16 +3,17 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Concurrent; using System.Diagnostics; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.Threading; -using System.Collections.Immutable; namespace Microsoft.CommonLanguageServerProtocol.Framework; /// -/// Coordinates the exectution of LSP messages to ensure correct results are sent back. +/// Coordinates the execution of LSP messages to ensure correct results are sent back. /// /// /// @@ -21,7 +22,7 @@ namespace Microsoft.CommonLanguageServerProtocol.Framework; /// (via textDocument/didChange for example). /// /// -/// This class acheives this by distinguishing between mutating and non-mutating requests, and ensuring that +/// This class achieves this by distinguishing between mutating and non-mutating requests, and ensuring that /// when a mutating request comes in, its processing blocks all subsequent requests. As each request comes in /// it is added to a queue, and a queue item will not be retrieved while a mutating request is running. Before /// any request is handled the solution state is created by merging workspace solution state, which could have @@ -89,6 +90,19 @@ protected IMethodHandler GetMethodHandler(string methodName return handler; } + /// + /// Indicates this queue requires in-progress work to be cancelled before servicing + /// a mutating request. + /// + /// + /// This was added for WebTools consumption as they aren't resilient to + /// incomplete requests continuing execution during didChange notifications. As their + /// parse trees are mutable, a didChange notification requires all previous requests + /// to be completed before processing. This is similar to the O# + /// WithContentModifiedSupport(false) behavior. + /// + protected virtual bool CancelInProgressWorkUponMutatingRequest => false; + /// /// Queues a request to be handled by the specified handler, with mutating requests blocking subsequent requests /// from starting until the mutation is complete. @@ -156,6 +170,8 @@ private async Task ProcessQueueAsync() ILspServices? lspServices = null; try { + var concurrentlyExecutingTasks = new ConcurrentDictionary(); + while (!_cancelSource.IsCancellationRequested) { // First attempt to de-queue the work item in its own try-catch. @@ -175,9 +191,27 @@ private async Task ProcessQueueAsync() try { var (work, activityId, cancellationToken) = queueItem; + CancellationTokenSource? currentWorkCts = null; lspServices = work.LspServices; - var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken, cancellationToken); + if (CancelInProgressWorkUponMutatingRequest) + { + try + { + currentWorkCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken, cancellationToken); + } + catch (ObjectDisposedException) + { + // Explicitly ignore this exception as this can occur during the CreateLinkTokenSource call, and means one of the + // linked cancellationTokens has been cancelled. If this occurs, skip to the next loop iteration as this + // queueItem requires no processing + continue; + } + + // Use the linked cancellation token so it's task can be cancelled if necessary during a mutating request + // on a queue that specifies CancelInProgressWorkUponMutatingRequest + cancellationToken = currentWorkCts.Token; + } // Restore our activity id so that logging/tracking works across asynchronous calls. Trace.CorrelationManager.ActivityId = activityId; @@ -186,23 +220,60 @@ private async Task ProcessQueueAsync() var context = await work.CreateRequestContextAsync(cancellationToken).ConfigureAwait(false); if (work.MutatesServerState) { + if (CancelInProgressWorkUponMutatingRequest) + { + // Cancel all concurrently executing tasks + var concurrentlyExecutingTasksArray = concurrentlyExecutingTasks.ToArray(); + for (var i = 0; i < concurrentlyExecutingTasksArray.Length; i++) + { + concurrentlyExecutingTasksArray[i].Value.Cancel(); + } + + // wait for all pending tasks to complete their cancellation, ignoring any exceptions + await Task.WhenAll(concurrentlyExecutingTasksArray.Select(kvp => kvp.Key)).NoThrowAwaitableInternal(captureContext: false); + } + + Debug.Assert(!concurrentlyExecutingTasks.Any(), "The tasks should have all been drained before continuing"); // Mutating requests block other requests from starting to ensure an up to date snapshot is used. // Since we're explicitly awaiting exceptions to mutating requests will bubble up here. await WrapStartRequestTaskAsync(work.StartRequestAsync(context, cancellationToken), rethrowExceptions: true).ConfigureAwait(false); } else { - // Non mutating are fire-and-forget because they are by definition readonly. Any errors + // Non mutating are fire-and-forget because they are by definition read-only. Any errors // will be sent back to the client but they can also be captured via HandleNonMutatingRequestError, // though these errors don't put us into a bad state as far as the rest of the queue goes. // Furthermore we use Task.Run here to protect ourselves against synchronous execution of work - // blocking the request queue for longer periods of time (it enforces parallelizabilty). - _ = WrapStartRequestTaskAsync(Task.Run(() => work.StartRequestAsync(context, cancellationToken), cancellationToken), rethrowExceptions: false); + // blocking the request queue for longer periods of time (it enforces parallelizability). + var currentWorkTask = WrapStartRequestTaskAsync(Task.Run(() => work.StartRequestAsync(context, cancellationToken), cancellationToken), rethrowExceptions: false); + + if (CancelInProgressWorkUponMutatingRequest) + { + if (currentWorkCts is null) + { + throw new InvalidOperationException($"unexpected null value for {nameof(currentWorkCts)}"); + } + + if (!concurrentlyExecutingTasks.TryAdd(currentWorkTask, currentWorkCts)) + { + throw new InvalidOperationException($"unable to add {nameof(currentWorkTask)} into {nameof(concurrentlyExecutingTasks)}"); + } + + _ = currentWorkTask.ContinueWith(t => + { + if (!concurrentlyExecutingTasks.TryRemove(t, out var concurrentlyExecutingTaskCts)) + { + throw new InvalidOperationException($"unexpected failure to remove task from {nameof(concurrentlyExecutingTasks)}"); + } + + concurrentlyExecutingTaskCts.Dispose(); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + } } } - catch (OperationCanceledException ex) when (ex.CancellationToken == queueItem.cancellationToken) + catch (OperationCanceledException) { - // Explicitly ignore this exception as cancellation occured as a result of our linked cancellation token. + // Explicitly ignore this exception as cancellation occurred as a result of our linked cancellation token. // This means either the queue is shutting down or the request itself was cancelled. // 1. If the queue is shutting down, then while loop will exit before the next iteration since it checks for cancellation. // 2. Request cancellations are normal so no need to report anything there. @@ -227,7 +298,7 @@ private async Task ProcessQueueAsync() } /// - /// Provides an extensiblity point to log or otherwise inspect errors thrown from non-mutating requests, + /// Provides an extensibility point to log or otherwise inspect errors thrown from non-mutating requests, /// which would otherwise be lost to the fire-and-forget task in the queue. /// /// The task to be inspected. diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/TaskExtensions.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/TaskExtensions.cs new file mode 100644 index 000000000000..2ce0886ee3c7 --- /dev/null +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/TaskExtensions.cs @@ -0,0 +1,138 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.CommonLanguageServerProtocol.Framework +{ + internal static partial class TaskExtensions + { + // Following code is copied from Microsoft.VisualStudio.Threading.TplExtensions (renamed to avoid ambiguity) + // https://github.com/microsoft/vs-threading/blob/main/src/Microsoft.VisualStudio.Threading/TplExtensions.cs + + /// + /// Returns an awaitable for the specified task that will never throw, even if the source task + /// faults or is canceled. + /// + /// The task whose completion should signal the completion of the returned awaitable. + /// if set to true the continuation will be scheduled on the caller's context; false to always execute the continuation on the threadpool. + /// An awaitable. + public static NoThrowTaskAwaitable NoThrowAwaitableInternal(this Task task, bool captureContext = true) + { + return new NoThrowTaskAwaitable(task, captureContext); + } + + /// + /// An awaitable that wraps a task and never throws an exception when waited on. + /// + public readonly struct NoThrowTaskAwaitable + { + /// + /// The task. + /// + private readonly Task _task; + + /// + /// A value indicating whether the continuation should be scheduled on the current sync context. + /// + private readonly bool _captureContext; + + /// + /// Initializes a new instance of the struct. + /// + /// The task. + /// Whether the continuation should be scheduled on the current sync context. + public NoThrowTaskAwaitable(Task task, bool captureContext) + { + if (task is null) + { + throw new InvalidOperationException(nameof(task)); + } + _task = task; + _captureContext = captureContext; + } + + /// + /// Gets the awaiter. + /// + /// The awaiter. + public NoThrowTaskAwaiter GetAwaiter() + { + return new NoThrowTaskAwaiter(_task, _captureContext); + } + } + + /// + /// An awaiter that wraps a task and never throws an exception when waited on. + /// + public readonly struct NoThrowTaskAwaiter : ICriticalNotifyCompletion + { + /// + /// The task. + /// + private readonly Task _task; + + /// + /// A value indicating whether the continuation should be scheduled on the current sync context. + /// + private readonly bool _captureContext; + + /// + /// Initializes a new instance of the struct. + /// + /// The task. + /// if set to true [capture context]. + public NoThrowTaskAwaiter(Task task, bool captureContext) + { + if (task is null) + { + throw new InvalidOperationException(nameof(task)); + } + _task = task; + _captureContext = captureContext; + } + + /// + /// Gets a value indicating whether the task has completed. + /// + public bool IsCompleted + { + get { return _task.IsCompleted; } + } + + /// + /// Schedules a delegate for execution at the conclusion of a task's execution. + /// + /// The action. + public void OnCompleted(Action continuation) + { + _task.ConfigureAwait(_captureContext).GetAwaiter().OnCompleted(continuation); + } + + /// + /// Schedules a delegate for execution at the conclusion of a task's execution + /// without capturing the ExecutionContext. + /// + /// The action. + public void UnsafeOnCompleted(Action continuation) + { + _task.ConfigureAwait(_captureContext).GetAwaiter().UnsafeOnCompleted(continuation); + } + + /// + /// Does nothing. + /// +#pragma warning disable CA1822 // Mark members as static + public void GetResult() +#pragma warning restore CA1822 // Mark members as static + { + // Never throw here. + } + } + } +}