diff --git a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs index e9ff147bbe5f03..35364a0af8476b 100644 --- a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs +++ b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs @@ -712,56 +712,54 @@ public abstract class StandaloneStreamConformanceTests : StreamConformanceTests return stream; } - protected async IAsyncEnumerable GetStreamsForValidation() - { - yield return await CreateReadOnlyStream(); - yield return await CreateReadOnlyStream(new byte[4]); - - yield return await CreateWriteOnlyStream(); - yield return await CreateWriteOnlyStream(new byte[4]); - - yield return await CreateReadWriteStream(); - yield return await CreateReadWriteStream(new byte[4]); - } + protected async Task> GetStreamsForValidationAsync() => + [ + await CreateReadOnlyStream(), + await CreateReadOnlyStream(new byte[4]), + await CreateWriteOnlyStream(), + await CreateWriteOnlyStream(new byte[4]), + await CreateReadWriteStream(), + await CreateReadWriteStream(new byte[4]), + ]; [Fact] [ActiveIssue("https://github.com/dotnet/runtime/issues/107981", TestPlatforms.Wasi)] public virtual async Task ArgumentValidation_ThrowsExpectedException() { - await foreach (Stream? stream in GetStreamsForValidation()) - { - if (stream != null) + await Task.WhenAll( + from stream in await GetStreamsForValidationAsync() + where stream is not null + select Task.Run(async () => { using var _ = stream; await ValidateMisuseExceptionsAsync(stream); - } - } + })); } [Fact] public virtual async Task Disposed_ThrowsObjectDisposedException() { - await foreach (Stream? stream in GetStreamsForValidation()) - { - if (stream != null) + await Task.WhenAll( + from stream in await GetStreamsForValidationAsync() + where stream is not null + select Task.Run(async () => { using var _ = stream; await ValidateDisposedExceptionsAsync(stream); - } - } + })); } [Fact] public virtual async Task ReadWriteAsync_Precanceled_ThrowsOperationCanceledException() { - await foreach (Stream? stream in GetStreamsForValidation()) - { - if (stream != null) + await Task.WhenAll( + from stream in await GetStreamsForValidationAsync() + where stream is not null + select Task.Run(async () => { using var _ = stream; await ValidatePrecanceledOperations_ThrowsCancellationException(stream); - } - } + })); } [Theory] @@ -1743,47 +1741,46 @@ public virtual async Task ReadWriteByte_Success() { using StreamPair streams = await CreateConnectedStreamsAsync(); - foreach ((Stream writeable, Stream readable) in GetReadWritePairs(streams)) + (Stream writeable, Stream readable) = GetReadWritePair(streams); + + byte[] writerBytes = GetRandomBytes(42); + var readerBytes = new byte[writerBytes.Length]; + + Task writes = Task.Run(() => { - byte[] writerBytes = GetRandomBytes(42); - var readerBytes = new byte[writerBytes.Length]; + foreach (byte b in writerBytes) + { + writeable.WriteByte(b); + } - Task writes = Task.Run(() => + if (FlushRequiredToWriteData) { - foreach (byte b in writerBytes) + if (FlushGuaranteesAllDataWritten) { - writeable.WriteByte(b); + writeable.Flush(); } - - if (FlushRequiredToWriteData) + else { - if (FlushGuaranteesAllDataWritten) - { - writeable.Flush(); - } - else - { - writeable.Dispose(); - } + writeable.Dispose(); } - }); - - for (int i = 0; i < readerBytes.Length; i++) - { - int r = readable.ReadByte(); - Assert.InRange(r, 0, 255); - readerBytes[i] = (byte)r; } + }); - AssertExtensions.SequenceEqual(writerBytes, readerBytes); + for (int i = 0; i < readerBytes.Length; i++) + { + int r = readable.ReadByte(); + Assert.InRange(r, 0, 255); + readerBytes[i] = (byte)r; + } - await writes; + AssertExtensions.SequenceEqual(writerBytes, readerBytes); - if (!FlushGuaranteesAllDataWritten) - { - break; - } + if (FlushRequiredToWriteData && !FlushGuaranteesAllDataWritten) + { + await readable.DisposeAsync(); } + + await writes; } public static IEnumerable ReadWrite_Modes => @@ -1824,51 +1821,50 @@ public virtual async Task ReadWrite_Success(ReadWriteMode mode, int writeSize, b { using StreamPair streams = await CreateConnectedStreamsAsync(); - foreach ((Stream writeable, Stream readable) in GetReadWritePairs(streams)) + (Stream writeable, Stream readable) = GetReadWritePair(streams); + + if (startWithFlush) { - if (startWithFlush) - { - await FlushAsync(mode, writeable, nonCanceledToken); - } + await FlushAsync(mode, writeable, nonCanceledToken); + } + + byte[] writerBytes = GetRandomBytes(writeSize); + var readerBytes = new byte[writerBytes.Length]; - byte[] writerBytes = GetRandomBytes(writeSize); - var readerBytes = new byte[writerBytes.Length]; + Task writes = Task.Run(async () => + { + await WriteAsync(mode, writeable, writerBytes, 0, writerBytes.Length, nonCanceledToken); - Task writes = Task.Run(async () => + if (FlushRequiredToWriteData) { - await WriteAsync(mode, writeable, writerBytes, 0, writerBytes.Length, nonCanceledToken); - - if (FlushRequiredToWriteData) + if (FlushGuaranteesAllDataWritten) { - if (FlushGuaranteesAllDataWritten) - { - await writeable.FlushAsync(); - } - else - { - await writeable.DisposeAsync(); - } + await writeable.FlushAsync(); + } + else + { + await writeable.DisposeAsync(); } - }); - - int n = 0; - while (n < readerBytes.Length) - { - int r = await ReadAsync(mode, readable, readerBytes, n, readerBytes.Length - n).WaitAsync(TimeSpan.FromSeconds(30)); - Assert.InRange(r, 1, readerBytes.Length - n); - n += r; } + }); - Assert.Equal(readerBytes.Length, n); - AssertExtensions.SequenceEqual(writerBytes, readerBytes); + int n = 0; + while (n < readerBytes.Length) + { + int r = await ReadAsync(mode, readable, readerBytes, n, readerBytes.Length - n).WaitAsync(TimeSpan.FromSeconds(30)); + Assert.InRange(r, 1, readerBytes.Length - n); + n += r; + } - await writes; + Assert.Equal(readerBytes.Length, n); + AssertExtensions.SequenceEqual(writerBytes, readerBytes); - if (!FlushGuaranteesAllDataWritten) - { - break; - } + if (FlushRequiredToWriteData && !FlushGuaranteesAllDataWritten) + { + await readable.DisposeAsync(); } + + await writes; } } @@ -1946,8 +1942,7 @@ public virtual async Task Read_Eof_Returns0(ReadWriteMode mode, bool dataAvailab } else { - writeable.Dispose(); - write = Task.CompletedTask; + write = writeable.DisposeAsync().AsTask(); } if (dataAvailableFirst) @@ -1959,9 +1954,10 @@ public virtual async Task Read_Eof_Returns0(ReadWriteMode mode, bool dataAvailab Assert.Equal('o', readable.ReadByte()); } - await write; - Assert.Equal(0, await ReadAsync(mode, readable, new byte[1], 0, 1)); + + await readable.DisposeAsync(); + await write; } [Theory] @@ -1990,6 +1986,7 @@ public virtual async Task Read_DataStoredAtDesiredOffset(ReadWriteMode mode) Assert.Equal(1, await ReadAsync(mode, readable, buffer, offset, buffer.Length - offset)); + await readable.DisposeAsync(); await write; for (int i = 0; i < buffer.Length; i++) @@ -2021,8 +2018,10 @@ public virtual async Task Write_DataReadFromDesiredOffset(ReadWriteMode mode) writeable.Dispose(); }); - using StreamReader reader = new StreamReader(readable); - Assert.Equal("hello", reader.ReadToEnd()); + using (StreamReader reader = new StreamReader(readable)) + { + Assert.Equal("hello", reader.ReadToEnd()); + } await write; } @@ -2038,7 +2037,7 @@ public virtual async Task WriteWithBrokenPipe_Throws() using StreamPair streams = await CreateConnectedStreamsAsync(); (Stream writeable, Stream readable) = GetReadWritePair(streams); - readable.Dispose(); + await readable.DisposeAsync(); byte[] buffer = new byte[4]; Assert.Throws(() => writeable.WriteByte(123)); @@ -2356,49 +2355,49 @@ public virtual async Task ZeroByteWrite_OtherDataReceivedSuccessfully(ReadWriteM byte[][] buffers = new[] { Array.Empty(), "hello"u8.ToArray(), Array.Empty(), "world"u8.ToArray() }; using StreamPair streams = await CreateConnectedStreamsAsync(); - foreach ((Stream writeable, Stream readable) in GetReadWritePairs(streams)) + (Stream writeable, Stream readable) = GetReadWritePair(streams); + + Task writes = Task.Run(async () => { - Task writes = Task.Run(async () => + foreach (byte[] buffer in buffers) { - foreach (byte[] buffer in buffers) - { - await WriteAsync(mode, writeable, buffer, 0, buffer.Length); - } - }); + await WriteAsync(mode, writeable, buffer, 0, buffer.Length); + } + }); - if (FlushRequiredToWriteData) + if (FlushRequiredToWriteData) + { + writes = writes.ContinueWith(t => { - writes = writes.ContinueWith(t => + t.GetAwaiter().GetResult(); + if (FlushGuaranteesAllDataWritten) { - t.GetAwaiter().GetResult(); - if (FlushGuaranteesAllDataWritten) - { - writeable.Flush(); - } - else - { - writeable.Dispose(); - } - }, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default); - } + writeable.Flush(); + } + else + { + writeable.Dispose(); + } + }, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default); + } - var readBytes = new byte[buffers.Sum(b => b.Length)]; - int count = 0; - while (count < readBytes.Length) - { - int n = await readable.ReadAsync(readBytes.AsMemory(count)); - Assert.InRange(n, 1, readBytes.Length - count); - count += n; - } + var readBytes = new byte[buffers.Sum(b => b.Length)]; + int count = 0; + while (count < readBytes.Length) + { + int n = await readable.ReadAsync(readBytes.AsMemory(count)); + Assert.InRange(n, 1, readBytes.Length - count); + count += n; + } - Assert.Equal("helloworld", Encoding.UTF8.GetString(readBytes)); - await writes; + Assert.Equal("helloworld", Encoding.UTF8.GetString(readBytes)); - if (!FlushGuaranteesAllDataWritten) - { - break; - } + if (FlushRequiredToWriteData && !FlushGuaranteesAllDataWritten) + { + await readable.DisposeAsync(); } + + await writes; } [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] @@ -2409,63 +2408,62 @@ public virtual async Task ZeroByteWrite_OtherDataReceivedSuccessfully(ReadWriteM public virtual async Task ReadWrite_CustomMemoryManager_Success(bool useAsync) { using StreamPair streams = await CreateConnectedStreamsAsync(); - foreach ((Stream writeable, Stream readable) in GetReadWritePairs(streams)) - { - using MemoryManager writeBuffer = new NativeMemoryManager(1024); - using MemoryManager readBuffer = new NativeMemoryManager(writeBuffer.Memory.Length); + (Stream writeable, Stream readable) = GetReadWritePair(streams); - Assert.Equal(1024, writeBuffer.Memory.Length); - Assert.Equal(writeBuffer.Memory.Length, readBuffer.Memory.Length); + using MemoryManager writeBuffer = new NativeMemoryManager(1024); + using MemoryManager readBuffer = new NativeMemoryManager(writeBuffer.Memory.Length); - Random.Shared.NextBytes(writeBuffer.Memory.Span); - readBuffer.Memory.Span.Clear(); + Assert.Equal(1024, writeBuffer.Memory.Length); + Assert.Equal(writeBuffer.Memory.Length, readBuffer.Memory.Length); - Task write = useAsync ? - writeable.WriteAsync(writeBuffer.Memory).AsTask() : - Task.Run(() => writeable.Write(writeBuffer.Memory.Span)); - if (FlushRequiredToWriteData) - { - write = write.ContinueWith(t => - { - t.GetAwaiter().GetResult(); - if (FlushGuaranteesAllDataWritten) - { - writeable.Flush(); - } - else - { - writeable.Dispose(); - } - }, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default); - } + Random.Shared.NextBytes(writeBuffer.Memory.Span); + readBuffer.Memory.Span.Clear(); - try + Task write = useAsync ? + writeable.WriteAsync(writeBuffer.Memory).AsTask() : + Task.Run(() => writeable.Write(writeBuffer.Memory.Span)); + if (FlushRequiredToWriteData) + { + write = write.ContinueWith(t => { - int bytesRead = 0; - while (bytesRead < readBuffer.Memory.Length) + t.GetAwaiter().GetResult(); + if (FlushGuaranteesAllDataWritten) { - int n = useAsync ? - await readable.ReadAsync(readBuffer.Memory.Slice(bytesRead)) : - readable.Read(readBuffer.Memory.Span.Slice(bytesRead)); - if (n == 0) - { - break; - } - Assert.InRange(n, 1, readBuffer.Memory.Length - bytesRead); - bytesRead += n; + writeable.Flush(); } + else + { + writeable.Dispose(); + } + }, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default); + } - Assert.True(writeBuffer.Memory.Span.SequenceEqual(readBuffer.Memory.Span)); - } - finally + try + { + int bytesRead = 0; + while (bytesRead < readBuffer.Memory.Length) { - await write; + int n = useAsync ? + await readable.ReadAsync(readBuffer.Memory.Slice(bytesRead)) : + readable.Read(readBuffer.Memory.Span.Slice(bytesRead)); + if (n == 0) + { + break; + } + Assert.InRange(n, 1, readBuffer.Memory.Length - bytesRead); + bytesRead += n; } - if (!FlushGuaranteesAllDataWritten) + Assert.True(writeBuffer.Memory.Span.SequenceEqual(readBuffer.Memory.Span)); + } + finally + { + if (FlushRequiredToWriteData && !FlushGuaranteesAllDataWritten) { - break; + await readable.DisposeAsync(); } + + await write; } } @@ -2542,12 +2540,20 @@ public virtual async Task CopyToAsync_AllDataCopied(int byteCount, bool useAsync Task copyTask; if (useAsync) { - copyTask = readable.CopyToAsync(results); + copyTask = readable.CopyToAsync(results).ContinueWith(t => + { + t.GetAwaiter().GetResult(); + readable.Dispose(); + }, TaskScheduler.Default); await writeable.WriteAsync(dataToCopy); } else { - copyTask = Task.Run(() => readable.CopyTo(results)); + copyTask = Task.Run(() => + { + readable.CopyTo(results); + readable.Dispose(); + }); writeable.Write(new ReadOnlySpan(dataToCopy)); } @@ -2713,7 +2719,8 @@ public virtual async Task ClosedConnection_WritesFailImmediately_ThrowException( using StreamPair streams = await CreateConnectedStreamsAsync(); (Stream writeable, Stream readable) = GetReadWritePair(streams); - readable.Dispose(); + await readable.DisposeAsync(); + Assert.Throws(() => writeable.WriteByte(1)); Assert.Throws(() => writeable.Write(new byte[1], 0, 1)); Assert.Throws(() => writeable.Write(new byte[1])); @@ -2740,7 +2747,10 @@ public virtual async Task ReadAsync_DuringReadAsync_ThrowsIfUnsupported() await Assert.ThrowsAsync(UnsupportedConcurrentExceptionType, async () => await readable.ReadAsync(new byte[1])); writeable.WriteByte(1); - writeable.Dispose(); + if (FlushRequiredToWriteData) + { + writeable.Flush(); + } Assert.Equal(1, await read); } @@ -2792,7 +2802,7 @@ public virtual async Task Dispose_ClosesStream(int disposeMode) using StreamPair streams = await CreateConnectedStreamsAsync(); - foreach (Stream stream in streams) + await Task.WhenAll(streams.Select(stream => Task.Run(async () => { switch (disposeMode) { @@ -2803,7 +2813,7 @@ public virtual async Task Dispose_ClosesStream(int disposeMode) Assert.False(stream.CanRead); Assert.False(stream.CanWrite); - } + }))); } } @@ -2869,18 +2879,25 @@ public virtual async Task Dispose_Flushes(bool useAsync, bool leaveOpen) using StreamPair wrapper = await CreateWrappedConnectedStreamsAsync(streams, leaveOpen); (Stream writeable, Stream readable) = GetReadWritePair(wrapper); - writeable.WriteByte(1); - - if (useAsync) - { - await writeable.DisposeAsync(); - } - else - { - writeable.Dispose(); - } + await Task.WhenAll( + Task.Run(async () => + { + writeable.WriteByte(1); - Assert.Equal(1, readable.ReadByte()); + if (useAsync) + { + await writeable.DisposeAsync(); + } + else + { + writeable.Dispose(); + } + }), + Task.Run(() => + { + Assert.Equal(1, readable.ReadByte()); + readable.Dispose(); + })); } [Theory] @@ -2900,14 +2917,29 @@ public virtual async Task Dispose_ClosesInnerStreamIfDesired(bool useAsync, bool using StreamPair wrapper = await CreateWrappedConnectedStreamsAsync((writeable, readable), leaveOpen); (Stream writeableWrapper, Stream readableWrapper) = GetReadWritePair(wrapper); - if (useAsync) - { - await writeableWrapper.DisposeAsync(); - } - else - { - writeableWrapper.Dispose(); - } + await Task.WhenAll( + Task.Run(async () => + { + if (useAsync) + { + await writeableWrapper.DisposeAsync(); + } + else + { + writeableWrapper.Dispose(); + } + }), + Task.Run(async () => + { + if (useAsync) + { + await readableWrapper.DisposeAsync(); + } + else + { + readableWrapper.Dispose(); + } + })); if (leaveOpen) { @@ -2993,6 +3025,7 @@ await WhenAllOrAnyFailed( Assert.Equal(i, readable.ReadByte()); } Assert.Equal(-1, readable.ReadByte()); + readable.Dispose(); })); } } @@ -3164,8 +3197,9 @@ public StreamPair((Stream, Stream) streams) public virtual void Dispose() { - Stream1?.Dispose(); - Stream2?.Dispose(); + Task.WaitAll( + Task.Run(() => Stream1?.Dispose()), + Task.Run(() => Stream2?.Dispose())); } public IEnumerator GetEnumerator() diff --git a/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs index 4840a7ed39dfd7..e4fb8f027c1623 100644 --- a/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs @@ -137,7 +137,7 @@ public async Task EstablishConnectionAsync() Stream stream = null; #if TARGET_BROWSER closableWrapper = new SocketWrapper(_listenSocket); - stream = new WebSocketStream(_listenSocket, ownsSocket: true); + stream = WebSocketStream.Create(_listenSocket, WebSocketMessageType.Binary, ownsWebSocket: true); #else var socket = await _listenSocket.AcceptAsync().ConfigureAwait(false); closableWrapper = new SocketWrapper(socket); diff --git a/src/libraries/Common/tests/System/Net/Prerequisites/LocalEchoServer.props b/src/libraries/Common/tests/System/Net/Prerequisites/LocalEchoServer.props index ca780615fe9b98..b6099056408490 100644 --- a/src/libraries/Common/tests/System/Net/Prerequisites/LocalEchoServer.props +++ b/src/libraries/Common/tests/System/Net/Prerequisites/LocalEchoServer.props @@ -33,7 +33,5 @@ GlobalPropertiesToRemove="TargetOS;TargetArchitecture;RuntimeIdentifier;RunAOTCompilation" AdditionalProperties="_TargetFrameworkForXHarness=$(_TargetFrameworkForXHarness)" ReferenceOutputAssembly="false"/> - diff --git a/src/libraries/Common/tests/System/Net/WebSockets/WebSocketStream.cs b/src/libraries/Common/tests/System/Net/WebSockets/WebSocketStream.cs deleted file mode 100644 index 84a96f15c48582..00000000000000 --- a/src/libraries/Common/tests/System/Net/WebSockets/WebSocketStream.cs +++ /dev/null @@ -1,277 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Threading.Tasks; -using System.IO; -using System.Threading; - -namespace System.Net.WebSockets -{ - public class WebSocketStream : Stream - { - // Used by the class to hold the underlying socket the stream uses. - private readonly WebSocket _streamSocket; - - // Whether the stream should dispose of the socket when the stream is disposed - private readonly bool _ownsSocket; - - // Used by the class to indicate that the stream is m_Readable. - private bool _readable; - - // Used by the class to indicate that the stream is writable. - private bool _writeable; - - // Whether Dispose has been called. - private bool _disposed; - - public WebSocketStream(WebSocket socket) - : this(socket, FileAccess.ReadWrite, ownsSocket: false) - { - } - - public WebSocketStream(WebSocket socket, bool ownsSocket) - : this(socket, FileAccess.ReadWrite, ownsSocket) - { - } - - public WebSocketStream(WebSocket socket, FileAccess access) - : this(socket, access, ownsSocket: false) - { - } - - public WebSocketStream(WebSocket socket, FileAccess access, bool ownsSocket) - { - ArgumentNullException.ThrowIfNull(socket); - if (socket.State != WebSocketState.Open) - { - throw new IOException("The operation is not allowed on non-connected sockets."); - } - - _streamSocket = socket; - _ownsSocket = ownsSocket; - - switch (access) - { - case FileAccess.Read: - _readable = true; - break; - case FileAccess.Write: - _writeable = true; - break; - case FileAccess.ReadWrite: - default: // assume FileAccess.ReadWrite - _readable = true; - _writeable = true; - break; - } - } - - public WebSocket Socket => _streamSocket; - - protected bool Readable - { - get { return _readable; } - set { _readable = value; } - } - - protected bool Writeable - { - get { return _writeable; } - set { _writeable = value; } - } - - public override bool CanRead => _readable; - public override bool CanSeek => false; - public override bool CanWrite => _writeable; - public override bool CanTimeout => true; - public override long Length => throw new NotSupportedException("This stream does not support seek operations."); - - public override long Position - { - get - { - throw new NotSupportedException("This stream does not support seek operations."); - } - set - { - throw new NotSupportedException("This stream does not support seek operations."); - } - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException("This stream does not support seek operations."); - } - - public override int Read(byte[] buffer, int offset, int count) - { - throw new IOException("The operation is not allowed on a non-blocking Socket."); - } - - public override int Read(Span buffer) - { - throw new IOException("The operation is not allowed on a non-blocking Socket."); - } - - public override void Write(byte[] buffer, int offset, int count) - { - throw new IOException("The operation is not allowed on a non-blocking Socket."); - } - - public override void Write(ReadOnlySpan buffer) - { - throw new IOException("The operation is not allowed on a non-blocking Socket."); - } - - private int _closeTimeout = -1; - - public void Close(int timeout) - { - if (timeout < -1) - { - throw new ArgumentOutOfRangeException(nameof(timeout)); - } - _closeTimeout = timeout; - Dispose(); - } - - protected override void Dispose(bool disposing) - { - if (Interlocked.Exchange(ref _disposed, true)) - { - return; - } - - if (disposing) - { - _readable = false; - _writeable = false; - if (_ownsSocket) - { - if (_streamSocket != null && (_streamSocket.State == WebSocketState.Open || _streamSocket.State == WebSocketState.Connecting || _streamSocket.State == WebSocketState.None)) - { - try - { - var task = _streamSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "closing remoteLoop", CancellationToken.None); - Task.WaitAll(task); - } - catch (Exception) - { - } - finally - { - _streamSocket.Dispose(); - } - } - } - } - - base.Dispose(disposing); - } - - ~WebSocketStream() => Dispose(false); - - public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - ValidateBufferArguments(buffer, offset, count); - ThrowIfDisposed(); - if (!CanRead) - { - throw new InvalidOperationException("The stream does not support reading."); - } - - try - { - var res = await _streamSocket.ReceiveAsync(new Memory(buffer, offset, count), cancellationToken); - return res.Count; - } - catch (Exception exception) when (!(exception is OutOfMemoryException)) - { - throw WrapException("Unable to read data from the transport connection", exception); - } - } - - public async override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) - { - bool canRead = CanRead; // Prevent race with Dispose. - ThrowIfDisposed(); - if (!canRead) - { - throw new InvalidOperationException("The stream does not support reading."); - } - - try - { - var res = await _streamSocket.ReceiveAsync(buffer, - cancellationToken); - return res.Count; - } - catch (Exception exception) when (!(exception is OutOfMemoryException)) - { - throw WrapException("Unable to read data from the transport connection", exception); - } - } - - public async override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - ValidateBufferArguments(buffer, offset, count); - ThrowIfDisposed(); - if (!CanWrite) - { - throw new InvalidOperationException("The stream does not support writing."); - } - - try - { - await _streamSocket.SendAsync(new ReadOnlyMemory(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken); - } - catch (Exception exception) when (!(exception is OutOfMemoryException)) - { - throw WrapException("Unable to write data to the transport connection", exception); - } - } - - public async override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) - { - bool canWrite = CanWrite; // Prevent race with Dispose. - ThrowIfDisposed(); - if (!canWrite) - { - throw new InvalidOperationException("The stream does not support writing."); - } - - try - { - await _streamSocket.SendAsync(buffer, WebSocketMessageType.Binary, true, cancellationToken); - } - catch (Exception exception) when (!(exception is OutOfMemoryException)) - { - throw WrapException("Unable to write data to the transport connection", exception); - } - } - - public override void Flush() - { - } - - public override Task FlushAsync(CancellationToken cancellationToken) - { - return Task.CompletedTask; - } - - public override void SetLength(long value) - { - throw new NotSupportedException("This stream does not support seek operations."); - } - - private void ThrowIfDisposed() - { - ObjectDisposedException.ThrowIf(_disposed, this); - } - - private static IOException WrapException(string resourceFormatString, Exception innerException) - { - return new IOException(resourceFormatString, innerException); - } - } -} diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index ae5337ec053857..256f16b11211f0 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -46,6 +46,36 @@ public static void RegisterPrefixes() { } public virtual System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, System.Net.WebSockets.WebSocketMessageFlags messageFlags, System.Threading.CancellationToken cancellationToken) { throw null; } protected static void ThrowOnInvalidState(System.Net.WebSockets.WebSocketState state, params System.Net.WebSockets.WebSocketState[] validStates) { } } + public partial class WebSocketStream : System.IO.Stream + { + private protected WebSocketStream() { } + public override System.IAsyncResult BeginWrite(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } + public override System.IAsyncResult BeginRead(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } + public override bool CanRead { get { throw null; } } + public override bool CanSeek { get { throw null; } } + public override bool CanWrite { get { throw null; } } + public static System.Net.WebSockets.WebSocketStream Create(WebSocket webSocket, WebSocketMessageType writeMessageType, bool ownsWebSocket = false) { throw null; } + public static System.Net.WebSockets.WebSocketStream Create(WebSocket webSocket, WebSocketMessageType writeMessageType, TimeSpan closeTimeout) { throw null; } + public static System.Net.WebSockets.WebSocketStream CreateWritableMessageStream(WebSocket webSocket, WebSocketMessageType writeMessageType) { throw null; } + public static System.Net.WebSockets.WebSocketStream CreateReadableMessageStream(WebSocket webSocket) { throw null; } + protected override void Dispose(bool disposing) { } + public override System.Threading.Tasks.ValueTask DisposeAsync() { throw null; } + public override int EndRead(System.IAsyncResult asyncResult) { throw null; } + public override void EndWrite(IAsyncResult asyncResult) { } + public override void Flush() { } + public override System.Threading.Tasks.Task FlushAsync(System.Threading.CancellationToken cancellationToken) { throw null; } + public override long Length { get { throw null; } } + public override long Position { get { throw null; } set { } } + public override int Read(byte[] buffer, int offset, int count) { throw null; } + public override System.Threading.Tasks.Task ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } + public override System.Threading.Tasks.ValueTask ReadAsync(System.Memory buffer, System.Threading.CancellationToken cancellationToken = default) { throw null; } + public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; } + public override void SetLength(long value) { } + public System.Net.WebSockets.WebSocket WebSocket { get { throw null; } } + public override void Write(byte[] buffer, int offset, int count) { throw null; } + public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } + public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory buffer, System.Threading.CancellationToken cancellationToken = default) { throw null; } + } public enum WebSocketCloseStatus { NormalClosure = 1000, diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index a57e81b239a92c..3ddd42906a160c 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -138,6 +138,9 @@ The argument must be a value between {0} and {1}. + + The timeout must be a value between non-negative or Timeout.InfiniteTimeSpan. + The WebSocket didn't recieve a Pong frame in response to a Ping frame within the configured KeepAliveTimeout. diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index c02ed3fa73d658..a28435597cac27 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -33,6 +33,7 @@ + buffer, WebSocketMessageType m { if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); - if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) - { - throw new ArgumentException(SR.Format( - SR.net_WebSockets_Argument_InvalidMessageType, - nameof(WebSocketMessageType.Close), nameof(SendAsync), nameof(WebSocketMessageType.Binary), nameof(WebSocketMessageType.Text), nameof(CloseOutputAsync)), - nameof(messageType)); - } + ThrowIfInvalidMessageType(messageType); WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); @@ -318,13 +312,7 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag { if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); - if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) - { - throw new ArgumentException(SR.Format( - SR.net_WebSockets_Argument_InvalidMessageType, - nameof(WebSocketMessageType.Close), nameof(SendAsync), nameof(WebSocketMessageType.Binary), nameof(WebSocketMessageType.Text), nameof(CloseOutputAsync)), - nameof(messageType)); - } + ThrowIfInvalidMessageType(messageType); try { @@ -1901,6 +1889,20 @@ static void LogFaulted(Task task, object? thisObj) } } + internal static void ThrowIfInvalidMessageType(WebSocketMessageType messageType, [CallerArgumentExpression(nameof(messageType))] string? paramName = null) + { + if (messageType is not (WebSocketMessageType.Text or WebSocketMessageType.Binary)) + { + ThrowInvalidMessageType(paramName); + } + + static void ThrowInvalidMessageType(string? paramName) => + throw new ArgumentException(SR.Format( + SR.net_WebSockets_Argument_InvalidMessageType, + nameof(WebSocketMessageType.Close), nameof(SendAsync), nameof(WebSocketMessageType.Binary), nameof(WebSocketMessageType.Text), nameof(CloseOutputAsync)), + paramName); + } + private sealed class Utf8MessageState { internal bool SequenceInProgress; diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketStream.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketStream.cs new file mode 100644 index 00000000000000..2f350d4846a6d8 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketStream.cs @@ -0,0 +1,383 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + /// Provides a that delegates to a wrapped . + public class WebSocketStream : Stream + { + /// The default number of seconds before canceling CloseAsync operation issued during stream disposal. + private const int DefaultCloseTimeoutSeconds = 16; + + /// Whether the stream has been disposed. + private bool _disposed; + + /// + /// Initializes a new instance of the class using a specified instance. + /// + /// The wrapped by this instance. + private WebSocketStream(WebSocket webSocket) => WebSocket = webSocket; + + /// Creates a that delegates to a wrapped . + /// The wrapped . + /// The type of messages that should be written as part of calls. Each write produces a message. + /// + /// if disposing the should close the underlying ; otherwise, . Defaults to . + /// + /// A new instance of that forwards reads and writes on the to the underlying . + public static WebSocketStream Create(WebSocket webSocket, WebSocketMessageType writeMessageType, bool ownsWebSocket = false) + { + ArgumentNullException.ThrowIfNull(webSocket); + ManagedWebSocket.ThrowIfInvalidMessageType(writeMessageType); + + return new ReadWriteStream( + webSocket, + writeMessageType, + closeTimeout: ownsWebSocket ? TimeSpan.FromSeconds(DefaultCloseTimeoutSeconds) : null); + } + + /// Creates a that delegates to a wrapped . + /// The wrapped . + /// The type of messages that should be written as part of calls. Each write produces a message. + /// The amount of time that disposing the will wait for a graceful closing of the 's output. + /// A new instance of that forwards reads and writes on the to the underlying . + public static WebSocketStream Create(WebSocket webSocket, WebSocketMessageType writeMessageType, TimeSpan closeTimeout) + { + ArgumentNullException.ThrowIfNull(webSocket); + ManagedWebSocket.ThrowIfInvalidMessageType(writeMessageType); + if (closeTimeout < TimeSpan.Zero && closeTimeout != Timeout.InfiniteTimeSpan) + { + throw new ArgumentOutOfRangeException(nameof(closeTimeout), SR.net_WebSockets_TimeoutOutOfRange); + } + + return new ReadWriteStream(webSocket, writeMessageType, closeTimeout); + } + + /// Creates a that writes a single message to the underlying . + /// The wrapped . + /// + /// The type of messages that should be written as part of calls. + /// Each write on the results in writing a partial message to the underlying . + /// When the is disposed, it will write an empty message to the underlying to signal the end of the message. + /// + /// A new instance of that forwards writes on the to the underlying . + public static WebSocketStream CreateWritableMessageStream(WebSocket webSocket, WebSocketMessageType writeMessageType) + { + ArgumentNullException.ThrowIfNull(webSocket); + ManagedWebSocket.ThrowIfInvalidMessageType(writeMessageType); + + return new WriteMessageStream(webSocket, writeMessageType); + } + + /// Creates a that reads a single message from the underlying . + /// The wrapped . + /// A new instance of that forwards reads on the to the underlying . + /// + /// Reads on the will read a single message from the underlying . This means that reads will start returning + /// 0 bytes read once all data has been consumed from the next message received in the . + /// + public static WebSocketStream CreateReadableMessageStream(WebSocket webSocket) + { + ArgumentNullException.ThrowIfNull(webSocket); + + return new ReadMessageStream(webSocket); + } + + /// Gets the underlying wrapped by this . + /// The used to construct this instance. + public WebSocket WebSocket { get; } + + /// + public override bool CanRead => !_disposed && WebSocket.State is WebSocketState.Open or WebSocketState.CloseSent; + + /// + public override bool CanWrite => !_disposed && WebSocket.State is WebSocketState.Open or WebSocketState.CloseReceived; + + /// + public override bool CanSeek => false; + + /// + protected override void Dispose(bool disposing) + { + if (disposing) + { + // There are no synchronous operations on WebSocket, so we're forced to do sync-over-async. + DisposeAsync().AsTask().GetAwaiter().GetResult(); + } + } + + /// + public override void Flush() { } + + /// + public override Task FlushAsync(CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : + Task.CompletedTask; + + /// + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateBufferArguments(buffer, offset, count); + + return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + } + + /// + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) => + ValueTask.FromException(new NotSupportedException()); + + /// + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => + TaskToAsyncResult.Begin(ReadAsync(buffer, offset, count), callback, state); + + /// + public override int EndRead(IAsyncResult asyncResult) => + TaskToAsyncResult.End(asyncResult); + + /// + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateBufferArguments(buffer, offset, count); + + return WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + } + + /// + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) => + ValueTask.FromException(new NotSupportedException()); + + /// + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => + TaskToAsyncResult.Begin(WriteAsync(buffer, offset, count), callback, state); + + /// + public override void EndWrite(IAsyncResult asyncResult) => + TaskToAsyncResult.End(asyncResult); + + /// + public override int Read(byte[] buffer, int offset, int count) => + ReadAsync(buffer, offset, count, default).GetAwaiter().GetResult(); + + /// + public override void Write(byte[] buffer, int offset, int count) => + WriteAsync(buffer, offset, count, default).GetAwaiter().GetResult(); + + /// + public override long Length => throw new NotSupportedException(); + + /// + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + /// + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + + /// + public override void SetLength(long value) => throw new NotSupportedException(); + + /// Provides stream that wraps a and forwards reads/writes. + private sealed class ReadWriteStream(WebSocket webSocket, WebSocketMessageType writeMessageType, TimeSpan? closeTimeout) : WebSocketStream(webSocket) + { + private readonly WebSocketMessageType _messageType = writeMessageType; + private readonly TimeSpan? _closeTimeout = closeTimeout; + + /// + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (_disposed) + { + return ValueTask.FromException(new ObjectDisposedException(GetType().FullName)); + } + + if (!CanWrite) + { + return ValueTask.FromException(new NotSupportedException(SR.NotWriteableStream)); + } + + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + return WebSocket.SendAsync(buffer, _messageType, endOfMessage: true, cancellationToken); + } + + /// + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed, this); + + if (!CanRead) + { + throw new NotSupportedException(SR.NotReadableStream); + } + + cancellationToken.ThrowIfCancellationRequested(); + + while (WebSocket.State < WebSocketState.CloseReceived) + { + ValueWebSocketReceiveResult result = await WebSocket.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false); + if (result.MessageType is WebSocketMessageType.Close) + { + break; + } + + if (result.Count > 0 || buffer.IsEmpty) + { + return result.Count; + } + } + + return 0; + } + + /// + public override async ValueTask DisposeAsync() + { + if (!_disposed) + { + _disposed = true; + + if (_closeTimeout is { } timeout) + { + if (WebSocket.State is < WebSocketState.Closed) + { + CancellationTokenSource? cts = null; + CancellationToken ct; + + if (timeout == default) + { + ct = new CancellationToken(canceled: true); + } + else if (timeout == Timeout.InfiniteTimeSpan) + { + ct = CancellationToken.None; + } + else + { + cts = new CancellationTokenSource(timeout); + ct = cts.Token; + } + + try + { + await WebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, ct).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + } + finally + { + cts?.Dispose(); + } + } + + WebSocket.Dispose(); + } + } + } + } + + /// Provides a stream that wraps a and writes a single message. + private sealed class WriteMessageStream(WebSocket webSocket, WebSocketMessageType writeMessageType) : WebSocketStream(webSocket) + { + private readonly WebSocketMessageType _messageType = writeMessageType; + + /// + public override bool CanRead => false; + + /// + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (_disposed) + { + return ValueTask.FromException(new ObjectDisposedException(GetType().FullName)); + } + + if (!CanWrite) + { + return ValueTask.FromException(new NotSupportedException(SR.NotWriteableStream)); + } + + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + return WebSocket.SendAsync(buffer, _messageType, endOfMessage: false, cancellationToken); + } + + public override ValueTask DisposeAsync() + { + if (!_disposed) + { + _disposed = true; + return WebSocket.SendAsync(ReadOnlyMemory.Empty, _messageType, endOfMessage: true, CancellationToken.None); + } + + return default; + } + } + + /// Provides a stream that wraps a and reads a single message. + private sealed class ReadMessageStream(WebSocket webSocket) : WebSocketStream(webSocket) + { + /// Whether we've seen and end-of-message marker. + private bool _eof; + + /// + public override bool CanWrite => false; + + /// + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed, this); + + if (!CanRead) + { + throw new NotSupportedException(SR.NotReadableStream); + } + + cancellationToken.ThrowIfCancellationRequested(); + + while (!_eof && WebSocket.State < WebSocketState.CloseReceived) + { + ValueWebSocketReceiveResult result = await WebSocket.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false); + if (result.MessageType is WebSocketMessageType.Close) + { + break; + } + + if (result.EndOfMessage) + { + _eof = true; + } + + if (result.Count > 0 || buffer.IsEmpty) + { + return result.Count; + } + } + + return 0; + } + + /// + public override ValueTask DisposeAsync() + { + _disposed = true; + if (!_eof && WebSocket.State < WebSocketState.CloseReceived) + { + WebSocket.Abort(); + } + return default; + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index a7f09ff31db29e..51b03c7dbed890 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -8,14 +8,15 @@ - + + + - + + - - @@ -24,4 +25,7 @@ + + + diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketStreamTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketStreamTests.cs new file mode 100644 index 00000000000000..7c77a98928dd51 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketStreamTests.cs @@ -0,0 +1,419 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.IO.Tests; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Tests +{ + public abstract class WebSocketStreamTests : ConnectedStreamConformanceTests + { + protected override bool BlocksOnZeroByteReads => true; + protected override bool FlushRequiredToWriteData => false; + protected override bool ReadsReadUntilSizeOrEof => false; + protected override bool UsableAfterCanceledReads => false; + protected override Type UnsupportedConcurrentExceptionType => null; + + protected static (WebSocket webSocket1, WebSocket webSocket2) CreateWebSockets() + { + (Stream stream1, Stream stream2) = ConnectedStreams.CreateBidirectional(); + + WebSocket webSocket1 = WebSocket.CreateFromStream(stream1, isServer: false, null, Timeout.InfiniteTimeSpan); + WebSocket webSocket2 = WebSocket.CreateFromStream(stream2, isServer: true, null, Timeout.InfiniteTimeSpan); + + return (webSocket1, webSocket2); + } + } + + public sealed class WebSocketStreamCreateTests : WebSocketStreamTests + { + protected override Task CreateConnectedStreamsAsync() + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + return Task.FromResult(new StreamPair( + WebSocketStream.Create(webSocket1, WebSocketMessageType.Binary, TimeSpan.FromSeconds(120)), + WebSocketStream.Create(webSocket2, WebSocketMessageType.Binary, TimeSpan.FromSeconds(120)))); + } + + [Fact] + public void Create_InvalidArgs_Throws() + { + AssertExtensions.Throws("webSocket", () => WebSocketStream.Create(null, WebSocketMessageType.Binary)); + AssertExtensions.Throws("webSocket", () => WebSocketStream.Create(null, WebSocketMessageType.Text, ownsWebSocket: true)); + + AssertExtensions.Throws("webSocket", () => WebSocketStream.Create(null, WebSocketMessageType.Text, TimeSpan.FromSeconds(30))); + + WebSocket webSocket = WebSocket.CreateFromStream(new MemoryStream(), new()); + + AssertExtensions.Throws("closeTimeout", () => WebSocketStream.Create(webSocket, WebSocketMessageType.Text, TimeSpan.FromSeconds(-2))); + AssertExtensions.Throws("closeTimeout", () => WebSocketStream.Create(webSocket, WebSocketMessageType.Text, TimeSpan.FromSeconds(-1))); + AssertExtensions.Throws("writeMessageType", () => WebSocketStream.CreateWritableMessageStream(webSocket, WebSocketMessageType.Close)); + + Assert.NotNull(WebSocketStream.Create(webSocket, WebSocketMessageType.Text, Timeout.InfiniteTimeSpan)); + Assert.NotNull(WebSocketStream.Create(webSocket, WebSocketMessageType.Text, TimeSpan.Zero)); + Assert.NotNull(WebSocketStream.Create(webSocket, WebSocketMessageType.Text, TimeSpan.FromSeconds(1))); + } + + [Theory] + [InlineData(null)] + [InlineData(false)] + [InlineData(true)] + public void Create_Roundtrips(bool? ownsWebSocket) + { + (WebSocket webSocket1, _) = CreateWebSockets(); + + WebSocketStream stream = ownsWebSocket is not null ? + WebSocketStream.Create(webSocket1, WebSocketMessageType.Text, ownsWebSocket.Value) : + WebSocketStream.Create(webSocket1, WebSocketMessageType.Text); + + Assert.Same(webSocket1, stream.WebSocket); + } + + [Theory] + [InlineData(null)] + [InlineData(false)] + [InlineData(true)] + public async Task Dispose_ClosesWebSocketIfOwned(bool? ownsWebSocket) + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + + WebSocketStream stream = ownsWebSocket is not null ? + WebSocketStream.Create(webSocket1, WebSocketMessageType.Text, ownsWebSocket.Value) : + WebSocketStream.Create(webSocket1, WebSocketMessageType.Text); + Assert.Equal(WebSocketState.Open, webSocket1.State); + + if (ownsWebSocket is true) + { + await Task.WhenAll( + stream.DisposeAsync().AsTask(), + webSocket2.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None)); + + Assert.Equal(WebSocketState.Closed, webSocket1.State); + } + else + { + stream.Dispose(); + Assert.Equal(WebSocketState.Open, webSocket1.State); + } + } + + [Fact] + public async Task DisposeWebSocket_CantReadOrWrite() + { + (WebSocket webSocket, _) = CreateWebSockets(); + + Stream stream1 = WebSocketStream.Create(webSocket, WebSocketMessageType.Text, ownsWebSocket: false); + Stream stream2 = WebSocketStream.CreateWritableMessageStream(webSocket, WebSocketMessageType.Binary); + Stream stream3 = WebSocketStream.CreateReadableMessageStream(webSocket); + + Assert.False(stream1.CanSeek); + Assert.True(stream1.CanRead); + Assert.True(stream1.CanWrite); + + Assert.False(stream1.CanSeek); + Assert.False(stream2.CanRead); + Assert.True(stream2.CanWrite); + + Assert.False(stream1.CanSeek); + Assert.True(stream3.CanRead); + Assert.False(stream3.CanWrite); + + webSocket.Dispose(); + + foreach (Stream stream in new[] { stream1, stream2, stream3 }) + { + Assert.False(stream.CanSeek); + Assert.False(stream.CanRead); + Assert.False(stream.CanWrite); + + Assert.Throws(() => stream.Read(new byte[1], 0, 1)); + Assert.Throws(() => stream.Write(new byte[1], 0, 1)); + Assert.Throws(() => stream.ReadByte()); + Assert.Throws(() => stream.WriteByte(0)); + await Assert.ThrowsAsync(async () => await stream.ReadAsync(new byte[1], 0, 1, CancellationToken.None)); + await Assert.ThrowsAsync(async () => await stream.WriteAsync(new byte[1], 0, 1, CancellationToken.None)); + await Assert.ThrowsAsync(async () => await stream.ReadAsync(new Memory(new byte[1]), CancellationToken.None)); + await Assert.ThrowsAsync(async () => await stream.WriteAsync(new ReadOnlyMemory(new byte[1]), CancellationToken.None)); + } + } + + [Theory] + [InlineData(WebSocketMessageType.Binary)] + [InlineData(WebSocketMessageType.Text)] + public async Task Write_EveryWriteProducesMessage(WebSocketMessageType messageType) + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + + WebSocketStream stream1 = WebSocketStream.Create(webSocket1, messageType); + + Memory buffer = new byte[10]; + for (int i = 0; i < 3; i++) + { + buffer.Span.Clear(); + + stream1.Write("hello"u8); + ValueWebSocketReceiveResult message = await webSocket2.ReceiveAsync(buffer, default); + Assert.True(message.EndOfMessage); + Assert.Equal(messageType, message.MessageType); + Assert.Equal(5, message.Count); + Assert.Equal("hello"u8, buffer.Span.Slice(0, 5)); + } + } + + [Fact] + public async Task ClosedSocket_Reads0() + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + + using WebSocketStream stream2 = WebSocketStream.Create(webSocket2, WebSocketMessageType.Text, ownsWebSocket: true); + + var read = stream2.ReadAsync(new byte[1], 0, 1, CancellationToken.None); + + await webSocket1.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, default); + + Assert.Equal(0, await read); + } + + [Theory] + [InlineData(false, 0)] + [InlineData(false, 1)] + [InlineData(true, 0)] + [InlineData(true, 1)] + public async Task Dispose_TimeoutApplies(bool useAsync, int timeoutSeconds) + { + (WebSocket webSocket1, _) = CreateWebSockets(); + + WebSocketStream stream = WebSocketStream.Create(webSocket1, WebSocketMessageType.Text, TimeSpan.FromSeconds(timeoutSeconds)); + + if (useAsync) + { + await stream.DisposeAsync(); + } + else + { + stream.Dispose(); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Dispose_InfiniteTimeout(bool useAsync) + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + + WebSocketStream stream = WebSocketStream.Create(webSocket1, WebSocketMessageType.Text, Timeout.InfiniteTimeSpan); + + Task disposeTask = Task.Run(async () => + { + if (useAsync) + { + await stream.DisposeAsync(); + } + else + { + stream.Dispose(); + } + }); + + await Task.Delay(100); + Assert.False(disposeTask.IsCompleted); + + await webSocket2.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None); + await disposeTask; + } + } + + public sealed class WebSocketStreamCreateMessageTests : WebSocketStreamTests + { + protected override Task CreateConnectedStreamsAsync() + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + return Task.FromResult(new StreamPair( + WebSocketStream.CreateWritableMessageStream(webSocket2, WebSocketMessageType.Binary), + WebSocketStream.CreateReadableMessageStream(webSocket1))); + } + + [Fact] + public void Create_InvalidArgs_Throws() + { + AssertExtensions.Throws("webSocket", () => WebSocketStream.CreateWritableMessageStream(null, WebSocketMessageType.Binary)); + AssertExtensions.Throws("webSocket", () => WebSocketStream.CreateReadableMessageStream(null)); + + WebSocket webSocket = WebSocket.CreateFromStream(new MemoryStream(), new()); + AssertExtensions.Throws("writeMessageType", () => WebSocketStream.CreateWritableMessageStream(webSocket, WebSocketMessageType.Close)); + } + + [Fact] + public void Create_Roundtrips() + { + (WebSocket webSocket, _) = CreateWebSockets(); + WebSocketStream stream; + + stream = WebSocketStream.CreateWritableMessageStream(webSocket, WebSocketMessageType.Text); + Assert.Same(webSocket, stream.WebSocket); + stream.Dispose(); + Assert.Same(webSocket, stream.WebSocket); + Assert.Equal(WebSocketState.Open, webSocket.State); + + stream = WebSocketStream.CreateReadableMessageStream(webSocket); + Assert.Same(webSocket, stream.WebSocket); + stream.Dispose(); // For read message stream, disposing is equal to cancelling a read operation + Assert.Same(webSocket, stream.WebSocket); + Assert.Equal(WebSocketState.Aborted, webSocket.State); + } + + [Theory] + [InlineData(WebSocketMessageType.Binary)] + [InlineData(WebSocketMessageType.Text)] + public async Task Write_EveryStreamProducesMessage(WebSocketMessageType messageType) + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + + ValueWebSocketReceiveResult message; + Memory buffer = new byte[10]; + for (int i = 0; i < 3; i++) + { + buffer.Span.Clear(); + + using (WebSocketStream stream1 = WebSocketStream.CreateWritableMessageStream(webSocket1, messageType)) + { + foreach (byte b in "hello"u8.ToArray()) + { + stream1.WriteByte(b); + + message = await webSocket2.ReceiveAsync(buffer, default); + Assert.False(message.EndOfMessage); + Assert.Equal(messageType, message.MessageType); + Assert.Equal(1, message.Count); + Assert.Equal(b, buffer.Span[0]); + } + } + + message = await webSocket2.ReceiveAsync(buffer, default); + Assert.True(message.EndOfMessage); + Assert.Equal(messageType, message.MessageType); + Assert.Equal(0, message.Count); + } + } + + [Theory] + [InlineData(WebSocketMessageType.Binary)] + [InlineData(WebSocketMessageType.Text)] + public async Task Read_EveryStreamConsumesMessage(WebSocketMessageType messageType) + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + + Memory buffer = new byte[10]; + for (int i = 0; i < 3; i++) + { + buffer.Span.Clear(); + + using (WebSocketStream stream1 = WebSocketStream.CreateReadableMessageStream(webSocket2)) + { + foreach (byte b in "hello"u8.ToArray()) + { + await webSocket1.SendAsync(new[] { b }, messageType, endOfMessage: false, default); + Assert.Equal(b, stream1.ReadByte()); + } + + await webSocket1.SendAsync(Array.Empty(), messageType, endOfMessage: true, default); + Assert.Equal(-1, stream1.ReadByte()); + } + } + } + + [Theory] + [InlineData(false, false, WebSocketState.Aborted)] // abortive: read canceled + [InlineData(true, false, WebSocketState.Open)] // graceful: EOF consumed + [InlineData(false, true, WebSocketState.CloseReceived)] // graceful: Close frame consumed + [InlineData(true, true, WebSocketState.Open)] // graceful: EOF consumed, Close frame NOT consumed (no reads after EOF) + public async Task Read_DisposeBeforeEofOrCloseIsAbortive(bool eof, bool close, WebSocketState expectedWebSocketState) + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + + byte[] data = "hello"u8.ToArray(); + await webSocket1.SendAsync(data, WebSocketMessageType.Binary, endOfMessage: false, default); + + if (eof) + { + await webSocket1.SendAsync(Array.Empty(), WebSocketMessageType.Binary, endOfMessage: true, default); + } + + if (close) + { + await webSocket1.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, default); + } + + WebSocketStream stream2 = WebSocketStream.CreateReadableMessageStream(webSocket2); + Memory buffer = new byte[10]; + + await stream2.ReadExactlyAsync(buffer[..data.Length], default); + Assert.Equal(data, buffer[..data.Length].ToArray()); + Assert.Equal(WebSocketState.Open, webSocket2.State); + + if (eof || close) + { + Assert.Equal(-1, stream2.ReadByte()); // consuming EOF or Close + } + + stream2.Dispose(); + Assert.Equal(expectedWebSocketState, webSocket2.State); + } + + [Theory] + [InlineData(WebSocketMessageType.Binary, 0)] + [InlineData(WebSocketMessageType.Binary, 1)] + [InlineData(WebSocketMessageType.Binary, 5)] + [InlineData(WebSocketMessageType.Text, 0)] + [InlineData(WebSocketMessageType.Text, 1)] + [InlineData(WebSocketMessageType.Text, 5)] + public async Task WriteRead_StreamPairPerMessage(WebSocketMessageType messageType, int length) + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + IEnumerable source = Enumerable.Range('a', length).Select(c => (byte)c); + + for (int i = 0; i < 3; i++) + { + await Task.WhenAll( + Task.Run(async () => + { + using WebSocketStream stream1 = WebSocketStream.CreateWritableMessageStream(webSocket1, messageType); + foreach (byte b in source) + { + await stream1.WriteAsync([b], 0, 1, default); + await Task.Delay(1); + } + }), + Task.Run(async () => + { + using WebSocketStream stream2 = WebSocketStream.CreateReadableMessageStream(webSocket2); + Memory buffer = new byte[length * 2 + 1]; + int bytesRead = await stream2.ReadAtLeastAsync(buffer, buffer.Length, throwOnEndOfStream: false); + + Assert.Equal(length, bytesRead); + Assert.Equal(source.ToArray(), buffer.Slice(0, bytesRead).ToArray()); + })); + } + } + + [Fact] + public async Task ClosedSocket_Reads0() + { + (WebSocket webSocket1, WebSocket webSocket2) = CreateWebSockets(); + + using WebSocketStream stream2 = WebSocketStream.CreateReadableMessageStream(webSocket2); + + var read = stream2.ReadAsync(new byte[1], 0, 1, CancellationToken.None); + + await webSocket1.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, default); + + Assert.Equal(0, await read); + } + } +}