diff --git a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx
index a01244712e108..72a5478e850ac 100644
--- a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx
+++ b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx
@@ -294,9 +294,6 @@
The result of the operation was already consumed and may not be used again.
-
- Another continuation was already registered.
-
The FileStream must have been opened for asynchronous reading and writing.
diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
index 7519dbd6d685d..9631bb66e8ca5 100644
--- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
+++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
@@ -918,26 +918,13 @@ internal AsyncTaskMethodBuilder GetCompletionResponsibility(out bool re
/// A SocketAsyncEventArgs that can be awaited to get the result of an operation.
internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource, IValueTaskSource, IValueTaskSource, IValueTaskSource
{
- private static readonly Action s_completedSentinel = new Action(state => throw new InvalidOperationException(SR.Format(SR.net_sockets_valuetaskmisuse, nameof(s_completedSentinel))));
/// The owning socket.
private readonly Socket _owner;
/// Whether this should be cached as a read or a write on the
private readonly bool _isReadForCaching;
- ///
- /// if it has completed. Another delegate if OnCompleted was called before the operation could complete,
- /// in which case it's the delegate to invoke when the operation does complete.
- ///
- private Action? _continuation;
- private ExecutionContext? _executionContext;
- private object? _scheduler;
- /// Current token value given to a ValueTask and then verified against the value it passes back to us.
- ///
- /// This is not meant to be a completely reliable mechanism, doesn't require additional synchronization, etc.
- /// It's purely a best effort attempt to catch misuse, including awaiting for a value task twice and after
- /// it's already being reused by someone else.
- ///
- private short _token;
- /// The cancellation token used for the current operation.
+ /// Core logic for the IValueTaskSource implementations.
+ private ManualResetValueTaskSourceCore _mrvtsc;
+ /// The cancellation token used for the current operation. Stored to propagate the most relevant exception.
private CancellationToken _cancellationToken;
/// Initializes the event args.
@@ -950,12 +937,18 @@ public AwaitableSocketAsyncEventArgs(Socket owner, bool isReceiveForCaching) :
public bool WrapExceptionsForNetworkStream { get; set; }
- private void Release()
+ /// Resets this instance after an asynchronous completion and puts it back into the pool.
+ private void ReleaseForAsyncCompletion()
{
_cancellationToken = default;
- _token++;
- _continuation = null;
+ _mrvtsc.Reset();
+ ReleaseForSyncCompletion();
+ }
+ /// Resets this instance after a synchronous completion and puts it back into the pool.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private void ReleaseForSyncCompletion()
+ {
ref AwaitableSocketAsyncEventArgs? cache = ref _isReadForCaching ? ref _owner._singleBufferReceiveEventArgs : ref _owner._singleBufferSendEventArgs;
if (Interlocked.CompareExchange(ref cache, this, null) != null)
{
@@ -963,49 +956,16 @@ private void Release()
}
}
- protected override void OnCompleted(SocketAsyncEventArgs _)
- {
- // When the operation completes, see if OnCompleted was already called to hook up a continuation.
- // If it was, invoke the continuation.
- Action? c = _continuation;
- if (c != null || (c = Interlocked.CompareExchange(ref _continuation, s_completedSentinel, null)) != null)
- {
- Debug.Assert(c != s_completedSentinel, "The delegate should not have been the completed sentinel.");
-
- object? continuationState = UserToken;
- UserToken = null;
- _continuation = s_completedSentinel; // in case someone's polling IsCompleted
-
- ExecutionContext? ec = _executionContext;
- if (ec == null)
- {
- InvokeContinuation(c, continuationState, forceAsync: false, requiresExecutionContextFlow: false);
- }
- else
- {
- // This case should be relatively rare, as the async Task/ValueTask method builders
- // use the awaiter's UnsafeOnCompleted, so this will only happen with code that
- // explicitly uses the awaiter's OnCompleted instead.
- _executionContext = null;
- ExecutionContext.Run(ec, runState =>
- {
- var t = ((AwaitableSocketAsyncEventArgs, Action, object))runState!;
- t.Item1.InvokeContinuation(t.Item2, t.Item3, forceAsync: false, requiresExecutionContextFlow: false);
- }, (this, c, continuationState));
- }
- }
- }
+ protected override void OnCompleted(SocketAsyncEventArgs _) => _mrvtsc.SetResult(true);
/// Initiates an accept operation on the associated socket.
/// This instance.
public ValueTask AcceptAsync(Socket socket, CancellationToken cancellationToken)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
-
if (socket.AcceptAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
Socket acceptSocket = AcceptSocket!;
@@ -1013,7 +973,7 @@ public ValueTask AcceptAsync(Socket socket, CancellationToken cancellati
AcceptSocket = null;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
new ValueTask(acceptSocket) :
@@ -1024,18 +984,16 @@ public ValueTask AcceptAsync(Socket socket, CancellationToken cancellati
/// This instance.
public ValueTask ReceiveAsync(Socket socket, CancellationToken cancellationToken)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
-
if (socket.ReceiveAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
int bytesTransferred = BytesTransferred;
SocketError error = SocketError;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
new ValueTask(bytesTransferred) :
@@ -1044,19 +1002,17 @@ public ValueTask ReceiveAsync(Socket socket, CancellationToken cancellation
public ValueTask ReceiveFromAsync(Socket socket, CancellationToken cancellationToken)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
-
if (socket.ReceiveFromAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
int bytesTransferred = BytesTransferred;
EndPoint remoteEndPoint = RemoteEndPoint!;
SocketError error = SocketError;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
new ValueTask(new SocketReceiveFromResult() { ReceivedBytes = bytesTransferred, RemoteEndPoint = remoteEndPoint }) :
@@ -1065,12 +1021,10 @@ public ValueTask ReceiveFromAsync(Socket socket, Cancel
public ValueTask ReceiveMessageFromAsync(Socket socket, CancellationToken cancellationToken)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
-
if (socket.ReceiveMessageFromAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
int bytesTransferred = BytesTransferred;
@@ -1079,7 +1033,7 @@ public ValueTask ReceiveMessageFromAsync(Socket
IPPacketInformation packetInformation = ReceiveMessageFromPacketInfo;
SocketError error = SocketError;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
new ValueTask(new SocketReceiveMessageFromResult() { ReceivedBytes = bytesTransferred, RemoteEndPoint = remoteEndPoint, SocketFlags = socketFlags, PacketInformation = packetInformation }) :
@@ -1090,18 +1044,16 @@ public ValueTask ReceiveMessageFromAsync(Socket
/// This instance.
public ValueTask SendAsync(Socket socket, CancellationToken cancellationToken)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
-
if (socket.SendAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
int bytesTransferred = BytesTransferred;
SocketError error = SocketError;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
new ValueTask(bytesTransferred) :
@@ -1110,17 +1062,15 @@ public ValueTask SendAsync(Socket socket, CancellationToken cancellationTok
public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken cancellationToken)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
-
if (socket.SendAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
SocketError error = SocketError;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
default :
@@ -1129,17 +1079,15 @@ public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken canc
public ValueTask SendPacketsAsync(Socket socket, CancellationToken cancellationToken)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
-
if (socket.SendPacketsAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
SocketError error = SocketError;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
default :
@@ -1148,18 +1096,16 @@ public ValueTask SendPacketsAsync(Socket socket, CancellationToken cancellationT
public ValueTask SendToAsync(Socket socket, CancellationToken cancellationToken)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
-
if (socket.SendToAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
int bytesTransferred = BytesTransferred;
SocketError error = SocketError;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
new ValueTask(bytesTransferred) :
@@ -1168,24 +1114,22 @@ public ValueTask SendToAsync(Socket socket, CancellationToken cancellationT
public ValueTask ConnectAsync(Socket socket)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
-
try
{
if (socket.ConnectAsync(this, userSocket: true, saeaCancelable: false))
{
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
}
catch
{
- Release();
+ ReleaseForSyncCompletion();
throw;
}
SocketError error = SocketError;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
default :
@@ -1194,17 +1138,15 @@ public ValueTask ConnectAsync(Socket socket)
public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationToken)
{
- Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");
-
if (socket.DisconnectAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
- return new ValueTask(this, _token);
+ return new ValueTask(this, _mrvtsc.Version);
}
SocketError error = SocketError;
- Release();
+ ReleaseForSyncCompletion();
return error == SocketError.Success ?
ValueTask.CompletedTask :
@@ -1212,108 +1154,12 @@ public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationTo
}
/// Gets the status of the operation.
- public ValueTaskSourceStatus GetStatus(short token)
- {
- if (token != _token)
- {
- ThrowIncorrectTokenException();
- }
-
- return
- !ReferenceEquals(_continuation, s_completedSentinel) ? ValueTaskSourceStatus.Pending :
- SocketError == SocketError.Success ? ValueTaskSourceStatus.Succeeded :
- ValueTaskSourceStatus.Faulted;
- }
+ public ValueTaskSourceStatus GetStatus(short token) =>
+ _mrvtsc.GetStatus(token);
/// Queues the provided continuation to be executed once the operation has completed.
- public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
- {
- if (token != _token)
- {
- ThrowIncorrectTokenException();
- }
-
- if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0)
- {
- _executionContext = ExecutionContext.Capture();
- }
-
- if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0)
- {
- SynchronizationContext? sc = SynchronizationContext.Current;
- if (sc != null && sc.GetType() != typeof(SynchronizationContext))
- {
- _scheduler = sc;
- }
- else
- {
- TaskScheduler ts = TaskScheduler.Current;
- if (ts != TaskScheduler.Default)
- {
- _scheduler = ts;
- }
- }
- }
-
- UserToken = state; // Use UserToken to carry the continuation state around
- Action? prevContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null);
- if (ReferenceEquals(prevContinuation, s_completedSentinel))
- {
- // Lost the race condition and the operation has now already completed.
- // We need to invoke the continuation, but it must be asynchronously to
- // avoid a stack dive. However, since all of the queueing mechanisms flow
- // ExecutionContext, and since we're still in the same context where we
- // captured it, we can just ignore the one we captured.
- bool requiresExecutionContextFlow = _executionContext != null;
- _executionContext = null;
- UserToken = null; // we have the state in "state"; no need for the one in UserToken
- InvokeContinuation(continuation, state, forceAsync: true, requiresExecutionContextFlow);
- }
- else if (prevContinuation != null)
- {
- // Flag errors with the continuation being hooked up multiple times.
- // This is purely to help alert a developer to a bug they need to fix.
- ThrowMultipleContinuationsException();
- }
- }
-
- private void InvokeContinuation(Action continuation, object? state, bool forceAsync, bool requiresExecutionContextFlow)
- {
- object? scheduler = _scheduler;
- _scheduler = null;
-
- if (scheduler != null)
- {
- if (scheduler is SynchronizationContext sc)
- {
- sc.Post(s =>
- {
- var t = ((Action, object))s!;
- t.Item1(t.Item2);
- }, (continuation, state));
- }
- else
- {
- Debug.Assert(scheduler is TaskScheduler, $"Expected TaskScheduler, got {scheduler}");
- Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, (TaskScheduler)scheduler);
- }
- }
- else if (forceAsync)
- {
- if (requiresExecutionContextFlow)
- {
- ThreadPool.QueueUserWorkItem(continuation, state, preferLocal: true);
- }
- else
- {
- ThreadPool.UnsafeQueueUserWorkItem(continuation, state, preferLocal: true);
- }
- }
- else
- {
- continuation(state);
- }
- }
+ public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) =>
+ _mrvtsc.OnCompleted(continuation, state, token, flags);
/// Gets the result of the completion operation.
/// Number of bytes transferred.
@@ -1323,7 +1169,7 @@ private void InvokeContinuation(Action continuation, object? state, boo
///
int IValueTaskSource.GetResult(short token)
{
- if (token != _token)
+ if (token != _mrvtsc.Version)
{
ThrowIncorrectTokenException();
}
@@ -1332,7 +1178,7 @@ int IValueTaskSource.GetResult(short token)
int bytes = BytesTransferred;
CancellationToken cancellationToken = _cancellationToken;
- Release();
+ ReleaseForAsyncCompletion();
if (error != SocketError.Success)
{
@@ -1343,7 +1189,7 @@ int IValueTaskSource.GetResult(short token)
void IValueTaskSource.GetResult(short token)
{
- if (token != _token)
+ if (token != _mrvtsc.Version)
{
ThrowIncorrectTokenException();
}
@@ -1351,7 +1197,7 @@ void IValueTaskSource.GetResult(short token)
SocketError error = SocketError;
CancellationToken cancellationToken = _cancellationToken;
- Release();
+ ReleaseForAsyncCompletion();
if (error != SocketError.Success)
{
@@ -1361,7 +1207,7 @@ void IValueTaskSource.GetResult(short token)
Socket IValueTaskSource.GetResult(short token)
{
- if (token != _token)
+ if (token != _mrvtsc.Version)
{
ThrowIncorrectTokenException();
}
@@ -1372,7 +1218,7 @@ Socket IValueTaskSource.GetResult(short token)
AcceptSocket = null;
- Release();
+ ReleaseForAsyncCompletion();
if (error != SocketError.Success)
{
@@ -1383,7 +1229,7 @@ Socket IValueTaskSource.GetResult(short token)
SocketReceiveFromResult IValueTaskSource.GetResult(short token)
{
- if (token != _token)
+ if (token != _mrvtsc.Version)
{
ThrowIncorrectTokenException();
}
@@ -1393,7 +1239,7 @@ SocketReceiveFromResult IValueTaskSource.GetResult(shor
EndPoint remoteEndPoint = RemoteEndPoint!;
CancellationToken cancellationToken = _cancellationToken;
- Release();
+ ReleaseForAsyncCompletion();
if (error != SocketError.Success)
{
@@ -1405,7 +1251,7 @@ SocketReceiveFromResult IValueTaskSource.GetResult(shor
SocketReceiveMessageFromResult IValueTaskSource.GetResult(short token)
{
- if (token != _token)
+ if (token != _mrvtsc.Version)
{
ThrowIncorrectTokenException();
}
@@ -1417,7 +1263,7 @@ SocketReceiveMessageFromResult IValueTaskSource.
IPPacketInformation packetInformation = ReceiveMessageFromPacketInfo;
CancellationToken cancellationToken = _cancellationToken;
- Release();
+ ReleaseForAsyncCompletion();
if (error != SocketError.Success)
{
@@ -1429,14 +1275,12 @@ SocketReceiveMessageFromResult IValueTaskSource.
private static void ThrowIncorrectTokenException() => throw new InvalidOperationException(SR.InvalidOperation_IncorrectToken);
- private static void ThrowMultipleContinuationsException() => throw new InvalidOperationException(SR.InvalidOperation_MultipleContinuations);
-
private void ThrowException(SocketError error, CancellationToken cancellationToken)
{
// Most operations will report OperationAborted when canceled.
// On Windows, SendFileAsync will report ConnectionAborted.
// There's a race here anyway, so there's no harm in also checking for ConnectionAborted in all cases.
- if (error == SocketError.OperationAborted || error == SocketError.ConnectionAborted)
+ if (error is SocketError.OperationAborted or SocketError.ConnectionAborted)
{
cancellationToken.ThrowIfCancellationRequested();
}