diff --git a/Directory.Packages.props b/Directory.Packages.props index e456e670d67..490e518213e 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -77,7 +77,6 @@ - @@ -126,6 +125,8 @@ + + diff --git a/eng/Versions.props b/eng/Versions.props index d032184c62b..636b494b864 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -50,6 +50,8 @@ 9.0.6 10.0.0-preview.5.25277.114 + + 2.1.1 2.23.32-alpha diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs index 6cc9f92bc46..7a2d1b632e0 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs @@ -4,6 +4,7 @@ using System.Net; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.ServiceDiscovery.Dns; @@ -12,6 +13,7 @@ internal sealed partial class DnsServiceEndpointProvider( string hostName, IOptionsMonitor options, ILogger logger, + IDnsResolver resolver, TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature { protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor; @@ -29,17 +31,14 @@ protected override async Task ResolveAsyncCore() var endpoints = new List(); var ttl = DefaultRefreshPeriod; Log.AddressQuery(logger, ServiceName, hostName); - var addresses = await System.Net.Dns.GetHostAddressesAsync(hostName, ShutdownToken).ConfigureAwait(false); + + var now = _timeProvider.GetUtcNow().DateTime; + var addresses = await resolver.ResolveIPAddressesAsync(hostName, ShutdownToken).ConfigureAwait(false); + foreach (var address in addresses) { - var serviceEndpoint = ServiceEndpoint.Create(new IPEndPoint(address, 0)); - serviceEndpoint.Features.Set(this); - if (options.CurrentValue.ShouldApplyHostNameMetadata(serviceEndpoint)) - { - serviceEndpoint.Features.Set(this); - } - - endpoints.Add(serviceEndpoint); + ttl = MinTtl(now, address.ExpiresAt, ttl); + endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, port: 0))); } if (endpoints.Count == 0) @@ -48,5 +47,23 @@ protected override async Task ResolveAsyncCore() } SetResult(endpoints, ttl); + + static TimeSpan MinTtl(DateTime now, DateTime expiresAt, TimeSpan existing) + { + var candidate = expiresAt - now; + return candidate < existing ? candidate : existing; + } + + ServiceEndpoint CreateEndpoint(EndPoint endPoint) + { + var serviceEndpoint = ServiceEndpoint.Create(endPoint); + serviceEndpoint.Features.Set(this); + if (options.CurrentValue.ShouldApplyHostNameMetadata(serviceEndpoint)) + { + serviceEndpoint.Features.Set(this); + } + + return serviceEndpoint; + } } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs index 6c69cc7a760..311c06f631a 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs @@ -14,7 +14,7 @@ internal abstract partial class DnsServiceEndpointProviderBase : IServiceEndpoin private readonly object _lock = new(); private readonly ILogger _logger; private readonly CancellationTokenSource _disposeCancellation = new(); - private readonly TimeProvider _timeProvider; + protected readonly TimeProvider _timeProvider; private long _lastRefreshTimeStamp; private Task _resolveTask = Task.CompletedTask; private bool _hasEndpoints; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs index c241ad89dd3..1da21411e64 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs @@ -4,18 +4,20 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.ServiceDiscovery.Dns; internal sealed partial class DnsServiceEndpointProviderFactory( IOptionsMonitor options, ILogger logger, + IDnsResolver resolver, TimeProvider timeProvider) : IServiceEndpointProviderFactory { /// public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider) { - provider = new DnsServiceEndpointProvider(query, hostName: query.ServiceName, options, logger, timeProvider); + provider = new DnsServiceEndpointProvider(query, hostName: query.ServiceName, options, logger, resolver, timeProvider); return true; } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs index c174cda4f68..6d5ade5059e 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net; -using DnsClient; -using DnsClient.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.ServiceDiscovery.Dns; @@ -15,7 +14,7 @@ internal sealed partial class DnsSrvServiceEndpointProvider( string hostName, IOptionsMonitor options, ILogger logger, - IDnsQuery dnsClient, + IDnsResolver resolver, TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature { protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor; @@ -35,56 +34,36 @@ protected override async Task ResolveAsyncCore() var endpoints = new List(); var ttl = DefaultRefreshPeriod; Log.SrvQuery(logger, ServiceName, srvQuery); - var result = await dnsClient.QueryAsync(srvQuery, QueryType.SRV, cancellationToken: ShutdownToken).ConfigureAwait(false); - if (result.HasError) - { - throw CreateException(srvQuery, result.ErrorMessage); - } - var lookupMapping = new Dictionary(); - foreach (var record in result.Additionals.Where(x => x is AddressRecord or CNameRecord)) - { - ttl = MinTtl(record, ttl); - lookupMapping[record.DomainName] = record; - } + var now = _timeProvider.GetUtcNow().DateTime; + var result = await resolver.ResolveServiceAsync(srvQuery, cancellationToken: ShutdownToken).ConfigureAwait(false); - var srvRecords = result.Answers.OfType(); - foreach (var record in srvRecords) + foreach (var record in result) { - if (!lookupMapping.TryGetValue(record.Target, out var targetRecord)) - { - continue; - } + ttl = MinTtl(now, record.ExpiresAt, ttl); - ttl = MinTtl(record, ttl); - if (targetRecord is AddressRecord addressRecord) + if (record.Addresses.Length > 0) { - endpoints.Add(CreateEndpoint(new IPEndPoint(addressRecord.Address, record.Port))); + foreach (var address in record.Addresses) + { + ttl = MinTtl(now, address.ExpiresAt, ttl); + endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, record.Port))); + } } - else if (targetRecord is CNameRecord canonicalNameRecord) + else { - endpoints.Add(CreateEndpoint(new DnsEndPoint(canonicalNameRecord.CanonicalName.Value.TrimEnd('.'), record.Port))); + endpoints.Add(CreateEndpoint(new DnsEndPoint(record.Target.TrimEnd('.'), record.Port))); } } SetResult(endpoints, ttl); - static TimeSpan MinTtl(DnsResourceRecord record, TimeSpan existing) + static TimeSpan MinTtl(DateTime now, DateTime expiresAt, TimeSpan existing) { - var candidate = TimeSpan.FromSeconds(record.TimeToLive); + var candidate = expiresAt - now; return candidate < existing ? candidate : existing; } - InvalidOperationException CreateException(string dnsName, string errorMessage) - { - var msg = errorMessage switch - { - { Length: > 0 } => $"No DNS records were found for service '{ServiceName}' (DNS name: '{dnsName}'): {errorMessage}.", - _ => $"No DNS records were found for service '{ServiceName}' (DNS name: '{dnsName}')." - }; - return new InvalidOperationException(msg); - } - ServiceEndpoint CreateEndpoint(EndPoint endPoint) { var serviceEndpoint = ServiceEndpoint.Create(endPoint); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs index fd0cb28353d..085ee30123b 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs @@ -2,29 +2,29 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; -using DnsClient; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.ServiceDiscovery.Dns; internal sealed partial class DnsSrvServiceEndpointProviderFactory( IOptionsMonitor options, ILogger logger, - IDnsQuery dnsClient, + IDnsResolver resolver, TimeProvider timeProvider) : IServiceEndpointProviderFactory { private static readonly string s_serviceAccountPath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount"); private static readonly string s_serviceAccountNamespacePath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount", "namespace"); private static readonly string s_resolveConfPath = Path.Combine($"{Path.DirectorySeparatorChar}etc", "resolv.conf"); - private readonly string? _querySuffix = options.CurrentValue.QuerySuffix ?? GetKubernetesHostDomain(); + private readonly string? _querySuffix = options.CurrentValue.QuerySuffix?.TrimStart('.') ?? GetKubernetesHostDomain(); /// public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider) { // If a default namespace is not specified, then this provider will attempt to infer the namespace from the service name, but only when running inside Kubernetes. // Kubernetes DNS spec: https://github.com/kubernetes/dns/blob/master/docs/specification.md - // SRV records are available for headless services with named ports. + // SRV records are available for headless services with named ports. // They take the form $"_{portName}._{protocol}.{serviceName}.{namespace}.{suffix}" // The suffix (after the service name) can be parsed from /etc/resolv.conf // Otherwise, the namespace can be read from /var/run/secrets/kubernetes.io/serviceaccount/namespace and combined with an assumed suffix of "svc.cluster.local". @@ -39,7 +39,7 @@ public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] ou var portName = query.EndpointName ?? "default"; var srvQuery = $"_{portName}._tcp.{query.ServiceName}.{_querySuffix}"; - provider = new DnsSrvServiceEndpointProvider(query, srvQuery, hostName: query.ServiceName, options, logger, dnsClient, timeProvider); + provider = new DnsSrvServiceEndpointProvider(query, srvQuery, hostName: query.ServiceName, options, logger, resolver, timeProvider); return true; } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj index 9854c93c01a..3aba9b3aaea 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj @@ -9,7 +9,6 @@ - @@ -23,7 +22,9 @@ - + + + diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs new file mode 100644 index 00000000000..094df3040d1 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -0,0 +1,133 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct DnsDataReader : IDisposable +{ + public ArraySegment MessageBuffer { get; private set; } + bool _returnToPool; + private int _position; + + public DnsDataReader(ArraySegment buffer, bool returnToPool = false) + { + MessageBuffer = buffer; + _position = 0; + _returnToPool = returnToPool; + } + + public bool TryReadHeader(out DnsMessageHeader header) + { + Debug.Assert(_position == 0); + + if (!DnsPrimitives.TryReadMessageHeader(MessageBuffer.AsSpan(), out header, out int bytesRead)) + { + header = default; + return false; + } + + _position += bytesRead; + return true; + } + + internal bool TryReadQuestion(out EncodedDomainName name, out QueryType type, out QueryClass @class) + { + if (!TryReadDomainName(out name) || + !TryReadUInt16(out ushort typeAsInt) || + !TryReadUInt16(out ushort classAsInt)) + { + type = 0; + @class = 0; + return false; + } + + type = (QueryType)typeAsInt; + @class = (QueryClass)classAsInt; + return true; + } + + public bool TryReadUInt16(out ushort value) + { + if (MessageBuffer.Count - _position < 2) + { + value = 0; + return false; + } + + value = BinaryPrimitives.ReadUInt16BigEndian(MessageBuffer.AsSpan(_position)); + _position += 2; + return true; + } + + public bool TryReadUInt32(out uint value) + { + if (MessageBuffer.Count - _position < 4) + { + value = 0; + return false; + } + + value = BinaryPrimitives.ReadUInt32BigEndian(MessageBuffer.AsSpan(_position)); + _position += 4; + return true; + } + + public bool TryReadResourceRecord(out DnsResourceRecord record) + { + if (!TryReadDomainName(out EncodedDomainName name) || + !TryReadUInt16(out ushort type) || + !TryReadUInt16(out ushort @class) || + !TryReadUInt32(out uint ttl) || + !TryReadUInt16(out ushort dataLength) || + MessageBuffer.Count - _position < dataLength) + { + record = default; + return false; + } + + ReadOnlyMemory data = MessageBuffer.AsMemory(_position, dataLength); + _position += dataLength; + + record = new DnsResourceRecord(name, (QueryType)type, (QueryClass)@class, (int)ttl, data); + return true; + } + + public bool TryReadDomainName(out EncodedDomainName name) + { + if (DnsPrimitives.TryReadQName(MessageBuffer, _position, out name, out int bytesRead)) + { + _position += bytesRead; + return true; + } + + return false; + } + + public bool TryReadSpan(int length, out ReadOnlySpan name) + { + if (MessageBuffer.Count - _position < length) + { + name = default; + return false; + } + + name = MessageBuffer.AsSpan(_position, length); + _position += length; + return true; + } + + public void Dispose() + { + if (_returnToPool && MessageBuffer.Array != null) + { + ArrayPool.Shared.Return(MessageBuffer.Array); + } + + _returnToPool = false; + MessageBuffer = default; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs new file mode 100644 index 00000000000..a0a11f0b808 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers.Binary; +using System.Diagnostics; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal sealed class DnsDataWriter +{ + private readonly Memory _buffer; + private int _position; + + internal DnsDataWriter(Memory buffer) + { + _buffer = buffer; + _position = 0; + } + + public int Position => _position; + + internal bool TryWriteHeader(in DnsMessageHeader header) + { + if (!DnsPrimitives.TryWriteMessageHeader(_buffer.Span.Slice(_position), header, out int written)) + { + return false; + } + + _position += written; + return true; + } + + internal bool TryWriteQuestion(EncodedDomainName name, QueryType type, QueryClass @class) + { + if (!TryWriteDomainName(name) || + !TryWriteUInt16((ushort)type) || + !TryWriteUInt16((ushort)@class)) + { + return false; + } + + return true; + } + + private bool TryWriteDomainName(EncodedDomainName name) + { + foreach (var label in name.Labels) + { + // this should be already validated by the caller + Debug.Assert(label.Length <= 63, "Label length must not exceed 63 bytes."); + + if (!TryWriteByte((byte)label.Length) || + !TryWriteRawData(label.Span)) + { + return false; + } + } + + // root label + return TryWriteByte(0); + } + + internal bool TryWriteDomainName(string name) + { + if (DnsPrimitives.TryWriteQName(_buffer.Span.Slice(_position), name, out int written)) + { + _position += written; + return true; + } + + return false; + } + + internal bool TryWriteByte(byte value) + { + if (_buffer.Length - _position < 1) + { + return false; + } + + _buffer.Span[_position] = value; + _position += 1; + return true; + } + + internal bool TryWriteUInt16(ushort value) + { + if (_buffer.Length - _position < 2) + { + return false; + } + + BinaryPrimitives.WriteUInt16BigEndian(_buffer.Span.Slice(_position), value); + _position += 2; + return true; + } + + internal bool TryWriteUInt32(uint value) + { + if (_buffer.Length - _position < 4) + { + return false; + } + + BinaryPrimitives.WriteUInt32BigEndian(_buffer.Span.Slice(_position), value); + _position += 4; + return true; + } + + internal bool TryWriteRawData(ReadOnlySpan value) + { + if (_buffer.Length - _position < value.Length) + { + return false; + } + + value.CopyTo(_buffer.Span.Slice(_position)); + _position += value.Length; + return true; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs new file mode 100644 index 00000000000..b22273a04f2 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +// RFC 1035 4.1.1. Header section format +internal struct DnsMessageHeader +{ + internal const int HeaderLength = 12; + public ushort TransactionId { get; set; } + + internal QueryFlags QueryFlags { get; set; } + + public ushort QueryCount { get; set; } + + public ushort AnswerCount { get; set; } + + public ushort AuthorityCount { get; set; } + + public ushort AdditionalRecordCount { get; set; } + + public QueryResponseCode ResponseCode + { + get => (QueryResponseCode)(QueryFlags & QueryFlags.ResponseCodeMask); + } + + public bool IsResultTruncated + { + get => (QueryFlags & QueryFlags.ResultTruncated) != 0; + } + + public bool IsResponse + { + get => (QueryFlags & QueryFlags.HasResponse) != 0; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs new file mode 100644 index 00000000000..e549abe2576 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs @@ -0,0 +1,318 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Globalization; +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal static class DnsPrimitives +{ + // Maximum length of a domain name in ASCII (excluding trailing dot) + internal const int MaxDomainNameLength = 253; + + internal static bool TryReadMessageHeader(ReadOnlySpan buffer, out DnsMessageHeader header, out int bytesRead) + { + // RFC 1035 4.1.1. Header section format + if (buffer.Length < DnsMessageHeader.HeaderLength) + { + header = default; + bytesRead = 0; + return false; + } + + header = new DnsMessageHeader + { + TransactionId = BinaryPrimitives.ReadUInt16BigEndian(buffer), + QueryFlags = (QueryFlags)BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(2)), + QueryCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(4)), + AnswerCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(6)), + AuthorityCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(8)), + AdditionalRecordCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(10)) + }; + + bytesRead = DnsMessageHeader.HeaderLength; + return true; + } + + internal static bool TryWriteMessageHeader(Span buffer, DnsMessageHeader header, out int bytesWritten) + { + // RFC 1035 4.1.1. Header section format + if (buffer.Length < DnsMessageHeader.HeaderLength) + { + bytesWritten = 0; + return false; + } + + BinaryPrimitives.WriteUInt16BigEndian(buffer, header.TransactionId); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(2), (ushort)header.QueryFlags); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(4), header.QueryCount); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(6), header.AnswerCount); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(8), header.AuthorityCount); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(10), header.AdditionalRecordCount); + + bytesWritten = DnsMessageHeader.HeaderLength; + return true; + } + + // https://www.rfc-editor.org/rfc/rfc1035#section-2.3.4 + // labels 63 octets or less + // name 255 octets or less + + private static readonly SearchValues s_domainNameValidChars = SearchValues.Create("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_."); + private static readonly IdnMapping s_idnMapping = new IdnMapping(); + internal static bool TryWriteQName(Span destination, string name, out int written) + { + written = 0; + + // + // RFC 1035 4.1.2. + // + // a domain name represented as a sequence of labels, where + // each label consists of a length octet followed by that + // number of octets. The domain name terminates with the + // zero length octet for the null label of the root. Note + // that this field may be an odd number of octets; no + // padding is used. + // + if (!Ascii.IsValid(name)) + { + // IDN name, apply punycode + try + { + // IdnMapping performs some validation internally (such as label + // and domain name lengths), but is more relaxed than RFC + // 1035 (e.g. allows ~ chars), so even if this conversion does + // not throw, we still need to perform additional validation + name = s_idnMapping.GetAscii(name); + } + catch + { + return false; + } + } + + if (name.Length > MaxDomainNameLength || + name.AsSpan().ContainsAnyExcept(s_domainNameValidChars) || + destination.IsEmpty || + !Encoding.ASCII.TryGetBytes(name, destination.Slice(1), out int length) || + destination.Length < length + 2) + { + // buffer too small + return false; + } + + Span nameBuffer = destination.Slice(0, 1 + length); + Span label; + while (true) + { + // figure out the next label and prepend the length + int index = nameBuffer.Slice(1).IndexOf((byte)'.'); + label = index == -1 ? nameBuffer.Slice(1) : nameBuffer.Slice(1, index); + + if (label.Length == 0) + { + // empty label (explicit root) is only allowed at the end + if (index != -1) + { + written = 0; + return false; + } + } + // Label restrictions: + // - maximum 63 octets long + // - must start with a letter or digit (digit is allowed by RFC 1123) + // - may start with an underscore (underscore may be present only + // at the start of the label to support SRV records) + // - must end with a letter or digit + else if (label.Length > 63 || + !char.IsAsciiLetterOrDigit((char)label[0]) && label[0] != '_' || + label.Slice(1).Contains((byte)'_') || + !char.IsAsciiLetterOrDigit((char)label[^1])) + { + written = 0; + return false; + } + + nameBuffer[0] = (byte)label.Length; + written += label.Length + 1; + + if (index == -1) + { + // this was the last label + break; + } + + nameBuffer = nameBuffer.Slice(index + 1); + } + + // Add root label if wasn't explicitly specified + if (label.Length != 0) + { + destination[written] = 0; + written++; + } + + return true; + } + + private static bool TryReadQNameCore(List> labels, int totalLength, ReadOnlyMemory messageBuffer, int offset, out int bytesRead, bool canStartWithPointer = true) + { + // + // domain name can be either + // - a sequence of labels, where each label consists of a length octet + // followed by that number of octets, terminated by a zero length octet + // (root label) + // - a pointer, where the first two bits are set to 1, and the remaining + // 14 bits are an offset (from the start of the message) to the true + // label + // + // It is not specified by the RFC if pointers must be backwards only, + // the code below prohibits forward (and self) pointers to avoid + // infinite loops. It also allows pointers only to point to a + // label, not to another pointer. + // + + bytesRead = 0; + bool allowPointer = canStartWithPointer; + + if (offset < 0 || offset >= messageBuffer.Length) + { + return false; + } + + int currentOffset = offset; + + while (true) + { + byte length = messageBuffer.Span[currentOffset]; + + if ((length & 0xC0) == 0x00) + { + // length followed by the label + if (length == 0) + { + // end of name + bytesRead = currentOffset - offset + 1; + return true; + } + + if (currentOffset + 1 + length >= messageBuffer.Length) + { + // too many labels or truncated data + break; + } + + // read next label/segment + labels.Add(messageBuffer.Slice(currentOffset + 1, length)); + totalLength += 1 + length; + + // subtract one for the length prefix of the first label + if (totalLength - 1 > MaxDomainNameLength) + { + // domain name is too long + return false; + } + + currentOffset += 1 + length; + bytesRead += 1 + length; + + // we read a label, they can be followed by pointer. + allowPointer = true; + } + else if ((length & 0xC0) == 0xC0) + { + // pointer, together with next byte gives the offset of the true label + if (!allowPointer || currentOffset + 1 >= messageBuffer.Length) + { + // pointer to pointer or truncated data + break; + } + + bytesRead += 2; + int pointer = ((length & 0x3F) << 8) | messageBuffer.Span[currentOffset + 1]; + + // we prohibit self-references and forward pointers to avoid + // infinite loops, we do this by truncating the + // messageBuffer at the offset where we started reading the + // name. We also ignore the bytesRead from the recursive + // call, as we are only interested on how many bytes we read + // from the initial start of the name. + return TryReadQNameCore(labels, totalLength, messageBuffer.Slice(0, offset), pointer, out int _, false); + } + else + { + // top two bits are reserved, this means invalid data + break; + } + } + + return false; + + } + + internal static bool TryReadQName(ReadOnlyMemory messageBuffer, int offset, out EncodedDomainName name, out int bytesRead) + { + List> labels = new List>(); + + if (TryReadQNameCore(labels, 0, messageBuffer, offset, out bytesRead)) + { + name = new EncodedDomainName(labels); + return true; + } + else + { + bytesRead = 0; + name = default; + return false; + } + } + + internal static bool TryReadService(ReadOnlyMemory buffer, out ushort priority, out ushort weight, out ushort port, out EncodedDomainName target, out int bytesRead) + { + // https://www.rfc-editor.org/rfc/rfc2782 + if (!BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span, out priority) || + !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span.Slice(2), out weight) || + !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span.Slice(4), out port) || + !TryReadQName(buffer.Slice(6), 0, out target, out bytesRead)) + { + target = default; + priority = 0; + weight = 0; + port = 0; + bytesRead = 0; + return false; + } + + bytesRead += 6; + return true; + } + + internal static bool TryReadSoa(ReadOnlyMemory buffer, out EncodedDomainName primaryNameServer, out EncodedDomainName responsibleMailAddress, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out int bytesRead) + { + // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 + if (!TryReadQName(buffer, 0, out primaryNameServer, out int w1) || + !TryReadQName(buffer.Slice(w1), 0, out responsibleMailAddress, out int w2) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2), out serial) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 4), out refresh) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 8), out retry) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 12), out expire) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 16), out minimum)) + { + primaryNameServer = default; + responsibleMailAddress = default; + serial = 0; + refresh = 0; + retry = 0; + expire = 0; + minimum = 0; + bytesRead = 0; + return false; + } + + bytesRead = w1 + w2 + 20; + return true; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs new file mode 100644 index 00000000000..adab9161737 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +using System.Net; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal partial class DnsResolver : IDnsResolver, IDisposable +{ + internal static partial class Log + { + [LoggerMessage(1, LogLevel.Debug, "Resolving {QueryType} {QueryName} on {Server} attempt {Attempt}", EventName = "Query")] + public static partial void Query(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(2, LogLevel.Debug, "Result truncated for {QueryType} {QueryName} from {Server} attempt {Attempt}. Restarting over TCP", EventName = "ResultTruncated")] + public static partial void ResultTruncated(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(3, LogLevel.Error, "Server {Server} replied with {ResponseCode} when querying {QueryType} {QueryName}", EventName = "ErrorResponseCode")] + public static partial void ErrorResponseCode(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, QueryResponseCode responseCode); + + [LoggerMessage(4, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} timed out.", EventName = "Timeout")] + public static partial void Timeout(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(5, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt}: no data matching given query type.", EventName = "NoData")] + public static partial void NoData(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(6, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt}: server indicates given name does not exist.", EventName = "NameError")] + public static partial void NameError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(7, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed to return a valid DNS response.", EventName = "MalformedResponse")] + public static partial void MalformedResponse(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(8, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed due to a network error.", EventName = "NetworkError")] + public static partial void NetworkError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception); + + [LoggerMessage(9, LogLevel.Error, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed.", EventName = "QueryError")] + public static partial void QueryError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception); + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs new file mode 100644 index 00000000000..4be956cede9 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs @@ -0,0 +1,115 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.Metrics; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal partial class DnsResolver +{ + internal static class Telemetry + { + private static readonly Meter s_meter = new Meter("Microsoft.Extensions.ServiceDiscovery.Dns.Resolver"); + private static readonly Histogram s_queryDuration = s_meter.CreateHistogram("query.duration", "ms", "DNS query duration"); + + private static bool IsEnabled() => s_queryDuration.Enabled; + + public static NameResolutionActivity StartNameResolution(string hostName, QueryType queryType, long startingTimestamp) + { + if (IsEnabled()) + { + return new NameResolutionActivity(hostName, queryType, startingTimestamp); + } + + return default; + } + + public static void StopNameResolution(string hostName, QueryType queryType, in NameResolutionActivity activity, object? answers, SendQueryError error, long endingTimestamp) + { + activity.Stop(answers, error, endingTimestamp, out TimeSpan duration); + + if (!IsEnabled()) + { + return; + } + + var hostNameTag = KeyValuePair.Create("dns.question.name", (object?)hostName); + var queryTypeTag = KeyValuePair.Create("dns.question.type", (object?)queryType); + + if (answers is not null) + { + s_queryDuration.Record(duration.TotalSeconds, hostNameTag, queryTypeTag); + } + else + { + var errorTypeTag = KeyValuePair.Create("error.type", (object?)error.ToString()); + s_queryDuration.Record(duration.TotalSeconds, hostNameTag, queryTypeTag, errorTypeTag); + } + } + } + + internal readonly struct NameResolutionActivity + { + private const string ActivitySourceName = "Microsoft.Extensions.ServiceDiscovery.Dns.Resolver"; + private const string ActivityName = ActivitySourceName + ".Resolve"; + private static readonly ActivitySource s_activitySource = new ActivitySource(ActivitySourceName); + + private readonly long _startingTimestamp; + private readonly Activity? _activity; // null if activity is not started + + public NameResolutionActivity(string hostName, QueryType queryType, long startingTimestamp) + { + _startingTimestamp = startingTimestamp; + _activity = s_activitySource.StartActivity(ActivityName, ActivityKind.Client); + if (_activity is not null) + { + _activity.DisplayName = $"Resolving {hostName}"; + if (_activity.IsAllDataRequested) + { + _activity.SetTag("dns.question.name", hostName); + _activity.SetTag("dns.question.type", queryType.ToString()); + } + } + } + + public void Stop(object? answers, SendQueryError error, long endingTimestamp, out TimeSpan duration) + { + duration = Stopwatch.GetElapsedTime(_startingTimestamp, endingTimestamp); + + if (_activity is null) + { + return; + } + + if (_activity.IsAllDataRequested) + { + if (answers is not null) + { + static string[] ToStringHelper(T[] array) => array.Select(a => a!.ToString()!).ToArray(); + + string[]? answersArray = answers switch + { + ServiceResult[] serviceResults => ToStringHelper(serviceResults), + AddressResult[] addressResults => ToStringHelper(addressResults), + _ => null + }; + + Debug.Assert(answersArray is not null); + _activity.SetTag("dns.answers", answersArray); + } + else + { + _activity.SetTag("error.type", error.ToString()); + } + } + + if (answers is null) + { + _activity.SetStatus(ActivityStatusCode.Error); + } + + _activity.Stop(); + } + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs new file mode 100644 index 00000000000..5722356a1c3 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -0,0 +1,931 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal sealed partial class DnsResolver : IDnsResolver, IDisposable +{ + private const int IPv4Length = 4; + private const int IPv6Length = 16; + + // CancellationTokenSource.CancelAfter has a maximum timeout of Int32.MaxValue milliseconds. + private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue); + + private bool _disposed; + private readonly ResolverOptions _options; + private readonly CancellationTokenSource _pendingRequestsCts = new(); + private readonly TimeProvider _timeProvider; + private readonly ILogger _logger; + + public DnsResolver(TimeProvider timeProvider, ILogger logger) : this(timeProvider, logger, OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() ? ResolvConf.GetOptions() : NetworkInfo.GetOptions()) + { + } + + internal DnsResolver(TimeProvider timeProvider, ILogger logger, ResolverOptions options) + { + _timeProvider = timeProvider; + _logger = logger; + _options = options; + Debug.Assert(_options.Servers.Count > 0); + + if (options.Timeout != Timeout.InfiniteTimeSpan) + { + ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(options.Timeout, TimeSpan.Zero); + ArgumentOutOfRangeException.ThrowIfGreaterThan(options.Timeout, s_maxTimeout); + } + } + + internal DnsResolver(ResolverOptions options) : this(TimeProvider.System, NullLogger.Instance, options) + { + } + + internal DnsResolver(IEnumerable servers) : this(new ResolverOptions(servers.ToArray())) + { + } + + internal DnsResolver(IPEndPoint server) : this(new ResolverOptions(server)) + { + } + + public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed, this); + cancellationToken.ThrowIfCancellationRequested(); + + // dnsSafeName is Disposed by SendQueryWithTelemetry + EncodedDomainName dnsSafeName = GetNormalizedHostName(name); + return SendQueryWithTelemetry(name, dnsSafeName, QueryType.SRV, ProcessResponse, cancellationToken); + + static (SendQueryError, ServiceResult[]) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response) + { + var results = new List(response.Answers.Count); + + foreach (var answer in response.Answers) + { + if (answer.Type == QueryType.SRV) + { + if (!DnsPrimitives.TryReadService(answer.Data, out ushort priority, out ushort weight, out ushort port, out EncodedDomainName target, out int bytesRead) || bytesRead != answer.Data.Length) + { + return (SendQueryError.MalformedResponse, []); + } + + List addresses = new List(); + foreach (var additional in response.Additionals) + { + // From RFC 2782: + // + // Target + // The domain name of the target host. There MUST be one or more + // address records for this name, the name MUST NOT be an alias (in + // the sense of RFC 1034 or RFC 2181). Implementors are urged, but + // not required, to return the address record(s) in the Additional + // Data section. Unless and until permitted by future standards + // action, name compression is not to be used for this field. + // + // A Target of "." means that the service is decidedly not + // available at this domain. + if (additional.Name.Equals(target) && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA)) + { + addresses.Add(new AddressResult(response.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span))); + } + } + + results.Add(new ServiceResult(response.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target.ToString(), addresses.ToArray())); + } + } + + return (SendQueryError.NoError, results.ToArray()); + } + } + + public async ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) + { + if (string.Equals(name, "localhost", StringComparison.OrdinalIgnoreCase)) + { + // name localhost exists outside of DNS and can't be resolved by a DNS server + int len = (Socket.OSSupportsIPv4 ? 1 : 0) + (Socket.OSSupportsIPv6 ? 1 : 0); + AddressResult[] res = new AddressResult[len]; + + int index = 0; + if (Socket.OSSupportsIPv6) // prefer IPv6 + { + res[index] = new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback); + index++; + } + if (Socket.OSSupportsIPv4) + { + res[index] = new AddressResult(DateTime.MaxValue, IPAddress.Loopback); + } + + return res; + } + + var ipv4AddressesTask = ResolveIPAddressesAsync(name, AddressFamily.InterNetwork, cancellationToken); + var ipv6AddressesTask = ResolveIPAddressesAsync(name, AddressFamily.InterNetworkV6, cancellationToken); + + AddressResult[] ipv4Addresses = await ipv4AddressesTask.ConfigureAwait(false); + AddressResult[] ipv6Addresses = await ipv6AddressesTask.ConfigureAwait(false); + + AddressResult[] results = new AddressResult[ipv4Addresses.Length + ipv6Addresses.Length]; + ipv6Addresses.CopyTo(results, 0); + ipv4Addresses.CopyTo(results, ipv6Addresses.Length); + return results; + } + + public ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed, this); + cancellationToken.ThrowIfCancellationRequested(); + + if (addressFamily != AddressFamily.InterNetwork && addressFamily != AddressFamily.InterNetworkV6) + { + throw new ArgumentOutOfRangeException(nameof(addressFamily), addressFamily, "Invalid address family"); + } + + if (string.Equals(name, "localhost", StringComparison.OrdinalIgnoreCase)) + { + // name localhost exists outside of DNS and can't be resolved by a DNS server + if (addressFamily == AddressFamily.InterNetwork && Socket.OSSupportsIPv4) + { + return ValueTask.FromResult([new AddressResult(DateTime.MaxValue, IPAddress.Loopback)]); + } + else if (addressFamily == AddressFamily.InterNetworkV6 && Socket.OSSupportsIPv6) + { + return ValueTask.FromResult([new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback)]); + } + + return ValueTask.FromResult([]); + } + + // dnsSafeName is Disposed by SendQueryWithTelemetry + EncodedDomainName dnsSafeName = GetNormalizedHostName(name); + var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; + return SendQueryWithTelemetry(name, dnsSafeName, queryType, ProcessResponse, cancellationToken); + + static (SendQueryError error, AddressResult[] result) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response) + { + List results = new List(response.Answers.Count); + + // Servers send back CNAME records together with associated A/AAAA records. Servers + // send only those CNAME records relevant to the query, and if there is a CNAME record, + // there should not be other records associated with the name. Therefore, we simply follow + // the list of CNAME aliases until we get to the primary name and return the A/AAAA records + // associated. + // + // more info: https://datatracker.ietf.org/doc/html/rfc1034#section-3.6.2 + // + // Most of the servers send the CNAME records in order so that we can sequentially scan the + // answers, but nothing prevents the records from being in arbitrary order. Attempt the linear + // scan first and fallback to a slower but more robust method if necessary. + + bool success = true; + EncodedDomainName currentAlias = dnsSafeName; + + foreach (var answer in response.Answers) + { + switch (answer.Type) + { + case QueryType.CNAME: + if (!TryReadTarget(answer, response.RawMessageBytes, out EncodedDomainName target)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (answer.Name.Equals(currentAlias)) + { + currentAlias = target; + continue; + } + + break; + + case var type when type == queryType: + if (!TryReadAddress(answer, queryType, out IPAddress? address)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (answer.Name.Equals(currentAlias)) + { + results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), address)); + continue; + } + + break; + } + + // unexpected name or record type, fall back to more robust path + results.Clear(); + success = false; + break; + } + + if (success) + { + return (SendQueryError.NoError, results.ToArray()); + } + + // more expensive path for uncommon (but valid) cases where CNAME records are out of order. Use of Dictionary + // allows us to stay within O(n) complexity for the number of answers, but we will use more memory. + Dictionary aliasMap = new(); + Dictionary> aRecordMap = new(); + foreach (var answer in response.Answers) + { + if (answer.Type == QueryType.CNAME) + { + // map the alias to the target name + if (!TryReadTarget(answer, response.RawMessageBytes, out EncodedDomainName target)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (!aliasMap.TryAdd(answer.Name, target)) + { + // Duplicate CNAME record + return (SendQueryError.MalformedResponse, []); + } + } + + if (answer.Type == queryType) + { + if (!TryReadAddress(answer, queryType, out IPAddress? address)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (!aRecordMap.TryGetValue(answer.Name, out List? addressList)) + { + addressList = new List(); + aRecordMap.Add(answer.Name, addressList); + } + + addressList.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), address)); + } + } + + // follow the CNAME chain, limit the maximum number of iterations to avoid infinite loops. + int i = 0; + currentAlias = dnsSafeName; + while (aliasMap.TryGetValue(currentAlias, out EncodedDomainName nextAlias)) + { + if (i >= aliasMap.Count) + { + // circular CNAME chain + return (SendQueryError.MalformedResponse, []); + } + + i++; + + if (aRecordMap.ContainsKey(currentAlias)) + { + // both CNAME record and A/AAAA records exist for the current alias + return (SendQueryError.MalformedResponse, []); + } + + currentAlias = nextAlias; + } + + // Now we have the final target name, check if we have any A/AAAA records for it. + aRecordMap.TryGetValue(currentAlias, out List? finalAddressList); + return (SendQueryError.NoError, finalAddressList?.ToArray() ?? []); + + static bool TryReadTarget(in DnsResourceRecord record, ArraySegment messageBytes, out EncodedDomainName target) + { + Debug.Assert(record.Type == QueryType.CNAME, "Only CNAME records should be processed here."); + + target = default; + + // some servers use domain name compression even inside CNAME records. In order to decode those + // correctly, we need to pass the entire message to TryReadQName. The Data span inside the record + // should be backed by the array containing the entire DNS message. We just need to account for the + // 2 byte offset in case of TCP fallback. + var gotArray = MemoryMarshal.TryGetArray(record.Data, out ArraySegment segment); + Debug.Assert(gotArray, "Failed to get array segment"); + Debug.Assert(segment.Array == messageBytes.Array, "record data backed by different array than the original message"); + + int messageOffset = messageBytes.Offset; + + bool result = DnsPrimitives.TryReadQName(segment.Array.AsMemory(messageOffset, segment.Offset + segment.Count - messageOffset), segment.Offset - messageOffset, out EncodedDomainName targetName, out int bytesRead) && bytesRead == record.Data.Length; + if (result) + { + target = targetName; + } + + return result; + } + + static bool TryReadAddress(in DnsResourceRecord record, QueryType type, [NotNullWhen(true)] out IPAddress? target) + { + Debug.Assert(record.Type is QueryType.A or QueryType.AAAA, "Only CNAME records should be processed here."); + + target = null; + if (record.Type == QueryType.A && record.Data.Length != IPv4Length || + record.Type == QueryType.AAAA && record.Data.Length != IPv6Length) + { + return false; + } + + target = new IPAddress(record.Data.Span); + return true; + } + } + } + + private async ValueTask SendQueryWithTelemetry(string name, EncodedDomainName dnsSafeName, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) + { + NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType, _timeProvider.GetTimestamp()); + (SendQueryError error, TResult[] result) = await SendQueryWithRetriesAsync(name, dnsSafeName, queryType, processResponseFunc, cancellationToken).ConfigureAwait(false); + Telemetry.StopNameResolution(name, queryType, activity, null, error, _timeProvider.GetTimestamp()); + dnsSafeName.Dispose(); + + return result; + } + + internal struct SendQueryResult + { + public DnsResponse Response; + public SendQueryError Error; + } + + async ValueTask<(SendQueryError error, TResult[] result)> SendQueryWithRetriesAsync(string name, EncodedDomainName dnsSafeName, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) + { + SendQueryError lastError = SendQueryError.InternalError; // will be overwritten by the first attempt + for (int index = 0; index < _options.Servers.Count; index++) + { + IPEndPoint serverEndPoint = _options.Servers[index]; + + for (int attempt = 1; attempt <= _options.Attempts; attempt++) + { + DnsResponse response = default; + try + { + TResult[] results = Array.Empty(); + + try + { + SendQueryResult queryResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, dnsSafeName, queryType, attempt, cancellationToken).ConfigureAwait(false); + lastError = queryResult.Error; + response = queryResult.Response; + + if (lastError == SendQueryError.NoError) + { + // Given that result.Error is NoError, there should be at least one answer. + Debug.Assert(response.Answers.Count > 0); + (lastError, results) = processResponseFunc(dnsSafeName, queryType, queryResult.Response); + } + } + catch (SocketException ex) + { + Log.NetworkError(_logger, queryType, name, serverEndPoint, attempt, ex); + lastError = SendQueryError.NetworkError; + } + catch (Exception ex) when (!cancellationToken.IsCancellationRequested) + { + // internal error, propagate + Log.QueryError(_logger, queryType, name, serverEndPoint, attempt, ex); + throw; + } + + switch (lastError) + { + // + // Definitive answers, no point retrying + // + case SendQueryError.NoError: + return (lastError, results); + + case SendQueryError.NameError: + // authoritative answer that the name does not exist, no point in retrying + Log.NameError(_logger, queryType, name, serverEndPoint, attempt); + return (lastError, results); + + case SendQueryError.NoData: + // no data available for the name from authoritative server + Log.NoData(_logger, queryType, name, serverEndPoint, attempt); + return (lastError, results); + + // + // Transient errors, retry on the same server + // + case SendQueryError.Timeout: + Log.Timeout(_logger, queryType, name, serverEndPoint, attempt); + continue; + + case SendQueryError.NetworkError: + // TODO: retry with exponential backoff? + continue; + + case SendQueryError.ServerError when response.Header.ResponseCode == QueryResponseCode.ServerFailure: + // ServerFailure may indicate transient failure with upstream DNS servers, retry on the same server + Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, response.Header.ResponseCode); + continue; + + // + // Persistent errors, skip to the next server + // + case SendQueryError.ServerError: + // this should cover all response codes except NoError, NameError which are definite and handled above, and + // ServerFailure which is a transient error and handled above. + Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, response.Header.ResponseCode); + break; + + case SendQueryError.MalformedResponse: + Log.MalformedResponse(_logger, queryType, name, serverEndPoint, attempt); + break; + + case SendQueryError.InternalError: + // exception logged above. + break; + } + + // actual break that causes skipping to the next server + break; + } + finally + { + response.Dispose(); + } + } + } + + // if we get here, we exhausted all servers and all attempts + return (lastError, []); + } + + internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, EncodedDomainName dnsSafeName, QueryType queryType, int attempt, CancellationToken cancellationToken) + { + (CancellationTokenSource cts, bool disposeTokenSource, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken); + + try + { + return await SendQueryToServerAsync(serverEndPoint, name, dnsSafeName, queryType, attempt, cts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) when ( + !cancellationToken.IsCancellationRequested && // not cancelled by the caller + !pendingRequestsCts.IsCancellationRequested) // not cancelled by the global token (dispose) + // the only remaining token that could cancel this is the linked cts from the timeout. + { + Debug.Assert(cts.Token.IsCancellationRequested); + return new SendQueryResult { Error = SendQueryError.Timeout }; + } + catch (OperationCanceledException ex) when (cancellationToken.IsCancellationRequested && ex.CancellationToken != cancellationToken) + { + // cancellation was initiated by the caller, but exception was triggered by a linked token, + // rethrow the exception with the caller's token. + cancellationToken.ThrowIfCancellationRequested(); + throw new UnreachableException(); + } + finally + { + if (disposeTokenSource) + { + cts.Dispose(); + } + } + } + + private async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, EncodedDomainName dnsSafeName, QueryType queryType, int attempt, CancellationToken cancellationToken) + { + Log.Query(_logger, queryType, name, serverEndPoint, attempt); + + SendQueryError sendError = SendQueryError.NoError; + DateTime queryStartedTime = _timeProvider.GetUtcNow().DateTime; + DnsDataReader responseReader = default; + DnsMessageHeader header; + + try + { + // use transport override if provided + if (_options._transportOverride != null) + { + (responseReader, header, sendError) = SendDnsQueryCustomTransport(_options._transportOverride, dnsSafeName, queryType); + } + else + { + (responseReader, header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); + + if (header.IsResultTruncated) + { + Log.ResultTruncated(_logger, queryType, name, serverEndPoint, 0); + responseReader.Dispose(); + // TCP fallback + (responseReader, header, sendError) = await SendDnsQueryCoreTcpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); + } + } + + if (sendError != SendQueryError.NoError) + { + // we failed to get back any response + return new SendQueryResult { Error = sendError }; + } + + if ((uint)header.ResponseCode > (uint)QueryResponseCode.Refused) + { + // Response code is outside of valid range + return new SendQueryResult + { + Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.MalformedResponse + }; + } + + // Recheck that the server echoes back the DNS question + if (header.QueryCount != 1 || + !responseReader.TryReadQuestion(out var qName, out var qType, out var qClass) || + !dnsSafeName.Equals(qName) || qType != queryType || qClass != QueryClass.Internet) + { + // DNS Question mismatch + return new SendQueryResult + { + Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.MalformedResponse + }; + } + + // Structurally separate the resource records, this will validate only the + // "outside structure" of the resource record, it will not validate the content. + int ttl = int.MaxValue; + if (!TryReadRecords(header.AnswerCount, ref ttl, ref responseReader, out List? answers) || + !TryReadRecords(header.AuthorityCount, ref ttl, ref responseReader, out List? authorities) || + !TryReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader, out List? additionals)) + { + return new SendQueryResult + { + Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.MalformedResponse + }; + } + + DateTime expirationTime = + (answers.Count + authorities.Count + additionals.Count) > 0 ? queryStartedTime.AddSeconds(ttl) : queryStartedTime; + + SendQueryError validationError = ValidateResponse(header.ResponseCode, queryStartedTime, answers, authorities, ref expirationTime); + + // we transfer ownership of RawData to the response + DnsResponse response = new DnsResponse(responseReader.MessageBuffer, header, queryStartedTime, expirationTime, answers, authorities, additionals); + responseReader = default; // avoid disposing (and returning RawData to the pool) + + return new SendQueryResult { Response = response, Error = validationError }; + } + finally + { + responseReader.Dispose(); + } + + static bool TryReadRecords(int count, ref int ttl, ref DnsDataReader reader, out List records) + { + // Since `count` is attacker controlled, limit the initial capacity + // to 32 items to avoid excessive memory allocation. More than 32 + // records are unusual so we don't need to optimize for them. + records = new(Math.Min(count, 32)); + + for (int i = 0; i < count; i++) + { + if (!reader.TryReadResourceRecord(out var record)) + { + return false; + } + + ttl = Math.Min(ttl, record.Ttl); + records.Add(new DnsResourceRecord(record.Name, record.Type, record.Class, record.Ttl, record.Data)); + } + + return true; + } + } + + internal static bool GetNegativeCacheExpiration(DateTime createdAt, List authorities, out DateTime expiration) + { + // + // RFC 2308 Section 5 - Caching Negative Answers + // + // Like normal answers negative answers have a time to live (TTL). As + // there is no record in the answer section to which this TTL can be + // applied, the TTL must be carried by another method. This is done by + // including the SOA record from the zone in the authority section of + // the reply. When the authoritative server creates this record its TTL + // is taken from the minimum of the SOA.MINIMUM field and SOA's TTL. + // This TTL decrements in a similar manner to a normal cached answer and + // upon reaching zero (0) indicates the cached negative answer MUST NOT + // be used again. + // + + DnsResourceRecord? soa = authorities.FirstOrDefault(r => r.Type == QueryType.SOA); + if (soa != null && DnsPrimitives.TryReadSoa(soa.Value.Data, out _, out _, out _, out _, out _, out _, out uint minimum, out _)) + { + expiration = createdAt.AddSeconds(Math.Min(minimum, soa.Value.Ttl)); + return true; + } + + expiration = default; + return false; + } + + internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, DateTime createdAt, List answers, List authorities, ref DateTime expiration) + { + if (responseCode == QueryResponseCode.NoError) + { + if (answers.Count > 0) + { + return SendQueryError.NoError; + } + // + // RFC 2308 Section 2.2 - No Data + // + // NODATA is indicated by an answer with the RCODE set to NOERROR and no + // relevant answers in the answer section. The authority section will + // contain an SOA record, or there will be no NS records there. + // + // + // RFC 2308 Section 5 - Caching Negative Answers + // + // A negative answer that resulted from a no data error (NODATA) should + // be cached such that it can be retrieved and returned in response to + // another query for the same that resulted in + // the cached negative response. + // + if (!authorities.Any(r => r.Type == QueryType.NS) && GetNegativeCacheExpiration(createdAt, authorities, out DateTime newExpiration)) + { + expiration = newExpiration; + // _cache.TryAdd(name, queryType, expiration, Array.Empty()); + } + return SendQueryError.NoData; + } + + if (responseCode == QueryResponseCode.NameError) + { + // + // RFC 2308 Section 5 - Caching Negative Answers + // + // A negative answer that resulted from a name error (NXDOMAIN) should + // be cached such that it can be retrieved and returned in response to + // another query for the same that resulted in the + // cached negative response. + // + if (GetNegativeCacheExpiration(createdAt, authorities, out DateTime newExpiration)) + { + expiration = newExpiration; + // _cache.TryAddNonexistent(name, expiration); + } + + return SendQueryError.NameError; + } + + return SendQueryError.ServerError; + } + + internal static (DnsDataReader reader, DnsMessageHeader header, SendQueryError sendError) SendDnsQueryCustomTransport(Func, int, int> callback, EncodedDomainName dnsSafeName, QueryType queryType) + { + byte[] buffer = ArrayPool.Shared.Rent(2048); + try + { + (ushort transactionId, int length) = EncodeQuestion(buffer, dnsSafeName, queryType); + length = callback(buffer, length); + + DnsDataReader responseReader = new DnsDataReader(new ArraySegment(buffer, 0, length), true); + + if (!responseReader.TryReadHeader(out DnsMessageHeader header) || + header.TransactionId != transactionId || + !header.IsResponse) + { + return (default, default, SendQueryError.MalformedResponse); + } + + // transfer ownership of buffer to the caller + buffer = null!; + return (responseReader, header, SendQueryError.NoError); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreUdpAsync(IPEndPoint serverEndPoint, EncodedDomainName dnsSafeName, QueryType queryType, CancellationToken cancellationToken) + { + var buffer = ArrayPool.Shared.Rent(512); + try + { + Memory memory = buffer; + (ushort transactionId, int length) = EncodeQuestion(memory, dnsSafeName, queryType); + + using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + await socket.SendToAsync(memory.Slice(0, length), SocketFlags.None, serverEndPoint, cancellationToken).ConfigureAwait(false); + + DnsDataReader responseReader; + DnsMessageHeader header; + + while (true) + { + // Because this is UDP, the response must be in a single packet, + // if the response does not fit into a single UDP packet, the server will + // set the Truncated flag in the header, and we will need to retry with TCP. + int packetLength = await socket.ReceiveAsync(memory, SocketFlags.None, cancellationToken).ConfigureAwait(false); + + if (packetLength < DnsMessageHeader.HeaderLength) + { + continue; + } + + responseReader = new DnsDataReader(new ArraySegment(buffer, 0, packetLength), true); + if (!responseReader.TryReadHeader(out header) || + header.TransactionId != transactionId || + !header.IsResponse) + { + // header mismatch, this is not a response to our query + continue; + } + + // ownership of the buffer is transferred to the reader, caller will dispose. + buffer = null!; + return (responseReader, header); + } + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header, SendQueryError error)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, EncodedDomainName dnsSafeName, QueryType queryType, CancellationToken cancellationToken) + { + var buffer = ArrayPool.Shared.Rent(8 * 1024); + try + { + // When sending over TCP, the message is prefixed by 2B length + (ushort transactionId, int length) = EncodeQuestion(buffer.AsMemory(2), dnsSafeName, queryType); + BinaryPrimitives.WriteUInt16BigEndian(buffer, (ushort)length); + + using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(serverEndPoint, cancellationToken).ConfigureAwait(false); + await socket.SendAsync(buffer.AsMemory(0, length + 2), SocketFlags.None, cancellationToken).ConfigureAwait(false); + + int responseLength = -1; + int bytesRead = 0; + while (responseLength < 0 || bytesRead < responseLength + 2) + { + int read = await socket.ReceiveAsync(buffer.AsMemory(bytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false); + bytesRead += read; + + if (read == 0) + { + // connection closed before receiving complete response message + return (default, default, SendQueryError.MalformedResponse); + } + + if (responseLength < 0 && bytesRead >= 2) + { + responseLength = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); + + if (responseLength + 2 > buffer.Length) + { + // even though this is user-controlled pre-allocation, it is limited to + // 64 kB, so it should be fine. + var largerBuffer = ArrayPool.Shared.Rent(responseLength + 2); + Array.Copy(buffer, largerBuffer, bytesRead); + ArrayPool.Shared.Return(buffer); + buffer = largerBuffer; + } + } + } + + DnsDataReader responseReader = new DnsDataReader(new ArraySegment(buffer, 2, responseLength), true); + if (!responseReader.TryReadHeader(out DnsMessageHeader header) || + header.TransactionId != transactionId || + !header.IsResponse) + { + // header mismatch on TCP fallback + return (default, default, SendQueryError.MalformedResponse); + } + + // transfer ownership of buffer to the caller + buffer = null!; + return (responseReader, header, SendQueryError.NoError); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + + private static (ushort id, int length) EncodeQuestion(Memory buffer, EncodedDomainName dnsSafeName, QueryType queryType) + { + DnsMessageHeader header = new DnsMessageHeader + { + TransactionId = (ushort)RandomNumberGenerator.GetInt32(ushort.MaxValue + 1), + QueryFlags = QueryFlags.RecursionDesired, + QueryCount = 1 + }; + + DnsDataWriter writer = new DnsDataWriter(buffer); + if (!writer.TryWriteHeader(header) || + !writer.TryWriteQuestion(dnsSafeName, queryType, QueryClass.Internet)) + { + // should never happen since we validated the name length before + throw new InvalidOperationException("Buffer too small"); + } + return (header.TransactionId, writer.Position); + } + + public void Dispose() + { + if (!_disposed) + { + _disposed = true; + + // Cancel all pending requests (if any). Note that we don't call CancelPendingRequests() but cancel + // the CTS directly. The reason is that CancelPendingRequests() would cancel the current CTS and create + // a new CTS. We don't want a new CTS in this case. + _pendingRequestsCts.Cancel(); + _pendingRequestsCts.Dispose(); + } + } + + private (CancellationTokenSource TokenSource, bool DisposeTokenSource, CancellationTokenSource PendingRequestsCts) PrepareCancellationTokenSource(CancellationToken cancellationToken) + { + // We need a CancellationTokenSource to use with the request. We always have the global + // _pendingRequestsCts to use, plus we may have a token provided by the caller, and we may + // have a timeout. If we have a timeout or a caller-provided token, we need to create a new + // CTS (we can't, for example, timeout the pending requests CTS, as that could cancel other + // unrelated operations). Otherwise, we can use the pending requests CTS directly. + + // Snapshot the current pending requests cancellation source. It can change concurrently due to cancellation being requested + // and it being replaced, and we need a stable view of it: if cancellation occurs and the caller's token hasn't been canceled, + // it's either due to this source or due to the timeout, and checking whether this source is the culprit is reliable whereas + // it's more approximate checking elapsed time. + CancellationTokenSource pendingRequestsCts = _pendingRequestsCts; + TimeSpan timeout = _options.Timeout; + + bool hasTimeout = timeout != System.Threading.Timeout.InfiniteTimeSpan; + if (hasTimeout || cancellationToken.CanBeCanceled) + { + CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, pendingRequestsCts.Token); + if (hasTimeout) + { + cts.CancelAfter(timeout); + } + + return (cts, DisposeTokenSource: true, pendingRequestsCts); + } + + return (pendingRequestsCts, DisposeTokenSource: false, pendingRequestsCts); + } + + private static EncodedDomainName GetNormalizedHostName(string name) + { + byte[] buffer = ArrayPool.Shared.Rent(256); + try + { + if (!DnsPrimitives.TryWriteQName(buffer, name, out _)) + { + throw new ArgumentException($"'{name}' is not a valid DNS name.", nameof(name)); + } + + List> labels = new(); + Memory memory = buffer.AsMemory(); + while (true) + { + int len = memory.Span[0]; + + if (len == 0) + { + // root label, we are finished + break; + } + + labels.Add(memory.Slice(1, len)); + memory = memory.Slice(len + 1); + } + + buffer = null!; // ownership transferred to the EncodedDomainName + return new EncodedDomainName(labels, buffer); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs new file mode 100644 index 00000000000..914ff9aac17 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct DnsResourceRecord +{ + public EncodedDomainName Name { get; } + public QueryType Type { get; } + public QueryClass Class { get; } + public int Ttl { get; } + public ReadOnlyMemory Data { get; } + + public DnsResourceRecord(EncodedDomainName name, QueryType type, QueryClass @class, int ttl, ReadOnlyMemory data) + { + Name = name; + Type = type; + Class = @class; + Ttl = ttl; + Data = data; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs new file mode 100644 index 00000000000..5a7fc8a0b52 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct DnsResponse : IDisposable +{ + public DnsMessageHeader Header { get; } + public List Answers { get; } + public List Authorities { get; } + public List Additionals { get; } + public DateTime CreatedAt { get; } + public DateTime Expiration { get; } + public ArraySegment RawMessageBytes { get; private set; } + + public DnsResponse(ArraySegment rawData, DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals) + { + RawMessageBytes = rawData; + + Header = header; + CreatedAt = createdAt; + Expiration = expiration; + Answers = answers; + Authorities = authorities; + Additionals = additionals; + } + + public void Dispose() + { + if (RawMessageBytes.Array != null) + { + ArrayPool.Shared.Return(RawMessageBytes.Array); + } + + RawMessageBytes = default; // prevent further access to the raw data + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs new file mode 100644 index 00000000000..4c258cac3ac --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct EncodedDomainName : IEquatable, IDisposable +{ + public IReadOnlyList> Labels { get; } + private byte[]? _pooledBuffer; + + public EncodedDomainName(List> labels, byte[]? pooledBuffer = null) + { + Labels = labels; + _pooledBuffer = pooledBuffer; + } + public override string ToString() + { + StringBuilder sb = new StringBuilder(); + + foreach (var label in Labels) + { + if (sb.Length > 0) + { + sb.Append('.'); + } + sb.Append(Encoding.ASCII.GetString(label.Span)); + } + + return sb.ToString(); + } + + public bool Equals(EncodedDomainName other) + { + if (Labels.Count != other.Labels.Count) + { + return false; + } + + for (int i = 0; i < Labels.Count; i++) + { + if (!Ascii.EqualsIgnoreCase(Labels[i].Span, other.Labels[i].Span)) + { + return false; + } + } + + return true; + } + + public override bool Equals(object? obj) + { + return obj is EncodedDomainName other && Equals(other); + } + + public override int GetHashCode() + { + HashCode hash = new HashCode(); + + foreach (var label in Labels) + { + foreach (byte b in label.Span) + { + hash.Add((byte)char.ToLower((char)b)); + } + } + + return hash.ToHashCode(); + } + + public void Dispose() + { + if (_pooledBuffer != null) + { + ArrayPool.Shared.Return(_pooledBuffer); + } + + _pooledBuffer = null; + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs new file mode 100644 index 00000000000..e09168d9552 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal interface IDnsResolver +{ + ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default); + ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default); + ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default); +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs new file mode 100644 index 00000000000..c2ef13f922e --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.NetworkInformation; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal static class NetworkInfo +{ + // basic option to get DNS serves via NetworkInfo. We may get it directly later via proper APIs. + public static ResolverOptions GetOptions() + { + List servers = new List(); + + foreach (NetworkInterface nic in NetworkInterface.GetAllNetworkInterfaces()) + { + IPInterfaceProperties properties = nic.GetIPProperties(); + // avoid loopback, VPN etc. Should be re-visited. + + if (nic.NetworkInterfaceType == NetworkInterfaceType.Ethernet && nic.OperationalStatus == OperationalStatus.Up) + { + foreach (IPAddress server in properties.DnsAddresses) + { + IPEndPoint ep = new IPEndPoint(server, 53); // 53 is standard DNS port + if (!servers.Contains(ep)) + { + servers.Add(ep); + } + } + } + } + + return new ResolverOptions(servers); + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs new file mode 100644 index 00000000000..732ca0216da --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal enum QueryClass +{ + Internet = 1 +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs new file mode 100644 index 00000000000..02474b6cda1 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +[Flags] +internal enum QueryFlags : ushort +{ + RecursionAvailable = 0x0080, + RecursionDesired = 0x0100, + ResultTruncated = 0x0200, + HasAuthorityAnswer = 0x0400, + HasResponse = 0x8000, + ResponseCodeMask = 0x000F, +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs new file mode 100644 index 00000000000..dd51c712112 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +/// +/// The response code (RCODE) in a DNS query response. +/// +internal enum QueryResponseCode : byte +{ + /// + /// No error condition + /// + NoError = 0, + + /// + /// The name server was unable to interpret the query. + /// + FormatError = 1, + + /// + /// The name server was unable to process this query due to a problem with the name server. + /// + ServerFailure = 2, + + /// + /// Meaningful only for responses from an authoritative name server, this + /// code signifies that the domain name referenced in the query does not + /// exist. + /// + NameError = 3, + + /// + /// The name server does not support the requested kind of query. + /// + NotImplemented = 4, + + /// + /// The name server refuses to perform the specified operation for policy reasons. + /// + Refused = 5, +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs new file mode 100644 index 00000000000..2ccc898a5b7 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +/// +/// DNS Query Types. +/// +internal enum QueryType +{ + /// + /// A host address. + /// + A = 1, + + /// + /// An authoritative name server. + /// + NS = 2, + + /// + /// The canonical name for an alias. + /// + CNAME = 5, + + /// + /// Marks the start of a zone of authority. + /// + SOA = 6, + + /// + /// Mail exchange. + /// + MX = 15, + + /// + /// Text strings. + /// + TXT = 16, + + /// + /// IPv6 host address. (RFC 3596) + /// + AAAA = 28, + + /// + /// Location information. (RFC 2782) + /// + SRV = 33, + + /// + /// Wildcard match. + /// + All = 255 +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs new file mode 100644 index 00000000000..fbfdc5ae027 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Runtime.Versioning; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal static class ResolvConf +{ + [SupportedOSPlatform("linux")] + [SupportedOSPlatform("osx")] + public static ResolverOptions GetOptions() + { + return GetOptions(new StreamReader("/etc/resolv.conf")); + } + + public static ResolverOptions GetOptions(TextReader reader) + { + List serverList = new(); + + while (reader.ReadLine() is string line) + { + string[] tokens = line.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + + if (line.StartsWith("nameserver")) + { + if (tokens.Length >= 2 && IPAddress.TryParse(tokens[1], out IPAddress? address)) + { + serverList.Add(new IPEndPoint(address, 53)); // 53 is standard DNS port + + if (serverList.Count == 3) + { + break; // resolv.conf manpage allow max 3 nameservers anyway + } + } + } + } + + if (serverList.Count == 0) + { + // If no nameservers are configured, fall back to the default behavior of using the system resolver configuration. + return NetworkInfo.GetOptions(); + } + + return new ResolverOptions(serverList); + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs new file mode 100644 index 00000000000..51d03f64bfd --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal sealed class ResolverOptions +{ + public IReadOnlyList Servers; + public int Attempts = 2; + public TimeSpan Timeout = TimeSpan.FromSeconds(3); + + // override for testing purposes + internal Func, int, int>? _transportOverride; + + public ResolverOptions(IReadOnlyList servers) + { + if (servers.Count == 0) + { + throw new ArgumentException("At least one DNS server is required.", nameof(servers)); + } + + Servers = servers; + } + + public ResolverOptions(IPEndPoint server) + { + Servers = new IPEndPoint[] { server }; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs new file mode 100644 index 00000000000..aed799ac8d6 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal record struct AddressResult(DateTime ExpiresAt, IPAddress Address); + +internal record struct ServiceResult(DateTime ExpiresAt, int Priority, int Weight, int Port, string Target, AddressResult[] Addresses); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs new file mode 100644 index 00000000000..3ba5632e207 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal enum SendQueryError +{ + /// + /// DNS query was successful and returned response message with answers. + /// + NoError, + + /// + /// Server failed to respond to the query withing specified timeout. + /// + Timeout, + + /// + /// Server returned a response with an error code. + /// + ServerError, + + /// + /// Server returned a malformed response. + /// + MalformedResponse, + + /// + /// Server returned a response indicating that the name exists, but no data are available. + /// + NoData, + + /// + /// Server returned a response indicating the name does not exist. + /// + NameError, + + /// + /// Network-level error occurred during the query. + /// + NetworkError, + + /// + /// Internal error on part of the implementation. + /// + InternalError, +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs index 98b9de1fd68..7d05243f741 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs @@ -1,11 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using DnsClient; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.ServiceDiscovery; using Microsoft.Extensions.ServiceDiscovery.Dns; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.Hosting; @@ -46,7 +46,7 @@ public static IServiceCollection AddDnsSrvServiceEndpointProvider(this IServiceC ArgumentNullException.ThrowIfNull(configureOptions); services.AddServiceDiscoveryCore(); - services.TryAddSingleton(); + services.TryAddSingleton(); services.AddSingleton(); var options = services.AddOptions(); options.Configure(o => configureOptions?.Invoke(o)); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore new file mode 100644 index 00000000000..0151cc4e360 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore @@ -0,0 +1,2 @@ +# corpuses generated by the fuzzing engine +corpuses/** \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs new file mode 100644 index 00000000000..1b180d74b9d --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +internal sealed class DnsResponseFuzzer : IFuzzer +{ + DnsResolver? _resolver; + byte[]? _buffer; + int _length; + + public void FuzzTarget(ReadOnlySpan data) + { + // lazy init + if (_resolver == null) + { + _buffer = new byte[4096]; + _resolver = new DnsResolver(new ResolverOptions(new IPEndPoint(IPAddress.Loopback, 53)) + { + Timeout = TimeSpan.FromSeconds(5), + Attempts = 1, + _transportOverride = (buffer, length) => + { + // the first two bytes are the random transaction ID, so we keep that + // and use the fuzzing payload for the rest of the DNS response + _buffer.AsSpan(0, Math.Min(_length, buffer.Length - 2)).CopyTo(buffer.Span.Slice(2)); + return _length + 2; + } + }); + } + + data.CopyTo(_buffer!); + _length = data.Length; + + // the _transportOverride makes the execution synchronous + ValueTask task = _resolver!.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork, CancellationToken.None); + Debug.Assert(task.IsCompleted, "Task should be completed synchronously"); + task.GetAwaiter().GetResult(); + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs new file mode 100644 index 00000000000..72f84b3c959 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +internal sealed class EncodedDomainNameFuzzer : IFuzzer +{ + public void FuzzTarget(ReadOnlySpan data) + { + byte[] buffer = ArrayPool.Shared.Rent(data.Length); + try + { + data.CopyTo(buffer); + + // attempt to read at any offset to really stress the parser + for (int i = 0; i < data.Length; i++) + { + if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, data.Length), i, out EncodedDomainName name, out _)) + { + continue; + } + + // the domain name should be readable + _ = name.ToString(); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs new file mode 100644 index 00000000000..f657245a842 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +internal sealed class WriteDomainNameRoundTripFuzzer : IFuzzer +{ + private static readonly System.Globalization.IdnMapping s_idnMapping = new(); + public void FuzzTarget(ReadOnlySpan data) + { + // first byte is the offset of the domain name, rest is the actual + // (simulated) DNS message payload + + byte[] buffer = ArrayPool.Shared.Rent(data.Length * 2); + + try + { + string domainName = Encoding.UTF8.GetString(data); + if (!DnsPrimitives.TryWriteQName(buffer, domainName, out int written)) + { + return; + } + + if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, written), 0, out EncodedDomainName name, out int read)) + { + return; + } + + if (read != written) + { + throw new InvalidOperationException($"Read {read} bytes, but wrote {written} bytes"); + } + + string readName = name.ToString(); + + if (!string.Equals(s_idnMapping.GetAscii(domainName).TrimEnd('.'), readName, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException($"Domain name mismatch: {readName} != {domainName}"); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs new file mode 100644 index 00000000000..2ff9d86b2ce --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs @@ -0,0 +1,5 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +global using System.Buffers; +global using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs new file mode 100644 index 00000000000..4b4c8c99b4b --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +public interface IFuzzer +{ + string Name => GetType().Name; + void FuzzTarget(ReadOnlySpan data); +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj new file mode 100644 index 00000000000..6572c27a1fa --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj @@ -0,0 +1,18 @@ + + + + $(DefaultTargetFramework) + enable + enable + Exe + + + + + + + + + + + diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs new file mode 100644 index 00000000000..22b1580d1ac --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using SharpFuzz; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +public static class Program +{ + public static void Main(string[] args) + { + IFuzzer[] fuzzers = typeof(Program).Assembly.GetTypes() + .Where(t => t.IsClass && !t.IsAbstract) + .Where(t => t.GetInterfaces().Contains(typeof(IFuzzer))) + .Select(t => (IFuzzer)Activator.CreateInstance(t)!) + .OrderBy(f => f.Name, StringComparer.OrdinalIgnoreCase) + .ToArray(); + + void PrintUsage() + { + Console.Error.WriteLine($""" + Usage: + DotnetFuzzing list + DotnetFuzzing [input file/directory] + // DotnetFuzzing prepare-onefuzz + + Available fuzzers: + {string.Join(Environment.NewLine, fuzzers.Select(f => $" {f.Name}"))} + """); + } + + if (args.Length == 0) + { + PrintUsage(); + return; + } + + string arg = args[0]; + IFuzzer? fuzzer = fuzzers.FirstOrDefault(f => string.Equals(f.Name, arg, StringComparison.OrdinalIgnoreCase)); + if (fuzzer == null) + { + Console.Error.WriteLine($"Unknown fuzzer: {arg}"); + PrintUsage(); + return; + } + + string? inputFiles = args.Length > 1 ? args[1] : null; + if (string.IsNullOrEmpty(inputFiles)) + { + // no input files, let the fuzzer generate + Fuzzer.LibFuzzer.Run(fuzzer.FuzzTarget); + return; + } + + string[] files = Directory.Exists(inputFiles) + ? Directory.GetFiles(inputFiles) + : [inputFiles]; + + foreach (string inputFile in files) + { + fuzzer.FuzzTarget(File.ReadAllBytes(inputFile)); + } + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/ip-www.example.com b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/ip-www.example.com new file mode 100644 index 00000000000..bb40cd100c6 Binary files /dev/null and b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/ip-www.example.com differ diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error new file mode 100644 index 00000000000..92a307a0f72 Binary files /dev/null and b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error differ diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error-2 b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error-2 new file mode 100644 index 00000000000..5b37565190e Binary files /dev/null and b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error-2 differ diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/no-data b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/no-data new file mode 100644 index 00000000000..23265fc7a8f Binary files /dev/null and b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/no-data differ diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/server-error b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/server-error new file mode 100644 index 00000000000..27f1054b9b1 Binary files /dev/null and b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/server-error differ diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/EncodedDomainNameFuzzer/ip-www.example.com b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/EncodedDomainNameFuzzer/ip-www.example.com new file mode 100644 index 00000000000..c227840c168 Binary files /dev/null and b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/EncodedDomainNameFuzzer/ip-www.example.com differ diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/example b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/example new file mode 100644 index 00000000000..8642ba7b94c --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/example @@ -0,0 +1 @@ +www.example.com \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/nonascii b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/nonascii new file mode 100644 index 00000000000..692a5a8aee0 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/nonascii @@ -0,0 +1 @@ +www.řffwefw.com \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/toolong b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/toolong new file mode 100644 index 00000000000..bb6d3a722e8 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/toolong @@ -0,0 +1 @@ +aa.efaw.ef.wef.ef.wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww.fafeww.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.fwefefefefwefwf.wzzefwefwefwefwfeewfwefwefw.ffff \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/run.ps1 b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/run.ps1 new file mode 100644 index 00000000000..17b3d27055d --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/run.ps1 @@ -0,0 +1,106 @@ +param( + # Name of the fuzzing target, see Fuzzers/*.cs files + [Parameter(Mandatory = $true, Position = 0)] + [ArgumentCompleter({ + param($commandName, $parameterName, $wordToComplete, $commandAst, $fakeBoundParameters) + $corpusSeedPath = Join-Path $PSScriptRoot "Fuzzers" + if (Test-Path $corpusSeedPath) { + Get-ChildItem -Path $corpusSeedPath -Filter "$wordToComplete*.cs" | ForEach-Object { $_.BaseName } + } + })] + [string] $Target, + + # Number of parallel jobs to run + [int] $Jobs, + + # Maximum length of the input + [int] $MaxLength = 512, + + # Ignore timeouts when running the fuzzer + [switch] $IgnoreTimeouts, + + # Skip the build of the project useful for reruning the fuzzer without recompiling + [switch] $NoBuild, + + # Path to the libfuzzer driver + [string] $LibFuzzer = "libfuzzer-dotnet-windows" +) + +$timeout = 30 +$SharpFuzz = "sharpfuzz" +$dict = $null + +$corpus = Join-Path $PSScriptRoot "corpuses" $Target +$null = New-Item -Path $corpus -ItemType Directory -Force + +$CorpusSeed = Join-Path $PSScriptRoot "corpus-seed" $Target + +if (Test-Path $CorpusSeed -ErrorAction SilentlyContinue) { + Write-Output "Copying corpus seed from $CorpusSeed to $corpus" + Get-ChildItem -Path $CorpusSeed | Copy-Item -Destination $corpus +} + +$project = Join-Path $PSScriptRoot "Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj" + +Set-StrictMode -Version Latest + +$outputDir = "bin" +$projectName = (Get-Item $project).BaseName +$projectDll = "$projectName.dll" +$executable = if ($IsWindows) { Join-Path $outputDir "$projectName.exe" } +else { Join-Path $outputDir "$projectName" } + +if (!$NoBuild) { + + if (Test-Path $outputDir) { + Remove-Item -Recurse -Force $outputDir + } + + dotnet publish $project -c release -o $outputDir + + $exclusions = @( + "dnlib.dll", + "SharpFuzz.dll", + "SharpFuzz.Common.dll", + $projectDll + ) + + $fuzzingTargets = @(Get-Item "$outputDir/Microsoft.Extensions.ServiceDiscovery.Dns.dll") + + if (($fuzzingTargets | Measure-Object).Count -eq 0) { + Write-Error "No fuzzing targets found" + exit 1 + } + + foreach ($fuzzingTarget in $fuzzingTargets) { + Write-Output "Instrumenting $fuzzingTarget" + & $SharpFuzz $fuzzingTarget.FullName + + if ($LastExitCode -ne 0) { + Write-Error "An error occurred while instrumenting $fuzzingTarget" + exit 1 + } + } +} + +$parameters = @( + "-timeout=$timeout" +) + +if ($Jobs) { + $parameters += "-fork=$Jobs" +} + +if ($IgnoreTimeouts) { + $parameters += "-ignore_timeouts=1" +} + +if ($MaxLength) { + $parameters += "-max_len=$MaxLength" +} + +if ($dict) { + $parameters += "-dict=$dict" +} + +& $LibFuzzer @parameters --target_path=$executable --target_arg=$Target $corpus \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs index 2b3a7fd7cd3..b949e713999 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs @@ -4,6 +4,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Time.Testing; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; using Xunit; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests; @@ -16,6 +17,7 @@ public async Task ResolveServiceEndpoint_Dns_MultiShot() var timeProvider = new FakeTimeProvider(); var services = new ServiceCollection() .AddSingleton(timeProvider) + .AddSingleton() .AddServiceDiscoveryCore() .AddDnsServiceEndpointProvider(o => o.DefaultRefreshPeriod = TimeSpan.FromSeconds(30)) .BuildServiceProvider(); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs index b58d9e2f4ec..ec21bf9fa9c 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net; -using DnsClient; -using DnsClient.Protocol; +using System.Net.Sockets; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Configuration.Memory; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; using Microsoft.Extensions.ServiceDiscovery.Internal; using Xunit; @@ -19,88 +19,38 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests; /// public class DnsSrvServiceEndpointResolverTests { - private sealed class FakeDnsClient : IDnsQuery + private sealed class FakeDnsResolver : IDnsResolver { - public Func>? QueryAsyncFunc { get; set; } + public Func>? ResolveIPAddressesAsyncFunc { get; set; } + public ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc!.Invoke(name, addressFamily, cancellationToken); - public IDnsQueryResponse Query(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse Query(DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse Query(DnsQuestion question, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException(); - public Task QueryAsync(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) - => QueryAsyncFunc!(query, queryType, queryClass, cancellationToken); - public Task QueryAsync(DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryAsync(DnsQuestion question, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryCache(DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse QueryCache(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryReverse(IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryReverse(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException(); - public Task QueryReverseAsync(IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryReverseAsync(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, DnsQuestion question, DnsQueryOptions queryOptions) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, DnsQuestion question, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress, DnsQueryOptions queryOptions) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - } + public Func>? ResolveIPAddressesAsyncFunc2 { get; set; } - private sealed class FakeDnsQueryResponse : IDnsQueryResponse - { - public IReadOnlyList? Questions { get; set; } - public IReadOnlyList? Additionals { get; set; } - public IEnumerable? AllRecords { get; set; } - public IReadOnlyList? Answers { get; set; } - public IReadOnlyList? Authorities { get; set; } - public string? AuditTrail { get; set; } - public string? ErrorMessage { get; set; } - public bool HasError { get; set; } - public DnsResponseHeader? Header { get; set; } - public int MessageSize { get; set; } - public NameServer? NameServer { get; set; } - public DnsQuerySettings? Settings { get; set; } + public ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc2!.Invoke(name, cancellationToken); + + public Func>? ResolveServiceAsyncFunc { get; set; } + + public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) => ResolveServiceAsyncFunc!.Invoke(name, cancellationToken); } [Fact] public async Task ResolveServiceEndpoint_DnsSrv() { - var dnsClientMock = new FakeDnsClient + var dnsClientMock = new FakeDnsResolver { - QueryAsyncFunc = (query, queryType, queryClass, cancellationToken) => + ResolveServiceAsyncFunc = (name, cancellationToken) => { - var response = new FakeDnsQueryResponse - { - Answers = new List - { - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 66, 8888, DnsString.Parse("srv-a")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 9999, DnsString.Parse("srv-b")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 7777, DnsString.Parse("srv-c")) - }, - Additionals = new List - { - new ARecord(new ResourceRecordInfo("srv-a", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Parse("10.10.10.10")), - new ARecord(new ResourceRecordInfo("srv-b", ResourceRecordType.AAAA, queryClass, 64, 0), IPAddress.IPv6Loopback), - new CNameRecord(new ResourceRecordInfo("srv-c", ResourceRecordType.AAAA, queryClass, 64, 0), DnsString.Parse("remotehost")), - new TxtRecord(new ResourceRecordInfo("srv-a", ResourceRecordType.TXT, queryClass, 64, 0), ["some txt values"], ["some txt utf8 values"]) - } - }; + ServiceResult[] response = [ + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 66, 8888, "srv-a", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Parse("10.10.10.10"))]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 9999, "srv-b", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.IPv6Loopback)]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 7777, "srv-c", []) + ]; - return Task.FromResult(response); + return ValueTask.FromResult(response); } }; var services = new ServiceCollection() - .AddSingleton(dnsClientMock) + .AddSingleton(dnsClientMock) .AddServiceDiscoveryCore() .AddDnsSrvServiceEndpointProvider(options => options.QuerySuffix = ".ns") .BuildServiceProvider(); @@ -119,7 +69,7 @@ public async Task ResolveServiceEndpoint_DnsSrv() var eps = initialResult.EndpointSource.Endpoints; Assert.Equal(new IPEndPoint(IPAddress.Parse("10.10.10.10"), 8888), eps[0].EndPoint); Assert.Equal(new IPEndPoint(IPAddress.IPv6Loopback, 9999), eps[1].EndPoint); - Assert.Equal(new DnsEndPoint("remotehost", 7777), eps[2].EndPoint); + Assert.Equal(new DnsEndPoint("srv-c", 7777), eps[2].EndPoint); Assert.All(initialResult.EndpointSource.Endpoints, ep => { @@ -137,28 +87,17 @@ public async Task ResolveServiceEndpoint_DnsSrv() [Theory] public async Task ResolveServiceEndpoint_DnsSrv_MultipleProviders_PreventMixing(bool dnsFirst) { - var dnsClientMock = new FakeDnsClient + var dnsClientMock = new FakeDnsResolver { - QueryAsyncFunc = (query, queryType, queryClass, cancellationToken) => + ResolveServiceAsyncFunc = (name, cancellationToken) => { - var response = new FakeDnsQueryResponse - { - Answers = new List - { - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 66, 8888, DnsString.Parse("srv-a")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 9999, DnsString.Parse("srv-b")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 7777, DnsString.Parse("srv-c")) - }, - Additionals = new List - { - new ARecord(new ResourceRecordInfo("srv-a", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Parse("10.10.10.10")), - new ARecord(new ResourceRecordInfo("srv-b", ResourceRecordType.AAAA, queryClass, 64, 0), IPAddress.IPv6Loopback), - new CNameRecord(new ResourceRecordInfo("srv-c", ResourceRecordType.AAAA, queryClass, 64, 0), DnsString.Parse("remotehost")), - new TxtRecord(new ResourceRecordInfo("srv-a", ResourceRecordType.TXT, queryClass, 64, 0), ["some txt values"], ["some txt utf8 values"]) - } - }; + ServiceResult[] response = [ + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 66, 8888, "srv-a", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Parse("10.10.10.10"))]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 9999, "srv-b", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.IPv6Loopback)]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 7777, "srv-c", []) + ]; - return Task.FromResult(response); + return ValueTask.FromResult(response); } }; var configSource = new MemoryConfigurationSource @@ -171,7 +110,7 @@ public async Task ResolveServiceEndpoint_DnsSrv_MultipleProviders_PreventMixing( }; var config = new ConfigurationBuilder().Add(configSource); var serviceCollection = new ServiceCollection() - .AddSingleton(dnsClientMock) + .AddSingleton(dnsClientMock) .AddSingleton(config.Build()) .AddServiceDiscoveryCore(); if (dnsFirst) @@ -211,7 +150,7 @@ public async Task ResolveServiceEndpoint_DnsSrv_MultipleProviders_PreventMixing( var eps = initialResult.EndpointSource.Endpoints; Assert.Equal(new IPEndPoint(IPAddress.Parse("10.10.10.10"), 8888), eps[0].EndPoint); Assert.Equal(new IPEndPoint(IPAddress.IPv6Loopback, 9999), eps[1].EndPoint); - Assert.Equal(new DnsEndPoint("remotehost", 7777), eps[2].EndPoint); + Assert.Equal(new DnsEndPoint("srv-c", 7777), eps[2].EndPoint); Assert.All(initialResult.EndpointSource.Endpoints, ep => { diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj index 24faf1d8abe..f6202d17d37 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj @@ -12,6 +12,10 @@ + + + + diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs new file mode 100644 index 00000000000..8c646ac18ee --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class CancellationTests : LoopbackDnsTestBase +{ + public CancellationTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public async Task PreCanceledToken_Throws() + { + CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); + + var ex = await Assert.ThrowsAnyAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork, cts.Token)); + + Assert.Equal(cts.Token, ex.CancellationToken); + } + + [Fact] + public async Task CancellationInProgress_Throws() + { + CancellationTokenSource cts = new CancellationTokenSource(); + + var task = Assert.ThrowsAnyAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork, cts.Token)); + + await DnsServer.ProcessUdpRequest(_ => + { + cts.Cancel(); + return Task.CompletedTask; + }); + + OperationCanceledException ex = await task; + Assert.Equal(cts.Token, ex.CancellationToken); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs new file mode 100644 index 00000000000..aad32fe785f --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class DnsDataReaderTests +{ + [Fact] + public void ReadResourceRecord_Success() + { + // example A record for example.com + byte[] buffer = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + DnsDataReader reader = new DnsDataReader(buffer); + Assert.True(reader.TryReadResourceRecord(out DnsResourceRecord record)); + + Assert.Equal("www.example.com", record.Name.ToString()); + Assert.Equal(QueryType.A, record.Type); + Assert.Equal(QueryClass.Internet, record.Class); + Assert.Equal(3600, record.Ttl); + Assert.Equal(4, record.Data.Length); + } + + [Fact] + public void ReadResourceRecord_Truncated_Fails() + { + // example A record for example.com + byte[] buffer = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + for (int i = 0; i < buffer.Length; i++) + { + DnsDataReader reader = new DnsDataReader(new ArraySegment(buffer, 0, i)); + Assert.False(reader.TryReadResourceRecord(out _)); + } + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs new file mode 100644 index 00000000000..b2039ce5a4c --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs @@ -0,0 +1,148 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class DnsDataWriterTests +{ + [Fact] + public void WriteResourceRecord_Success() + { + // example A record for example.com + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + DnsResourceRecord record = new DnsResourceRecord(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet, 3600, new byte[4]); + + byte[] buffer = new byte[512]; + DnsDataWriter writer = new DnsDataWriter(buffer); + Assert.True(writer.TryWriteResourceRecord(record)); + Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); + } + + [Fact] + public void WriteResourceRecord_Truncated_Fails() + { + // example A record for example.com + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + DnsResourceRecord record = new DnsResourceRecord(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet, 3600, new byte[4]); + + byte[] buffer = new byte[512]; + for (int i = 0; i < expected.Length; i++) + { + DnsDataWriter writer = new DnsDataWriter(buffer.AsMemory(0, i)); + Assert.False(writer.TryWriteResourceRecord(record)); + } + } + + [Fact] + public void WriteQuestion_Success() + { + // example question for example.com (A record) + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01 + ]; + + byte[] buffer = new byte[512]; + DnsDataWriter writer = new DnsDataWriter(buffer); + Assert.True(writer.TryWriteQuestion(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet)); + Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); + } + + [Fact] + public void WriteQuestion_Truncated_Fails() + { + // example question for example.com (A record) + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01 + ]; + + byte[] buffer = new byte[512]; + for (int i = 0; i < expected.Length; i++) + { + DnsDataWriter writer = new DnsDataWriter(buffer.AsMemory(0, i)); + Assert.False(writer.TryWriteQuestion(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet)); + } + } + + [Fact] + public void WriteHeader_Success() + { + // example header + byte[] expected = [ + // ID (0x1234) + 0x12, 0x34, + // Flags (0x5678) + 0x56, 0x78, + // Question count (1) + 0x00, 0x01, + // Answer count (0) + 0x00, 0x02, + // Authority count (0) + 0x00, 0x03, + // Additional count (0) + 0x00, 0x04 + ]; + + DnsMessageHeader header = new() + { + TransactionId = 0x1234, + QueryFlags = (QueryFlags)0x5678, + QueryCount = 1, + AnswerCount = 2, + AuthorityCount = 3, + AdditionalRecordCount = 4, + }; + + byte[] buffer = new byte[512]; + DnsDataWriter writer = new DnsDataWriter(buffer); + Assert.True(writer.TryWriteHeader(header)); + Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); + } + + private static EncodedDomainName EncodeDomainName(string name) + { + byte[] nameBuffer = new byte[512]; + Assert.True(DnsPrimitives.TryWriteQName(nameBuffer, name, out int nameLength)); + Assert.True(DnsPrimitives.TryReadQName(nameBuffer.AsMemory(0, nameLength), 0, out EncodedDomainName encodedDomainName, out _)); + return encodedDomainName; + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs new file mode 100644 index 00000000000..6733a553bad --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs @@ -0,0 +1,195 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class DnsPrimitivesTests +{ + public static TheoryData QNameData => new() + { + { "www.example.com", "\x0003www\x0007example\x0003com\x0000"u8.ToArray() }, + { "example.com", "\x0007example\x0003com\x0000"u8.ToArray() }, + { "com", "\x0003com\x0000"u8.ToArray() }, + { "example", "\x0007example\x0000"u8.ToArray() }, + { "www", "\x0003www\x0000"u8.ToArray() }, + { "a", "\x0001a\x0000"u8.ToArray() }, + }; + + [Theory] + [MemberData(nameof(QNameData))] + public void TryWriteQName_Success(string name, byte[] expected) + { + byte[] buffer = new byte[512]; + + Assert.True(DnsPrimitives.TryWriteQName(buffer, name, out int written)); + Assert.Equal(name.Length + 2, written); + Assert.Equal(expected, buffer.AsSpan().Slice(0, written).ToArray()); + } + + [Fact] + public void TryWriteQName_LabelTooLong_False() + { + byte[] buffer = new byte[512]; + + Assert.False(DnsPrimitives.TryWriteQName(buffer, new string('a', 70), out _)); + } + + [Fact] + public void TryWriteQName_BufferTooShort_Fails() + { + byte[] buffer = new byte[512]; + string name = "www.example.com"; + + for (int i = 0; i < name.Length + 2; i++) + { + Assert.False(DnsPrimitives.TryWriteQName(buffer.AsSpan(0, i), name, out _)); + } + } + + [Theory] + [InlineData("www.-0.com")] + [InlineData("www.-a.com")] + [InlineData("www.a-.com")] + [InlineData("www.a_a.com")] + [InlineData("www.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.com")] // 64 occurrences of 'a' (too long) + [InlineData("www.a~a.com")] // 64 occurrences of 'a' (too long) + [InlineData("www..com")] + [InlineData("www..")] + public void TryWriteQName_InvalidName_ReturnsFalse(string name) + { + byte[] buffer = new byte[512]; + Assert.False(DnsPrimitives.TryWriteQName(buffer, name, out _)); + } + + [Fact] + public void TryWriteQName_ExplicitRoot_Success() + { + string name1 = "www.example.com"; + string name2 = "www.example.com."; + + byte[] buffer1 = new byte[512]; + byte[] buffer2 = new byte[512]; + + Assert.True(DnsPrimitives.TryWriteQName(buffer1, name1, out int written1)); + Assert.True(DnsPrimitives.TryWriteQName(buffer2, name2, out int written2)); + Assert.Equal(written1, written2); + Assert.Equal(buffer1.AsSpan().Slice(0, written1).ToArray(), buffer2.AsSpan().Slice(0, written2).ToArray()); + } + + [Theory] + [MemberData(nameof(QNameData))] + public void TryReadQName_Success(string expected, byte[] serialized) + { + Assert.True(DnsPrimitives.TryReadQName(serialized, 0, out EncodedDomainName actual, out int bytesRead)); + Assert.Equal(expected, actual.ToString()); + Assert.Equal(serialized.Length, bytesRead); + } + + [Fact] + public void TryReadQName_TruncatedData_Fails() + { + ReadOnlyMemory data = "\x0003www\x0007example\x0003com\x0000"u8.ToArray(); + + for (int i = 0; i < data.Length; i++) + { + Assert.False(DnsPrimitives.TryReadQName(data.Slice(0, i), 0, out _, out _)); + } + } + + [Fact] + public void TryReadQName_Pointer_Success() + { + // [7B padding], example.com. www->[ptr to example.com.] + Memory data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); + data.Span[^2] = 0xc0; + + Assert.True(DnsPrimitives.TryReadQName(data, data.Length - 6, out EncodedDomainName actual, out int bytesRead)); + Assert.Equal("www.example.com", actual.ToString()); + Assert.Equal(6, bytesRead); + } + + [Fact] + public void TryReadQName_PointerTruncated_Fails() + { + // [7B padding], example.com. www->[ptr to example.com.] + Memory data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); + data.Span[^2] = 0xc0; + + for (int i = 0; i < data.Length; i++) + { + Assert.False(DnsPrimitives.TryReadQName(data.Slice(0, i), data.Length - 6, out _, out _)); + } + } + + [Fact] + public void TryReadQName_ForwardPointer_Fails() + { + // www->[ptr to example.com], [7B padding], example.com. + Memory data = "\x03www\x00\x000dpadding\x0007example\x0003com\x00"u8.ToArray(); + data.Span[4] = 0xc0; + + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + + [Fact] + public void TryReadQName_PointerToSelf_Fails() + { + // www->[ptr to www->...] + Memory data = "\x0003www\0\0"u8.ToArray(); + data.Span[4] = 0xc0; + + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + + [Fact] + public void TryReadQName_PointerToPointer_Fails() + { + // com, example[->com], example2[->[->com]] + Memory data = "\x0003com\0\x0007example\0\0\x0008example2\0\0"u8.ToArray(); + data.Span[13] = 0xc0; + data.Span[14] = 0x00; // -> com + data.Span[24] = 0xc0; + data.Span[25] = 13; // -> -> com + + Assert.False(DnsPrimitives.TryReadQName(data, 15, out _, out _)); + } + + [Fact] + public void TryReadQName_ReservedBits() + { + Memory data = "\x0003www\x00c0"u8.ToArray(); + data.Span[0] = 0x40; + + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + + [Theory] + [InlineData(253)] + [InlineData(254)] + [InlineData(255)] + public void TryReadQName_NameTooLong(int length) + { + // longest possible label is 63 bytes + 1 byte for length + byte[] labelData = new byte[64]; + Array.Fill(labelData, (byte)'a'); + labelData[0] = 63; + + int remainder = length - 3 * 64; + + byte[] lastLabelData = new byte[remainder + 1]; + Array.Fill(lastLabelData, (byte)'a'); + lastLabelData[0] = (byte)remainder; + + byte[] data = Enumerable.Repeat(labelData, 3).SelectMany(x => x).Concat(lastLabelData).Concat(new byte[1]).ToArray(); + if (length > 253) + { + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + else + { + Assert.True(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs new file mode 100644 index 00000000000..4789e21c575 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -0,0 +1,331 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Globalization; +using System.Net; +using System.Net.Sockets; +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +internal sealed class LoopbackDnsServer : IDisposable +{ + private readonly Socket _dnsSocket; + private Socket? _tcpSocket; + + public IPEndPoint DnsEndPoint => (IPEndPoint)_dnsSocket.LocalEndPoint!; + + public LoopbackDnsServer() + { + _dnsSocket = new(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + _dnsSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + } + + public void Dispose() + { + _dnsSocket.Dispose(); + _tcpSocket?.Dispose(); + } + + private static async Task ProcessRequestCore(IPEndPoint remoteEndPoint, ArraySegment message, Func action, Memory responseBuffer) + { + DnsDataReader reader = new DnsDataReader(message); + + if (!reader.TryReadHeader(out DnsMessageHeader header) || + !reader.TryReadQuestion(out var name, out var type, out var @class)) + { + return 0; + } + + LoopbackDnsResponseBuilder responseBuilder = new(name.ToString(), type, @class); + responseBuilder.TransactionId = header.TransactionId; + responseBuilder.Flags = header.QueryFlags | QueryFlags.HasResponse; + responseBuilder.ResponseCode = QueryResponseCode.NoError; + + await action(responseBuilder, remoteEndPoint); + + return responseBuilder.Write(responseBuffer); + } + + public async Task ProcessUdpRequest(Func action) + { + byte[] buffer = ArrayPool.Shared.Rent(512); + try + { + EndPoint remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); + SocketReceiveFromResult result = await _dnsSocket.ReceiveFromAsync(buffer, remoteEndPoint); + + int bytesWritten = await ProcessRequestCore((IPEndPoint)result.RemoteEndPoint, new ArraySegment(buffer, 0, result.ReceivedBytes), action, buffer.AsMemory(0, 512)); + + await _dnsSocket.SendToAsync(buffer.AsMemory(0, bytesWritten), SocketFlags.None, result.RemoteEndPoint); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + public Task ProcessUdpRequest(Func action) + { + return ProcessUdpRequest((builder, _) => action(builder)); + } + + public async Task ProcessTcpRequest(Func action) + { + if (_tcpSocket is null) + { + _tcpSocket = new(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _tcpSocket.Bind(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)_dnsSocket.LocalEndPoint!).Port)); + _tcpSocket.Listen(); + } + + using Socket tcpClient = await _tcpSocket.AcceptAsync(); + + byte[] buffer = ArrayPool.Shared.Rent(8 * 1024); + try + { + int bytesRead = 0; + int length = -1; + while (length < 0 || bytesRead < length + 2) + { + int toRead = length < 0 ? 2 : length + 2 - bytesRead; + int read = await tcpClient.ReceiveAsync(buffer.AsMemory(bytesRead, toRead), SocketFlags.None); + bytesRead += read; + + if (length < 0 && bytesRead >= 2) + { + length = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); + } + } + + int bytesWritten = await ProcessRequestCore((IPEndPoint)tcpClient.RemoteEndPoint!, new ArraySegment(buffer, 2, length), action, buffer.AsMemory(2)); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(0, 2), (ushort)bytesWritten); + await tcpClient.SendAsync(buffer.AsMemory(0, bytesWritten + 2), SocketFlags.None); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + public Task ProcessTcpRequest(Func action) + { + return ProcessTcpRequest((builder, _) => action(builder)); + } +} + +internal sealed class LoopbackDnsResponseBuilder +{ + private static readonly SearchValues s_domainNameValidChars = SearchValues.Create("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_."); + + public LoopbackDnsResponseBuilder(string name, QueryType type, QueryClass @class) + { + Name = name; + Type = type; + Class = @class; + Questions.Add((name, type, @class)); + + if (name.AsSpan().ContainsAnyExcept(s_domainNameValidChars)) + { + throw new ArgumentException($"Invalid characters in domain name '{name}'"); + } + } + + public ushort TransactionId { get; set; } + public QueryFlags Flags { get; set; } + public QueryResponseCode ResponseCode { get; set; } + + public string Name { get; } + public QueryType Type { get; } + public QueryClass Class { get; } + + public List<(string, QueryType, QueryClass)> Questions { get; } = new List<(string, QueryType, QueryClass)>(); + public List Answers { get; } = new List(); + public List Authorities { get; } = new List(); + public List Additionals { get; } = new List(); + + public int Write(Memory responseBuffer) + { + DnsDataWriter writer = new(responseBuffer); + if (!writer.TryWriteHeader(new DnsMessageHeader + { + TransactionId = TransactionId, + QueryFlags = Flags | (QueryFlags)ResponseCode, + QueryCount = (ushort)Questions.Count, + AnswerCount = (ushort)Answers.Count, + AuthorityCount = (ushort)Authorities.Count, + AdditionalRecordCount = (ushort)Additionals.Count + })) + { + throw new InvalidOperationException("Failed to write header"); + } + + byte[] buffer = ArrayPool.Shared.Rent(512); + foreach (var (questionName, questionType, questionClass) in Questions) + { + if (!DnsPrimitives.TryWriteQName(buffer, questionName, out int length) || + !DnsPrimitives.TryReadQName(buffer.AsMemory(0, length), 0, out EncodedDomainName encodedName, out _)) + { + throw new InvalidOperationException("Failed to encode domain name"); + } + if (!writer.TryWriteQuestion(encodedName, questionType, questionClass)) + { + throw new InvalidOperationException("Failed to write question"); + } + } + ArrayPool.Shared.Return(buffer); + + foreach (var answer in Answers) + { + if (!writer.TryWriteResourceRecord(answer)) + { + throw new InvalidOperationException("Failed to write answer"); + } + } + + foreach (var authority in Authorities) + { + if (!writer.TryWriteResourceRecord(authority)) + { + throw new InvalidOperationException("Failed to write authority"); + } + } + + foreach (var additional in Additionals) + { + if (!writer.TryWriteResourceRecord(additional)) + { + throw new InvalidOperationException("Failed to write additional records"); + } + } + + return writer.Position; + } + + public byte[] GetMessageBytes() + { + byte[] buffer = ArrayPool.Shared.Rent(512); + try + { + int bytesWritten = Write(buffer.AsMemory(0, 512)); + return buffer.AsSpan(0, bytesWritten).ToArray(); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } +} + +internal static class LoopbackDnsServerExtensions +{ + private static readonly IdnMapping s_idnMapping = new IdnMapping(); + + private static EncodedDomainName EncodeDomainName(string name) + { + var encodedLabels = name.Split('.', StringSplitOptions.RemoveEmptyEntries).Select(label => (ReadOnlyMemory)Encoding.UTF8.GetBytes(s_idnMapping.GetAscii(label))) + .ToList(); + + return new EncodedDomainName(encodedLabels); + } + + public static List AddAddress(this List records, string name, int ttl, IPAddress address) + { + QueryType type = address.AddressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; + records.Add(new DnsResourceRecord(EncodeDomainName(name), type, QueryClass.Internet, ttl, address.GetAddressBytes())); + return records; + } + + public static List AddCname(this List records, string name, int ttl, string alias) + { + byte[] buff = new byte[256]; + if (!DnsPrimitives.TryWriteQName(buff, alias, out int length)) + { + throw new InvalidOperationException("Failed to encode domain name"); + } + + records.Add(new DnsResourceRecord(EncodeDomainName(name), QueryType.CNAME, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + return records; + } + + public static List AddService(this List records, string name, int ttl, ushort priority, ushort weight, ushort port, string target) + { + byte[] buff = new byte[256]; + + // https://www.rfc-editor.org/rfc/rfc2782 + if (!BinaryPrimitives.TryWriteUInt16BigEndian(buff, priority) || + !BinaryPrimitives.TryWriteUInt16BigEndian(buff.AsSpan(2), weight) || + !BinaryPrimitives.TryWriteUInt16BigEndian(buff.AsSpan(4), port) || + !DnsPrimitives.TryWriteQName(buff.AsSpan(6), target, out int length)) + { + throw new InvalidOperationException("Failed to encode SRV record"); + } + + length += 6; + + records.Add(new DnsResourceRecord(EncodeDomainName(name), QueryType.SRV, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + return records; + } + + public static List AddStartOfAuthority(this List records, string name, int ttl, string mname, string rname, uint serial, uint refresh, uint retry, uint expire, uint minimum) + { + byte[] buff = new byte[256]; + + // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 + if (!DnsPrimitives.TryWriteQName(buff, mname, out int w1) || + !DnsPrimitives.TryWriteQName(buff.AsSpan(w1), rname, out int w2) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2), serial) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 4), refresh) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 8), retry) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 12), expire) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 16), minimum)) + { + throw new InvalidOperationException("Failed to encode SOA record"); + } + + int length = w1 + w2 + 20; + + records.Add(new DnsResourceRecord(EncodeDomainName(name), QueryType.SOA, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + return records; + } +} + +internal static class DnsDataWriterExtensions +{ + internal static bool TryWriteResourceRecord(this DnsDataWriter writer, DnsResourceRecord record) + { + if (!TryWriteDomainName(writer, record.Name) || + !writer.TryWriteUInt16((ushort)record.Type) || + !writer.TryWriteUInt16((ushort)record.Class) || + !writer.TryWriteUInt32((uint)record.Ttl) || + !writer.TryWriteUInt16((ushort)record.Data.Length) || + !writer.TryWriteRawData(record.Data.Span)) + { + return false; + } + + return true; + } + + internal static bool TryWriteDomainName(this DnsDataWriter writer, EncodedDomainName name) + { + foreach (var label in name.Labels) + { + if (label.Length > 63) + { + throw new InvalidOperationException("Label length exceeds maximum of 63 bytes"); + } + + if (!writer.TryWriteByte((byte)label.Length) || + !writer.TryWriteRawData(label.Span)) + { + return false; + } + } + + // root label + return writer.TryWriteByte(0); + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs new file mode 100644 index 00000000000..f76621db93c --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using Microsoft.Extensions.Time.Testing; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public abstract class LoopbackDnsTestBase : IDisposable +{ + protected readonly ITestOutputHelper Output; + + internal LoopbackDnsServer DnsServer { get; } + private readonly Lazy _resolverLazy; + internal DnsResolver Resolver => _resolverLazy.Value; + internal ResolverOptions Options { get; } + protected readonly FakeTimeProvider TimeProvider; + + public LoopbackDnsTestBase(ITestOutputHelper output) + { + Output = output; + DnsServer = new(); + TimeProvider = new(); + Options = new([DnsServer.DnsEndPoint]) + { + Timeout = TimeSpan.FromSeconds(5), + Attempts = 1, + }; + _resolverLazy = new(InitializeResolver); + } + + DnsResolver InitializeResolver() + { + ServiceCollection services = new(); + services.AddXunitLogging(Output); + + // construct DnsResolver manually via internal constructor which accepts ResolverOptions + var resolver = new DnsResolver(TimeProvider, services.BuildServiceProvider().GetRequiredService>(), Options); + return resolver; + } + + public void Dispose() + { + DnsServer.Dispose(); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs new file mode 100644 index 00000000000..281ffbecd24 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class ResolvConfTests +{ + [Fact] + public void GetOptions() + { + var contents = @" +nameserver 10.96.0.10 +search default.svc.cluster.local svc.cluster.local cluster.local +options ndots:5 +@"; + + var reader = new StringReader(contents); + ResolverOptions options = ResolvConf.GetOptions(reader); + + IPEndPoint ipAddress = Assert.Single(options.Servers); + Assert.Equal(new IPEndPoint(IPAddress.Parse("10.96.0.10"), 53), ipAddress); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs new file mode 100644 index 00000000000..b87e1362f3d --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -0,0 +1,307 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using System.Net; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class ResolveAddressesTests : LoopbackDnsTestBase +{ + public ResolveAddressesTests(ITestOutputHelper output) : base(output) + { + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ResolveIPv4_NoData_Success(bool includeSoa) + { + string hostName = "nodata.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + if (includeSoa) + { + builder.Authorities.AddStartOfAuthority("ns.com", 240, "ns.com", "admin.ns.com", 1, 900, 180, 6048000, 3600); + } + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(results); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ResolveIPv4_NoSuchName_Success(bool includeSoa) + { + string hostName = "nosuchname.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.ResponseCode = QueryResponseCode.NameError; + if (includeSoa) + { + builder.Authorities.AddStartOfAuthority("ns.com", 240, "ns.com", "admin.ns.com", 1, 900, 180, 6048000, 3600); + } + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(results); + } + + [Theory] + [InlineData("www.resolveipv4.com")] + [InlineData("www.resolveipv4.com.")] + [InlineData("www.ř.com")] + public async Task ResolveIPv4_Simple_Success(string name) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddAddress(name, 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(name, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task ResolveIPv4_Aliases_InOrder_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-in-order.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddAddress("www.example3.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task ResolveIPv4_Aliases_OutOfOrder_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-out-of-order.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddAddress("www.example3.com", 3600, address); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task ResolveIPv4_Aliases_Loop_ReturnsEmpty() + { + string hostName = "alias-loop2.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddCname("www.example3.com", 3600, hostName); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_Aliases_Loop_Reverse_ReturnsEmpty() + { + string hostName = "alias-loop2.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example3.com", 3600, hostName); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_Alias_And_Address() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-address.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddAddress("www.example2.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_DuplicateAlias() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "duplicate-alias.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example4.com"); + builder.Answers.AddAddress("www.example2.com", 3600, address); + builder.Answers.AddAddress("www.example4.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_Aliases_NotFound_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-no-found.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + + // extra address in the answer not connected to the above + builder.Answers.AddAddress("www.example4.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIP_InvalidAddressFamily_Throws() + { + await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("invalid-af.test", AddressFamily.Unknown)); + } + + [Theory] + [InlineData(AddressFamily.InterNetwork, "127.0.0.1")] + [InlineData(AddressFamily.InterNetworkV6, "::1")] + public async Task ResolveIP_Localhost_ReturnsLoopback(AddressFamily family, string addressAsString) + { + IPAddress address = IPAddress.Parse(addressAsString); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("localhost", family); + AddressResult result = Assert.Single(results); + + Assert.Equal(address, result.Address); + } + + [Fact] + public async Task Resolve_Timeout_ReturnsEmpty() + { + Options.Timeout = TimeSpan.FromSeconds(1); + AddressResult[] result = await Resolver.ResolveIPAddressesAsync("timeout-empty.test", AddressFamily.InterNetwork); + Assert.Empty(result); + } + + [Theory] + [InlineData("not-example.com", (int)QueryType.A, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.AAAA, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.A, 0)] + public async Task Resolve_QuestionMismatch_ReturnsEmpty(string name, int type, int @class) + { + Options.Timeout = TimeSpan.FromSeconds(1); + + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Questions[0] = (name, (QueryType)type, (QueryClass)@class); + builder.Answers.AddAddress("www.example4.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); + Assert.Empty(result); + } + + [Fact] + public async Task Resolve_HeaderMismatch_Ignores() + { + string name = "header-mismatch.test"; + Options.Timeout = TimeSpan.FromSeconds(5); + + SemaphoreSlim responseSemaphore = new SemaphoreSlim(0, 1); + SemaphoreSlim requestSemaphore = new SemaphoreSlim(0, 1); + + IPEndPoint clientAddress = null!; + + IPAddress address = IPAddress.Parse("172.213.245.111"); + ushort transactionId = 0x1234; + _ = DnsServer.ProcessUdpRequest((builder, clientAddr) => + { + clientAddress = clientAddr; + transactionId = (ushort)(builder.TransactionId + 1); + + builder.Answers.AddAddress(name, 3600, address); + requestSemaphore.Release(); + return responseSemaphore.WaitAsync(); + }); + + ValueTask task = Resolver.ResolveIPAddressesAsync(name, AddressFamily.InterNetwork); + + await requestSemaphore.WaitAsync().WaitAsync(Options.Timeout); + + using Socket socket = new Socket(clientAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + LoopbackDnsResponseBuilder responseBuilder = new LoopbackDnsResponseBuilder(name, QueryType.A, QueryClass.Internet) + { + TransactionId = transactionId, + ResponseCode = QueryResponseCode.NoError + }; + + responseBuilder.Questions.Add((name, QueryType.A, QueryClass.Internet)); + responseBuilder.Answers.AddAddress(name, 3600, IPAddress.Loopback); + socket.SendTo(responseBuilder.GetMessageBytes(), clientAddress); + + responseSemaphore.Release(); + + AddressResult[] results = await task; + AddressResult result = Assert.Single(results); + + Assert.Equal(address, result.Address); + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs new file mode 100644 index 00000000000..e1cd1df2959 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class ResolveServiceTests : LoopbackDnsTestBase +{ + public ResolveServiceTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public async Task ResolveService_Simple_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddService("_s0._tcp.example.com", 3600, 1, 2, 8080, "www.example.com"); + builder.Additionals.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + ServiceResult[] results = await Resolver.ResolveServiceAsync("_s0._tcp.example.com"); + + ServiceResult result = Assert.Single(results); + Assert.Equal("www.example.com", result.Target); + Assert.Equal(1, result.Priority); + Assert.Equal(2, result.Weight); + Assert.Equal(8080, result.Port); + + AddressResult addressResult = Assert.Single(result.Addresses); + Assert.Equal(address, addressResult.Address); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs new file mode 100644 index 00000000000..800905d1ac5 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs @@ -0,0 +1,309 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Sockets; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class RetryTests : LoopbackDnsTestBase +{ + public RetryTests(ITestOutputHelper output) : base(output) + { + Options.Attempts = 3; + } + + private Task SetupUdpProcessFunction(LoopbackDnsServer server, Func func) + { + return Task.Run(async () => + { + try + { + while (true) + { + await server.ProcessUdpRequest(func); + } + } + catch (Exception ex) + { + Output.WriteLine($"UDP server stopped with exception: {ex}"); + // Test teardown closed the socket, ignore + } + }); + } + + private Task SetupUdpProcessFunction(Func func) + { + return SetupUdpProcessFunction(DnsServer, func); + } + + [Fact] + public async Task Retry_Simple_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "retry-simple-success.com"; + + int attempt = 0; + + Task t = SetupUdpProcessFunction(builder => + { + attempt++; + if (attempt == Options.Attempts) + { + builder.Answers.AddAddress(hostName, 3600, address); + } + else + { + builder.ResponseCode = QueryResponseCode.ServerFailure; + } + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + public enum PersistentErrorType + { + NotImplemented, + Refused, + MalformedResponse + } + + [Theory] + [InlineData(PersistentErrorType.NotImplemented)] + [InlineData(PersistentErrorType.Refused)] + [InlineData(PersistentErrorType.MalformedResponse)] + public async Task PersistentErrorsResponseCode_FailoverToNextServer(PersistentErrorType type) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "www.persistent.com"; + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper(hostName, + builder => + { + primaryAttempt++; + switch (type) + { + case PersistentErrorType.NotImplemented: + builder.ResponseCode = QueryResponseCode.NotImplemented; + break; + + case PersistentErrorType.Refused: + builder.ResponseCode = QueryResponseCode.Refused; + break; + + case PersistentErrorType.MalformedResponse: + builder.ResponseCode = (QueryResponseCode)0xFF; + break; + } + return Task.CompletedTask; + }, + builder => + { + secondaryAttempt++; + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + Assert.Equal(1, primaryAttempt); + Assert.Equal(1, secondaryAttempt); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + public enum DefinitveAnswerType + { + NoError, + NoData, + NameError, + } + + [Theory] + [InlineData(DefinitveAnswerType.NoError, false)] + [InlineData(DefinitveAnswerType.NoData, false)] + [InlineData(DefinitveAnswerType.NoData, true)] + [InlineData(DefinitveAnswerType.NameError, false)] + [InlineData(DefinitveAnswerType.NameError, true)] + public async Task DefinitiveAnswers_NoRetryOrFailover(DefinitveAnswerType type, bool additionalData) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "www.retry.com"; + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper(hostName, + builder => + { + primaryAttempt++; + switch (type) + { + case DefinitveAnswerType.NoError: + builder.ResponseCode = QueryResponseCode.NoError; + builder.Answers.AddAddress(hostName, 3600, address); + break; + + case DefinitveAnswerType.NoData: + builder.ResponseCode = QueryResponseCode.NoError; + break; + + case DefinitveAnswerType.NameError: + builder.ResponseCode = QueryResponseCode.NameError; + break; + } + + if (additionalData) + { + builder.Authorities.AddStartOfAuthority(hostName, 300, "ns1.example.com", "hostmaster.example.com", 2023101001, 1, 3600, 300, 86400); + } + + return Task.CompletedTask; + }, + builder => + { + secondaryAttempt++; + builder.ResponseCode = QueryResponseCode.Refused; + return Task.CompletedTask; + }); + + Assert.Equal(1, primaryAttempt); + Assert.Equal(0, secondaryAttempt); + + if (type == DefinitveAnswerType.NoError) + { + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + else + { + Assert.Empty(results); + } + } + + [Fact] + public async Task ExhaustedRetries_FailoverToNextServer() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "ExhaustedRetriesFailoverToNextServer"; + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper(hostName, + builder => + { + primaryAttempt++; + builder.ResponseCode = QueryResponseCode.ServerFailure; + return Task.CompletedTask; + }, + builder => + { + secondaryAttempt++; + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + Assert.Equal(Options.Attempts, primaryAttempt); + Assert.Equal(1, secondaryAttempt); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + public enum TransientErrorType + { + Timeout, + ServerFailure, + // TODO: simulate NetworkErrors + } + + [Theory] + [InlineData(TransientErrorType.Timeout)] + [InlineData(TransientErrorType.ServerFailure)] + public async Task TransientError_RetryOnSameServer(TransientErrorType type) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "www.transient.com"; + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper(hostName, + async builder => + { + primaryAttempt++; + if (primaryAttempt == 1) + { + switch (type) + { + case TransientErrorType.Timeout: + await Task.Delay(Options.Timeout.Multiply(1.5)); + builder.Answers.AddAddress(hostName, 3600, address); + break; + + case TransientErrorType.ServerFailure: + builder.ResponseCode = QueryResponseCode.ServerFailure; + break; + } + } + else + { + builder.Answers.AddAddress(hostName, 3600, address); + } + }, + builder => + { + secondaryAttempt++; + builder.ResponseCode = QueryResponseCode.Refused; + return Task.CompletedTask; + }); + + Assert.Equal(2, primaryAttempt); + Assert.Equal(0, secondaryAttempt); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + private async Task RunWithFallbackServerHelper(string name, Func primaryHandler, Func fallbackHandler) + { + Task t = SetupUdpProcessFunction(primaryHandler); + using LoopbackDnsServer fallbackServer = new LoopbackDnsServer(); + Task t2 = SetupUdpProcessFunction(fallbackServer, fallbackHandler); + + Options.Servers = [DnsServer.DnsEndPoint, fallbackServer.DnsEndPoint]; + + return await Resolver.ResolveIPAddressesAsync(name, AddressFamily.InterNetwork); + } + + [Fact] + public async Task NameError_NoRetry() + { + int counter = 0; + Task t = SetupUdpProcessFunction(builder => + { + counter++; + // authoritative answer that the name does not exist + builder.ResponseCode = QueryResponseCode.NameError; + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("nameerror-noretry", AddressFamily.InterNetwork); + + Assert.Empty(results); + Assert.Equal(1, counter); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs new file mode 100644 index 00000000000..40841e3d11a --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using System.Net; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class TcpFailoverTests : LoopbackDnsTestBase +{ + public TcpFailoverTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public async Task TcpFailover_Simple_Success() + { + string hostName = "tcp-simple.test"; + IPAddress address = IPAddress.Parse("172.213.245.111"); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + _ = DnsServer.ProcessTcpRequest(builder => + { + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task TcpFailover_ServerClosesWithoutData_EmptyResult() + { + string hostName = "tcp-server-closes.test"; + Options.Attempts = 1; + Options.Timeout = TimeSpan.FromSeconds(60); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + Task serverTask = DnsServer.ProcessTcpRequest(builder => + { + throw new InvalidOperationException("This forces closing the socket without writing any data"); + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + Assert.Empty(results); + + await Assert.ThrowsAsync(() => serverTask); + } + + [Fact] + public async Task TcpFailover_TcpNotAvailable_EmptyResult() + { + string hostName = "tcp-not-available.test"; + Options.Attempts = 1; + Options.Timeout = TimeSpan.FromMilliseconds(100000); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(results); + } + + [Fact] + public async Task TcpFailover_HeaderMismatch_ReturnsEmpty() + { + string hostName = "tcp-header-mismatch.test"; + Options.Timeout = TimeSpan.FromSeconds(1); + IPAddress address = IPAddress.Parse("172.213.245.111"); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + _ = DnsServer.ProcessTcpRequest(builder => + { + builder.TransactionId++; + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] result = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(result); + } + + [Theory] + [InlineData("not-example.com", (int)QueryType.A, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.AAAA, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.A, 0)] + public async Task TcpFailover_QuestionMismatch_ReturnsEmpty(string name, int type, int @class) + { + string hostName = "tcp-question-mismatch.test"; + Options.Timeout = TimeSpan.FromSeconds(1); + IPAddress address = IPAddress.Parse("172.213.245.111"); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + _ = DnsServer.ProcessTcpRequest(builder => + { + builder.Questions[0] = (name, (QueryType)type, (QueryClass)@class); + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] result = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(result); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs index 96fd46d47ea..c2751823c65 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs @@ -6,10 +6,11 @@ using Xunit; using Yarp.ReverseProxy.Configuration; using System.Net; -using DnsClient; -using DnsClient.Protocol; +using System.Net.Sockets; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.Extensions.ServiceDiscovery.Yarp.Tests; @@ -231,7 +232,10 @@ public async Task ServiceDiscoveryDestinationResolverTests_Configuration_Disallo [Fact] public async Task ServiceDiscoveryDestinationResolverTests_Dns() { + DnsResolver resolver = new DnsResolver(TimeProvider.System, NullLogger.Instance); + await using var services = new ServiceCollection() + .AddSingleton(resolver) .AddServiceDiscoveryCore() .AddDnsServiceEndpointProvider() .BuildServiceProvider(); @@ -265,32 +269,22 @@ public async Task ServiceDiscoveryDestinationResolverTests_Dns() [Fact] public async Task ServiceDiscoveryDestinationResolverTests_DnsSrv() { - var dnsClientMock = new FakeDnsClient + var dnsClientMock = new FakeDnsResolver { - QueryAsyncFunc = (query, queryType, queryClass, cancellationToken) => + ResolveServiceAsyncFunc = (name, cancellationToken) => { - var response = new FakeDnsQueryResponse - { - Answers = new List - { - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 66, 8888, DnsString.Parse("srv-a")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 9999, DnsString.Parse("srv-b")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 7777, DnsString.Parse("srv-c")) - }, - Additionals = new List - { - new ARecord(new ResourceRecordInfo("srv-a", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Parse("10.10.10.10")), - new ARecord(new ResourceRecordInfo("srv-b", ResourceRecordType.AAAA, queryClass, 64, 0), IPAddress.IPv6Loopback), - new ARecord(new ResourceRecordInfo("srv-c", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Loopback), - } - }; - - return Task.FromResult(response); + ServiceResult[] response = [ + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 66, 8888, "srv-a", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Parse("10.10.10.10"))]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 9999, "srv-b", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.IPv6Loopback)]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 7777, "srv-c", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Loopback)]) + ]; + + return ValueTask.FromResult(response); } }; await using var services = new ServiceCollection() - .AddSingleton(dnsClientMock) + .AddSingleton(dnsClientMock) .AddServiceDiscoveryCore() .AddDnsSrvServiceEndpointProvider(options => options.QuerySuffix = ".ns") .BuildServiceProvider(); @@ -313,56 +307,17 @@ public async Task ServiceDiscoveryDestinationResolverTests_DnsSrv() a => Assert.Equal("https://127.0.0.1:7777/", a)); } - private sealed class FakeDnsClient : IDnsQuery + private sealed class FakeDnsResolver : IDnsResolver { - public Func>? QueryAsyncFunc { get; set; } - - public IDnsQueryResponse Query(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse Query(DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse Query(DnsQuestion question, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException(); - public Task QueryAsync(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) - => QueryAsyncFunc!(query, queryType, queryClass, cancellationToken); - public Task QueryAsync(DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryAsync(DnsQuestion question, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryCache(DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse QueryCache(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryReverse(IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryReverse(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException(); - public Task QueryReverseAsync(IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryReverseAsync(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, DnsQuestion question, DnsQueryOptions queryOptions) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, DnsQuestion question, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress, DnsQueryOptions queryOptions) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - } + public Func>? ResolveIPAddressesAsyncFunc { get; set; } + public ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc!.Invoke(name, addressFamily, cancellationToken); - private sealed class FakeDnsQueryResponse : IDnsQueryResponse - { - public IReadOnlyList? Questions { get; set; } - public IReadOnlyList? Additionals { get; set; } - public IEnumerable? AllRecords { get; set; } - public IReadOnlyList? Answers { get; set; } - public IReadOnlyList? Authorities { get; set; } - public string? AuditTrail { get; set; } - public string? ErrorMessage { get; set; } - public bool HasError { get; set; } - public DnsResponseHeader? Header { get; set; } - public int MessageSize { get; set; } - public NameServer? NameServer { get; set; } - public DnsQuerySettings? Settings { get; set; } + public Func>? ResolveIPAddressesAsyncFunc2 { get; set; } + + public ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc2!.Invoke(name, cancellationToken); + + public Func>? ResolveServiceAsyncFunc { get; set; } + + public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) => ResolveServiceAsyncFunc!.Invoke(name, cancellationToken); } }