Skip to content

Commit

Permalink
fix SendTo with SocketAsyncEventArgs (#98134)
Browse files Browse the repository at this point in the history
* fix SendTo with SocketAsyncEventArgs

* feedback
  • Loading branch information
wfurt authored Feb 27, 2024
1 parent f729653 commit f9637f1
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,6 @@ public ValueTask<int> SendToAsync(ReadOnlyMemory<byte> buffer, SocketFlags socke
Debug.Assert(saea.BufferList == null);
saea.SetBuffer(MemoryMarshal.AsMemory(buffer));
saea.SocketFlags = socketFlags;
saea._socketAddress = null;
saea.RemoteEndPoint = remoteEP;
saea.WrapExceptionsForNetworkStream = false;
return saea.SendToAsync(this, cancellationToken);
Expand Down Expand Up @@ -709,8 +708,17 @@ public ValueTask<int> SendToAsync(ReadOnlyMemory<byte> buffer, SocketFlags socke
saea.SetBuffer(MemoryMarshal.AsMemory(buffer));
saea.SocketFlags = socketFlags;
saea._socketAddress = socketAddress;
saea.RemoteEndPoint = null;
saea.WrapExceptionsForNetworkStream = false;
return saea.SendToAsync(this, cancellationToken);
try
{
return saea.SendToAsync(this, cancellationToken);
}
finally
{
// detach user provided SA so we do not accidentally stomp on it later.
saea._socketAddress = null;
}
}

/// <summary>
Expand Down
20 changes: 14 additions & 6 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3095,14 +3095,22 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT
ArgumentNullException.ThrowIfNull(e);

EndPoint? endPointSnapshot = e.RemoteEndPoint;
if (e._socketAddress == null)

// RemoteEndPoint should be set unless somebody used SendTo with their own SA.
// In that case RemoteEndPoint will be null and we take provided SA as given.
if (endPointSnapshot == null && e._socketAddress == null)
{
if (endPointSnapshot == null)
{
throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e));
}
throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e));
}

// Prepare SocketAddress
if (e._socketAddress != null && endPointSnapshot is IPEndPoint ipep && e._socketAddress.Family == endPointSnapshot?.AddressFamily)
{
// we have matching SocketAddress. Since this is only used internally, it is ok to overwrite it without
ipep.Serialize(e._socketAddress.Buffer.Span);
}
else if (endPointSnapshot != null)
{
// Prepare new SocketAddress
e._socketAddress = Serialize(ref endPointSnapshot);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,12 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags
case SocketAsyncOperation.ReceiveFrom:
// Deal with incoming address.
UpdateReceivedSocketAddress(_socketAddress!);
if (_remoteEndPoint != null && !SocketAddressExtensions.Equals(_socketAddress!, _remoteEndPoint))
if (_remoteEndPoint == null)
{
// detach user provided SA as it was updated in place.
_socketAddress = null;
}
else if (!SocketAddressExtensions.Equals(_socketAddress!, _remoteEndPoint))
{
try
{
Expand Down
29 changes: 29 additions & 0 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,35 @@ public void SendToAsync_NullAsyncEventArgs_Throws_ArgumentNullException()
public sealed class SendTo_Task : SendTo<SocketHelperTask>
{
public SendTo_Task(ITestOutputHelper output) : base(output) { }

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task SendTo_DifferentEP_Success(bool ipv4)
{
IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback;
IPEndPoint remoteEp = new IPEndPoint(address, 0);

using Socket receiver1 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
using Socket receiver2 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
using Socket sender = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp);

receiver1.BindToAnonymousPort(address);
receiver2.BindToAnonymousPort(address);

byte[] sendBuffer = new byte[32];
var receiveInternalBuffer = new byte[sendBuffer.Length];
ArraySegment<byte> receiveBuffer = new ArraySegment<byte>(receiveInternalBuffer, 0, receiveInternalBuffer.Length);


await sender.SendToAsync(sendBuffer, SocketFlags.None, receiver1.LocalEndPoint);
SocketReceiveFromResult result = await ReceiveFromAsync(receiver1, receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout);
Assert.Equal(sendBuffer.Length, result.ReceivedBytes);

await sender.SendToAsync(sendBuffer, SocketFlags.None, receiver2.LocalEndPoint);
result = await ReceiveFromAsync(receiver2, receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout);
Assert.Equal(sendBuffer.Length, result.ReceivedBytes);
}
}

public sealed class SendTo_CancellableTask : SendTo<SocketHelperCancellableTask>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,5 +895,52 @@ void CreateSocketAsyncEventArgs() // separated out so that JIT doesn't extend li
return cwt.Count() == 0; // validate that the cwt becomes empty
}, 30_000));
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task SendTo_DifferentEP_Success(bool ipv4)
{
IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback;
IPEndPoint remoteEp = new IPEndPoint(address, 0);

using Socket receiver1 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
using Socket receiver2 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
using Socket sender = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp);

receiver1.BindToAnonymousPort(address);
receiver2.BindToAnonymousPort(address);

byte[] sendBuffer = new byte[32];
var receiveInternalBuffer = new byte[sendBuffer.Length];
ArraySegment<byte> receiveBuffer = new ArraySegment<byte>(receiveInternalBuffer, 0, receiveInternalBuffer.Length);

using SocketAsyncEventArgs saea = new SocketAsyncEventArgs();
ManualResetEventSlim mres = new ManualResetEventSlim(false);

saea.SetBuffer(sendBuffer);
saea.RemoteEndPoint = receiver1.LocalEndPoint;
saea.Completed += delegate { mres.Set(); };
if (sender.SendToAsync(saea))
{
// did not finish synchronously.
mres.Wait();
}

SocketReceiveFromResult result = await receiver1.ReceiveFromAsync(receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout);
Assert.Equal(sendBuffer.Length, result.ReceivedBytes);
mres.Reset();


saea.RemoteEndPoint = receiver2.LocalEndPoint;
if (sender.SendToAsync(saea))
{
// did not finish synchronously.
mres.Wait();
}

result = await receiver2.ReceiveFromAsync(receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout);
Assert.Equal(sendBuffer.Length, result.ReceivedBytes);
}
}
}

0 comments on commit f9637f1

Please sign in to comment.