diff --git a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs index 93e2da1a1ba..f177284d458 100644 --- a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs +++ b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs @@ -38,41 +38,12 @@ public static void EfficientCopyTo(this Stream input, Stream output) public static int Read(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) { - try - { - using var manualResetEvent = new ManualResetEventSlim(); - var readOperation = stream.BeginRead( - buffer, - offset, - count, - state => ((ManualResetEventSlim)state.AsyncState).Set(), - manualResetEvent); - - if (readOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken)) - { - return stream.EndRead(readOperation); - } - } - catch (OperationCanceledException) - { - // Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed. - } - catch (ObjectDisposedException) - { - throw new IOException(); - } - - try - { - stream.Dispose(); - } - catch - { - // Ignore any exceptions - } - - cancellationToken.ThrowIfCancellationRequested(); - throw new TimeoutException(); + return UseStreamWithTimeout( + stream, + (str, state) => str.Read(state.buffer, state.offset, state.count), + (buffer, offset, count), + timeout, + cancellationToken); } public static async Task ReadAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) @@ -219,43 +190,16 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] destination, public static void Write(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) { - try - { - using var manualResetEvent = new ManualResetEventSlim(); - var writeOperation = stream.BeginWrite( - buffer, - offset, - count, - state => ((ManualResetEventSlim)state.AsyncState).Set(), - manualResetEvent); - - if (writeOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken)) + UseStreamWithTimeout( + stream, + (str, state) => { - stream.EndWrite(writeOperation); - return; - } - } - catch (OperationCanceledException) - { - // Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed. - } - catch (ObjectDisposedException) - { - // It's possible to get ObjectDisposedException when the connection pool was closed with interruptInUseConnections set to true. - throw new IOException(); - } - - try - { - stream.Dispose(); - } - catch - { - // Ignore any exceptions - } - - cancellationToken.ThrowIfCancellationRequested(); - throw new TimeoutException(); + str.Write(state.buffer, state.offset, state.count); + return true; + }, + (buffer, offset, count), + timeout, + cancellationToken); } public static async Task WriteAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) @@ -325,5 +269,89 @@ public static async Task WriteBytesAsync(this Stream stream, OperationContext op count -= bytesToWrite; } } + + private static TResult UseStreamWithTimeout(Stream stream, Func method, TState state, TimeSpan timeout, CancellationToken cancellationToken) + { + StreamDisposeCallbackState callbackState = null; + Timer timer = null; + CancellationTokenRegistration cancellationSubscription = default; + if (timeout != Timeout.InfiniteTimeSpan) + { + callbackState = new StreamDisposeCallbackState(stream); + timer = new Timer(DisposeStreamCallback, callbackState, timeout, Timeout.InfiniteTimeSpan); + } + + if (cancellationToken.CanBeCanceled) + { + callbackState ??= new StreamDisposeCallbackState(stream); + cancellationSubscription = cancellationToken.Register(DisposeStreamCallback, callbackState); + } + + try + { + var result = method(stream, state); + if (callbackState?.TryChangeState(OperationState.Done) == false) + { + // if cannot change the state - then the stream was/will be disposed, throw here + throw new IOException(); + } + + return result; + } + catch (IOException) + { + if (callbackState?.OperationState == OperationState.Cancelled) + { + cancellationToken.ThrowIfCancellationRequested(); + throw new TimeoutException(); + } + + throw; + } + finally + { + timer?.Dispose(); + cancellationSubscription.Dispose(); + } + + static void DisposeStreamCallback(object state) + { + var disposeCallbackState = (StreamDisposeCallbackState)state; + if (!disposeCallbackState.TryChangeState(OperationState.Cancelled)) + { + // if cannot change the state - then I/O was already succeeded + return; + } + + try + { + disposeCallbackState.Stream.Dispose(); + } + catch (Exception) + { + // callbacks should not fail, suppress any exceptions here + } + } + } + + private record StreamDisposeCallbackState(Stream Stream) + { + private int _operationState = 0; + + public OperationState OperationState + { + get => (OperationState)_operationState; + } + + public bool TryChangeState(OperationState newState) => + Interlocked.CompareExchange(ref _operationState, (int)newState, (int)OperationState.InProgress) == (int)OperationState.InProgress; + } + + private enum OperationState + { + InProgress = 0, + Done, + Cancelled, + } } } diff --git a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs index 56cebd0858e..a22cdf3e951 100644 --- a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs @@ -811,19 +811,8 @@ public async Task SendMessage_should_put_the_message_on_the_stream_and_raise_the private void SetupStreamRead(Mock streamMock, TaskCompletionSource tcs) { - streamMock.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] _, int __, int ___, AsyncCallback callback, object state) => - { - var innerTcs = new TaskCompletionSource(state); - tcs.Task.ContinueWith(t => - { - innerTcs.TrySetException(t.Exception.InnerException); - callback(innerTcs.Task); - }); - return innerTcs.Task; - }); - streamMock.Setup(s => s.EndRead(It.IsAny())) - .Returns(x => ((Task)x).GetAwaiter().GetResult()); + streamMock.Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] _, int __, int ___) => tcs.Task.GetAwaiter().GetResult()); streamMock.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(tcs.Task); streamMock.Setup(s => s.Close()).Callback(() => tcs.TrySetException(new ObjectDisposedException("stream"))); diff --git a/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs b/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs index 8da7e5f7de8..66665e11779 100644 --- a/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs @@ -90,20 +90,18 @@ public async Task ReadBytes_with_byte_array_should_have_expected_effect_for_part var bytes = new byte[] { 1, 2, 3 }; var n = 0; var position = 0; - Task ReadPartial (byte[] buffer, int offset, int count) + int ReadPartial (byte[] buffer, int offset, int count) { var length = partition[n++]; Buffer.BlockCopy(bytes, position, buffer, offset, length); position += length; - return Task.FromResult(length); + return length; } mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count)); - mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count)); - mockStream.Setup(s => s.EndRead(It.IsAny())) - .Returns(x => ((Task)x).GetAwaiter().GetResult()); + .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count))); + mockStream.Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count)); var destination = new byte[3]; if (async) @@ -267,20 +265,18 @@ public async Task ReadBytes_with_byte_buffer_should_have_expected_effect_for_par var destination = new ByteArrayBuffer(new byte[3], 3); var n = 0; var position = 0; - Task ReadPartial (byte[] buffer, int offset, int count) + int ReadPartial (byte[] buffer, int offset, int count) { var length = partition[n++]; Buffer.BlockCopy(bytes, position, buffer, offset, length); position += length; - return Task.FromResult(length); + return length; } mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count)); - mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count)); - mockStream.Setup(s => s.EndRead(It.IsAny())) - .Returns(x => ((Task)x).GetAwaiter().GetResult()); + .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count))); + mockStream.Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count)); if (async) { diff --git a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs index 4d61e9c7b60..7bdaba366f9 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs @@ -96,7 +96,7 @@ public void Heartbeat_should_be_emitted_before_connection_open() var mockStream = new Mock(); mockStream - .Setup(s => s.BeginWrite(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(s => s.Write(It.IsAny(), It.IsAny(), It.IsAny())) .Callback(() => EnqueueEvent(HelloReceivedEvent)) .Throws(new Exception("Stream is closed."));