diff --git a/src/HotChocolate/Fusion-vnext/src/Fusion.Execution/Execution/Pipeline/OperationPlanCacheMiddleware.cs b/src/HotChocolate/Fusion-vnext/src/Fusion.Execution/Execution/Pipeline/OperationPlanCacheMiddleware.cs index 5a50a46f788..3084a159ce9 100644 --- a/src/HotChocolate/Fusion-vnext/src/Fusion.Execution/Execution/Pipeline/OperationPlanCacheMiddleware.cs +++ b/src/HotChocolate/Fusion-vnext/src/Fusion.Execution/Execution/Pipeline/OperationPlanCacheMiddleware.cs @@ -1,3 +1,4 @@ +using System.Collections.Concurrent; using HotChocolate.Caching.Memory; using HotChocolate.Execution; using HotChocolate.Fusion.Diagnostics; @@ -10,6 +11,8 @@ internal sealed class OperationPlanCacheMiddleware { private readonly Cache _cache; private readonly IFusionExecutionDiagnosticEvents _diagnosticEvents; + private readonly ConcurrentDictionary>> _inFlightPlans = + new(StringComparer.Ordinal); private OperationPlanCacheMiddleware(Cache cache, IFusionExecutionDiagnosticEvents diagnosticEvents) { @@ -32,29 +35,95 @@ public async ValueTask InvokeAsync(RequestContext context, RequestDelegate next) : $"{documentInfo.Hash.Value}.{context.Request.OperationName ?? "Default"}"; context.SetOperationId(operationId); - var isPlanCached = false; + var isSingleFlightLeader = false; + Lazy>? inFlightPlan = null; if (_cache.TryGet(operationId, out var plan)) { context.SetOperationPlan(plan); - isPlanCached = true; _diagnosticEvents.RetrievedOperationPlanFromCache(context, operationId); } + else if (_inFlightPlans.TryGetValue(operationId, out inFlightPlan)) + { + // Another request is already planning this operation. + // Await the leader's result to avoid redundant planning work. + var coalescedPlan = await inFlightPlan.Value.Task + .WaitAsync(context.RequestAborted) + .ConfigureAwait(false); + context.SetOperationPlan(coalescedPlan); + } + else + { + // No plan is cached and no planning is in progress. + // Use a Lazy so that under burst conditions only one TCS is materialized + // even if multiple requests race through GetOrAdd concurrently. + inFlightPlan = new Lazy>( + static () => new TaskCompletionSource( + TaskCreationOptions.RunContinuationsAsynchronously)); + var cachedInFlightPlan = _inFlightPlans.GetOrAdd(operationId, inFlightPlan); - await next(context); + if (ReferenceEquals(cachedInFlightPlan, inFlightPlan)) + { + // We won the race! This request is the single-flight leader + // responsible for planning and signaling all followers. + isSingleFlightLeader = true; + context.Features.Set(inFlightPlan.Value); + } + else + { + // We lost the race! Another request claimed leadership between + // TryGetValue and GetOrAdd. So we simply await the leader's result. + var coalescedPlan = await cachedInFlightPlan.Value.Task + .WaitAsync(context.RequestAborted) + .ConfigureAwait(false); + context.SetOperationPlan(coalescedPlan); + } + } - if (!isPlanCached) + try + { + await next(context); + } + catch (Exception ex) { - // We retrieve the execution plan from the context. - // If there is no execution plan, we can exit early as something must have - // gone wrong in the pipeline. If we get, however, an execution plan, - // we try to cache it. - var executionPlan = context.GetOperationPlan(); + // Propagate the exception to all waiting followers. + if (isSingleFlightLeader && inFlightPlan is not null) + { + inFlightPlan.Value.TrySetException(ex); + } - if (executionPlan is not null) + throw; + } + finally + { + if (isSingleFlightLeader) { - _cache.TryAdd(operationId, executionPlan); - _diagnosticEvents.AddedOperationPlanToCache(context, operationId); + // Guard against a faulty diagnostic event handler preventing cleanup. + // Without this, a throw from the cache or diagnostics would leak the + // in-flight entry, causing _inFlightPlans to grow indefinitely. + try + { + if (context.GetOperationPlan() is { } operationPlan) + { + // Cache the plan before removing the in-flight entry so that + // there is no window where the plan is in neither structure. + _cache.TryAdd(operationId, operationPlan); + _diagnosticEvents.AddedOperationPlanToCache(context, operationId); + inFlightPlan?.Value.TrySetResult(operationPlan); + } + else if (inFlightPlan?.Value.Task.IsCompleted == false) + { + // The pipeline completed without producing a plan and without + // throwing. Signal followers so they do not hang indefinitely. + inFlightPlan.Value.TrySetException( + new InvalidOperationException( + "The operation plan task completed without a result.")); + } + } + finally + { + _inFlightPlans.TryRemove(operationId, out _); + } } } } diff --git a/src/HotChocolate/Fusion-vnext/src/Fusion.Execution/Execution/Pipeline/OperationPlanMiddleware.cs b/src/HotChocolate/Fusion-vnext/src/Fusion.Execution/Execution/Pipeline/OperationPlanMiddleware.cs index 39c679206ee..8b170d2302f 100644 --- a/src/HotChocolate/Fusion-vnext/src/Fusion.Execution/Execution/Pipeline/OperationPlanMiddleware.cs +++ b/src/HotChocolate/Fusion-vnext/src/Fusion.Execution/Execution/Pipeline/OperationPlanMiddleware.cs @@ -57,6 +57,7 @@ private void PlanOperation( var operationShortHash = operationHash[..8]; using var scope = _diagnosticsEvents.PlanOperation(context, operationId); + var inFlightPlan = context.Features.Get>(); try { @@ -74,9 +75,11 @@ private void PlanOperation( context.RequestAborted); OnAfterPlanCompleted(operationDocumentInfo, operationPlan); context.SetOperationPlan(operationPlan); + inFlightPlan?.TrySetResult(operationPlan); } catch (Exception ex) { + inFlightPlan?.TrySetException(ex); _diagnosticsEvents.PlanOperationError(context, operationId, ex); throw; diff --git a/src/HotChocolate/Fusion-vnext/test/Fusion.Execution.Tests/Execution/OperationPlanSingleFlightTests.cs b/src/HotChocolate/Fusion-vnext/test/Fusion.Execution.Tests/Execution/OperationPlanSingleFlightTests.cs new file mode 100644 index 00000000000..adcafd4cbe6 --- /dev/null +++ b/src/HotChocolate/Fusion-vnext/test/Fusion.Execution.Tests/Execution/OperationPlanSingleFlightTests.cs @@ -0,0 +1,366 @@ +using System.Collections.Concurrent; +using System.Diagnostics.Tracing; +using HotChocolate.Collections.Immutable; +using HotChocolate.Execution; +using HotChocolate.Fusion.Execution.Nodes; +using HotChocolate.Fusion.Execution.Pipeline; +using HotChocolate.Fusion.Planning; +using Microsoft.Extensions.DependencyInjection; + +namespace HotChocolate.Fusion.Execution; + +public sealed class OperationPlanSingleFlightTests : FusionTestBase +{ + [Fact] + public async Task Concurrent_Same_Operation_Should_Be_Coalesced_To_One_Planning_Run() + { + // arrange + const int requestCount = 8; + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + using var listener = new PlannerEventListener(); + var operationIds = new ConcurrentBag(); + var gate = new RequestGate(requestCount); + + var executor = await new ServiceCollection() + .AddGraphQLGateway() + .UseDefaultPipeline() + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationPlanCacheMiddleware, + (_, next) => CreateGateMiddleware(next, gate)) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationPlanMiddleware, + (_, next) => CreateSingleFlightLeaderDelayMiddleware(next, TimeSpan.FromMilliseconds(100))) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationPlanMiddleware, + (_, next) => CreateOperationIdCaptureMiddleware(next, operationIds)) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationExecutionMiddleware, + (_, _) => CreatePlanCaptureMiddleware()) + .AddInMemoryConfiguration( + ComposeSchemaDocument( + """ + type Query { + foo: String + } + """)) + .Services + .BuildServiceProvider() + .GetRequestExecutorAsync(cancellationToken: cts.Token); + + // act + const string operationText = + """ + query SameOpCoalesce { + foo + } + """; + var results = await Task.WhenAll( + Enumerable.Range(0, requestCount) + .Select(_ => executor.ExecuteAsync(operationText, cts.Token))); + + // assert + Assert.All(results, t => Assert.Empty(t.ExpectOperationResult().Errors)); + + var operationId = Assert.Single(operationIds.Distinct()); + Assert.Equal(1, listener.Count(PlannerEventSource.PlanStartEventId, operationId)); + } + + [Fact] + public async Task Concurrent_Distinct_Operations_Should_Not_Be_Coalesced() + { + // arrange + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + using var listener = new PlannerEventListener(); + var operationIds = new ConcurrentBag(); + var gate = new RequestGate(expectedRequests: 2); + + var executor = await new ServiceCollection() + .AddGraphQLGateway() + .UseDefaultPipeline() + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationPlanCacheMiddleware, + (_, next) => CreateGateMiddleware(next, gate)) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationPlanMiddleware, + (_, next) => CreateOperationIdCaptureMiddleware(next, operationIds)) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationExecutionMiddleware, + (_, _) => CreatePlanCaptureMiddleware()) + .AddInMemoryConfiguration( + ComposeSchemaDocument( + """ + type Query { + foo: String + } + """)) + .Services + .BuildServiceProvider() + .GetRequestExecutorAsync(cancellationToken: cts.Token); + + const string operationText1 = + """ + query DistinctOpOne { + foo + } + """; + const string operationText2 = + """ + query DistinctOpTwo { + __typename + } + """; + + // act + var results = await Task.WhenAll( + executor.ExecuteAsync(operationText1, cts.Token), + executor.ExecuteAsync(operationText2, cts.Token)); + + // assert + Assert.All(results, t => Assert.Empty(t.ExpectOperationResult().Errors)); + + var ids = operationIds.Distinct().ToArray(); + Assert.Equal(2, ids.Length); + Assert.All(ids, id => Assert.Equal(1, listener.Count(PlannerEventSource.PlanStartEventId, id))); + } + + [Fact] + public async Task Leader_Planning_Failure_Should_Be_Observed_By_Followers() + { + // arrange + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + using var listener = new PlannerEventListener(); + var operationIds = new ConcurrentBag(); + var gate = new RequestGate(expectedRequests: 2); + + var executor = await new ServiceCollection() + .AddGraphQLGateway() + .UseDefaultPipeline() + .ModifyPlannerOptions(o => o.MaxPlanningTime = TimeSpan.FromTicks(1)) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationPlanCacheMiddleware, + (_, next) => CreateGateMiddleware(next, gate)) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationPlanMiddleware, + (_, next) => CreateSingleFlightLeaderDelayMiddleware(next, TimeSpan.FromMilliseconds(100))) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationPlanMiddleware, + (_, next) => CreateOperationIdCaptureMiddleware(next, operationIds)) + .AddInMemoryConfiguration( + ComposeSchemaDocument( + """ + type Query { + foo: String + } + """)) + .Services + .BuildServiceProvider() + .GetRequestExecutorAsync(cancellationToken: cts.Token); + + const string operationText = + """ + query FailureCoalesce { + foo + } + """; + + // act + var results = await Task.WhenAll( + executor.ExecuteAsync(operationText, cts.Token), + executor.ExecuteAsync(operationText, cts.Token)); + + // assert + Assert.All(results, t => Assert.NotEmpty(t.ExpectOperationResult().Errors)); + + var operationId = Assert.Single(operationIds.Distinct()); + Assert.Equal(1, listener.Count(PlannerEventSource.PlanStartEventId, operationId)); + Assert.Equal(1, listener.Count(PlannerEventSource.PlanErrorEventId, operationId)); + } + + [Fact] + public async Task Follower_Cancellation_Should_Not_Cancel_Leader_Planning() + { + // arrange + using var leaderCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + using var followerCts = new CancellationTokenSource(TimeSpan.FromMilliseconds(150)); + using var listener = new PlannerEventListener(); + var operationIds = new ConcurrentBag(); + var blockingInterceptor = new BlockingPlannerInterceptor(); + + var executor = await new ServiceCollection() + .AddGraphQLGateway() + .UseDefaultPipeline() + .AddOperationPlannerInterceptor(_ => blockingInterceptor) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationPlanMiddleware, + (_, next) => CreateOperationIdCaptureMiddleware(next, operationIds)) + .InsertUseRequest( + before: WellKnownRequestMiddleware.OperationExecutionMiddleware, + (_, _) => CreatePlanCaptureMiddleware()) + .AddInMemoryConfiguration( + ComposeSchemaDocument( + """ + type Query { + foo: String + } + """)) + .Services + .BuildServiceProvider() + .GetRequestExecutorAsync(cancellationToken: leaderCts.Token); + + const string operationText = + """ + query CancelFollowerOnly { + foo + } + """; + + // act + var leaderTask = Task.Run( + () => executor.ExecuteAsync(operationText, leaderCts.Token), + CancellationToken.None); + Assert.True(blockingInterceptor.WaitForEntry(TimeSpan.FromSeconds(5))); + + var followerTask = Task.Run( + () => executor.ExecuteAsync(operationText, followerCts.Token), + CancellationToken.None); + var followerCompletion = await Task.WhenAny( + followerTask, + Task.Delay(TimeSpan.FromSeconds(2), leaderCts.Token)); + blockingInterceptor.Release(); + Assert.Same(followerTask, followerCompletion); + + var followerResult = await followerTask; + var leaderResult = await leaderTask; + + // assert + var followerErrors = followerResult.ExpectOperationResult().Errors; + Assert.NotEmpty(followerErrors); + Assert.Contains( + followerErrors, + e => e.Message.Contains("cancel", StringComparison.OrdinalIgnoreCase)); + + Assert.Empty(leaderResult.ExpectOperationResult().Errors); + + var operationId = Assert.Single(operationIds.Distinct()); + Assert.Equal(1, listener.Count(PlannerEventSource.PlanStartEventId, operationId)); + } + + private static RequestDelegate CreateGateMiddleware( + RequestDelegate next, + RequestGate gate) + => async context => + { + await gate.SignalAndWaitAsync(context.RequestAborted); + await next(context); + }; + + private static RequestDelegate CreateOperationIdCaptureMiddleware( + RequestDelegate next, + ConcurrentBag operationIds) + => async context => + { + operationIds.Add(context.GetOperationId()); + await next(context); + }; + + private static RequestDelegate CreateSingleFlightLeaderDelayMiddleware( + RequestDelegate next, + TimeSpan delay) + => async context => + { + if (context.Features.Get>() is not null) + { + await Task.Delay(delay, context.RequestAborted); + } + + await next(context); + }; + + private static RequestDelegate CreatePlanCaptureMiddleware() + => context => + { + context.Result = + new OperationResult( + ImmutableOrderedDictionary.Empty.Add("operationPlan", context.GetOperationPlan())); + return ValueTask.CompletedTask; + }; + + private sealed class RequestGate(int expectedRequests) + { + private readonly TaskCompletionSource _allArrived = + new(TaskCreationOptions.RunContinuationsAsynchronously); + private int _arrived; + + public ValueTask SignalAndWaitAsync(CancellationToken cancellationToken) + { + if (Interlocked.Increment(ref _arrived) >= expectedRequests) + { + _allArrived.TrySetResult(); + } + + return new ValueTask(_allArrived.Task.WaitAsync(cancellationToken)); + } + } + + private sealed class BlockingPlannerInterceptor : IOperationPlannerInterceptor + { + private readonly ManualResetEventSlim _entered = new(false); + private readonly TaskCompletionSource _release = + new(TaskCreationOptions.RunContinuationsAsynchronously); + + public bool WaitForEntry(TimeSpan timeout) + => _entered.Wait(timeout); + + public void Release() + => _release.TrySetResult(); + + public void OnAfterPlanCompleted( + OperationDocumentInfo operationDocumentInfo, + OperationPlan operationPlan) + { + _entered.Set(); + _release.Task.GetAwaiter().GetResult(); + } + } + + private sealed class PlannerEventListener : EventListener + { + private readonly ConcurrentQueue _events = []; + + protected override void OnEventSourceCreated(EventSource eventSource) + { + if (eventSource.Name.Equals(PlannerEventSource.EventSourceName, StringComparison.Ordinal)) + { + EnableEvents(eventSource, EventLevel.Informational, EventKeywords.All); + } + } + + protected override void OnEventWritten(EventWrittenEventArgs eventData) + { + if (!eventData.EventSource.Name.Equals(PlannerEventSource.EventSourceName, StringComparison.Ordinal)) + { + return; + } + + _events.Enqueue( + new CapturedEvent( + eventData.EventId, + eventData.Payload is null + ? [] + : [.. eventData.Payload])); + } + + public int Count(int eventId, string operationId) + => _events.Count(t => t.EventId == eventId && t.HasOperationId(operationId)); + } + + private sealed record CapturedEvent( + int EventId, + IReadOnlyList Payload) + { + public bool HasOperationId(string operationId) + => Payload.Count > 0 + && Payload[0] is string payloadOperationId + && payloadOperationId.Equals(operationId, StringComparison.Ordinal); + } +}