Skip to content

Commit

Permalink
add AcceptAsync cancellation overloads (#53340)
Browse files Browse the repository at this point in the history
* add AcceptAsync cancellation overloads

* pass cancellationToken to AcceptAsync in Unix NamedPipe impl

* add TcpListener overloads too

* enable pipe cancellation test on Unix

Co-authored-by: Geoffrey Kizer <[email protected]>
  • Loading branch information
geoffkizer and Geoffrey Kizer authored May 29, 2021
1 parent f881cf3 commit b5c91e4
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public Task WaitForConnectionAsync(CancellationToken cancellationToken)
WaitForConnectionAsyncCore();

async Task WaitForConnectionAsyncCore() =>
HandleAcceptedSocket(await _instance!.ListeningSocket.AcceptAsync().ConfigureAwait(false));
HandleAcceptedSocket(await _instance!.ListeningSocket.AcceptAsync(cancellationToken).ConfigureAwait(false));
}

private void HandleAcceptedSocket(Socket acceptedSocket)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,10 @@ public async Task CancelTokenOn_ServerWaitForConnectionAsync_Throws_OperationCan

var ctx = new CancellationTokenSource();

if (OperatingSystem.IsWindows()) // cancellation token after the operation has been initiated
{
Task serverWaitTimeout = server.WaitForConnectionAsync(ctx.Token);
ctx.Cancel();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => serverWaitTimeout);
}

Task serverWaitTimeout = server.WaitForConnectionAsync(ctx.Token);
ctx.Cancel();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => serverWaitTimeout);

Assert.True(server.WaitForConnectionAsync(ctx.Token).IsCanceled);
}

Expand Down
4 changes: 4 additions & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ public Socket(System.Net.Sockets.SocketType socketType, System.Net.Sockets.Proto
public bool UseOnlyOverlappedIO { get { throw null; } set { } }
public System.Net.Sockets.Socket Accept() { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync() { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync(System.Net.Sockets.Socket? acceptSocket) { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptAsync(System.Net.Sockets.Socket? acceptSocket, System.Threading.CancellationToken cancellationToken) { throw null; }
public bool AcceptAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public System.IAsyncResult BeginAccept(System.AsyncCallback? callback, object? state) { throw null; }
public System.IAsyncResult BeginAccept(int receiveSize, System.AsyncCallback? callback, object? state) { throw null; }
Expand Down Expand Up @@ -691,8 +693,10 @@ public TcpListener(System.Net.IPEndPoint localEP) { }
public System.Net.Sockets.Socket Server { get { throw null; } }
public System.Net.Sockets.Socket AcceptSocket() { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptSocketAsync() { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptSocketAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Net.Sockets.TcpClient AcceptTcpClient() { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.TcpClient> AcceptTcpClientAsync() { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.TcpClient> AcceptTcpClientAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
public void AllowNatTraversal(bool allowed) { }
public System.IAsyncResult BeginAcceptSocket(System.AsyncCallback? callback, object? state) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@ namespace System.Net.Sockets
{
public partial class Socket
{
/// <summary>Cached instance for accept operations.</summary>
private TaskSocketAsyncEventArgs<Socket>? _acceptEventArgs;

/// <summary>Cached instance for receive operations that return <see cref="ValueTask{Int32}"/>. Also used for ConnectAsync operations.</summary>
private AwaitableSocketAsyncEventArgs? _singleBufferReceiveEventArgs;
/// <summary>Cached instance for send operations that return <see cref="ValueTask{Int32}"/>.</summary>
/// <summary>Cached instance for send operations that return <see cref="ValueTask{Int32}"/>. Also used for AcceptAsync operations.</summary>
private AwaitableSocketAsyncEventArgs? _singleBufferSendEventArgs;

/// <summary>Cached instance for receive operations that return <see cref="Task{Int32}"/>.</summary>
Expand All @@ -32,54 +29,44 @@ public partial class Socket
/// Accepts an incoming connection.
/// </summary>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public Task<Socket> AcceptAsync() => AcceptAsync((Socket?)null);
public Task<Socket> AcceptAsync() => AcceptAsync((Socket?)null, CancellationToken.None).AsTask();

/// <summary>
/// Accepts an incoming connection.
/// </summary>
/// <param name="acceptSocket">The socket to use for accepting the connection.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public Task<Socket> AcceptAsync(Socket? acceptSocket)
{
// Get any cached SocketAsyncEventArg we may have.
TaskSocketAsyncEventArgs<Socket>? saea = Interlocked.Exchange(ref _acceptEventArgs, null);
if (saea is null)
{
saea = new TaskSocketAsyncEventArgs<Socket>();
saea.Completed += (s, e) => CompleteAccept((Socket)s!, (TaskSocketAsyncEventArgs<Socket>)e);
}
public ValueTask<Socket> AcceptAsync(CancellationToken cancellationToken) => AcceptAsync((Socket?)null, cancellationToken);

// Configure the SAEA.
saea.AcceptSocket = acceptSocket;
/// <summary>
/// Accepts an incoming connection.
/// </summary>
/// <param name="acceptSocket">The socket to use for accepting the connection.</param>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public Task<Socket> AcceptAsync(Socket? acceptSocket) => AcceptAsync(acceptSocket, CancellationToken.None).AsTask();

// Initiate the accept operation.
Task<Socket> t;
if (AcceptAsync(saea))
/// <summary>
/// Accepts an incoming connection.
/// </summary>
/// <param name="acceptSocket">The socket to use for accepting the connection.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public ValueTask<Socket> AcceptAsync(Socket? acceptSocket, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
// The operation is completing asynchronously (it may have already completed).
// Get the task for the operation, with appropriate synchronization to coordinate
// with the async callback that'll be completing the task.
bool responsibleForReturningToPool;
t = saea.GetCompletionResponsibility(out responsibleForReturningToPool).Task;
if (responsibleForReturningToPool)
{
// We're responsible for returning it only if the callback has already been invoked
// and gotten what it needs from the SAEA; otherwise, the callback will return it.
ReturnSocketAsyncEventArgs(saea);
}
return ValueTask.FromCanceled<Socket>(cancellationToken);
}
else
{
// The operation completed synchronously. Get a task for it.
t = saea.SocketError == SocketError.Success ?
Task.FromResult(saea.AcceptSocket!) :
Task.FromException<Socket>(GetException(saea.SocketError));

// There won't be a callback, and we're done with the SAEA, so return it to the pool.
ReturnSocketAsyncEventArgs(saea);
}
AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);

return t;
Debug.Assert(saea.BufferList == null);
saea.SetBuffer(null, 0, 0);
saea.AcceptSocket = acceptSocket;
saea.WrapExceptionsForNetworkStream = false;
return saea.AcceptAsync(this, cancellationToken);
}

/// <summary>
Expand Down Expand Up @@ -738,34 +725,6 @@ private Task<int> GetTaskForSendReceive(bool pending, TaskSocketAsyncEventArgs<i
return t;
}

/// <summary>Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool.</summary>
private static void CompleteAccept(Socket s, TaskSocketAsyncEventArgs<Socket> saea)
{
// Pull the relevant state off of the SAEA
SocketError error = saea.SocketError;
Socket? acceptSocket = saea.AcceptSocket;

// Synchronize with the initiating thread. If the synchronous caller already got what
// it needs from the SAEA, then we can return it to the pool now. Otherwise, it'll be
// responsible for returning it once it's gotten what it needs from it.
bool responsibleForReturningToPool;
AsyncTaskMethodBuilder<Socket> builder = saea.GetCompletionResponsibility(out responsibleForReturningToPool);
if (responsibleForReturningToPool)
{
s.ReturnSocketAsyncEventArgs(saea);
}

// Complete the builder/task with the results.
if (error == SocketError.Success)
{
builder.SetResult(acceptSocket!);
}
else
{
builder.SetException(GetException(error));
}
}

/// <summary>Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool.</summary>
private static void CompleteSendReceive(Socket s, TaskSocketAsyncEventArgs<int> saea, bool isReceive)
{
Expand Down Expand Up @@ -824,29 +783,9 @@ private void ReturnSocketAsyncEventArgs(TaskSocketAsyncEventArgs<int> saea, bool
}
}

/// <summary>Returns a <see cref="TaskSocketAsyncEventArgs{TResult}"/> instance for reuse.</summary>
/// <param name="saea">The instance to return.</param>
private void ReturnSocketAsyncEventArgs(TaskSocketAsyncEventArgs<Socket> saea)
{
// Reset state on the SAEA before returning it. But do not reset buffer state. That'll be done
// if necessary by the consumer, but we want to keep the buffers due to likely subsequent reuse
// and the costs associated with changing them.
saea.AcceptSocket = null;
saea._accessed = false;
saea._builder = default;

// Write this instance back as a cached instance, only if there isn't currently one cached.
if (Interlocked.CompareExchange(ref _acceptEventArgs, saea, null) != null)
{
// Couldn't return it, so dispose it.
saea.Dispose();
}
}

/// <summary>Dispose of any cached <see cref="TaskSocketAsyncEventArgs{TResult}"/> instances.</summary>
private void DisposeCachedTaskSocketAsyncEventArgs()
{
Interlocked.Exchange(ref _acceptEventArgs, null)?.Dispose();
Interlocked.Exchange(ref _multiBufferReceiveEventArgs, null)?.Dispose();
Interlocked.Exchange(ref _multiBufferSendEventArgs, null)?.Dispose();
Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null)?.Dispose();
Expand Down Expand Up @@ -907,7 +846,7 @@ internal AsyncTaskMethodBuilder<TResult> GetCompletionResponsibility(out bool re
}

/// <summary>A SocketAsyncEventArgs that can be awaited to get the result of an operation.</summary>
internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource<int>, IValueTaskSource<SocketReceiveFromResult>, IValueTaskSource<SocketReceiveMessageFromResult>
internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource<int>, IValueTaskSource<Socket>, IValueTaskSource<SocketReceiveFromResult>, IValueTaskSource<SocketReceiveMessageFromResult>
{
private static readonly Action<object?> s_completedSentinel = new Action<object?>(state => throw new InvalidOperationException(SR.Format(SR.net_sockets_valuetaskmisuse, nameof(s_completedSentinel))));
/// <summary>The owning socket.</summary>
Expand Down Expand Up @@ -987,6 +926,28 @@ protected override void OnCompleted(SocketAsyncEventArgs _)
}
}

/// <summary>Initiates an accept operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<Socket> AcceptAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");

if (socket.AcceptAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask<Socket>(this, _token);
}

Socket acceptSocket = AcceptSocket!;
SocketError error = SocketError;

Release();

return error == SocketError.Success ?
new ValueTask<Socket>(acceptSocket) :
ValueTask.FromException<Socket>(CreateException(error));
}

/// <summary>Initiates a receive operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<int> ReceiveAsync(Socket socket, CancellationToken cancellationToken)
Expand Down Expand Up @@ -1288,7 +1249,7 @@ private void InvokeContinuation(Action<object?> continuation, object? state, boo
/// Unlike TaskAwaiter's GetResult, this does not block until the operation completes: it must only
/// be used once the operation has completed. This is handled implicitly by await.
/// </remarks>
public int GetResult(short token)
int IValueTaskSource<int>.GetResult(short token)
{
if (token != _token)
{
Expand Down Expand Up @@ -1326,6 +1287,26 @@ void IValueTaskSource.GetResult(short token)
}
}

Socket IValueTaskSource<Socket>.GetResult(short token)
{
if (token != _token)
{
ThrowIncorrectTokenException();
}

SocketError error = SocketError;
Socket acceptSocket = AcceptSocket!;
CancellationToken cancellationToken = _cancellationToken;

Release();

if (error != SocketError.Success)
{
ThrowException(error, cancellationToken);
}
return acceptSocket;
}

SocketReceiveFromResult IValueTaskSource<SocketReceiveFromResult>.GetResult(short token)
{
if (token != _token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2656,7 +2656,9 @@ public void Shutdown(SocketShutdown how)
// Async methods
//

public bool AcceptAsync(SocketAsyncEventArgs e)
public bool AcceptAsync(SocketAsyncEventArgs e) => AcceptAsync(e, CancellationToken.None);

private bool AcceptAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
ThrowIfDisposed();

Expand Down Expand Up @@ -2689,7 +2691,7 @@ public bool AcceptAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationAccept(this, _handle, acceptHandle);
socketError = e.DoOperationAccept(this, _handle, acceptHandle, cancellationToken);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ public SocketError Accept(byte[] socketAddress, ref int socketAddressLen, out In
return operation.ErrorCode;
}

public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action<IntPtr, byte[], int, SocketError> callback)
public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action<IntPtr, byte[], int, SocketError> callback, CancellationToken cancellationToken)
{
Debug.Assert(socketAddress != null, "Expected non-null socketAddress");
Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}");
Expand All @@ -1456,7 +1456,7 @@ public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, o
operation.SocketAddress = socketAddress;
operation.SocketAddressLen = socketAddressLen;

if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
{
socketAddressLen = operation.SocketAddressLen;
acceptedFd = operation.AcceptedFileDescriptor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private void CompleteAcceptOperation(IntPtr acceptedFileDescriptor, byte[] socke
_acceptAddressBufferCount = socketAddressSize;
}

internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle)
internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle, CancellationToken cancellationToken)
{
if (!_buffer.Equals(default))
{
Expand All @@ -64,7 +64,7 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha

IntPtr acceptedFd;
int socketAddressLen = _acceptAddressBufferCount / 2;
SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback);
SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken);

if (socketError != SocketError.IOPending)
{
Expand Down
Loading

0 comments on commit b5c91e4

Please sign in to comment.