Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,33 +97,7 @@ public ValueTask ConnectAsync(EndPoint remoteEP, CancellationToken cancellationT

saea.RemoteEndPoint = remoteEP;

ValueTask connectTask = saea.ConnectAsync(this, saeaCancelable: cancellationToken.CanBeCanceled);
if (connectTask.IsCompleted || !cancellationToken.CanBeCanceled)
{
// Avoid async invocation overhead
return connectTask;
}
else
{
return WaitForConnectWithCancellation(saea, connectTask, cancellationToken);
}

static async ValueTask WaitForConnectWithCancellation(AwaitableSocketAsyncEventArgs saea, ValueTask connectTask, CancellationToken cancellationToken)
{
Debug.Assert(cancellationToken.CanBeCanceled);
try
{
using (cancellationToken.UnsafeRegister(o => CancelConnectAsync((SocketAsyncEventArgs)o!), saea))
{
await connectTask.ConfigureAwait(false);
}
}
catch (SocketException se) when (se.SocketErrorCode == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
throw;
}
}
return saea.ConnectAsync(this, cancellationToken);
}

/// <summary>
Expand Down Expand Up @@ -1210,12 +1184,13 @@ public ValueTask<int> SendToAsync(Socket socket, CancellationToken cancellationT
ValueTask.FromException<int>(CreateException(error));
}

public ValueTask ConnectAsync(Socket socket, bool saeaCancelable)
public ValueTask ConnectAsync(Socket socket, CancellationToken cancellationToken)
{
try
{
if (socket.ConnectAsync(this, userSocket: true, saeaCancelable: saeaCancelable))
if (socket.ConnectAsync(this, userSocket: true, saeaMultiConnectCancelable: false, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask(this, _mrvtsc.Version);
}
}
Expand Down
22 changes: 13 additions & 9 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2875,11 +2875,14 @@ private bool AcceptAsync(SocketAsyncEventArgs e, CancellationToken cancellationT
}

public bool ConnectAsync(SocketAsyncEventArgs e) =>
ConnectAsync(e, userSocket: true, saeaCancelable: true);
ConnectAsync(e, userSocket: true, saeaMultiConnectCancelable: true);

internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCancelable)
internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaMultiConnectCancelable, CancellationToken cancellationToken = default)
{
bool pending;
// saeaMultiConnectCancelable == true means that this method is being called by a SocketAsyncEventArgs-based top level API.
// In such cases, SocketAsyncEventArgs.StartOperationConnect() will set up an internal cancellation token (_multipleConnectCancellation)
// to support cancelling DNS multi-connect for Socket.CancelConnectAsync().
Debug.Assert(!saeaMultiConnectCancelable || cancellationToken == default);

ThrowIfDisposed();

Expand All @@ -2901,6 +2904,7 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCan
EndPoint? endPointSnapshot = e.RemoteEndPoint;
DnsEndPoint? dnsEP = endPointSnapshot as DnsEndPoint;

bool pending;
if (dnsEP != null)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.ConnectedAsyncDns(this);
Expand All @@ -2911,10 +2915,10 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCan
}

e.StartOperationCommon(this, SocketAsyncOperation.Connect);
e.StartOperationConnect(saeaCancelable, userSocket);
e.StartOperationConnect(saeaMultiConnectCancelable, userSocket);
try
{
pending = e.DnsConnectAsync(dnsEP, default, default);
pending = e.DnsConnectAsync(dnsEP, default, default, cancellationToken);
}
catch
{
Expand Down Expand Up @@ -2957,8 +2961,8 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCan
// ConnectEx supports connection-oriented sockets but not UDS. The socket must be bound before calling ConnectEx.
bool canUseConnectEx = _socketType == SocketType.Stream && endPointSnapshot.AddressFamily != AddressFamily.Unix;
SocketError socketError = canUseConnectEx ?
e.DoOperationConnectEx(this, _handle) :
e.DoOperationConnect(_handle); // For connectionless protocols, Connect is not an I/O call.
e.DoOperationConnectEx(this, _handle, cancellationToken) :
e.DoOperationConnect(_handle, cancellationToken); // For connectionless protocols, Connect is not an I/O call.
pending = socketError == SocketError.IOPending;
}
catch (Exception ex)
Expand Down Expand Up @@ -3001,7 +3005,7 @@ public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType
e.StartOperationConnect(saeaMultiConnectCancelable: true, userSocket: false);
try
{
pending = e.DnsConnectAsync(dnsEP, socketType, protocolType);
pending = e.DnsConnectAsync(dnsEP, socketType, protocolType, cancellationToken: default);
}
catch
{
Expand All @@ -3012,7 +3016,7 @@ public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType
else
{
Socket attemptSocket = new Socket(endPointSnapshot.AddressFamily, socketType, protocolType);
pending = attemptSocket.ConnectAsync(e, userSocket: false, saeaCancelable: true);
pending = attemptSocket.ConnectAsync(e, userSocket: false, saeaMultiConnectCancelable: true);
}

return pending;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1536,7 +1536,7 @@ public SocketError Connect(Memory<byte> socketAddress)
return operation.ErrorCode;
}

public SocketError ConnectAsync(Memory<byte> socketAddress, Action<int, Memory<byte>, SocketFlags, SocketError> callback, Memory<byte> buffer, out int sentBytes)
public SocketError ConnectAsync(Memory<byte> socketAddress, Action<int, Memory<byte>, SocketFlags, SocketError> callback, Memory<byte> buffer, out int sentBytes, CancellationToken cancellationToken)
{
Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}");
Debug.Assert(callback != null, "Expected non-null callback");
Expand Down Expand Up @@ -1574,7 +1574,7 @@ public SocketError ConnectAsync(Memory<byte> socketAddress, Action<int, Memory<b
BytesTransferred = sentBytes,
};

if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
{
if (operation.ErrorCode == SocketError.Success)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,19 @@ private void ConnectCompletionCallback(int bytesTransferred, Memory<byte> socket
CompletionCallback(bytesTransferred, SocketFlags.None, socketError);
}

internal SocketError DoOperationConnectEx(Socket _ /*socket*/, SafeSocketHandle handle)
internal SocketError DoOperationConnectEx(Socket _ /*socket*/, SafeSocketHandle handle, CancellationToken cancellationToken)
{
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback, _buffer.Slice(_offset, _count), out int sentBytes);
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback, _buffer.Slice(_offset, _count), out int sentBytes, cancellationToken);
if (socketError != SocketError.IOPending)
{
FinishOperationSync(socketError, sentBytes, SocketFlags.None);
}
return socketError;
}

internal SocketError DoOperationConnect(SafeSocketHandle handle)
internal SocketError DoOperationConnect(SafeSocketHandle handle, CancellationToken cancellationToken)
{
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback, Memory<byte>.Empty, out int _);
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback, Memory<byte>.Empty, out int _, cancellationToken);
if (socketError != SocketError.IOPending)
{
FinishOperationSync(socketError, 0, SocketFlags.None);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,17 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha
}
}

internal SocketError DoOperationConnect(SafeSocketHandle handle)
#pragma warning disable IDE0060
internal SocketError DoOperationConnect(SafeSocketHandle handle, CancellationToken cancellationToken)
#pragma warning restore IDE0060
{
// Called for connectionless protocols.
SocketError socketError = SocketPal.Connect(handle, _socketAddress!.Buffer);
FinishOperationSync(socketError, 0, SocketFlags.None);
return socketError;
}

internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle handle)
internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
Debug.Assert(_asyncCompletionOwnership == 0, $"Expected 0, got {_asyncCompletionOwnership}");

Expand All @@ -313,7 +315,7 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle
out int bytesTransferred,
overlapped);

return ProcessIOCPResult(success, bytesTransferred, ref overlapped, _buffer, cancellationToken: default);
return ProcessIOCPResult(success, bytesTransferred, ref overlapped, _buffer, cancellationToken);
}
catch when (overlapped is not null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,14 +676,20 @@ internal void FinishOperationAsyncFailure(SocketError socketError, int bytesTran
/// <param name="endPoint">The DNS end point to which to connect.</param>
/// <param name="socketType">The SocketType to use to construct new sockets, if necessary.</param>
/// <param name="protocolType">The ProtocolType to use to construct new sockets, if necessary.</param>
/// <param name="cancellationToken">The CancellationToken.</param>
/// <returns>true if the operation is pending; otherwise, false if it's already completed.</returns>
internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, ProtocolType protocolType)
internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, ProtocolType protocolType, CancellationToken cancellationToken)
{
Debug.Assert(endPoint.AddressFamily == AddressFamily.Unspecified ||
endPoint.AddressFamily == AddressFamily.InterNetwork ||
endPoint.AddressFamily == AddressFamily.InterNetworkV6);

CancellationToken cancellationToken = _multipleConnectCancellation?.Token ?? default;
if (_multipleConnectCancellation is not null)
{
Debug.Assert(!cancellationToken.CanBeCanceled, "Task-based connect logic should not use _multipleConnectCancellation for cancellation.");
// We registered a CancellationTokenSource in StartOperationConnect.
cancellationToken = _multipleConnectCancellation.Token;
}

// In .NET 5 and earlier, the APM implementation allowed for synchronous exceptions from this to propagate
// synchronously. This call is made here rather than in the Core async method below to preserve that behavior.
Expand Down Expand Up @@ -774,12 +780,9 @@ async Task Core(MultiConnectSocketAsyncEventArgs internalArgs, Task<IPAddress[]>
}

// Issue the connect. If it pends, wait for it to complete.
if (attemptSocket.ConnectAsync(internalArgs))
if (attemptSocket.ConnectAsync(internalArgs, userSocket: true, saeaMultiConnectCancelable: false, cancellationToken))
{
using (cancellationToken.UnsafeRegister(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s!), internalArgs))
{
await new ValueTask(internalArgs, internalArgs.Version).ConfigureAwait(false);
}
await new ValueTask(internalArgs, internalArgs.Version).ConfigureAwait(false);
}

// If it completed successfully, we're done; cleanup will be handled by the finally.
Expand Down
Loading