Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement ConnectAsync overloads with CancellationToken on Socket and TcpClient #40750

Merged
merged 8 commits into from
Aug 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ internal override async ValueTask ConnectAsync(CancellationToken cancellationTok
}

Socket socket = new Socket(_remoteEndPoint!.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(_remoteEndPoint).ConfigureAwait(false);
await socket.ConnectAsync(_remoteEndPoint, cancellationToken).ConfigureAwait(false);
socket.NoDelay = true;

_localEndPoint = (IPEndPoint?)socket.LocalEndPoint;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
<Compile Include="System\Net\Connections\IConnectionProperties.cs" />
<Compile Include="System\Net\Connections\Sockets\SocketConnection.cs" />
<Compile Include="System\Net\Connections\Sockets\SocketsConnectionFactory.cs" />
<Compile Include="System\Net\Connections\Sockets\TaskSocketAsyncEventArgs.cs" />
</ItemGroup>
<ItemGroup>
<Reference Include="System.Runtime" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,7 @@ public override async ValueTask<Connection> ConnectAsync(

try
{
using var args = new TaskSocketAsyncEventArgs();
args.RemoteEndPoint = endPoint;

if (socket.ConnectAsync(args))
{
using (cancellationToken.UnsafeRegister(static o => Socket.CancelConnectAsync((SocketAsyncEventArgs)o!), args))
{
await args.Task.ConfigureAwait(false);
}
}

if (args.SocketError != SocketError.Success)
{
if (args.SocketError == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
}

throw NetworkErrorHelper.MapSocketException(new SocketException((int)args.SocketError));
}

await socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false);
return new SocketConnection(socket);
}
catch (SocketException socketException)
Expand Down

This file was deleted.

7 changes: 7 additions & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,13 @@ public static partial class SocketTaskExtensions
public static System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync(this System.Net.Sockets.Socket socket) { throw null; }
public static System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync(this System.Net.Sockets.Socket socket, System.Net.Sockets.Socket? acceptSocket) { throw null; }
public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.EndPoint remoteEP) { throw null; }
public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken) { throw null; }
public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress address, int port) { throw null; }
public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress address, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress[] addresses, int port) { throw null; }
public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress[] addresses, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, string host, int port) { throw null; }
public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, string host, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public static System.Threading.Tasks.Task<int> ReceiveAsync(this System.Net.Sockets.Socket socket, System.ArraySegment<byte> buffer, System.Net.Sockets.SocketFlags socketFlags) { throw null; }
public static System.Threading.Tasks.Task<int> ReceiveAsync(this System.Net.Sockets.Socket socket, System.Collections.Generic.IList<System.ArraySegment<byte>> buffers, System.Net.Sockets.SocketFlags socketFlags) { throw null; }
public static System.Threading.Tasks.ValueTask<int> ReceiveAsync(this System.Net.Sockets.Socket socket, System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down Expand Up @@ -606,6 +610,9 @@ public void Connect(string hostname, int port) { }
public System.Threading.Tasks.Task ConnectAsync(System.Net.IPAddress address, int port) { throw null; }
public System.Threading.Tasks.Task ConnectAsync(System.Net.IPAddress[] addresses, int port) { throw null; }
public System.Threading.Tasks.Task ConnectAsync(string host, int port) { throw null; }
public System.Threading.Tasks.ValueTask ConnectAsync(System.Net.IPAddress address, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Threading.Tasks.ValueTask ConnectAsync(System.Net.IPAddress[] addresses, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Threading.Tasks.ValueTask ConnectAsync(string host, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
public void EndConnect(System.IAsyncResult asyncResult) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,59 @@ internal Task<Socket> AcceptAsync(Socket? acceptSocket)
return t;
}

internal Task ConnectAsync(EndPoint remoteEP)
internal Task ConnectAsync(EndPoint remoteEP) => ConnectAsync(remoteEP, default).AsTask();

internal ValueTask ConnectAsync(EndPoint remoteEP, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled(cancellationToken);
}

// Use _singleBufferReceiveEventArgs so the AwaitableSocketAsyncEventArgs can be re-used later for receives.
AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: true);

saea.RemoteEndPoint = remoteEP;
return saea.ConnectAsync(this).AsTask();

ValueTask connectTask = saea.ConnectAsync(this);
if (connectTask.IsCompleted || !cancellationToken.CanBeCanceled)
{
// Avoid async invocation overhead
return connectTask;
}
else
{
return WaitForConnectWithCancellation(saea, connectTask, cancellationToken);
}

async ValueTask WaitForConnectWithCancellation(AwaitableSocketAsyncEventArgs saea, ValueTask connectTask, CancellationToken cancellationToken)
{
Debug.Assert(cancellationToken.CanBeCanceled);
try
{
using (cancellationToken.UnsafeRegister(o => CancelConnectAsync((SocketAsyncEventArgs)o!), saea))
{
await connectTask.ConfigureAwait(false);
}
}
catch (SocketException se) when (se.SocketErrorCode == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
throw;
}
}

}

internal Task ConnectAsync(IPAddress address, int port) => ConnectAsync(new IPEndPoint(address, port));

internal Task ConnectAsync(IPAddress[] addresses, int port)
internal ValueTask ConnectAsync(IPAddress address, int port, CancellationToken cancellationToken) => ConnectAsync(new IPEndPoint(address, port), cancellationToken);

internal Task ConnectAsync(IPAddress[] addresses, int port) => ConnectAsync(addresses, port, CancellationToken.None).AsTask();

internal ValueTask ConnectAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken)
{
if (addresses == null)
{
Expand All @@ -98,20 +137,20 @@ internal Task ConnectAsync(IPAddress[] addresses, int port)
throw new ArgumentException(SR.net_invalidAddressList, nameof(addresses));
}

return DoConnectAsync(addresses, port);
return DoConnectAsync(addresses, port, cancellationToken);
}

private async Task DoConnectAsync(IPAddress[] addresses, int port)
private async ValueTask DoConnectAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken)
{
Exception? lastException = null;
foreach (IPAddress address in addresses)
{
try
{
await ConnectAsync(address, port).ConfigureAwait(false);
await ConnectAsync(address, port, cancellationToken).ConfigureAwait(false);
geoffkizer marked this conversation as resolved.
Show resolved Hide resolved
return;
}
catch (Exception ex)
catch (Exception ex) when (ex is not OperationCanceledException)
{
lastException = ex;
}
Expand All @@ -121,7 +160,9 @@ private async Task DoConnectAsync(IPAddress[] addresses, int port)
ExceptionDispatchInfo.Throw(lastException);
}

internal Task ConnectAsync(string host, int port)
internal Task ConnectAsync(string host, int port) => ConnectAsync(host, port, default).AsTask();

internal ValueTask ConnectAsync(string host, int port, CancellationToken cancellationToken)
geoffkizer marked this conversation as resolved.
Show resolved Hide resolved
{
if (host == null)
{
Expand All @@ -131,7 +172,7 @@ internal Task ConnectAsync(string host, int port)
EndPoint ep = IPAddress.TryParse(host, out IPAddress? parsedAddress) ? (EndPoint)
new IPEndPoint(parsedAddress, port) :
new DnsEndPoint(host, port);
return ConnectAsync(ep);
return ConnectAsync(ep, cancellationToken);
}

internal Task<int> ReceiveAsync(ArraySegment<byte> buffer, SocketFlags socketFlags, bool fromNetworkStream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@ public static Task<Socket> AcceptAsync(this Socket socket, Socket? acceptSocket)

public static Task ConnectAsync(this Socket socket, EndPoint remoteEP) =>
socket.ConnectAsync(remoteEP);
public static ValueTask ConnectAsync(this Socket socket, EndPoint remoteEP, CancellationToken cancellationToken) =>
socket.ConnectAsync(remoteEP, cancellationToken);
public static Task ConnectAsync(this Socket socket, IPAddress address, int port) =>
socket.ConnectAsync(address, port);
public static ValueTask ConnectAsync(this Socket socket, IPAddress address, int port, CancellationToken cancellationToken) =>
socket.ConnectAsync(address, port, cancellationToken);
public static Task ConnectAsync(this Socket socket, IPAddress[] addresses, int port) =>
socket.ConnectAsync(addresses, port);
public static ValueTask ConnectAsync(this Socket socket, IPAddress[] addresses, int port, CancellationToken cancellationToken) =>
socket.ConnectAsync(addresses, port, cancellationToken);
public static Task ConnectAsync(this Socket socket, string host, int port) =>
socket.ConnectAsync(host, port);
public static ValueTask ConnectAsync(this Socket socket, string host, int port, CancellationToken cancellationToken) =>
socket.ConnectAsync(host, port, cancellationToken);

public static Task<int> ReceiveAsync(this Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) =>
socket.ReceiveAsync(buffer, socketFlags, fromNetworkStream: false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,8 @@ public void Connect(IPAddress[] ipAddresses, int port)
_active = true;
}

public Task ConnectAsync(IPAddress address, int port)
{

Task result = CompleteConnectAsync(Client.ConnectAsync(address, port));

return result;
}
public Task ConnectAsync(IPAddress address, int port) =>
CompleteConnectAsync(Client.ConnectAsync(address, port));

public Task ConnectAsync(string host, int port) =>
CompleteConnectAsync(Client.ConnectAsync(host, port));
Expand All @@ -279,6 +274,21 @@ private async Task CompleteConnectAsync(Task task)
_active = true;
}

public ValueTask ConnectAsync(IPAddress address, int port, CancellationToken cancellationToken) =>
CompleteConnectAsync(Client.ConnectAsync(address, port, cancellationToken));

public ValueTask ConnectAsync(string host, int port, CancellationToken cancellationToken) =>
CompleteConnectAsync(Client.ConnectAsync(host, port, cancellationToken));

public ValueTask ConnectAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken) =>
CompleteConnectAsync(Client.ConnectAsync(addresses, port, cancellationToken));

private async ValueTask CompleteConnectAsync(ValueTask task)
{
await task.ConfigureAwait(false);
_active = true;
}

public IAsyncResult BeginConnect(IPAddress address, int port, AsyncCallback? requestCallback, object? state) =>
Client.BeginConnect(address, port, requestCallback, state);

Expand Down
Loading