Skip to content

Commit c47a53d

Browse files
authored
[WinHTTP] Make concurrent IO check thread safe (#111750)
* Check for concurrent IO is now thread safe. * Add test * Use compare exchange, fix cancellation token registration * Typo * Fixed unit tests. * Test exception narrow down * Narrow down the exception * Don't dispose handle outside of lock
1 parent 2e42b6e commit c47a53d

File tree

6 files changed

+51
-27
lines changed

6 files changed

+51
-27
lines changed

src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,14 @@ private async Task StartRequestAsync(WinHttpRequestState state)
941941
// will have the side-effect of WinHTTP cancelling any pending I/O and accelerating its callbacks
942942
// on the handle and thus releasing the awaiting tasks in the loop below. This helps to provide
943943
// a more timely, cooperative, cancellation pattern.
944-
using (state.CancellationToken.Register(s => ((WinHttpRequestState)s!).RequestHandle!.Dispose(), state))
944+
using (state.CancellationToken.Register(static s =>
945+
{
946+
var state = (WinHttpRequestState)s!;
947+
lock (state.Lock)
948+
{
949+
state.RequestHandle?.Dispose();
950+
}
951+
}, state))
945952
{
946953
do
947954
{

src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpRequestState.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ public WinHttpTransportContext TransportContext
150150

151151
public RendezvousAwaitable<int> LifecycleAwaitable { get; set; } = new RendezvousAwaitable<int>();
152152
public TaskCompletionSource<bool>? TcsInternalWriteDataToRequestStream { get; set; }
153-
public bool AsyncReadInProgress { get; set; }
153+
public volatile int AsyncReadInProgress;
154154

155155
// WinHttpResponseStream state.
156156
public long? ExpectedBytesToRead { get; set; }

src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public static HttpResponseMessage CreateResponseMessage(
6464

6565
// Create response stream and wrap it in a StreamContent object.
6666
var responseStream = new WinHttpResponseStream(requestHandle, state, response);
67-
state.RequestHandle = null; // ownership successfully transferred to WinHttpResponseStram.
67+
state.RequestHandle = null; // ownership successfully transferred to WinHttpResponseStream.
6868
Stream decompressedStream = responseStream;
6969

7070
if (manuallyProcessedDecompressionMethods != DecompressionMethods.None)

src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseStream.cs

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,6 @@ public override Task CopyToAsync(Stream destination, int bufferSize, Cancellatio
9494
// Validate arguments as would base CopyToAsync
9595
StreamHelpers.ValidateCopyToArgs(this, destination, bufferSize);
9696

97-
// Check that there are no other pending read operations
98-
if (_state.AsyncReadInProgress)
99-
{
100-
throw new InvalidOperationException(SR.net_http_no_concurrent_io_allowed);
101-
}
102-
10397
// Early check for cancellation
10498
if (cancellationToken.IsCancellationRequested)
10599
{
@@ -112,11 +106,15 @@ public override Task CopyToAsync(Stream destination, int bufferSize, Cancellatio
112106

113107
private async Task CopyToAsyncCore(Stream destination, byte[] buffer, CancellationToken cancellationToken)
114108
{
115-
_state.PinReceiveBuffer(buffer);
116-
CancellationTokenRegistration ctr = cancellationToken.Register(s => ((WinHttpResponseStream)s!).CancelPendingResponseStreamReadOperation(), this);
117-
_state.AsyncReadInProgress = true;
109+
// Check that there are no other pending read operations
110+
if (Interlocked.CompareExchange(ref _state.AsyncReadInProgress, 1, 0) == 1)
111+
{
112+
throw new InvalidOperationException(SR.net_http_no_concurrent_io_allowed);
113+
}
118114
try
119115
{
116+
using var ctr = cancellationToken.Register(s => ((WinHttpResponseStream)s!).CancelPendingResponseStreamReadOperation(), this);
117+
_state.PinReceiveBuffer(buffer);
120118
// Loop until there's no more data to be read
121119
while (true)
122120
{
@@ -163,8 +161,7 @@ private async Task CopyToAsyncCore(Stream destination, byte[] buffer, Cancellati
163161
}
164162
finally
165163
{
166-
_state.AsyncReadInProgress = false;
167-
ctr.Dispose();
164+
_state.AsyncReadInProgress = 0;
168165
ArrayPool<byte>.Shared.Return(buffer);
169166
}
170167

@@ -201,11 +198,6 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
201198

202199
CheckDisposed();
203200

204-
if (_state.AsyncReadInProgress)
205-
{
206-
throw new InvalidOperationException(SR.net_http_no_concurrent_io_allowed);
207-
}
208-
209201
return ReadAsyncCore(buffer, offset, count, token);
210202
}
211203

@@ -221,12 +213,15 @@ private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, Canc
221213
{
222214
return 0;
223215
}
224-
225-
_state.PinReceiveBuffer(buffer);
226-
var ctr = token.Register(s => ((WinHttpResponseStream)s!).CancelPendingResponseStreamReadOperation(), this);
227-
_state.AsyncReadInProgress = true;
216+
// Check that there are no other pending read operations
217+
if (Interlocked.CompareExchange(ref _state.AsyncReadInProgress, 1, 0) == 1)
218+
{
219+
throw new InvalidOperationException(SR.net_http_no_concurrent_io_allowed);
220+
}
228221
try
229222
{
223+
using var ctr = token.Register(s => ((WinHttpResponseStream)s!).CancelPendingResponseStreamReadOperation(), this);
224+
_state.PinReceiveBuffer(buffer);
230225
lock (_state.Lock)
231226
{
232227
Debug.Assert(!_requestHandle.IsInvalid);
@@ -262,8 +257,7 @@ private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, Canc
262257
}
263258
finally
264259
{
265-
_state.AsyncReadInProgress = false;
266-
ctr.Dispose();
260+
_state.AsyncReadInProgress = 0;
267261
}
268262
}
269263

@@ -357,7 +351,7 @@ private void CancelPendingResponseStreamReadOperation()
357351
{
358352
lock (_state.Lock)
359353
{
360-
if (_state.AsyncReadInProgress)
354+
if (_state.AsyncReadInProgress == 1)
361355
{
362356
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info("before dispose");
363357
_requestHandle?.Dispose(); // null check necessary to handle race condition between stream disposal and cancellation

src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,29 @@ public void SendAsync_SimpleGet_Success()
4646
}
4747
}
4848

49+
[OuterLoop("Uses external server.")]
50+
[Fact]
51+
public async Task GetAsync_ConcurrentRead_ThrowsInvalidOperationException()
52+
{
53+
using var client = new HttpClient(new WinHttpHandler());
54+
using var response = await client.GetAsync("https://httpbin.org/stream-bytes/4096", HttpCompletionOption.ResponseHeadersRead);
55+
using var stream = await response.Content.ReadAsStreamAsync();
56+
var tasks = new Task[1_000];
57+
for (int i = 0; i < tasks.Length; ++i)
58+
{
59+
tasks[i] = Task.Run(async () =>
60+
{
61+
try
62+
{
63+
await stream.ReadAsync(new byte[5]);
64+
}
65+
catch (InvalidOperationException ioe) when (ioe.Message.Contains("concurrent I/O")) // Expected exception for concurrent IO
66+
{ }
67+
});
68+
}
69+
await Task.WhenAll(tasks);
70+
}
71+
4972
[OuterLoop]
5073
[Theory]
5174
[InlineData(CookieUsePolicy.UseInternalCookieStoreOnly, "cookieName1", "cookieValue1")]

src/libraries/System.Net.Http.WinHttpHandler/tests/UnitTests/WinHttpResponseStreamTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ public void ReadAsync_PriorReadInProgress_ThrowsInvalidOperationException()
285285
TestControl.WinHttpReadData.Pause();
286286
Task t1 = stream.ReadAsync(new byte[1], 0, 1);
287287

288-
Assert.Throws<InvalidOperationException>(() => { Task t2 = stream.ReadAsync(new byte[1], 0, 1); });
288+
Assert.ThrowsAsync<InvalidOperationException>(() => stream.ReadAsync(new byte[1], 0, 1));
289289

290290
TestControl.WinHttpReadData.Resume();
291291
t1.Wait();

0 commit comments

Comments
 (0)