Skip to content

Commit

Permalink
implement ConnectAsync overloads with CancellationToken on Socket and…
Browse files Browse the repository at this point in the history
… TcpClient (#40750)

* implement ConnectAsync overloads with CancellationToken on Socket and TcpClient
  • Loading branch information
geoffkizer authored Aug 16, 2020
1 parent 91d7cd0 commit 1341c35
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ internal override async ValueTask ConnectAsync(CancellationToken cancellationTok
}

Socket socket = new Socket(_remoteEndPoint!.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(_remoteEndPoint).ConfigureAwait(false);
await socket.ConnectAsync(_remoteEndPoint, cancellationToken).ConfigureAwait(false);
socket.NoDelay = true;

_localEndPoint = (IPEndPoint?)socket.LocalEndPoint;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
<Compile Include="System\Net\Connections\IConnectionProperties.cs" />
<Compile Include="System\Net\Connections\Sockets\SocketConnection.cs" />
<Compile Include="System\Net\Connections\Sockets\SocketsConnectionFactory.cs" />
<Compile Include="System\Net\Connections\Sockets\TaskSocketAsyncEventArgs.cs" />
</ItemGroup>
<ItemGroup>
<Reference Include="System.Runtime" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,7 @@ public override async ValueTask<Connection> ConnectAsync(

try
{
using var args = new TaskSocketAsyncEventArgs();
args.RemoteEndPoint = endPoint;

if (socket.ConnectAsync(args))
{
using (cancellationToken.UnsafeRegister(static o => Socket.CancelConnectAsync((SocketAsyncEventArgs)o!), args))
{
await args.Task.ConfigureAwait(false);
}
}

if (args.SocketError != SocketError.Success)
{
if (args.SocketError == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
}

throw NetworkErrorHelper.MapSocketException(new SocketException((int)args.SocketError));
}

await socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false);
return new SocketConnection(socket);
}
catch (SocketException socketException)
Expand Down

This file was deleted.

7 changes: 7 additions & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,13 @@ public static partial class SocketTaskExtensions
public static System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync(this System.Net.Sockets.Socket socket) { throw null; }
public static System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync(this System.Net.Sockets.Socket socket, System.Net.Sockets.Socket? acceptSocket) { throw null; }
public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.EndPoint remoteEP) { throw null; }
public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken) { throw null; }
public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress address, int port) { throw null; }
public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress address, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress[] addresses, int port) { throw null; }
public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress[] addresses, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, string host, int port) { throw null; }
public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, string host, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public static System.Threading.Tasks.Task<int> ReceiveAsync(this System.Net.Sockets.Socket socket, System.ArraySegment<byte> buffer, System.Net.Sockets.SocketFlags socketFlags) { throw null; }
public static System.Threading.Tasks.Task<int> ReceiveAsync(this System.Net.Sockets.Socket socket, System.Collections.Generic.IList<System.ArraySegment<byte>> buffers, System.Net.Sockets.SocketFlags socketFlags) { throw null; }
public static System.Threading.Tasks.ValueTask<int> ReceiveAsync(this System.Net.Sockets.Socket socket, System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down Expand Up @@ -606,6 +610,9 @@ public void Connect(string hostname, int port) { }
public System.Threading.Tasks.Task ConnectAsync(System.Net.IPAddress address, int port) { throw null; }
public System.Threading.Tasks.Task ConnectAsync(System.Net.IPAddress[] addresses, int port) { throw null; }
public System.Threading.Tasks.Task ConnectAsync(string host, int port) { throw null; }
public System.Threading.Tasks.ValueTask ConnectAsync(System.Net.IPAddress address, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Threading.Tasks.ValueTask ConnectAsync(System.Net.IPAddress[] addresses, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Threading.Tasks.ValueTask ConnectAsync(string host, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
public void EndConnect(System.IAsyncResult asyncResult) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,59 @@ internal Task<Socket> AcceptAsync(Socket? acceptSocket)
return t;
}

internal Task ConnectAsync(EndPoint remoteEP)
internal Task ConnectAsync(EndPoint remoteEP) => ConnectAsync(remoteEP, default).AsTask();

internal ValueTask ConnectAsync(EndPoint remoteEP, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled(cancellationToken);
}

// Use _singleBufferReceiveEventArgs so the AwaitableSocketAsyncEventArgs can be re-used later for receives.
AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: true);

saea.RemoteEndPoint = remoteEP;
return saea.ConnectAsync(this).AsTask();

ValueTask connectTask = saea.ConnectAsync(this);
if (connectTask.IsCompleted || !cancellationToken.CanBeCanceled)
{
// Avoid async invocation overhead
return connectTask;
}
else
{
return WaitForConnectWithCancellation(saea, connectTask, cancellationToken);
}

async ValueTask WaitForConnectWithCancellation(AwaitableSocketAsyncEventArgs saea, ValueTask connectTask, CancellationToken cancellationToken)
{
Debug.Assert(cancellationToken.CanBeCanceled);
try
{
using (cancellationToken.UnsafeRegister(o => CancelConnectAsync((SocketAsyncEventArgs)o!), saea))
{
await connectTask.ConfigureAwait(false);
}
}
catch (SocketException se) when (se.SocketErrorCode == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
throw;
}
}

}

internal Task ConnectAsync(IPAddress address, int port) => ConnectAsync(new IPEndPoint(address, port));

internal Task ConnectAsync(IPAddress[] addresses, int port)
internal ValueTask ConnectAsync(IPAddress address, int port, CancellationToken cancellationToken) => ConnectAsync(new IPEndPoint(address, port), cancellationToken);

internal Task ConnectAsync(IPAddress[] addresses, int port) => ConnectAsync(addresses, port, CancellationToken.None).AsTask();

internal ValueTask ConnectAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken)
{
if (addresses == null)
{
Expand All @@ -98,20 +137,20 @@ internal Task ConnectAsync(IPAddress[] addresses, int port)
throw new ArgumentException(SR.net_invalidAddressList, nameof(addresses));
}

return DoConnectAsync(addresses, port);
return DoConnectAsync(addresses, port, cancellationToken);
}

private async Task DoConnectAsync(IPAddress[] addresses, int port)
private async ValueTask DoConnectAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken)
{
Exception? lastException = null;
foreach (IPAddress address in addresses)
{
try
{
await ConnectAsync(address, port).ConfigureAwait(false);
await ConnectAsync(address, port, cancellationToken).ConfigureAwait(false);
return;
}
catch (Exception ex)
catch (Exception ex) when (ex is not OperationCanceledException)
{
lastException = ex;
}
Expand All @@ -121,7 +160,9 @@ private async Task DoConnectAsync(IPAddress[] addresses, int port)
ExceptionDispatchInfo.Throw(lastException);
}

internal Task ConnectAsync(string host, int port)
internal Task ConnectAsync(string host, int port) => ConnectAsync(host, port, default).AsTask();

internal ValueTask ConnectAsync(string host, int port, CancellationToken cancellationToken)
{
if (host == null)
{
Expand All @@ -131,7 +172,7 @@ internal Task ConnectAsync(string host, int port)
EndPoint ep = IPAddress.TryParse(host, out IPAddress? parsedAddress) ? (EndPoint)
new IPEndPoint(parsedAddress, port) :
new DnsEndPoint(host, port);
return ConnectAsync(ep);
return ConnectAsync(ep, cancellationToken);
}

internal Task<int> ReceiveAsync(ArraySegment<byte> buffer, SocketFlags socketFlags, bool fromNetworkStream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@ public static Task<Socket> AcceptAsync(this Socket socket, Socket? acceptSocket)

public static Task ConnectAsync(this Socket socket, EndPoint remoteEP) =>
socket.ConnectAsync(remoteEP);
public static ValueTask ConnectAsync(this Socket socket, EndPoint remoteEP, CancellationToken cancellationToken) =>
socket.ConnectAsync(remoteEP, cancellationToken);
public static Task ConnectAsync(this Socket socket, IPAddress address, int port) =>
socket.ConnectAsync(address, port);
public static ValueTask ConnectAsync(this Socket socket, IPAddress address, int port, CancellationToken cancellationToken) =>
socket.ConnectAsync(address, port, cancellationToken);
public static Task ConnectAsync(this Socket socket, IPAddress[] addresses, int port) =>
socket.ConnectAsync(addresses, port);
public static ValueTask ConnectAsync(this Socket socket, IPAddress[] addresses, int port, CancellationToken cancellationToken) =>
socket.ConnectAsync(addresses, port, cancellationToken);
public static Task ConnectAsync(this Socket socket, string host, int port) =>
socket.ConnectAsync(host, port);
public static ValueTask ConnectAsync(this Socket socket, string host, int port, CancellationToken cancellationToken) =>
socket.ConnectAsync(host, port, cancellationToken);

public static Task<int> ReceiveAsync(this Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) =>
socket.ReceiveAsync(buffer, socketFlags, fromNetworkStream: false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,8 @@ public void Connect(IPAddress[] ipAddresses, int port)
_active = true;
}

public Task ConnectAsync(IPAddress address, int port)
{

Task result = CompleteConnectAsync(Client.ConnectAsync(address, port));

return result;
}
public Task ConnectAsync(IPAddress address, int port) =>
CompleteConnectAsync(Client.ConnectAsync(address, port));

public Task ConnectAsync(string host, int port) =>
CompleteConnectAsync(Client.ConnectAsync(host, port));
Expand All @@ -279,6 +274,21 @@ private async Task CompleteConnectAsync(Task task)
_active = true;
}

public ValueTask ConnectAsync(IPAddress address, int port, CancellationToken cancellationToken) =>
CompleteConnectAsync(Client.ConnectAsync(address, port, cancellationToken));

public ValueTask ConnectAsync(string host, int port, CancellationToken cancellationToken) =>
CompleteConnectAsync(Client.ConnectAsync(host, port, cancellationToken));

public ValueTask ConnectAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken) =>
CompleteConnectAsync(Client.ConnectAsync(addresses, port, cancellationToken));

private async ValueTask CompleteConnectAsync(ValueTask task)
{
await task.ConfigureAwait(false);
_active = true;
}

public IAsyncResult BeginConnect(IPAddress address, int port, AsyncCallback? requestCallback, object? state) =>
Client.BeginConnect(address, port, requestCallback, state);

Expand Down
Loading

0 comments on commit 1341c35

Please sign in to comment.