From ad418094ab11c0192067aa6b11dacb0ff41de2cc Mon Sep 17 00:00:00 2001 From: wfurt Date: Tue, 11 Jul 2023 20:05:51 -0700 Subject: [PATCH 01/18] check --- .../Interop/Windows/WinSock/Interop.accept.cs | 2 +- .../Windows/WinSock/Interop.recvfrom.cs | 4 +- .../Interop/Windows/WinSock/Interop.sendto.cs | 2 +- .../Common/src/System/Net/SocketAddress.cs | 13 ++ .../ref/System.Net.Primitives.cs | 3 +- .../ref/System.Net.Sockets.cs | 10 +- .../src/System/Net/Sockets/Socket.cs | 102 ++++++++- .../Net/Sockets/SocketAsyncContext.Unix.cs | 213 +++++++++--------- .../Net/Sockets/SocketAsyncEventArgs.Unix.cs | 60 ++--- .../src/System/Net/Sockets/SocketPal.Unix.cs | 162 +++++++------ .../System/Net/Sockets/SocketPal.Windows.cs | 25 +- 11 files changed, 350 insertions(+), 246 deletions(-) diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.accept.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.accept.cs index 2202c0815ff17..531327f0e5ca9 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.accept.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.accept.cs @@ -12,7 +12,7 @@ internal static partial class Winsock [LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true)] internal static partial IntPtr accept( SafeSocketHandle socketHandle, - byte[] socketAddress, + Span socketAddress, ref int socketAddressSize); } } diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.recvfrom.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.recvfrom.cs index 8c470c30268ef..70d76733825ac 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.recvfrom.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.recvfrom.cs @@ -13,10 +13,10 @@ internal static partial class Winsock [LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true)] internal static unsafe partial int recvfrom( SafeSocketHandle socketHandle, - byte* pinnedBuffer, + Span pinnedBuffer, int len, SocketFlags socketFlags, - byte[] socketAddress, + Span socketAddress, ref int socketAddressSize); } } diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.sendto.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.sendto.cs index 01e036664049e..e93d609401513 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.sendto.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.sendto.cs @@ -15,7 +15,7 @@ internal static unsafe partial int sendto( byte* pinnedBuffer, int len, SocketFlags socketFlags, - byte[] socketAddress, + ReadOnlySpan socketAddress, int socketAddressSize); } } diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index 62323be556499..ef58290645883 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -50,6 +50,11 @@ public int Size { return InternalSize; } + set + { + ArgumentOutOfRangeException.ThrowIfGreaterThan(value, Buffer.Length); + InternalSize = value; + } } // Access to unmanaged serialized data. This doesn't @@ -139,6 +144,14 @@ internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan buffer) SocketAddressPal.SetAddressFamily(Buffer, addressFamily); } + public Memory SocketBuffer + { + get + { + return new Memory(Buffer, 0, InternalSize); + } + } + internal IPAddress GetIPAddress() { if (Family == AddressFamily.InterNetworkV6) diff --git a/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs b/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs index 2ed166f381b9b..3b97047c98a7e 100644 --- a/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs +++ b/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs @@ -355,7 +355,8 @@ public SocketAddress(System.Net.Sockets.AddressFamily family) { } public SocketAddress(System.Net.Sockets.AddressFamily family, int size) { } public System.Net.Sockets.AddressFamily Family { get { throw null; } } public byte this[int offset] { get { throw null; } set { } } - public int Size { get { throw null; } } + public int Size { get { throw null; } set { } } + public System.Memory SocketBuffer { get { throw null; } } public override bool Equals(object? comparand) { throw null; } public override int GetHashCode() { throw null; } public override string ToString() { throw null; } diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index c9b67c966f7a0..ec2eab3363091 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -105,6 +105,8 @@ public partial class LingerOption public LingerOption(bool enable, int seconds) { } public bool Enabled { get { throw null; } set { } } public int LingerTime { get { throw null; } set { } } + public override bool Equals(object? comparand) { throw null; } + public override int GetHashCode() { throw null; } } public partial class MulticastOption { @@ -223,7 +225,7 @@ public enum ProtocolType public sealed partial class SafeSocketHandle : Microsoft.Win32.SafeHandles.SafeHandleMinusOneIsInvalid { public SafeSocketHandle() : base (default(bool)) { } - public SafeSocketHandle(System.IntPtr preexistingHandle, bool ownsHandle) : base (default(bool)) { } + public SafeSocketHandle(nint preexistingHandle, bool ownsHandle) : base (default(bool)) { } public override bool IsInvalid { get { throw null; } } protected override bool ReleaseHandle() { throw null; } } @@ -272,7 +274,7 @@ public Socket(System.Net.Sockets.SocketType socketType, System.Net.Sockets.Proto public bool DualMode { get { throw null; } set { } } public bool EnableBroadcast { get { throw null; } set { } } public bool ExclusiveAddressUse { get { throw null; } set { } } - public System.IntPtr Handle { get { throw null; } } + public nint Handle { get { throw null; } } public bool IsBound { get { throw null; } } [System.Diagnostics.CodeAnalysis.DisallowNullAttribute] public System.Net.Sockets.LingerOption? LingerState { get { throw null; } set { } } @@ -398,6 +400,7 @@ public void Listen(int backlog) { } public int ReceiveFrom(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } + public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress remoteSA) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -442,6 +445,7 @@ public void SendFile(string? fileName, System.ReadOnlySpan preBuffer, Syst public int SendTo(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public int SendTo(System.ReadOnlySpan buffer, System.Net.EndPoint remoteEP) { throw null; } public int SendTo(System.ReadOnlySpan buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } + public int SendTo(System.ReadOnlySpan buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress remoteSA) { throw null; } public System.Threading.Tasks.Task SendToAsync(System.ArraySegment buffer, System.Net.EndPoint remoteEP) { throw null; } public System.Threading.Tasks.Task SendToAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public bool SendToAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } @@ -811,6 +815,8 @@ public sealed partial class UnixDomainSocketEndPoint : System.Net.EndPoint public UnixDomainSocketEndPoint(string path) { } public override System.Net.Sockets.AddressFamily AddressFamily { get { throw null; } } public override System.Net.EndPoint Create(System.Net.SocketAddress socketAddress) { throw null; } + public override bool Equals([System.Diagnostics.CodeAnalysis.NotNullWhenAttribute(true)] object? obj) { throw null; } + public override int GetHashCode() { throw null; } public override System.Net.SocketAddress Serialize() { throw null; } public override string ToString() { throw null; } } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 6b724c9a1709c..e95fab725709b 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -1014,13 +1014,15 @@ public Socket Accept() // This may throw ObjectDisposedException. SafeSocketHandle acceptedSocketHandle; SocketError errorCode; + int socketAddressLen; try { errorCode = SocketPal.Accept( _handle, - socketAddress.Buffer, - ref socketAddress.InternalSize, + socketAddress.SocketBuffer, + out socketAddressLen, out acceptedSocketHandle); + socketAddress.Size = socketAddressLen; } catch (Exception ex) { @@ -1282,7 +1284,7 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, Internals.SocketAddress socketAddress = Serialize(ref remoteEP); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, socketAddress.Size, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, offset, size, socketFlags, socketAddress.SocketBuffer, out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1354,7 +1356,7 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, EndPoint r Internals.SocketAddress socketAddress = Serialize(ref remoteEP); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer, socketAddress.Size, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.SocketBuffer, out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1375,6 +1377,42 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, EndPoint r return bytesTransferred; } + /// + /// Sends data to a specific endpoint using the specified . + /// + /// A span of bytes that contains the data to be sent. + /// A bitwise combination of the values. + /// The that represents the destination for the data. + /// The number of bytes sent. + /// remoteEP is . + /// An error occurred when attempting to access the socket. + /// The has been closed. + public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, SocketAddress remoteSA) + { + ThrowIfDisposed(); + ArgumentNullException.ThrowIfNull(remoteSA); + + ValidateBlockingMode(); + + int bytesTransferred; + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, remoteSA.SocketBuffer, out bytesTransferred); + + // Throw an appropriate SocketException if the native call fails. + if (errorCode != SocketError.Success) + { + UpdateSendSocketErrorForDisposed(ref errorCode); + + UpdateStatusAfterSocketErrorAndThrowException(errorCode); + } + else if (SocketsTelemetry.Log.IsEnabled()) + { + SocketsTelemetry.Log.BytesSent(bytesTransferred); + if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramSent(); + } + + return bytesTransferred; + } + // Receives data from a connected socket. public int Receive(byte[] buffer, int size, SocketFlags socketFlags) { @@ -1681,7 +1719,7 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); int bytesTransferred; - SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, ref socketAddress.InternalSize, out bytesTransferred); + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, out int socketAddressLength, out bytesTransferred); UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); // If the native call fails we'll throw a SocketException. @@ -1703,6 +1741,8 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); } + socketAddress.Size = socketAddressLength; + if (!socketAddressOriginal.Equals(socketAddress)) { try @@ -1788,7 +1828,7 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); int bytesTransferred; - SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, socketAddress.Buffer, ref socketAddress.InternalSize, out bytesTransferred); + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, socketAddress.Buffer, out int socketAddressLength, out bytesTransferred); UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); // If the native call fails we'll throw a SocketException. @@ -1810,6 +1850,7 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); } + socketAddress.Size = socketAddressLength; if (!socketAddressOriginal.Equals(socketAddress)) { try @@ -1840,6 +1881,55 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint return bytesTransferred; } + /// + /// Receives a datagram into the data buffer, using the specified , and stores the endpoint. + /// + /// A span of bytes that is the storage location for received data. + /// A bitwise combination of the values. + /// An , passed by reference, that represents the remote server. + /// The number of bytes received. + /// remoteEP is . + /// An error occurred when attempting to access the socket. + /// The has been closed. + public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress remoteSA) + { + ThrowIfDisposed(); + + ValidateBlockingMode(); + + int bytesTransferred; + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, remoteSA.SocketBuffer, out int socketAddressSize, out bytesTransferred); + + UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); + // If the native call fails we'll throw a SocketException. + SocketException? socketException = null; + if (errorCode != SocketError.Success) + { + socketException = new SocketException((int)errorCode); + UpdateStatusAfterSocketError(socketException); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, socketException); + + if (socketException.SocketErrorCode != SocketError.MessageSize) + { + throw socketException; + } + } + else if (SocketsTelemetry.Log.IsEnabled()) + { + SocketsTelemetry.Log.BytesReceived(bytesTransferred); + if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); + } + + if (socketException != null) + { + throw socketException; + } + + remoteSA.Size = socketAddressSize; + + return bytesTransferred; + } + public int IOControl(int ioControlCode, byte[]? optionInValue, byte[]? optionOutValue) { ThrowIfDisposed(); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index 1926097b8e45d..255ecc8b2300d 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -108,7 +108,7 @@ private BufferListSendOperation RentBufferListSendOperation() => Interlocked.Exchange(ref _cachedBufferListSendOperation, null) ?? new BufferListSendOperation(this); - private abstract class AsyncOperation : IThreadPoolWorkItem + private abstract unsafe class AsyncOperation : IThreadPoolWorkItem { private enum State { @@ -128,7 +128,7 @@ private enum State public readonly SocketAsyncContext AssociatedContext; public AsyncOperation Next = null!; // initialized by helper called from ctor public SocketError ErrorCode; - public byte[]? SocketAddress; + public Memory SocketAddress; public int SocketAddressLen; public CancellationTokenRegistration CancellationRegistration; @@ -348,7 +348,7 @@ public WriteOperation(SocketAsyncContext context) : base(context) { } void IThreadPoolWorkItem.Execute() => AssociatedContext.ProcessAsyncWriteOperation(this); } - private abstract class SendOperation : WriteOperation + private abstract unsafe class SendOperation : WriteOperation { public SocketFlags Flags; public int BytesTransferred; @@ -357,10 +357,10 @@ private abstract class SendOperation : WriteOperation public SendOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, SocketFlags, SocketError>? Callback { get; set; } - public override void InvokeCallback(bool allowPooling) => - Callback!(BytesTransferred, SocketAddress, SocketAddressLen, SocketFlags.None, ErrorCode); + public override unsafe void InvokeCallback(bool allowPooling) => + Callback!(BytesTransferred, SocketAddress, SocketFlags.None, ErrorCode); } private sealed class BufferMemorySendOperation : SendOperation @@ -372,15 +372,14 @@ public BufferMemorySendOperation(SocketAsyncContext context) : base(context) { } protected override bool DoTryComplete(SocketAsyncContext context) { int bufferIndex = 0; - return SocketPal.TryCompleteSendTo(context._socket, Buffer.Span, null, ref bufferIndex, ref Offset, ref Count, Flags, SocketAddress, SocketAddressLen, ref BytesTransferred, out ErrorCode); + return SocketPal.TryCompleteSendTo(context._socket, Buffer.Span, null, ref bufferIndex, ref Offset, ref Count, Flags, SocketAddress.Span, ref BytesTransferred, out ErrorCode); } - public override void InvokeCallback(bool allowPooling) + public override unsafe void InvokeCallback(bool allowPooling) { var cb = Callback!; int bt = BytesTransferred; - byte[]? sa = SocketAddress; - int sal = SocketAddressLen; + Memory sa = SocketAddress; SocketError ec = ErrorCode; if (allowPooling) @@ -388,7 +387,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(bt, sa, sal, SocketFlags.None, ec); + cb(bt, sa, SocketFlags.None, ec); } } @@ -401,15 +400,14 @@ public BufferListSendOperation(SocketAsyncContext context) : base(context) { } protected override bool DoTryComplete(SocketAsyncContext context) { - return SocketPal.TryCompleteSendTo(context._socket, default(ReadOnlySpan), Buffers, ref BufferIndex, ref Offset, ref Count, Flags, SocketAddress, SocketAddressLen, ref BytesTransferred, out ErrorCode); + return SocketPal.TryCompleteSendTo(context._socket, default(ReadOnlySpan), Buffers, ref BufferIndex, ref Offset, ref Count, Flags, SocketAddress.Span, ref BytesTransferred, out ErrorCode); } public override void InvokeCallback(bool allowPooling) { var cb = Callback!; int bt = BytesTransferred; - byte[]? sa = SocketAddress; - int sal = SocketAddressLen; + Memory sa = SocketAddress; SocketError ec = ErrorCode; if (allowPooling) @@ -417,7 +415,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(bt, sa, sal, SocketFlags.None, ec); + cb(bt, sa, SocketFlags.None, ec); } } @@ -431,7 +429,7 @@ protected override bool DoTryComplete(SocketAsyncContext context) { int bufferIndex = 0; int bufferLength = Offset + Count; // TryCompleteSendTo expects the entire buffer, which it then indexes into with the ref Offset and ref Count arguments - return SocketPal.TryCompleteSendTo(context._socket, new ReadOnlySpan(BufferPtr, bufferLength), null, ref bufferIndex, ref Offset, ref Count, Flags, SocketAddress, SocketAddressLen, ref BytesTransferred, out ErrorCode); + return SocketPal.TryCompleteSendTo(context._socket, new ReadOnlySpan(BufferPtr, bufferLength), null, ref bufferIndex, ref Offset, ref Count, Flags, SocketAddress.Span, ref BytesTransferred, out ErrorCode); } } @@ -443,10 +441,10 @@ private abstract class ReceiveOperation : ReadOperation public ReceiveOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, SocketFlags, SocketError>? Callback { get; set; } public override void InvokeCallback(bool allowPooling) => - Callback!(BytesTransferred, SocketAddress, SocketAddressLen, ReceivedFlags, ErrorCode); + Callback!(BytesTransferred, SocketAddress, ReceivedFlags, ErrorCode); } private sealed class BufferMemoryReceiveOperation : ReceiveOperation @@ -460,7 +458,7 @@ protected override bool DoTryComplete(SocketAsyncContext context) { // Zero byte read is performed to know when data is available. // We don't have to call receive, our caller is interested in the event. - if (Buffer.Length == 0 && Flags == SocketFlags.None && SocketAddress == null) + if (Buffer.Length == 0 && Flags == SocketFlags.None && SocketAddress.Length == 0) { BytesTransferred = 0; ReceivedFlags = SocketFlags.None; @@ -471,14 +469,14 @@ protected override bool DoTryComplete(SocketAsyncContext context) { if (!SetReceivedFlags) { - Debug.Assert(SocketAddress == null); + Debug.Assert(SocketAddress.Length == 0); ReceivedFlags = SocketFlags.None; return SocketPal.TryCompleteReceive(context._socket, Buffer.Span, Flags, out BytesTransferred, out ErrorCode); } else { - return SocketPal.TryCompleteReceiveFrom(context._socket, Buffer.Span, null, Flags, SocketAddress, ref SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + return SocketPal.TryCompleteReceiveFrom(context._socket, Buffer.Span, null, Flags, SocketAddress.Span, out SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); } } } @@ -487,7 +485,7 @@ public override void InvokeCallback(bool allowPooling) { var cb = Callback!; int bt = BytesTransferred; - byte[]? sa = SocketAddress; + Memory sa = SocketAddress; int sal = SocketAddressLen; SocketFlags rf = ReceivedFlags; SocketError ec = ErrorCode; @@ -497,7 +495,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(bt, sa, sal, rf, ec); + cb(bt, sa, rf, ec); } } @@ -508,14 +506,13 @@ private sealed class BufferListReceiveOperation : ReceiveOperation public BufferListReceiveOperation(SocketAsyncContext context) : base(context) { } protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveFrom(context._socket, default(Span), Buffers, Flags, SocketAddress, ref SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + SocketPal.TryCompleteReceiveFrom(context._socket, default(Span), Buffers, Flags, SocketAddress.Span, out SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); public override void InvokeCallback(bool allowPooling) { var cb = Callback!; int bt = BytesTransferred; - byte[]? sa = SocketAddress; - int sal = SocketAddressLen; + Memory sa = SocketAddress; SocketFlags rf = ReceivedFlags; SocketError ec = ErrorCode; @@ -524,7 +521,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(bt, sa, sal, rf, ec); + cb(bt, sa, rf, ec); } } @@ -536,7 +533,7 @@ private sealed unsafe class BufferPtrReceiveOperation : ReceiveOperation public BufferPtrReceiveOperation(SocketAsyncContext context) : base(context) { } protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress, ref SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + SocketPal.TryCompleteReceiveFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress.Span, out SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); } private sealed class ReceiveMessageFromOperation : ReadOperation @@ -553,13 +550,13 @@ private sealed class ReceiveMessageFromOperation : ReadOperation public ReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; } protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveMessageFrom(context._socket, Buffer.Span, Buffers, Flags, SocketAddress!, ref SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + SocketPal.TryCompleteReceiveMessageFrom(context._socket, Buffer.Span, Buffers, Flags, SocketAddress, out SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); public override void InvokeCallback(bool allowPooling) => - Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode); + Callback!(BytesTransferred, SocketAddress, ReceivedFlags, IPPacketInformation, ErrorCode); } private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation @@ -576,13 +573,13 @@ private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation public BufferPtrReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, int, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; } protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress!, ref SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress!, out SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); public override void InvokeCallback(bool allowPooling) => - Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode); + Callback!(BytesTransferred, SocketAddress, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode); } private sealed class AcceptOperation : ReadOperation @@ -591,11 +588,11 @@ private sealed class AcceptOperation : ReadOperation public AcceptOperation(SocketAsyncContext context) : base(context) { } - public Action? Callback { get; set; } + public Action, SocketError>? Callback { get; set; } protected override bool DoTryComplete(SocketAsyncContext context) { - bool completed = SocketPal.TryCompleteAccept(context._socket, SocketAddress!, ref SocketAddressLen, out AcceptedFileDescriptor, out ErrorCode); + bool completed = SocketPal.TryCompleteAccept(context._socket, SocketAddress, out SocketAddressLen, out AcceptedFileDescriptor, out ErrorCode); Debug.Assert(ErrorCode == SocketError.Success || AcceptedFileDescriptor == (IntPtr)(-1), $"Unexpected values: ErrorCode={ErrorCode}, AcceptedFileDescriptor={AcceptedFileDescriptor}"); return completed; } @@ -604,7 +601,7 @@ public override void InvokeCallback(bool allowPooling) { var cb = Callback!; IntPtr fd = AcceptedFileDescriptor; - byte[] sa = SocketAddress!; + Memory sa = SocketAddress; int sal = SocketAddressLen; SocketError ec = ErrorCode; @@ -613,7 +610,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(fd, sa, sal, ec); + cb(fd, sa.Slice(0, sal), ec); } } @@ -1383,15 +1380,14 @@ private bool ShouldRetrySyncOperation(out SocketError errorCode) private void ProcessAsyncWriteOperation(WriteOperation op) => _sendQueue.ProcessAsyncOperation(op); - public SocketError Accept(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd) + public SocketError Accept(Memory socketAddress, out int socketAddressLen, out IntPtr acceptedFd) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteAccept(_socket, socketAddress, ref socketAddressLen, out acceptedFd, out errorCode)) + SocketPal.TryCompleteAccept(_socket, socketAddress, out socketAddressLen, out acceptedFd, out errorCode)) { Debug.Assert(errorCode == SocketError.Success || acceptedFd == (IntPtr)(-1), $"Unexpected values: errorCode={errorCode}, acceptedFd={acceptedFd}"); return errorCode; @@ -1400,7 +1396,7 @@ public SocketError Accept(byte[] socketAddress, ref int socketAddressLen, out In var operation = new AcceptOperation(this) { SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, }; PerformSyncOperation(ref _receiveQueue, operation, -1, observedSequenceNumber); @@ -1410,10 +1406,9 @@ public SocketError Accept(byte[] socketAddress, ref int socketAddressLen, out In return operation.ErrorCode; } - public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action callback, CancellationToken cancellationToken) + public SocketError AcceptAsync(Memory socketAddress, out int socketAddressLen, out IntPtr acceptedFd, Action, SocketError> callback, CancellationToken cancellationToken) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); Debug.Assert(callback != null, "Expected non-null callback"); SetHandleNonBlocking(); @@ -1421,7 +1416,7 @@ public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, o SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteAccept(_socket, socketAddress, ref socketAddressLen, out acceptedFd, out errorCode)) + SocketPal.TryCompleteAccept(_socket, socketAddress, out socketAddressLen, out acceptedFd, out errorCode)) { Debug.Assert(errorCode == SocketError.Success || acceptedFd == (IntPtr)(-1), $"Unexpected values: errorCode={errorCode}, acceptedFd={acceptedFd}"); @@ -1431,7 +1426,7 @@ public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, o AcceptOperation operation = RentAcceptOperation(); operation.Callback = callback; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; + operation.SocketAddressLen = socketAddress.Length; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { @@ -1444,6 +1439,7 @@ public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, o } acceptedFd = (IntPtr)(-1); + socketAddressLen = 0; return SocketError.IOPending; } @@ -1514,23 +1510,21 @@ public SocketError ConnectAsync(byte[] socketAddress, int socketAddressLen, Acti public SocketError Receive(Memory buffer, SocketFlags flags, int timeout, out int bytesReceived) { - int socketAddressLen = 0; - return ReceiveFrom(buffer, ref flags, null, ref socketAddressLen, timeout, out bytesReceived); + //int socketAddressLen = 0; + return ReceiveFrom(buffer, ref flags, Memory.Empty, out int _, timeout, out bytesReceived); } public SocketError Receive(Span buffer, SocketFlags flags, int timeout, out int bytesReceived) { - int socketAddressLen = 0; - return ReceiveFrom(buffer, ref flags, null, ref socketAddressLen, timeout, out bytesReceived); + return ReceiveFrom(buffer, ref flags, Memory.Empty, out int _, timeout, out bytesReceived); } - public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action callback, CancellationToken cancellationToken) + public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken) { - int socketAddressLen = 0; - return ReceiveFromAsync(buffer, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback, cancellationToken); + return ReceiveFromAsync(buffer, flags, Memory.Empty, out int _, out bytesReceived, out receivedFlags, callback, cancellationToken); } - public SocketError ReceiveFrom(Memory buffer, ref SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, int timeout, out int bytesReceived) + public unsafe SocketError ReceiveFrom(Memory buffer, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, int timeout, out int bytesReceived) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1538,7 +1532,7 @@ public SocketError ReceiveFrom(Memory buffer, ref SocketFlags flags, byte[ SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveFrom(_socket, buffer.Span, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || + (SocketPal.TryCompleteReceiveFrom(_socket, buffer.Span, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1551,23 +1545,24 @@ public SocketError ReceiveFrom(Memory buffer, ref SocketFlags flags, byte[ Flags = flags, SetReceivedFlags = true, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); flags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; + socketAddressLen = operation.SocketAddressLen; return operation.ErrorCode; } - public unsafe SocketError ReceiveFrom(Span buffer, ref SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, int timeout, out int bytesReceived) + public unsafe SocketError ReceiveFrom(Span buffer, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, int timeout, out int bytesReceived) { SocketFlags receivedFlags; SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveFrom(_socket, buffer, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || + (SocketPal.TryCompleteReceiveFrom(_socket, buffer, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1582,18 +1577,19 @@ public unsafe SocketError ReceiveFrom(Span buffer, ref SocketFlags flags, Length = buffer.Length, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); flags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; + socketAddressLen = operation.SocketAddressLen; return operation.ErrorCode; } } - public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int bytesReceived, Action callback, CancellationToken cancellationToken = default) + public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int bytesReceived, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default) { SetHandleNonBlocking(); @@ -1626,15 +1622,16 @@ public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int return SocketError.IOPending; } - public SocketError ReceiveFromAsync(Memory buffer, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action callback, CancellationToken cancellationToken = default) + public SocketError ReceiveFromAsync(Memory buffer, SocketFlags flags, Memory socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default) { SetHandleNonBlocking(); SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteReceiveFrom(_socket, buffer.Span, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) + SocketPal.TryCompleteReceiveFrom(_socket, buffer.Span, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) { + //ocketAddressLen = socketAddressLength; return errorCode; } @@ -1644,35 +1641,36 @@ public SocketError ReceiveFromAsync(Memory buffer, SocketFlags flags, byte operation.Buffer = buffer; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; + operation.SocketAddressLen = socketAddress.Length; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { receivedFlags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; errorCode = operation.ErrorCode; + socketAddressLen = operation.SocketAddressLen; ReturnOperation(operation); return errorCode; } bytesReceived = 0; + socketAddressLen = 0; receivedFlags = SocketFlags.None; return SocketError.IOPending; } public SocketError Receive(IList> buffers, SocketFlags flags, int timeout, out int bytesReceived) { - return ReceiveFrom(buffers, ref flags, null, 0, timeout, out bytesReceived); + return ReceiveFrom(buffers, ref flags, Memory.Empty, out int _, timeout, out bytesReceived); } - public SocketError ReceiveAsync(IList> buffers, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action callback) + public SocketError ReceiveAsync(IList> buffers, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action, SocketFlags, SocketError> callback) { - int socketAddressLen = 0; - return ReceiveFromAsync(buffers, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback); + return ReceiveFromAsync(buffers, flags, Memory.Empty, out int _, out bytesReceived, out receivedFlags, callback); } - public SocketError ReceiveFrom(IList> buffers, ref SocketFlags flags, byte[]? socketAddress, int socketAddressLen, int timeout, out int bytesReceived) + public unsafe SocketError ReceiveFrom(IList> buffers, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, int timeout, out int bytesReceived) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1680,7 +1678,7 @@ public SocketError ReceiveFrom(IList> buffers, ref SocketFlag SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveFrom(_socket, buffers, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || + (SocketPal.TryCompleteReceiveFrom(_socket, buffers, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1692,7 +1690,7 @@ public SocketError ReceiveFrom(IList> buffers, ref SocketFlag Buffers = buffers, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen + SocketAddressLen = socketAddress.Length, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); @@ -1703,14 +1701,14 @@ public SocketError ReceiveFrom(IList> buffers, ref SocketFlag return operation.ErrorCode; } - public SocketError ReceiveFromAsync(IList> buffers, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action callback) + public SocketError ReceiveFromAsync(IList> buffers, SocketFlags flags, Memory socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action, SocketFlags, SocketError> callback) { SetHandleNonBlocking(); SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteReceiveFrom(_socket, buffers, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) + SocketPal.TryCompleteReceiveFrom(_socket, buffers, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) { // Synchronous success or failure return errorCode; @@ -1721,7 +1719,7 @@ public SocketError ReceiveFromAsync(IList> buffers, SocketFla operation.Buffers = buffers; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; + operation.SocketAddressLen = socketAddress.Length; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) { @@ -1735,12 +1733,13 @@ public SocketError ReceiveFromAsync(IList> buffers, SocketFla } receivedFlags = SocketFlags.None; + socketAddressLen = 0; bytesReceived = 0; return SocketError.IOPending; } public SocketError ReceiveMessageFrom( - Memory buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) + Memory buffer, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1748,7 +1747,7 @@ public SocketError ReceiveMessageFrom( SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || + (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, null, flags, socketAddress, out socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1761,7 +1760,7 @@ public SocketError ReceiveMessageFrom( Buffers = null, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, IsIPv4 = isIPv4, IsIPv6 = isIPv6, }; @@ -1776,7 +1775,7 @@ public SocketError ReceiveMessageFrom( } public unsafe SocketError ReceiveMessageFrom( - Span buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) + Span buffer, ref SocketFlags flags, Memory socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1784,7 +1783,7 @@ public unsafe SocketError ReceiveMessageFrom( SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || + (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer, null, flags, socketAddress, out socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1799,7 +1798,7 @@ public unsafe SocketError ReceiveMessageFrom( Length = buffer.Length, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, IsIPv4 = isIPv4, IsIPv6 = isIPv6, }; @@ -1814,14 +1813,14 @@ public unsafe SocketError ReceiveMessageFrom( } } - public SocketError ReceiveMessageFromAsync(Memory buffer, IList>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action callback, CancellationToken cancellationToken = default) + public SocketError ReceiveMessageFromAsync(Memory buffer, IList>? buffers, SocketFlags flags, Memory socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action, SocketFlags, IPPacketInformation, SocketError> callback, CancellationToken cancellationToken = default) { SetHandleNonBlocking(); SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, buffers, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode)) + SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, buffers, flags, socketAddress, out socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode)) { return errorCode; } @@ -1833,7 +1832,7 @@ public SocketError ReceiveMessageFromAsync(Memory buffer, IList buffer, IList buffer, SocketFlags flags, int timeout, out int bytesSent) => - SendTo(buffer, flags, null, 0, timeout, out bytesSent); + SendTo(buffer, flags, Memory.Empty, timeout, out bytesSent); public SocketError Send(byte[] buffer, int offset, int count, SocketFlags flags, int timeout, out int bytesSent) { - return SendTo(buffer, offset, count, flags, null, 0, timeout, out bytesSent); + return SendTo(buffer, offset, count, flags, Memory.Empty, timeout, out bytesSent); } - public SocketError SendAsync(Memory buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action callback, CancellationToken cancellationToken) + public SocketError SendAsync(Memory buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken) { - int socketAddressLen = 0; - return SendToAsync(buffer, offset, count, flags, null, ref socketAddressLen, out bytesSent, callback, cancellationToken); + return SendToAsync(buffer, offset, count, flags, Memory.Empty, out bytesSent, callback, cancellationToken); } - public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, int timeout, out int bytesSent) + public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flags, Memory socketAddress, int timeout, out int bytesSent) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1875,7 +1874,7 @@ public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flag SocketError errorCode; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteSendTo(_socket, buffer, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode) || + (SocketPal.TryCompleteSendTo(_socket, buffer, ref offset, ref count, flags, socketAddress.Span, ref bytesSent, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { return errorCode; @@ -1888,7 +1887,7 @@ public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flag Count = count, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, BytesTransferred = bytesSent }; @@ -1898,7 +1897,7 @@ public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flag return operation.ErrorCode; } - public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, int timeout, out int bytesSent) + public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, Memory socketAddress, int timeout, out int bytesSent) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1907,7 +1906,7 @@ public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, b int bufferIndexIgnored = 0, offset = 0, count = buffer.Length; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteSendTo(_socket, buffer, null, ref bufferIndexIgnored, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode) || + (SocketPal.TryCompleteSendTo(_socket, buffer, null, ref bufferIndexIgnored, ref offset, ref count, flags, socketAddress.Span, ref bytesSent, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { return errorCode; @@ -1922,7 +1921,7 @@ public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, b Count = count, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, BytesTransferred = bytesSent }; @@ -1933,7 +1932,7 @@ public unsafe SocketError SendTo(ReadOnlySpan buffer, SocketFlags flags, b } } - public SocketError SendToAsync(Memory buffer, int offset, int count, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesSent, Action callback, CancellationToken cancellationToken = default) + public SocketError SendToAsync(Memory buffer, int offset, int count, SocketFlags flags, Memory socketAddress, out int bytesSent, Action, SocketFlags, SocketError> callback, CancellationToken cancellationToken = default) { SetHandleNonBlocking(); @@ -1941,7 +1940,7 @@ public SocketError SendToAsync(Memory buffer, int offset, int count, Socke SocketError errorCode; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteSendTo(_socket, buffer.Span, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode)) + SocketPal.TryCompleteSendTo(_socket, buffer.Span, ref offset, ref count, flags, socketAddress.Span, ref bytesSent, out errorCode)) { return errorCode; } @@ -1953,7 +1952,6 @@ public SocketError SendToAsync(Memory buffer, int offset, int count, Socke operation.Count = count; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; operation.BytesTransferred = bytesSent; if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) @@ -1970,16 +1968,15 @@ public SocketError SendToAsync(Memory buffer, int offset, int count, Socke public SocketError Send(IList> buffers, SocketFlags flags, int timeout, out int bytesSent) { - return SendTo(buffers, flags, null, 0, timeout, out bytesSent); + return SendTo(buffers, flags, Memory.Empty, timeout, out bytesSent); } - public SocketError SendAsync(IList> buffers, SocketFlags flags, out int bytesSent, Action callback) + public SocketError SendAsync(IList> buffers, SocketFlags flags, out int bytesSent, Action, SocketFlags, SocketError> callback) { - int socketAddressLen = 0; - return SendToAsync(buffers, flags, null, ref socketAddressLen, out bytesSent, callback); + return SendToAsync(buffers, flags, Memory.Empty, out bytesSent, callback); } - public SocketError SendTo(IList> buffers, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, int timeout, out int bytesSent) + public SocketError SendTo(IList> buffers, SocketFlags flags, Memory socketAddress, int timeout, out int bytesSent) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1989,7 +1986,7 @@ public SocketError SendTo(IList> buffers, SocketFlags flags, SocketError errorCode; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteSendTo(_socket, buffers, ref bufferIndex, ref offset, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode) || + (SocketPal.TryCompleteSendTo(_socket, buffers, ref bufferIndex, ref offset, flags, socketAddress.Span, ref bytesSent, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { return errorCode; @@ -2002,7 +1999,7 @@ public SocketError SendTo(IList> buffers, SocketFlags flags, Offset = offset, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, BytesTransferred = bytesSent }; @@ -2012,7 +2009,7 @@ public SocketError SendTo(IList> buffers, SocketFlags flags, return operation.ErrorCode; } - public SocketError SendToAsync(IList> buffers, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesSent, Action callback) + public SocketError SendToAsync(IList> buffers, SocketFlags flags, Memory socketAddress, out int bytesSent, Action, SocketFlags, SocketError> callback) { SetHandleNonBlocking(); @@ -2022,7 +2019,7 @@ public SocketError SendToAsync(IList> buffers, SocketFlags fl SocketError errorCode; int observedSequenceNumber; if (_sendQueue.IsReady(this, out observedSequenceNumber) && - SocketPal.TryCompleteSendTo(_socket, buffers, ref bufferIndex, ref offset, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode)) + SocketPal.TryCompleteSendTo(_socket, buffers, ref bufferIndex, ref offset, flags, socketAddress.Span, ref bytesSent, out errorCode)) { return errorCode; } @@ -2034,7 +2031,7 @@ public SocketError SendToAsync(IList> buffers, SocketFlags fl operation.Offset = offset; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddressLen; + operation.SocketAddressLen = socketAddress.Length; operation.BytesTransferred = bytesSent; if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index 1a09d751708e6..62be7145580b4 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -13,7 +13,7 @@ public partial class SocketAsyncEventArgs : EventArgs, IDisposable private IntPtr _acceptedFileDescriptor; private int _socketAddressSize; private SocketFlags _receivedFlags; - private Action? _transferCompletionCallback; + private Action, SocketFlags, SocketError>? _transferCompletionCallback; partial void InitializeInternals(); @@ -23,18 +23,18 @@ public partial class SocketAsyncEventArgs : EventArgs, IDisposable partial void CompleteCore(); - private void AcceptCompletionCallback(IntPtr acceptedFileDescriptor, byte[] socketAddress, int socketAddressSize, SocketError socketError) + private void AcceptCompletionCallback(IntPtr acceptedFileDescriptor, Memory socketAddress, SocketError socketError) { - CompleteAcceptOperation(acceptedFileDescriptor, socketAddress, socketAddressSize); + CompleteAcceptOperation(acceptedFileDescriptor, socketAddress); CompletionCallback(0, SocketFlags.None, socketError); } - private void CompleteAcceptOperation(IntPtr acceptedFileDescriptor, byte[] socketAddress, int socketAddressSize) + private void CompleteAcceptOperation(IntPtr acceptedFileDescriptor, Memory socketAddress) { _acceptedFileDescriptor = acceptedFileDescriptor; - Debug.Assert(socketAddress == null || socketAddress == _acceptBuffer, $"Unexpected socketAddress: {socketAddress}"); - _acceptAddressBufferCount = socketAddressSize; + Debug.Assert(socketAddress.Length > 0); + _acceptAddressBufferCount = socketAddress.Length; } internal unsafe SocketError DoOperationAccept(Socket _ /*socket*/, SafeSocketHandle handle, SafeSocketHandle? acceptHandle, CancellationToken cancellationToken) @@ -49,12 +49,12 @@ internal unsafe SocketError DoOperationAccept(Socket _ /*socket*/, SafeSocketHan Debug.Assert(acceptHandle == null, $"Unexpected acceptHandle: {acceptHandle}"); IntPtr acceptedFd; - int socketAddressLen = _acceptAddressBufferCount / 2; - SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken); + //int socketAddressLen = _acceptAddressBufferCount / 2; + SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, out int socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken); if (socketError != SocketError.IOPending) { - CompleteAcceptOperation(acceptedFd, _acceptBuffer!, socketAddressLen); + CompleteAcceptOperation(acceptedFd, new Memory(_acceptBuffer, 0, socketAddressLen)); FinishOperationSync(socketError, 0, SocketFlags.None); } @@ -86,19 +86,19 @@ internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handl return socketError; } - private Action TransferCompletionCallback => + private Action, SocketFlags, SocketError> TransferCompletionCallback => _transferCompletionCallback ??= TransferCompletionCallbackCore; - private void TransferCompletionCallbackCore(int bytesTransferred, byte[]? socketAddress, int socketAddressSize, SocketFlags receivedFlags, SocketError socketError) + private void TransferCompletionCallbackCore(int bytesTransferred, Memory socketAddress, SocketFlags receivedFlags, SocketError socketError) { - CompleteTransferOperation(socketAddress, socketAddressSize, receivedFlags); + CompleteTransferOperation(socketAddress, socketAddress.Length, receivedFlags); CompletionCallback(bytesTransferred, receivedFlags, socketError); } - private void CompleteTransferOperation(byte[]? socketAddress, int socketAddressSize, SocketFlags receivedFlags) + private void CompleteTransferOperation(Memory _, int socketAddressSize, SocketFlags receivedFlags) { - Debug.Assert(socketAddress == null || socketAddress == _socketAddress!.Buffer, $"Unexpected socketAddress: {socketAddress}"); + //Debug.Assert(socketAddress == null || socketAddress == _socketAddress!.Buffer, $"Unexpected socketAddress: {socketAddress}"); _socketAddressSize = socketAddressSize; _receivedFlags = receivedFlags; } @@ -147,14 +147,14 @@ internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle, Canc SocketFlags flags; SocketError errorCode; int bytesReceived; - int socketAddressLen = _socketAddress!.Size; + int socketAddressLen; // = _socketAddress!.Size; if (_bufferList == null) { - errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken); + errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress!.SocketBuffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken); } else { - errorCode = handle.AsyncContext.ReceiveFromAsync(_bufferListInternal!, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback); + errorCode = handle.AsyncContext.ReceiveFromAsync(_bufferListInternal!, _socketFlags, _socketAddress!.SocketBuffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback); } if (errorCode != SocketError.IOPending) @@ -166,19 +166,20 @@ internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle, Canc return errorCode; } - private void ReceiveMessageFromCompletionCallback(int bytesTransferred, byte[] socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation, SocketError errorCode) + private void ReceiveMessageFromCompletionCallback(int bytesTransferred, Memory socketAddress, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation, SocketError errorCode) { - CompleteReceiveMessageFromOperation(socketAddress, socketAddressSize, receivedFlags, ipPacketInformation); + CompleteReceiveMessageFromOperation(socketAddress, socketAddress.Length, receivedFlags, ipPacketInformation); CompletionCallback(bytesTransferred, receivedFlags, errorCode); } - private void CompleteReceiveMessageFromOperation(byte[] socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation) + private void CompleteReceiveMessageFromOperation(Memory socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation) { - Debug.Assert(_socketAddress != null, "Expected non-null _socketAddress"); - Debug.Assert(socketAddress == null || _socketAddress.Buffer == socketAddress, $"Unexpected socketAddress: {socketAddress}"); + //Debug.Assert(_socketAddress != null, "Expected non-null _socketAddress"); + //Debug.Assert(socketAddress == null || _socketAddress.Buffer == socketAddress, $"Unexpected socketAddress: {socketAddress}"); + Debug.Assert(socketAddress.Length == socketAddressSize); - _socketAddressSize = socketAddressSize; + _socketAddressSize = socketAddress.Length; _receivedFlags = receivedFlags; _receiveMessageFromPacketInfo = ipPacketInformation; } @@ -196,10 +197,11 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc int bytesReceived; SocketFlags receivedFlags; IPPacketInformation ipPacketInformation; - SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(_buffer.Slice(_offset, _count), _bufferListInternal, _socketFlags, _socketAddress.Buffer, ref socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, ReceiveMessageFromCompletionCallback, cancellationToken); + SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(_buffer.Slice(_offset, _count), _bufferListInternal, _socketFlags, _socketAddress.Buffer, out socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, ReceiveMessageFromCompletionCallback, cancellationToken); if (socketError != SocketError.IOPending) { - CompleteReceiveMessageFromOperation(_socketAddress.Buffer, socketAddressSize, receivedFlags, ipPacketInformation); + _socketAddress.Size = socketAddressSize; + CompleteReceiveMessageFromOperation(_socketAddress.SocketBuffer, socketAddressSize, receivedFlags, ipPacketInformation); FinishOperationSync(socketError, bytesReceived, receivedFlags); } return socketError; @@ -295,20 +297,20 @@ internal SocketError DoOperationSendTo(SafeSocketHandle handle, CancellationToke _socketAddressSize = 0; int bytesSent; - int socketAddressLen = _socketAddress!.Size; + //int socketAddressLen = _socketAddress!.Size; SocketError errorCode; if (_bufferList == null) { - errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesSent, TransferCompletionCallback, cancellationToken); + errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress!.SocketBuffer, out bytesSent, TransferCompletionCallback, cancellationToken); } else { - errorCode = handle.AsyncContext.SendToAsync(_bufferListInternal!, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesSent, TransferCompletionCallback); + errorCode = handle.AsyncContext.SendToAsync(_bufferListInternal!, _socketFlags, _socketAddress!.SocketBuffer, out bytesSent, TransferCompletionCallback); } if (errorCode != SocketError.IOPending) { - CompleteTransferOperation(_socketAddress.Buffer, socketAddressLen, SocketFlags.None); + CompleteTransferOperation(_socketAddress.SocketBuffer, _socketAddress.Size, SocketFlags.None); FinishOperationSync(errorCode, bytesSent, SocketFlags.None); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 3c81f6e4c81ce..5cc88ce0a91cc 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -138,16 +138,13 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, return received; } - private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, Span buffer, byte[]? socketAddress, ref int socketAddressLen, out SocketFlags receivedFlags, out Interop.Error errno) + private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, Span buffer, Span socketAddress, out int socketAddressLen, out SocketFlags receivedFlags, out Interop.Error errno) { Debug.Assert(socket.IsSocket); - Debug.Assert(socketAddress != null || socketAddressLen == 0, $"Unexpected values: socketAddress={socketAddress}, socketAddressLen={socketAddressLen}"); - long received = 0; - int sockAddrLen = socketAddress != null ? socketAddressLen : 0; - fixed (byte* sockAddr = socketAddress) + fixed (byte* sockAddr = &MemoryMarshal.GetReference(socketAddress)) fixed (byte* b = &MemoryMarshal.GetReference(buffer)) { var iov = new Interop.Sys.IOVector { @@ -157,7 +154,7 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, var messageHeader = new Interop.Sys.MessageHeader { SocketAddress = sockAddr, - SocketAddressLen = sockAddrLen, + SocketAddressLen = socketAddress.Length, IOVectors = &iov, IOVectorCount = 1 }; @@ -169,7 +166,7 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, &received); receivedFlags = messageHeader.Flags; - sockAddrLen = messageHeader.SocketAddressLen; + socketAddressLen = messageHeader.SocketAddressLen; } if (errno != Interop.Error.SUCCESS) @@ -177,7 +174,6 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, return -1; } - socketAddressLen = sockAddrLen; return checked((int)received); } @@ -237,7 +233,7 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, Re return sent; } - private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, ReadOnlySpan buffer, ref int offset, ref int count, byte[] socketAddress, int socketAddressLen, out Interop.Error errno) + private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, ReadOnlySpan buffer, ref int offset, ref int count, ReadOnlySpan socketAddress, out Interop.Error errno) { Debug.Assert(socket.IsSocket); @@ -256,7 +252,7 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, Re var messageHeader = new Interop.Sys.MessageHeader { SocketAddress = sockAddr, - SocketAddressLen = socketAddress != null ? socketAddressLen : 0, + SocketAddressLen = socketAddress.Length, IOVectors = &iov, IOVectorCount = 1 }; @@ -281,19 +277,13 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, Re return sent; } - private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, IList> buffers, ref int bufferIndex, ref int offset, byte[]? socketAddress, int socketAddressLen, out Interop.Error errno) + private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, IList> buffers, ref int bufferIndex, ref int offset, ReadOnlySpan socketAddress, out Interop.Error errno) { Debug.Assert(socket.IsSocket); // Pin buffers and set up iovecs. int startIndex = bufferIndex, startOffset = offset; - int sockAddrLen = 0; - if (socketAddress != null) - { - sockAddrLen = socketAddressLen; - } - int maxBuffers = buffers.Count - startIndex; bool allocOnStack = maxBuffers <= IovStackThreshold; Span handles = allocOnStack ? stackalloc GCHandle[IovStackThreshold] : new GCHandle[maxBuffers]; @@ -320,7 +310,7 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, IL { var messageHeader = new Interop.Sys.MessageHeader { SocketAddress = sockAddr, - SocketAddressLen = sockAddrLen, + SocketAddressLen = socketAddress.Length, IOVectors = iov, IOVectorCount = iovCount }; @@ -377,7 +367,7 @@ private static unsafe long SendFile(SafeSocketHandle socket, SafeFileHandle file return bytesSent; } - private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, IList> buffers, byte[]? socketAddress, ref int socketAddressLen, out SocketFlags receivedFlags, out Interop.Error errno) + private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, IList> buffers, Span socketAddress, out int socketAddressLen, out SocketFlags receivedFlags, out Interop.Error errno) { Debug.Assert(socket.IsSocket); @@ -392,6 +382,7 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, if (errno != Interop.Error.SUCCESS) { receivedFlags = 0; + socketAddressLen = 0; return -1; } if (available == 0) @@ -405,11 +396,11 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, Span handles = allocOnStack ? stackalloc GCHandle[IovStackThreshold] : new GCHandle[maxBuffers]; Span iovecs = allocOnStack ? stackalloc Interop.Sys.IOVector[IovStackThreshold] : new Interop.Sys.IOVector[maxBuffers]; - int sockAddrLen = 0; - if (socketAddress != null) - { - sockAddrLen = socketAddressLen; - } + int sockAddrLen = socketAddress.Length; + //if (socketAddress != null) + //{ + // sockAddrLen = socketAddressLen; + //} long received = 0; int toReceive = 0, iovCount = 0; @@ -469,16 +460,17 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, } } + socketAddressLen = sockAddrLen; + if (errno != Interop.Error.SUCCESS) { return -1; } - socketAddressLen = sockAddrLen; return checked((int)received); } - private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketFlags flags, Span buffer, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out Interop.Error errno) + private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketFlags flags, Span buffer, Span socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out Interop.Error errno) { Debug.Assert(socket.IsSocket); Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); @@ -486,7 +478,7 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF int cmsgBufferLen = Interop.Sys.GetControlMessageBufferSize(Convert.ToInt32(isIPv4), Convert.ToInt32(isIPv6)); byte* cmsgBuffer = stackalloc byte[cmsgBufferLen]; - int sockAddrLen = socketAddressLen; + int sockAddrLen = socketAddress.Length; Interop.Sys.MessageHeader messageHeader; @@ -518,6 +510,8 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF sockAddrLen = messageHeader.SocketAddressLen; } + socketAddressLen = sockAddrLen; + if (errno != Interop.Error.SUCCESS) { ipPacketInformation = default(IPPacketInformation); @@ -525,13 +519,12 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF } ipPacketInformation = GetIPPacketInformation(&messageHeader, isIPv4, isIPv6); - socketAddressLen = sockAddrLen; return checked((int)received); } private static unsafe int SysReceiveMessageFrom( SafeSocketHandle socket, SocketFlags flags, IList> buffers, - byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, + Span socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out Interop.Error errno) { Debug.Assert(socket.IsSocket); @@ -566,7 +559,7 @@ private static unsafe int SysReceiveMessageFrom( var messageHeader = new Interop.Sys.MessageHeader { SocketAddress = sockAddr, - SocketAddressLen = socketAddressLen, + SocketAddressLen = socketAddress.Length, IOVectors = iov, IOVectorCount = iovCount, ControlBuffer = cmsgBuffer, @@ -581,12 +574,11 @@ private static unsafe int SysReceiveMessageFrom( &received); receivedFlags = messageHeader.Flags; - int sockAddrLen = messageHeader.SocketAddressLen; + socketAddressLen = messageHeader.SocketAddressLen; if (errno == Interop.Error.SUCCESS) { ipPacketInformation = GetIPPacketInformation(&messageHeader, isIPv4, isIPv6); - socketAddressLen = sockAddrLen; return checked((int)received); } else @@ -606,22 +598,24 @@ private static unsafe int SysReceiveMessageFrom( } } - public static unsafe bool TryCompleteAccept(SafeSocketHandle socket, byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, out SocketError errorCode) + public static unsafe bool TryCompleteAccept(SafeSocketHandle socket, Memory socketAddress, out int socketAddressLen, out IntPtr acceptedFd, out SocketError errorCode) { IntPtr fd = IntPtr.Zero; Interop.Error errno; - int sockAddrLen = socketAddressLen; - fixed (byte* rawSocketAddress = socketAddress) + int sockAddrLen = socketAddress.Length; + fixed (byte* rawSocketAddress = socketAddress.Span) { try { errno = Interop.Sys.Accept(socket, rawSocketAddress, &sockAddrLen, &fd); + socketAddressLen = sockAddrLen; } catch (ObjectDisposedException) { // The socket was closed, or is closing. errorCode = SocketError.OperationAborted; acceptedFd = (IntPtr)(-1); + socketAddressLen = 0; return true; } } @@ -630,7 +624,6 @@ public static unsafe bool TryCompleteAccept(SafeSocketHandle socket, byte[] sock { Debug.Assert(fd != (IntPtr)(-1), "Expected fd != -1"); - socketAddressLen = sockAddrLen; errorCode = SocketError.Success; acceptedFd = fd; @@ -735,11 +728,11 @@ public static unsafe bool TryCompleteConnect(SafeSocketHandle socket, out Socket return true; } - public static bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) => - TryCompleteReceiveFrom(socket, buffer, null, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode); + public static bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, SocketFlags flags, Span socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) => + TryCompleteReceiveFrom(socket, buffer, null, flags, socketAddress, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode); - public static bool TryCompleteReceiveFrom(SafeSocketHandle socket, IList> buffers, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) => - TryCompleteReceiveFrom(socket, default(Span), buffers, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode); + public static bool TryCompleteReceiveFrom(SafeSocketHandle socket, IList> buffers, SocketFlags flags, Span socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) => + TryCompleteReceiveFrom(socket, default(Span), buffers, flags, socketAddress, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode); public static unsafe bool TryCompleteReceive(SafeSocketHandle socket, Span buffer, SocketFlags flags, out int bytesReceived, out SocketError errorCode) { @@ -800,12 +793,13 @@ public static unsafe bool TryCompleteReceive(SafeSocketHandle socket, Span } } - public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, byte[]? socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) + public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, Span socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) { try { Interop.Error errno; int received; + int socketAddressLength = 0; if (!socket.IsSocket) { @@ -819,7 +813,7 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span(&oneBytePeekBuffer, 1), socketAddress, ref socketAddressLen, out receivedFlags, out errno); + received = SysReceive(socket, flags | SocketFlags.Peek, new Span(&oneBytePeekBuffer, 1), socketAddress, out socketAddressLength, out receivedFlags, out errno); if (received > 0) { // Peeked for 1-byte, but the actual request was for 0. @@ -838,17 +832,19 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span 0 bytes into a single buffer - received = SysReceive(socket, flags, buffer, socketAddress, ref socketAddressLen, out receivedFlags, out errno); + received = SysReceive(socket, flags, buffer, socketAddress, out socketAddressLength, out receivedFlags, out errno); } if (received != -1) { bytesReceived = received; errorCode = SocketError.Success; + socketAddressLen = socketAddressLength; return true; } bytesReceived = 0; + socketAddressLen = 0; if (errno != Interop.Error.EAGAIN && errno != Interop.Error.EWOULDBLOCK) { @@ -864,20 +860,21 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out SocketError errorCode) + public static unsafe bool TryCompleteReceiveMessageFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, Memory socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out SocketError errorCode) { try { Interop.Error errno; int received = buffers == null ? - SysReceiveMessageFrom(socket, flags, buffer, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno) : - SysReceiveMessageFrom(socket, flags, buffers, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno); + SysReceiveMessageFrom(socket, flags, buffer, socketAddress.Span, out socketAddressLen, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno) : + SysReceiveMessageFrom(socket, flags, buffers, socketAddress.Span, out socketAddressLen, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno); if (received != -1) { @@ -902,31 +899,32 @@ public static unsafe bool TryCompleteReceiveMessageFrom(SafeSocketHandle socket, // The socket was closed, or is closing. bytesReceived = 0; receivedFlags = 0; + socketAddressLen = 0; ipPacketInformation = default(IPPacketInformation); errorCode = SocketError.OperationAborted; return true; } } - public static bool TryCompleteSendTo(SafeSocketHandle socket, Span buffer, ref int offset, ref int count, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, ref int bytesSent, out SocketError errorCode) + public static bool TryCompleteSendTo(SafeSocketHandle socket, Span buffer, ref int offset, ref int count, SocketFlags flags, ReadOnlySpan socketAddress, ref int bytesSent, out SocketError errorCode) { int bufferIndex = 0; - return TryCompleteSendTo(socket, buffer, null, ref bufferIndex, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode); + return TryCompleteSendTo(socket, buffer, null, ref bufferIndex, ref offset, ref count, flags, socketAddress, ref bytesSent, out errorCode); } - public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan buffer, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, ref int bytesSent, out SocketError errorCode) + public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan buffer, SocketFlags flags, ReadOnlySpan socketAddress, ref int bytesSent, out SocketError errorCode) { int bufferIndex = 0, offset = 0, count = buffer.Length; - return TryCompleteSendTo(socket, buffer, null, ref bufferIndex, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode); + return TryCompleteSendTo(socket, buffer, null, ref bufferIndex, ref offset, ref count, flags, socketAddress, ref bytesSent, out errorCode); } - public static bool TryCompleteSendTo(SafeSocketHandle socket, IList> buffers, ref int bufferIndex, ref int offset, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, ref int bytesSent, out SocketError errorCode) + public static bool TryCompleteSendTo(SafeSocketHandle socket, IList> buffers, ref int bufferIndex, ref int offset, SocketFlags flags, ReadOnlySpan socketAddress, ref int bytesSent, out SocketError errorCode) { int count = 0; - return TryCompleteSendTo(socket, default(ReadOnlySpan), buffers, ref bufferIndex, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode); + return TryCompleteSendTo(socket, default(ReadOnlySpan), buffers, ref bufferIndex, ref offset, ref count, flags, socketAddress, ref bytesSent, out errorCode); } - public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan buffer, IList>? buffers, ref int bufferIndex, ref int offset, ref int count, SocketFlags flags, byte[]? socketAddress, int socketAddressLen, ref int bytesSent, out SocketError errorCode) + public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan buffer, IList>? buffers, ref int bufferIndex, ref int offset, ref int count, SocketFlags flags, ReadOnlySpan socketAddress, ref int bytesSent, out SocketError errorCode) { bool successfulSend = false; long start = socket.IsUnderlyingHandleBlocking && socket.SendTimeout > 0 ? Environment.TickCount64 : 0; // Get ticks only if timeout is set and socket is blocking. @@ -946,9 +944,9 @@ public static bool TryCompleteSendTo(SafeSocketHandle socket, ReadOnlySpan else { sent = buffers != null ? - SysSend(socket, flags, buffers, ref bufferIndex, ref offset, socketAddress, socketAddressLen, out errno) : + SysSend(socket, flags, buffers, ref bufferIndex, ref offset, socketAddress, out errno) : socketAddress == null ? SysSend(socket, flags, buffer, ref offset, ref count, out errno) : - SysSend(socket, flags, buffer, ref offset, ref count, socketAddress, socketAddressLen, out errno); + SysSend(socket, flags, buffer, ref offset, ref count, socketAddress, out errno); } } catch (ObjectDisposedException) @@ -1094,7 +1092,7 @@ public static SocketError Listen(SafeSocketHandle handle, int backlog) return err == Interop.Error.SUCCESS ? SocketError.Success : GetSocketErrorForErrorCode(err); } - public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAddress, ref int socketAddressLen, out SafeSocketHandle socket) + public static SocketError Accept(SafeSocketHandle listenSocket, Memory socketAddress, out int socketAddressLen, out SafeSocketHandle socket) { socket = new SafeSocketHandle(); @@ -1102,11 +1100,11 @@ public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAdd SocketError errorCode; if (!listenSocket.IsNonBlocking) { - errorCode = listenSocket.AsyncContext.Accept(socketAddress, ref socketAddressLen, out acceptedFd); + errorCode = listenSocket.AsyncContext.Accept(socketAddress, out socketAddressLen, out acceptedFd); } else { - if (!TryCompleteAccept(listenSocket, socketAddress, ref socketAddressLen, out acceptedFd, out errorCode)) + if (!TryCompleteAccept(listenSocket, socketAddress, out socketAddressLen, out acceptedFd, out errorCode)) { errorCode = SocketError.WouldBlock; } @@ -1151,7 +1149,7 @@ public static SocketError Send(SafeSocketHandle handle, IList int bufferIndex = 0; int offset = 0; SocketError errorCode; - TryCompleteSendTo(handle, bufferList, ref bufferIndex, ref offset, socketFlags, null, 0, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, bufferList, ref bufferIndex, ref offset, socketFlags, ReadOnlySpan.Empty, ref bytesTransferred, out errorCode); return errorCode; } @@ -1164,7 +1162,7 @@ public static SocketError Send(SafeSocketHandle handle, byte[] buffer, int offse bytesTransferred = 0; SocketError errorCode; - TryCompleteSendTo(handle, buffer, ref offset, ref count, socketFlags, null, 0, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, buffer, ref offset, ref count, socketFlags, ReadOnlySpan.Empty, ref bytesTransferred, out errorCode); return errorCode; } @@ -1177,7 +1175,7 @@ public static SocketError Send(SafeSocketHandle handle, ReadOnlySpan buffe bytesTransferred = 0; SocketError errorCode; - TryCompleteSendTo(handle, buffer, socketFlags, null, 0, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, buffer, socketFlags, ReadOnlySpan.Empty, ref bytesTransferred, out errorCode); return errorCode; } @@ -1197,29 +1195,29 @@ public static SocketError SendFile(SafeSocketHandle handle, SafeFileHandle fileH return completed ? errorCode : SocketError.WouldBlock; } - public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, byte[] socketAddress, int socketAddressLen, out int bytesTransferred) + public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Memory socketAddress, out int bytesTransferred) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.SendTo(buffer, offset, count, socketFlags, socketAddress, socketAddressLen, handle.SendTimeout, out bytesTransferred); + return handle.AsyncContext.SendTo(buffer, offset, count, socketFlags, socketAddress, handle.SendTimeout, out bytesTransferred); } bytesTransferred = 0; SocketError errorCode; - TryCompleteSendTo(handle, buffer, ref offset, ref count, socketFlags, socketAddress, socketAddressLen, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, buffer, ref offset, ref count, socketFlags, socketAddress.Span, ref bytesTransferred, out errorCode); return errorCode; } - public static SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, byte[] socketAddress, int socketAddressLen, out int bytesTransferred) + public static SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, Memory socketAddress, out int bytesTransferred) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.SendTo(buffer, socketFlags, socketAddress, socketAddressLen, handle.SendTimeout, out bytesTransferred); + return handle.AsyncContext.SendTo(buffer, socketFlags, socketAddress, handle.SendTimeout, out bytesTransferred); } bytesTransferred = 0; SocketError errorCode; - TryCompleteSendTo(handle, buffer, socketFlags, socketAddress, socketAddressLen, ref bytesTransferred, out errorCode); + TryCompleteSendTo(handle, buffer, socketFlags, socketAddress.Span, ref bytesTransferred, out errorCode); return errorCode; } @@ -1232,8 +1230,7 @@ public static SocketError Receive(SafeSocketHandle handle, IList buffer, So public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int count, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { - byte[] socketAddressBuffer = socketAddress.Buffer; - int socketAddressLen = socketAddress.Size; + int socketAddressLen; // = socketAddress.Size; bool isIPv4, isIPv6; Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out isIPv4, out isIPv6); @@ -1277,11 +1273,11 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han SocketError errorCode; if (!handle.IsNonBlocking) { - errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); + errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddress.SocketBuffer, out socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); } else { - if (!TryCompleteReceiveMessageFrom(handle, new Span(buffer, offset, count), null, socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) + if (!TryCompleteReceiveMessageFrom(handle, new Span(buffer, offset, count), null, socketFlags, socketAddress.SocketBuffer, out socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) { errorCode = SocketError.WouldBlock; } @@ -1296,7 +1292,7 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { byte[] socketAddressBuffer = socketAddress.Buffer; - int socketAddressLen = socketAddress.Size; + int socketAddressLen; bool isIPv4, isIPv6; Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out isIPv4, out isIPv6); @@ -1304,11 +1300,11 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han SocketError errorCode; if (!handle.IsNonBlocking) { - errorCode = handle.AsyncContext.ReceiveMessageFrom(buffer, ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); + errorCode = handle.AsyncContext.ReceiveMessageFrom(buffer, ref socketFlags, socketAddressBuffer, out socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); } else { - if (!TryCompleteReceiveMessageFrom(handle, buffer, null, socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) + if (!TryCompleteReceiveMessageFrom(handle, buffer, null, socketFlags, socketAddressBuffer, out socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) { errorCode = SocketError.WouldBlock; } @@ -1319,27 +1315,27 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han return errorCode; } - public static SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, byte[] socketAddress, ref int socketAddressLen, out int bytesTransferred) + public static SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Memory socketAddress, out int socketAddressLen, out int bytesTransferred) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.ReceiveFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddress, ref socketAddressLen, handle.ReceiveTimeout, out bytesTransferred); + return handle.AsyncContext.ReceiveFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddress, out socketAddressLen, handle.ReceiveTimeout, out bytesTransferred); } SocketError errorCode; - bool completed = TryCompleteReceiveFrom(handle, new Span(buffer, offset, count), socketFlags, socketAddress, ref socketAddressLen, out bytesTransferred, out socketFlags, out errorCode); + bool completed = TryCompleteReceiveFrom(handle, new Span(buffer, offset, count), socketFlags, socketAddress.Span, out socketAddressLen, out bytesTransferred, out socketFlags, out errorCode); return completed ? errorCode : SocketError.WouldBlock; } - public static SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, byte[] socketAddress, ref int socketAddressLen, out int bytesTransferred) + public static SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, Memory socketAddress, out int socketAddressLen, out int bytesTransferred) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.ReceiveFrom(buffer, ref socketFlags, socketAddress, ref socketAddressLen, handle.ReceiveTimeout, out bytesTransferred); + return handle.AsyncContext.ReceiveFrom(buffer, ref socketFlags, socketAddress, out socketAddressLen, handle.ReceiveTimeout, out bytesTransferred); } SocketError errorCode; - bool completed = TryCompleteReceiveFrom(handle, buffer, socketFlags, socketAddress, ref socketAddressLen, out bytesTransferred, out socketFlags, out errorCode); + bool completed = TryCompleteReceiveFrom(handle, buffer, socketFlags, socketAddress.Span, out socketAddressLen, out bytesTransferred, out socketFlags, out errorCode); return completed ? errorCode : SocketError.WouldBlock; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index 749c26a608ebc..8d92a17740f28 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -176,10 +176,11 @@ public static SocketError Listen(SafeSocketHandle handle, int backlog) return errorCode == SocketError.SocketError ? GetLastSocketError() : SocketError.Success; } - public static SocketError Accept(SafeSocketHandle listenSocket, byte[] socketAddress, ref int socketAddressSize, out SafeSocketHandle socket) + public static SocketError Accept(SafeSocketHandle listenSocket, Memory socketAddress, out int socketAddressSize, out SafeSocketHandle socket) { socket = new SafeSocketHandle(); - Marshal.InitHandle(socket, Interop.Winsock.accept(listenSocket, socketAddress, ref socketAddressSize)); + socketAddressSize = socketAddress.Length; + Marshal.InitHandle(socket, Interop.Winsock.accept(listenSocket, socketAddress.Span, ref socketAddressSize)); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, socket); @@ -300,15 +301,15 @@ public static unsafe SocketError SendFile(SafeSocketHandle handle, SafeFileHandl } } - public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags socketFlags, byte[] peerAddress, int peerAddressSize, out int bytesTransferred) => - SendTo(handle, buffer.AsSpan(offset, size), socketFlags, peerAddress, peerAddressSize, out bytesTransferred); + public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags socketFlags, ReadOnlyMemory peerAddress, out int bytesTransferred) => + SendTo(handle, buffer.AsSpan(offset, size), socketFlags, peerAddress, out bytesTransferred); - public static unsafe SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, byte[] peerAddress, int peerAddressSize, out int bytesTransferred) + public static unsafe SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, ReadOnlyMemory peerAddress, out int bytesTransferred) { int bytesSent; fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) { - bytesSent = Interop.Winsock.sendto(handle, bufferPtr, buffer.Length, socketFlags, peerAddress, peerAddressSize); + bytesSent = Interop.Winsock.sendto(handle, bufferPtr, buffer.Length, socketFlags, peerAddress.Span, peerAddress.Length); } if (bytesSent == (int)SocketError.SocketError) @@ -512,17 +513,15 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan return SocketError.Success; } - public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags _ /*socketFlags*/, byte[] socketAddress, ref int addressLength, out int bytesTransferred) => - ReceiveFrom(handle, buffer.AsSpan(offset, size), SocketFlags.None, socketAddress, ref addressLength, out bytesTransferred); + public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags _ /*socketFlags*/, Memory socketAddress, out int addressLength, out int bytesTransferred) => + ReceiveFrom(handle, buffer.AsSpan(offset, size), SocketFlags.None, socketAddress, out addressLength, out bytesTransferred); - public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, byte[] socketAddress, ref int addressLength, out int bytesTransferred) + public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, Memory socketAddress, out int addressLength, out int bytesTransferred) { int bytesReceived; - fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) - { - bytesReceived = Interop.Winsock.recvfrom(handle, bufferPtr, buffer.Length, socketFlags, socketAddress, ref addressLength); - } + addressLength = socketAddress.Length; + bytesReceived = Interop.Winsock.recvfrom(handle, buffer, buffer.Length, socketFlags, socketAddress.Span, ref addressLength); if (bytesReceived == (int)SocketError.SocketError) { From a09f0f34e541aedd1d999be83e4434d46e2478fe Mon Sep 17 00:00:00 2001 From: wfurt Date: Sun, 16 Jul 2023 14:20:11 -0700 Subject: [PATCH 02/18] update --- .../Windows/WinSock/Interop.WSAConnect.cs | 2 +- .../src/System/Net/Sockets/Socket.cs | 2 +- .../Net/Sockets/SocketAsyncContext.Unix.cs | 129 ++++++++++-------- .../Net/Sockets/SocketAsyncEventArgs.Unix.cs | 2 +- .../Sockets/SocketAsyncEventArgs.Windows.cs | 2 +- .../src/System/Net/Sockets/SocketPal.Unix.cs | 15 +- .../System/Net/Sockets/SocketPal.Windows.cs | 6 +- .../tests/FunctionalTests/ReceiveFrom.cs | 46 +++++++ 8 files changed, 130 insertions(+), 74 deletions(-) diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs index 77c846af58f8d..aef925573e60f 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs @@ -12,7 +12,7 @@ internal static partial class Winsock [LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true)] internal static partial SocketError WSAConnect( SafeSocketHandle socketHandle, - byte[] socketAddress, + Span socketAddress, int socketAddressSize, IntPtr inBuffer, IntPtr outBuffer, diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index e95fab725709b..83514bb1acfc1 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -3172,7 +3172,7 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket SocketError errorCode; try { - errorCode = SocketPal.Connect(_handle, socketAddress.Buffer, socketAddress.Size); + errorCode = SocketPal.Connect(_handle, socketAddress.SocketBuffer); } catch (Exception ex) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index 255ecc8b2300d..c3c8f7bbabdcb 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -48,7 +48,7 @@ private void ReturnOperation(AcceptOperation operation) { operation.Reset(); operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedAcceptOperation, operation); // benign race condition } @@ -57,7 +57,7 @@ private void ReturnOperation(BufferMemoryReceiveOperation operation) operation.Reset(); operation.Buffer = default; operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedBufferMemoryReceiveOperation, operation); // benign race condition } @@ -66,7 +66,7 @@ private void ReturnOperation(BufferListReceiveOperation operation) operation.Reset(); operation.Buffers = null; operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedBufferListReceiveOperation, operation); // benign race condition } @@ -75,7 +75,7 @@ private void ReturnOperation(BufferMemorySendOperation operation) operation.Reset(); operation.Buffer = default; operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedBufferMemorySendOperation, operation); // benign race condition } @@ -84,7 +84,7 @@ private void ReturnOperation(BufferListSendOperation operation) operation.Reset(); operation.Buffers = null; operation.Callback = null; - operation.SocketAddress = null; + operation.SocketAddress = default; Volatile.Write(ref _cachedBufferListSendOperation, operation); // benign race condition } @@ -108,7 +108,7 @@ private BufferListSendOperation RentBufferListSendOperation() => Interlocked.Exchange(ref _cachedBufferListSendOperation, null) ?? new BufferListSendOperation(this); - private abstract unsafe class AsyncOperation : IThreadPoolWorkItem + private abstract class AsyncOperation : IThreadPoolWorkItem { private enum State { @@ -129,7 +129,6 @@ private enum State public AsyncOperation Next = null!; // initialized by helper called from ctor public SocketError ErrorCode; public Memory SocketAddress; - public int SocketAddressLen; public CancellationTokenRegistration CancellationRegistration; public ManualResetEventSlim? Event { get; set; } @@ -476,7 +475,9 @@ protected override bool DoTryComplete(SocketAsyncContext context) } else { - return SocketPal.TryCompleteReceiveFrom(context._socket, Buffer.Span, null, Flags, SocketAddress.Span, out SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + bool result = SocketPal.TryCompleteReceiveFrom(context._socket, Buffer.Span, null, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + return result; } } } @@ -486,7 +487,6 @@ public override void InvokeCallback(bool allowPooling) var cb = Callback!; int bt = BytesTransferred; Memory sa = SocketAddress; - int sal = SocketAddressLen; SocketFlags rf = ReceivedFlags; SocketError ec = ErrorCode; @@ -505,8 +505,15 @@ private sealed class BufferListReceiveOperation : ReceiveOperation public BufferListReceiveOperation(SocketAsyncContext context) : base(context) { } - protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveFrom(context._socket, default(Span), Buffers, Flags, SocketAddress.Span, out SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + protected override bool DoTryComplete(SocketAsyncContext context) + { + bool completed = SocketPal.TryCompleteReceiveFrom(context._socket, default(Span), Buffers, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } + return completed; + } public override void InvokeCallback(bool allowPooling) { @@ -532,8 +539,15 @@ private sealed unsafe class BufferPtrReceiveOperation : ReceiveOperation public BufferPtrReceiveOperation(SocketAsyncContext context) : base(context) { } - protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress.Span, out SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + protected override bool DoTryComplete(SocketAsyncContext context) + { + bool completed = SocketPal.TryCompleteReceiveFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } + return completed; + } } private sealed class ReceiveMessageFromOperation : ReadOperation @@ -552,8 +566,15 @@ public ReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { public Action, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; } - protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveMessageFrom(context._socket, Buffer.Span, Buffers, Flags, SocketAddress, out SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + protected override bool DoTryComplete(SocketAsyncContext context) + { + bool completed = SocketPal.TryCompleteReceiveMessageFrom(context._socket, Buffer.Span, Buffers, Flags, SocketAddress, out int socketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } + return completed; + } public override void InvokeCallback(bool allowPooling) => Callback!(BytesTransferred, SocketAddress, ReceivedFlags, IPPacketInformation, ErrorCode); @@ -573,13 +594,20 @@ private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation public BufferPtrReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { } - public Action, int, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; } + public Action, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; } - protected override bool DoTryComplete(SocketAsyncContext context) => - SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress!, out SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + protected override bool DoTryComplete(SocketAsyncContext context) + { + bool completed = SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress!, out int socketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } + return completed; + } public override void InvokeCallback(bool allowPooling) => - Callback!(BytesTransferred, SocketAddress, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode); + Callback!(BytesTransferred, SocketAddress, ReceivedFlags, IPPacketInformation, ErrorCode); } private sealed class AcceptOperation : ReadOperation @@ -592,8 +620,12 @@ public AcceptOperation(SocketAsyncContext context) : base(context) { } protected override bool DoTryComplete(SocketAsyncContext context) { - bool completed = SocketPal.TryCompleteAccept(context._socket, SocketAddress, out SocketAddressLen, out AcceptedFileDescriptor, out ErrorCode); + bool completed = SocketPal.TryCompleteAccept(context._socket, SocketAddress, out int socketAddressLen, out AcceptedFileDescriptor, out ErrorCode); Debug.Assert(ErrorCode == SocketError.Success || AcceptedFileDescriptor == (IntPtr)(-1), $"Unexpected values: ErrorCode={ErrorCode}, AcceptedFileDescriptor={AcceptedFileDescriptor}"); + if (ErrorCode == SocketError.Success) + { + SocketAddress = SocketAddress.Slice(0, socketAddressLen); + } return completed; } @@ -602,7 +634,6 @@ public override void InvokeCallback(bool allowPooling) var cb = Callback!; IntPtr fd = AcceptedFileDescriptor; Memory sa = SocketAddress; - int sal = SocketAddressLen; SocketError ec = ErrorCode; if (allowPooling) @@ -610,7 +641,7 @@ public override void InvokeCallback(bool allowPooling) AssociatedContext.ReturnOperation(this); } - cb(fd, sa.Slice(0, sal), ec); + cb(fd, sa, ec); } } @@ -1396,12 +1427,11 @@ public SocketError Accept(Memory socketAddress, out int socketAddressLen, var operation = new AcceptOperation(this) { SocketAddress = socketAddress, - SocketAddressLen = socketAddress.Length, }; PerformSyncOperation(ref _receiveQueue, operation, -1, observedSequenceNumber); - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; acceptedFd = operation.AcceptedFileDescriptor; return operation.ErrorCode; } @@ -1426,11 +1456,10 @@ public SocketError AcceptAsync(Memory socketAddress, out int socketAddress AcceptOperation operation = RentAcceptOperation(); operation.Callback = callback; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddress.Length; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; acceptedFd = operation.AcceptedFileDescriptor; errorCode = operation.ErrorCode; @@ -1443,10 +1472,9 @@ public SocketError AcceptAsync(Memory socketAddress, out int socketAddress return SocketError.IOPending; } - public SocketError Connect(byte[] socketAddress, int socketAddressLen) + public SocketError Connect(Memory socketAddress) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); // Connect is different than the usual "readiness" pattern of other operations. // We need to call TryStartConnect to initiate the connect with the OS, @@ -1455,7 +1483,7 @@ public SocketError Connect(byte[] socketAddress, int socketAddressLen) SocketError errorCode; int observedSequenceNumber; _sendQueue.IsReady(this, out observedSequenceNumber); - if (SocketPal.TryStartConnect(_socket, socketAddress, socketAddressLen, out errorCode) || + if (SocketPal.TryStartConnect(_socket, socketAddress, out errorCode) || !ShouldRetrySyncOperation(out errorCode)) { _socket.RegisterConnectResult(errorCode); @@ -1465,7 +1493,6 @@ public SocketError Connect(byte[] socketAddress, int socketAddressLen) var operation = new ConnectOperation(this) { SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen }; PerformSyncOperation(ref _sendQueue, operation, -1, observedSequenceNumber); @@ -1473,10 +1500,9 @@ public SocketError Connect(byte[] socketAddress, int socketAddressLen) return operation.ErrorCode; } - public SocketError ConnectAsync(byte[] socketAddress, int socketAddressLen, Action callback) + public SocketError ConnectAsync(Memory socketAddress, Action callback) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); Debug.Assert(callback != null, "Expected non-null callback"); SetHandleNonBlocking(); @@ -1487,7 +1513,7 @@ public SocketError ConnectAsync(byte[] socketAddress, int socketAddressLen, Acti SocketError errorCode; int observedSequenceNumber; _sendQueue.IsReady(this, out observedSequenceNumber); - if (SocketPal.TryStartConnect(_socket, socketAddress, socketAddressLen, out errorCode)) + if (SocketPal.TryStartConnect(_socket, socketAddress, out errorCode)) { _socket.RegisterConnectResult(errorCode); return errorCode; @@ -1497,7 +1523,6 @@ public SocketError ConnectAsync(byte[] socketAddress, int socketAddressLen, Acti { Callback = callback, SocketAddress = socketAddress, - SocketAddressLen = socketAddressLen }; if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) @@ -1545,14 +1570,13 @@ public unsafe SocketError ReceiveFrom(Memory buffer, ref SocketFlags flags Flags = flags, SetReceivedFlags = true, SocketAddress = socketAddress, - SocketAddressLen = socketAddress.Length, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); flags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; return operation.ErrorCode; } @@ -1577,14 +1601,13 @@ public unsafe SocketError ReceiveFrom(Span buffer, ref SocketFlags flags, Length = buffer.Length, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddress.Length, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); flags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; return operation.ErrorCode; } } @@ -1606,8 +1629,7 @@ public SocketError ReceiveAsync(Memory buffer, SocketFlags flags, out int operation.Callback = callback; operation.Buffer = buffer; operation.Flags = flags; - operation.SocketAddress = null; - operation.SocketAddressLen = 0; + operation.SocketAddress = default; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { @@ -1631,7 +1653,6 @@ public SocketError ReceiveFromAsync(Memory buffer, SocketFlags flags, Memo if (_receiveQueue.IsReady(this, out observedSequenceNumber) && SocketPal.TryCompleteReceiveFrom(_socket, buffer.Span, flags, socketAddress.Span, out socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) { - //ocketAddressLen = socketAddressLength; return errorCode; } @@ -1641,14 +1662,13 @@ public SocketError ReceiveFromAsync(Memory buffer, SocketFlags flags, Memo operation.Buffer = buffer; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddress.Length; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { receivedFlags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; errorCode = operation.ErrorCode; - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; ReturnOperation(operation); return errorCode; @@ -1690,12 +1710,11 @@ public unsafe SocketError ReceiveFrom(IList> buffers, ref Soc Buffers = buffers, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddress.Length, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; flags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; return operation.ErrorCode; @@ -1719,11 +1738,10 @@ public SocketError ReceiveFromAsync(IList> buffers, SocketFla operation.Buffers = buffers; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddress.Length; if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) { - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; receivedFlags = operation.ReceivedFlags; bytesReceived = operation.BytesTransferred; errorCode = operation.ErrorCode; @@ -1760,14 +1778,13 @@ public SocketError ReceiveMessageFrom( Buffers = null, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddress.Length, IsIPv4 = isIPv4, IsIPv6 = isIPv6, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; flags = operation.ReceivedFlags; ipPacketInformation = operation.IPPacketInformation; bytesReceived = operation.BytesTransferred; @@ -1798,14 +1815,13 @@ public unsafe SocketError ReceiveMessageFrom( Length = buffer.Length, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddress.Length, IsIPv4 = isIPv4, IsIPv6 = isIPv6, }; PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); - socketAddressLen = operation.SocketAddressLen; + socketAddressLen = operation.SocketAddress.Length; flags = operation.ReceivedFlags; ipPacketInformation = operation.IPPacketInformation; bytesReceived = operation.BytesTransferred; @@ -1832,14 +1848,13 @@ public SocketError ReceiveMessageFromAsync(Memory buffer, IList buffer, SocketFlags flags, M Count = count, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddress.Length, BytesTransferred = bytesSent }; @@ -1999,7 +2012,6 @@ public SocketError SendTo(IList> buffers, SocketFlags flags, Offset = offset, Flags = flags, SocketAddress = socketAddress, - SocketAddressLen = socketAddress.Length, BytesTransferred = bytesSent }; @@ -2031,7 +2043,6 @@ public SocketError SendToAsync(IList> buffers, SocketFlags fl operation.Offset = offset; operation.Flags = flags; operation.SocketAddress = socketAddress; - operation.SocketAddressLen = socketAddress.Length; operation.BytesTransferred = bytesSent; if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index 62be7145580b4..52bc8d6792d9a 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -71,7 +71,7 @@ internal unsafe SocketError DoOperationConnectEx(Socket _ /*socket*/, SafeSocket internal unsafe SocketError DoOperationConnect(SafeSocketHandle handle) { - SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, _socketAddress.Size, ConnectCompletionCallback); + SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.SocketBuffer, ConnectCompletionCallback); if (socketError != SocketError.IOPending) { FinishOperationSync(socketError, 0, SocketFlags.None); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index 004b8dab0342b..dc56421349b60 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -284,7 +284,7 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha internal SocketError DoOperationConnect(SafeSocketHandle handle) { // Called for connectionless protocols. - SocketError socketError = SocketPal.Connect(handle, _socketAddress!.Buffer, _socketAddress.Size); + SocketError socketError = SocketPal.Connect(handle, _socketAddress!.SocketBuffer); FinishOperationSync(socketError, 0, SocketFlags.None); return socketError; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 5cc88ce0a91cc..1ade09259eb2a 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -641,10 +641,9 @@ public static unsafe bool TryCompleteAccept(SafeSocketHandle socket, Memory socketAddress, out SocketError errorCode) { - Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); - Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); + Debug.Assert(socketAddress.Length > 0, $"Unexpected socketAddressLen: {socketAddress.Length}"); if (socket.IsDisconnected) { @@ -653,9 +652,9 @@ public static unsafe bool TryStartConnect(SafeSocketHandle socket, byte[] socket } Interop.Error err; - fixed (byte* rawSocketAddress = socketAddress) + fixed (byte* rawSocketAddress = socketAddress.Span) { - err = Interop.Sys.Connect(socket, rawSocketAddress, socketAddressLen); + err = Interop.Sys.Connect(socket, rawSocketAddress, socketAddress.Length); } if (err == Interop.Error.SUCCESS) @@ -1117,15 +1116,15 @@ public static SocketError Accept(SafeSocketHandle listenSocket, Memory soc return errorCode; } - public static SocketError Connect(SafeSocketHandle handle, byte[] socketAddress, int socketAddressLen) + public static SocketError Connect(SafeSocketHandle handle, Memory socketAddress) { if (!handle.IsNonBlocking) { - return handle.AsyncContext.Connect(socketAddress, socketAddressLen); + return handle.AsyncContext.Connect(socketAddress); } SocketError errorCode; - bool completed = TryStartConnect(handle, socketAddress, socketAddressLen, out errorCode); + bool completed = TryStartConnect(handle, socketAddress, out errorCode); if (completed) { handle.RegisterConnectResult(errorCode); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index 8d92a17740f28..68e4aadf321b8 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -187,12 +187,12 @@ public static SocketError Accept(SafeSocketHandle listenSocket, Memory soc return socket.IsInvalid ? GetLastSocketError() : SocketError.Success; } - public static SocketError Connect(SafeSocketHandle handle, byte[] peerAddress, int peerAddressLen) + public static SocketError Connect(SafeSocketHandle handle, Memory peerAddress) { SocketError errorCode = Interop.Winsock.WSAConnect( handle, - peerAddress, - peerAddressLen, + peerAddress.Span, + peerAddress.Length, IntPtr.Zero, IntPtr.Zero, IntPtr.Zero, diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs index 30f899edda040..1f590dec81416 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs @@ -154,6 +154,52 @@ public async Task ReceiveSent_UDP_Success(bool ipv4) } } + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ReceiveSent_SocketAddress_Success(bool ipv4) + { + //const int Offset = 10; + const int DatagramSize = 256; + const int DatagramsToSend = 16; + + IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; + using Socket server = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket client = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + + client.BindToAnonymousPort(address); + server.BindToAnonymousPort(address); + + byte[] sendBuffer = new byte[DatagramSize]; + byte[] receiveBuffer = new byte[DatagramSize]; + + SocketAddress serverSA = server.LocalEndPoint.Serialize(); + SocketAddress clientSA = client.LocalEndPoint.Serialize(); + SocketAddress sa = new SocketAddress(address.AddressFamily); + + Random rnd = new Random(0); + + for (int i = 0; i < DatagramsToSend; i++) + { + rnd.NextBytes(sendBuffer); + client.SendTo(sendBuffer.AsSpan(), SocketFlags.None, serverSA); + + int readBytes = server.ReceiveFrom(receiveBuffer, SocketFlags.None, sa); + Assert.Equal(sa, clientSA); + Assert.Equal(client.LocalEndPoint, client.LocalEndPoint.Create(sa)); + Assert.True(new Span(receiveBuffer, 0, readBytes).SequenceEqual(sendBuffer)); + + // and send it back to make sure it works. + rnd.NextBytes(sendBuffer); + server.SendTo(sendBuffer, SocketFlags.None, sa); + readBytes = client.ReceiveFrom(receiveBuffer, SocketFlags.None, sa); + Assert.Equal(sa, serverSA); + Assert.Equal(server.LocalEndPoint, server.LocalEndPoint.Create(sa)); + Assert.True(new Span(receiveBuffer, 0, readBytes).SequenceEqual(sendBuffer)); + + } + } + [Theory] [InlineData(true)] [InlineData(false)] From d5bdef4a731ff3dc1dc7901f1f25536d63f579e3 Mon Sep 17 00:00:00 2001 From: wfurt Date: Sun, 16 Jul 2023 14:51:42 -0700 Subject: [PATCH 03/18] update --- .../src/System/Net/Sockets/SocketAsyncContext.Unix.cs | 1 - .../src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs | 7 +------ .../src/System/Net/Sockets/SocketPal.Unix.cs | 7 +------ 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index c3c8f7bbabdcb..3f2e27bc78401 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -1535,7 +1535,6 @@ public SocketError ConnectAsync(Memory socketAddress, Action public SocketError Receive(Memory buffer, SocketFlags flags, int timeout, out int bytesReceived) { - //int socketAddressLen = 0; return ReceiveFrom(buffer, ref flags, Memory.Empty, out int _, timeout, out bytesReceived); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index 52bc8d6792d9a..12b94b579988c 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -49,7 +49,6 @@ internal unsafe SocketError DoOperationAccept(Socket _ /*socket*/, SafeSocketHan Debug.Assert(acceptHandle == null, $"Unexpected acceptHandle: {acceptHandle}"); IntPtr acceptedFd; - //int socketAddressLen = _acceptAddressBufferCount / 2; SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, out int socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken); if (socketError != SocketError.IOPending) @@ -98,7 +97,6 @@ private void TransferCompletionCallbackCore(int bytesTransferred, Memory s private void CompleteTransferOperation(Memory _, int socketAddressSize, SocketFlags receivedFlags) { - //Debug.Assert(socketAddress == null || socketAddress == _socketAddress!.Buffer, $"Unexpected socketAddress: {socketAddress}"); _socketAddressSize = socketAddressSize; _receivedFlags = receivedFlags; } @@ -147,7 +145,7 @@ internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle, Canc SocketFlags flags; SocketError errorCode; int bytesReceived; - int socketAddressLen; // = _socketAddress!.Size; + int socketAddressLen; if (_bufferList == null) { errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress!.SocketBuffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken); @@ -175,8 +173,6 @@ private void ReceiveMessageFromCompletionCallback(int bytesTransferred, Memory socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation) { - //Debug.Assert(_socketAddress != null, "Expected non-null _socketAddress"); - //Debug.Assert(socketAddress == null || _socketAddress.Buffer == socketAddress, $"Unexpected socketAddress: {socketAddress}"); Debug.Assert(socketAddress.Length == socketAddressSize); _socketAddressSize = socketAddress.Length; @@ -297,7 +293,6 @@ internal SocketError DoOperationSendTo(SafeSocketHandle handle, CancellationToke _socketAddressSize = 0; int bytesSent; - //int socketAddressLen = _socketAddress!.Size; SocketError errorCode; if (_bufferList == null) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 1ade09259eb2a..ba596ffa18d5e 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -397,11 +397,6 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags, Span iovecs = allocOnStack ? stackalloc Interop.Sys.IOVector[IovStackThreshold] : new Interop.Sys.IOVector[maxBuffers]; int sockAddrLen = socketAddress.Length; - //if (socketAddress != null) - //{ - // sockAddrLen = socketAddressLen; - //} - long received = 0; int toReceive = 0, iovCount = 0; try @@ -1264,7 +1259,7 @@ public static SocketError Receive(SafeSocketHandle handle, Span buffer, So public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int count, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { - int socketAddressLen; // = socketAddress.Size; + int socketAddressLen; bool isIPv4, isIPv6; Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out isIPv4, out isIPv6); From 5cdfde4020463996772fe396265d5438e4f38708 Mon Sep 17 00:00:00 2001 From: wfurt Date: Sun, 16 Jul 2023 19:35:28 -0700 Subject: [PATCH 04/18] set length --- src/libraries/Common/src/System/Net/SocketAddress.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index ef58290645883..17186b1c0bb9d 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -101,6 +101,7 @@ public SocketAddress(AddressFamily family, int size) size = (size + IntPtr.Size - 1) / IntPtr.Size * IntPtr.Size + IntPtr.Size; #endif Buffer = new byte[size]; + Buffer[0] = (byte)InternalSize; SocketAddressPal.SetAddressFamily(Buffer, family); } From b49e3d97d2085ee1bc271d2649fb60174674e164 Mon Sep 17 00:00:00 2001 From: wfurt Date: Tue, 18 Jul 2023 07:29:25 -0700 Subject: [PATCH 05/18] feedback --- .../Common/src/System/Net/SocketAddress.cs | 5 +++++ .../System.Net.Sockets/ref/System.Net.Sockets.cs | 4 ++-- .../src/System/Net/Sockets/Socket.cs | 16 ++++++++-------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index 17186b1c0bb9d..3eb56bd26cbfb 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -53,6 +53,7 @@ public int Size set { ArgumentOutOfRangeException.ThrowIfGreaterThan(value, Buffer.Length); + ArgumentOutOfRangeException.ThrowIfLessThan(value, MinSize); InternalSize = value; } } @@ -145,6 +146,10 @@ internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan buffer) SocketAddressPal.SetAddressFamily(Buffer, addressFamily); } + /// This represents underlying memory that can be passed to native OS calls. + /// + /// This memory can be invalidated if is changed or if the SocketAddress is used in another receive call. + /// public Memory SocketBuffer { get diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index ec2eab3363091..1dd563333c80e 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -400,7 +400,7 @@ public void Listen(int backlog) { } public int ReceiveFrom(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } - public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress remoteSA) { throw null; } + public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress socketAddress) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -445,7 +445,7 @@ public void SendFile(string? fileName, System.ReadOnlySpan preBuffer, Syst public int SendTo(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public int SendTo(System.ReadOnlySpan buffer, System.Net.EndPoint remoteEP) { throw null; } public int SendTo(System.ReadOnlySpan buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } - public int SendTo(System.ReadOnlySpan buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress remoteSA) { throw null; } + public int SendTo(System.ReadOnlySpan buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress socketAddress) { throw null; } public System.Threading.Tasks.Task SendToAsync(System.ArraySegment buffer, System.Net.EndPoint remoteEP) { throw null; } public System.Threading.Tasks.Task SendToAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public bool SendToAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 83514bb1acfc1..bf556c40e24ed 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -1382,20 +1382,20 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, EndPoint r /// /// A span of bytes that contains the data to be sent. /// A bitwise combination of the values. - /// The that represents the destination for the data. + /// The that represents the destination for the data. /// The number of bytes sent. /// remoteEP is . /// An error occurred when attempting to access the socket. /// The has been closed. - public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, SocketAddress remoteSA) + public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, SocketAddress socketAddress) { ThrowIfDisposed(); - ArgumentNullException.ThrowIfNull(remoteSA); + ArgumentNullException.ThrowIfNull(socketAddress); ValidateBlockingMode(); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, remoteSA.SocketBuffer, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.SocketBuffer, out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1886,19 +1886,19 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint /// /// A span of bytes that is the storage location for received data. /// A bitwise combination of the values. - /// An , passed by reference, that represents the remote server. + /// An , passed by reference, that represents the remote server. /// The number of bytes received. /// remoteEP is . /// An error occurred when attempting to access the socket. /// The has been closed. - public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress remoteSA) + public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress socketAddress) { ThrowIfDisposed(); ValidateBlockingMode(); int bytesTransferred; - SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, remoteSA.SocketBuffer, out int socketAddressSize, out bytesTransferred); + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, socketAddress.SocketBuffer, out int socketAddressSize, out bytesTransferred); UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); // If the native call fails we'll throw a SocketException. @@ -1925,7 +1925,7 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress throw socketException; } - remoteSA.Size = socketAddressSize; + socketAddress.Size = socketAddressSize; return bytesTransferred; } From b6d0ebf60256d8d2763bcc4f485f04987886d32d Mon Sep 17 00:00:00 2001 From: wfurt Date: Tue, 18 Jul 2023 22:19:16 -0600 Subject: [PATCH 06/18] update to match approved API --- .../Common/src/System/Net/SocketAddress.cs | 58 +++++++++---------- .../src/System/Net/SocketAddressPal.Unix.cs | 2 + .../System/Net/SocketAddressPal.Windows.cs | 2 + .../ref/System.Net.Primitives.cs | 2 +- .../ref/System.Net.Sockets.cs | 2 +- .../src/System/Net/Sockets/Socket.Windows.cs | 2 +- .../src/System/Net/Sockets/Socket.cs | 24 ++++---- .../Net/Sockets/SocketAsyncEventArgs.Unix.cs | 16 ++--- .../Sockets/SocketAsyncEventArgs.Windows.cs | 18 +++--- .../src/System/Net/Sockets/SocketPal.Unix.cs | 6 +- .../System/Net/Sockets/SocketPal.Windows.cs | 2 +- 11 files changed, 69 insertions(+), 65 deletions(-) diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index 3eb56bd26cbfb..14c2500b78d26 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -28,7 +28,7 @@ class SocketAddress #pragma warning restore CA1802 internal int InternalSize; - internal byte[] Buffer; + internal byte[] InternalBuffer; private const int MinSize = 2; private const int MaxSize = 32; // IrDA requires 32 bytes @@ -40,7 +40,7 @@ public AddressFamily Family { get { - return SocketAddressPal.GetAddressFamily(Buffer); + return SocketAddressPal.GetAddressFamily(InternalBuffer); } } @@ -52,7 +52,7 @@ public int Size } set { - ArgumentOutOfRangeException.ThrowIfGreaterThan(value, Buffer.Length); + ArgumentOutOfRangeException.ThrowIfGreaterThan(value, InternalBuffer.Length); ArgumentOutOfRangeException.ThrowIfLessThan(value, MinSize); InternalSize = value; } @@ -70,7 +70,7 @@ public byte this[int offset] { throw new IndexOutOfRangeException(); } - return Buffer[offset]; + return InternalBuffer[offset]; } set { @@ -78,11 +78,11 @@ public byte this[int offset] { throw new IndexOutOfRangeException(); } - if (Buffer[offset] != value) + if (InternalBuffer[offset] != value) { _changed = true; } - Buffer[offset] = value; + InternalBuffer[offset] = value; } } @@ -101,10 +101,10 @@ public SocketAddress(AddressFamily family, int size) // The following formula will extend 'size' to the alignment boundary then add IntPtr.Size more bytes. size = (size + IntPtr.Size - 1) / IntPtr.Size * IntPtr.Size + IntPtr.Size; #endif - Buffer = new byte[size]; - Buffer[0] = (byte)InternalSize; + InternalBuffer = new byte[size]; + InternalBuffer[0] = (byte)InternalSize; - SocketAddressPal.SetAddressFamily(Buffer, family); + SocketAddressPal.SetAddressFamily(InternalBuffer, family); } internal SocketAddress(IPAddress ipAddress) @@ -112,7 +112,7 @@ internal SocketAddress(IPAddress ipAddress) ((ipAddress.AddressFamily == AddressFamily.InterNetwork) ? IPv4AddressSize : IPv6AddressSize)) { // No Port. - SocketAddressPal.SetPort(Buffer, 0); + SocketAddressPal.SetPort(InternalBuffer, 0); if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6) { @@ -120,7 +120,7 @@ internal SocketAddress(IPAddress ipAddress) ipAddress.TryWriteBytes(addressBytes, out int bytesWritten); Debug.Assert(bytesWritten == IPAddressParserStatics.IPv6AddressBytes); - SocketAddressPal.SetIPv6Address(Buffer, addressBytes, (uint)ipAddress.ScopeId); + SocketAddressPal.SetIPv6Address(InternalBuffer, addressBytes, (uint)ipAddress.ScopeId); } else { @@ -129,32 +129,32 @@ internal SocketAddress(IPAddress ipAddress) #pragma warning restore CS0618 Debug.Assert(ipAddress.AddressFamily == AddressFamily.InterNetwork); - SocketAddressPal.SetIPv4Address(Buffer, address); + SocketAddressPal.SetIPv4Address(InternalBuffer, address); } } internal SocketAddress(IPAddress ipaddress, int port) : this(ipaddress) { - SocketAddressPal.SetPort(Buffer, unchecked((ushort)port)); + SocketAddressPal.SetPort(InternalBuffer, unchecked((ushort)port)); } internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan buffer) { - Buffer = buffer.ToArray(); + InternalBuffer = buffer.ToArray(); InternalSize = Buffer.Length; - SocketAddressPal.SetAddressFamily(Buffer, addressFamily); + SocketAddressPal.SetAddressFamily(InternalBuffer, addressFamily); } /// This represents underlying memory that can be passed to native OS calls. /// /// This memory can be invalidated if is changed or if the SocketAddress is used in another receive call. /// - public Memory SocketBuffer + public Memory Buffer { get { - return new Memory(Buffer, 0, InternalSize); + return new Memory(InternalBuffer, 0, InternalSize); } } @@ -166,14 +166,14 @@ internal IPAddress GetIPAddress() Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; uint scope; - SocketAddressPal.GetIPv6Address(Buffer, address, out scope); + SocketAddressPal.GetIPv6Address(InternalBuffer, address, out scope); return new IPAddress(address, (long)scope); } else if (Family == AddressFamily.InterNetwork) { Debug.Assert(Size >= IPv4AddressSize); - long address = (long)SocketAddressPal.GetIPv4Address(Buffer) & 0x0FFFFFFFF; + long address = (long)SocketAddressPal.GetIPv4Address(InternalBuffer) & 0x0FFFFFFFF; return new IPAddress(address); } else @@ -186,7 +186,7 @@ internal IPAddress GetIPAddress() } } - internal int GetPort() => (int)SocketAddressPal.GetPort(Buffer); + internal int GetPort() => (int)SocketAddressPal.GetPort(InternalBuffer); internal IPEndPoint GetIPEndPoint() { @@ -198,22 +198,22 @@ internal IPEndPoint GetIPEndPoint() internal void CopyAddressSizeIntoBuffer() { int addressSizeOffset = GetAddressSizeOffset(); - Buffer[addressSizeOffset] = unchecked((byte)(InternalSize)); - Buffer[addressSizeOffset + 1] = unchecked((byte)(InternalSize >> 8)); - Buffer[addressSizeOffset + 2] = unchecked((byte)(InternalSize >> 16)); - Buffer[addressSizeOffset + 3] = unchecked((byte)(InternalSize >> 24)); + InternalBuffer[addressSizeOffset] = unchecked((byte)(InternalSize)); + InternalBuffer[addressSizeOffset + 1] = unchecked((byte)(InternalSize >> 8)); + InternalBuffer[addressSizeOffset + 2] = unchecked((byte)(InternalSize >> 16)); + InternalBuffer[addressSizeOffset + 3] = unchecked((byte)(InternalSize >> 24)); } // Can be called after the above method did work. internal int GetAddressSizeOffset() { - return Buffer.Length - IntPtr.Size; + return InternalBuffer.Length - IntPtr.Size; } #endif public override bool Equals(object? comparand) => comparand is SocketAddress other && - Buffer.AsSpan(0, Size).SequenceEqual(other.Buffer.AsSpan(0, other.Size)); + InternalBuffer.AsSpan(0, Size).SequenceEqual(other.InternalBuffer.AsSpan(0, other.Size)); public override int GetHashCode() { @@ -227,7 +227,7 @@ public override int GetHashCode() for (i = 0; i < size; i += 4) { - _hash ^= BinaryPrimitives.ReadInt32LittleEndian(Buffer.AsSpan(i)); + _hash ^= BinaryPrimitives.ReadInt32LittleEndian(InternalBuffer.AsSpan(i)); } if ((Size & 3) != 0) { @@ -236,7 +236,7 @@ public override int GetHashCode() for (; i < Size; ++i) { - remnant |= ((int)Buffer[i]) << shift; + remnant |= ((int)InternalBuffer[i]) << shift; shift += 8; } _hash ^= remnant; @@ -276,7 +276,7 @@ public override string ToString() result[length++] = ':'; result[length++] = '{'; - byte[] buffer = Buffer; + byte[] buffer = InternalBuffer; for (int i = DataOffset; i < Size; i++) { if (i > DataOffset) diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs index 3aa6d95337c54..099eeed576adb 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -13,6 +13,8 @@ internal static class SocketAddressPal { public static readonly int IPv6AddressSize = GetIPv6AddressSize(); public static readonly int IPv4AddressSize = GetIPv4AddressSize(); + //public static readonly int IPv4AddressSize = GetUdsAddressSize(); + //public static readonly int IPv4AddressSize = GetMaxAddressSize(); private static unsafe int GetIPv6AddressSize() { diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs index d9bdfb1279704..4b0765065a82e 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs @@ -10,6 +10,8 @@ internal static class SocketAddressPal { public const int IPv6AddressSize = 28; public const int IPv4AddressSize = 16; + public const int UdsAddressSize = 110; + public const int MaxAddressSize = 128; public static AddressFamily GetAddressFamily(ReadOnlySpan buffer) { diff --git a/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs b/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs index 3b97047c98a7e..1b15b74b24525 100644 --- a/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs +++ b/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs @@ -356,7 +356,7 @@ public SocketAddress(System.Net.Sockets.AddressFamily family, int size) { } public System.Net.Sockets.AddressFamily Family { get { throw null; } } public byte this[int offset] { get { throw null; } set { } } public int Size { get { throw null; } set { } } - public System.Memory SocketBuffer { get { throw null; } } + public System.Memory Buffer { get { throw null; } } public override bool Equals(object? comparand) { throw null; } public override int GetHashCode() { throw null; } public override string ToString() { throw null; } diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index 1dd563333c80e..57de68d185656 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -400,7 +400,7 @@ public void Listen(int backlog) { } public int ReceiveFrom(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } - public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress socketAddress) { throw null; } + public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress receivedSocketAddress) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index c7467ee9d7b41..9b93f56b23cbc 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -73,7 +73,7 @@ public Socket(SocketInformation socketInformation) Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(ep); unsafe { - fixed (byte* bufferPtr = socketAddress.Buffer) + fixed (byte* bufferPtr = socketAddress.InternalBuffer) fixed (int* sizePtr = &socketAddress.InternalSize) { errorCode = SocketPal.GetSockName(_handle, bufferPtr, sizePtr); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index bf556c40e24ed..3063aea8cc5c6 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -304,7 +304,7 @@ public EndPoint? LocalEndPoint unsafe { - fixed (byte* buffer = socketAddress.Buffer) + fixed (byte* buffer = socketAddress.InternalBuffer) fixed (int* bufferSize = &socketAddress.InternalSize) { // This may throw ObjectDisposedException. @@ -346,7 +346,7 @@ public EndPoint? RemoteEndPoint // This may throw ObjectDisposedException. SocketError errorCode = SocketPal.GetPeerName( _handle, - socketAddress.Buffer, + socketAddress.InternalBuffer, ref socketAddress.InternalSize); if (errorCode != SocketError.Success) @@ -765,7 +765,7 @@ private void DoBind(EndPoint endPointSnapshot, Internals.SocketAddress socketAdd SocketError errorCode = SocketPal.Bind( _handle, _protocolType, - socketAddress.Buffer, + socketAddress.InternalBuffer, socketAddress.Size); // Throw an appropriate SocketException if the native call fails. @@ -1019,7 +1019,7 @@ public Socket Accept() { errorCode = SocketPal.Accept( _handle, - socketAddress.SocketBuffer, + socketAddress.Buffer, out socketAddressLen, out acceptedSocketHandle); socketAddress.Size = socketAddressLen; @@ -1284,7 +1284,7 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, Internals.SocketAddress socketAddress = Serialize(ref remoteEP); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, offset, size, socketFlags, socketAddress.SocketBuffer, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1356,7 +1356,7 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, EndPoint r Internals.SocketAddress socketAddress = Serialize(ref remoteEP); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.SocketBuffer, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer, out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1395,7 +1395,7 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, SocketAddr ValidateBlockingMode(); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.SocketBuffer, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer, out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1886,19 +1886,19 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint /// /// A span of bytes that is the storage location for received data. /// A bitwise combination of the values. - /// An , passed by reference, that represents the remote server. + /// An , that will be updated with value of the remote peer. /// The number of bytes received. /// remoteEP is . /// An error occurred when attempting to access the socket. /// The has been closed. - public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress socketAddress) + public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress receivedSocketAddress) { ThrowIfDisposed(); ValidateBlockingMode(); int bytesTransferred; - SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, socketAddress.SocketBuffer, out int socketAddressSize, out bytesTransferred); + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, receivedSocketAddress.Buffer, out int socketAddressSize, out bytesTransferred); UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); // If the native call fails we'll throw a SocketException. @@ -1925,7 +1925,7 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress throw socketException; } - socketAddress.Size = socketAddressSize; + receivedSocketAddress.Size = socketAddressSize; return bytesTransferred; } @@ -3172,7 +3172,7 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket SocketError errorCode; try { - errorCode = SocketPal.Connect(_handle, socketAddress.SocketBuffer); + errorCode = SocketPal.Connect(_handle, socketAddress.Buffer); } catch (Exception ex) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index 12b94b579988c..a8efcdab76322 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -70,7 +70,7 @@ internal unsafe SocketError DoOperationConnectEx(Socket _ /*socket*/, SafeSocket internal unsafe SocketError DoOperationConnect(SafeSocketHandle handle) { - SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.SocketBuffer, ConnectCompletionCallback); + SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback); if (socketError != SocketError.IOPending) { FinishOperationSync(socketError, 0, SocketFlags.None); @@ -148,11 +148,11 @@ internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle, Canc int socketAddressLen; if (_bufferList == null) { - errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress!.SocketBuffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken); + errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress!.Buffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken); } else { - errorCode = handle.AsyncContext.ReceiveFromAsync(_bufferListInternal!, _socketFlags, _socketAddress!.SocketBuffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback); + errorCode = handle.AsyncContext.ReceiveFromAsync(_bufferListInternal!, _socketFlags, _socketAddress!.Buffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback); } if (errorCode != SocketError.IOPending) @@ -197,7 +197,7 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc if (socketError != SocketError.IOPending) { _socketAddress.Size = socketAddressSize; - CompleteReceiveMessageFromOperation(_socketAddress.SocketBuffer, socketAddressSize, receivedFlags, ipPacketInformation); + CompleteReceiveMessageFromOperation(_socketAddress.Buffer, socketAddressSize, receivedFlags, ipPacketInformation); FinishOperationSync(socketError, bytesReceived, receivedFlags); } return socketError; @@ -296,16 +296,16 @@ internal SocketError DoOperationSendTo(SafeSocketHandle handle, CancellationToke SocketError errorCode; if (_bufferList == null) { - errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress!.SocketBuffer, out bytesSent, TransferCompletionCallback, cancellationToken); + errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress!.Buffer, out bytesSent, TransferCompletionCallback, cancellationToken); } else { - errorCode = handle.AsyncContext.SendToAsync(_bufferListInternal!, _socketFlags, _socketAddress!.SocketBuffer, out bytesSent, TransferCompletionCallback); + errorCode = handle.AsyncContext.SendToAsync(_bufferListInternal!, _socketFlags, _socketAddress!.Buffer, out bytesSent, TransferCompletionCallback); } if (errorCode != SocketError.IOPending) { - CompleteTransferOperation(_socketAddress.SocketBuffer, _socketAddress.Size, SocketFlags.None); + CompleteTransferOperation(_socketAddress.Buffer, _socketAddress.Size, SocketFlags.None); FinishOperationSync(errorCode, bytesSent, SocketFlags.None); } @@ -331,7 +331,7 @@ internal void LogBuffer(int size) private SocketError FinishOperationAccept(Internals.SocketAddress remoteSocketAddress) { - System.Buffer.BlockCopy(_acceptBuffer!, 0, remoteSocketAddress.Buffer, 0, _acceptAddressBufferCount); + System.Buffer.BlockCopy(_acceptBuffer!, 0, remoteSocketAddress.InternalBuffer, 0, _acceptAddressBufferCount); Socket acceptedSocket = _currentSocket!.CreateAcceptSocket( SocketPal.CreateSocket(_acceptedFileDescriptor), _currentSocket._rightEndPoint!.Create(remoteSocketAddress)); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index dc56421349b60..192a8a1b30bf6 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -284,7 +284,7 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha internal SocketError DoOperationConnect(SafeSocketHandle handle) { // Called for connectionless protocols. - SocketError socketError = SocketPal.Connect(handle, _socketAddress!.SocketBuffer); + SocketError socketError = SocketPal.Connect(handle, _socketAddress!.Buffer); FinishOperationSync(socketError, 0, SocketFlags.None); return socketError; } @@ -303,7 +303,7 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle { bool success = socket.ConnectEx( handle, - _socketAddress!.Buffer.AsSpan(), + _socketAddress!.InternalBuffer.AsSpan(), (IntPtr)(bufferPtr + _offset), _count, out int bytesTransferred, @@ -763,7 +763,7 @@ internal unsafe SocketError DoOperationSendToSingleBuffer(SafeSocketHandle handl 1, out int bytesTransferred, _socketFlags, - _socketAddress!.Buffer.AsSpan(), + _socketAddress!.InternalBuffer.AsSpan(), overlapped, IntPtr.Zero); @@ -790,7 +790,7 @@ internal unsafe SocketError DoOperationSendToMultiBuffer(SafeSocketHandle handle _bufferListInternal!.Count, out int bytesTransferred, _socketFlags, - _socketAddress!.Buffer.AsSpan(), + _socketAddress!.InternalBuffer.AsSpan(), overlapped, IntPtr.Zero); @@ -883,11 +883,11 @@ private unsafe IntPtr PtrSocketAddressBuffer get { Debug.Assert(_pinnedSocketAddress != null); - Debug.Assert(_pinnedSocketAddress.Buffer != null); - Debug.Assert(_pinnedSocketAddress.Buffer.Length > 0); + Debug.Assert(_pinnedSocketAddress.InternalBuffer != null); + Debug.Assert(_pinnedSocketAddress.InternalBuffer.Length > 0); Debug.Assert(_socketAddressGCHandle.IsAllocated); - Debug.Assert(_socketAddressGCHandle.Target == _pinnedSocketAddress.Buffer); - fixed (void* ptrSocketAddressBuffer = &_pinnedSocketAddress.Buffer[0]) + Debug.Assert(_socketAddressGCHandle.Target == _pinnedSocketAddress.InternalBuffer); + fixed (void* ptrSocketAddressBuffer = &_pinnedSocketAddress.InternalBuffer[0]) { return (IntPtr)ptrSocketAddressBuffer; } @@ -1075,7 +1075,7 @@ private unsafe SocketError FinishOperationAccept(Internals.SocketAddress remoteS out remoteSocketAddress.InternalSize ); - Marshal.Copy(remoteAddr, remoteSocketAddress.Buffer, 0, remoteSocketAddress.Size); + Marshal.Copy(remoteAddr, remoteSocketAddress.InternalBuffer, 0, remoteSocketAddress.Size); } socketError = Interop.Winsock.setsockopt( diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index ba596ffa18d5e..b43bfbde09ae1 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -1267,11 +1267,11 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han SocketError errorCode; if (!handle.IsNonBlocking) { - errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddress.SocketBuffer, out socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); + errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddress.Buffer, out socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); } else { - if (!TryCompleteReceiveMessageFrom(handle, new Span(buffer, offset, count), null, socketFlags, socketAddress.SocketBuffer, out socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) + if (!TryCompleteReceiveMessageFrom(handle, new Span(buffer, offset, count), null, socketFlags, socketAddress.Buffer, out socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) { errorCode = SocketError.WouldBlock; } @@ -1285,7 +1285,7 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { - byte[] socketAddressBuffer = socketAddress.Buffer; + byte[] socketAddressBuffer = socketAddress.InternalBuffer; int socketAddressLen; bool isIPv4, isIPv6; diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index 68e4aadf321b8..38528df066b83 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -448,7 +448,7 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan receiveAddress = socketAddress; ipPacketInformation = default(IPPacketInformation); fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) - fixed (byte* ptrSocketAddress = socketAddress.Buffer) + fixed (byte* ptrSocketAddress = &MemoryMarshal.GetReference(socketAddress.Buffer.Span)) { Interop.Winsock.WSAMsg wsaMsg; wsaMsg.socketAddress = (IntPtr)ptrSocketAddress; From bb44fd171df1fd3915018d3224aa97c6e9a46e30 Mon Sep 17 00:00:00 2001 From: wfurt Date: Wed, 19 Jul 2023 07:45:03 -0700 Subject: [PATCH 07/18] quic --- .../src/System/Net/Quic/Internal/MsQuicHelpers.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs index bf454c047c3cd..44f6c2d76661a 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs @@ -51,7 +51,7 @@ internal static unsafe QuicAddr ToQuicAddr(this IPEndPoint ipEndPoint) Internals.SocketAddress address = IPEndPointExtensions.Serialize(ipEndPoint); Debug.Assert(address.Size <= rawAddress.Length); - address.Buffer.AsSpan(0, address.Size).CopyTo(rawAddress); + address.InternalBuffer.AsSpan(0, address.Size).CopyTo(rawAddress); return result; } From 54ff2be19fc61cd58c8c31b246eaa50ef84c47d3 Mon Sep 17 00:00:00 2001 From: wfurt Date: Wed, 19 Jul 2023 21:42:06 -0600 Subject: [PATCH 08/18] update --- .../src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs | 2 +- src/libraries/Common/src/System/Net/SocketAddress.cs | 3 ++- .../System/Net/NetworkInformation/SystemNetworkInterface.cs | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs index 71e67116e4ebe..0152ae15b496e 100644 --- a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs +++ b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs @@ -64,7 +64,7 @@ internal IPAddress MarshalIPAddress() AddressFamily family = (addressLength > Internals.SocketAddress.IPv4AddressSize) ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork; Internals.SocketAddress sockAddress = new Internals.SocketAddress(family, addressLength); - Marshal.Copy(address, sockAddress.Buffer, 0, addressLength); + Marshal.Copy(address, sockAddress.InternalBuffer, 0, addressLength); return sockAddress.GetIPAddress(); } diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index 14c2500b78d26..965ede3948aae 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -111,6 +111,7 @@ internal SocketAddress(IPAddress ipAddress) : this(ipAddress.AddressFamily, ((ipAddress.AddressFamily == AddressFamily.InterNetwork) ? IPv4AddressSize : IPv6AddressSize)) { + // No Port. SocketAddressPal.SetPort(InternalBuffer, 0); @@ -142,7 +143,7 @@ internal SocketAddress(IPAddress ipaddress, int port) internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan buffer) { InternalBuffer = buffer.ToArray(); - InternalSize = Buffer.Length; + InternalSize = InternalBuffer.Length; SocketAddressPal.SetAddressFamily(InternalBuffer, addressFamily); } diff --git a/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs b/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs index 4c62c3770d96a..630cfd68c2d56 100644 --- a/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs +++ b/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs @@ -45,7 +45,7 @@ private static unsafe int GetBestInterfaceForAddress(IPAddress addr) { int index; Internals.SocketAddress address = new Internals.SocketAddress(addr); - fixed (byte* buffer = address.Buffer) + fixed (byte* buffer = address.InternalBuffer) { int error = (int)Interop.IpHlpApi.GetBestInterfaceEx(buffer, &index); if (error != 0) From 04c4ad8a65ca756a0d89206fc7e31c66d5d1f0fe Mon Sep 17 00:00:00 2001 From: wfurt Date: Thu, 20 Jul 2023 20:51:34 -0700 Subject: [PATCH 09/18] fixes --- .../src/System/Net/SocketAddressPal.Unix.cs | 12 +++++-- .../System/Net/SocketAddressPal.Windows.cs | 11 +++++-- .../Net/NetworkInformation/Ping.Windows.cs | 2 +- .../src/System/Net/Sockets/SocketPal.Unix.cs | 33 +++++++++++++++---- 4 files changed, 45 insertions(+), 13 deletions(-) diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs index 099eeed576adb..00a59928db7d4 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -13,8 +13,6 @@ internal static class SocketAddressPal { public static readonly int IPv6AddressSize = GetIPv6AddressSize(); public static readonly int IPv4AddressSize = GetIPv4AddressSize(); - //public static readonly int IPv4AddressSize = GetUdsAddressSize(); - //public static readonly int IPv4AddressSize = GetMaxAddressSize(); private static unsafe int GetIPv6AddressSize() { @@ -66,7 +64,7 @@ public static unsafe AddressFamily GetAddressFamily(ReadOnlySpan buffer) return family; } - public static unsafe void SetAddressFamily(byte[] buffer, AddressFamily family) + public static unsafe void SetAddressFamily(Span buffer, AddressFamily family) { Interop.Error err; @@ -167,5 +165,13 @@ public static unsafe void SetIPv6Address(byte[] buffer, byte* address, int addre ThrowOnFailure(err); } + + public static unsafe void Clear(Span buffer) + { + AddressFamily family = GetAddressFamily(buffer); + buffer.Clear(); + buffer[0] = (byte)buffer.Length; + SetAddressFamily(buffer, family); + } } } diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs index 4b0765065a82e..f5774a7030bd8 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs @@ -10,15 +10,13 @@ internal static class SocketAddressPal { public const int IPv6AddressSize = 28; public const int IPv4AddressSize = 16; - public const int UdsAddressSize = 110; - public const int MaxAddressSize = 128; public static AddressFamily GetAddressFamily(ReadOnlySpan buffer) { return (AddressFamily)BitConverter.ToInt16(buffer); } - public static void SetAddressFamily(byte[] buffer, AddressFamily family) + public static void SetAddressFamily(Span buffer, AddressFamily family) { if ((int)(family) > ushort.MaxValue) { @@ -68,5 +66,12 @@ public static void SetIPv6Address(byte[] buffer, Span address, uint scope) // Address serialization address.CopyTo(buffer.AsSpan(8)); } + + public static unsafe void Clear(Span buffer) + { + AddressFamily family = GetAddressFamily(buffer); + buffer.Clear(); + SetAddressFamily(buffer, family); + } } } diff --git a/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs b/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs index 9a1b3075282e3..4efdb16fa15b2 100644 --- a/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs +++ b/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs @@ -194,7 +194,7 @@ private unsafe int SendEcho(IPAddress address, byte[] buffer, int timeout, PingO IntPtr.Zero, IntPtr.Zero, sourceAddr, - remoteAddr.Buffer, + remoteAddr.InternalBuffer, _requestBuffer!, (ushort)buffer.Length, ref ipOptions, diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index b43bfbde09ae1..8884f966ae234 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -473,8 +473,6 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF int cmsgBufferLen = Interop.Sys.GetControlMessageBufferSize(Convert.ToInt32(isIPv4), Convert.ToInt32(isIPv6)); byte* cmsgBuffer = stackalloc byte[cmsgBufferLen]; - int sockAddrLen = socketAddress.Length; - Interop.Sys.MessageHeader messageHeader; long received = 0; @@ -488,7 +486,7 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF messageHeader = new Interop.Sys.MessageHeader { SocketAddress = rawSocketAddress, - SocketAddressLen = sockAddrLen, + SocketAddressLen = socketAddress.Length, IOVectors = &iov, IOVectorCount = 1, ControlBuffer = cmsgBuffer, @@ -502,17 +500,21 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF &received); receivedFlags = messageHeader.Flags; - sockAddrLen = messageHeader.SocketAddressLen; + socketAddressLen = messageHeader.SocketAddressLen; } - socketAddressLen = sockAddrLen; - if (errno != Interop.Error.SUCCESS) { ipPacketInformation = default(IPPacketInformation); return -1; } + if (socketAddressLen == 0) + { + // We can fail to get peer address on TCP + socketAddressLen = socketAddress.Length; + SocketAddressPal.Clear(socketAddress); + } ipPacketInformation = GetIPPacketInformation(&messageHeader, isIPv4, isIPv6); return checked((int)received); } @@ -574,6 +576,13 @@ private static unsafe int SysReceiveMessageFrom( if (errno == Interop.Error.SUCCESS) { ipPacketInformation = GetIPPacketInformation(&messageHeader, isIPv4, isIPv6); + if (socketAddressLen == 0) + { + // We can fail to get peer address on TCP + socketAddressLen = socketAddress.Length; + SocketAddressPal.Clear(socketAddress); + } + return checked((int)received); } else @@ -834,6 +843,12 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span 0 && socketAddressLen == 0) + { + // We can fail to get peer address on TCP + socketAddressLen = socketAddress.Length; + SocketAddressPal.Clear(socketAddress); + } return true; } @@ -872,6 +887,12 @@ public static unsafe bool TryCompleteReceiveMessageFrom(SafeSocketHandle socket, if (received != -1) { + if (socketAddress.Length > 0 && socketAddressLen == 0) + { + // We can fail to get peer address on TCP + socketAddressLen = socketAddress.Length; + SocketAddressPal.Clear(socketAddress.Span); + } bytesReceived = received; errorCode = SocketError.Success; return true; From 3fa720745b399133a5dcfe01b74e67f57c27e013 Mon Sep 17 00:00:00 2001 From: wfurt Date: Thu, 20 Jul 2023 21:01:25 -0700 Subject: [PATCH 10/18] GetHashCode --- .../Common/src/System/Net/SocketAddress.cs | 37 ++----------------- 1 file changed, 4 insertions(+), 33 deletions(-) diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index 965ede3948aae..70f9f3c11ec74 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -33,8 +33,6 @@ class SocketAddress private const int MinSize = 2; private const int MaxSize = 32; // IrDA requires 32 bytes private const int DataOffset = 2; - private bool _changed = true; - private int _hash; public AddressFamily Family { @@ -78,10 +76,6 @@ public byte this[int offset] { throw new IndexOutOfRangeException(); } - if (InternalBuffer[offset] != value) - { - _changed = true; - } InternalBuffer[offset] = value; } } @@ -214,36 +208,13 @@ internal int GetAddressSizeOffset() public override bool Equals(object? comparand) => comparand is SocketAddress other && - InternalBuffer.AsSpan(0, Size).SequenceEqual(other.InternalBuffer.AsSpan(0, other.Size)); + Buffer.Span.SequenceEqual(other.Buffer.Span); public override int GetHashCode() { - if (_changed) - { - _changed = false; - _hash = 0; - - int i; - int size = Size & ~3; - - for (i = 0; i < size; i += 4) - { - _hash ^= BinaryPrimitives.ReadInt32LittleEndian(InternalBuffer.AsSpan(i)); - } - if ((Size & 3) != 0) - { - int remnant = 0; - int shift = 0; - - for (; i < Size; ++i) - { - remnant |= ((int)InternalBuffer[i]) << shift; - shift += 8; - } - _hash ^= remnant; - } - } - return _hash; + HashCode hash = default; + hash.AddBytes(Buffer.Span); + return hash.ToHashCode(); } public override string ToString() From a1e932e6f13c103d066fd0b716ca3746130bf7db Mon Sep 17 00:00:00 2001 From: wfurt Date: Fri, 21 Jul 2023 17:14:13 -0700 Subject: [PATCH 11/18] cleanup --- .../Unix/System.Native/Interop.Socket.cs | 5 +- .../Interop/Windows/IpHlpApi/Interop.ICMP.cs | 2 +- .../IpHlpApi/Interop.NetworkInformation.cs | 13 ++--- .../src/System/Net/IPEndPointExtensions.cs | 54 +++++++++++++++++++ .../src/System/Net/SocketAddressPal.Unix.cs | 11 ++-- .../System/Net/SocketAddressPal.Windows.cs | 16 +++--- .../Net/SocketProtocolSupportPal.Unix.cs | 5 +- .../src/System.Net.NetworkInformation.csproj | 2 +- .../SystemNetworkInterface.cs | 13 +++-- .../src/System.Net.Ping.csproj | 14 ++--- .../Net/NetworkInformation/Ping.Windows.cs | 11 ++-- .../src/System.Net.Quic.csproj | 3 +- .../System/Net/Quic/Internal/MsQuicHelpers.cs | 15 +++--- .../Sockets/SocketAsyncEventArgs.Windows.cs | 2 +- 14 files changed, 106 insertions(+), 60 deletions(-) create mode 100644 src/libraries/Common/src/System/Net/IPEndPointExtensions.cs diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs index 56d00ccb3a98c..a593fd34fe458 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs @@ -3,7 +3,6 @@ using System; using System.Net; -using System.Net.Internals; using System.Net.Sockets; using System.Runtime.InteropServices; @@ -11,7 +10,11 @@ internal static partial class Interop { internal static partial class Sys { +#if SYSTEM_NET_SOCKETS_DLL [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Socket")] internal static unsafe partial Error Socket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType, IntPtr* socket); +#endif + [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Socket")] + internal static unsafe partial Error Socket(int addressFamily, int socketType, int protocolType, IntPtr* socket); } } diff --git a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.ICMP.cs b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.ICMP.cs index 090bf72a31a39..1f71754b486f8 100644 --- a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.ICMP.cs +++ b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.ICMP.cs @@ -104,6 +104,6 @@ internal static partial uint IcmpSendEcho2(SafeCloseIcmpHandle icmpHandle, SafeW [LibraryImport(Interop.Libraries.IpHlpApi, SetLastError = true)] internal static unsafe partial uint Icmp6SendEcho2(SafeCloseIcmpHandle icmpHandle, SafeWaitHandle Event, IntPtr apcRoutine, IntPtr apcContext, - byte* sourceSocketAddress, byte[] destSocketAddress, SafeLocalAllocHandle data, ushort dataSize, ref IP_OPTION_INFORMATION options, SafeLocalAllocHandle replyBuffer, uint replySize, uint timeout); + Span sourceSocketAddress, Span destSocketAddress, SafeLocalAllocHandle data, ushort dataSize, ref IP_OPTION_INFORMATION options, SafeLocalAllocHandle replyBuffer, uint replySize, uint timeout); } } diff --git a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs index 0152ae15b496e..b01bb8ee9d3c9 100644 --- a/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs +++ b/src/libraries/Common/src/Interop/Windows/IpHlpApi/Interop.NetworkInformation.cs @@ -8,7 +8,6 @@ using System.Net.NetworkInformation; using System.Net.Sockets; using System.Runtime.InteropServices; -using Internals = System.Net.Internals; internal static partial class Interop { @@ -53,20 +52,14 @@ internal enum GetAdaptersAddressesFlags } [StructLayout(LayoutKind.Sequential)] - internal struct IpSocketAddress + internal unsafe struct IpSocketAddress { internal IntPtr address; internal int addressLength; internal IPAddress MarshalIPAddress() { - // Determine the address family used to create the IPAddress. - AddressFamily family = (addressLength > Internals.SocketAddress.IPv4AddressSize) - ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork; - Internals.SocketAddress sockAddress = new Internals.SocketAddress(family, addressLength); - Marshal.Copy(address, sockAddress.InternalBuffer, 0, addressLength); - - return sockAddress.GetIPAddress(); + return IPEndPointExtensions.GetIPAddress(new Span((void*)address, addressLength)); } } @@ -511,7 +504,7 @@ internal static unsafe partial uint GetAdaptersAddresses( uint* outBufLen); [LibraryImport(Interop.Libraries.IpHlpApi)] - internal static unsafe partial uint GetBestInterfaceEx(byte* ipAddress, int* index); + internal static unsafe partial uint GetBestInterfaceEx(Span ipAddress, int* index); [LibraryImport(Interop.Libraries.IpHlpApi)] internal static partial uint GetIfEntry2(ref MibIfRow2 pIfRow); diff --git a/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs new file mode 100644 index 0000000000000..b62b1d17d9193 --- /dev/null +++ b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; + +namespace System.Net.Sockets +{ + internal static class IPEndPointExtensions + { + public static IPAddress GetIPAddress(ReadOnlySpan socketAddressBuffer) + { + if (SocketAddressPal.GetAddressFamily(socketAddressBuffer) == AddressFamily.InterNetworkV6) + { + Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; + uint scope; + SocketAddressPal.GetIPv6Address(socketAddressBuffer, address, out scope); + return new IPAddress(address, (long)scope); + } + + return new IPAddress((long)SocketAddressPal.GetIPv4Address(socketAddressBuffer) & 0x0FFFFFFFF); + } + + public static void SetIPAddress(Span socketAddressBuffer, IPAddress address) + { + SocketAddressPal.SetAddressFamily(socketAddressBuffer, address.AddressFamily); + SocketAddressPal.SetPort(socketAddressBuffer, 0); + if (address.AddressFamily == AddressFamily.InterNetwork) + { +#pragma warning disable CS0618 + SocketAddressPal.SetIPv4Address(socketAddressBuffer, (uint)address.Address); +#pragma warning restore CS0618 + } + else + { + Span addressBuffer = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; + address.TryWriteBytes(addressBuffer, out _); + SocketAddressPal.SetIPv6Address(socketAddressBuffer, addressBuffer, (uint)address.ScopeId); + } + } + + public static IPEndPoint CreateIPEndPoint(ReadOnlySpan socketAddressBuffer) + { + return new IPEndPoint(GetIPAddress(socketAddressBuffer), SocketAddressPal.GetPort(socketAddressBuffer)); + } + + // https://github.com/dotnet/runtime/issues/78993 + public static void Serialize(this IPEndPoint endPoint, Span destination) + { + SocketAddressPal.SetAddressFamily(destination, endPoint.AddressFamily); + SetIPAddress(destination, endPoint.Address); + SocketAddressPal.SetPort(destination, (ushort)endPoint.Port); + } + } +} diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs index 00a59928db7d4..966a78d39b552 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -92,7 +92,7 @@ public static unsafe ushort GetPort(ReadOnlySpan buffer) return port; } - public static unsafe void SetPort(byte[] buffer, ushort port) + public static unsafe void SetPort(Span buffer, ushort port) { Interop.Error err; fixed (byte* rawAddress = buffer) @@ -130,7 +130,7 @@ public static unsafe void GetIPv6Address(ReadOnlySpan buffer, Span a scope = localScope; } - public static unsafe void SetIPv4Address(byte[] buffer, uint address) + public static unsafe void SetIPv4Address(Span buffer, uint address) { Interop.Error err; fixed (byte* rawAddress = buffer) @@ -141,21 +141,22 @@ public static unsafe void SetIPv4Address(byte[] buffer, uint address) ThrowOnFailure(err); } - public static unsafe void SetIPv4Address(byte[] buffer, byte* address) + public static unsafe void SetIPv4Address(Span buffer, byte* address) { uint addr = (uint)System.Runtime.InteropServices.Marshal.ReadInt32((IntPtr)address); SetIPv4Address(buffer, addr); } - public static unsafe void SetIPv6Address(byte[] buffer, Span address, uint scope) + public static unsafe void SetIPv6Address(Span buffer, Span address, uint scope) { + fixed (byte* rawInput = &MemoryMarshal.GetReference(address)) { SetIPv6Address(buffer, rawInput, address.Length, scope); } } - public static unsafe void SetIPv6Address(byte[] buffer, byte* address, int addressLength, uint scope) + public static unsafe void SetIPv6Address(Span buffer, byte* address, int addressLength, uint scope) { Interop.Error err; fixed (byte* rawAddress = buffer) diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs index f5774a7030bd8..362f66b8a325d 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs @@ -36,8 +36,8 @@ public static void SetAddressFamily(Span buffer, AddressFamily family) public static ushort GetPort(ReadOnlySpan buffer) => BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(2)); - public static void SetPort(byte[] buffer, ushort port) - => BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(2), port); + public static void SetPort(Span buffer, ushort port) + => BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(2), port); public static uint GetIPv4Address(ReadOnlySpan buffer) => BinaryPrimitives.ReadUInt32LittleEndian(buffer.Slice(4)); @@ -49,22 +49,22 @@ public static void GetIPv6Address(ReadOnlySpan buffer, Span address, scope = BinaryPrimitives.ReadUInt32LittleEndian(buffer.Slice(24)); } - public static void SetIPv4Address(byte[] buffer, uint address) + public static void SetIPv4Address(Span buffer, uint address) { // IPv4 Address serialization - BinaryPrimitives.WriteUInt32LittleEndian(buffer.AsSpan(4), address); + BinaryPrimitives.WriteUInt32LittleEndian(buffer.Slice(4), address); } - public static void SetIPv6Address(byte[] buffer, Span address, uint scope) + public static void SetIPv6Address(Span buffer, Span address, uint scope) { // No handling for Flow Information - BinaryPrimitives.WriteUInt32LittleEndian(buffer.AsSpan(4), 0); + BinaryPrimitives.WriteUInt32LittleEndian(buffer.Slice(4), 0); // Scope serialization - BinaryPrimitives.WriteUInt32LittleEndian(buffer.AsSpan(24), scope); + BinaryPrimitives.WriteUInt32LittleEndian(buffer.Slice(24), scope); // Address serialization - address.CopyTo(buffer.AsSpan(8)); + address.CopyTo(buffer.Slice(8)); } public static unsafe void Clear(Span buffer) diff --git a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs index 973254788e542..e85dc64ce6acc 100644 --- a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Net.Internals; using System.Net.Sockets; using System.Runtime.InteropServices; @@ -21,7 +20,11 @@ private static unsafe bool IsSupported(AddressFamily af) IntPtr socket = invalid; try { +#if SYSTEM_NET_SOCKETS_DLL Interop.Error result = Interop.Sys.Socket(af, SocketType.Dgram, 0, &socket); +#else + Interop.Error result = Interop.Sys.Socket((int)af, (int)Internals.SocketType.Dgram, 0, &socket); +#endif // we get EAFNOSUPPORT when family is not supported by Kernel, EPROTONOSUPPORT may come from policy enforcement like FreeBSD jail() return result != Interop.Error.EAFNOSUPPORT && result != Interop.Error.EPROTONOSUPPORT; } diff --git a/src/libraries/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj b/src/libraries/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj index 2a742d79f7144..3cf6d5f597f67 100644 --- a/src/libraries/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj +++ b/src/libraries/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj @@ -82,7 +82,7 @@ - + diff --git a/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs b/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs index 630cfd68c2d56..96256e552a0f1 100644 --- a/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs +++ b/src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/SystemNetworkInterface.cs @@ -44,14 +44,13 @@ internal static int InternalIPv6LoopbackInterfaceIndex private static unsafe int GetBestInterfaceForAddress(IPAddress addr) { int index; - Internals.SocketAddress address = new Internals.SocketAddress(addr); - fixed (byte* buffer = address.InternalBuffer) + Span buffer= stackalloc byte[SocketAddressPal.IPv6AddressSize]; + IPEndPointExtensions.SetIPAddress(buffer, addr); + + int error = (int)Interop.IpHlpApi.GetBestInterfaceEx(buffer, &index); + if (error != 0) { - int error = (int)Interop.IpHlpApi.GetBestInterfaceEx(buffer, &index); - if (error != 0) - { - throw new NetworkInformationException(error); - } + throw new NetworkInformationException(error); } return index; diff --git a/src/libraries/System.Net.Ping/src/System.Net.Ping.csproj b/src/libraries/System.Net.Ping/src/System.Net.Ping.csproj index 9c817f3908a98..aaaf38f5bf51f 100644 --- a/src/libraries/System.Net.Ping/src/System.Net.Ping.csproj +++ b/src/libraries/System.Net.Ping/src/System.Net.Ping.csproj @@ -24,17 +24,14 @@ - - - - + + @@ -96,9 +93,6 @@ Link="Common\Interop\Windows\WinSock\Interop.WSAStartup.cs" /> - - diff --git a/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs b/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs index 4efdb16fa15b2..b5ca9832bdb3e 100644 --- a/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs +++ b/src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs @@ -183,10 +183,11 @@ private unsafe int SendEcho(IPAddress address, byte[] buffer, int timeout, PingO (uint)timeout); } - IPEndPoint ep = new IPEndPoint(address, 0); - Internals.SocketAddress remoteAddr = IPEndPointExtensions.Serialize(ep); - byte* sourceAddr = stackalloc byte[28]; - NativeMemory.Clear(sourceAddr, 28); + Span remoteAddr = stackalloc byte[SocketAddressPal.IPv6AddressSize]; + IPEndPointExtensions.SetIPAddress(remoteAddr, address); + + Span sourceAddr = stackalloc byte[SocketAddressPal.IPv6AddressSize]; + sourceAddr.Clear(); return (int)Interop.IpHlpApi.Icmp6SendEcho2( _handlePingV6!, @@ -194,7 +195,7 @@ private unsafe int SendEcho(IPAddress address, byte[] buffer, int timeout, PingO IntPtr.Zero, IntPtr.Zero, sourceAddr, - remoteAddr.InternalBuffer, + remoteAddr, _requestBuffer!, (ushort)buffer.Length, ref ipOptions, diff --git a/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj b/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj index 53b39492c39a4..3b94c902e62a9 100644 --- a/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj +++ b/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj @@ -28,9 +28,8 @@ - - + diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs index 44f6c2d76661a..17886b8f8a1ba 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs @@ -38,20 +38,19 @@ internal static bool TryParse(this EndPoint endPoint, out string? host, out IPAd internal static unsafe IPEndPoint ToIPEndPoint(this ref QuicAddr quicAddress, AddressFamily? addressFamilyOverride = null) { // MsQuic always uses storage size as if IPv6 was used - Span addressBytes = new Span((byte*)Unsafe.AsPointer(ref quicAddress), Internals.SocketAddress.IPv6AddressSize); - return new Internals.SocketAddress(addressFamilyOverride ?? SocketAddressPal.GetAddressFamily(addressBytes), addressBytes).GetIPEndPoint(); + Span addressBytes = new Span((byte*)Unsafe.AsPointer(ref quicAddress), SocketAddressPal.IPv6AddressSize); + if (addressFamilyOverride != null) + { + SocketAddressPal.SetAddressFamily(addressBytes, (AddressFamily)addressFamilyOverride!); + } + return IPEndPointExtensions.CreateIPEndPoint(addressBytes); } internal static unsafe QuicAddr ToQuicAddr(this IPEndPoint ipEndPoint) { - // TODO: is the layout same for SocketAddress.Buffer and QuicAddr on all platforms? QuicAddr result = default; Span rawAddress = MemoryMarshal.AsBytes(MemoryMarshal.CreateSpan(ref result, 1)); - - Internals.SocketAddress address = IPEndPointExtensions.Serialize(ipEndPoint); - Debug.Assert(address.Size <= rawAddress.Length); - - address.InternalBuffer.AsSpan(0, address.Size).CopyTo(rawAddress); + ipEndPoint.Serialize(rawAddress); return result; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index 192a8a1b30bf6..6f5f5e856c66d 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -873,7 +873,7 @@ private void PinSocketAddressBuffer() } // Pin down the new one. - _socketAddressGCHandle = GCHandle.Alloc(_socketAddress!.Buffer, GCHandleType.Pinned); + _socketAddressGCHandle = GCHandle.Alloc(_socketAddress!.InternalBuffer, GCHandleType.Pinned); _socketAddress.CopyAddressSizeIntoBuffer(); _pinnedSocketAddress = _socketAddress; } From 845a4a7e21ed3e0be2fb022ee443057e49874354 Mon Sep 17 00:00:00 2001 From: wfurt Date: Fri, 21 Jul 2023 21:43:43 -0700 Subject: [PATCH 12/18] PalTests --- .../tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj b/src/libraries/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj index e0aa8f994b0c0..5256dd8848232 100644 --- a/src/libraries/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj +++ b/src/libraries/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj @@ -19,9 +19,9 @@ - - From d79cf82a347fef28a80082d86fe8e6b4bd820017 Mon Sep 17 00:00:00 2001 From: wfurt Date: Sun, 23 Jul 2023 11:40:20 -0700 Subject: [PATCH 13/18] GetMaximumAddressSize --- .../System.Native/Interop.SocketAddress.cs | 4 ++-- .../Common/src/System/Net/SocketAddress.cs | 20 ++++++++++++++----- .../src/System/Net/SocketAddressPal.Unix.cs | 20 ++++++------------- .../System/Net/SocketAddressPal.Windows.cs | 2 ++ .../ref/System.Net.Primitives.cs | 4 +++- .../FunctionalTests/SocketAddressTest.cs | 2 +- src/native/libs/System.Native/entrypoints.c | 2 +- .../libs/System.Native/pal_networking.c | 6 ++++-- .../libs/System.Native/pal_networking.h | 2 +- 9 files changed, 35 insertions(+), 27 deletions(-) diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs index 4c03a64e1beb3..03c46ddcd684e 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs @@ -9,9 +9,9 @@ internal static partial class Interop { internal static partial class Sys { - [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetIPSocketAddressSizes")] + [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetSocketAddressSizes")] [SuppressGCTransition] - internal static unsafe partial Error GetIPSocketAddressSizes(int* ipv4SocketAddressSize, int* ipv6SocketAddressSize); + internal static partial Error GetSocketAddressSizes(ref int ipv4SocketAddressSize, ref int ipv6SocketAddressSize, ref int udsSocketAddressSize, ref int maxSocketAddressSize); [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetAddressFamily")] [SuppressGCTransition] diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index 70f9f3c11ec74..c58515bc9f57b 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -20,18 +20,19 @@ namespace System.Net.Internals #else internal sealed #endif - class SocketAddress + class SocketAddress : System.IEquatable { #pragma warning disable CA1802 // these could be const on Windows but need to be static readonly for Unix internal static readonly int IPv6AddressSize = SocketAddressPal.IPv6AddressSize; internal static readonly int IPv4AddressSize = SocketAddressPal.IPv4AddressSize; + internal static readonly int UdsAddressSize = SocketAddressPal.UdsAddressSize; + internal static readonly int MaxAddressSize = SocketAddressPal.MaxAddressSize; #pragma warning restore CA1802 internal int InternalSize; internal byte[] InternalBuffer; private const int MinSize = 2; - private const int MaxSize = 32; // IrDA requires 32 bytes private const int DataOffset = 2; public AddressFamily Family @@ -80,7 +81,15 @@ public byte this[int offset] } } - public SocketAddress(AddressFamily family) : this(family, MaxSize) + public static int GetMaximumAddressSize(AddressFamily addressFamily) => addressFamily switch + { + AddressFamily.InterNetwork => IPv4AddressSize, + AddressFamily.InterNetworkV6 => IPv6AddressSize, + AddressFamily.Unix => UdsAddressSize, + _ => MaxAddressSize + }; + + public SocketAddress(AddressFamily family) : this(family, GetMaximumAddressSize(family)) { } @@ -207,8 +216,9 @@ internal int GetAddressSizeOffset() #endif public override bool Equals(object? comparand) => - comparand is SocketAddress other && - Buffer.Span.SequenceEqual(other.Buffer.Span); + comparand is SocketAddress other && Equals(other); + + public bool Equals(SocketAddress? comparand) => comparand != null && Buffer.Span.SequenceEqual(comparand.Buffer.Span); public override int GetHashCode() { diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs index 966a78d39b552..9efd65240d425 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -11,23 +11,15 @@ namespace System.Net { internal static class SocketAddressPal { - public static readonly int IPv6AddressSize = GetIPv6AddressSize(); - public static readonly int IPv4AddressSize = GetIPv4AddressSize(); + public static readonly int IPv6AddressSize; + public static readonly int IPv4AddressSize; + public static readonly int UdsAddressSize; + public static readonly int MaxAddressSize; - private static unsafe int GetIPv6AddressSize() + static SocketAddressPal() { - int ipv6AddressSize, unused; - Interop.Error err = Interop.Sys.GetIPSocketAddressSizes(&unused, &ipv6AddressSize); + Interop.Error err = Interop.Sys.GetSocketAddressSizes(ref IPv4AddressSize, ref IPv6AddressSize, ref UdsAddressSize, ref MaxAddressSize); Debug.Assert(err == Interop.Error.SUCCESS, $"Unexpected err: {err}"); - return ipv6AddressSize; - } - - private static unsafe int GetIPv4AddressSize() - { - int ipv4AddressSize, unused; - Interop.Error err = Interop.Sys.GetIPSocketAddressSizes(&ipv4AddressSize, &unused); - Debug.Assert(err == Interop.Error.SUCCESS, $"Unexpected err: {err}"); - return ipv4AddressSize; } private static void ThrowOnFailure(Interop.Error err) diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs index 362f66b8a325d..a563675dab62f 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs @@ -10,6 +10,8 @@ internal static class SocketAddressPal { public const int IPv6AddressSize = 28; public const int IPv4AddressSize = 16; + public const int UdsAddressSize = 110; + public const int MaxAddressSize = 128; public static AddressFamily GetAddressFamily(ReadOnlySpan buffer) { diff --git a/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs b/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs index 1b15b74b24525..480548a33f1b7 100644 --- a/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs +++ b/src/libraries/System.Net.Primitives/ref/System.Net.Primitives.cs @@ -349,15 +349,17 @@ public NetworkCredential(string? userName, string? password, string? domain) { } public System.Net.NetworkCredential GetCredential(string? host, int port, string? authenticationType) { throw null; } public System.Net.NetworkCredential GetCredential(System.Uri? uri, string? authenticationType) { throw null; } } - public partial class SocketAddress + public partial class SocketAddress : System.IEquatable { public SocketAddress(System.Net.Sockets.AddressFamily family) { } public SocketAddress(System.Net.Sockets.AddressFamily family, int size) { } public System.Net.Sockets.AddressFamily Family { get { throw null; } } public byte this[int offset] { get { throw null; } set { } } public int Size { get { throw null; } set { } } + public static int GetMaximumAddressSize(System.Net.Sockets.AddressFamily addressFamily) { throw null; } public System.Memory Buffer { get { throw null; } } public override bool Equals(object? comparand) { throw null; } + public bool Equals(System.Net.SocketAddress? comparand) { throw null; } public override int GetHashCode() { throw null; } public override string ToString() { throw null; } } diff --git a/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs b/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs index 70ef816d1d0af..ec79ac2976b9b 100644 --- a/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs +++ b/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs @@ -14,7 +14,7 @@ public static void Ctor_AddressFamily_Success() { SocketAddress sa = new SocketAddress(AddressFamily.InterNetwork); Assert.Equal(AddressFamily.InterNetwork, sa.Family); - Assert.Equal(32, sa.Size); + Assert.Equal(16, sa.Size); } [Fact] diff --git a/src/native/libs/System.Native/entrypoints.c b/src/native/libs/System.Native/entrypoints.c index 394a39b0d0036..0b719b1545fae 100644 --- a/src/native/libs/System.Native/entrypoints.c +++ b/src/native/libs/System.Native/entrypoints.c @@ -137,7 +137,7 @@ static const Entry s_sysNative[] = DllImportEntry(SystemNative_GetNameInfo) DllImportEntry(SystemNative_GetDomainName) DllImportEntry(SystemNative_GetHostName) - DllImportEntry(SystemNative_GetIPSocketAddressSizes) + DllImportEntry(SystemNative_GetSocketAddressSizes) DllImportEntry(SystemNative_GetAddressFamily) DllImportEntry(SystemNative_SetAddressFamily) DllImportEntry(SystemNative_GetPort) diff --git a/src/native/libs/System.Native/pal_networking.c b/src/native/libs/System.Native/pal_networking.c index ffd548835901f..8dfc133f2ed78 100644 --- a/src/native/libs/System.Native/pal_networking.c +++ b/src/native/libs/System.Native/pal_networking.c @@ -659,15 +659,17 @@ static bool IsInBounds(const void* void_baseAddr, size_t len, const void* void_v return valueAddr >= baseAddr && (valueAddr + valueSize) <= (baseAddr + len); } -int32_t SystemNative_GetIPSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize) +int32_t SystemNative_GetSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize, int32_t* udsSocketAddressSize, int32_t* maxSocketAddressSize) { - if (ipv4SocketAddressSize == NULL || ipv6SocketAddressSize == NULL) + if (ipv4SocketAddressSize == NULL || ipv6SocketAddressSize == NULL || udsSocketAddressSize == NULL || maxSocketAddressSize == NULL) { return Error_EFAULT; } *ipv4SocketAddressSize = sizeof(struct sockaddr_in); *ipv6SocketAddressSize = sizeof(struct sockaddr_in6); + *udsSocketAddressSize = sizeof(struct sockaddr_un); + *maxSocketAddressSize = sizeof(struct sockaddr_storage); return Error_SUCCESS; } diff --git a/src/native/libs/System.Native/pal_networking.h b/src/native/libs/System.Native/pal_networking.h index 65148d67a646b..0a46f1490aab9 100644 --- a/src/native/libs/System.Native/pal_networking.h +++ b/src/native/libs/System.Native/pal_networking.h @@ -312,7 +312,7 @@ PALEXPORT int32_t SystemNative_GetDomainName(uint8_t* name, int32_t nameLength); PALEXPORT int32_t SystemNative_GetHostName(uint8_t* name, int32_t nameLength); -PALEXPORT int32_t SystemNative_GetIPSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize); +PALEXPORT int32_t SystemNative_GetSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize, int32_t* udsSocketAddressSize, int32_t* maxSocketAddressSize); PALEXPORT int32_t SystemNative_GetAddressFamily(const uint8_t* socketAddress, int32_t socketAddressLen, int32_t* addressFamily); From e1931323d34fe83544e9da7ac520d2ba98939ad4 Mon Sep 17 00:00:00 2001 From: wfurt Date: Sun, 23 Jul 2023 14:01:18 -0700 Subject: [PATCH 14/18] wasi --- src/native/libs/System.Native/pal_networking_wasi.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/native/libs/System.Native/pal_networking_wasi.c b/src/native/libs/System.Native/pal_networking_wasi.c index 1cbd365519410..baeb6494b31e7 100644 --- a/src/native/libs/System.Native/pal_networking_wasi.c +++ b/src/native/libs/System.Native/pal_networking_wasi.c @@ -61,7 +61,7 @@ int32_t SystemNative_GetHostName(uint8_t* name, int32_t nameLength) return gethostname((char*)name, unsignedSize); } -int32_t SystemNative_GetIPSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize) +int32_t SystemNative_GetSocketAddressSizes(int32_t* ipv4SocketAddressSize, int32_t* ipv6SocketAddressSize, int32_t*udsSocketAddressSize, int* maxSocketAddressSize) { return Error_EFAULT; } From 377ae221aef36f01092618e97d96d727970c5597 Mon Sep 17 00:00:00 2001 From: wfurt Date: Wed, 26 Jul 2023 22:41:02 -0700 Subject: [PATCH 15/18] feedback --- .../Unix/System.Native/Interop.Socket.cs | 4 ---- .../System.Native/Interop.SocketAddress.cs | 2 +- .../src/System/Net/IPEndPointExtensions.cs | 15 ++++++++++---- .../Common/src/System/Net/SocketAddress.cs | 8 ++++---- .../src/System/Net/SocketAddressPal.Unix.cs | 20 ++++++++++++++++--- .../Net/SocketProtocolSupportPal.Unix.cs | 7 ++----- .../src/System/Net/Sockets/Socket.cs | 2 +- .../src/System/Net/Sockets/SocketPal.Unix.cs | 6 +++--- 8 files changed, 39 insertions(+), 25 deletions(-) diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs index a593fd34fe458..01b5fb9f32f73 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Socket.cs @@ -10,10 +10,6 @@ internal static partial class Interop { internal static partial class Sys { -#if SYSTEM_NET_SOCKETS_DLL - [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Socket")] - internal static unsafe partial Error Socket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType, IntPtr* socket); -#endif [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Socket")] internal static unsafe partial Error Socket(int addressFamily, int socketType, int protocolType, IntPtr* socket); } diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs index 03c46ddcd684e..7a1dceee2fd86 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.SocketAddress.cs @@ -11,7 +11,7 @@ internal static partial class Sys { [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetSocketAddressSizes")] [SuppressGCTransition] - internal static partial Error GetSocketAddressSizes(ref int ipv4SocketAddressSize, ref int ipv6SocketAddressSize, ref int udsSocketAddressSize, ref int maxSocketAddressSize); + internal static unsafe partial Error GetSocketAddressSizes(int* ipv4SocketAddressSize, int* ipv6SocketAddressSize, int* udsSocketAddressSize, int* maxSocketAddressSize); [LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetAddressFamily")] [SuppressGCTransition] diff --git a/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs index b62b1d17d9193..7b9a6a697a4e3 100644 --- a/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs +++ b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs @@ -9,15 +9,21 @@ internal static class IPEndPointExtensions { public static IPAddress GetIPAddress(ReadOnlySpan socketAddressBuffer) { - if (SocketAddressPal.GetAddressFamily(socketAddressBuffer) == AddressFamily.InterNetworkV6) + AddressFamily family = SocketAddressPal.GetAddressFamily(socketAddressBuffer); + + if (family == AddressFamily.InterNetworkV6) { Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; uint scope; SocketAddressPal.GetIPv6Address(socketAddressBuffer, address, out scope); return new IPAddress(address, (long)scope); } + else if (family == AddressFamily.InterNetwork) + { + return new IPAddress((long)SocketAddressPal.GetIPv4Address(socketAddressBuffer) & 0x0FFFFFFFF); + } - return new IPAddress((long)SocketAddressPal.GetIPv4Address(socketAddressBuffer) & 0x0FFFFFFFF); + throw new SocketException((int)SocketError.AddressFamilyNotSupported); } public static void SetIPAddress(Span socketAddressBuffer, IPAddress address) @@ -33,7 +39,8 @@ public static void SetIPAddress(Span socketAddressBuffer, IPAddress addres else { Span addressBuffer = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; - address.TryWriteBytes(addressBuffer, out _); + address.TryWriteBytes(addressBuffer, out int written); + Debug.Assert(written == IPAddressParserStatics.IPv6AddressBytes); SocketAddressPal.SetIPv6Address(socketAddressBuffer, addressBuffer, (uint)address.ScopeId); } } @@ -43,7 +50,7 @@ public static IPEndPoint CreateIPEndPoint(ReadOnlySpan socketAddressBuffer return new IPEndPoint(GetIPAddress(socketAddressBuffer), SocketAddressPal.GetPort(socketAddressBuffer)); } - // https://github.com/dotnet/runtime/issues/78993 + // suggestion from https://github.com/dotnet/runtime/issues/78993 public static void Serialize(this IPEndPoint endPoint, Span destination) { SocketAddressPal.SetAddressFamily(destination, endPoint.AddressFamily); diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index c58515bc9f57b..c7ad378555516 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -65,7 +65,7 @@ public byte this[int offset] { get { - if (offset < 0 || offset >= Size) + if ((uint)offset >= (uint)Size) { throw new IndexOutOfRangeException(); } @@ -73,7 +73,7 @@ public byte this[int offset] } set { - if (offset < 0 || offset >= Size) + if ((uint)offset >= (uint)Size) { throw new IndexOutOfRangeException(); } @@ -152,7 +152,7 @@ internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan buffer) /// This represents underlying memory that can be passed to native OS calls. /// - /// This memory can be invalidated if is changed or if the SocketAddress is used in another receive call. + /// Content of the memory can be invalidated if is changed or if the SocketAddress is used in another receive call. /// public Memory Buffer { @@ -223,7 +223,7 @@ public override bool Equals(object? comparand) => public override int GetHashCode() { HashCode hash = default; - hash.AddBytes(Buffer.Span); + hash.AddBytes(new ReadOnlySpan(InternalBuffer, 0, InternalSize)); return hash.ToHashCode(); } diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs index 9efd65240d425..8ebe41f61eae0 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -11,16 +11,30 @@ namespace System.Net { internal static class SocketAddressPal { - public static readonly int IPv6AddressSize; public static readonly int IPv4AddressSize; + public static readonly int IPv6AddressSize; public static readonly int UdsAddressSize; public static readonly int MaxAddressSize; - static SocketAddressPal() +#pragma warning disable CA1810 + static unsafe SocketAddressPal() { - Interop.Error err = Interop.Sys.GetSocketAddressSizes(ref IPv4AddressSize, ref IPv6AddressSize, ref UdsAddressSize, ref MaxAddressSize); + int ipv4 = 0; + int ipv6 = 0; + int uds = 0; + int max = 0; + Interop.Error err = Interop.Sys.GetSocketAddressSizes(&ipv4, &ipv6, &uds, &max); Debug.Assert(err == Interop.Error.SUCCESS, $"Unexpected err: {err}"); + Debug.Assert(ipv4 > 0); + Debug.Assert(ipv6 > 0); + Debug.Assert(uds > 0); + Debug.Assert(max >= ipv4 && max >= ipv6 && max >= uds); + IPv4AddressSize = ipv4; + IPv6AddressSize =ipv6; + UdsAddressSize = uds; + MaxAddressSize =max; } +#pragma warning restore CA1810 private static void ThrowOnFailure(Interop.Error err) { diff --git a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs index e85dc64ce6acc..0e6e96a4dd20d 100644 --- a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs @@ -8,6 +8,7 @@ namespace System.Net { internal static partial class SocketProtocolSupportPal { + private const int DgramSocketType = 2; private static unsafe bool IsSupported(AddressFamily af) { // Check for AF_UNIX on iOS/tvOS. The OS claims to support this, but returns EPERM on bind. @@ -20,11 +21,7 @@ private static unsafe bool IsSupported(AddressFamily af) IntPtr socket = invalid; try { -#if SYSTEM_NET_SOCKETS_DLL - Interop.Error result = Interop.Sys.Socket(af, SocketType.Dgram, 0, &socket); -#else - Interop.Error result = Interop.Sys.Socket((int)af, (int)Internals.SocketType.Dgram, 0, &socket); -#endif + Interop.Error result = Interop.Sys.Socket((int)af, DgramSocketType, 0, &socket); // we get EAFNOSUPPORT when family is not supported by Kernel, EPROTONOSUPPORT may come from policy enforcement like FreeBSD jail() return result != Interop.Error.EAFNOSUPPORT && result != Interop.Error.EPROTONOSUPPORT; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 3063aea8cc5c6..e01ec8da41f37 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -1382,7 +1382,7 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, EndPoint r /// /// A span of bytes that contains the data to be sent. /// A bitwise combination of the values. - /// The that represents the destination for the data. + /// The that represents the destination for the data. /// The number of bytes sent. /// remoteEP is . /// An error occurred when attempting to access the socket. diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 8884f966ae234..27415a5e440a0 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -60,7 +60,7 @@ public static unsafe SocketError CreateSocket(AddressFamily addressFamily, Socke IntPtr fd; SocketError errorCode; - Interop.Error error = Interop.Sys.Socket(addressFamily, socketType, protocolType, &fd); + Interop.Error error = Interop.Sys.Socket((int)addressFamily, (int)socketType, (int)protocolType, &fd); if (error == Interop.Error.SUCCESS) { Debug.Assert(fd != (IntPtr)(-1), "fd should not be -1"); @@ -579,7 +579,7 @@ private static unsafe int SysReceiveMessageFrom( if (socketAddressLen == 0) { // We can fail to get peer address on TCP - socketAddressLen = socketAddress.Length; + socketAddressLen = socketAddress.Length; SocketAddressPal.Clear(socketAddress); } @@ -846,7 +846,7 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span 0 && socketAddressLen == 0) { // We can fail to get peer address on TCP - socketAddressLen = socketAddress.Length; + socketAddressLen = socketAddress.Length; SocketAddressPal.Clear(socketAddress); } return true; From 2ffb7e6873e80c1522ecddb90aaf14b06d4764df Mon Sep 17 00:00:00 2001 From: wfurt Date: Thu, 27 Jul 2023 15:31:54 -0700 Subject: [PATCH 16/18] feedback --- .../src/System/Net/IPEndPointExtensions.cs | 1 + .../FunctionalTests/SocketAddressTest.cs | 22 +++++++++++++++++++ .../src/Resources/Strings.resx | 3 +++ .../src/System/Net/Sockets/Socket.cs | 5 +++++ .../tests/FunctionalTests/ReceiveFrom.cs | 15 ++++++++++++- 5 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs index 7b9a6a697a4e3..2308d48c29671 100644 --- a/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs +++ b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Net; namespace System.Net.Sockets diff --git a/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs b/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs index ec79ac2976b9b..0a185b7b42296 100644 --- a/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs +++ b/src/libraries/System.Net.Primitives/tests/FunctionalTests/SocketAddressTest.cs @@ -31,6 +31,28 @@ public static void Ctor_AddressFamilySize_Invalid() Assert.Throws(() => new SocketAddress(AddressFamily.InterNetwork, 1)); //Size < MinSize (32) } + [Theory] + [InlineData(AddressFamily.InterNetwork)] + [InlineData(AddressFamily.InterNetworkV6)] + [InlineData(AddressFamily.Unix)] + public static void Ctor_AddressFamilySize_Correct(AddressFamily addressFamily) + { + SocketAddress sa = new SocketAddress(addressFamily); + Assert.Equal(SocketAddress.GetMaximumAddressSize(addressFamily), sa.Size); + Assert.Equal(SocketAddress.GetMaximumAddressSize(addressFamily), sa.Buffer.Length); + Assert.True(sa.Size <= SocketAddress.GetMaximumAddressSize(AddressFamily.Unknown)); + } + + [Fact] + public static void AddressFamily_Size_Correct() + { + SocketAddress sa = new SocketAddress(AddressFamily.InterNetwork); + Assert.Throws(() => sa.Size = sa.Size + 1); + + sa.Size = 4; + Assert.Equal(4, sa.Buffer.Length); + } + [Fact] public static void Equals_Compare_Success() { diff --git a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx index 72a5478e850ac..7a0feca077c03 100644 --- a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx @@ -315,4 +315,7 @@ Handle is already used by another Socket. + + Provided SocketAddress is too small for given AddressFamily. + diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index e01ec8da41f37..b9a0a90350354 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -1895,6 +1895,11 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress { ThrowIfDisposed(); + if (receivedSocketAddress.Size < SocketAddress.GetMaximumAddressSize(AddressFamily)) + { + throw new ArgumentOutOfRangeException(nameof(receivedSocketAddress), SR.net_sockets_address_small); + } + ValidateBlockingMode(); int bytesTransferred; diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs index 1f590dec81416..1a720df27d250 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs @@ -159,7 +159,6 @@ public async Task ReceiveSent_UDP_Success(bool ipv4) [InlineData(true)] public void ReceiveSent_SocketAddress_Success(bool ipv4) { - //const int Offset = 10; const int DatagramSize = 256; const int DatagramsToSend = 16; @@ -200,6 +199,20 @@ public void ReceiveSent_SocketAddress_Success(bool ipv4) } } + [Fact] + public void ReceiveSent_SmallSocketAddress_Throws() + { + using Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + server.BindToAnonymousPort(IPAddress.Loopback); + + byte[] receiveBuffer = new byte[1]; + + SocketAddress serverSA = server.LocalEndPoint.Serialize(); + + SocketAddress sa = new SocketAddress(AddressFamily.InterNetwork, 2); + Assert.Throws(() => server.ReceiveFrom(receiveBuffer, SocketFlags.None, sa)); + } + [Theory] [InlineData(true)] [InlineData(false)] From fc26610e684041df1aba32165ea455973285ef81 Mon Sep 17 00:00:00 2001 From: wfurt Date: Thu, 27 Jul 2023 18:44:16 -0700 Subject: [PATCH 17/18] loose ends --- .../src/System/Net/SocketAddressPal.Unix.cs | 3 +- .../Net/Quic/Internal/MsQuicExtensions.cs | 8 ++--- .../System/Net/Quic/Internal/MsQuicHelpers.cs | 4 +-- .../src/System/Net/Quic/QuicConnection.cs | 12 +++---- .../src/System/Net/Quic/QuicListener.cs | 4 +-- .../src/System/Net/Sockets/Socket.cs | 20 ++++-------- .../src/System/Net/Sockets/SocketPal.Unix.cs | 31 +++++++++---------- 7 files changed, 37 insertions(+), 45 deletions(-) diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs index 8ebe41f61eae0..fc73ddca17989 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -177,7 +177,8 @@ public static unsafe void Clear(Span buffer) { AddressFamily family = GetAddressFamily(buffer); buffer.Clear(); - buffer[0] = (byte)buffer.Length; + // platforms where this matters (OSXLike & BSD) use uint8 for SA length + buffer[0] = (byte)Math.Min(buffer.Length, 255); SetAddressFamily(buffer, family); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs index fbade293cfda3..a3d7bc6f3f7d3 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs @@ -8,7 +8,7 @@ namespace Microsoft.Quic; internal unsafe partial struct QUIC_NEW_CONNECTION_INFO { public override string ToString() - => $"{{ {nameof(QuicVersion)} = {QuicVersion}, {nameof(LocalAddress)} = {LocalAddress->ToIPEndPoint()}, {nameof(RemoteAddress)} = {RemoteAddress->ToIPEndPoint()} }}"; + => $"{{ {nameof(QuicVersion)} = {QuicVersion}, {nameof(LocalAddress)} = {MsQuicHelpers.QuicAddrToIPEndPoint(LocalAddress)}, {nameof(RemoteAddress)} = {MsQuicHelpers.QuicAddrToIPEndPoint(RemoteAddress)} }}"; } internal unsafe partial struct QUIC_LISTENER_EVENT @@ -17,7 +17,7 @@ public override string ToString() => Type switch { QUIC_LISTENER_EVENT_TYPE.NEW_CONNECTION - => $"{{ {nameof(NEW_CONNECTION.Info)} = {{ {nameof(QUIC_NEW_CONNECTION_INFO.QuicVersion)} = {NEW_CONNECTION.Info->QuicVersion}, {nameof(QUIC_NEW_CONNECTION_INFO.LocalAddress)} = {NEW_CONNECTION.Info->LocalAddress->ToIPEndPoint()}, {nameof(QUIC_NEW_CONNECTION_INFO.RemoteAddress)} = {NEW_CONNECTION.Info->RemoteAddress->ToIPEndPoint()} }} }}", + => $"{{ {nameof(NEW_CONNECTION.Info)} = {{ {nameof(QUIC_NEW_CONNECTION_INFO.QuicVersion)} = {NEW_CONNECTION.Info->QuicVersion}, {nameof(QUIC_NEW_CONNECTION_INFO.LocalAddress)} = {MsQuicHelpers.QuicAddrToIPEndPoint(NEW_CONNECTION.Info->LocalAddress)}, {nameof(QUIC_NEW_CONNECTION_INFO.RemoteAddress)} = {MsQuicHelpers.QuicAddrToIPEndPoint(NEW_CONNECTION.Info->RemoteAddress)} }} }}", _ => string.Empty }; } @@ -36,9 +36,9 @@ public override string ToString() QUIC_CONNECTION_EVENT_TYPE.SHUTDOWN_COMPLETE => $"{{ {nameof(SHUTDOWN_COMPLETE.HandshakeCompleted)} = {SHUTDOWN_COMPLETE.HandshakeCompleted}, {nameof(SHUTDOWN_COMPLETE.PeerAcknowledgedShutdown)} = {SHUTDOWN_COMPLETE.PeerAcknowledgedShutdown}, {nameof(SHUTDOWN_COMPLETE.AppCloseInProgress)} = {SHUTDOWN_COMPLETE.AppCloseInProgress} }}", QUIC_CONNECTION_EVENT_TYPE.LOCAL_ADDRESS_CHANGED - => $"{{ {nameof(LOCAL_ADDRESS_CHANGED.Address)} = {LOCAL_ADDRESS_CHANGED.Address->ToIPEndPoint()} }}", + => $"{{ {nameof(LOCAL_ADDRESS_CHANGED.Address)} = {MsQuicHelpers.QuicAddrToIPEndPoint(LOCAL_ADDRESS_CHANGED.Address)} }}", QUIC_CONNECTION_EVENT_TYPE.PEER_ADDRESS_CHANGED - => $"{{ {nameof(PEER_ADDRESS_CHANGED.Address)} = {PEER_ADDRESS_CHANGED.Address->ToIPEndPoint()} }}", + => $"{{ {nameof(PEER_ADDRESS_CHANGED.Address)} = {MsQuicHelpers.QuicAddrToIPEndPoint(PEER_ADDRESS_CHANGED.Address)} }}", QUIC_CONNECTION_EVENT_TYPE.PEER_STREAM_STARTED => $"{{ {nameof(PEER_STREAM_STARTED.Flags)} = {PEER_STREAM_STARTED.Flags} }}", QUIC_CONNECTION_EVENT_TYPE.PEER_CERTIFICATE_RECEIVED diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs index 17886b8f8a1ba..ad3ed5c134599 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs @@ -35,10 +35,10 @@ internal static bool TryParse(this EndPoint endPoint, out string? host, out IPAd return false; } - internal static unsafe IPEndPoint ToIPEndPoint(this ref QuicAddr quicAddress, AddressFamily? addressFamilyOverride = null) + internal static unsafe IPEndPoint QuicAddrToIPEndPoint(QuicAddr* quicAddress, AddressFamily? addressFamilyOverride = null) { // MsQuic always uses storage size as if IPv6 was used - Span addressBytes = new Span((byte*)Unsafe.AsPointer(ref quicAddress), SocketAddressPal.IPv6AddressSize); + Span addressBytes = new Span(quicAddress, SocketAddressPal.IPv6AddressSize); if (addressFamilyOverride != null) { SocketAddressPal.SetAddressFamily(addressBytes, (AddressFamily)addressFamilyOverride!); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index b3ac83567cee4..4c82110c11836 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -239,8 +239,8 @@ internal unsafe QuicConnection(QUIC_HANDLE* handle, QUIC_NEW_CONNECTION_INFO* in throw; } - _remoteEndPoint = info->RemoteAddress->ToIPEndPoint(); - _localEndPoint = info->LocalAddress->ToIPEndPoint(); + _remoteEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(info->RemoteAddress); + _localEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(info->LocalAddress); #if DEBUG _tlsSecret = MsQuicTlsSecret.Create(_handle); #endif @@ -478,10 +478,10 @@ private unsafe int HandleEventConnected(ref CONNECTED_DATA data) _negotiatedApplicationProtocol = new SslApplicationProtocol(new Span(data.NegotiatedAlpn, data.NegotiatedAlpnLength).ToArray()); QuicAddr remoteAddress = MsQuicHelpers.GetMsQuicParameter(_handle, QUIC_PARAM_CONN_REMOTE_ADDRESS); - _remoteEndPoint = remoteAddress.ToIPEndPoint(); + _remoteEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(&remoteAddress); QuicAddr localAddress = MsQuicHelpers.GetMsQuicParameter(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS); - _localEndPoint = localAddress.ToIPEndPoint(); + _localEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(&localAddress); if (NetEventSource.Log.IsEnabled()) { @@ -514,12 +514,12 @@ private unsafe int HandleEventShutdownComplete() } private unsafe int HandleEventLocalAddressChanged(ref LOCAL_ADDRESS_CHANGED_DATA data) { - _localEndPoint = data.Address->ToIPEndPoint(); + _localEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(data.Address); return QUIC_STATUS_SUCCESS; } private unsafe int HandleEventPeerAddressChanged(ref PEER_ADDRESS_CHANGED_DATA data) { - _remoteEndPoint = data.Address->ToIPEndPoint(); + _remoteEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(data.Address); return QUIC_STATUS_SUCCESS; } private unsafe int HandleEventPeerStreamStarted(ref PEER_STREAM_STARTED_DATA data) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs index 7b4c6613afabf..4116d4a609f79 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs @@ -157,7 +157,7 @@ private unsafe QuicListener(QuicListenerOptions options) // Get the actual listening endpoint. address = GetMsQuicParameter(_handle, QUIC_PARAM_LISTENER_LOCAL_ADDRESS); - LocalEndPoint = address.ToIPEndPoint(options.ListenEndPoint.AddressFamily); + LocalEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(&address, options.ListenEndPoint.AddressFamily); } /// @@ -281,7 +281,7 @@ private unsafe int HandleEventNewConnection(ref NEW_CONNECTION_DATA data) { if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Info(this, $"{this} Refusing connection from {data.Info->RemoteAddress->ToIPEndPoint()} due to backlog limit"); + NetEventSource.Info(this, $"{this} Refusing connection from {MsQuicHelpers.QuicAddrToIPEndPoint(data.Info->RemoteAddress)} due to backlog limit"); } Interlocked.Increment(ref _pendingConnectionsCapacity); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index b9a0a90350354..3f450bb2544ed 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -1904,20 +1904,19 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress int bytesTransferred; SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, receivedSocketAddress.Buffer, out int socketAddressSize, out bytesTransferred); - + if (socketAddressSize > 0) + { + receivedSocketAddress.Size = socketAddressSize; + } UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); // If the native call fails we'll throw a SocketException. - SocketException? socketException = null; if (errorCode != SocketError.Success) { - socketException = new SocketException((int)errorCode); + SocketException socketException = new SocketException((int)errorCode); UpdateStatusAfterSocketError(socketException); if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, socketException); - if (socketException.SocketErrorCode != SocketError.MessageSize) - { - throw socketException; - } + throw socketException; } else if (SocketsTelemetry.Log.IsEnabled()) { @@ -1925,13 +1924,6 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); } - if (socketException != null) - { - throw socketException; - } - - receivedSocketAddress.Size = socketAddressSize; - return bytesTransferred; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 27415a5e440a0..06ce39380a4b8 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -796,13 +796,12 @@ public static unsafe bool TryCompleteReceive(SafeSocketHandle socket, Span } } - public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, Span socketAddress, out int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) + public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, Span socketAddress, out int receivedSocketAddressLength, out int bytesReceived, out SocketFlags receivedFlags, out SocketError errorCode) { try { Interop.Error errno; int received; - int socketAddressLength = 0; if (!socket.IsSocket) { @@ -812,11 +811,12 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span(&oneBytePeekBuffer, 1), socketAddress, out socketAddressLength, out receivedFlags, out errno); + received = SysReceive(socket, flags | SocketFlags.Peek, new Span(&oneBytePeekBuffer, 1), socketAddress, out receivedSocketAddressLength, out receivedFlags, out errno); if (received > 0) { // Peeked for 1-byte, but the actual request was for 0. @@ -835,25 +835,24 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span 0 bytes into a single buffer - received = SysReceive(socket, flags, buffer, socketAddress, out socketAddressLength, out receivedFlags, out errno); + received = SysReceive(socket, flags, buffer, socketAddress, out receivedSocketAddressLength, out receivedFlags, out errno); } if (received != -1) { bytesReceived = received; errorCode = SocketError.Success; - socketAddressLen = socketAddressLength; - if (socketAddress.Length > 0 && socketAddressLen == 0) + if (socketAddress.Length > 0 && receivedSocketAddressLength == 0) { // We can fail to get peer address on TCP - socketAddressLen = socketAddress.Length; + receivedSocketAddressLength = socketAddress.Length; SocketAddressPal.Clear(socketAddress); } return true; } bytesReceived = 0; - socketAddressLen = 0; + receivedSocketAddressLength = 0; if (errno != Interop.Error.EAGAIN && errno != Interop.Error.EWOULDBLOCK) { @@ -869,28 +868,28 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, Memory socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out SocketError errorCode) + public static unsafe bool TryCompleteReceiveMessageFrom(SafeSocketHandle socket, Span buffer, IList>? buffers, SocketFlags flags, Memory socketAddress, out int receivedSocketAddressLength, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out SocketError errorCode) { try { Interop.Error errno; int received = buffers == null ? - SysReceiveMessageFrom(socket, flags, buffer, socketAddress.Span, out socketAddressLen, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno) : - SysReceiveMessageFrom(socket, flags, buffers, socketAddress.Span, out socketAddressLen, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno); + SysReceiveMessageFrom(socket, flags, buffer, socketAddress.Span, out receivedSocketAddressLength, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno) : + SysReceiveMessageFrom(socket, flags, buffers, socketAddress.Span, out receivedSocketAddressLength, isIPv4, isIPv6, out receivedFlags, out ipPacketInformation, out errno); if (received != -1) { - if (socketAddress.Length > 0 && socketAddressLen == 0) + if (socketAddress.Length > 0 && receivedSocketAddressLength == 0) { // We can fail to get peer address on TCP - socketAddressLen = socketAddress.Length; + receivedSocketAddressLength = socketAddress.Length; SocketAddressPal.Clear(socketAddress.Span); } bytesReceived = received; @@ -914,7 +913,7 @@ public static unsafe bool TryCompleteReceiveMessageFrom(SafeSocketHandle socket, // The socket was closed, or is closing. bytesReceived = 0; receivedFlags = 0; - socketAddressLen = 0; + receivedSocketAddressLength = 0; ipPacketInformation = default(IPPacketInformation); errorCode = SocketError.OperationAborted; return true; From 981f659522f83481d2c8d8e2e03efff9e8f5e3ee Mon Sep 17 00:00:00 2001 From: wfurt Date: Mon, 31 Jul 2023 20:05:34 -0700 Subject: [PATCH 18/18] feedback --- .../Interop/Windows/WinSock/Interop.WSAConnect.cs | 13 +++++++++++-- .../Common/src/System/Net/SocketAddressPal.Unix.cs | 4 ++-- .../src/System/Net/Sockets/SocketPal.Windows.cs | 1 - 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs index aef925573e60f..6f3faacef3279 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAConnect.cs @@ -10,13 +10,22 @@ internal static partial class Interop internal static partial class Winsock { [LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true)] - internal static partial SocketError WSAConnect( + private static partial SocketError WSAConnect( SafeSocketHandle socketHandle, - Span socketAddress, + ReadOnlySpan socketAddress, int socketAddressSize, IntPtr inBuffer, IntPtr outBuffer, IntPtr sQOS, IntPtr gQOS); + + internal static SocketError WSAConnect( + SafeSocketHandle socketHandle, + ReadOnlySpan socketAddress, + IntPtr inBuffer, + IntPtr outBuffer, + IntPtr sQOS, + IntPtr gQOS) => + WSAConnect(socketHandle, socketAddress, socketAddress.Length, inBuffer, outBuffer, sQOS, gQOS); } } diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs index fc73ddca17989..107d1d5e65967 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -30,9 +30,9 @@ static unsafe SocketAddressPal() Debug.Assert(uds > 0); Debug.Assert(max >= ipv4 && max >= ipv6 && max >= uds); IPv4AddressSize = ipv4; - IPv6AddressSize =ipv6; + IPv6AddressSize = ipv6; UdsAddressSize = uds; - MaxAddressSize =max; + MaxAddressSize = max; } #pragma warning restore CA1810 diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index 38528df066b83..7ab0a867e64c7 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -192,7 +192,6 @@ public static SocketError Connect(SafeSocketHandle handle, Memory peerAddr SocketError errorCode = Interop.Winsock.WSAConnect( handle, peerAddress.Span, - peerAddress.Length, IntPtr.Zero, IntPtr.Zero, IntPtr.Zero,