Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding span version of ReceiveMessageFrom #46285

Merged
merged 8 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ public void Listen(int backlog) { }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.SocketReceiveFromResult> ReceiveFromAsync(System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public bool ReceiveFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; }
public int ReceiveMessageFrom(System.Span<byte> buffer, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.SocketReceiveMessageFromResult> ReceiveMessageFromAsync(System.ArraySegment<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.SocketReceiveMessageFromResult> ReceiveMessageFromAsync(System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public bool ReceiveMessageFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
Expand Down
94 changes: 94 additions & 0 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,100 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla
return bytesTransferred;
}

/// <summary>
/// Receives the specified number of bytes of data into the specified location of the data buffer,
/// using the specified <paramref name="socketFlags"/>, and stores the endpoint and packet information.
/// </summary>
/// <param name="buffer">
/// An <see cref="Span{T}"/> of type <see cref="byte"/> that is the storage location for received data.
/// </param>
/// <param name="socketFlags">
/// A bitwise combination of the <see cref="SocketFlags"/> values.
/// </param>
/// <param name="remoteEP">
/// An <see cref="EndPoint"/>, passed by reference, that represents the remote server.
/// </param>
/// <param name="ipPacketInformation">
/// An <see cref="IPPacketInformation"/> holding address and interface information.
/// </param>
/// <returns>
/// The number of bytes received.
/// </returns>
/// <exception cref="ObjectDisposedException">The <see cref="Socket"/> object has been closed.</exception>
/// <exception cref="ArgumentNullException">The <see cref="EndPoint"/> remoteEP is null.</exception>
/// <exception cref="ArgumentException">The <see cref="AddressFamily"/> of the <see cref="EndPoint"/> used in
/// <see cref="Socket.ReceiveMessageFrom(Span{byte}, ref SocketFlags, ref EndPoint, out IPPacketInformation)"/>
/// needs to match the <see cref="AddressFamily"/> of the <see cref="EndPoint"/> used in SendTo.</exception>
/// <exception cref="InvalidOperationException">
/// <para>The <see cref="Socket"/> object is not in blocking mode and cannot accept this synchronous call.</para>
/// <para>You must call the Bind method before performing this operation.</para></exception>
public int ReceiveMessageFrom(Span<byte> buffer, ref SocketFlags socketFlags, ref EndPoint remoteEP, out IPPacketInformation ipPacketInformation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new practice is to add xmldocs right in the dotnet/runtime PR. See 9f73188 from #45083

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it seems like non of the other methods in this class has it, should I start the it or am I looking at the wrong place?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check #47230 and #47229 as examples adding new methods to Socket.cs together with xml docs.

Just copy existing docs, and modify them according to the new overload:
https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socket.receivemessagefrom?view=net-5.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

took a stab at the documentation, copied from the existing, but existing is missing some of the exception actually thrown from the method, so a bit unsure if they should be added.

{
ThrowIfDisposed();

if (remoteEP == null)
{
throw new ArgumentNullException(nameof(remoteEP));
}
if (!CanTryAddressFamily(remoteEP.AddressFamily))
{
throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP));
}
if (_rightEndPoint == null)
{
throw new InvalidOperationException(SR.net_sockets_mustbind);
}

SocketPal.CheckDualModeReceiveSupport(this);
ValidateBlockingMode();

// We don't do a CAS demand here because the contents of remoteEP aren't used by
// WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress
// with the right address family.
EndPoint endPointSnapshot = remoteEP;
Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot);

// Save a copy of the original EndPoint.
Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot);

SetReceivingPacketInformation();

Internals.SocketAddress receiveAddress;
int bytesTransferred;
SocketError errorCode = SocketPal.ReceiveMessageFrom(this, _handle, buffer, ref socketFlags, socketAddress, out receiveAddress, out ipPacketInformation, out bytesTransferred);

UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred);
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success && errorCode != SocketError.MessageSize)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
}
else if (SocketsTelemetry.Log.IsEnabled())
{
SocketsTelemetry.Log.BytesReceived(bytesTransferred);
if (errorCode == SocketError.Success && SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived();
}

if (!socketAddressOriginal.Equals(receiveAddress))
{
try
{
remoteEP = endPointSnapshot.Create(receiveAddress);
}
catch
{
}
if (_rightEndPoint == null)
{
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint = endPointSnapshot;
}
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, errorCode);
return bytesTransferred;
}

// Receives a datagram into a specific location in the data buffer and stores
// the end point.
public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,31 @@ public override void InvokeCallback(bool allowPooling) =>
Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode);
}

private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation
{
public byte* BufferPtr;
public int Length;
public SocketFlags Flags;
public int BytesTransferred;
public SocketFlags ReceivedFlags;

public bool IsIPv4;
public bool IsIPv6;
public IPPacketInformation IPPacketInformation;

public BufferPtrReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { }

protected sealed override void Abort() { }

public Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; }

protected override bool DoTryComplete(SocketAsyncContext context) =>
SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span<byte>(BufferPtr, Length), null, Flags, SocketAddress!, ref SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode);

public override void InvokeCallback(bool allowPooling) =>
Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode);
}

private sealed class AcceptOperation : ReadOperation
{
public IntPtr AcceptedFileDescriptor;
Expand Down Expand Up @@ -1696,15 +1721,15 @@ public SocketError ReceiveFromAsync(IList<ArraySegment<byte>> buffers, SocketFla
}

public SocketError ReceiveMessageFrom(
Memory<byte> buffer, IList<ArraySegment<byte>>? buffers, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived)
Memory<byte> buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived)
{
Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}");

SocketFlags receivedFlags;
SocketError errorCode;
int observedSequenceNumber;
if (_receiveQueue.IsReady(this, out observedSequenceNumber) &&
(SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, buffers, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) ||
(SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) ||
!ShouldRetrySyncOperation(out errorCode)))
{
flags = receivedFlags;
Expand All @@ -1714,7 +1739,7 @@ public SocketError ReceiveMessageFrom(
var operation = new ReceiveMessageFromOperation(this)
{
Buffer = buffer,
Buffers = buffers,
Buffers = null,
Flags = flags,
SocketAddress = socketAddress,
SocketAddressLen = socketAddressLen,
Expand All @@ -1731,6 +1756,45 @@ public SocketError ReceiveMessageFrom(
return operation.ErrorCode;
}

public unsafe SocketError ReceiveMessageFrom(
Span<byte> buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived)
{
Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}");

SocketFlags receivedFlags;
SocketError errorCode;
int observedSequenceNumber;
if (_receiveQueue.IsReady(this, out observedSequenceNumber) &&
(SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) ||
!ShouldRetrySyncOperation(out errorCode)))
{
flags = receivedFlags;
return errorCode;
}

fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer))
{
var operation = new BufferPtrReceiveMessageFromOperation(this)
{
BufferPtr = bufferPtr,
Length = buffer.Length,
Flags = flags,
SocketAddress = socketAddress,
SocketAddressLen = socketAddressLen,
IsIPv4 = isIPv4,
IsIPv6 = isIPv6,
};

PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber);

socketAddressLen = operation.SocketAddressLen;
flags = operation.ReceivedFlags;
ipPacketInformation = operation.IPPacketInformation;
bytesReceived = operation.BytesTransferred;
return operation.ErrorCode;
}
}

public SocketError ReceiveMessageFromAsync(Memory<byte> buffer, IList<ArraySegment<byte>>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError> callback, CancellationToken cancellationToken = default)
{
SetNonBlocking();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han
SocketError errorCode;
if (!handle.IsNonBlocking)
{
errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory<byte>(buffer, offset, count), null, ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred);
errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory<byte>(buffer, offset, count), ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred);
}
else
{
Expand All @@ -1187,6 +1187,33 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han
return errorCode;
}


public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span<byte> buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred)
{
byte[] socketAddressBuffer = socketAddress.Buffer;
int socketAddressLen = socketAddress.Size;

bool isIPv4, isIPv6;
Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out isIPv4, out isIPv6);

SocketError errorCode;
if (!handle.IsNonBlocking)
{
errorCode = handle.AsyncContext.ReceiveMessageFrom(buffer, ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred);
}
else
{
if (!TryCompleteReceiveMessageFrom(handle, buffer, null, socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode))
{
errorCode = SocketError.WouldBlock;
}
}

socketAddress.InternalSize = socketAddressLen;
receiveAddress = socketAddress;
return errorCode;
}

public static SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, byte[] socketAddress, ref int socketAddressLen, out int bytesTransferred)
{
if (!handle.IsNonBlocking)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,14 +451,19 @@ public static unsafe IPPacketInformation GetIPPacketInformation(Interop.Winsock.
}

public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int size, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred)
{
return ReceiveMessageFrom(socket, handle, new Span<byte>(buffer, offset, size), ref socketFlags, socketAddress, out receiveAddress, out ipPacketInformation, out bytesTransferred);
}

public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span<byte> buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred)
ovebastiansen marked this conversation as resolved.
Show resolved Hide resolved
{
bool ipv4, ipv6;
Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out ipv4, out ipv6);

bytesTransferred = 0;
receiveAddress = socketAddress;
ipPacketInformation = default(IPPacketInformation);
fixed (byte* ptrBuffer = buffer)
fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer))
fixed (byte* ptrSocketAddress = socketAddress.Buffer)
{
Interop.Winsock.WSAMsg wsaMsg;
Expand All @@ -467,8 +472,8 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan
wsaMsg.flags = socketFlags;

WSABuffer wsaBuffer;
wsaBuffer.Length = size;
wsaBuffer.Pointer = (IntPtr)(ptrBuffer + offset);
wsaBuffer.Length = buffer.Length;
wsaBuffer.Pointer = (IntPtr)bufferPtr;
wsaMsg.buffers = (IntPtr)(&wsaBuffer);
wsaMsg.count = 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public async Task ReceiveSent_TCP_Success(bool ipv6)
[InlineData(true)]
public async Task ReceiveSentMessages_UDP_Success(bool ipv4)
{
// [ActiveIssue("https://github.com/dotnet/runtime/issues/47637")]
int Offset = UsesSync || !PlatformDetection.IsWindows ? 10 : 0;
const int DatagramSize = 256;
const int DatagramsToSend = 16;

Expand All @@ -69,7 +71,9 @@ public async Task ReceiveSentMessages_UDP_Success(bool ipv4)
sender.BindToAnonymousPort(address);

byte[] sendBuffer = new byte[DatagramSize];
byte[] receiveBuffer = new byte[DatagramSize];
var receiveInternalBuffer = new byte[DatagramSize + Offset];
var emptyBuffer = new byte[Offset];
ArraySegment<byte> receiveBuffer = new ArraySegment<byte>(receiveInternalBuffer, Offset, DatagramSize);
Random rnd = new Random(0);

IPEndPoint remoteEp = new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0);
Expand All @@ -83,7 +87,8 @@ public async Task ReceiveSentMessages_UDP_Success(bool ipv4)
IPPacketInformation packetInformation = result.PacketInformation;

Assert.Equal(DatagramSize, result.ReceivedBytes);
AssertExtensions.SequenceEqual(sendBuffer, receiveBuffer);
AssertExtensions.SequenceEqual(emptyBuffer, new ReadOnlySpan<byte>(receiveInternalBuffer, 0, Offset));
AssertExtensions.SequenceEqual(sendBuffer, new ReadOnlySpan<byte>(receiveInternalBuffer, Offset, DatagramSize));
Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint);
Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, packetInformation.Address);
}
Expand Down
Loading