diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index e2b49083df744d..e55cf47f6df62d 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -111,7 +111,7 @@ private void Dispose(bool disposing) if (disposing) { GC.SuppressFinalize(this); - _stream.Dispose(); + _stream.DisposeAsync().AsTask().GetAwaiter().GetResult(); } } } diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index 70c4c72760a59b..117f147d0d0b07 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -1428,6 +1428,105 @@ public async Task Expect100ContinueTimeout_SetAfterUse_Throws() Assert.Throws(() => handler.Expect100ContinueTimeout = TimeSpan.FromMilliseconds(1)); } } + + [Fact] + public async Task ReadBodyCanceled_CallsDisposeAsyncOnConnectCallbackStream() + { + HalfDuplexStream? stream = null; + + using var handler = new SocketsHttpHandler(); + handler.ConnectCallback = async (context, cancellation) => + { + Assert.Null(stream); + + var serverToClientPipe = new IO.Pipelines.Pipe(); + stream = new HalfDuplexStream(serverToClientPipe); + + var hangingChunkedResponse = "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"u8; + var writeSpan = serverToClientPipe.Writer.GetSpan(hangingChunkedResponse.Length); + hangingChunkedResponse.CopyTo(writeSpan); + serverToClientPipe.Writer.Advance(hangingChunkedResponse.Length); + await serverToClientPipe.Writer.FlushAsync(cancellation); + return stream; + }; + using HttpClient client = CreateHttpClient(handler); + + using HttpRequestMessage request = CreateRequest(HttpMethod.Get, new Uri("http://example.com"), UseVersion, exactVersion: true); + using var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead).WaitAsync(TestHelper.PassingTestTimeout); + response.EnsureSuccessStatusCode(); + + using var cts = new CancellationTokenSource(); + var responseBodyReadTask = response.Content.ReadAsStringAsync(cts.Token); + cts.Cancel(); + + var tcException = await Assert.ThrowsAsync(() => responseBodyReadTask).WaitAsync(TestHelper.PassingTestTimeout); + var ioException = Assert.IsType(tcException.InnerException); + Assert.Equal(HttpRequestError.ResponseEnded, ioException.HttpRequestError); + + Assert.NotNull(stream); + Assert.True(stream.DisposeCalled); + Assert.True(stream.DisposeAsyncCalled); + } + + private class HalfDuplexStream(IO.Pipelines.Pipe responsePipe) : Stream + { + private readonly Stream _readStream = responsePipe.Reader.AsStream(); + + public bool DisposeCalled { get; private set; } + public bool DisposeAsyncCalled { get; private set; } + + public override bool CanRead => true; + public override bool CanWrite => true; + public override bool CanSeek => false; + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => _readStream.ReadAsync(buffer, offset, count, cancellationToken); + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + => _readStream.ReadAsync(buffer, cancellationToken); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => Null.WriteAsync(buffer, offset, count, cancellationToken); + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + => Null.WriteAsync(buffer, cancellationToken); + + public override Task FlushAsync(CancellationToken cancellationToken) + => Null.FlushAsync(cancellationToken); + + public override int Read(byte[] buffer, int offset, int count) => _readStream.Read(buffer, offset, count); + public override void Write(byte[] buffer, int offset, int count) => Null.Write(buffer, offset, count); + public override void Flush() => Null.Flush(); + + protected override void Dispose(bool disposing) + { + DisposeCalled = true; + responsePipe.Writer.Complete(); + } + + public override async ValueTask DisposeAsync() + { + DisposeAsyncCalled = true; + await base.DisposeAsync(); + } + + // Unsupported stuff + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + } } public abstract class SocketsHttpHandler_HttpClientHandler_MaxResponseHeadersLength : HttpClientHandler_MaxResponseHeadersLength_Test