diff --git a/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs index f1723a2361454..0e242e592fab1 100644 --- a/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs +++ b/src/libraries/Common/src/System/Net/IPEndPointExtensions.cs @@ -58,5 +58,34 @@ public static void Serialize(this IPEndPoint endPoint, Span destination) SetIPAddress(destination, endPoint.Address); SocketAddressPal.SetPort(destination, (ushort)endPoint.Port); } + + public static bool Equals(this IPEndPoint endPoint, ReadOnlySpan socketAddressBuffer) + { + if (socketAddressBuffer.Length >= SocketAddress.GetMaximumAddressSize(endPoint.AddressFamily) && + endPoint.AddressFamily == SocketAddressPal.GetAddressFamily(socketAddressBuffer) && + endPoint.Port == (int)SocketAddressPal.GetPort(socketAddressBuffer)) + { + if (endPoint.AddressFamily == AddressFamily.InterNetwork) + { +#pragma warning disable CS0618 + return endPoint.Address.Address == (long)SocketAddressPal.GetIPv4Address(socketAddressBuffer); +#pragma warning restore CS0618 + } + else + { + Span addressBuffer1 = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; + Span addressBuffer2 = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; + SocketAddressPal.GetIPv6Address(socketAddressBuffer, addressBuffer1, out uint scopeid); + if (endPoint.Address.ScopeId != (long)scopeid) + { + return false; + } + endPoint.Address.TryWriteBytes(addressBuffer2, out _); + return addressBuffer1.SequenceEqual(addressBuffer2); + } + } + + return false; + } } } diff --git a/src/libraries/Common/src/System/Net/Internals/IPAddressExtensions.cs b/src/libraries/Common/src/System/Net/Internals/IPAddressExtensions.cs deleted file mode 100644 index acd2aaed82d6c..0000000000000 --- a/src/libraries/Common/src/System/Net/Internals/IPAddressExtensions.cs +++ /dev/null @@ -1,30 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; - -namespace System.Net.Sockets -{ - internal static class IPAddressExtensions - { - public static IPAddress Snapshot(this IPAddress original) - { - switch (original.AddressFamily) - { - case AddressFamily.InterNetwork: -#pragma warning disable CS0618 // IPAddress.Address is obsoleted, but it's the most efficient way to get the Int32 IPv4 address - return new IPAddress(original.Address); -#pragma warning restore CS0618 - - case AddressFamily.InterNetworkV6: - Span addressBytes = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; - original.TryWriteBytes(addressBytes, out int bytesWritten); - Debug.Assert(bytesWritten == IPAddressParserStatics.IPv6AddressBytes); - return new IPAddress(addressBytes, (uint)original.ScopeId); - - default: - throw new InternalException(original.AddressFamily); - } - } - } -} diff --git a/src/libraries/Common/src/System/Net/Internals/IPEndPointExtensions.cs b/src/libraries/Common/src/System/Net/Internals/IPEndPointExtensions.cs deleted file mode 100644 index 590048a08912b..0000000000000 --- a/src/libraries/Common/src/System/Net/Internals/IPEndPointExtensions.cs +++ /dev/null @@ -1,72 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; - -namespace System.Net.Sockets -{ - internal static partial class IPEndPointExtensions - { - public static Internals.SocketAddress Serialize(EndPoint endpoint) - { - Debug.Assert(!(endpoint is DnsEndPoint)); - - var ipEndPoint = endpoint as IPEndPoint; - if (ipEndPoint != null) - { - return new Internals.SocketAddress(ipEndPoint.Address, ipEndPoint.Port); - } - - System.Net.SocketAddress address = endpoint.Serialize(); - return GetInternalSocketAddress(address); - } - - public static EndPoint Create(this EndPoint thisObj, Internals.SocketAddress socketAddress) - { - AddressFamily family = socketAddress.Family; - if (family != thisObj.AddressFamily) - { - throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, family.ToString(), thisObj.GetType().FullName, thisObj.AddressFamily.ToString()), nameof(socketAddress)); - } - - if (family == AddressFamily.InterNetwork || family == AddressFamily.InterNetworkV6) - { - if (socketAddress.Size < 8) - { - throw new ArgumentException(SR.Format(SR.net_InvalidSocketAddressSize, socketAddress.GetType().FullName, thisObj.GetType().FullName), nameof(socketAddress)); - } - - return socketAddress.GetIPEndPoint(); - } - else if (family == AddressFamily.Unknown) - { - return thisObj; - } - - System.Net.SocketAddress address = GetNetSocketAddress(socketAddress); - return thisObj.Create(address); - } - - private static Internals.SocketAddress GetInternalSocketAddress(System.Net.SocketAddress address) - { - var result = new Internals.SocketAddress(address.Family, address.Size); - for (int index = 0; index < address.Size; index++) - { - result[index] = address[index]; - } - - return result; - } - - internal static System.Net.SocketAddress GetNetSocketAddress(Internals.SocketAddress address) - { - var result = new System.Net.SocketAddress(address.Family, address.Size); - for (int index = 0; index < address.Size; index++) - { - result[index] = address[index]; - } - - return result; - } - } -} diff --git a/src/libraries/Common/src/System/Net/Internals/readme.md b/src/libraries/Common/src/System/Net/Internals/readme.md deleted file mode 100644 index 58353735bfa45..0000000000000 --- a/src/libraries/Common/src/System/Net/Internals/readme.md +++ /dev/null @@ -1,4 +0,0 @@ -Contracts such as NameResolution and Sockets require internal access to Primitive types. Binary copies of these types have been made within the System.Net.Internals namespace using #ifdef pragmas (source code is reused). - -An adaptation layer between .Internals and public types exists within the Extensions classes. - diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index 8c844c3e0d8f7..f4b6522d32b2a 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -1,26 +1,14 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers.Binary; using System.Diagnostics; -using System.Globalization; using System.Net.Sockets; -using System.Text; -#if SYSTEM_NET_PRIMITIVES_DLL namespace System.Net -#else -namespace System.Net.Internals -#endif { // This class is used when subclassing EndPoint, and provides indication // on how to format the memory buffers that the platform uses for network addresses. -#if SYSTEM_NET_PRIMITIVES_DLL - public -#else - internal sealed -#endif - class SocketAddress : System.IEquatable + public class SocketAddress : IEquatable { #pragma warning disable CA1802 // these could be const on Windows but need to be static readonly for Unix internal static readonly int IPv6AddressSize = SocketAddressPal.IPv6AddressSize; @@ -52,7 +40,7 @@ public int Size set { ArgumentOutOfRangeException.ThrowIfGreaterThan(value, _buffer.Length); - ArgumentOutOfRangeException.ThrowIfLessThan(value, MinSize); + ArgumentOutOfRangeException.ThrowIfLessThan(value, 0); _size = value; } } @@ -137,13 +125,6 @@ internal SocketAddress(IPAddress ipaddress, int port) SocketAddressPal.SetPort(_buffer, unchecked((ushort)port)); } - internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan buffer) - { - _buffer = buffer.ToArray(); - _size = _buffer.Length; - SocketAddressPal.SetAddressFamily(_buffer, addressFamily); - } - /// This represents underlying memory that can be passed to native OS calls. /// /// Content of the memory can be invalidated if is changed or if the SocketAddress is used in another receive call. @@ -152,44 +133,10 @@ public Memory Buffer { get { - return new Memory(_buffer, 0, _size); - } - } - - internal IPAddress GetIPAddress() - { - if (Family == AddressFamily.InterNetworkV6) - { - Debug.Assert(Size >= IPv6AddressSize); - - Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; - uint scope; - SocketAddressPal.GetIPv6Address(_buffer, address, out scope); - - return new IPAddress(address, (long)scope); - } - else if (Family == AddressFamily.InterNetwork) - { - Debug.Assert(Size >= IPv4AddressSize); - long address = (long)SocketAddressPal.GetIPv4Address(_buffer) & 0x0FFFFFFFF; - return new IPAddress(address); - } - else - { -#if SYSTEM_NET_PRIMITIVES_DLL - throw new SocketException(SocketError.AddressFamilyNotSupported); -#else - throw new SocketException((int)SocketError.AddressFamilyNotSupported); -#endif + return new Memory(_buffer); } } - internal int GetPort() => (int)SocketAddressPal.GetPort(_buffer); - - internal IPEndPoint GetIPEndPoint() - { - return new IPEndPoint(GetIPAddress(), GetPort()); - } public override bool Equals(object? comparand) => comparand is SocketAddress other && Equals(other); diff --git a/src/libraries/Common/src/System/Net/SocketAddressExtensions.cs b/src/libraries/Common/src/System/Net/SocketAddressExtensions.cs new file mode 100644 index 0000000000000..b260ce1a2c442 --- /dev/null +++ b/src/libraries/Common/src/System/Net/SocketAddressExtensions.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Net; + +namespace System.Net.Sockets +{ + internal static partial class SocketAddressExtensions + { + public static IPAddress GetIPAddress(this SocketAddress socketAddress) => IPEndPointExtensions.GetIPAddress(socketAddress.Buffer.Span); + public static int GetPort(this SocketAddress socketAddress) + { + Debug.Assert(socketAddress.Family == AddressFamily.InterNetwork || socketAddress.Family == AddressFamily.InterNetworkV6); + return (int)SocketAddressPal.GetPort(socketAddress.Buffer.Span); + } + + public static IPEndPoint GetIPEndPoint(this SocketAddress socketAddress) + { + return new IPEndPoint(socketAddress.GetIPAddress(), socketAddress.GetPort()); + } + + public static bool Equals(this SocketAddress socketAddress, EndPoint? endPoint) + { + if (socketAddress.Family == endPoint?.AddressFamily && endPoint is IPEndPoint ipe) + { + return ipe.Equals(socketAddress.Buffer.Span); + } + + // We could serialize other EndPoints and compare socket addresses. + // But that would do two allocations and is probably as expensive as + // allocating new EndPoint. + // This may change if https://github.com/dotnet/runtime/issues/78993 is done + return false; + } + } +} diff --git a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs index a7115c9f2539b..cb81412f460ad 100644 --- a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Unix.cs @@ -8,6 +8,8 @@ namespace System.Net { internal static partial class SocketProtocolSupportPal { + private const int DgramSocketType = 2; + private static unsafe bool IsSupported(AddressFamily af) { // Check for AF_UNIX on iOS/tvOS. The OS claims to support this, but returns EPERM on bind. diff --git a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Windows.cs b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Windows.cs index 4cc5ffc33b45b..063d04fd3a6ba 100644 --- a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Windows.cs +++ b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Windows.cs @@ -10,13 +10,14 @@ internal static partial class SocketProtocolSupportPal { private static bool IsSupported(AddressFamily af) { + const int StreamSocketType = 1; Interop.Winsock.EnsureInitialized(); IntPtr INVALID_SOCKET = (IntPtr)(-1); IntPtr socket = INVALID_SOCKET; try { - socket = Interop.Winsock.WSASocketW(af, DgramSocketType, 0, IntPtr.Zero, 0, (int)Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT); + socket = Interop.Winsock.WSASocketW(af, StreamSocketType, 0, IntPtr.Zero, 0, (int)Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT); return socket != INVALID_SOCKET || (SocketError)Marshal.GetLastPInvokeError() != SocketError.AddressFamilyNotSupported; diff --git a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.cs b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.cs index d6906e3b5c1d4..a61f47a0fa458 100644 --- a/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.cs +++ b/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.cs @@ -14,8 +14,6 @@ internal static partial class SocketProtocolSupportPal public static bool OSSupportsIPv4 { get; } = IsSupported(AddressFamily.InterNetwork); public static bool OSSupportsUnixDomainSockets { get; } = IsSupported(AddressFamily.Unix); - private const int DgramSocketType = 2; - private static bool IsIPv6Disabled() { // First check for the AppContext switch, giving it priority over the environment variable. diff --git a/src/libraries/Common/src/System/Net/Sockets/ProtocolType.cs b/src/libraries/Common/src/System/Net/Sockets/ProtocolType.cs index c93b254c3f8a2..ce2090e2abb38 100644 --- a/src/libraries/Common/src/System/Net/Sockets/ProtocolType.cs +++ b/src/libraries/Common/src/System/Net/Sockets/ProtocolType.cs @@ -1,15 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -#if SYSTEM_NET_SOCKETS_DLL namespace System.Net.Sockets { public -#else -namespace System.Net.Internals -{ - internal -#endif // Specifies the protocols that the Socket class supports. enum ProtocolType { diff --git a/src/libraries/Common/src/System/Net/Sockets/SocketType.cs b/src/libraries/Common/src/System/Net/Sockets/SocketType.cs index 58dc09a4fe73b..c3f082f3c53b6 100644 --- a/src/libraries/Common/src/System/Net/Sockets/SocketType.cs +++ b/src/libraries/Common/src/System/Net/Sockets/SocketType.cs @@ -1,15 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -#if SYSTEM_NET_SOCKETS_DLL namespace System.Net.Sockets { public -#else -namespace System.Net.Internals -{ - internal -#endif // Specifies the type of socket an instance of the System.Net.Sockets.Socket class represents. enum SocketType { diff --git a/src/libraries/System.Net.Primitives/src/System.Net.Primitives.csproj b/src/libraries/System.Net.Primitives/src/System.Net.Primitives.csproj index e41eda1691a4c..39bf0e8f1437c 100644 --- a/src/libraries/System.Net.Primitives/src/System.Net.Primitives.csproj +++ b/src/libraries/System.Net.Primitives/src/System.Net.Primitives.csproj @@ -60,6 +60,8 @@ Link="Common\System\Net\CookieParser.cs" /> + + (() => sa.Size = sa.Size + 1); - - sa.Size = 4; - Assert.Equal(4, sa.Buffer.Length); + sa.Size = 0; + Assert.Throws(() => sa.Size = - 1); } [Fact] diff --git a/src/libraries/System.Net.Primitives/tests/PalTests/System.Net.Primitives.Pal.Tests.csproj b/src/libraries/System.Net.Primitives/tests/PalTests/System.Net.Primitives.Pal.Tests.csproj index a48a8e9eb222c..03c17baf2ff96 100644 --- a/src/libraries/System.Net.Primitives/tests/PalTests/System.Net.Primitives.Pal.Tests.csproj +++ b/src/libraries/System.Net.Primitives/tests/PalTests/System.Net.Primitives.Pal.Tests.csproj @@ -39,6 +39,10 @@ Link="Common\System\HexConverter.cs" /> + + diff --git a/src/libraries/System.Net.Primitives/tests/UnitTests/System.Net.Primitives.UnitTests.Tests.csproj b/src/libraries/System.Net.Primitives/tests/UnitTests/System.Net.Primitives.UnitTests.Tests.csproj index 3bbac4fb7203e..0510554f5f393 100644 --- a/src/libraries/System.Net.Primitives/tests/UnitTests/System.Net.Primitives.UnitTests.Tests.csproj +++ b/src/libraries/System.Net.Primitives/tests/UnitTests/System.Net.Primitives.UnitTests.Tests.csproj @@ -50,6 +50,10 @@ Link="Common\System\HexConverter.cs" /> + + diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index 57de68d185656..fca21b5d7c952 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -400,11 +400,12 @@ public void Listen(int backlog) { } public int ReceiveFrom(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } - public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress receivedSocketAddress) { throw null; } + public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress receivedAddress) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress receivedAddress, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public bool ReceiveFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; } public int ReceiveMessageFrom(System.Span buffer, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; } @@ -451,6 +452,7 @@ public void SendFile(string? fileName, System.ReadOnlySpan preBuffer, Syst public bool SendToAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } public System.Threading.Tasks.ValueTask SendToAsync(System.ReadOnlyMemory buffer, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public System.Threading.Tasks.ValueTask SendToAsync(System.ReadOnlyMemory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Threading.Tasks.ValueTask SendToAsync(System.ReadOnlyMemory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress socketAddress, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } [System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")] public void SetIPProtectionLevel(System.Net.Sockets.IPProtectionLevel level) { } public void SetRawSocketOption(int optionLevel, int optionName, System.ReadOnlySpan optionValue) { } diff --git a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj index d013d5e706452..41fe91603e17c 100644 --- a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -69,17 +69,12 @@ Link="Common\System\Net\ExceptionCheck.cs" /> - + - - - ReceiveFromAsync(Memory buffer, saea.SetBuffer(buffer); saea.SocketFlags = socketFlags; saea.RemoteEndPoint = remoteEndPoint; + saea._socketAddress = new SocketAddress(AddressFamily); + if (remoteEndPoint!.AddressFamily != AddressFamily && AddressFamily == AddressFamily.InterNetworkV6 && IsDualMode) + { + saea.RemoteEndPoint = s_IPEndPointIPv6; + } saea.WrapExceptionsForNetworkStream = false; return saea.ReceiveFromAsync(this, cancellationToken); } + /// + /// Receives data and returns the endpoint of the sending host. + /// + /// The buffer for the received data. + /// A bitwise combination of SocketFlags values that will be used when receiving the data. + /// An , that will be updated with value of the remote peer. + /// A cancellation token that can be used to signal the asynchronous operation should be canceled. + /// An asynchronous task that completes with a containing the number of bytes received and the endpoint of the sending host. + public ValueTask ReceiveFromAsync(Memory buffer, SocketFlags socketFlags, SocketAddress receivedAddress, CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + ArgumentNullException.ThrowIfNull(receivedAddress, nameof(receivedAddress)); + + if (receivedAddress.Size < SocketAddress.GetMaximumAddressSize(AddressFamily)) + { + throw new ArgumentOutOfRangeException(nameof(receivedAddress), SR.net_sockets_address_small); + } + + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + AwaitableSocketAsyncEventArgs saea = + Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null) ?? + new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: true); + + Debug.Assert(saea.BufferList == null); + saea.SetBuffer(buffer); + saea.SocketFlags = socketFlags; + saea.RemoteEndPoint = null; + saea._socketAddress = receivedAddress; + saea.WrapExceptionsForNetworkStream = false; + return saea.ReceiveFromSocketAddressAsync(this, cancellationToken); + } + /// /// Receives data and returns additional information about the sender of the message. /// @@ -636,11 +677,42 @@ public ValueTask SendToAsync(ReadOnlyMemory buffer, SocketFlags socke Debug.Assert(saea.BufferList == null); saea.SetBuffer(MemoryMarshal.AsMemory(buffer)); saea.SocketFlags = socketFlags; + saea._socketAddress = null; saea.RemoteEndPoint = remoteEP; saea.WrapExceptionsForNetworkStream = false; return saea.SendToAsync(this, cancellationToken); } + /// + /// Sends data to the specified remote host. + /// + /// The buffer for the data to send. + /// A bitwise combination of SocketFlags values that will be used when sending the data. + /// The remote host to which to send the data. + /// A cancellation token that can be used to cancel the asynchronous operation. + /// An asynchronous task that completes with the number of bytes sent. + public ValueTask SendToAsync(ReadOnlyMemory buffer, SocketFlags socketFlags, SocketAddress socketAddress, CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + ArgumentNullException.ThrowIfNull(socketAddress); + + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + AwaitableSocketAsyncEventArgs saea = + Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ?? + new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false); + + Debug.Assert(saea.BufferList == null); + saea.SetBuffer(MemoryMarshal.AsMemory(buffer)); + saea.SocketFlags = socketFlags; + saea._socketAddress = socketAddress; + saea.WrapExceptionsForNetworkStream = false; + return saea.SendToAsync(this, cancellationToken); + } + /// /// Sends the file to a connected object. /// @@ -1019,6 +1091,24 @@ public ValueTask ReceiveFromAsync(Socket socket, Cancel ValueTask.FromException(CreateException(error)); } + internal ValueTask ReceiveFromSocketAddressAsync(Socket socket, CancellationToken cancellationToken) + { + if (socket.ReceiveFromAsync(this, cancellationToken)) + { + _cancellationToken = cancellationToken; + return new ValueTask(this, _mrvtsc.Version); + } + + int bytesTransferred = BytesTransferred; + SocketError error = SocketError; + + ReleaseForSyncCompletion(); + + return error == SocketError.Success ? + new ValueTask(bytesTransferred) : + ValueTask.FromException(CreateException(error)); + } + public ValueTask ReceiveMessageFromAsync(Socket socket, CancellationToken cancellationToken) { if (socket.ReceiveMessageFromAsync(this, cancellationToken)) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index 8eae5f877b9ab..c7fd8394807fe 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -26,12 +26,12 @@ internal void ReplaceHandleIfNecessaryAfterFailedConnect() { /* nop on Windows * private sealed class CachedSerializedEndPoint { public readonly IPEndPoint IPEndPoint; - public readonly Internals.SocketAddress SocketAddress; + public readonly SocketAddress SocketAddress; public CachedSerializedEndPoint(IPAddress address) { IPEndPoint = new IPEndPoint(address, 0); - SocketAddress = IPEndPointExtensions.Serialize(IPEndPoint); + SocketAddress = IPEndPoint.Serialize(); } } @@ -70,7 +70,7 @@ public Socket(SocketInformation socketInformation) IPAddress tempAddress = _addressFamily == AddressFamily.InterNetwork ? IPAddress.Any : IPAddress.IPv6Any; IPEndPoint ep = new IPEndPoint(tempAddress, 0); - Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(ep); + SocketAddress socketAddress = ep.Serialize(); int size = socketAddress.Buffer.Length; unsafe { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index d54fac57d6c9c..b918cebb84058 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -7,7 +7,6 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; -using System.Net.Internals; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; @@ -24,6 +23,7 @@ public partial class Socket : IDisposable internal const int DefaultCloseTimeout = -1; // NOTE: changing this default is a breaking change. private static readonly IPAddress s_IPAddressAnyMapToIPv6 = IPAddress.Any.MapToIPv6(); + private static readonly IPEndPoint s_IPEndPointIPv6 = new IPEndPoint(s_IPAddressAnyMapToIPv6, 0); private SafeSocketHandle _handle; @@ -152,7 +152,6 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) // Try to get the local end point. That will in turn enable the remote // end point to be retrieved on-demand when the property is accessed. - Internals.SocketAddress? socketAddress = null; switch (_addressFamily) { case AddressFamily.InterNetwork: @@ -170,8 +169,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) break; case AddressFamily.Unix: - socketAddress = new Internals.SocketAddress(AddressFamily.Unix, buffer.Slice(0, bufferLength)); - _rightEndPoint = new UnixDomainSocketEndPoint(IPEndPointExtensions.GetNetSocketAddress(socketAddress)); + _rightEndPoint = new UnixDomainSocketEndPoint(buffer.Slice(0, bufferLength)); break; } @@ -203,8 +201,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) break; case AddressFamily.Unix: - socketAddress = new Internals.SocketAddress(AddressFamily.Unix, buffer.Slice(0, bufferLength)); - _remoteEndPoint = new UnixDomainSocketEndPoint(IPEndPointExtensions.GetNetSocketAddress(socketAddress)); + _remoteEndPoint = new UnixDomainSocketEndPoint(buffer.Slice(0, bufferLength)); break; } @@ -764,11 +761,11 @@ public void Bind(EndPoint localEP) if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"localEP:{localEP}"); - Internals.SocketAddress socketAddress = Serialize(ref localEP); + SocketAddress socketAddress = Serialize(ref localEP); DoBind(localEP, socketAddress); } - private void DoBind(EndPoint endPointSnapshot, Internals.SocketAddress socketAddress) + private void DoBind(EndPoint endPointSnapshot, SocketAddress socketAddress) { // Mitigation for Blue Screen of Death (Win7, maybe others). IPEndPoint? ipEndPoint = endPointSnapshot as IPEndPoint; @@ -781,7 +778,7 @@ private void DoBind(EndPoint endPointSnapshot, Internals.SocketAddress socketAdd SocketError errorCode = SocketPal.Bind( _handle, _protocolType, - socketAddress.Buffer.Span); + socketAddress.Buffer.Span.Slice(0, socketAddress.Size)); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -834,7 +831,7 @@ public void Connect(EndPoint remoteEP) ValidateForMultiConnect(isMultiEndpoint: false); - Internals.SocketAddress socketAddress = Serialize(ref remoteEP); + SocketAddress socketAddress = Serialize(ref remoteEP); _pendingConnectRightEndPoint = remoteEP; _nonBlockingConnectInProgress = !Blocking; @@ -1010,10 +1007,7 @@ public Socket Accept() ValidateBlockingMode(); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SRC:{LocalEndPoint}"); - Internals.SocketAddress socketAddress = - _addressFamily == AddressFamily.InterNetwork || _addressFamily == AddressFamily.InterNetworkV6 ? - IPEndPointExtensions.Serialize(_rightEndPoint) : - new Internals.SocketAddress(_addressFamily, SocketPal.MaximumAddressSize); // may be different size. + SocketAddress socketAddress = new SocketAddress(_addressFamily); if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AcceptStart(socketAddress); @@ -1287,10 +1281,10 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, ValidateBlockingMode(); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SRC:{LocalEndPoint} size:{size} remoteEP:{remoteEP}"); - Internals.SocketAddress socketAddress = Serialize(ref remoteEP); + SocketAddress socketAddress = Serialize(ref remoteEP); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer.Slice(0, socketAddress.Size), out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1359,10 +1353,10 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, EndPoint r ValidateBlockingMode(); - Internals.SocketAddress socketAddress = Serialize(ref remoteEP); + SocketAddress socketAddress = Serialize(ref remoteEP); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer.Slice(0, socketAddress.Size), out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1401,7 +1395,7 @@ public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, SocketAddr ValidateBlockingMode(); int bytesTransferred; - SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer, out bytesTransferred); + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer.Slice(0, socketAddress.Size), out bytesTransferred); // Throw an appropriate SocketException if the native call fails. if (errorCode != SocketError.Success) @@ -1578,14 +1572,11 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla // WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress // with the right address family. EndPoint endPointSnapshot = remoteEP; - Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot); - - // Save a copy of the original EndPoint. - Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); + SocketAddress socketAddress = Serialize(ref endPointSnapshot); SetReceivingPacketInformation(); - Internals.SocketAddress receiveAddress; + SocketAddress receiveAddress; int bytesTransferred; SocketError errorCode = SocketPal.ReceiveMessageFrom(this, _handle, buffer, offset, size, ref socketFlags, socketAddress, out receiveAddress, out ipPacketInformation, out bytesTransferred); @@ -1601,7 +1592,7 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla if (errorCode == SocketError.Success && SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); } - if (!socketAddressOriginal.Equals(receiveAddress)) + if (!SocketAddressExtensions.Equals(socketAddress, remoteEP)) { try { @@ -1666,14 +1657,11 @@ public int ReceiveMessageFrom(Span buffer, ref SocketFlags socketFlags, re // WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress // with the right address family. EndPoint endPointSnapshot = remoteEP; - Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot); - - // Save a copy of the original EndPoint. - Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); + SocketAddress socketAddress = Serialize(ref endPointSnapshot); SetReceivingPacketInformation(); - Internals.SocketAddress receiveAddress; + SocketAddress receiveAddress; int bytesTransferred; SocketError errorCode = SocketPal.ReceiveMessageFrom(this, _handle, buffer, ref socketFlags, socketAddress, out receiveAddress, out ipPacketInformation, out bytesTransferred); @@ -1689,7 +1677,7 @@ public int ReceiveMessageFrom(Span buffer, ref SocketFlags socketFlags, re if (errorCode == SocketError.Success && SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); } - if (!socketAddressOriginal.Equals(receiveAddress)) + if (!SocketAddressExtensions.Equals(socketAddress, remoteEP)) { try { @@ -1721,8 +1709,11 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl // WSARecvFrom; all that matters is that we generate a unique-to-this-call SocketAddress // with the right address family. EndPoint endPointSnapshot = remoteEP; - Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot); - Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); + SocketAddress socketAddress = new SocketAddress(AddressFamily); + if (endPointSnapshot.AddressFamily == AddressFamily.InterNetwork && IsDualMode) + { + endPointSnapshot = s_IPEndPointIPv6; + } int bytesTransferred; SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, out int socketAddressLength, out bytesTransferred); @@ -1749,15 +1740,19 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl socketAddress.Size = socketAddressLength; - if (!socketAddressOriginal.Equals(socketAddress)) + if (socketAddressLength > 0 && !socketAddress.Equals(remoteEP) || remoteEP.AddressFamily != socketAddress.Family) { try { if (endPointSnapshot.AddressFamily == socketAddress.Family) { - remoteEP = _remoteEndPoint != null ? _remoteEndPoint.Create(socketAddress) : socketAddress.GetIPEndPoint(); + remoteEP = endPointSnapshot.Create(socketAddress); } - else if (endPointSnapshot.AddressFamily == AddressFamily.InterNetworkV6 && socketAddress.Family == AddressFamily.InterNetwork) + //else if (socketAddress.Family == AddressFamily.InterNetworkV6 && IsDualMode) + //{ + // remoteEP = socketAddress.GetIPEndPoint(); + //} + else if (AddressFamily == AddressFamily.InterNetworkV6 && socketAddress.Family == AddressFamily.InterNetwork) { // We expect IPv6 on DualMode sockets but we can also get plain old IPv4 remoteEP = new IPEndPoint(socketAddress.GetIPAddress().MapToIPv6(), socketAddress.GetPort()); @@ -1830,8 +1825,11 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint // WSARecvFrom; all that matters is that we generate a unique-to-this-call SocketAddress // with the right address family. EndPoint endPointSnapshot = remoteEP; - Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot); - Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); + SocketAddress socketAddress = new SocketAddress(AddressFamily); + if (endPointSnapshot.AddressFamily == AddressFamily.InterNetwork && IsDualMode) + { + endPointSnapshot = s_IPEndPointIPv6; + } int bytesTransferred; SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, socketAddress.Buffer, out int socketAddressLength, out bytesTransferred); @@ -1857,13 +1855,13 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint } socketAddress.Size = socketAddressLength; - if (!socketAddressOriginal.Equals(socketAddress)) + if (socketAddressLength > 0 && !socketAddress.Equals(remoteEP) || remoteEP.AddressFamily != socketAddress.Family) { try { if (endPointSnapshot.AddressFamily == socketAddress.Family) { - remoteEP = _remoteEndPoint != null ? _remoteEndPoint.Create(socketAddress) : socketAddress.GetIPEndPoint(); + remoteEP = endPointSnapshot.Create(socketAddress); } else if (endPointSnapshot.AddressFamily == AddressFamily.InterNetworkV6 && socketAddress.Family == AddressFamily.InterNetwork) { @@ -1892,28 +1890,27 @@ public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint /// /// A span of bytes that is the storage location for received data. /// A bitwise combination of the values. - /// An , that will be updated with value of the remote peer. + /// An , that will be updated with value of the remote peer. /// The number of bytes received. /// remoteEP is . /// An error occurred when attempting to access the socket. /// The has been closed. - public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress receivedSocketAddress) + public int ReceiveFrom(Span buffer, SocketFlags socketFlags, SocketAddress receivedAddress) { ThrowIfDisposed(); + ArgumentNullException.ThrowIfNull(receivedAddress, nameof(receivedAddress)); - if (receivedSocketAddress.Size < SocketAddress.GetMaximumAddressSize(AddressFamily)) + if (receivedAddress.Size < SocketAddress.GetMaximumAddressSize(AddressFamily)) { - throw new ArgumentOutOfRangeException(nameof(receivedSocketAddress), SR.net_sockets_address_small); + throw new ArgumentOutOfRangeException(nameof(receivedAddress), SR.net_sockets_address_small); } ValidateBlockingMode(); int bytesTransferred; - SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, receivedSocketAddress.Buffer, out int socketAddressSize, out bytesTransferred); - if (socketAddressSize > 0) - { - receivedSocketAddress.Size = socketAddressSize; - } + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, receivedAddress.Buffer, out int socketAddressSize, out bytesTransferred); + receivedAddress.Size = socketAddressSize; + UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); // If the native call fails we'll throw a SocketException. if (errorCode != SocketError.Success) @@ -2934,20 +2931,33 @@ private bool ReceiveFromAsync(SocketAsyncEventArgs e, CancellationToken cancella ThrowIfDisposed(); ArgumentNullException.ThrowIfNull(e); - if (e.RemoteEndPoint == null) - { - throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e)); - } - if (!CanTryAddressFamily(e.RemoteEndPoint.AddressFamily)) + EndPoint? endPointSnapshot = e.RemoteEndPoint; + if (e._socketAddress == null) { - throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, e.RemoteEndPoint.AddressFamily, _addressFamily), nameof(e)); - } + if (endPointSnapshot is DnsEndPoint) + { + throw new ArgumentException(SR.Format(SR.net_sockets_invalid_dnsendpoint, "e.RemoteEndPoint"), nameof(e)); + } - // We don't do a CAS demand here because the contents of remoteEP aren't used by - // WSARecvFrom; all that matters is that we generate a unique-to-this-call SocketAddress - // with the right address family. - EndPoint endPointSnapshot = e.RemoteEndPoint; - e._socketAddress = Serialize(ref endPointSnapshot); + if (endPointSnapshot == null) + { + throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e)); + } + if (!CanTryAddressFamily(endPointSnapshot.AddressFamily)) + { + throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, endPointSnapshot.AddressFamily, _addressFamily), nameof(e)); + } + + // We don't do a CAS demand here because the contents of remoteEP aren't used by + // WSARecvFrom; all that matters is that we generate a unique-to-this-call SocketAddress + // with the right address family. + + if (endPointSnapshot.AddressFamily == AddressFamily.InterNetwork && IsDualMode) + { + endPointSnapshot = s_IPEndPointIPv6; + } + e._socketAddress ??= new SocketAddress(AddressFamily); + } // DualMode sockets may have updated the endPointSnapshot, and it has to have the same AddressFamily as // e.m_SocketAddres for Create to work later. @@ -3083,14 +3093,18 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT ThrowIfDisposed(); ArgumentNullException.ThrowIfNull(e); - if (e.RemoteEndPoint == null) + + EndPoint? endPointSnapshot = e.RemoteEndPoint; + if (e._socketAddress == null) { - throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e)); - } + if (endPointSnapshot == null) + { + throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e)); + } - // Prepare SocketAddress - EndPoint endPointSnapshot = e.RemoteEndPoint; - e._socketAddress = Serialize(ref endPointSnapshot); + // Prepare SocketAddress + e._socketAddress = Serialize(ref endPointSnapshot); + } // Prepare for and make the native call. e.StartOperationCommon(this, SocketAsyncOperation.SendTo); @@ -3131,7 +3145,7 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT // Internal and private methods // - internal static void GetIPProtocolInformation(AddressFamily addressFamily, Internals.SocketAddress socketAddress, out bool isIPv4, out bool isIPv6) + internal static void GetIPProtocolInformation(AddressFamily addressFamily, SocketAddress socketAddress, out bool isIPv4, out bool isIPv6) { bool isIPv4MappedToIPv6 = socketAddress.Family == AddressFamily.InterNetworkV6 && socketAddress.GetIPAddress().IsIPv4MappedToIPv6; isIPv4 = addressFamily == AddressFamily.InterNetwork || isIPv4MappedToIPv6; // DualMode @@ -3147,7 +3161,7 @@ internal static int GetAddressSize(EndPoint endPoint) endPoint.Serialize().Size; } - private Internals.SocketAddress Serialize(ref EndPoint remoteEP) + private SocketAddress Serialize(ref EndPoint remoteEP) { if (remoteEP is IPEndPoint ip) { @@ -3163,16 +3177,16 @@ private Internals.SocketAddress Serialize(ref EndPoint remoteEP) throw new ArgumentException(SR.Format(SR.net_sockets_invalid_dnsendpoint, nameof(remoteEP)), nameof(remoteEP)); } - return IPEndPointExtensions.Serialize(remoteEP); + return remoteEP.Serialize(); } - private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socketAddress) + private void DoConnect(EndPoint endPointSnapshot, SocketAddress socketAddress) { SocketsTelemetry.Log.ConnectStart(socketAddress); SocketError errorCode; try { - errorCode = SocketPal.Connect(_handle, socketAddress.Buffer); + errorCode = SocketPal.Connect(_handle, socketAddress.Buffer.Slice(0, socketAddress.Size)); } catch (Exception ex) { @@ -3756,6 +3770,12 @@ private bool CheckErrorAndUpdateStatus(SocketError errorCode) private void ValidateReceiveFromEndpointAndState(EndPoint remoteEndPoint, string remoteEndPointArgumentName) { ArgumentNullException.ThrowIfNull(remoteEndPoint, remoteEndPointArgumentName); + + if (remoteEndPoint is DnsEndPoint) + { + throw new ArgumentException(SR.Format(SR.net_sockets_invalid_dnsendpoint, remoteEndPointArgumentName), remoteEndPointArgumentName); + } + if (!CanTryAddressFamily(remoteEndPoint.AddressFamily)) { throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEndPoint.AddressFamily, _addressFamily), remoteEndPointArgumentName); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index ba34da0e171a6..26b3cf81333b5 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -196,7 +196,7 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc bool isIPv4, isIPv6; Socket.GetIPProtocolInformation(socket.AddressFamily, _socketAddress!, out isIPv4, out isIPv6); - int socketAddressSize = _socketAddress!.Size; + int socketAddressSize = _socketAddress!.Buffer.Length; int bytesReceived; SocketFlags receivedFlags; IPPacketInformation ipPacketInformation; @@ -336,7 +336,7 @@ internal void LogBuffer(int size) } } - private SocketError FinishOperationAccept(Internals.SocketAddress remoteSocketAddress) + private SocketError FinishOperationAccept(SocketAddress remoteSocketAddress) { new ReadOnlySpan(_acceptBuffer, 0, _acceptAddressBufferCount).CopyTo(remoteSocketAddress.Buffer.Span); remoteSocketAddress.Size = _acceptAddressBufferCount; @@ -366,7 +366,7 @@ private static SocketError FinishOperationConnect() return SocketError.Success; } - private void UpdateReceivedSocketAddress(Internals.SocketAddress socketAddress) + private void UpdateReceivedSocketAddress(SocketAddress socketAddress) { if (_socketAddressSize > 0) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index 2aa782d826cac..bd193ee6565ab 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -1029,7 +1029,7 @@ internal unsafe void LogBuffer(int size) } } - private unsafe SocketError FinishOperationAccept(Internals.SocketAddress remoteSocketAddress) + private unsafe SocketError FinishOperationAccept(SocketAddress remoteSocketAddress) { SocketError socketError; IntPtr localAddr; @@ -1120,7 +1120,7 @@ private unsafe SocketError FinishOperationConnect() } } - private unsafe void UpdateReceivedSocketAddress(Internals.SocketAddress socketAddress) + private unsafe void UpdateReceivedSocketAddress(SocketAddress socketAddress) { Debug.Assert(_socketAddressPtr != IntPtr.Zero); int size = *((int*)_socketAddressPtr); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index b56873ab4c104..e04739d5fe7a6 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -65,7 +65,7 @@ public partial class SocketAsyncEventArgs : EventArgs, IDisposable private int _acceptAddressBufferCount; // Internal SocketAddress buffer. - internal Internals.SocketAddress? _socketAddress; + internal SocketAddress? _socketAddress; // Misc state variables. private readonly bool _flowExecutionContext; @@ -866,7 +866,7 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags { case SocketAsyncOperation.Accept: // Get the endpoint. - Internals.SocketAddress remoteSocketAddress = IPEndPointExtensions.Serialize(_currentSocket!._rightEndPoint!); + SocketAddress remoteSocketAddress = _currentSocket!._rightEndPoint!.Serialize(); socketError = FinishOperationAccept(remoteSocketAddress); @@ -923,8 +923,7 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags case SocketAsyncOperation.ReceiveFrom: // Deal with incoming address. UpdateReceivedSocketAddress(_socketAddress!); - Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(_remoteEndPoint!); - if (!socketAddressOriginal.Equals(_socketAddress)) + if (_remoteEndPoint != null && !SocketAddressExtensions.Equals(_socketAddress!, _remoteEndPoint)) { try { @@ -946,8 +945,7 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags case SocketAsyncOperation.ReceiveMessageFrom: // Deal with incoming address. UpdateReceivedSocketAddress(_socketAddress!); - socketAddressOriginal = IPEndPointExtensions.Serialize(_remoteEndPoint!); - if (!socketAddressOriginal.Equals(_socketAddress)) + if (!SocketAddressExtensions.Equals(_socketAddress!, _remoteEndPoint)) { try { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index c57639ff13548..0220f9b3334bb 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -842,12 +842,6 @@ public static unsafe bool TryCompleteReceiveFrom(SafeSocketHandle socket, Span 0 && receivedSocketAddressLength == 0) - { - // We can fail to get peer address on TCP - receivedSocketAddressLength = socketAddress.Length; - SocketAddressPal.Clear(socketAddress); - } return true; } @@ -1273,7 +1267,7 @@ public static SocketError Receive(SafeSocketHandle handle, Span buffer, So return completed ? errorCode : SocketError.WouldBlock; } - public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int count, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) + public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int count, ref SocketFlags socketFlags, SocketAddress socketAddress, out SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { int socketAddressLen; @@ -1299,7 +1293,7 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han } - public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) + public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span buffer, ref SocketFlags socketFlags, SocketAddress socketAddress, out SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { int socketAddressLen; diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index b474447132d72..d47c3bdf20966 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -433,12 +433,12 @@ public static unsafe IPPacketInformation GetIPPacketInformation(Interop.Winsock. return new IPPacketInformation(address, (int)controlBuffer->index); } - public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int size, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) + public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int size, ref SocketFlags socketFlags, SocketAddress socketAddress, out SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { return ReceiveMessageFrom(socket, handle, new Span(buffer, offset, size), ref socketFlags, socketAddress, out receiveAddress, out ipPacketInformation, out bytesTransferred); } - public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) + public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span buffer, ref SocketFlags socketFlags, SocketAddress socketAddress, out SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { bool ipv4, ipv6; Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out ipv4, out ipv6); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketsTelemetry.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketsTelemetry.cs index 9790aea5fd95f..bea730b47b162 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketsTelemetry.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketsTelemetry.cs @@ -77,7 +77,7 @@ private void AcceptFailed(SocketError error, string? exceptionMessage) } [NonEvent] - public void ConnectStart(Internals.SocketAddress address) + public void ConnectStart(SocketAddress address) { Interlocked.Increment(ref _currentOutgoingConnectAttempts); @@ -107,7 +107,7 @@ public void AfterConnect(SocketError error, string? exceptionMessage = null) } [NonEvent] - public void AcceptStart(Internals.SocketAddress address) + public void AcceptStart(SocketAddress address) { if (IsEnabled(EventLevel.Informational, EventKeywords.All)) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/UnixDomainSocketEndPoint.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/UnixDomainSocketEndPoint.cs index 0a06e5d418fda..fab23b52982e1 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/UnixDomainSocketEndPoint.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/UnixDomainSocketEndPoint.cs @@ -64,19 +64,13 @@ private UnixDomainSocketEndPoint(string path, string? boundFileName) internal static int MaxAddressSize => s_nativeAddressSize; - internal UnixDomainSocketEndPoint(SocketAddress socketAddress) + internal UnixDomainSocketEndPoint(ReadOnlySpan socketAddress) { - ArgumentNullException.ThrowIfNull(socketAddress); + Debug.Assert(AddressFamily.Unix == SocketAddressPal.GetAddressFamily(socketAddress)); - if (socketAddress.Family != EndPointAddressFamily || - socketAddress.Size > s_nativeAddressSize) + if (socketAddress.Length > s_nativePathOffset) { - throw new ArgumentOutOfRangeException(nameof(socketAddress)); - } - - if (socketAddress.Size > s_nativePathOffset) - { - _encodedPath = new byte[socketAddress.Size - s_nativePathOffset]; + _encodedPath = new byte[socketAddress.Length - s_nativePathOffset]; for (int i = 0; i < _encodedPath.Length; i++) { _encodedPath[i] = socketAddress[s_nativePathOffset + i]; @@ -118,7 +112,7 @@ public override SocketAddress Serialize() /// Creates an instance from a instance. /// The socket address that serves as the endpoint for a connection. /// A new instance that is initialized from the specified instance. - public override EndPoint Create(SocketAddress socketAddress) => new UnixDomainSocketEndPoint(socketAddress); + public override EndPoint Create(SocketAddress socketAddress) => new UnixDomainSocketEndPoint(socketAddress.Buffer.Span.Slice(0, socketAddress.Size)); /// Gets the address family to which the endpoint belongs. /// One of the values. diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/DualModeSocketTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/DualModeSocketTest.cs index d9d75bfc5e450..6c5c2d8e2a9ce 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/DualModeSocketTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/DualModeSocketTest.cs @@ -975,7 +975,7 @@ public async Task Socket_ReceiveFromDnsEndPoint_Throws() int port = socket.BindToAnonymousPort(IPAddress.IPv6Loopback); EndPoint receivedFrom = new DnsEndPoint("localhost", port, AddressFamily.InterNetworkV6); - await AssertExtensions.ThrowsAsync("remoteEP", () => ReceiveFromAsync(socket, new byte[1], receivedFrom)); + await Assert.ThrowsAsync(() => ReceiveFromAsync(socket, new byte[1], receivedFrom)); } [Fact] diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs index 1a720df27d250..1a5ec7d05d28e 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs @@ -1,8 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections.Generic; -using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -64,6 +62,16 @@ public async Task NullEndpoint_Throws_ArgumentException() } } + [Fact] + public async Task NullSocketAddress_Throws_ArgumentException() + { + using Socket socket = CreateSocket(); + SocketAddress socketAddress = null; + + Assert.Throws(() => socket.ReceiveFrom(new byte[1], SocketFlags.None, socketAddress)); + await Assert.ThrowsAsync(() => socket.ReceiveFromAsync(new Memory(new byte[1]), SocketFlags.None, socketAddress).AsTask()); + } + [Fact] public async Task AddressFamilyDoesNotMatch_Throws_ArgumentException() { @@ -151,6 +159,12 @@ public async Task ReceiveSent_UDP_Success(bool ipv4) AssertExtensions.SequenceEqual(emptyBuffer, new ReadOnlySpan(receiveInternalBuffer, 0, Offset)); AssertExtensions.SequenceEqual(sendBuffer, new ReadOnlySpan(receiveInternalBuffer, Offset, DatagramSize)); Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint); + remoteEp = (IPEndPoint)result.RemoteEndPoint; + if (i > 0) + { + // reference should be same after first round + Assert.True(remoteEp == result.RemoteEndPoint); + } } } @@ -195,7 +209,50 @@ public void ReceiveSent_SocketAddress_Success(bool ipv4) Assert.Equal(sa, serverSA); Assert.Equal(server.LocalEndPoint, server.LocalEndPoint.Create(sa)); Assert.True(new Span(receiveBuffer, 0, readBytes).SequenceEqual(sendBuffer)); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ReceiveSent_SocketAddressAsync_Success(bool ipv4) + { + const int DatagramSize = 256; + const int DatagramsToSend = 16; + + IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; + using Socket server = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket client = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + + client.BindToAnonymousPort(address); + server.BindToAnonymousPort(address); + + byte[] sendBuffer = new byte[DatagramSize]; + byte[] receiveBuffer = new byte[DatagramSize]; + + SocketAddress serverSA = server.LocalEndPoint.Serialize(); + SocketAddress clientSA = client.LocalEndPoint.Serialize(); + SocketAddress sa = new SocketAddress(address.AddressFamily); + + Random rnd = new Random(0); + + for (int i = 0; i < DatagramsToSend; i++) + { + rnd.NextBytes(sendBuffer); + await client.SendToAsync(sendBuffer, SocketFlags.None, serverSA); + + int readBytes = await server.ReceiveFromAsync(receiveBuffer, SocketFlags.None, sa); + Assert.Equal(sa, clientSA); + Assert.Equal(client.LocalEndPoint, client.LocalEndPoint.Create(sa)); + Assert.True(new Span(receiveBuffer, 0, readBytes).SequenceEqual(sendBuffer)); + // and send it back to make sure it works. + rnd.NextBytes(sendBuffer); + await server.SendToAsync(sendBuffer, SocketFlags.None, sa); + readBytes = await client.ReceiveFromAsync(receiveBuffer, SocketFlags.None, sa); + Assert.Equal(sa, serverSA); + Assert.Equal(server.LocalEndPoint, server.LocalEndPoint.Create(sa)); + Assert.True(new Span(receiveBuffer, 0, readBytes).SequenceEqual(sendBuffer)); } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs index f7475703b3da1..bf0ad14658869 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs @@ -63,6 +63,16 @@ public async Task NullEndpoint_Throws_ArgumentException() } } + [Fact] + public async Task NullSocketAddress_Throws_ArgumentException() + { + using Socket socket = CreateSocket(); + SocketAddress socketAddress = null; + + Assert.Throws(() => socket.SendTo(new byte[1], SocketFlags.None, socketAddress)); + await AssertThrowsSynchronously(() => socket.SendToAsync(new byte[1], SocketFlags.None, socketAddress).AsTask()); + } + [Fact] public async Task Datagram_UDP_ShouldImplicitlyBindLocalEndpoint() { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs index ba0fbd7c13edb..3a8139b4b2963 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Reflection; using System.Runtime.InteropServices; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.DotNet.RemoteExecutor; @@ -400,13 +401,13 @@ public void UnixDomainSocketEndPoint_RemoteEndPointEqualsBindAddress(bool abstra } // An abstract socket address starts with a zero byte. serverAddress = '\0' + Guid.NewGuid().ToString(); - clientAddress = '\0' + Guid.NewGuid().ToString(); + clientAddress = '\0' + Guid.NewGuid().ToString() + "ABC"; expectedClientAddress = '@' + clientAddress.Substring(1); } else { serverAddress = GetRandomNonExistingFilePath(); - clientAddress = GetRandomNonExistingFilePath(); + clientAddress = GetRandomNonExistingFilePath() + "ABC"; expectedClientAddress = clientAddress; } @@ -536,6 +537,55 @@ public void FilePathEquality() Assert.NotEqual(endPoint2, endPoint3); } + [ConditionalTheory(typeof(Socket), nameof(Socket.OSSupportsUnixDomainSockets))] + [ActiveIssue("https://github.com/dotnet/runtime/issues/26189", TestPlatforms.Windows)] + [SkipOnPlatform(TestPlatforms.LinuxBionic, "SElinux blocks UNIX sockets in our CI environment")] + [InlineData(true)] + [InlineData(false)] + public async Task ReceiveFrom_EndPoints_Correct(bool useAsync) + { + string serverAddress = GetRandomNonExistingFilePath(); + string clientAddress = GetRandomNonExistingFilePath() + "ABCD"; + + using (Socket server = new Socket(AddressFamily.Unix, SocketType.Dgram, ProtocolType.Unspecified)) + { + server.Bind(new UnixDomainSocketEndPoint(serverAddress)); + using (Socket client = new Socket(AddressFamily.Unix, SocketType.Dgram, ProtocolType.Unspecified)) + { + byte[] data = Encoding.ASCII.GetBytes(nameof(ReceiveFrom_EndPoints_Correct)); + // Bind the client. + client.Bind(new UnixDomainSocketEndPoint(clientAddress)); + + var sender = new UnixDomainSocketEndPoint(GetRandomNonExistingFilePath()); + EndPoint senderRemote = (EndPoint)sender; + int transferredBytes; + if (useAsync) + { + transferredBytes = await client.SendToAsync(data, server.LocalEndPoint); + } + else + { + transferredBytes = client.SendTo(data, server.LocalEndPoint); + } + Assert.Equal(data.Length, transferredBytes); + + byte[] buffer = new byte[data.Length * 2]; + if (useAsync) + { + SocketReceiveFromResult result = await server.ReceiveFromAsync(buffer, senderRemote); + Assert.Equal(clientAddress, result.RemoteEndPoint.ToString()); + Assert.Equal(data.Length, result.ReceivedBytes); + } + else + { + transferredBytes = server.ReceiveFrom(buffer, ref senderRemote); + Assert.Equal(data.Length, transferredBytes); + Assert.Equal(clientAddress, senderRemote.ToString()); + } + } + } + } + private static string GetRandomNonExistingFilePath() { string result;