diff --git a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx index a01244712e108..72a5478e850ac 100644 --- a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx @@ -294,9 +294,6 @@ The result of the operation was already consumed and may not be used again. - - Another continuation was already registered. - The FileStream must have been opened for asynchronous reading and writing. diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index 7519dbd6d685d..9631bb66e8ca5 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -918,26 +918,13 @@ internal AsyncTaskMethodBuilder GetCompletionResponsibility(out bool re /// A SocketAsyncEventArgs that can be awaited to get the result of an operation. internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource, IValueTaskSource, IValueTaskSource, IValueTaskSource { - private static readonly Action s_completedSentinel = new Action(state => throw new InvalidOperationException(SR.Format(SR.net_sockets_valuetaskmisuse, nameof(s_completedSentinel)))); /// The owning socket. private readonly Socket _owner; /// Whether this should be cached as a read or a write on the private readonly bool _isReadForCaching; - /// - /// if it has completed. Another delegate if OnCompleted was called before the operation could complete, - /// in which case it's the delegate to invoke when the operation does complete. - /// - private Action? _continuation; - private ExecutionContext? _executionContext; - private object? _scheduler; - /// Current token value given to a ValueTask and then verified against the value it passes back to us. - /// - /// This is not meant to be a completely reliable mechanism, doesn't require additional synchronization, etc. - /// It's purely a best effort attempt to catch misuse, including awaiting for a value task twice and after - /// it's already being reused by someone else. - /// - private short _token; - /// The cancellation token used for the current operation. + /// Core logic for the IValueTaskSource implementations. + private ManualResetValueTaskSourceCore _mrvtsc; + /// The cancellation token used for the current operation. Stored to propagate the most relevant exception. private CancellationToken _cancellationToken; /// Initializes the event args. @@ -950,12 +937,18 @@ public AwaitableSocketAsyncEventArgs(Socket owner, bool isReceiveForCaching) : public bool WrapExceptionsForNetworkStream { get; set; } - private void Release() + /// Resets this instance after an asynchronous completion and puts it back into the pool. + private void ReleaseForAsyncCompletion() { _cancellationToken = default; - _token++; - _continuation = null; + _mrvtsc.Reset(); + ReleaseForSyncCompletion(); + } + /// Resets this instance after a synchronous completion and puts it back into the pool. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ReleaseForSyncCompletion() + { ref AwaitableSocketAsyncEventArgs? cache = ref _isReadForCaching ? ref _owner._singleBufferReceiveEventArgs : ref _owner._singleBufferSendEventArgs; if (Interlocked.CompareExchange(ref cache, this, null) != null) { @@ -963,49 +956,16 @@ private void Release() } } - protected override void OnCompleted(SocketAsyncEventArgs _) - { - // When the operation completes, see if OnCompleted was already called to hook up a continuation. - // If it was, invoke the continuation. - Action? c = _continuation; - if (c != null || (c = Interlocked.CompareExchange(ref _continuation, s_completedSentinel, null)) != null) - { - Debug.Assert(c != s_completedSentinel, "The delegate should not have been the completed sentinel."); - - object? continuationState = UserToken; - UserToken = null; - _continuation = s_completedSentinel; // in case someone's polling IsCompleted - - ExecutionContext? ec = _executionContext; - if (ec == null) - { - InvokeContinuation(c, continuationState, forceAsync: false, requiresExecutionContextFlow: false); - } - else - { - // This case should be relatively rare, as the async Task/ValueTask method builders - // use the awaiter's UnsafeOnCompleted, so this will only happen with code that - // explicitly uses the awaiter's OnCompleted instead. - _executionContext = null; - ExecutionContext.Run(ec, runState => - { - var t = ((AwaitableSocketAsyncEventArgs, Action, object))runState!; - t.Item1.InvokeContinuation(t.Item2, t.Item3, forceAsync: false, requiresExecutionContextFlow: false); - }, (this, c, continuationState)); - } - } - } + protected override void OnCompleted(SocketAsyncEventArgs _) => _mrvtsc.SetResult(true); /// Initiates an accept operation on the associated socket. /// This instance. public ValueTask AcceptAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - if (socket.AcceptAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } Socket acceptSocket = AcceptSocket!; @@ -1013,7 +973,7 @@ public ValueTask AcceptAsync(Socket socket, CancellationToken cancellati AcceptSocket = null; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? new ValueTask(acceptSocket) : @@ -1024,18 +984,16 @@ public ValueTask AcceptAsync(Socket socket, CancellationToken cancellati /// This instance. public ValueTask ReceiveAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - if (socket.ReceiveAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } int bytesTransferred = BytesTransferred; SocketError error = SocketError; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? new ValueTask(bytesTransferred) : @@ -1044,19 +1002,17 @@ public ValueTask ReceiveAsync(Socket socket, CancellationToken cancellation public ValueTask ReceiveFromAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - if (socket.ReceiveFromAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } int bytesTransferred = BytesTransferred; EndPoint remoteEndPoint = RemoteEndPoint!; SocketError error = SocketError; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? new ValueTask(new SocketReceiveFromResult() { ReceivedBytes = bytesTransferred, RemoteEndPoint = remoteEndPoint }) : @@ -1065,12 +1021,10 @@ public ValueTask ReceiveFromAsync(Socket socket, Cancel public ValueTask ReceiveMessageFromAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - if (socket.ReceiveMessageFromAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } int bytesTransferred = BytesTransferred; @@ -1079,7 +1033,7 @@ public ValueTask ReceiveMessageFromAsync(Socket IPPacketInformation packetInformation = ReceiveMessageFromPacketInfo; SocketError error = SocketError; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? new ValueTask(new SocketReceiveMessageFromResult() { ReceivedBytes = bytesTransferred, RemoteEndPoint = remoteEndPoint, SocketFlags = socketFlags, PacketInformation = packetInformation }) : @@ -1090,18 +1044,16 @@ public ValueTask ReceiveMessageFromAsync(Socket /// This instance. public ValueTask SendAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - if (socket.SendAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } int bytesTransferred = BytesTransferred; SocketError error = SocketError; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? new ValueTask(bytesTransferred) : @@ -1110,17 +1062,15 @@ public ValueTask SendAsync(Socket socket, CancellationToken cancellationTok public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - if (socket.SendAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } SocketError error = SocketError; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? default : @@ -1129,17 +1079,15 @@ public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken canc public ValueTask SendPacketsAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - if (socket.SendPacketsAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } SocketError error = SocketError; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? default : @@ -1148,18 +1096,16 @@ public ValueTask SendPacketsAsync(Socket socket, CancellationToken cancellationT public ValueTask SendToAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - if (socket.SendToAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } int bytesTransferred = BytesTransferred; SocketError error = SocketError; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? new ValueTask(bytesTransferred) : @@ -1168,24 +1114,22 @@ public ValueTask SendToAsync(Socket socket, CancellationToken cancellationT public ValueTask ConnectAsync(Socket socket) { - Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - try { if (socket.ConnectAsync(this, userSocket: true, saeaCancelable: false)) { - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } } catch { - Release(); + ReleaseForSyncCompletion(); throw; } SocketError error = SocketError; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? default : @@ -1194,17 +1138,15 @@ public ValueTask ConnectAsync(Socket socket) public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); - if (socket.DisconnectAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; - return new ValueTask(this, _token); + return new ValueTask(this, _mrvtsc.Version); } SocketError error = SocketError; - Release(); + ReleaseForSyncCompletion(); return error == SocketError.Success ? ValueTask.CompletedTask : @@ -1212,108 +1154,12 @@ public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationTo } /// Gets the status of the operation. - public ValueTaskSourceStatus GetStatus(short token) - { - if (token != _token) - { - ThrowIncorrectTokenException(); - } - - return - !ReferenceEquals(_continuation, s_completedSentinel) ? ValueTaskSourceStatus.Pending : - SocketError == SocketError.Success ? ValueTaskSourceStatus.Succeeded : - ValueTaskSourceStatus.Faulted; - } + public ValueTaskSourceStatus GetStatus(short token) => + _mrvtsc.GetStatus(token); /// Queues the provided continuation to be executed once the operation has completed. - public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) - { - if (token != _token) - { - ThrowIncorrectTokenException(); - } - - if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0) - { - _executionContext = ExecutionContext.Capture(); - } - - if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0) - { - SynchronizationContext? sc = SynchronizationContext.Current; - if (sc != null && sc.GetType() != typeof(SynchronizationContext)) - { - _scheduler = sc; - } - else - { - TaskScheduler ts = TaskScheduler.Current; - if (ts != TaskScheduler.Default) - { - _scheduler = ts; - } - } - } - - UserToken = state; // Use UserToken to carry the continuation state around - Action? prevContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null); - if (ReferenceEquals(prevContinuation, s_completedSentinel)) - { - // Lost the race condition and the operation has now already completed. - // We need to invoke the continuation, but it must be asynchronously to - // avoid a stack dive. However, since all of the queueing mechanisms flow - // ExecutionContext, and since we're still in the same context where we - // captured it, we can just ignore the one we captured. - bool requiresExecutionContextFlow = _executionContext != null; - _executionContext = null; - UserToken = null; // we have the state in "state"; no need for the one in UserToken - InvokeContinuation(continuation, state, forceAsync: true, requiresExecutionContextFlow); - } - else if (prevContinuation != null) - { - // Flag errors with the continuation being hooked up multiple times. - // This is purely to help alert a developer to a bug they need to fix. - ThrowMultipleContinuationsException(); - } - } - - private void InvokeContinuation(Action continuation, object? state, bool forceAsync, bool requiresExecutionContextFlow) - { - object? scheduler = _scheduler; - _scheduler = null; - - if (scheduler != null) - { - if (scheduler is SynchronizationContext sc) - { - sc.Post(s => - { - var t = ((Action, object))s!; - t.Item1(t.Item2); - }, (continuation, state)); - } - else - { - Debug.Assert(scheduler is TaskScheduler, $"Expected TaskScheduler, got {scheduler}"); - Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, (TaskScheduler)scheduler); - } - } - else if (forceAsync) - { - if (requiresExecutionContextFlow) - { - ThreadPool.QueueUserWorkItem(continuation, state, preferLocal: true); - } - else - { - ThreadPool.UnsafeQueueUserWorkItem(continuation, state, preferLocal: true); - } - } - else - { - continuation(state); - } - } + public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => + _mrvtsc.OnCompleted(continuation, state, token, flags); /// Gets the result of the completion operation. /// Number of bytes transferred. @@ -1323,7 +1169,7 @@ private void InvokeContinuation(Action continuation, object? state, boo /// int IValueTaskSource.GetResult(short token) { - if (token != _token) + if (token != _mrvtsc.Version) { ThrowIncorrectTokenException(); } @@ -1332,7 +1178,7 @@ int IValueTaskSource.GetResult(short token) int bytes = BytesTransferred; CancellationToken cancellationToken = _cancellationToken; - Release(); + ReleaseForAsyncCompletion(); if (error != SocketError.Success) { @@ -1343,7 +1189,7 @@ int IValueTaskSource.GetResult(short token) void IValueTaskSource.GetResult(short token) { - if (token != _token) + if (token != _mrvtsc.Version) { ThrowIncorrectTokenException(); } @@ -1351,7 +1197,7 @@ void IValueTaskSource.GetResult(short token) SocketError error = SocketError; CancellationToken cancellationToken = _cancellationToken; - Release(); + ReleaseForAsyncCompletion(); if (error != SocketError.Success) { @@ -1361,7 +1207,7 @@ void IValueTaskSource.GetResult(short token) Socket IValueTaskSource.GetResult(short token) { - if (token != _token) + if (token != _mrvtsc.Version) { ThrowIncorrectTokenException(); } @@ -1372,7 +1218,7 @@ Socket IValueTaskSource.GetResult(short token) AcceptSocket = null; - Release(); + ReleaseForAsyncCompletion(); if (error != SocketError.Success) { @@ -1383,7 +1229,7 @@ Socket IValueTaskSource.GetResult(short token) SocketReceiveFromResult IValueTaskSource.GetResult(short token) { - if (token != _token) + if (token != _mrvtsc.Version) { ThrowIncorrectTokenException(); } @@ -1393,7 +1239,7 @@ SocketReceiveFromResult IValueTaskSource.GetResult(shor EndPoint remoteEndPoint = RemoteEndPoint!; CancellationToken cancellationToken = _cancellationToken; - Release(); + ReleaseForAsyncCompletion(); if (error != SocketError.Success) { @@ -1405,7 +1251,7 @@ SocketReceiveFromResult IValueTaskSource.GetResult(shor SocketReceiveMessageFromResult IValueTaskSource.GetResult(short token) { - if (token != _token) + if (token != _mrvtsc.Version) { ThrowIncorrectTokenException(); } @@ -1417,7 +1263,7 @@ SocketReceiveMessageFromResult IValueTaskSource. IPPacketInformation packetInformation = ReceiveMessageFromPacketInfo; CancellationToken cancellationToken = _cancellationToken; - Release(); + ReleaseForAsyncCompletion(); if (error != SocketError.Success) { @@ -1429,14 +1275,12 @@ SocketReceiveMessageFromResult IValueTaskSource. private static void ThrowIncorrectTokenException() => throw new InvalidOperationException(SR.InvalidOperation_IncorrectToken); - private static void ThrowMultipleContinuationsException() => throw new InvalidOperationException(SR.InvalidOperation_MultipleContinuations); - private void ThrowException(SocketError error, CancellationToken cancellationToken) { // Most operations will report OperationAborted when canceled. // On Windows, SendFileAsync will report ConnectionAborted. // There's a race here anyway, so there's no harm in also checking for ConnectionAborted in all cases. - if (error == SocketError.OperationAborted || error == SocketError.ConnectionAborted) + if (error is SocketError.OperationAborted or SocketError.ConnectionAborted) { cancellationToken.ThrowIfCancellationRequested(); }