diff --git a/Directory.Packages.props b/Directory.Packages.props
index e456e670d67..490e518213e 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -77,7 +77,6 @@
-
@@ -126,6 +125,8 @@
+
+
diff --git a/eng/Versions.props b/eng/Versions.props
index d032184c62b..636b494b864 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -50,6 +50,8 @@
9.0.6
10.0.0-preview.5.25277.114
+
+ 2.1.1
2.23.32-alpha
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs
index 6cc9f92bc46..7a2d1b632e0 100644
--- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs
@@ -4,6 +4,7 @@
using System.Net;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
+using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
namespace Microsoft.Extensions.ServiceDiscovery.Dns;
@@ -12,6 +13,7 @@ internal sealed partial class DnsServiceEndpointProvider(
string hostName,
IOptionsMonitor options,
ILogger logger,
+ IDnsResolver resolver,
TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature
{
protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor;
@@ -29,17 +31,14 @@ protected override async Task ResolveAsyncCore()
var endpoints = new List();
var ttl = DefaultRefreshPeriod;
Log.AddressQuery(logger, ServiceName, hostName);
- var addresses = await System.Net.Dns.GetHostAddressesAsync(hostName, ShutdownToken).ConfigureAwait(false);
+
+ var now = _timeProvider.GetUtcNow().DateTime;
+ var addresses = await resolver.ResolveIPAddressesAsync(hostName, ShutdownToken).ConfigureAwait(false);
+
foreach (var address in addresses)
{
- var serviceEndpoint = ServiceEndpoint.Create(new IPEndPoint(address, 0));
- serviceEndpoint.Features.Set(this);
- if (options.CurrentValue.ShouldApplyHostNameMetadata(serviceEndpoint))
- {
- serviceEndpoint.Features.Set(this);
- }
-
- endpoints.Add(serviceEndpoint);
+ ttl = MinTtl(now, address.ExpiresAt, ttl);
+ endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, port: 0)));
}
if (endpoints.Count == 0)
@@ -48,5 +47,23 @@ protected override async Task ResolveAsyncCore()
}
SetResult(endpoints, ttl);
+
+ static TimeSpan MinTtl(DateTime now, DateTime expiresAt, TimeSpan existing)
+ {
+ var candidate = expiresAt - now;
+ return candidate < existing ? candidate : existing;
+ }
+
+ ServiceEndpoint CreateEndpoint(EndPoint endPoint)
+ {
+ var serviceEndpoint = ServiceEndpoint.Create(endPoint);
+ serviceEndpoint.Features.Set(this);
+ if (options.CurrentValue.ShouldApplyHostNameMetadata(serviceEndpoint))
+ {
+ serviceEndpoint.Features.Set(this);
+ }
+
+ return serviceEndpoint;
+ }
}
}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs
index 6c69cc7a760..311c06f631a 100644
--- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs
@@ -14,7 +14,7 @@ internal abstract partial class DnsServiceEndpointProviderBase : IServiceEndpoin
private readonly object _lock = new();
private readonly ILogger _logger;
private readonly CancellationTokenSource _disposeCancellation = new();
- private readonly TimeProvider _timeProvider;
+ protected readonly TimeProvider _timeProvider;
private long _lastRefreshTimeStamp;
private Task _resolveTask = Task.CompletedTask;
private bool _hasEndpoints;
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs
index c241ad89dd3..1da21411e64 100644
--- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs
@@ -4,18 +4,20 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
+using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
namespace Microsoft.Extensions.ServiceDiscovery.Dns;
internal sealed partial class DnsServiceEndpointProviderFactory(
IOptionsMonitor options,
ILogger logger,
+ IDnsResolver resolver,
TimeProvider timeProvider) : IServiceEndpointProviderFactory
{
///
public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider)
{
- provider = new DnsServiceEndpointProvider(query, hostName: query.ServiceName, options, logger, timeProvider);
+ provider = new DnsServiceEndpointProvider(query, hostName: query.ServiceName, options, logger, resolver, timeProvider);
return true;
}
}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs
index c174cda4f68..6d5ade5059e 100644
--- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs
@@ -2,10 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Net;
-using DnsClient;
-using DnsClient.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
+using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
namespace Microsoft.Extensions.ServiceDiscovery.Dns;
@@ -15,7 +14,7 @@ internal sealed partial class DnsSrvServiceEndpointProvider(
string hostName,
IOptionsMonitor options,
ILogger logger,
- IDnsQuery dnsClient,
+ IDnsResolver resolver,
TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature
{
protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor;
@@ -35,56 +34,36 @@ protected override async Task ResolveAsyncCore()
var endpoints = new List();
var ttl = DefaultRefreshPeriod;
Log.SrvQuery(logger, ServiceName, srvQuery);
- var result = await dnsClient.QueryAsync(srvQuery, QueryType.SRV, cancellationToken: ShutdownToken).ConfigureAwait(false);
- if (result.HasError)
- {
- throw CreateException(srvQuery, result.ErrorMessage);
- }
- var lookupMapping = new Dictionary();
- foreach (var record in result.Additionals.Where(x => x is AddressRecord or CNameRecord))
- {
- ttl = MinTtl(record, ttl);
- lookupMapping[record.DomainName] = record;
- }
+ var now = _timeProvider.GetUtcNow().DateTime;
+ var result = await resolver.ResolveServiceAsync(srvQuery, cancellationToken: ShutdownToken).ConfigureAwait(false);
- var srvRecords = result.Answers.OfType();
- foreach (var record in srvRecords)
+ foreach (var record in result)
{
- if (!lookupMapping.TryGetValue(record.Target, out var targetRecord))
- {
- continue;
- }
+ ttl = MinTtl(now, record.ExpiresAt, ttl);
- ttl = MinTtl(record, ttl);
- if (targetRecord is AddressRecord addressRecord)
+ if (record.Addresses.Length > 0)
{
- endpoints.Add(CreateEndpoint(new IPEndPoint(addressRecord.Address, record.Port)));
+ foreach (var address in record.Addresses)
+ {
+ ttl = MinTtl(now, address.ExpiresAt, ttl);
+ endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, record.Port)));
+ }
}
- else if (targetRecord is CNameRecord canonicalNameRecord)
+ else
{
- endpoints.Add(CreateEndpoint(new DnsEndPoint(canonicalNameRecord.CanonicalName.Value.TrimEnd('.'), record.Port)));
+ endpoints.Add(CreateEndpoint(new DnsEndPoint(record.Target.TrimEnd('.'), record.Port)));
}
}
SetResult(endpoints, ttl);
- static TimeSpan MinTtl(DnsResourceRecord record, TimeSpan existing)
+ static TimeSpan MinTtl(DateTime now, DateTime expiresAt, TimeSpan existing)
{
- var candidate = TimeSpan.FromSeconds(record.TimeToLive);
+ var candidate = expiresAt - now;
return candidate < existing ? candidate : existing;
}
- InvalidOperationException CreateException(string dnsName, string errorMessage)
- {
- var msg = errorMessage switch
- {
- { Length: > 0 } => $"No DNS records were found for service '{ServiceName}' (DNS name: '{dnsName}'): {errorMessage}.",
- _ => $"No DNS records were found for service '{ServiceName}' (DNS name: '{dnsName}')."
- };
- return new InvalidOperationException(msg);
- }
-
ServiceEndpoint CreateEndpoint(EndPoint endPoint)
{
var serviceEndpoint = ServiceEndpoint.Create(endPoint);
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs
index fd0cb28353d..085ee30123b 100644
--- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs
@@ -2,29 +2,29 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics.CodeAnalysis;
-using DnsClient;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
+using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
namespace Microsoft.Extensions.ServiceDiscovery.Dns;
internal sealed partial class DnsSrvServiceEndpointProviderFactory(
IOptionsMonitor options,
ILogger logger,
- IDnsQuery dnsClient,
+ IDnsResolver resolver,
TimeProvider timeProvider) : IServiceEndpointProviderFactory
{
private static readonly string s_serviceAccountPath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount");
private static readonly string s_serviceAccountNamespacePath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount", "namespace");
private static readonly string s_resolveConfPath = Path.Combine($"{Path.DirectorySeparatorChar}etc", "resolv.conf");
- private readonly string? _querySuffix = options.CurrentValue.QuerySuffix ?? GetKubernetesHostDomain();
+ private readonly string? _querySuffix = options.CurrentValue.QuerySuffix?.TrimStart('.') ?? GetKubernetesHostDomain();
///
public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider)
{
// If a default namespace is not specified, then this provider will attempt to infer the namespace from the service name, but only when running inside Kubernetes.
// Kubernetes DNS spec: https://github.com/kubernetes/dns/blob/master/docs/specification.md
- // SRV records are available for headless services with named ports.
+ // SRV records are available for headless services with named ports.
// They take the form $"_{portName}._{protocol}.{serviceName}.{namespace}.{suffix}"
// The suffix (after the service name) can be parsed from /etc/resolv.conf
// Otherwise, the namespace can be read from /var/run/secrets/kubernetes.io/serviceaccount/namespace and combined with an assumed suffix of "svc.cluster.local".
@@ -39,7 +39,7 @@ public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] ou
var portName = query.EndpointName ?? "default";
var srvQuery = $"_{portName}._tcp.{query.ServiceName}.{_querySuffix}";
- provider = new DnsSrvServiceEndpointProvider(query, srvQuery, hostName: query.ServiceName, options, logger, dnsClient, timeProvider);
+ provider = new DnsSrvServiceEndpointProvider(query, srvQuery, hostName: query.ServiceName, options, logger, resolver, timeProvider);
return true;
}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj
index 9854c93c01a..3aba9b3aaea 100644
--- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj
@@ -9,7 +9,6 @@
-
@@ -23,7 +22,9 @@
-
+
+
+
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs
new file mode 100644
index 00000000000..094df3040d1
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs
@@ -0,0 +1,133 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Buffers.Binary;
+using System.Diagnostics;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal struct DnsDataReader : IDisposable
+{
+ public ArraySegment MessageBuffer { get; private set; }
+ bool _returnToPool;
+ private int _position;
+
+ public DnsDataReader(ArraySegment buffer, bool returnToPool = false)
+ {
+ MessageBuffer = buffer;
+ _position = 0;
+ _returnToPool = returnToPool;
+ }
+
+ public bool TryReadHeader(out DnsMessageHeader header)
+ {
+ Debug.Assert(_position == 0);
+
+ if (!DnsPrimitives.TryReadMessageHeader(MessageBuffer.AsSpan(), out header, out int bytesRead))
+ {
+ header = default;
+ return false;
+ }
+
+ _position += bytesRead;
+ return true;
+ }
+
+ internal bool TryReadQuestion(out EncodedDomainName name, out QueryType type, out QueryClass @class)
+ {
+ if (!TryReadDomainName(out name) ||
+ !TryReadUInt16(out ushort typeAsInt) ||
+ !TryReadUInt16(out ushort classAsInt))
+ {
+ type = 0;
+ @class = 0;
+ return false;
+ }
+
+ type = (QueryType)typeAsInt;
+ @class = (QueryClass)classAsInt;
+ return true;
+ }
+
+ public bool TryReadUInt16(out ushort value)
+ {
+ if (MessageBuffer.Count - _position < 2)
+ {
+ value = 0;
+ return false;
+ }
+
+ value = BinaryPrimitives.ReadUInt16BigEndian(MessageBuffer.AsSpan(_position));
+ _position += 2;
+ return true;
+ }
+
+ public bool TryReadUInt32(out uint value)
+ {
+ if (MessageBuffer.Count - _position < 4)
+ {
+ value = 0;
+ return false;
+ }
+
+ value = BinaryPrimitives.ReadUInt32BigEndian(MessageBuffer.AsSpan(_position));
+ _position += 4;
+ return true;
+ }
+
+ public bool TryReadResourceRecord(out DnsResourceRecord record)
+ {
+ if (!TryReadDomainName(out EncodedDomainName name) ||
+ !TryReadUInt16(out ushort type) ||
+ !TryReadUInt16(out ushort @class) ||
+ !TryReadUInt32(out uint ttl) ||
+ !TryReadUInt16(out ushort dataLength) ||
+ MessageBuffer.Count - _position < dataLength)
+ {
+ record = default;
+ return false;
+ }
+
+ ReadOnlyMemory data = MessageBuffer.AsMemory(_position, dataLength);
+ _position += dataLength;
+
+ record = new DnsResourceRecord(name, (QueryType)type, (QueryClass)@class, (int)ttl, data);
+ return true;
+ }
+
+ public bool TryReadDomainName(out EncodedDomainName name)
+ {
+ if (DnsPrimitives.TryReadQName(MessageBuffer, _position, out name, out int bytesRead))
+ {
+ _position += bytesRead;
+ return true;
+ }
+
+ return false;
+ }
+
+ public bool TryReadSpan(int length, out ReadOnlySpan name)
+ {
+ if (MessageBuffer.Count - _position < length)
+ {
+ name = default;
+ return false;
+ }
+
+ name = MessageBuffer.AsSpan(_position, length);
+ _position += length;
+ return true;
+ }
+
+ public void Dispose()
+ {
+ if (_returnToPool && MessageBuffer.Array != null)
+ {
+ ArrayPool.Shared.Return(MessageBuffer.Array);
+ }
+
+ _returnToPool = false;
+ MessageBuffer = default;
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs
new file mode 100644
index 00000000000..a0a11f0b808
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs
@@ -0,0 +1,121 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers.Binary;
+using System.Diagnostics;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal sealed class DnsDataWriter
+{
+ private readonly Memory _buffer;
+ private int _position;
+
+ internal DnsDataWriter(Memory buffer)
+ {
+ _buffer = buffer;
+ _position = 0;
+ }
+
+ public int Position => _position;
+
+ internal bool TryWriteHeader(in DnsMessageHeader header)
+ {
+ if (!DnsPrimitives.TryWriteMessageHeader(_buffer.Span.Slice(_position), header, out int written))
+ {
+ return false;
+ }
+
+ _position += written;
+ return true;
+ }
+
+ internal bool TryWriteQuestion(EncodedDomainName name, QueryType type, QueryClass @class)
+ {
+ if (!TryWriteDomainName(name) ||
+ !TryWriteUInt16((ushort)type) ||
+ !TryWriteUInt16((ushort)@class))
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ private bool TryWriteDomainName(EncodedDomainName name)
+ {
+ foreach (var label in name.Labels)
+ {
+ // this should be already validated by the caller
+ Debug.Assert(label.Length <= 63, "Label length must not exceed 63 bytes.");
+
+ if (!TryWriteByte((byte)label.Length) ||
+ !TryWriteRawData(label.Span))
+ {
+ return false;
+ }
+ }
+
+ // root label
+ return TryWriteByte(0);
+ }
+
+ internal bool TryWriteDomainName(string name)
+ {
+ if (DnsPrimitives.TryWriteQName(_buffer.Span.Slice(_position), name, out int written))
+ {
+ _position += written;
+ return true;
+ }
+
+ return false;
+ }
+
+ internal bool TryWriteByte(byte value)
+ {
+ if (_buffer.Length - _position < 1)
+ {
+ return false;
+ }
+
+ _buffer.Span[_position] = value;
+ _position += 1;
+ return true;
+ }
+
+ internal bool TryWriteUInt16(ushort value)
+ {
+ if (_buffer.Length - _position < 2)
+ {
+ return false;
+ }
+
+ BinaryPrimitives.WriteUInt16BigEndian(_buffer.Span.Slice(_position), value);
+ _position += 2;
+ return true;
+ }
+
+ internal bool TryWriteUInt32(uint value)
+ {
+ if (_buffer.Length - _position < 4)
+ {
+ return false;
+ }
+
+ BinaryPrimitives.WriteUInt32BigEndian(_buffer.Span.Slice(_position), value);
+ _position += 4;
+ return true;
+ }
+
+ internal bool TryWriteRawData(ReadOnlySpan value)
+ {
+ if (_buffer.Length - _position < value.Length)
+ {
+ return false;
+ }
+
+ value.CopyTo(_buffer.Span.Slice(_position));
+ _position += value.Length;
+ return true;
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs
new file mode 100644
index 00000000000..b22273a04f2
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs
@@ -0,0 +1,36 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+// RFC 1035 4.1.1. Header section format
+internal struct DnsMessageHeader
+{
+ internal const int HeaderLength = 12;
+ public ushort TransactionId { get; set; }
+
+ internal QueryFlags QueryFlags { get; set; }
+
+ public ushort QueryCount { get; set; }
+
+ public ushort AnswerCount { get; set; }
+
+ public ushort AuthorityCount { get; set; }
+
+ public ushort AdditionalRecordCount { get; set; }
+
+ public QueryResponseCode ResponseCode
+ {
+ get => (QueryResponseCode)(QueryFlags & QueryFlags.ResponseCodeMask);
+ }
+
+ public bool IsResultTruncated
+ {
+ get => (QueryFlags & QueryFlags.ResultTruncated) != 0;
+ }
+
+ public bool IsResponse
+ {
+ get => (QueryFlags & QueryFlags.HasResponse) != 0;
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs
new file mode 100644
index 00000000000..e549abe2576
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs
@@ -0,0 +1,318 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Buffers.Binary;
+using System.Globalization;
+using System.Text;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal static class DnsPrimitives
+{
+ // Maximum length of a domain name in ASCII (excluding trailing dot)
+ internal const int MaxDomainNameLength = 253;
+
+ internal static bool TryReadMessageHeader(ReadOnlySpan buffer, out DnsMessageHeader header, out int bytesRead)
+ {
+ // RFC 1035 4.1.1. Header section format
+ if (buffer.Length < DnsMessageHeader.HeaderLength)
+ {
+ header = default;
+ bytesRead = 0;
+ return false;
+ }
+
+ header = new DnsMessageHeader
+ {
+ TransactionId = BinaryPrimitives.ReadUInt16BigEndian(buffer),
+ QueryFlags = (QueryFlags)BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(2)),
+ QueryCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(4)),
+ AnswerCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(6)),
+ AuthorityCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(8)),
+ AdditionalRecordCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(10))
+ };
+
+ bytesRead = DnsMessageHeader.HeaderLength;
+ return true;
+ }
+
+ internal static bool TryWriteMessageHeader(Span buffer, DnsMessageHeader header, out int bytesWritten)
+ {
+ // RFC 1035 4.1.1. Header section format
+ if (buffer.Length < DnsMessageHeader.HeaderLength)
+ {
+ bytesWritten = 0;
+ return false;
+ }
+
+ BinaryPrimitives.WriteUInt16BigEndian(buffer, header.TransactionId);
+ BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(2), (ushort)header.QueryFlags);
+ BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(4), header.QueryCount);
+ BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(6), header.AnswerCount);
+ BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(8), header.AuthorityCount);
+ BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(10), header.AdditionalRecordCount);
+
+ bytesWritten = DnsMessageHeader.HeaderLength;
+ return true;
+ }
+
+ // https://www.rfc-editor.org/rfc/rfc1035#section-2.3.4
+ // labels 63 octets or less
+ // name 255 octets or less
+
+ private static readonly SearchValues s_domainNameValidChars = SearchValues.Create("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.");
+ private static readonly IdnMapping s_idnMapping = new IdnMapping();
+ internal static bool TryWriteQName(Span destination, string name, out int written)
+ {
+ written = 0;
+
+ //
+ // RFC 1035 4.1.2.
+ //
+ // a domain name represented as a sequence of labels, where
+ // each label consists of a length octet followed by that
+ // number of octets. The domain name terminates with the
+ // zero length octet for the null label of the root. Note
+ // that this field may be an odd number of octets; no
+ // padding is used.
+ //
+ if (!Ascii.IsValid(name))
+ {
+ // IDN name, apply punycode
+ try
+ {
+ // IdnMapping performs some validation internally (such as label
+ // and domain name lengths), but is more relaxed than RFC
+ // 1035 (e.g. allows ~ chars), so even if this conversion does
+ // not throw, we still need to perform additional validation
+ name = s_idnMapping.GetAscii(name);
+ }
+ catch
+ {
+ return false;
+ }
+ }
+
+ if (name.Length > MaxDomainNameLength ||
+ name.AsSpan().ContainsAnyExcept(s_domainNameValidChars) ||
+ destination.IsEmpty ||
+ !Encoding.ASCII.TryGetBytes(name, destination.Slice(1), out int length) ||
+ destination.Length < length + 2)
+ {
+ // buffer too small
+ return false;
+ }
+
+ Span nameBuffer = destination.Slice(0, 1 + length);
+ Span label;
+ while (true)
+ {
+ // figure out the next label and prepend the length
+ int index = nameBuffer.Slice(1).IndexOf((byte)'.');
+ label = index == -1 ? nameBuffer.Slice(1) : nameBuffer.Slice(1, index);
+
+ if (label.Length == 0)
+ {
+ // empty label (explicit root) is only allowed at the end
+ if (index != -1)
+ {
+ written = 0;
+ return false;
+ }
+ }
+ // Label restrictions:
+ // - maximum 63 octets long
+ // - must start with a letter or digit (digit is allowed by RFC 1123)
+ // - may start with an underscore (underscore may be present only
+ // at the start of the label to support SRV records)
+ // - must end with a letter or digit
+ else if (label.Length > 63 ||
+ !char.IsAsciiLetterOrDigit((char)label[0]) && label[0] != '_' ||
+ label.Slice(1).Contains((byte)'_') ||
+ !char.IsAsciiLetterOrDigit((char)label[^1]))
+ {
+ written = 0;
+ return false;
+ }
+
+ nameBuffer[0] = (byte)label.Length;
+ written += label.Length + 1;
+
+ if (index == -1)
+ {
+ // this was the last label
+ break;
+ }
+
+ nameBuffer = nameBuffer.Slice(index + 1);
+ }
+
+ // Add root label if wasn't explicitly specified
+ if (label.Length != 0)
+ {
+ destination[written] = 0;
+ written++;
+ }
+
+ return true;
+ }
+
+ private static bool TryReadQNameCore(List> labels, int totalLength, ReadOnlyMemory messageBuffer, int offset, out int bytesRead, bool canStartWithPointer = true)
+ {
+ //
+ // domain name can be either
+ // - a sequence of labels, where each label consists of a length octet
+ // followed by that number of octets, terminated by a zero length octet
+ // (root label)
+ // - a pointer, where the first two bits are set to 1, and the remaining
+ // 14 bits are an offset (from the start of the message) to the true
+ // label
+ //
+ // It is not specified by the RFC if pointers must be backwards only,
+ // the code below prohibits forward (and self) pointers to avoid
+ // infinite loops. It also allows pointers only to point to a
+ // label, not to another pointer.
+ //
+
+ bytesRead = 0;
+ bool allowPointer = canStartWithPointer;
+
+ if (offset < 0 || offset >= messageBuffer.Length)
+ {
+ return false;
+ }
+
+ int currentOffset = offset;
+
+ while (true)
+ {
+ byte length = messageBuffer.Span[currentOffset];
+
+ if ((length & 0xC0) == 0x00)
+ {
+ // length followed by the label
+ if (length == 0)
+ {
+ // end of name
+ bytesRead = currentOffset - offset + 1;
+ return true;
+ }
+
+ if (currentOffset + 1 + length >= messageBuffer.Length)
+ {
+ // too many labels or truncated data
+ break;
+ }
+
+ // read next label/segment
+ labels.Add(messageBuffer.Slice(currentOffset + 1, length));
+ totalLength += 1 + length;
+
+ // subtract one for the length prefix of the first label
+ if (totalLength - 1 > MaxDomainNameLength)
+ {
+ // domain name is too long
+ return false;
+ }
+
+ currentOffset += 1 + length;
+ bytesRead += 1 + length;
+
+ // we read a label, they can be followed by pointer.
+ allowPointer = true;
+ }
+ else if ((length & 0xC0) == 0xC0)
+ {
+ // pointer, together with next byte gives the offset of the true label
+ if (!allowPointer || currentOffset + 1 >= messageBuffer.Length)
+ {
+ // pointer to pointer or truncated data
+ break;
+ }
+
+ bytesRead += 2;
+ int pointer = ((length & 0x3F) << 8) | messageBuffer.Span[currentOffset + 1];
+
+ // we prohibit self-references and forward pointers to avoid
+ // infinite loops, we do this by truncating the
+ // messageBuffer at the offset where we started reading the
+ // name. We also ignore the bytesRead from the recursive
+ // call, as we are only interested on how many bytes we read
+ // from the initial start of the name.
+ return TryReadQNameCore(labels, totalLength, messageBuffer.Slice(0, offset), pointer, out int _, false);
+ }
+ else
+ {
+ // top two bits are reserved, this means invalid data
+ break;
+ }
+ }
+
+ return false;
+
+ }
+
+ internal static bool TryReadQName(ReadOnlyMemory messageBuffer, int offset, out EncodedDomainName name, out int bytesRead)
+ {
+ List> labels = new List>();
+
+ if (TryReadQNameCore(labels, 0, messageBuffer, offset, out bytesRead))
+ {
+ name = new EncodedDomainName(labels);
+ return true;
+ }
+ else
+ {
+ bytesRead = 0;
+ name = default;
+ return false;
+ }
+ }
+
+ internal static bool TryReadService(ReadOnlyMemory buffer, out ushort priority, out ushort weight, out ushort port, out EncodedDomainName target, out int bytesRead)
+ {
+ // https://www.rfc-editor.org/rfc/rfc2782
+ if (!BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span, out priority) ||
+ !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span.Slice(2), out weight) ||
+ !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span.Slice(4), out port) ||
+ !TryReadQName(buffer.Slice(6), 0, out target, out bytesRead))
+ {
+ target = default;
+ priority = 0;
+ weight = 0;
+ port = 0;
+ bytesRead = 0;
+ return false;
+ }
+
+ bytesRead += 6;
+ return true;
+ }
+
+ internal static bool TryReadSoa(ReadOnlyMemory buffer, out EncodedDomainName primaryNameServer, out EncodedDomainName responsibleMailAddress, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out int bytesRead)
+ {
+ // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13
+ if (!TryReadQName(buffer, 0, out primaryNameServer, out int w1) ||
+ !TryReadQName(buffer.Slice(w1), 0, out responsibleMailAddress, out int w2) ||
+ !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2), out serial) ||
+ !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 4), out refresh) ||
+ !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 8), out retry) ||
+ !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 12), out expire) ||
+ !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 16), out minimum))
+ {
+ primaryNameServer = default;
+ responsibleMailAddress = default;
+ serial = 0;
+ refresh = 0;
+ retry = 0;
+ expire = 0;
+ minimum = 0;
+ bytesRead = 0;
+ return false;
+ }
+
+ bytesRead = w1 + w2 + 20;
+ return true;
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs
new file mode 100644
index 00000000000..adab9161737
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs
@@ -0,0 +1,39 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+using System.Net;
+using Microsoft.Extensions.Logging;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal partial class DnsResolver : IDnsResolver, IDisposable
+{
+ internal static partial class Log
+ {
+ [LoggerMessage(1, LogLevel.Debug, "Resolving {QueryType} {QueryName} on {Server} attempt {Attempt}", EventName = "Query")]
+ public static partial void Query(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt);
+
+ [LoggerMessage(2, LogLevel.Debug, "Result truncated for {QueryType} {QueryName} from {Server} attempt {Attempt}. Restarting over TCP", EventName = "ResultTruncated")]
+ public static partial void ResultTruncated(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt);
+
+ [LoggerMessage(3, LogLevel.Error, "Server {Server} replied with {ResponseCode} when querying {QueryType} {QueryName}", EventName = "ErrorResponseCode")]
+ public static partial void ErrorResponseCode(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, QueryResponseCode responseCode);
+
+ [LoggerMessage(4, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} timed out.", EventName = "Timeout")]
+ public static partial void Timeout(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt);
+
+ [LoggerMessage(5, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt}: no data matching given query type.", EventName = "NoData")]
+ public static partial void NoData(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt);
+
+ [LoggerMessage(6, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt}: server indicates given name does not exist.", EventName = "NameError")]
+ public static partial void NameError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt);
+
+ [LoggerMessage(7, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed to return a valid DNS response.", EventName = "MalformedResponse")]
+ public static partial void MalformedResponse(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt);
+
+ [LoggerMessage(8, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed due to a network error.", EventName = "NetworkError")]
+ public static partial void NetworkError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception);
+
+ [LoggerMessage(9, LogLevel.Error, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed.", EventName = "QueryError")]
+ public static partial void QueryError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception);
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs
new file mode 100644
index 00000000000..4be956cede9
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs
@@ -0,0 +1,115 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics;
+using System.Diagnostics.Metrics;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal partial class DnsResolver
+{
+ internal static class Telemetry
+ {
+ private static readonly Meter s_meter = new Meter("Microsoft.Extensions.ServiceDiscovery.Dns.Resolver");
+ private static readonly Histogram s_queryDuration = s_meter.CreateHistogram("query.duration", "ms", "DNS query duration");
+
+ private static bool IsEnabled() => s_queryDuration.Enabled;
+
+ public static NameResolutionActivity StartNameResolution(string hostName, QueryType queryType, long startingTimestamp)
+ {
+ if (IsEnabled())
+ {
+ return new NameResolutionActivity(hostName, queryType, startingTimestamp);
+ }
+
+ return default;
+ }
+
+ public static void StopNameResolution(string hostName, QueryType queryType, in NameResolutionActivity activity, object? answers, SendQueryError error, long endingTimestamp)
+ {
+ activity.Stop(answers, error, endingTimestamp, out TimeSpan duration);
+
+ if (!IsEnabled())
+ {
+ return;
+ }
+
+ var hostNameTag = KeyValuePair.Create("dns.question.name", (object?)hostName);
+ var queryTypeTag = KeyValuePair.Create("dns.question.type", (object?)queryType);
+
+ if (answers is not null)
+ {
+ s_queryDuration.Record(duration.TotalSeconds, hostNameTag, queryTypeTag);
+ }
+ else
+ {
+ var errorTypeTag = KeyValuePair.Create("error.type", (object?)error.ToString());
+ s_queryDuration.Record(duration.TotalSeconds, hostNameTag, queryTypeTag, errorTypeTag);
+ }
+ }
+ }
+
+ internal readonly struct NameResolutionActivity
+ {
+ private const string ActivitySourceName = "Microsoft.Extensions.ServiceDiscovery.Dns.Resolver";
+ private const string ActivityName = ActivitySourceName + ".Resolve";
+ private static readonly ActivitySource s_activitySource = new ActivitySource(ActivitySourceName);
+
+ private readonly long _startingTimestamp;
+ private readonly Activity? _activity; // null if activity is not started
+
+ public NameResolutionActivity(string hostName, QueryType queryType, long startingTimestamp)
+ {
+ _startingTimestamp = startingTimestamp;
+ _activity = s_activitySource.StartActivity(ActivityName, ActivityKind.Client);
+ if (_activity is not null)
+ {
+ _activity.DisplayName = $"Resolving {hostName}";
+ if (_activity.IsAllDataRequested)
+ {
+ _activity.SetTag("dns.question.name", hostName);
+ _activity.SetTag("dns.question.type", queryType.ToString());
+ }
+ }
+ }
+
+ public void Stop(object? answers, SendQueryError error, long endingTimestamp, out TimeSpan duration)
+ {
+ duration = Stopwatch.GetElapsedTime(_startingTimestamp, endingTimestamp);
+
+ if (_activity is null)
+ {
+ return;
+ }
+
+ if (_activity.IsAllDataRequested)
+ {
+ if (answers is not null)
+ {
+ static string[] ToStringHelper(T[] array) => array.Select(a => a!.ToString()!).ToArray();
+
+ string[]? answersArray = answers switch
+ {
+ ServiceResult[] serviceResults => ToStringHelper(serviceResults),
+ AddressResult[] addressResults => ToStringHelper(addressResults),
+ _ => null
+ };
+
+ Debug.Assert(answersArray is not null);
+ _activity.SetTag("dns.answers", answersArray);
+ }
+ else
+ {
+ _activity.SetTag("error.type", error.ToString());
+ }
+ }
+
+ if (answers is null)
+ {
+ _activity.SetStatus(ActivityStatusCode.Error);
+ }
+
+ _activity.Stop();
+ }
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs
new file mode 100644
index 00000000000..5722356a1c3
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs
@@ -0,0 +1,931 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Buffers.Binary;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Net;
+using System.Net.Sockets;
+using System.Runtime.InteropServices;
+using System.Security.Cryptography;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal sealed partial class DnsResolver : IDnsResolver, IDisposable
+{
+ private const int IPv4Length = 4;
+ private const int IPv6Length = 16;
+
+ // CancellationTokenSource.CancelAfter has a maximum timeout of Int32.MaxValue milliseconds.
+ private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue);
+
+ private bool _disposed;
+ private readonly ResolverOptions _options;
+ private readonly CancellationTokenSource _pendingRequestsCts = new();
+ private readonly TimeProvider _timeProvider;
+ private readonly ILogger _logger;
+
+ public DnsResolver(TimeProvider timeProvider, ILogger logger) : this(timeProvider, logger, OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() ? ResolvConf.GetOptions() : NetworkInfo.GetOptions())
+ {
+ }
+
+ internal DnsResolver(TimeProvider timeProvider, ILogger logger, ResolverOptions options)
+ {
+ _timeProvider = timeProvider;
+ _logger = logger;
+ _options = options;
+ Debug.Assert(_options.Servers.Count > 0);
+
+ if (options.Timeout != Timeout.InfiniteTimeSpan)
+ {
+ ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(options.Timeout, TimeSpan.Zero);
+ ArgumentOutOfRangeException.ThrowIfGreaterThan(options.Timeout, s_maxTimeout);
+ }
+ }
+
+ internal DnsResolver(ResolverOptions options) : this(TimeProvider.System, NullLogger.Instance, options)
+ {
+ }
+
+ internal DnsResolver(IEnumerable servers) : this(new ResolverOptions(servers.ToArray()))
+ {
+ }
+
+ internal DnsResolver(IPEndPoint server) : this(new ResolverOptions(server))
+ {
+ }
+
+ public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default)
+ {
+ ObjectDisposedException.ThrowIf(_disposed, this);
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // dnsSafeName is Disposed by SendQueryWithTelemetry
+ EncodedDomainName dnsSafeName = GetNormalizedHostName(name);
+ return SendQueryWithTelemetry(name, dnsSafeName, QueryType.SRV, ProcessResponse, cancellationToken);
+
+ static (SendQueryError, ServiceResult[]) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response)
+ {
+ var results = new List(response.Answers.Count);
+
+ foreach (var answer in response.Answers)
+ {
+ if (answer.Type == QueryType.SRV)
+ {
+ if (!DnsPrimitives.TryReadService(answer.Data, out ushort priority, out ushort weight, out ushort port, out EncodedDomainName target, out int bytesRead) || bytesRead != answer.Data.Length)
+ {
+ return (SendQueryError.MalformedResponse, []);
+ }
+
+ List addresses = new List();
+ foreach (var additional in response.Additionals)
+ {
+ // From RFC 2782:
+ //
+ // Target
+ // The domain name of the target host. There MUST be one or more
+ // address records for this name, the name MUST NOT be an alias (in
+ // the sense of RFC 1034 or RFC 2181). Implementors are urged, but
+ // not required, to return the address record(s) in the Additional
+ // Data section. Unless and until permitted by future standards
+ // action, name compression is not to be used for this field.
+ //
+ // A Target of "." means that the service is decidedly not
+ // available at this domain.
+ if (additional.Name.Equals(target) && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA))
+ {
+ addresses.Add(new AddressResult(response.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span)));
+ }
+ }
+
+ results.Add(new ServiceResult(response.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target.ToString(), addresses.ToArray()));
+ }
+ }
+
+ return (SendQueryError.NoError, results.ToArray());
+ }
+ }
+
+ public async ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default)
+ {
+ if (string.Equals(name, "localhost", StringComparison.OrdinalIgnoreCase))
+ {
+ // name localhost exists outside of DNS and can't be resolved by a DNS server
+ int len = (Socket.OSSupportsIPv4 ? 1 : 0) + (Socket.OSSupportsIPv6 ? 1 : 0);
+ AddressResult[] res = new AddressResult[len];
+
+ int index = 0;
+ if (Socket.OSSupportsIPv6) // prefer IPv6
+ {
+ res[index] = new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback);
+ index++;
+ }
+ if (Socket.OSSupportsIPv4)
+ {
+ res[index] = new AddressResult(DateTime.MaxValue, IPAddress.Loopback);
+ }
+
+ return res;
+ }
+
+ var ipv4AddressesTask = ResolveIPAddressesAsync(name, AddressFamily.InterNetwork, cancellationToken);
+ var ipv6AddressesTask = ResolveIPAddressesAsync(name, AddressFamily.InterNetworkV6, cancellationToken);
+
+ AddressResult[] ipv4Addresses = await ipv4AddressesTask.ConfigureAwait(false);
+ AddressResult[] ipv6Addresses = await ipv6AddressesTask.ConfigureAwait(false);
+
+ AddressResult[] results = new AddressResult[ipv4Addresses.Length + ipv6Addresses.Length];
+ ipv6Addresses.CopyTo(results, 0);
+ ipv4Addresses.CopyTo(results, ipv6Addresses.Length);
+ return results;
+ }
+
+ public ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default)
+ {
+ ObjectDisposedException.ThrowIf(_disposed, this);
+ cancellationToken.ThrowIfCancellationRequested();
+
+ if (addressFamily != AddressFamily.InterNetwork && addressFamily != AddressFamily.InterNetworkV6)
+ {
+ throw new ArgumentOutOfRangeException(nameof(addressFamily), addressFamily, "Invalid address family");
+ }
+
+ if (string.Equals(name, "localhost", StringComparison.OrdinalIgnoreCase))
+ {
+ // name localhost exists outside of DNS and can't be resolved by a DNS server
+ if (addressFamily == AddressFamily.InterNetwork && Socket.OSSupportsIPv4)
+ {
+ return ValueTask.FromResult([new AddressResult(DateTime.MaxValue, IPAddress.Loopback)]);
+ }
+ else if (addressFamily == AddressFamily.InterNetworkV6 && Socket.OSSupportsIPv6)
+ {
+ return ValueTask.FromResult([new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback)]);
+ }
+
+ return ValueTask.FromResult([]);
+ }
+
+ // dnsSafeName is Disposed by SendQueryWithTelemetry
+ EncodedDomainName dnsSafeName = GetNormalizedHostName(name);
+ var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA;
+ return SendQueryWithTelemetry(name, dnsSafeName, queryType, ProcessResponse, cancellationToken);
+
+ static (SendQueryError error, AddressResult[] result) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response)
+ {
+ List results = new List(response.Answers.Count);
+
+ // Servers send back CNAME records together with associated A/AAAA records. Servers
+ // send only those CNAME records relevant to the query, and if there is a CNAME record,
+ // there should not be other records associated with the name. Therefore, we simply follow
+ // the list of CNAME aliases until we get to the primary name and return the A/AAAA records
+ // associated.
+ //
+ // more info: https://datatracker.ietf.org/doc/html/rfc1034#section-3.6.2
+ //
+ // Most of the servers send the CNAME records in order so that we can sequentially scan the
+ // answers, but nothing prevents the records from being in arbitrary order. Attempt the linear
+ // scan first and fallback to a slower but more robust method if necessary.
+
+ bool success = true;
+ EncodedDomainName currentAlias = dnsSafeName;
+
+ foreach (var answer in response.Answers)
+ {
+ switch (answer.Type)
+ {
+ case QueryType.CNAME:
+ if (!TryReadTarget(answer, response.RawMessageBytes, out EncodedDomainName target))
+ {
+ return (SendQueryError.MalformedResponse, []);
+ }
+
+ if (answer.Name.Equals(currentAlias))
+ {
+ currentAlias = target;
+ continue;
+ }
+
+ break;
+
+ case var type when type == queryType:
+ if (!TryReadAddress(answer, queryType, out IPAddress? address))
+ {
+ return (SendQueryError.MalformedResponse, []);
+ }
+
+ if (answer.Name.Equals(currentAlias))
+ {
+ results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), address));
+ continue;
+ }
+
+ break;
+ }
+
+ // unexpected name or record type, fall back to more robust path
+ results.Clear();
+ success = false;
+ break;
+ }
+
+ if (success)
+ {
+ return (SendQueryError.NoError, results.ToArray());
+ }
+
+ // more expensive path for uncommon (but valid) cases where CNAME records are out of order. Use of Dictionary
+ // allows us to stay within O(n) complexity for the number of answers, but we will use more memory.
+ Dictionary aliasMap = new();
+ Dictionary> aRecordMap = new();
+ foreach (var answer in response.Answers)
+ {
+ if (answer.Type == QueryType.CNAME)
+ {
+ // map the alias to the target name
+ if (!TryReadTarget(answer, response.RawMessageBytes, out EncodedDomainName target))
+ {
+ return (SendQueryError.MalformedResponse, []);
+ }
+
+ if (!aliasMap.TryAdd(answer.Name, target))
+ {
+ // Duplicate CNAME record
+ return (SendQueryError.MalformedResponse, []);
+ }
+ }
+
+ if (answer.Type == queryType)
+ {
+ if (!TryReadAddress(answer, queryType, out IPAddress? address))
+ {
+ return (SendQueryError.MalformedResponse, []);
+ }
+
+ if (!aRecordMap.TryGetValue(answer.Name, out List? addressList))
+ {
+ addressList = new List();
+ aRecordMap.Add(answer.Name, addressList);
+ }
+
+ addressList.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), address));
+ }
+ }
+
+ // follow the CNAME chain, limit the maximum number of iterations to avoid infinite loops.
+ int i = 0;
+ currentAlias = dnsSafeName;
+ while (aliasMap.TryGetValue(currentAlias, out EncodedDomainName nextAlias))
+ {
+ if (i >= aliasMap.Count)
+ {
+ // circular CNAME chain
+ return (SendQueryError.MalformedResponse, []);
+ }
+
+ i++;
+
+ if (aRecordMap.ContainsKey(currentAlias))
+ {
+ // both CNAME record and A/AAAA records exist for the current alias
+ return (SendQueryError.MalformedResponse, []);
+ }
+
+ currentAlias = nextAlias;
+ }
+
+ // Now we have the final target name, check if we have any A/AAAA records for it.
+ aRecordMap.TryGetValue(currentAlias, out List? finalAddressList);
+ return (SendQueryError.NoError, finalAddressList?.ToArray() ?? []);
+
+ static bool TryReadTarget(in DnsResourceRecord record, ArraySegment messageBytes, out EncodedDomainName target)
+ {
+ Debug.Assert(record.Type == QueryType.CNAME, "Only CNAME records should be processed here.");
+
+ target = default;
+
+ // some servers use domain name compression even inside CNAME records. In order to decode those
+ // correctly, we need to pass the entire message to TryReadQName. The Data span inside the record
+ // should be backed by the array containing the entire DNS message. We just need to account for the
+ // 2 byte offset in case of TCP fallback.
+ var gotArray = MemoryMarshal.TryGetArray(record.Data, out ArraySegment segment);
+ Debug.Assert(gotArray, "Failed to get array segment");
+ Debug.Assert(segment.Array == messageBytes.Array, "record data backed by different array than the original message");
+
+ int messageOffset = messageBytes.Offset;
+
+ bool result = DnsPrimitives.TryReadQName(segment.Array.AsMemory(messageOffset, segment.Offset + segment.Count - messageOffset), segment.Offset - messageOffset, out EncodedDomainName targetName, out int bytesRead) && bytesRead == record.Data.Length;
+ if (result)
+ {
+ target = targetName;
+ }
+
+ return result;
+ }
+
+ static bool TryReadAddress(in DnsResourceRecord record, QueryType type, [NotNullWhen(true)] out IPAddress? target)
+ {
+ Debug.Assert(record.Type is QueryType.A or QueryType.AAAA, "Only CNAME records should be processed here.");
+
+ target = null;
+ if (record.Type == QueryType.A && record.Data.Length != IPv4Length ||
+ record.Type == QueryType.AAAA && record.Data.Length != IPv6Length)
+ {
+ return false;
+ }
+
+ target = new IPAddress(record.Data.Span);
+ return true;
+ }
+ }
+ }
+
+ private async ValueTask SendQueryWithTelemetry(string name, EncodedDomainName dnsSafeName, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken)
+ {
+ NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType, _timeProvider.GetTimestamp());
+ (SendQueryError error, TResult[] result) = await SendQueryWithRetriesAsync(name, dnsSafeName, queryType, processResponseFunc, cancellationToken).ConfigureAwait(false);
+ Telemetry.StopNameResolution(name, queryType, activity, null, error, _timeProvider.GetTimestamp());
+ dnsSafeName.Dispose();
+
+ return result;
+ }
+
+ internal struct SendQueryResult
+ {
+ public DnsResponse Response;
+ public SendQueryError Error;
+ }
+
+ async ValueTask<(SendQueryError error, TResult[] result)> SendQueryWithRetriesAsync(string name, EncodedDomainName dnsSafeName, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken)
+ {
+ SendQueryError lastError = SendQueryError.InternalError; // will be overwritten by the first attempt
+ for (int index = 0; index < _options.Servers.Count; index++)
+ {
+ IPEndPoint serverEndPoint = _options.Servers[index];
+
+ for (int attempt = 1; attempt <= _options.Attempts; attempt++)
+ {
+ DnsResponse response = default;
+ try
+ {
+ TResult[] results = Array.Empty();
+
+ try
+ {
+ SendQueryResult queryResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, dnsSafeName, queryType, attempt, cancellationToken).ConfigureAwait(false);
+ lastError = queryResult.Error;
+ response = queryResult.Response;
+
+ if (lastError == SendQueryError.NoError)
+ {
+ // Given that result.Error is NoError, there should be at least one answer.
+ Debug.Assert(response.Answers.Count > 0);
+ (lastError, results) = processResponseFunc(dnsSafeName, queryType, queryResult.Response);
+ }
+ }
+ catch (SocketException ex)
+ {
+ Log.NetworkError(_logger, queryType, name, serverEndPoint, attempt, ex);
+ lastError = SendQueryError.NetworkError;
+ }
+ catch (Exception ex) when (!cancellationToken.IsCancellationRequested)
+ {
+ // internal error, propagate
+ Log.QueryError(_logger, queryType, name, serverEndPoint, attempt, ex);
+ throw;
+ }
+
+ switch (lastError)
+ {
+ //
+ // Definitive answers, no point retrying
+ //
+ case SendQueryError.NoError:
+ return (lastError, results);
+
+ case SendQueryError.NameError:
+ // authoritative answer that the name does not exist, no point in retrying
+ Log.NameError(_logger, queryType, name, serverEndPoint, attempt);
+ return (lastError, results);
+
+ case SendQueryError.NoData:
+ // no data available for the name from authoritative server
+ Log.NoData(_logger, queryType, name, serverEndPoint, attempt);
+ return (lastError, results);
+
+ //
+ // Transient errors, retry on the same server
+ //
+ case SendQueryError.Timeout:
+ Log.Timeout(_logger, queryType, name, serverEndPoint, attempt);
+ continue;
+
+ case SendQueryError.NetworkError:
+ // TODO: retry with exponential backoff?
+ continue;
+
+ case SendQueryError.ServerError when response.Header.ResponseCode == QueryResponseCode.ServerFailure:
+ // ServerFailure may indicate transient failure with upstream DNS servers, retry on the same server
+ Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, response.Header.ResponseCode);
+ continue;
+
+ //
+ // Persistent errors, skip to the next server
+ //
+ case SendQueryError.ServerError:
+ // this should cover all response codes except NoError, NameError which are definite and handled above, and
+ // ServerFailure which is a transient error and handled above.
+ Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, response.Header.ResponseCode);
+ break;
+
+ case SendQueryError.MalformedResponse:
+ Log.MalformedResponse(_logger, queryType, name, serverEndPoint, attempt);
+ break;
+
+ case SendQueryError.InternalError:
+ // exception logged above.
+ break;
+ }
+
+ // actual break that causes skipping to the next server
+ break;
+ }
+ finally
+ {
+ response.Dispose();
+ }
+ }
+ }
+
+ // if we get here, we exhausted all servers and all attempts
+ return (lastError, []);
+ }
+
+ internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, EncodedDomainName dnsSafeName, QueryType queryType, int attempt, CancellationToken cancellationToken)
+ {
+ (CancellationTokenSource cts, bool disposeTokenSource, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
+
+ try
+ {
+ return await SendQueryToServerAsync(serverEndPoint, name, dnsSafeName, queryType, attempt, cts.Token).ConfigureAwait(false);
+ }
+ catch (OperationCanceledException) when (
+ !cancellationToken.IsCancellationRequested && // not cancelled by the caller
+ !pendingRequestsCts.IsCancellationRequested) // not cancelled by the global token (dispose)
+ // the only remaining token that could cancel this is the linked cts from the timeout.
+ {
+ Debug.Assert(cts.Token.IsCancellationRequested);
+ return new SendQueryResult { Error = SendQueryError.Timeout };
+ }
+ catch (OperationCanceledException ex) when (cancellationToken.IsCancellationRequested && ex.CancellationToken != cancellationToken)
+ {
+ // cancellation was initiated by the caller, but exception was triggered by a linked token,
+ // rethrow the exception with the caller's token.
+ cancellationToken.ThrowIfCancellationRequested();
+ throw new UnreachableException();
+ }
+ finally
+ {
+ if (disposeTokenSource)
+ {
+ cts.Dispose();
+ }
+ }
+ }
+
+ private async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, EncodedDomainName dnsSafeName, QueryType queryType, int attempt, CancellationToken cancellationToken)
+ {
+ Log.Query(_logger, queryType, name, serverEndPoint, attempt);
+
+ SendQueryError sendError = SendQueryError.NoError;
+ DateTime queryStartedTime = _timeProvider.GetUtcNow().DateTime;
+ DnsDataReader responseReader = default;
+ DnsMessageHeader header;
+
+ try
+ {
+ // use transport override if provided
+ if (_options._transportOverride != null)
+ {
+ (responseReader, header, sendError) = SendDnsQueryCustomTransport(_options._transportOverride, dnsSafeName, queryType);
+ }
+ else
+ {
+ (responseReader, header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false);
+
+ if (header.IsResultTruncated)
+ {
+ Log.ResultTruncated(_logger, queryType, name, serverEndPoint, 0);
+ responseReader.Dispose();
+ // TCP fallback
+ (responseReader, header, sendError) = await SendDnsQueryCoreTcpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false);
+ }
+ }
+
+ if (sendError != SendQueryError.NoError)
+ {
+ // we failed to get back any response
+ return new SendQueryResult { Error = sendError };
+ }
+
+ if ((uint)header.ResponseCode > (uint)QueryResponseCode.Refused)
+ {
+ // Response code is outside of valid range
+ return new SendQueryResult
+ {
+ Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!),
+ Error = SendQueryError.MalformedResponse
+ };
+ }
+
+ // Recheck that the server echoes back the DNS question
+ if (header.QueryCount != 1 ||
+ !responseReader.TryReadQuestion(out var qName, out var qType, out var qClass) ||
+ !dnsSafeName.Equals(qName) || qType != queryType || qClass != QueryClass.Internet)
+ {
+ // DNS Question mismatch
+ return new SendQueryResult
+ {
+ Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!),
+ Error = SendQueryError.MalformedResponse
+ };
+ }
+
+ // Structurally separate the resource records, this will validate only the
+ // "outside structure" of the resource record, it will not validate the content.
+ int ttl = int.MaxValue;
+ if (!TryReadRecords(header.AnswerCount, ref ttl, ref responseReader, out List? answers) ||
+ !TryReadRecords(header.AuthorityCount, ref ttl, ref responseReader, out List? authorities) ||
+ !TryReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader, out List? additionals))
+ {
+ return new SendQueryResult
+ {
+ Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!),
+ Error = SendQueryError.MalformedResponse
+ };
+ }
+
+ DateTime expirationTime =
+ (answers.Count + authorities.Count + additionals.Count) > 0 ? queryStartedTime.AddSeconds(ttl) : queryStartedTime;
+
+ SendQueryError validationError = ValidateResponse(header.ResponseCode, queryStartedTime, answers, authorities, ref expirationTime);
+
+ // we transfer ownership of RawData to the response
+ DnsResponse response = new DnsResponse(responseReader.MessageBuffer, header, queryStartedTime, expirationTime, answers, authorities, additionals);
+ responseReader = default; // avoid disposing (and returning RawData to the pool)
+
+ return new SendQueryResult { Response = response, Error = validationError };
+ }
+ finally
+ {
+ responseReader.Dispose();
+ }
+
+ static bool TryReadRecords(int count, ref int ttl, ref DnsDataReader reader, out List records)
+ {
+ // Since `count` is attacker controlled, limit the initial capacity
+ // to 32 items to avoid excessive memory allocation. More than 32
+ // records are unusual so we don't need to optimize for them.
+ records = new(Math.Min(count, 32));
+
+ for (int i = 0; i < count; i++)
+ {
+ if (!reader.TryReadResourceRecord(out var record))
+ {
+ return false;
+ }
+
+ ttl = Math.Min(ttl, record.Ttl);
+ records.Add(new DnsResourceRecord(record.Name, record.Type, record.Class, record.Ttl, record.Data));
+ }
+
+ return true;
+ }
+ }
+
+ internal static bool GetNegativeCacheExpiration(DateTime createdAt, List authorities, out DateTime expiration)
+ {
+ //
+ // RFC 2308 Section 5 - Caching Negative Answers
+ //
+ // Like normal answers negative answers have a time to live (TTL). As
+ // there is no record in the answer section to which this TTL can be
+ // applied, the TTL must be carried by another method. This is done by
+ // including the SOA record from the zone in the authority section of
+ // the reply. When the authoritative server creates this record its TTL
+ // is taken from the minimum of the SOA.MINIMUM field and SOA's TTL.
+ // This TTL decrements in a similar manner to a normal cached answer and
+ // upon reaching zero (0) indicates the cached negative answer MUST NOT
+ // be used again.
+ //
+
+ DnsResourceRecord? soa = authorities.FirstOrDefault(r => r.Type == QueryType.SOA);
+ if (soa != null && DnsPrimitives.TryReadSoa(soa.Value.Data, out _, out _, out _, out _, out _, out _, out uint minimum, out _))
+ {
+ expiration = createdAt.AddSeconds(Math.Min(minimum, soa.Value.Ttl));
+ return true;
+ }
+
+ expiration = default;
+ return false;
+ }
+
+ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, DateTime createdAt, List answers, List authorities, ref DateTime expiration)
+ {
+ if (responseCode == QueryResponseCode.NoError)
+ {
+ if (answers.Count > 0)
+ {
+ return SendQueryError.NoError;
+ }
+ //
+ // RFC 2308 Section 2.2 - No Data
+ //
+ // NODATA is indicated by an answer with the RCODE set to NOERROR and no
+ // relevant answers in the answer section. The authority section will
+ // contain an SOA record, or there will be no NS records there.
+ //
+ //
+ // RFC 2308 Section 5 - Caching Negative Answers
+ //
+ // A negative answer that resulted from a no data error (NODATA) should
+ // be cached such that it can be retrieved and returned in response to
+ // another query for the same that resulted in
+ // the cached negative response.
+ //
+ if (!authorities.Any(r => r.Type == QueryType.NS) && GetNegativeCacheExpiration(createdAt, authorities, out DateTime newExpiration))
+ {
+ expiration = newExpiration;
+ // _cache.TryAdd(name, queryType, expiration, Array.Empty());
+ }
+ return SendQueryError.NoData;
+ }
+
+ if (responseCode == QueryResponseCode.NameError)
+ {
+ //
+ // RFC 2308 Section 5 - Caching Negative Answers
+ //
+ // A negative answer that resulted from a name error (NXDOMAIN) should
+ // be cached such that it can be retrieved and returned in response to
+ // another query for the same that resulted in the
+ // cached negative response.
+ //
+ if (GetNegativeCacheExpiration(createdAt, authorities, out DateTime newExpiration))
+ {
+ expiration = newExpiration;
+ // _cache.TryAddNonexistent(name, expiration);
+ }
+
+ return SendQueryError.NameError;
+ }
+
+ return SendQueryError.ServerError;
+ }
+
+ internal static (DnsDataReader reader, DnsMessageHeader header, SendQueryError sendError) SendDnsQueryCustomTransport(Func, int, int> callback, EncodedDomainName dnsSafeName, QueryType queryType)
+ {
+ byte[] buffer = ArrayPool.Shared.Rent(2048);
+ try
+ {
+ (ushort transactionId, int length) = EncodeQuestion(buffer, dnsSafeName, queryType);
+ length = callback(buffer, length);
+
+ DnsDataReader responseReader = new DnsDataReader(new ArraySegment(buffer, 0, length), true);
+
+ if (!responseReader.TryReadHeader(out DnsMessageHeader header) ||
+ header.TransactionId != transactionId ||
+ !header.IsResponse)
+ {
+ return (default, default, SendQueryError.MalformedResponse);
+ }
+
+ // transfer ownership of buffer to the caller
+ buffer = null!;
+ return (responseReader, header, SendQueryError.NoError);
+ }
+ finally
+ {
+ if (buffer != null)
+ {
+ ArrayPool.Shared.Return(buffer);
+ }
+ }
+ }
+
+ internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreUdpAsync(IPEndPoint serverEndPoint, EncodedDomainName dnsSafeName, QueryType queryType, CancellationToken cancellationToken)
+ {
+ var buffer = ArrayPool.Shared.Rent(512);
+ try
+ {
+ Memory memory = buffer;
+ (ushort transactionId, int length) = EncodeQuestion(memory, dnsSafeName, queryType);
+
+ using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
+ await socket.SendToAsync(memory.Slice(0, length), SocketFlags.None, serverEndPoint, cancellationToken).ConfigureAwait(false);
+
+ DnsDataReader responseReader;
+ DnsMessageHeader header;
+
+ while (true)
+ {
+ // Because this is UDP, the response must be in a single packet,
+ // if the response does not fit into a single UDP packet, the server will
+ // set the Truncated flag in the header, and we will need to retry with TCP.
+ int packetLength = await socket.ReceiveAsync(memory, SocketFlags.None, cancellationToken).ConfigureAwait(false);
+
+ if (packetLength < DnsMessageHeader.HeaderLength)
+ {
+ continue;
+ }
+
+ responseReader = new DnsDataReader(new ArraySegment(buffer, 0, packetLength), true);
+ if (!responseReader.TryReadHeader(out header) ||
+ header.TransactionId != transactionId ||
+ !header.IsResponse)
+ {
+ // header mismatch, this is not a response to our query
+ continue;
+ }
+
+ // ownership of the buffer is transferred to the reader, caller will dispose.
+ buffer = null!;
+ return (responseReader, header);
+ }
+ }
+ finally
+ {
+ if (buffer != null)
+ {
+ ArrayPool.Shared.Return(buffer);
+ }
+ }
+ }
+
+ internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header, SendQueryError error)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, EncodedDomainName dnsSafeName, QueryType queryType, CancellationToken cancellationToken)
+ {
+ var buffer = ArrayPool.Shared.Rent(8 * 1024);
+ try
+ {
+ // When sending over TCP, the message is prefixed by 2B length
+ (ushort transactionId, int length) = EncodeQuestion(buffer.AsMemory(2), dnsSafeName, queryType);
+ BinaryPrimitives.WriteUInt16BigEndian(buffer, (ushort)length);
+
+ using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
+ await socket.ConnectAsync(serverEndPoint, cancellationToken).ConfigureAwait(false);
+ await socket.SendAsync(buffer.AsMemory(0, length + 2), SocketFlags.None, cancellationToken).ConfigureAwait(false);
+
+ int responseLength = -1;
+ int bytesRead = 0;
+ while (responseLength < 0 || bytesRead < responseLength + 2)
+ {
+ int read = await socket.ReceiveAsync(buffer.AsMemory(bytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false);
+ bytesRead += read;
+
+ if (read == 0)
+ {
+ // connection closed before receiving complete response message
+ return (default, default, SendQueryError.MalformedResponse);
+ }
+
+ if (responseLength < 0 && bytesRead >= 2)
+ {
+ responseLength = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2));
+
+ if (responseLength + 2 > buffer.Length)
+ {
+ // even though this is user-controlled pre-allocation, it is limited to
+ // 64 kB, so it should be fine.
+ var largerBuffer = ArrayPool.Shared.Rent(responseLength + 2);
+ Array.Copy(buffer, largerBuffer, bytesRead);
+ ArrayPool.Shared.Return(buffer);
+ buffer = largerBuffer;
+ }
+ }
+ }
+
+ DnsDataReader responseReader = new DnsDataReader(new ArraySegment(buffer, 2, responseLength), true);
+ if (!responseReader.TryReadHeader(out DnsMessageHeader header) ||
+ header.TransactionId != transactionId ||
+ !header.IsResponse)
+ {
+ // header mismatch on TCP fallback
+ return (default, default, SendQueryError.MalformedResponse);
+ }
+
+ // transfer ownership of buffer to the caller
+ buffer = null!;
+ return (responseReader, header, SendQueryError.NoError);
+ }
+ finally
+ {
+ if (buffer != null)
+ {
+ ArrayPool.Shared.Return(buffer);
+ }
+ }
+ }
+
+ private static (ushort id, int length) EncodeQuestion(Memory buffer, EncodedDomainName dnsSafeName, QueryType queryType)
+ {
+ DnsMessageHeader header = new DnsMessageHeader
+ {
+ TransactionId = (ushort)RandomNumberGenerator.GetInt32(ushort.MaxValue + 1),
+ QueryFlags = QueryFlags.RecursionDesired,
+ QueryCount = 1
+ };
+
+ DnsDataWriter writer = new DnsDataWriter(buffer);
+ if (!writer.TryWriteHeader(header) ||
+ !writer.TryWriteQuestion(dnsSafeName, queryType, QueryClass.Internet))
+ {
+ // should never happen since we validated the name length before
+ throw new InvalidOperationException("Buffer too small");
+ }
+ return (header.TransactionId, writer.Position);
+ }
+
+ public void Dispose()
+ {
+ if (!_disposed)
+ {
+ _disposed = true;
+
+ // Cancel all pending requests (if any). Note that we don't call CancelPendingRequests() but cancel
+ // the CTS directly. The reason is that CancelPendingRequests() would cancel the current CTS and create
+ // a new CTS. We don't want a new CTS in this case.
+ _pendingRequestsCts.Cancel();
+ _pendingRequestsCts.Dispose();
+ }
+ }
+
+ private (CancellationTokenSource TokenSource, bool DisposeTokenSource, CancellationTokenSource PendingRequestsCts) PrepareCancellationTokenSource(CancellationToken cancellationToken)
+ {
+ // We need a CancellationTokenSource to use with the request. We always have the global
+ // _pendingRequestsCts to use, plus we may have a token provided by the caller, and we may
+ // have a timeout. If we have a timeout or a caller-provided token, we need to create a new
+ // CTS (we can't, for example, timeout the pending requests CTS, as that could cancel other
+ // unrelated operations). Otherwise, we can use the pending requests CTS directly.
+
+ // Snapshot the current pending requests cancellation source. It can change concurrently due to cancellation being requested
+ // and it being replaced, and we need a stable view of it: if cancellation occurs and the caller's token hasn't been canceled,
+ // it's either due to this source or due to the timeout, and checking whether this source is the culprit is reliable whereas
+ // it's more approximate checking elapsed time.
+ CancellationTokenSource pendingRequestsCts = _pendingRequestsCts;
+ TimeSpan timeout = _options.Timeout;
+
+ bool hasTimeout = timeout != System.Threading.Timeout.InfiniteTimeSpan;
+ if (hasTimeout || cancellationToken.CanBeCanceled)
+ {
+ CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, pendingRequestsCts.Token);
+ if (hasTimeout)
+ {
+ cts.CancelAfter(timeout);
+ }
+
+ return (cts, DisposeTokenSource: true, pendingRequestsCts);
+ }
+
+ return (pendingRequestsCts, DisposeTokenSource: false, pendingRequestsCts);
+ }
+
+ private static EncodedDomainName GetNormalizedHostName(string name)
+ {
+ byte[] buffer = ArrayPool.Shared.Rent(256);
+ try
+ {
+ if (!DnsPrimitives.TryWriteQName(buffer, name, out _))
+ {
+ throw new ArgumentException($"'{name}' is not a valid DNS name.", nameof(name));
+ }
+
+ List> labels = new();
+ Memory memory = buffer.AsMemory();
+ while (true)
+ {
+ int len = memory.Span[0];
+
+ if (len == 0)
+ {
+ // root label, we are finished
+ break;
+ }
+
+ labels.Add(memory.Slice(1, len));
+ memory = memory.Slice(len + 1);
+ }
+
+ buffer = null!; // ownership transferred to the EncodedDomainName
+ return new EncodedDomainName(labels, buffer);
+ }
+ finally
+ {
+ if (buffer != null)
+ {
+ ArrayPool.Shared.Return(buffer);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs
new file mode 100644
index 00000000000..914ff9aac17
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs
@@ -0,0 +1,22 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal struct DnsResourceRecord
+{
+ public EncodedDomainName Name { get; }
+ public QueryType Type { get; }
+ public QueryClass Class { get; }
+ public int Ttl { get; }
+ public ReadOnlyMemory Data { get; }
+
+ public DnsResourceRecord(EncodedDomainName name, QueryType type, QueryClass @class, int ttl, ReadOnlyMemory data)
+ {
+ Name = name;
+ Type = type;
+ Class = @class;
+ Ttl = ttl;
+ Data = data;
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs
new file mode 100644
index 00000000000..5a7fc8a0b52
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs
@@ -0,0 +1,39 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal struct DnsResponse : IDisposable
+{
+ public DnsMessageHeader Header { get; }
+ public List Answers { get; }
+ public List Authorities { get; }
+ public List Additionals { get; }
+ public DateTime CreatedAt { get; }
+ public DateTime Expiration { get; }
+ public ArraySegment RawMessageBytes { get; private set; }
+
+ public DnsResponse(ArraySegment rawData, DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals)
+ {
+ RawMessageBytes = rawData;
+
+ Header = header;
+ CreatedAt = createdAt;
+ Expiration = expiration;
+ Answers = answers;
+ Authorities = authorities;
+ Additionals = additionals;
+ }
+
+ public void Dispose()
+ {
+ if (RawMessageBytes.Array != null)
+ {
+ ArrayPool.Shared.Return(RawMessageBytes.Array);
+ }
+
+ RawMessageBytes = default; // prevent further access to the raw data
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs
new file mode 100644
index 00000000000..4c258cac3ac
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs
@@ -0,0 +1,82 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Text;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal struct EncodedDomainName : IEquatable, IDisposable
+{
+ public IReadOnlyList> Labels { get; }
+ private byte[]? _pooledBuffer;
+
+ public EncodedDomainName(List> labels, byte[]? pooledBuffer = null)
+ {
+ Labels = labels;
+ _pooledBuffer = pooledBuffer;
+ }
+ public override string ToString()
+ {
+ StringBuilder sb = new StringBuilder();
+
+ foreach (var label in Labels)
+ {
+ if (sb.Length > 0)
+ {
+ sb.Append('.');
+ }
+ sb.Append(Encoding.ASCII.GetString(label.Span));
+ }
+
+ return sb.ToString();
+ }
+
+ public bool Equals(EncodedDomainName other)
+ {
+ if (Labels.Count != other.Labels.Count)
+ {
+ return false;
+ }
+
+ for (int i = 0; i < Labels.Count; i++)
+ {
+ if (!Ascii.EqualsIgnoreCase(Labels[i].Span, other.Labels[i].Span))
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ public override bool Equals(object? obj)
+ {
+ return obj is EncodedDomainName other && Equals(other);
+ }
+
+ public override int GetHashCode()
+ {
+ HashCode hash = new HashCode();
+
+ foreach (var label in Labels)
+ {
+ foreach (byte b in label.Span)
+ {
+ hash.Add((byte)char.ToLower((char)b));
+ }
+ }
+
+ return hash.ToHashCode();
+ }
+
+ public void Dispose()
+ {
+ if (_pooledBuffer != null)
+ {
+ ArrayPool.Shared.Return(_pooledBuffer);
+ }
+
+ _pooledBuffer = null;
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs
new file mode 100644
index 00000000000..e09168d9552
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs
@@ -0,0 +1,13 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Net.Sockets;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal interface IDnsResolver
+{
+ ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default);
+ ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default);
+ ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default);
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs
new file mode 100644
index 00000000000..c2ef13f922e
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs
@@ -0,0 +1,36 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Net;
+using System.Net.NetworkInformation;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal static class NetworkInfo
+{
+ // basic option to get DNS serves via NetworkInfo. We may get it directly later via proper APIs.
+ public static ResolverOptions GetOptions()
+ {
+ List servers = new List();
+
+ foreach (NetworkInterface nic in NetworkInterface.GetAllNetworkInterfaces())
+ {
+ IPInterfaceProperties properties = nic.GetIPProperties();
+ // avoid loopback, VPN etc. Should be re-visited.
+
+ if (nic.NetworkInterfaceType == NetworkInterfaceType.Ethernet && nic.OperationalStatus == OperationalStatus.Up)
+ {
+ foreach (IPAddress server in properties.DnsAddresses)
+ {
+ IPEndPoint ep = new IPEndPoint(server, 53); // 53 is standard DNS port
+ if (!servers.Contains(ep))
+ {
+ servers.Add(ep);
+ }
+ }
+ }
+ }
+
+ return new ResolverOptions(servers);
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs
new file mode 100644
index 00000000000..732ca0216da
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs
@@ -0,0 +1,9 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal enum QueryClass
+{
+ Internet = 1
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs
new file mode 100644
index 00000000000..02474b6cda1
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs
@@ -0,0 +1,15 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+[Flags]
+internal enum QueryFlags : ushort
+{
+ RecursionAvailable = 0x0080,
+ RecursionDesired = 0x0100,
+ ResultTruncated = 0x0200,
+ HasAuthorityAnswer = 0x0400,
+ HasResponse = 0x8000,
+ ResponseCodeMask = 0x000F,
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs
new file mode 100644
index 00000000000..dd51c712112
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs
@@ -0,0 +1,42 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+///
+/// The response code (RCODE) in a DNS query response.
+///
+internal enum QueryResponseCode : byte
+{
+ ///
+ /// No error condition
+ ///
+ NoError = 0,
+
+ ///
+ /// The name server was unable to interpret the query.
+ ///
+ FormatError = 1,
+
+ ///
+ /// The name server was unable to process this query due to a problem with the name server.
+ ///
+ ServerFailure = 2,
+
+ ///
+ /// Meaningful only for responses from an authoritative name server, this
+ /// code signifies that the domain name referenced in the query does not
+ /// exist.
+ ///
+ NameError = 3,
+
+ ///
+ /// The name server does not support the requested kind of query.
+ ///
+ NotImplemented = 4,
+
+ ///
+ /// The name server refuses to perform the specified operation for policy reasons.
+ ///
+ Refused = 5,
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs
new file mode 100644
index 00000000000..2ccc898a5b7
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs
@@ -0,0 +1,55 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+///
+/// DNS Query Types.
+///
+internal enum QueryType
+{
+ ///
+ /// A host address.
+ ///
+ A = 1,
+
+ ///
+ /// An authoritative name server.
+ ///
+ NS = 2,
+
+ ///
+ /// The canonical name for an alias.
+ ///
+ CNAME = 5,
+
+ ///
+ /// Marks the start of a zone of authority.
+ ///
+ SOA = 6,
+
+ ///
+ /// Mail exchange.
+ ///
+ MX = 15,
+
+ ///
+ /// Text strings.
+ ///
+ TXT = 16,
+
+ ///
+ /// IPv6 host address. (RFC 3596)
+ ///
+ AAAA = 28,
+
+ ///
+ /// Location information. (RFC 2782)
+ ///
+ SRV = 33,
+
+ ///
+ /// Wildcard match.
+ ///
+ All = 255
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs
new file mode 100644
index 00000000000..fbfdc5ae027
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs
@@ -0,0 +1,48 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Net;
+using System.Runtime.Versioning;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal static class ResolvConf
+{
+ [SupportedOSPlatform("linux")]
+ [SupportedOSPlatform("osx")]
+ public static ResolverOptions GetOptions()
+ {
+ return GetOptions(new StreamReader("/etc/resolv.conf"));
+ }
+
+ public static ResolverOptions GetOptions(TextReader reader)
+ {
+ List serverList = new();
+
+ while (reader.ReadLine() is string line)
+ {
+ string[] tokens = line.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
+
+ if (line.StartsWith("nameserver"))
+ {
+ if (tokens.Length >= 2 && IPAddress.TryParse(tokens[1], out IPAddress? address))
+ {
+ serverList.Add(new IPEndPoint(address, 53)); // 53 is standard DNS port
+
+ if (serverList.Count == 3)
+ {
+ break; // resolv.conf manpage allow max 3 nameservers anyway
+ }
+ }
+ }
+ }
+
+ if (serverList.Count == 0)
+ {
+ // If no nameservers are configured, fall back to the default behavior of using the system resolver configuration.
+ return NetworkInfo.GetOptions();
+ }
+
+ return new ResolverOptions(serverList);
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs
new file mode 100644
index 00000000000..51d03f64bfd
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs
@@ -0,0 +1,31 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Net;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal sealed class ResolverOptions
+{
+ public IReadOnlyList Servers;
+ public int Attempts = 2;
+ public TimeSpan Timeout = TimeSpan.FromSeconds(3);
+
+ // override for testing purposes
+ internal Func, int, int>? _transportOverride;
+
+ public ResolverOptions(IReadOnlyList servers)
+ {
+ if (servers.Count == 0)
+ {
+ throw new ArgumentException("At least one DNS server is required.", nameof(servers));
+ }
+
+ Servers = servers;
+ }
+
+ public ResolverOptions(IPEndPoint server)
+ {
+ Servers = new IPEndPoint[] { server };
+ }
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs
new file mode 100644
index 00000000000..aed799ac8d6
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs
@@ -0,0 +1,10 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Net;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal record struct AddressResult(DateTime ExpiresAt, IPAddress Address);
+
+internal record struct ServiceResult(DateTime ExpiresAt, int Priority, int Weight, int Port, string Target, AddressResult[] Addresses);
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs
new file mode 100644
index 00000000000..3ba5632e207
--- /dev/null
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs
@@ -0,0 +1,47 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
+
+internal enum SendQueryError
+{
+ ///
+ /// DNS query was successful and returned response message with answers.
+ ///
+ NoError,
+
+ ///
+ /// Server failed to respond to the query withing specified timeout.
+ ///
+ Timeout,
+
+ ///
+ /// Server returned a response with an error code.
+ ///
+ ServerError,
+
+ ///
+ /// Server returned a malformed response.
+ ///
+ MalformedResponse,
+
+ ///
+ /// Server returned a response indicating that the name exists, but no data are available.
+ ///
+ NoData,
+
+ ///
+ /// Server returned a response indicating the name does not exist.
+ ///
+ NameError,
+
+ ///
+ /// Network-level error occurred during the query.
+ ///
+ NetworkError,
+
+ ///
+ /// Internal error on part of the implementation.
+ ///
+ InternalError,
+}
diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs
index 98b9de1fd68..7d05243f741 100644
--- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs
+++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs
@@ -1,11 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
-using DnsClient;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.ServiceDiscovery;
using Microsoft.Extensions.ServiceDiscovery.Dns;
+using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
namespace Microsoft.Extensions.Hosting;
@@ -46,7 +46,7 @@ public static IServiceCollection AddDnsSrvServiceEndpointProvider(this IServiceC
ArgumentNullException.ThrowIfNull(configureOptions);
services.AddServiceDiscoveryCore();
- services.TryAddSingleton();
+ services.TryAddSingleton();
services.AddSingleton();
var options = services.AddOptions();
options.Configure(o => configureOptions?.Invoke(o));
diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore
new file mode 100644
index 00000000000..0151cc4e360
--- /dev/null
+++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore
@@ -0,0 +1,2 @@
+# corpuses generated by the fuzzing engine
+corpuses/**
\ No newline at end of file
diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs
new file mode 100644
index 00000000000..1b180d74b9d
--- /dev/null
+++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs
@@ -0,0 +1,44 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics;
+using System.Net;
+using System.Net.Sockets;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing;
+
+internal sealed class DnsResponseFuzzer : IFuzzer
+{
+ DnsResolver? _resolver;
+ byte[]? _buffer;
+ int _length;
+
+ public void FuzzTarget(ReadOnlySpan data)
+ {
+ // lazy init
+ if (_resolver == null)
+ {
+ _buffer = new byte[4096];
+ _resolver = new DnsResolver(new ResolverOptions(new IPEndPoint(IPAddress.Loopback, 53))
+ {
+ Timeout = TimeSpan.FromSeconds(5),
+ Attempts = 1,
+ _transportOverride = (buffer, length) =>
+ {
+ // the first two bytes are the random transaction ID, so we keep that
+ // and use the fuzzing payload for the rest of the DNS response
+ _buffer.AsSpan(0, Math.Min(_length, buffer.Length - 2)).CopyTo(buffer.Span.Slice(2));
+ return _length + 2;
+ }
+ });
+ }
+
+ data.CopyTo(_buffer!);
+ _length = data.Length;
+
+ // the _transportOverride makes the execution synchronous
+ ValueTask task = _resolver!.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork, CancellationToken.None);
+ Debug.Assert(task.IsCompleted, "Task should be completed synchronously");
+ task.GetAwaiter().GetResult();
+ }
+}
\ No newline at end of file
diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs
new file mode 100644
index 00000000000..72f84b3c959
--- /dev/null
+++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs
@@ -0,0 +1,33 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing;
+
+internal sealed class EncodedDomainNameFuzzer : IFuzzer
+{
+ public void FuzzTarget(ReadOnlySpan data)
+ {
+ byte[] buffer = ArrayPool.Shared.Rent(data.Length);
+ try
+ {
+ data.CopyTo(buffer);
+
+ // attempt to read at any offset to really stress the parser
+ for (int i = 0; i < data.Length; i++)
+ {
+ if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, data.Length), i, out EncodedDomainName name, out _))
+ {
+ continue;
+ }
+
+ // the domain name should be readable
+ _ = name.ToString();
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(buffer);
+ }
+
+ }
+}
\ No newline at end of file
diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs
new file mode 100644
index 00000000000..f657245a842
--- /dev/null
+++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs
@@ -0,0 +1,48 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Text;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing;
+
+internal sealed class WriteDomainNameRoundTripFuzzer : IFuzzer
+{
+ private static readonly System.Globalization.IdnMapping s_idnMapping = new();
+ public void FuzzTarget(ReadOnlySpan data)
+ {
+ // first byte is the offset of the domain name, rest is the actual
+ // (simulated) DNS message payload
+
+ byte[] buffer = ArrayPool.Shared.Rent(data.Length * 2);
+
+ try
+ {
+ string domainName = Encoding.UTF8.GetString(data);
+ if (!DnsPrimitives.TryWriteQName(buffer, domainName, out int written))
+ {
+ return;
+ }
+
+ if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, written), 0, out EncodedDomainName name, out int read))
+ {
+ return;
+ }
+
+ if (read != written)
+ {
+ throw new InvalidOperationException($"Read {read} bytes, but wrote {written} bytes");
+ }
+
+ string readName = name.ToString();
+
+ if (!string.Equals(s_idnMapping.GetAscii(domainName).TrimEnd('.'), readName, StringComparison.OrdinalIgnoreCase))
+ {
+ throw new InvalidOperationException($"Domain name mismatch: {readName} != {domainName}");
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(buffer);
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs
new file mode 100644
index 00000000000..2ff9d86b2ce
--- /dev/null
+++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs
@@ -0,0 +1,5 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+global using System.Buffers;
+global using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
\ No newline at end of file
diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs
new file mode 100644
index 00000000000..4b4c8c99b4b
--- /dev/null
+++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs
@@ -0,0 +1,10 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing;
+
+public interface IFuzzer
+{
+ string Name => GetType().Name;
+ void FuzzTarget(ReadOnlySpan data);
+}
\ No newline at end of file
diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj
new file mode 100644
index 00000000000..6572c27a1fa
--- /dev/null
+++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj
@@ -0,0 +1,18 @@
+
+
+
+ $(DefaultTargetFramework)
+ enable
+ enable
+ Exe
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs
new file mode 100644
index 00000000000..22b1580d1ac
--- /dev/null
+++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs
@@ -0,0 +1,64 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using SharpFuzz;
+
+namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing;
+
+public static class Program
+{
+ public static void Main(string[] args)
+ {
+ IFuzzer[] fuzzers = typeof(Program).Assembly.GetTypes()
+ .Where(t => t.IsClass && !t.IsAbstract)
+ .Where(t => t.GetInterfaces().Contains(typeof(IFuzzer)))
+ .Select(t => (IFuzzer)Activator.CreateInstance(t)!)
+ .OrderBy(f => f.Name, StringComparer.OrdinalIgnoreCase)
+ .ToArray();
+
+ void PrintUsage()
+ {
+ Console.Error.WriteLine($"""
+ Usage:
+ DotnetFuzzing list
+ DotnetFuzzing [input file/directory]
+ // DotnetFuzzing prepare-onefuzz