From 0945e3554329b270f1a4b84e9da127fe614ae5b0 Mon Sep 17 00:00:00 2001 From: Michael Staib Date: Thu, 19 Feb 2026 16:05:10 +0000 Subject: [PATCH] Fix SSE cancellation issue --- .../DefaultHttpGraphQLSubscriptionClient.cs | 7 +- ...faultHttpGraphQLSubscriptionClientTests.cs | 160 ++++++++++++++++++ 2 files changed, 165 insertions(+), 2 deletions(-) create mode 100644 src/HotChocolate/Fusion/test/Core.Tests/DefaultHttpGraphQLSubscriptionClientTests.cs diff --git a/src/HotChocolate/Fusion/src/Core/Clients/DefaultHttpGraphQLSubscriptionClient.cs b/src/HotChocolate/Fusion/src/Core/Clients/DefaultHttpGraphQLSubscriptionClient.cs index 04e3c682ee6..bb3215c9b12 100644 --- a/src/HotChocolate/Fusion/src/Core/Clients/DefaultHttpGraphQLSubscriptionClient.cs +++ b/src/HotChocolate/Fusion/src/Core/Clients/DefaultHttpGraphQLSubscriptionClient.cs @@ -43,9 +43,12 @@ private async IAsyncEnumerable SubscribeInternalAsync( var request = new GraphQLHttpRequest(subgraphRequest, _config.EndpointUri); using var response = await _client.SendAsync(request, cancellationToken).ConfigureAwait(false); - await foreach (var result in response.ReadAsResultStreamAsync(cancellationToken).ConfigureAwait(false)) + var resultStream = response.ReadAsResultStreamAsync(cancellationToken); + await using var resultEnumerator = resultStream.GetAsyncEnumerator(cancellationToken); + + while (await resultEnumerator.MoveNextAsync().ConfigureAwait(false)) { - yield return new GraphQLResponse(result); + yield return new GraphQLResponse(resultEnumerator.Current); } } diff --git a/src/HotChocolate/Fusion/test/Core.Tests/DefaultHttpGraphQLSubscriptionClientTests.cs b/src/HotChocolate/Fusion/test/Core.Tests/DefaultHttpGraphQLSubscriptionClientTests.cs new file mode 100644 index 00000000000..8bbbff1367a --- /dev/null +++ b/src/HotChocolate/Fusion/test/Core.Tests/DefaultHttpGraphQLSubscriptionClientTests.cs @@ -0,0 +1,160 @@ +using System.Net; +using HotChocolate.Fusion.Clients; +using HotChocolate.Fusion.Metadata; + +namespace HotChocolate.Fusion; + +public class DefaultHttpGraphQLSubscriptionClientTests +{ + [Fact] + public async Task SubscribeAsync_Passes_CancellationToken_To_Sse_Enumeration() + { + var sseStream = new ObservingSseStream(); + var response = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StreamContent(sseStream), + }; + response.Content.Headers.ContentType = new("text/event-stream"); + + using var httpClient = new HttpClient(new StaticResponseHandler(response)); + + var config = new HttpClientConfiguration( + clientName: "test", + subgraphName: "reviews", + endpointUri: new Uri("http://localhost/graphql")); + + await using var client = new DefaultHttpGraphQLSubscriptionClient(config, httpClient); + + var request = new SubgraphGraphQLRequest( + subgraph: "reviews", + document: "subscription OnNewReview { onNewReview { body } }", + variableValues: null, + extensions: null); + + using var cts = new CancellationTokenSource(); + await using var stream = client.SubscribeAsync(request, cts.Token).GetAsyncEnumerator(); + + var moveNext = stream.MoveNextAsync().AsTask(); + await sseStream.ReadStarted.Task.WaitAsync(TimeSpan.FromSeconds(2)); + + cts.Cancel(); + + var linked = await WaitUntilAsync( + () => sseStream.CapturedToken.IsCancellationRequested, + TimeSpan.FromSeconds(1)); + + Assert.True(linked, "SSE enumeration token is not linked to the caller cancellation token."); + + sseStream.Release(); + await Task.WhenAny(moveNext, Task.Delay(TimeSpan.FromSeconds(2))); + } + + private static async Task WaitUntilAsync(Func condition, TimeSpan timeout) + { + var end = DateTime.UtcNow + timeout; + + while (DateTime.UtcNow < end) + { + if (condition()) + { + return true; + } + + await Task.Delay(20); + } + + return condition(); + } + + private sealed class StaticResponseHandler(HttpResponseMessage response) : HttpMessageHandler + { + protected override Task SendAsync( + HttpRequestMessage request, + CancellationToken cancellationToken) + => Task.FromResult(response); + } + + private sealed class ObservingSseStream : Stream + { + private readonly CancellationTokenSource _release = new(); + + public TaskCompletionSource ReadStarted { get; } = + new(TaskCreationOptions.RunContinuationsAsynchronously); + + public CancellationToken CapturedToken { get; private set; } + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => false; + + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override long Seek(long offset, SeekOrigin origin) + => throw new NotSupportedException(); + + public override void SetLength(long value) + => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + => BlockUntilCanceledOrReleasedAsync(cancellationToken); + + public override Task ReadAsync( + byte[] buffer, + int offset, + int count, + CancellationToken cancellationToken) + => ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + + public void Release() => _release.Cancel(); + + protected override void Dispose(bool disposing) + { + if (disposing) + { + _release.Cancel(); + _release.Dispose(); + } + + base.Dispose(disposing); + } + + private async ValueTask BlockUntilCanceledOrReleasedAsync(CancellationToken cancellationToken) + { + CapturedToken = cancellationToken; + ReadStarted.TrySetResult(); + + using var linked = CancellationTokenSource.CreateLinkedTokenSource( + cancellationToken, + _release.Token); + + try + { + await Task.Delay(Timeout.InfiniteTimeSpan, linked.Token); + } + catch (OperationCanceledException) + { + // Cancellation is expected in this test. + } + + return 0; + } + } +}