From 2c217a3962841e694b2b934c75991f83e9aef6cf Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 4 Oct 2024 08:45:45 +0200 Subject: [PATCH 01/45] First integration of Resolver sources --- .../TestShop.ServiceDefaults/Extensions.cs | 2 + .../DnsServiceEndpointProvider.cs | 43 +- .../DnsServiceEndpointProviderBase.Log.cs | 4 +- .../DnsServiceEndpointProviderBase.cs | 2 +- .../DnsServiceEndpointProviderFactory.cs | 4 +- .../DnsSrvServiceEndpointProvider.cs | 56 +- .../DnsSrvServiceEndpointProviderFactory.cs | 8 +- ...oft.Extensions.ServiceDiscovery.Dns.csproj | 2 +- .../Resolver/DnsDataReader.cs | 120 +++++ .../Resolver/DnsDataWriter.cs | 104 ++++ .../Resolver/DnsMessageHeader.cs | 106 ++++ .../Resolver/DnsPrimitives.cs | 242 +++++++++ .../Resolver/DnsResolver.cs | 489 ++++++++++++++++++ .../Resolver/DnsResourceRecord.cs | 22 + .../Resolver/DnsResponse.cs | 24 + .../Resolver/NetworkInfo.cs | 36 ++ .../Resolver/QueryClass.cs | 9 + .../Resolver/QueryFlags.cs | 16 + .../Resolver/QueryResponseCode.cs | 42 ++ .../Resolver/QueryType.cs | 55 ++ .../Resolver/ResolvConf.cs | 69 +++ .../Resolver/ResolverOptions.cs | 24 + .../Resolver/ResultTypes.cs | 26 + ...DiscoveryDnsServiceCollectionExtensions.cs | 4 +- .../Resolver/CancellationTests.cs | 31 ++ .../Resolver/DnsDataReaderTests.cs | 64 +++ .../Resolver/DnsDataWriterTests.cs | 140 +++++ .../Resolver/DnsPrimitivesTests.cs | 125 +++++ .../Resolver/LoopbackDnsServer.cs | 221 ++++++++ .../Resolver/LoopbackDnsTestBase.cs | 34 ++ .../Resolver/ResolvConfTests.cs | 25 + .../Resolver/ResolveAddressesTests.cs | 162 ++++++ .../Resolver/ResolveServiceTests.cs | 38 ++ .../Resolver/TcpFailoverTests.cs | 40 ++ .../Resolver/TestTimeProvider.cs | 12 + 35 files changed, 2344 insertions(+), 57 deletions(-) create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TestTimeProvider.cs diff --git a/playground/TestShop/TestShop.ServiceDefaults/Extensions.cs b/playground/TestShop/TestShop.ServiceDefaults/Extensions.cs index f2d77e9ffcc..58d3b5e240c 100644 --- a/playground/TestShop/TestShop.ServiceDefaults/Extensions.cs +++ b/playground/TestShop/TestShop.ServiceDefaults/Extensions.cs @@ -21,6 +21,8 @@ public static TBuilder AddServiceDefaults(this TBuilder builder) where builder.AddDefaultHealthChecks(); builder.Services.AddServiceDiscovery(); + // builder.Services.AddServiceDiscoveryCore(); + // builder.Services.AddDnsSrvServiceEndpointProvider(); builder.Services.ConfigureHttpClientDefaults(http => { diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs index 6cc9f92bc46..52f2da3014a 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs @@ -2,8 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net; +using System.Net.Sockets; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.ServiceDiscovery.Dns; @@ -12,6 +14,7 @@ internal sealed partial class DnsServiceEndpointProvider( string hostName, IOptionsMonitor options, ILogger logger, + DnsResolver resolver, TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature { protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor; @@ -29,17 +32,21 @@ protected override async Task ResolveAsyncCore() var endpoints = new List(); var ttl = DefaultRefreshPeriod; Log.AddressQuery(logger, ServiceName, hostName); - var addresses = await System.Net.Dns.GetHostAddressesAsync(hostName, ShutdownToken).ConfigureAwait(false); - foreach (var address in addresses) + + var now = _timeProvider.GetUtcNow().DateTime; + var ipv4Addresses = resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork, ShutdownToken).ConfigureAwait(false); + var ipv6Addresses = resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetworkV6, ShutdownToken).ConfigureAwait(false); + + foreach (var address in await ipv4Addresses) { - var serviceEndpoint = ServiceEndpoint.Create(new IPEndPoint(address, 0)); - serviceEndpoint.Features.Set(this); - if (options.CurrentValue.ShouldApplyHostNameMetadata(serviceEndpoint)) - { - serviceEndpoint.Features.Set(this); - } + ttl = MinTtl(now, address.ExpiresAt, ttl); + endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, 0))); + } - endpoints.Add(serviceEndpoint); + foreach (var address in await ipv6Addresses) + { + ttl = MinTtl(now, address.ExpiresAt, ttl); + endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, 0))); } if (endpoints.Count == 0) @@ -48,5 +55,23 @@ protected override async Task ResolveAsyncCore() } SetResult(endpoints, ttl); + + static TimeSpan MinTtl(DateTime now, DateTime expiresAt, TimeSpan existing) + { + var candidate = expiresAt - now; + return candidate < existing ? candidate : existing; + } + + ServiceEndpoint CreateEndpoint(EndPoint endPoint) + { + var serviceEndpoint = ServiceEndpoint.Create(endPoint); + serviceEndpoint.Features.Set(this); + if (options.CurrentValue.ShouldApplyHostNameMetadata(serviceEndpoint)) + { + serviceEndpoint.Features.Set(this); + } + + return serviceEndpoint; + } } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs index 29aaaf8e930..9dbfafe4ef6 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs @@ -9,10 +9,10 @@ partial class DnsServiceEndpointProviderBase { internal static partial class Log { - [LoggerMessage(1, LogLevel.Trace, "Resolving endpoints for service '{ServiceName}' using DNS SRV lookup for name '{RecordName}'.", EventName = "SrvQuery")] + [LoggerMessage(1, LogLevel.Information, "Resolving endpoints for service '{ServiceName}' using DNS SRV lookup for name '{RecordName}'.", EventName = "SrvQuery")] public static partial void SrvQuery(ILogger logger, string serviceName, string recordName); - [LoggerMessage(2, LogLevel.Trace, "Resolving endpoints for service '{ServiceName}' using host lookup for name '{RecordName}'.", EventName = "AddressQuery")] + [LoggerMessage(2, LogLevel.Information, "Resolving endpoints for service '{ServiceName}' using host lookup for name '{RecordName}'.", EventName = "AddressQuery")] public static partial void AddressQuery(ILogger logger, string serviceName, string recordName); [LoggerMessage(3, LogLevel.Debug, "Skipping endpoint resolution for service '{ServiceName}': '{Reason}'.", EventName = "SkippedResolution")] diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs index 6c69cc7a760..311c06f631a 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs @@ -14,7 +14,7 @@ internal abstract partial class DnsServiceEndpointProviderBase : IServiceEndpoin private readonly object _lock = new(); private readonly ILogger _logger; private readonly CancellationTokenSource _disposeCancellation = new(); - private readonly TimeProvider _timeProvider; + protected readonly TimeProvider _timeProvider; private long _lastRefreshTimeStamp; private Task _resolveTask = Task.CompletedTask; private bool _hasEndpoints; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs index c241ad89dd3..80fb009c0d1 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs @@ -4,18 +4,20 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.ServiceDiscovery.Dns; internal sealed partial class DnsServiceEndpointProviderFactory( IOptionsMonitor options, ILogger logger, + DnsResolver resolver, TimeProvider timeProvider) : IServiceEndpointProviderFactory { /// public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider) { - provider = new DnsServiceEndpointProvider(query, hostName: query.ServiceName, options, logger, timeProvider); + provider = new DnsServiceEndpointProvider(query, hostName: query.ServiceName, options, logger, resolver, timeProvider); return true; } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs index c174cda4f68..9532c7a9ed5 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net; -using DnsClient; -using DnsClient.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.ServiceDiscovery.Dns; @@ -15,7 +14,7 @@ internal sealed partial class DnsSrvServiceEndpointProvider( string hostName, IOptionsMonitor options, ILogger logger, - IDnsQuery dnsClient, + DnsResolver resolver, TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature { protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor; @@ -35,56 +34,39 @@ protected override async Task ResolveAsyncCore() var endpoints = new List(); var ttl = DefaultRefreshPeriod; Log.SrvQuery(logger, ServiceName, srvQuery); - var result = await dnsClient.QueryAsync(srvQuery, QueryType.SRV, cancellationToken: ShutdownToken).ConfigureAwait(false); - if (result.HasError) - { - throw CreateException(srvQuery, result.ErrorMessage); - } + logger.LogInformation("Resolving endpoints for service '{ServiceName}' using DNS SRV lookup for name '{RecordName}'.", ServiceName, srvQuery); - var lookupMapping = new Dictionary(); - foreach (var record in result.Additionals.Where(x => x is AddressRecord or CNameRecord)) - { - ttl = MinTtl(record, ttl); - lookupMapping[record.DomainName] = record; - } + var now = _timeProvider.GetUtcNow().DateTime; + var result = await resolver.ResolveServiceAsync(srvQuery, cancellationToken: ShutdownToken).ConfigureAwait(false); + + logger.LogInformation("Resolved {Number} entries", result.Length); - var srvRecords = result.Answers.OfType(); - foreach (var record in srvRecords) + foreach (var record in result) { - if (!lookupMapping.TryGetValue(record.Target, out var targetRecord)) - { - continue; - } + ttl = MinTtl(now, record.ExpiresAt, ttl); - ttl = MinTtl(record, ttl); - if (targetRecord is AddressRecord addressRecord) + if (record.Addresses.Length > 0) { - endpoints.Add(CreateEndpoint(new IPEndPoint(addressRecord.Address, record.Port))); + foreach (var address in record.Addresses) + { + ttl = MinTtl(now, address.ExpiresAt, ttl); + endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, record.Port))); + } } - else if (targetRecord is CNameRecord canonicalNameRecord) + else { - endpoints.Add(CreateEndpoint(new DnsEndPoint(canonicalNameRecord.CanonicalName.Value.TrimEnd('.'), record.Port))); + endpoints.Add(CreateEndpoint(new DnsEndPoint(record.Target.TrimEnd('.'), record.Port))); } } SetResult(endpoints, ttl); - static TimeSpan MinTtl(DnsResourceRecord record, TimeSpan existing) + static TimeSpan MinTtl(DateTime now, DateTime expiresAt, TimeSpan existing) { - var candidate = TimeSpan.FromSeconds(record.TimeToLive); + var candidate = expiresAt - now; return candidate < existing ? candidate : existing; } - InvalidOperationException CreateException(string dnsName, string errorMessage) - { - var msg = errorMessage switch - { - { Length: > 0 } => $"No DNS records were found for service '{ServiceName}' (DNS name: '{dnsName}'): {errorMessage}.", - _ => $"No DNS records were found for service '{ServiceName}' (DNS name: '{dnsName}')." - }; - return new InvalidOperationException(msg); - } - ServiceEndpoint CreateEndpoint(EndPoint endPoint) { var serviceEndpoint = ServiceEndpoint.Create(endPoint); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs index fd0cb28353d..fb7d006ae9a 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs @@ -2,16 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; -using DnsClient; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.ServiceDiscovery.Dns; internal sealed partial class DnsSrvServiceEndpointProviderFactory( IOptionsMonitor options, ILogger logger, - IDnsQuery dnsClient, + DnsResolver resolver, TimeProvider timeProvider) : IServiceEndpointProviderFactory { private static readonly string s_serviceAccountPath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount"); @@ -24,7 +24,7 @@ public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] ou { // If a default namespace is not specified, then this provider will attempt to infer the namespace from the service name, but only when running inside Kubernetes. // Kubernetes DNS spec: https://github.com/kubernetes/dns/blob/master/docs/specification.md - // SRV records are available for headless services with named ports. + // SRV records are available for headless services with named ports. // They take the form $"_{portName}._{protocol}.{serviceName}.{namespace}.{suffix}" // The suffix (after the service name) can be parsed from /etc/resolv.conf // Otherwise, the namespace can be read from /var/run/secrets/kubernetes.io/serviceaccount/namespace and combined with an assumed suffix of "svc.cluster.local". @@ -39,7 +39,7 @@ public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] ou var portName = query.EndpointName ?? "default"; var srvQuery = $"_{portName}._tcp.{query.ServiceName}.{_querySuffix}"; - provider = new DnsSrvServiceEndpointProvider(query, srvQuery, hostName: query.ServiceName, options, logger, dnsClient, timeProvider); + provider = new DnsSrvServiceEndpointProvider(query, srvQuery, hostName: query.ServiceName, options, logger, resolver, timeProvider); return true; } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj index 3f503049e83..a5b4993f286 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj @@ -27,7 +27,7 @@ - + diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs new file mode 100644 index 00000000000..227698107df --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -0,0 +1,120 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct DnsDataReader : IDisposable +{ + private ReadOnlyMemory _buffer; + private byte[]? _pooled; + private int _position; + + public DnsDataReader(ReadOnlyMemory buffer, byte[]? returnToPool = null) + { + _buffer = buffer; + _position = 0; + _pooled = returnToPool; + } + + public bool TryReadHeader(out DnsMessageHeader header) + { + if (_buffer.Length - _position < DnsMessageHeader.HeaderLength) + { + header = default; + return false; + } + + _position += DnsMessageHeader.HeaderLength; + header = MemoryMarshal.AsRef(_buffer.Span); + return true; + } + + internal bool TryReadQuestion([NotNullWhen(true)] out string? name, out QueryType type, out QueryClass @class) + { + if (!TryReadDomainName(out name) || + !TryReadUInt16(out ushort typeAsInt) || + !TryReadUInt16(out ushort classAsInt)) + { + type = 0; + @class = 0; + return false; + } + + type = (QueryType)typeAsInt; + @class = (QueryClass)classAsInt; + return true; + } + + public bool TryReadUInt16(out ushort value) + { + if (_buffer.Length - _position < 2) + { + value = 0; + return false; + } + + value = BinaryPrimitives.ReadUInt16BigEndian(_buffer.Span.Slice(_position)); + _position += 2; + return true; + } + + public bool TryReadUInt32(out uint value) + { + if (_buffer.Length - _position < 4) + { + value = 0; + return false; + } + + value = BinaryPrimitives.ReadUInt32BigEndian(_buffer.Span.Slice(_position)); + _position += 4; + return true; + } + + public bool TryReadResourceRecord(out DnsResourceRecord record) + { + if (!TryReadDomainName(out string? name) || + !TryReadUInt16(out ushort type) || + !TryReadUInt16(out ushort @class) || + !TryReadUInt32(out uint ttl) || + !TryReadUInt16(out ushort dataLength) || + _buffer.Length - _position < dataLength) + { + record = default; + return false; + } + + ReadOnlyMemory data = _buffer.Slice(_position, dataLength); + _position += dataLength; + + record = new DnsResourceRecord(name, (QueryType)type, (QueryClass)@class, (int)ttl, data); + return true; + } + + public bool TryReadDomainName([NotNullWhen(true)] out string? name) + { + if (DnsPrimitives.TryReadQName(_buffer.Span, _position, out name, out int bytesRead)) + { + _position += bytesRead; + return true; + } + + return false; + } + + public void Dispose() + { + if (_pooled is not null) + { + ArrayPool.Shared.Return(_pooled); + _pooled = null!; + } + + _buffer = default; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs new file mode 100644 index 00000000000..4abbf277328 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers.Binary; +using System.Runtime.InteropServices; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal class DnsDataWriter +{ + private readonly Memory _buffer; + private int _position; + + internal DnsDataWriter(Memory buffer) + { + _buffer = buffer; + _position = 0; + } + + public int Position => _position; + + internal bool TryWriteHeader(in DnsMessageHeader header) + { + if (_buffer.Length - _position < DnsMessageHeader.HeaderLength) + { + return false; + } + + MemoryMarshal.Write(_buffer.Span, in header); + _position += DnsMessageHeader.HeaderLength; + return true; + } + + internal bool TryWriteQuestion(string name, QueryType type, QueryClass @class) + { + if (!TryWriteDomainName(name) || + !TryWriteUInt16((ushort)type) || + !TryWriteUInt16((ushort)@class)) + { + return false; + } + + return true; + } + + internal bool TryWriteResourceRecord(in DnsResourceRecord record) + { + if (!TryWriteDomainName(record.Name) || + !TryWriteUInt16((ushort)record.Type) || + !TryWriteUInt16((ushort)record.Class) || + !TryWriteUInt32((uint)record.Ttl)) + { + return false; + } + + if (record.Data.Length + 2 > _buffer.Length - _position) + { + return false; + } + + BinaryPrimitives.WriteUInt16BigEndian(_buffer.Span.Slice(_position), (ushort)record.Data.Length); + _position += 2; + + record.Data.Span.CopyTo(_buffer.Span.Slice(_position)); + _position += record.Data.Length; + + return true; + } + + internal bool TryWriteDomainName(string name) + { + if (DnsPrimitives.TryWriteQName(_buffer.Span.Slice(_position), name, out int written)) + { + _position += written; + return true; + } + + return false; + } + + internal bool TryWriteUInt16(ushort value) + { + if (_buffer.Length - _position < 2) + { + return false; + } + + BinaryPrimitives.WriteUInt16BigEndian(_buffer.Span.Slice(_position), value); + _position += 2; + return true; + } + + internal bool TryWriteUInt32(uint value) + { + if (_buffer.Length - _position < 4) + { + return false; + } + + BinaryPrimitives.WriteUInt32BigEndian(_buffer.Span.Slice(_position), value); + _position += 4; + return true; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs new file mode 100644 index 00000000000..c920375aa9c --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs @@ -0,0 +1,106 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers.Binary; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +// RFC 1035 4.1.1. Header section format +internal struct DnsMessageHeader +{ + internal const int HeaderLength = 12; + + private ushort _transactionId; + private ushort _flags; + + private ushort _queryCount; + private ushort _answerCount; + private ushort _authorityCount; + private ushort _additionalRecordCount; + + internal ushort QueryCount + { + get => ReverseByteOrder(_queryCount); + set => _queryCount = ReverseByteOrder(value); + } + + internal ushort AnswerCount + { + get => ReverseByteOrder(_answerCount); + set => _answerCount = ReverseByteOrder(value); + } + + internal ushort AuthorityCount + { + get => ReverseByteOrder(_authorityCount); + set => _authorityCount = ReverseByteOrder(value); + } + + internal ushort AdditionalRecordCount + { + get => ReverseByteOrder(_additionalRecordCount); + set => _additionalRecordCount = ReverseByteOrder(value); + } + + internal ushort TransactionId + { + get => ReverseByteOrder(_transactionId); + set => _transactionId = ReverseByteOrder(value); + } + + internal QueryFlags QueryFlags + { + get => (QueryFlags)ReverseByteOrder(_flags); + set => _flags = ReverseByteOrder((ushort)value); + } + + internal bool IsRecursionDesired + { + get => (QueryFlags & QueryFlags.RecursionDesired) != 0; + set + { + if (value) + { + QueryFlags |= QueryFlags.RecursionDesired; + } + else + { + QueryFlags &= ~QueryFlags.RecursionDesired; + } + } + } + + internal QueryResponseCode ResponseCode + { + get => (QueryResponseCode)((_flags & 0x0F00) >> 8); + set => _flags = (ushort)((_flags & 0xF0FF) | ((ushort)value << 8)); + } + + internal bool IsResultTruncated => (QueryFlags & QueryFlags.ResultTruncated) != 0; + + internal bool IsResponse + { + get => (QueryFlags & QueryFlags.HasResponse) != 0; + set + { + if (value) + { + QueryFlags |= QueryFlags.HasResponse; + } + else + { + QueryFlags &= ~QueryFlags.HasResponse; + } + } + } + + internal void InitQueryHeader() + { + this = default; + TransactionId = (ushort)Random.Shared.Next(ushort.MaxValue); + IsRecursionDesired = true; + QueryCount = 1; + } + + private static ushort ReverseByteOrder(ushort value) => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value; +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs new file mode 100644 index 00000000000..96467442445 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs @@ -0,0 +1,242 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers.Binary; +using System.Diagnostics.CodeAnalysis; +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal static class DnsPrimitives +{ + internal static bool TryWriteQName(Span destination, string name, out int written) + { + // + // RFC 1035 4.1.2. + // + // a domain name represented as a sequence of labels, where + // each label consists of a length octet followed by that + // number of octets. The domain name terminates with the + // zero length octet for the null label of the root. Note + // that this field may be an odd number of octets; no + // padding is used. + // + + if (!Encoding.ASCII.TryGetBytes(name, destination.IsEmpty ? destination : destination.Slice(1), out int length) || destination.Length < length + 2) + { + // buffer too small + written = 0; + return false; + } + + destination[1 + length] = 0; // last label (root) + + Span nameBuffer = destination.Slice(0, 1 + length); + while (true) + { + // figure out the next label and prepend the length + int index = nameBuffer.Slice(1).IndexOf((byte)'.'); + int labelLen = index == -1 ? nameBuffer.Length - 1 : index; + + if (labelLen > 63) + { + throw new ArgumentException("Label is too long"); + } + + nameBuffer[0] = (byte)labelLen; + if (index == -1) + { + // this was the last label + break; + } + + nameBuffer = nameBuffer.Slice(index + 1); + } + + written = length + 2; + return true; + } + + private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messageBuffer, int offset, out int bytesRead) + { + // + // domain name can be either + // - a sequence of labels, where each label consists of a length octet + // followed by that number of octets, terminated by a zero length octet + // (root label) + // - a pointer, where the first two bits are set to 1, and the remaining + // 14 bits are an offset (from the start of the message) to the true + // label + // + // It is not specified by the RFC if pointers must be backwards only, + // the code below prohibits forward (and self) pointers to avoid + // infinite loops + // + + bytesRead = 0; + + if (offset < 0 || offset >= messageBuffer.Length) + { + return false; + } + + int currentOffset = offset; + + while (true) + { + byte length = messageBuffer[currentOffset]; + + if ((length & 0xC0) == 0x00) + { + // length followed by the label + if (length == 0) + { + // end of name + bytesRead = currentOffset - offset + 1; + return true; + } + + if (currentOffset + 1 + length < messageBuffer.Length) + { + // read next label/segment + if (sb.Length > 0) + { + sb.Append('.'); + } + sb.Append(Encoding.ASCII.GetString(messageBuffer.Slice(currentOffset + 1, length))); + currentOffset += 1 + length; + bytesRead += 1 + length; + } + else + { + // truncated data + break; + } + } + else if ((length & 0xC0) == 0xC0) + { + // pointer, together with next byte gives the offset of the true label + if (currentOffset + 1 < messageBuffer.Length) + { + bytesRead += 2; + int pointer = ((length & 0x3F) << 8) | messageBuffer[currentOffset + 1]; + + // we prohibit self-references and forward pointers to avoid + // infinite loops, we do this by truncating the + // messageBuffer at the offset where we started reading the + // name. We also ignore the bytesRead from the recursive + // call, as we are only interested on how many bytes we read + // from the initial start of the name. + return TryReadQNameCore(sb, messageBuffer.Slice(0, offset), pointer, out int _); + } + else + { + // truncated data + break; + } + } + else + { + // top two bits are reserved, this means invalid data + break; + } + } + + return false; + + } + + internal static bool TryReadQName(ReadOnlySpan messageBuffer, int offset, [NotNullWhen(true)] out string? name, out int bytesRead) + { + StringBuilder sb = new StringBuilder(); + + if (TryReadQNameCore(sb, messageBuffer, offset, out bytesRead)) + { + name = sb.ToString(); + return true; + } + else + { + bytesRead = 0; + name = null; + return false; + } + } + + internal static bool TryReadService(ReadOnlySpan buffer, out ushort priority, out ushort weight, out ushort port, [NotNullWhen(true)] out string? target, out int bytesRead) + { + if (!BinaryPrimitives.TryReadUInt16BigEndian(buffer, out priority) || + !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Slice(2), out weight) || + !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Slice(4), out port) || + !TryReadQName(buffer.Slice(6), 0, out target, out bytesRead)) + { + target = null; + priority = 0; + weight = 0; + port = 0; + bytesRead = 0; + return false; + } + + bytesRead += 6; + return true; + } + + internal static bool TryWriteService(Span buffer, ushort priority, ushort weight, ushort port, string target, out int bytesWritten) + { + if (!BinaryPrimitives.TryWriteUInt16BigEndian(buffer, priority) || + !BinaryPrimitives.TryWriteUInt16BigEndian(buffer.Slice(2), weight) || + !BinaryPrimitives.TryWriteUInt16BigEndian(buffer.Slice(4), port) || + !TryWriteQName(buffer.Slice(6), target, out bytesWritten)) + { + bytesWritten = 0; + return false; + } + + bytesWritten += 6; + return true; + } + + internal static bool TryWriteSoa(Span buffer, string primaryNameServer, string responsibleMailAddress, uint serial, uint refresh, uint retry, uint expire, uint minimum, out int bytesWritten) + { + if (!TryWriteQName(buffer, primaryNameServer, out int w1) || + !TryWriteQName(buffer.Slice(w1), responsibleMailAddress, out int w2) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2), serial) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2 + 4), refresh) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2 + 8), retry) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2 + 12), expire) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2 + 16), minimum)) + { + bytesWritten = 0; + return false; + } + + bytesWritten = w1 + w2 + 20; + return true; + } + + internal static bool TryReadSoa(ReadOnlySpan buffer, [NotNullWhen(true)] out string? primaryNameServer, [NotNullWhen(true)] out string? responsibleMailAddress, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out int bytesRead) + { + if (!TryReadQName(buffer, 0, out primaryNameServer, out int w1) || + !TryReadQName(buffer.Slice(w1), 0, out responsibleMailAddress, out int w2) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2), out serial) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2 + 4), out refresh) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2 + 8), out retry) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2 + 12), out expire) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2 + 16), out minimum)) + { + primaryNameServer = null!; + responsibleMailAddress = null!; + serial = 0; + refresh = 0; + retry = 0; + expire = 0; + minimum = 0; + bytesRead = 0; + return false; + } + + bytesRead = w1 + w2 + 20; + return true; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs new file mode 100644 index 00000000000..dd8d9ebbf08 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -0,0 +1,489 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal class DnsResolver : IDisposable +{ + private const int MaximumNameLength = 253; + private const int IPv4Length = 4; + private const int IPv6Length = 16; + private const int HeaderSize = 12; + + private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue); + + bool _disposed; + private readonly ResolverOptions _options; + private TimeSpan _timeout = System.Threading.Timeout.InfiniteTimeSpan; + private readonly CancellationTokenSource _pendingRequestsCts = new(); + + private TimeProvider _timeProvider = TimeProvider.System; + + internal void SetTimeProvider(TimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + public DnsResolver() : this(OperatingSystem.IsWindows() ? NetworkInfo.GetOptions() : ResolvConf.GetOptions()) + { + } + + internal DnsResolver(ResolverOptions options) + { + _options = options; + if (options.Servers.Length == 0) + { + throw new ArgumentException("There are no DNS servers configured.", nameof(options)); + } + } + + internal DnsResolver(IEnumerable servers) : this(new ResolverOptions(servers.ToArray())) + { + } + + internal DnsResolver(IPEndPoint server) : this(new ResolverOptions(server)) + { + } + + public TimeSpan Timeout + { + get => _timeout; + set + { + ObjectDisposedException.ThrowIf(_disposed, this); + + if (value != System.Threading.Timeout.InfiniteTimeSpan) + { + ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(value, TimeSpan.Zero); + ArgumentOutOfRangeException.ThrowIfGreaterThan(value, s_maxTimeout); + } + _timeout = value; + } + } + + public async ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed, this); + cancellationToken.ThrowIfCancellationRequested(); + + DnsResponse record = await SendQueryAsync(name, QueryType.SRV, cancellationToken).ConfigureAwait(false); + if (!ValidateResponse(record)) + { + return Array.Empty(); + } + + var results = new List(record.Answers.Count); + + foreach (var answer in record.Answers) + { + if (answer.Type == QueryType.SRV) + { + bool success = DnsPrimitives.TryReadService(answer.Data.Span, out ushort priority, out ushort weight, out ushort port, out string? target, out _); + Debug.Assert(success, "Failed to read SRV"); + + List addresses = new List(); + foreach (var additional in record.Additionals) + { + if (additional.Name == target && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA)) + { + addresses.Add(new AddressResult(record.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span))); + } + } + + results.Add(new ServiceResult(record.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target!, addresses.ToArray())); + } + } + + var result = results.ToArray(); + return result; + } + + public async ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed, this); + cancellationToken.ThrowIfCancellationRequested(); + + if (addressFamily != AddressFamily.InterNetwork && addressFamily != AddressFamily.InterNetworkV6 && addressFamily != AddressFamily.Unspecified) + { + throw new ArgumentOutOfRangeException(nameof(addressFamily), addressFamily, "Invalid address family"); + } + + if (name.Length > MaximumNameLength) + { + throw new ArgumentException("Name is too long", nameof(name)); + } + + var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; + + DnsResponse record = await SendQueryAsync(name, queryType, cancellationToken).ConfigureAwait(false); + if (!ValidateResponse(record)) + { + return Array.Empty(); + } + + var results = new List(record.Answers.Count); + + // servers send back CNAME records together with associated A/AAAA records + string currentAlias = name; + + foreach (var answer in record.Answers) + { + if (answer.Name != currentAlias) + { + continue; + } + + if (answer.Type == QueryType.CNAME) + { + bool success = DnsPrimitives.TryReadQName(answer.Data.Span, 0, out currentAlias!, out _); + Debug.Assert(success, "Failed to read CNAME"); + continue; + } + + else if (answer.Type == queryType) + { + Debug.Assert(answer.Data.Length == IPv4Length || answer.Data.Length == IPv6Length); + results.Add(new AddressResult(record.CreatedAt.AddSeconds(answer.Ttl), new IPAddress(answer.Data.Span))); + } + } + + var result = results.ToArray(); + return result; + } + + internal static bool GetNegativeCacheExpiration(in DnsResponse response, out DateTime expiration) + { + // + // RFC 2308 Section 5 - Caching Negative Answers + // + // Like normal answers negative answers have a time to live (TTL). As + // there is no record in the answer section to which this TTL can be + // applied, the TTL must be carried by another method. This is done by + // including the SOA record from the zone in the authority section of + // the reply. When the authoritative server creates this record its TTL + // is taken from the minimum of the SOA.MINIMUM field and SOA's TTL. + // This TTL decrements in a similar manner to a normal cached answer and + // upon reaching zero (0) indicates the cached negative answer MUST NOT + // be used again. + // + + DnsResourceRecord? soa = response.Authorities.FirstOrDefault(r => r.Type == QueryType.SOA); + if (soa != null && DnsPrimitives.TryReadSoa(soa.Value.Data.Span, out string? mname, out string? rname, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out _)) + { + expiration = response.CreatedAt.AddSeconds(Math.Min(minimum, soa.Value.Ttl)); + return true; + } + + expiration = default; + return false; + } + + internal static bool ValidateResponse(in DnsResponse response) + { + if (response.Header.ResponseCode == QueryResponseCode.NoError) + { + if (response.Answers.Count > 0) + { + return true; + } + // + // RFC 2308 Section 2.2 - No Data + // + // NODATA is indicated by an answer with the RCODE set to NOERROR and no + // relevant answers in the answer section. The authority section will + // contain an SOA record, or there will be no NS records there. + // + // + // RFC 2308 Section 5 - Caching Negative Answers + // + // A negative answer that resulted from a no data error (NODATA) should + // be cached such that it can be retrieved and returned in response to + // another query for the same that resulted in + // the cached negative response. + // + if (!response.Authorities.Any(r => r.Type == QueryType.NS) && GetNegativeCacheExpiration(response, out DateTime expiration)) + { + // _cache.TryAdd(name, queryType, expiration, Array.Empty()); + } + return false; + } + + if (response.Header.ResponseCode == QueryResponseCode.NameError) + { + // + // RFC 2308 Section 5 - Caching Negative Answers + // + // A negative answer that resulted from a name error (NXDOMAIN) should + // be cached such that it can be retrieved and returned in response to + // another query for the same that resulted in the + // cached negative response. + // + if (GetNegativeCacheExpiration(response, out DateTime expiration)) + { + // _cache.TryAddNonexistent(name, expiration); + } + + return false; + } + + return true; + } + + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, CancellationToken cancellationToken) + { + var buffer = ArrayPool.Shared.Rent(8 * 1024); + try + { + // When sending over TCP, the message is prefixed by 2B length + (ushort transactionId, int length) = EncodeQuestion(buffer.AsMemory(2), name, queryType); + BinaryPrimitives.WriteUInt16BigEndian(buffer, (ushort)length); + + using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(serverEndPoint, cancellationToken).ConfigureAwait(false); + await socket.SendAsync(buffer.AsMemory(0, length + 2), SocketFlags.None, cancellationToken).ConfigureAwait(false); + + int responseLength = -1; + int bytesRead = 0; + while (responseLength < 0 || bytesRead < length + 2) + { + int read = await socket.ReceiveAsync(buffer.AsMemory(bytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false); + bytesRead += read; + + if (responseLength < 0 && bytesRead >= 2) + { + responseLength = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); + + if (responseLength > buffer.Length) + { + var largerBuffer = ArrayPool.Shared.Rent(responseLength); + Array.Copy(buffer, largerBuffer, bytesRead); + ArrayPool.Shared.Return(buffer); + buffer = largerBuffer; + } + } + } + + DnsDataReader responseReader = new DnsDataReader(buffer.AsMemory(2, responseLength), buffer); + if (!responseReader.TryReadHeader(out DnsMessageHeader header) || + header.TransactionId != transactionId || + !header.IsResponse) + { + throw new InvalidOperationException("Invalid response: Header mismatch"); + } + + buffer = null!; + return (responseReader, header); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreUdpAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, CancellationToken cancellationToken) + { + var buffer = ArrayPool.Shared.Rent(512); + try + { + Memory memory = buffer; + (ushort transactionId, int length) = EncodeQuestion(memory, name, queryType); + + using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + await socket.ConnectAsync(serverEndPoint, cancellationToken).ConfigureAwait(false); + + await socket.SendAsync(memory.Slice(0, length), SocketFlags.None, cancellationToken).ConfigureAwait(false); + + DnsDataReader responseReader; + DnsMessageHeader header; + + while (true) + { + int readLength = await socket.ReceiveAsync(memory, SocketFlags.None, cancellationToken).ConfigureAwait(false); + + if (readLength < HeaderSize) + { + continue; + } + + responseReader = new DnsDataReader(memory.Slice(0, readLength), buffer); + if (!responseReader.TryReadHeader(out header) || + header.TransactionId != transactionId || + !header.IsResponse) + { + // the message is not a response for our query. + // don't dispose reader, we will reuse the buffer + continue; + } + + // ownership of the buffer is transferred to the reader, caller will dispose. + buffer = null!; + return (responseReader, header); + } + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + + internal async ValueTask SendQueryAsync(string name, QueryType queryType, CancellationToken cancellationToken) + { + (CancellationTokenSource cts, bool disposeTokenSource, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken); + + try + { + return await SendQueryAsyncSlow(name, queryType, cts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException oce) when ( + !cancellationToken.IsCancellationRequested && // not cancelled by the caller + !pendingRequestsCts.IsCancellationRequested) // not cancelled by the global token (dispose) + // the only remaining token that could cancel this is the linked cts from the timeout. + { + Debug.Assert(cts.Token.IsCancellationRequested); + throw new TimeoutException("The operation has timed out.", oce); + } + finally + { + if (disposeTokenSource) + { + cts.Dispose(); + } + } + + async ValueTask SendQueryAsyncSlow(string name, QueryType queryType, CancellationToken cancellationToken) + { + DnsDataReader responseReader = default; + DnsMessageHeader header = default; + DateTime queryStartedTime = default; + + foreach (IPEndPoint serverEndPoint in _options.Servers) + { + queryStartedTime = _timeProvider.GetUtcNow().DateTime; + (responseReader, header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); + + if (header.IsResultTruncated) + { + responseReader.Dispose(); + // TCP fallback + (responseReader, header) = await SendDnsQueryCoreTcpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); + } + + if (header.QueryCount != 1 || + !responseReader.TryReadQuestion(out var qName, out var qType, out var qClass) || + qName != name || qType != queryType || qClass != QueryClass.Internet) + { + // TODO: do we care? + throw new InvalidOperationException("Invalid response: Query mismatch"); + // return default; + } + + // TODO: on which response codes should we retry? + if (header.ResponseCode == QueryResponseCode.NoError) + { + break; + } + + responseReader.Dispose(); + } + + int ttl = int.MaxValue; + List answers = ReadRecords(header.AnswerCount, ref ttl, ref responseReader); + List authorities = ReadRecords(header.AuthorityCount, ref ttl, ref responseReader); + List additionals = ReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader); + + DnsResponse record = new(header, queryStartedTime, queryStartedTime.AddSeconds(ttl), answers, authorities, additionals); + responseReader.Dispose(); + + return record; + + static List ReadRecords(int count, ref int ttl, ref DnsDataReader reader) + { + List records = new(count); + + for (int i = 0; i < count; i++) + { + if (!reader.TryReadResourceRecord(out var record)) + { + // TODO how to handle corrupted responses? + throw new InvalidOperationException("Invalid response: Answer record"); + } + + ttl = Math.Min(ttl, record.Ttl); + // copy the data to a new array since the underlying array is pooled + records.Add(new DnsResourceRecord(record.Name, record.Type, record.Class, record.Ttl, record.Data.ToArray())); + } + + return records; + } + } + } + + private static (ushort id, int length) EncodeQuestion(Memory buffer, string name, QueryType queryType) + { + DnsMessageHeader header = default; + header.InitQueryHeader(); + DnsDataWriter writer = new DnsDataWriter(buffer); + if (!writer.TryWriteHeader(header) || + !writer.TryWriteQuestion(name, queryType, QueryClass.Internet)) + { + // should never happen since we validated the name length + throw new InvalidOperationException("Buffer too small"); + } + return (header.TransactionId, writer.Position); + } + + public void Dispose() + { + if (!_disposed) + { + _disposed = true; + + // Cancel all pending requests (if any). Note that we don't call CancelPendingRequests() but cancel + // the CTS directly. The reason is that CancelPendingRequests() would cancel the current CTS and create + // a new CTS. We don't want a new CTS in this case. + _pendingRequestsCts.Cancel(); + _pendingRequestsCts.Dispose(); + } + } + + private (CancellationTokenSource TokenSource, bool DisposeTokenSource, CancellationTokenSource PendingRequestsCts) PrepareCancellationTokenSource(CancellationToken cancellationToken) + { + // We need a CancellationTokenSource to use with the request. We always have the global + // _pendingRequestsCts to use, plus we may have a token provided by the caller, and we may + // have a timeout. If we have a timeout or a caller-provided token, we need to create a new + // CTS (we can't, for example, timeout the pending requests CTS, as that could cancel other + // unrelated operations). Otherwise, we can use the pending requests CTS directly. + + // Snapshot the current pending requests cancellation source. It can change concurrently due to cancellation being requested + // and it being replaced, and we need a stable view of it: if cancellation occurs and the caller's token hasn't been canceled, + // it's either due to this source or due to the timeout, and checking whether this source is the culprit is reliable whereas + // it's more approximate checking elapsed time. + CancellationTokenSource pendingRequestsCts = _pendingRequestsCts; + + bool hasTimeout = _timeout != System.Threading.Timeout.InfiniteTimeSpan; + if (hasTimeout || cancellationToken.CanBeCanceled) + { + CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, pendingRequestsCts.Token); + if (hasTimeout) + { + cts.CancelAfter(_timeout); + } + + return (cts, DisposeTokenSource: true, pendingRequestsCts); + } + + return (pendingRequestsCts, DisposeTokenSource: false, pendingRequestsCts); + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs new file mode 100644 index 00000000000..929fa893fd5 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct DnsResourceRecord +{ + public string Name { get; } + public QueryType Type { get; } + public QueryClass Class { get; } + public int Ttl { get; } + public ReadOnlyMemory Data { get; } + + public DnsResourceRecord(string name, QueryType type, QueryClass @class, int ttl, ReadOnlyMemory data) + { + Name = name; + Type = type; + Class = @class; + Ttl = ttl; + Data = data; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs new file mode 100644 index 00000000000..582cb282730 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct DnsResponse +{ + public DnsMessageHeader Header { get; } + public List Answers { get; } + public List Authorities { get; } + public List Additionals { get; } + public DateTime CreatedAt { get; } + public DateTime Expiration { get; } + + public DnsResponse(DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals) + { + Header = header; + CreatedAt = createdAt; + Expiration = expiration; + Answers = answers; + Authorities = authorities; + Additionals = additionals; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs new file mode 100644 index 00000000000..fb9f331559b --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.NetworkInformation; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal static class NetworkInfo +{ + // basic option to get DNS serves via NetworkInfo. We may get it directly later via proper APIs. + public static ResolverOptions GetOptions() + { + List servers = new List(); + + foreach (NetworkInterface nic in NetworkInterface.GetAllNetworkInterfaces()) + { + IPInterfaceProperties properties = nic.GetIPProperties(); + // avoid loopback, VPN etc. Should be re-visited. + + if (nic.NetworkInterfaceType == NetworkInterfaceType.Ethernet && nic.OperationalStatus == OperationalStatus.Up) + { + foreach (IPAddress server in properties.DnsAddresses) + { + IPEndPoint ep = new IPEndPoint(server, 53); + if (!servers.Contains(ep)) + { + servers.Add(ep); + } + } + } + } + + return new ResolverOptions(servers!.ToArray()); + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs new file mode 100644 index 00000000000..732ca0216da --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal enum QueryClass +{ + Internet = 1 +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs new file mode 100644 index 00000000000..94fa019f54f --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +[Flags] +internal enum QueryFlags : ushort +{ + IsCheckingDisabled = 0x0010, + IsAuthenticData = 0x0020, + RecursionAvailable = 0x0080, + RecursionDesired = 0x0100, + ResultTruncated = 0x0200, + HasAuthorityAnswer = 0x0400, + HasResponse = 0x8000, +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs new file mode 100644 index 00000000000..20b8790f54c --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +/// +/// The response code (RCODE) in a DNS query response. +/// +enum QueryResponseCode : byte +{ + /// + /// No error condition + /// + NoError = 0, + + /// + /// The name server was unable to interpret the query. + /// + FormatError = 1, + + /// + /// The name server was unable to process this query due to a problem with the name server. + /// + ServerFailure = 2, + + /// + /// The name server does not support the requested kind of query. + /// + NotImplemented = 4, + + /// + /// Meaningful only for responses from an authoritative name server, this + /// code signifies that the domain name referenced in the query does not + /// exist. + /// + NameError = 3, + + /// + /// The name server refuses to perform the specified operation for policy reasons. + /// + Refused = 5, +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs new file mode 100644 index 00000000000..2ccc898a5b7 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +/// +/// DNS Query Types. +/// +internal enum QueryType +{ + /// + /// A host address. + /// + A = 1, + + /// + /// An authoritative name server. + /// + NS = 2, + + /// + /// The canonical name for an alias. + /// + CNAME = 5, + + /// + /// Marks the start of a zone of authority. + /// + SOA = 6, + + /// + /// Mail exchange. + /// + MX = 15, + + /// + /// Text strings. + /// + TXT = 16, + + /// + /// IPv6 host address. (RFC 3596) + /// + AAAA = 28, + + /// + /// Location information. (RFC 2782) + /// + SRV = 33, + + /// + /// Wildcard match. + /// + All = 255 +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs new file mode 100644 index 00000000000..4312b9d9d84 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs @@ -0,0 +1,69 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Sockets; +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal static class ResolvConf +{ + public static ResolverOptions GetOptions() + { + return GetOptions(new StreamReader("/etc/resolv.conf")); + } + + public static ResolverOptions GetOptions(TextReader reader) + { + int serverCount = 0; + int domainCount = 0; + + string[] lines = reader.ReadToEnd().Split('\n', StringSplitOptions.RemoveEmptyEntries); + foreach (string line in lines) + { + if (line.StartsWith("nameserver")) + { + serverCount++; + } + else if (line.StartsWith("search")) + { + domainCount++; + } + } + + if (serverCount == 0) + { + throw new SocketException((int)SocketError.AddressNotAvailable); + } + + IPEndPoint[] serverList = new IPEndPoint[serverCount]; + var options = new ResolverOptions(serverList); + if (domainCount > 0) + { + options.SearchDomains = new string[domainCount]; + } + + serverCount = 0; + domainCount = 0; + foreach (string line in lines) + { + string[] tokens = line.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + if (tokens[0].Equals("nameserver")) + { + options.Servers[serverCount] = new IPEndPoint(IPAddress.Parse(tokens[1]), 53); + serverCount++; + } + else if (tokens[0].Equals("search")) + { + options.SearchDomains![domainCount] = tokens[1]; + domainCount++; + } + else if (tokens[0].Equals("domain")) + { + options.DefaultDomain = tokens[1]; + } + } + + return options; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs new file mode 100644 index 00000000000..8f0d2fd271d --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal class ResolverOptions +{ + public IPEndPoint[] Servers; + public string DefaultDomain = string.Empty; + public string[]? SearchDomains; + public bool UseHostsFile; + + public ResolverOptions(IPEndPoint[] servers) + { + Servers = servers; + } + + public ResolverOptions(IPEndPoint server) + { + Servers = new IPEndPoint[] { server }; + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs new file mode 100644 index 00000000000..5b2f1d7229c --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal record struct AddressResult(DateTime ExpiresAt, IPAddress Address); + +internal record struct ServiceResult(DateTime ExpiresAt, int Priority, int Weight, int Port, string Target, AddressResult[] Addresses); + +internal record struct TxtResult(int Ttl, byte[] Data) +{ + internal IEnumerable GetText() => GetText(Encoding.ASCII); + + internal IEnumerable GetText(Encoding encoding) + { + for (int i = 0; i < Data.Length;) + { + int length = Data[i]; + yield return encoding.GetString(Data, i + 1, length); + i += length + 1; + } + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs index 98b9de1fd68..3a8510743b4 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs @@ -1,11 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using DnsClient; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.ServiceDiscovery; using Microsoft.Extensions.ServiceDiscovery.Dns; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; namespace Microsoft.Extensions.Hosting; @@ -46,7 +46,7 @@ public static IServiceCollection AddDnsSrvServiceEndpointProvider(this IServiceC ArgumentNullException.ThrowIfNull(configureOptions); services.AddServiceDiscoveryCore(); - services.TryAddSingleton(); + services.TryAddSingleton(); services.AddSingleton(); var options = services.AddOptions(); options.Configure(o => configureOptions?.Invoke(o)); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs new file mode 100644 index 00000000000..105b64094ce --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using Xunit.Abstractions; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class CancellationTests : LoopbackDnsTestBase +{ + public CancellationTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public async Task PreCanceledToken_Throws() + { + CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork, cts.Token)); + } + + [Fact] + public async Task Timeout_Throws() + { + Resolver.Timeout = TimeSpan.FromSeconds(1); + await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork)); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs new file mode 100644 index 00000000000..b889270e19e --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class DnsDataReaderTests +{ + [Fact] + public void ReadResourceRecord_Success() + { + // example A record for example.com + byte[] buffer = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + DnsDataReader reader = new DnsDataReader(buffer); + Assert.True(reader.TryReadResourceRecord(out DnsResourceRecord record)); + + Assert.Equal("www.example.com", record.Name); + Assert.Equal(QueryType.A, record.Type); + Assert.Equal(QueryClass.Internet, record.Class); + Assert.Equal(3600, record.Ttl); + Assert.Equal(4, record.Data.Length); + } + + [Fact] + public void ReadResourceRecord_Truncated_Fails() + { + // example A record for example.com + byte[] buffer = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + for (int i = 0; i < buffer.Length; i++) + { + DnsDataReader reader = new DnsDataReader(buffer.AsMemory(0, i)); + Assert.False(reader.TryReadResourceRecord(out _)); + } + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs new file mode 100644 index 00000000000..5adbff0c8ac --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs @@ -0,0 +1,140 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class DnsDataWriterTests +{ + [Fact] + public void WriteResourceRecord_Success() + { + // example A record for example.com + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + DnsResourceRecord record = new DnsResourceRecord("www.example.com", QueryType.A, QueryClass.Internet, 3600, new byte[4]); + + byte[] buffer = new byte[512]; + DnsDataWriter writer = new DnsDataWriter(buffer); + Assert.True(writer.TryWriteResourceRecord(record)); + Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); + } + + [Fact] + public void WriteResourceRecord_Truncated_Fails() + { + // example A record for example.com + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + DnsResourceRecord record = new DnsResourceRecord("www.example.com", QueryType.A, QueryClass.Internet, 3600, new byte[4]); + + byte[] buffer = new byte[512]; + for (int i = 0; i < expected.Length; i++) + { + DnsDataWriter writer = new DnsDataWriter(buffer.AsMemory(0, i)); + Assert.False(writer.TryWriteResourceRecord(record)); + } + } + + [Fact] + public void WriteQuestion_Success() + { + // example question for example.com (A record) + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01 + ]; + + byte[] buffer = new byte[512]; + DnsDataWriter writer = new DnsDataWriter(buffer); + Assert.True(writer.TryWriteQuestion("www.example.com", QueryType.A, QueryClass.Internet)); + Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); + } + + [Fact] + public void WriteQuestion_Truncated_Fails() + { + // example question for example.com (A record) + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01 + ]; + + byte[] buffer = new byte[512]; + for (int i = 0; i < expected.Length; i++) + { + DnsDataWriter writer = new DnsDataWriter(buffer.AsMemory(0, i)); + Assert.False(writer.TryWriteQuestion("www.example.com", QueryType.A, QueryClass.Internet)); + } + } + + [Fact] + public void WriteHeader_Success() + { + // example header + byte[] expected = [ + // ID (0x1234) + 0x12, 0x34, + // Flags (0x5678) + 0x56, 0x78, + // Question count (1) + 0x00, 0x01, + // Answer count (0) + 0x00, 0x02, + // Authority count (0) + 0x00, 0x03, + // Additional count (0) + 0x00, 0x04 + ]; + + DnsMessageHeader header = new() + { + TransactionId = 0x1234, + QueryFlags = (QueryFlags)0x5678, + QueryCount = 1, + AnswerCount = 2, + AuthorityCount = 3, + AdditionalRecordCount = 4, + }; + + byte[] buffer = new byte[512]; + DnsDataWriter writer = new DnsDataWriter(buffer); + Assert.True(writer.TryWriteHeader(header)); + Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs new file mode 100644 index 00000000000..b2afc5510f3 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs @@ -0,0 +1,125 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class DnsPrimitivesTests +{ + public static TheoryData QNameData => new() + { + { "www.example.com", "\x0003www\x0007example\x0003com\x0000"u8.ToArray() }, + { "example.com", "\x0007example\x0003com\x0000"u8.ToArray() }, + { "com", "\x0003com\x0000"u8.ToArray() }, + { "example", "\x0007example\x0000"u8.ToArray() }, + { "www", "\x0003www\x0000"u8.ToArray() }, + { "a", "\x0001a\x0000"u8.ToArray() }, + }; + + [Theory] + [MemberData(nameof(QNameData))] + public void TryWriteQName_Success(string name, byte[] expected) + { + byte[] buffer = new byte[512]; + + Assert.True(DnsPrimitives.TryWriteQName(buffer, name, out int written)); + Assert.Equal(name.Length + 2, written); + Assert.Equal(expected, buffer.AsSpan().Slice(0, written).ToArray()); + } + + [Fact] + public void TryWriteQName_LabelTooLong_Throws() + { + byte[] buffer = new byte[512]; + + Assert.Throws(() => DnsPrimitives.TryWriteQName(buffer, new string('a', 70), out int written)); + } + + [Fact] + public void TryWriteQName_BufferTooShort_Fails() + { + byte[] buffer = new byte[512]; + string name = "www.example.com"; + + for (int i = 0; i < name.Length + 2; i++) + { + Assert.False(DnsPrimitives.TryWriteQName(buffer.AsSpan(0, i), name, out _)); + } + } + + [Theory] + [MemberData(nameof(QNameData))] + public void TryReadQName_Success(string expected, byte[] serialized) + { + Assert.True(DnsPrimitives.TryReadQName(serialized, 0, out string? actual, out int bytesRead)); + Assert.Equal(expected, actual); + Assert.Equal(serialized.Length, bytesRead); + } + + [Fact] + public void TryReadQName_TruncatedData_Fails() + { + ReadOnlySpan data = "\x0003www\x0007example\x0003com\x0000"u8; + + for (int i = 0; i < data.Length; i++) + { + Assert.False(DnsPrimitives.TryReadQName(data.Slice(0, i), 0, out _, out _)); + } + } + + [Fact] + public void TryReadQName_Pointer_Success() + { + // [7B padding], example.com. www->[ptr to example.com.] + Span data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); + data[^2] = 0xc0; + + Assert.True(DnsPrimitives.TryReadQName(data, data.Length - 6, out string? actual, out int bytesRead)); + Assert.Equal("www.example.com", actual); + Assert.Equal(6, bytesRead); + } + + [Fact] + public void TryReadQName_PointerTruncated_Fails() + { + // [7B padding], example.com. www->[ptr to example.com.] + Span data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); + data[^2] = 0xc0; + + for (int i = 0; i < data.Length; i++) + { + Assert.False(DnsPrimitives.TryReadQName(data.Slice(0, i), data.Length - 6, out _, out _)); + } + } + + [Fact] + public void TryReadQName_ForwardPointer_Fails() + { + // www->[ptr to example.com], [7B padding], example.com. + Span data = "\x03www\x00\0x000cpaddingexample\x0003com\x00"u8.ToArray(); + data[4] = 0xc0; + + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + + [Fact] + public void TryReadQName_PointerToSelf_Fails() + { + // www->[ptr to www->...] + Span data = "\x0003www\x00c0"u8.ToArray(); + data[4] = 0xc0; + + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + + [Fact] + public void TryReadQName_ReservedBits() + { + Span data = "\x0003www\x00c0"u8.ToArray(); + data[4] = 0xc0; + data[0] = 0x40; + + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs new file mode 100644 index 00000000000..fc0507beb35 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -0,0 +1,221 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Net; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +internal sealed class LoopbackDnsServer : IDisposable +{ + readonly Socket _dnsSocket; + readonly Socket _tcpSocket; + + public IPEndPoint DnsEndPoint => (IPEndPoint)_dnsSocket.LocalEndPoint!; + + public LoopbackDnsServer() + { + _dnsSocket = new(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + _dnsSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + + _tcpSocket = new(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _tcpSocket.Bind(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)_dnsSocket.LocalEndPoint!).Port)); + _tcpSocket.Listen(); + } + + public void Dispose() + { + _dnsSocket.Dispose(); + _tcpSocket.Dispose(); + } + + private static async Task ProcessRequestCore(ReadOnlyMemory message, Func action, Memory responseBuffer) + { + DnsDataReader reader = new DnsDataReader(message); + + if (!reader.TryReadHeader(out DnsMessageHeader header) || + !reader.TryReadQuestion(out var name, out var type, out var @class)) + { + return 0; + } + + LoopbackDnsResponseBuilder responseBuilder = new(name, type, @class); + responseBuilder.TransactionId = header.TransactionId; + responseBuilder.Flags = header.QueryFlags | QueryFlags.HasResponse; + responseBuilder.ResponseCode = QueryResponseCode.NoError; + + await action(responseBuilder); + + DnsDataWriter writer = new(responseBuffer); + if (!writer.TryWriteHeader(new DnsMessageHeader + { + TransactionId = responseBuilder.TransactionId, + QueryFlags = responseBuilder.Flags, + ResponseCode = responseBuilder.ResponseCode, + QueryCount = (ushort)responseBuilder.Questions.Count, + AnswerCount = (ushort)responseBuilder.Answers.Count, + AuthorityCount = (ushort)responseBuilder.Authorities.Count, + AdditionalRecordCount = (ushort)responseBuilder.Additionals.Count + })) + { + throw new InvalidOperationException("Failed to write header"); + }; + + foreach (var (questionName, questionType, questionClass) in responseBuilder.Questions) + { + if (!writer.TryWriteQuestion(questionName, questionType, questionClass)) + { + throw new InvalidOperationException("Failed to write question"); + } + } + + foreach (var answer in responseBuilder.Answers) + { + if (!writer.TryWriteResourceRecord(answer)) + { + throw new InvalidOperationException("Failed to write answer"); + } + } + + foreach (var authority in responseBuilder.Authorities) + { + if (!writer.TryWriteResourceRecord(authority)) + { + throw new InvalidOperationException("Failed to write authority"); + } + } + + foreach (var additional in responseBuilder.Additionals) + { + if (!writer.TryWriteResourceRecord(additional)) + { + throw new InvalidOperationException("Failed to write additional records"); + } + } + + return writer.Position; + } + + public async Task ProcessUdpRequest(Func action) + { + byte[] buffer = ArrayPool.Shared.Rent(512); + try + { + EndPoint remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); + SocketReceiveFromResult result = await _dnsSocket.ReceiveFromAsync(buffer, remoteEndPoint); + + int bytesWritten = await ProcessRequestCore(buffer.AsMemory(0, result.ReceivedBytes), action, buffer.AsMemory(0, 512)); + + await _dnsSocket.SendToAsync(buffer.AsMemory(0, bytesWritten), SocketFlags.None, result.RemoteEndPoint); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + public async Task ProcessTcpRequest(Func action) + { + using Socket tcpClient = await _tcpSocket.AcceptAsync(); + + byte[] buffer = ArrayPool.Shared.Rent(8 * 1024); + try + { + EndPoint remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); + + int bytesRead = 0; + int length = -1; + while (length < 0 || bytesRead < length + 2) + { + int toRead = length < 0 ? 2 : length + 2 - bytesRead; + int read = await tcpClient.ReceiveAsync(buffer.AsMemory(bytesRead, toRead), SocketFlags.None); + bytesRead += read; + + if (length < 0 && bytesRead >= 2) + { + length = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); + } + } + + int bytesWritten = await ProcessRequestCore(buffer.AsMemory(2, length), action, buffer.AsMemory(2)); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(0, 2), (ushort)bytesWritten); + await tcpClient.SendAsync(buffer.AsMemory(0, bytesWritten + 2), SocketFlags.None); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } +} + +internal sealed class LoopbackDnsResponseBuilder +{ + public LoopbackDnsResponseBuilder(string name, QueryType type, QueryClass @class) + { + Name = name; + Type = type; + Class = @class; + Questions.Add((name, type, @class)); + } + + public ushort TransactionId { get; set; } + public QueryFlags Flags { get; set; } + public QueryResponseCode ResponseCode { get; set; } + + public string Name { get; } + public QueryType Type { get; } + public QueryClass Class { get; } + + public List<(string, QueryType, QueryClass)> Questions { get; } = new List<(string, QueryType, QueryClass)>(); + public List Answers { get; } = new List(); + public List Authorities { get; } = new List(); + public List Additionals { get; } = new List(); +} + +internal static class LoopbackDnsServerExtensions +{ + public static List AddAddress(this List records, string name, int ttl, IPAddress address) + { + QueryType type = address.AddressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; + records.Add(new DnsResourceRecord(name, type, QueryClass.Internet, ttl, address.GetAddressBytes())); + return records; + } + + public static List AddCname(this List records, string name, int ttl, string alias) + { + byte[] buff = new byte[256]; + if (!DnsPrimitives.TryWriteQName(buff, alias, out int length)) + { + throw new InvalidOperationException("Failed to encode domain name"); + } + + records.Add(new DnsResourceRecord(name, QueryType.CNAME, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + return records; + } + + public static List AddService(this List records, string name, int ttl, ushort priority, ushort weight, ushort port, string target) + { + byte[] buff = new byte[256]; + if (!DnsPrimitives.TryWriteService(buff, priority, weight, port, target, out int length)) + { + throw new InvalidOperationException("Failed to encode SRV record"); + } + + records.Add(new DnsResourceRecord(name, QueryType.SRV, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + return records; + } + + public static List AddStartOfAuthority(this List records, string name, int ttl, string mname, string rname, uint serial, uint refresh, uint retry, uint expire, uint minimum) + { + byte[] buff = new byte[256]; + if (!DnsPrimitives.TryWriteSoa(buff, mname, rname, serial, refresh, retry, expire, minimum, out int length)) + { + throw new InvalidOperationException("Failed to encode SOA record"); + } + + records.Add(new DnsResourceRecord(name, QueryType.SOA, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + return records; + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs new file mode 100644 index 00000000000..aa20a3b30ca --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit.Abstractions; +using System.Runtime.CompilerServices; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public abstract class LoopbackDnsTestBase : IDisposable +{ + protected readonly ITestOutputHelper Output; + + internal LoopbackDnsServer DnsServer { get; } + internal DnsResolver Resolver { get; } + protected readonly TestTimeProvider TimeProvider; + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "SetTimeProvider")] + static extern void MockTimeProvider(DnsResolver instance, TimeProvider provider); + + public LoopbackDnsTestBase(ITestOutputHelper output) + { + Output = output; + DnsServer = new(); + Resolver = new([DnsServer.DnsEndPoint]); + Resolver.Timeout = TimeSpan.FromSeconds(5); + TimeProvider = new(); + MockTimeProvider(Resolver, TimeProvider); + } + + public void Dispose() + { + DnsServer.Dispose(); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs new file mode 100644 index 00000000000..a17e9a07159 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class ResolvConfTests +{ + public void GetOptions() + { + var contents = @" +nameserver 10.96.0.10 +search default.svc.cluster.local svc.cluster.local cluster.local +options ndots:5 +@"; + + var reader = new StringReader(contents); + ResolverOptions options = ResolvConf.GetOptions(reader); + + IPEndPoint ipAddress = Assert.Single(options.Servers); + Assert.Equal(new IPEndPoint(IPAddress.Parse("10.96.0.10"), 53), ipAddress); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs new file mode 100644 index 00000000000..a6c6a904182 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -0,0 +1,162 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using Xunit.Abstractions; +using System.Net; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class ResolveAddressesTests : LoopbackDnsTestBase +{ + public ResolveAddressesTests(ITestOutputHelper output) : base(output) + { + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ResolveIPv4_NoData_Success(bool includeSoa) + { + Resolver.Timeout = TimeSpan.FromSeconds(1); + + _ = DnsServer.ProcessUdpRequest(builder => + { + if (includeSoa) + { + builder.Authorities.AddStartOfAuthority("ns.com", 240, "ns.com", "admin.ns.com", 1, 900, 180, 6048000, 3600); + } + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + Assert.Empty(results); + + if (includeSoa) + { + // if SOA is included, the negative result can be cached + Assert.Empty(await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); + + // negative result does not affect other types + await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetworkV6)); + } + else + { + // no caching -> new request, and the request should timeout + await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ResolveIPv4_NoSuchName_Success(bool includeSoa) + { + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.ResponseCode = QueryResponseCode.NameError; + if (includeSoa) + { + builder.Authorities.AddStartOfAuthority("ns.com", 240, "ns.com", "admin.ns.com", 1, 900, 180, 6048000, 3600); + } + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + Assert.Empty(results); + + if (includeSoa) + { + // if SOA is included, the negative result can be cached for all types + Assert.Empty(await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); + Assert.Empty(await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetworkV6)); + } + else + { + Resolver.Timeout = TimeSpan.FromSeconds(1); + // no caching -> new request, and the request should timeout + await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); + } + } + + [Fact] + public async Task ResolveIPv4_Simple_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task ResolveIPv4_Aliases_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example.com", 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddAddress("www.example3.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task ResolveIPv4_Aliases_NotFound_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example.com", 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + + // extra address in the answer not connected to the above + builder.Answers.AddAddress("www.example4.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIP_InvalidAddressFamily_Throws() + { + await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.Unknown)); + } + + [Fact] + public async Task ResolveIP_Cached_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + DnsServer.Dispose(); + + AddressResult cached = Assert.Single(await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); + Assert.Equal(res, cached); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs new file mode 100644 index 00000000000..fe599ca1789 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using Xunit.Abstractions; +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class ResolveServiceTests : LoopbackDnsTestBase +{ + public ResolveServiceTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public async Task ResolveService_Simple_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddService("_s0._tcp.example.com", 3600, 1, 2, 8080, "www.example.com"); + builder.Additionals.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + ServiceResult[] results = await Resolver.ResolveServiceAsync("_s0._tcp.example.com"); + + ServiceResult result = Assert.Single(results); + Assert.Equal("www.example.com", result.Target); + Assert.Equal(1, result.Priority); + Assert.Equal(2, result.Weight); + Assert.Equal(8080, result.Port); + + AddressResult addressResult = Assert.Single(result.Addresses); + Assert.Equal(address, addressResult.Address); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs new file mode 100644 index 00000000000..d61f8ca8aa3 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using Xunit.Abstractions; +using System.Net; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class TcpFailoverTests : LoopbackDnsTestBase +{ + public TcpFailoverTests(ITestOutputHelper output) : base(output) + { + Resolver.Timeout = TimeSpan.FromHours(5); + } + + [Fact] + public async Task TcpFailover_Simple_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + _ = DnsServer.ProcessTcpRequest(builder => + { + builder.Answers.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TestTimeProvider.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TestTimeProvider.cs new file mode 100644 index 00000000000..453d4cb5a4e --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TestTimeProvider.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class TestTimeProvider : TimeProvider +{ + public DateTime Now { get; set; } = DateTime.UtcNow; + public void Advance(TimeSpan time) => Now += time; + + public override DateTimeOffset GetUtcNow() => Now; +} From 704e33c77e7d21d5e50eca5010efbdbc9cfb83be Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 4 Oct 2024 09:09:23 +0200 Subject: [PATCH 02/45] Allow mocking --- .../DnsServiceEndpointProvider.cs | 2 +- .../DnsServiceEndpointProviderFactory.cs | 2 +- .../DnsSrvServiceEndpointProvider.cs | 2 +- .../DnsSrvServiceEndpointProviderFactory.cs | 2 +- .../Resolver/DnsResolver.cs | 5 +++-- .../Resolver/IDnsResolver.cs | 12 ++++++++++++ ...ServiceDiscoveryDnsServiceCollectionExtensions.cs | 2 +- .../DnsServiceEndpointResolverTests.cs | 2 ++ 8 files changed, 22 insertions(+), 7 deletions(-) create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs index 52f2da3014a..3b5748e872c 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs @@ -14,7 +14,7 @@ internal sealed partial class DnsServiceEndpointProvider( string hostName, IOptionsMonitor options, ILogger logger, - DnsResolver resolver, + IDnsResolver resolver, TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature { protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs index 80fb009c0d1..1da21411e64 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs @@ -11,7 +11,7 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns; internal sealed partial class DnsServiceEndpointProviderFactory( IOptionsMonitor options, ILogger logger, - DnsResolver resolver, + IDnsResolver resolver, TimeProvider timeProvider) : IServiceEndpointProviderFactory { /// diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs index 9532c7a9ed5..dc47bb8b5cd 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs @@ -14,7 +14,7 @@ internal sealed partial class DnsSrvServiceEndpointProvider( string hostName, IOptionsMonitor options, ILogger logger, - DnsResolver resolver, + IDnsResolver resolver, TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature { protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs index fb7d006ae9a..e9b428e523c 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs @@ -11,7 +11,7 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns; internal sealed partial class DnsSrvServiceEndpointProviderFactory( IOptionsMonitor options, ILogger logger, - DnsResolver resolver, + IDnsResolver resolver, TimeProvider timeProvider) : IServiceEndpointProviderFactory { private static readonly string s_serviceAccountPath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount"); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index dd8d9ebbf08..be545acb522 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -9,7 +9,7 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; -internal class DnsResolver : IDisposable +internal class DnsResolver : IDnsResolver, IDisposable { private const int MaximumNameLength = 253; private const int IPv4Length = 4; @@ -30,8 +30,9 @@ internal void SetTimeProvider(TimeProvider timeProvider) _timeProvider = timeProvider; } - public DnsResolver() : this(OperatingSystem.IsWindows() ? NetworkInfo.GetOptions() : ResolvConf.GetOptions()) + public DnsResolver(TimeProvider timeProvider) : this(OperatingSystem.IsWindows() ? NetworkInfo.GetOptions() : ResolvConf.GetOptions()) { + _timeProvider = timeProvider; } internal DnsResolver(ResolverOptions options) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs new file mode 100644 index 00000000000..906f45f83d5 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal interface IDnsResolver +{ + ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default); + ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default); +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs index 3a8510743b4..7d05243f741 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs @@ -46,7 +46,7 @@ public static IServiceCollection AddDnsSrvServiceEndpointProvider(this IServiceC ArgumentNullException.ThrowIfNull(configureOptions); services.AddServiceDiscoveryCore(); - services.TryAddSingleton(); + services.TryAddSingleton(); services.AddSingleton(); var options = services.AddOptions(); options.Configure(o => configureOptions?.Invoke(o)); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs index 2b3a7fd7cd3..309a937643a 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs @@ -4,6 +4,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Time.Testing; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; using Xunit; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests; @@ -16,6 +17,7 @@ public async Task ResolveServiceEndpoint_Dns_MultiShot() var timeProvider = new FakeTimeProvider(); var services = new ServiceCollection() .AddSingleton(timeProvider) + .AddSingleton() .AddServiceDiscoveryCore() .AddDnsServiceEndpointProvider(o => o.DefaultRefreshPeriod = TimeSpan.FromSeconds(30)) .BuildServiceProvider(); From aff6d179ab12293feb9c8d01ba27d7d3a4e566f2 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 4 Oct 2024 09:30:14 +0200 Subject: [PATCH 03/45] Fix dispose --- .../Resolver/DnsResolver.cs | 12 +++-- .../Resolver/ResolveAddressesTests.cs | 49 +------------------ 2 files changed, 10 insertions(+), 51 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index be545acb522..58956332a52 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -369,8 +369,10 @@ async ValueTask SendQueryAsyncSlow(string name, QueryType queryType DnsMessageHeader header = default; DateTime queryStartedTime = default; - foreach (IPEndPoint serverEndPoint in _options.Servers) + for (int index = 0; index < _options.Servers.Length; index++) { + IPEndPoint serverEndPoint = _options.Servers[index]; + queryStartedTime = _timeProvider.GetUtcNow().DateTime; (responseReader, header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); @@ -396,7 +398,11 @@ async ValueTask SendQueryAsyncSlow(string name, QueryType queryType break; } - responseReader.Dispose(); + if (index < _options.Servers.Length - 1) + { + // keep the reader open for processing + responseReader.Dispose(); + } } int ttl = int.MaxValue; @@ -418,7 +424,7 @@ static List ReadRecords(int count, ref int ttl, ref DnsDataRe if (!reader.TryReadResourceRecord(out var record)) { // TODO how to handle corrupted responses? - throw new InvalidOperationException("Invalid response: Answer record"); + throw new InvalidOperationException("Invalid response: corrupted record"); } ttl = Math.Min(ttl, record.Ttl); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs index a6c6a904182..6447e9e2ba9 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -12,6 +12,7 @@ public class ResolveAddressesTests : LoopbackDnsTestBase { public ResolveAddressesTests(ITestOutputHelper output) : base(output) { + Resolver.Timeout = TimeSpan.FromSeconds(5); } [Theory] @@ -19,8 +20,6 @@ public ResolveAddressesTests(ITestOutputHelper output) : base(output) [InlineData(true)] public async Task ResolveIPv4_NoData_Success(bool includeSoa) { - Resolver.Timeout = TimeSpan.FromSeconds(1); - _ = DnsServer.ProcessUdpRequest(builder => { if (includeSoa) @@ -32,20 +31,6 @@ public async Task ResolveIPv4_NoData_Success(bool includeSoa) AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); Assert.Empty(results); - - if (includeSoa) - { - // if SOA is included, the negative result can be cached - Assert.Empty(await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); - - // negative result does not affect other types - await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetworkV6)); - } - else - { - // no caching -> new request, and the request should timeout - await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); - } } [Theory] @@ -65,19 +50,6 @@ public async Task ResolveIPv4_NoSuchName_Success(bool includeSoa) AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); Assert.Empty(results); - - if (includeSoa) - { - // if SOA is included, the negative result can be cached for all types - Assert.Empty(await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); - Assert.Empty(await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetworkV6)); - } - else - { - Resolver.Timeout = TimeSpan.FromSeconds(1); - // no caching -> new request, and the request should timeout - await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); - } } [Fact] @@ -140,23 +112,4 @@ public async Task ResolveIP_InvalidAddressFamily_Throws() { await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.Unknown)); } - - [Fact] - public async Task ResolveIP_Cached_Success() - { - IPAddress address = IPAddress.Parse("172.213.245.111"); - _ = DnsServer.ProcessUdpRequest(builder => - { - builder.Answers.AddAddress("www.example.com", 3600, address); - return Task.CompletedTask; - }); - - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); - - AddressResult res = Assert.Single(results); - DnsServer.Dispose(); - - AddressResult cached = Assert.Single(await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork)); - Assert.Equal(res, cached); - } } From 15db8a14ce2b9005c0ba24607f899bace161427c Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 4 Oct 2024 13:35:20 +0200 Subject: [PATCH 04/45] Fix tests --- .../DnsServiceEndpointProvider.cs | 12 +- .../DnsSrvServiceEndpointProviderFactory.cs | 2 +- .../Resolver/DnsResolver.cs | 64 ++++++++- .../Resolver/IDnsResolver.cs | 1 + .../DnsServiceEndpointResolverTests.cs | 2 +- .../DnsSrvServiceEndpointResolverTests.cs | 123 +++++------------- .../Resolver/LoopbackDnsTestBase.cs | 5 +- .../Resolver/ResolveAddressesTests.cs | 12 ++ .../Resolver/TestTimeProvider.cs | 12 -- 9 files changed, 112 insertions(+), 121 deletions(-) delete mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TestTimeProvider.cs diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs index 3b5748e872c..c969a9e91e9 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net; -using System.Net.Sockets; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; @@ -34,16 +33,9 @@ protected override async Task ResolveAsyncCore() Log.AddressQuery(logger, ServiceName, hostName); var now = _timeProvider.GetUtcNow().DateTime; - var ipv4Addresses = resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork, ShutdownToken).ConfigureAwait(false); - var ipv6Addresses = resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetworkV6, ShutdownToken).ConfigureAwait(false); + var addresses = await resolver.ResolveIPAddressesAsync(hostName, ShutdownToken).ConfigureAwait(false); - foreach (var address in await ipv4Addresses) - { - ttl = MinTtl(now, address.ExpiresAt, ttl); - endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, 0))); - } - - foreach (var address in await ipv6Addresses) + foreach (var address in addresses) { ttl = MinTtl(now, address.ExpiresAt, ttl); endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, 0))); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs index e9b428e523c..085ee30123b 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs @@ -17,7 +17,7 @@ internal sealed partial class DnsSrvServiceEndpointProviderFactory( private static readonly string s_serviceAccountPath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount"); private static readonly string s_serviceAccountNamespacePath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount", "namespace"); private static readonly string s_resolveConfPath = Path.Combine($"{Path.DirectorySeparatorChar}etc", "resolv.conf"); - private readonly string? _querySuffix = options.CurrentValue.QuerySuffix ?? GetKubernetesHostDomain(); + private readonly string? _querySuffix = options.CurrentValue.QuerySuffix?.TrimStart('.') ?? GetKubernetesHostDomain(); /// public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 58956332a52..d3a7cb664c6 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -91,6 +91,18 @@ public async ValueTask ResolveServiceAsync(string name, Cancell List addresses = new List(); foreach (var additional in record.Additionals) { + // From RFC 2782: + // + // Target + // The domain name of the target host. There MUST be one or more + // address records for this name, the name MUST NOT be an alias (in + // the sense of RFC 1034 or RFC 2181). Implementors are urged, but + // not required, to return the address record(s) in the Additional + // Data section. Unless and until permitted by future standards + // action, name compression is not to be used for this field. + // + // A Target of "." means that the service is decidedly not + // available at this domain. if (additional.Name == target && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA)) { addresses.Add(new AddressResult(record.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span))); @@ -105,16 +117,62 @@ public async ValueTask ResolveServiceAsync(string name, Cancell return result; } + public async ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) + { + if (name == "localhost") + { + // name localhost exists outside of DNS and can't be resolved by a DNS server + int len = (Socket.OSSupportsIPv4 ? 1 : 0) + (Socket.OSSupportsIPv6 ? 1 : 0); + AddressResult[] res = new AddressResult[len]; + + int index = 0; + if (Socket.OSSupportsIPv6) // prefer IPv6 + { + res[index] = new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback); + } + if (Socket.OSSupportsIPv4) + { + res[index++] = new AddressResult(DateTime.MaxValue, IPAddress.Loopback); + } + } + + var ipv4AddressesTask = ResolveIPAddressesAsync(name, AddressFamily.InterNetwork, cancellationToken); + var ipv6AddressesTask = ResolveIPAddressesAsync(name, AddressFamily.InterNetworkV6, cancellationToken); + + AddressResult[] ipv4Addresses = await ipv4AddressesTask.ConfigureAwait(false); + AddressResult[] ipv6Addresses = await ipv6AddressesTask.ConfigureAwait(false); + + AddressResult[] results = new AddressResult[ipv4Addresses.Length + ipv6Addresses.Length]; + ipv6Addresses.CopyTo(results, 0); + ipv4Addresses.CopyTo(results, ipv6Addresses.Length); + return results; + } + public async ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) { ObjectDisposedException.ThrowIf(_disposed, this); cancellationToken.ThrowIfCancellationRequested(); - if (addressFamily != AddressFamily.InterNetwork && addressFamily != AddressFamily.InterNetworkV6 && addressFamily != AddressFamily.Unspecified) + if (addressFamily != AddressFamily.InterNetwork && addressFamily != AddressFamily.InterNetworkV6) { throw new ArgumentOutOfRangeException(nameof(addressFamily), addressFamily, "Invalid address family"); } + if (name == "localhost") + { + // name localhost exists outside of DNS and can't be resolved by a DNS server + if (addressFamily == AddressFamily.InterNetwork && Socket.OSSupportsIPv4) + { + return [new AddressResult(DateTime.MaxValue, IPAddress.Loopback)]; + } + else if (addressFamily == AddressFamily.InterNetworkV6 && Socket.OSSupportsIPv6) + { + return [new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback)]; + } + + return Array.Empty(); + } + if (name.Length > MaximumNameLength) { throw new ArgumentException("Name is too long", nameof(name)); @@ -345,7 +403,7 @@ internal async ValueTask SendQueryAsync(string name, QueryType quer try { - return await SendQueryAsyncSlow(name, queryType, cts.Token).ConfigureAwait(false); + return await SendQueryAsyncCore(name, queryType, cts.Token).ConfigureAwait(false); } catch (OperationCanceledException oce) when ( !cancellationToken.IsCancellationRequested && // not cancelled by the caller @@ -363,7 +421,7 @@ internal async ValueTask SendQueryAsync(string name, QueryType quer } } - async ValueTask SendQueryAsyncSlow(string name, QueryType queryType, CancellationToken cancellationToken) + async ValueTask SendQueryAsyncCore(string name, QueryType queryType, CancellationToken cancellationToken) { DnsDataReader responseReader = default; DnsMessageHeader header = default; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs index 906f45f83d5..e09168d9552 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs @@ -8,5 +8,6 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal interface IDnsResolver { ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default); + ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default); ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default); } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs index 309a937643a..b949e713999 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs @@ -17,7 +17,7 @@ public async Task ResolveServiceEndpoint_Dns_MultiShot() var timeProvider = new FakeTimeProvider(); var services = new ServiceCollection() .AddSingleton(timeProvider) - .AddSingleton() + .AddSingleton() .AddServiceDiscoveryCore() .AddDnsServiceEndpointProvider(o => o.DefaultRefreshPeriod = TimeSpan.FromSeconds(30)) .BuildServiceProvider(); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs index b58d9e2f4ec..ec21bf9fa9c 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net; -using DnsClient; -using DnsClient.Protocol; +using System.Net.Sockets; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Configuration.Memory; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; using Microsoft.Extensions.ServiceDiscovery.Internal; using Xunit; @@ -19,88 +19,38 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests; /// public class DnsSrvServiceEndpointResolverTests { - private sealed class FakeDnsClient : IDnsQuery + private sealed class FakeDnsResolver : IDnsResolver { - public Func>? QueryAsyncFunc { get; set; } + public Func>? ResolveIPAddressesAsyncFunc { get; set; } + public ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc!.Invoke(name, addressFamily, cancellationToken); - public IDnsQueryResponse Query(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse Query(DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse Query(DnsQuestion question, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException(); - public Task QueryAsync(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) - => QueryAsyncFunc!(query, queryType, queryClass, cancellationToken); - public Task QueryAsync(DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryAsync(DnsQuestion question, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryCache(DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse QueryCache(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryReverse(IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryReverse(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException(); - public Task QueryReverseAsync(IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryReverseAsync(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, DnsQuestion question, DnsQueryOptions queryOptions) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, DnsQuestion question, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress, DnsQueryOptions queryOptions) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - } + public Func>? ResolveIPAddressesAsyncFunc2 { get; set; } - private sealed class FakeDnsQueryResponse : IDnsQueryResponse - { - public IReadOnlyList? Questions { get; set; } - public IReadOnlyList? Additionals { get; set; } - public IEnumerable? AllRecords { get; set; } - public IReadOnlyList? Answers { get; set; } - public IReadOnlyList? Authorities { get; set; } - public string? AuditTrail { get; set; } - public string? ErrorMessage { get; set; } - public bool HasError { get; set; } - public DnsResponseHeader? Header { get; set; } - public int MessageSize { get; set; } - public NameServer? NameServer { get; set; } - public DnsQuerySettings? Settings { get; set; } + public ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc2!.Invoke(name, cancellationToken); + + public Func>? ResolveServiceAsyncFunc { get; set; } + + public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) => ResolveServiceAsyncFunc!.Invoke(name, cancellationToken); } [Fact] public async Task ResolveServiceEndpoint_DnsSrv() { - var dnsClientMock = new FakeDnsClient + var dnsClientMock = new FakeDnsResolver { - QueryAsyncFunc = (query, queryType, queryClass, cancellationToken) => + ResolveServiceAsyncFunc = (name, cancellationToken) => { - var response = new FakeDnsQueryResponse - { - Answers = new List - { - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 66, 8888, DnsString.Parse("srv-a")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 9999, DnsString.Parse("srv-b")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 7777, DnsString.Parse("srv-c")) - }, - Additionals = new List - { - new ARecord(new ResourceRecordInfo("srv-a", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Parse("10.10.10.10")), - new ARecord(new ResourceRecordInfo("srv-b", ResourceRecordType.AAAA, queryClass, 64, 0), IPAddress.IPv6Loopback), - new CNameRecord(new ResourceRecordInfo("srv-c", ResourceRecordType.AAAA, queryClass, 64, 0), DnsString.Parse("remotehost")), - new TxtRecord(new ResourceRecordInfo("srv-a", ResourceRecordType.TXT, queryClass, 64, 0), ["some txt values"], ["some txt utf8 values"]) - } - }; + ServiceResult[] response = [ + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 66, 8888, "srv-a", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Parse("10.10.10.10"))]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 9999, "srv-b", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.IPv6Loopback)]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 7777, "srv-c", []) + ]; - return Task.FromResult(response); + return ValueTask.FromResult(response); } }; var services = new ServiceCollection() - .AddSingleton(dnsClientMock) + .AddSingleton(dnsClientMock) .AddServiceDiscoveryCore() .AddDnsSrvServiceEndpointProvider(options => options.QuerySuffix = ".ns") .BuildServiceProvider(); @@ -119,7 +69,7 @@ public async Task ResolveServiceEndpoint_DnsSrv() var eps = initialResult.EndpointSource.Endpoints; Assert.Equal(new IPEndPoint(IPAddress.Parse("10.10.10.10"), 8888), eps[0].EndPoint); Assert.Equal(new IPEndPoint(IPAddress.IPv6Loopback, 9999), eps[1].EndPoint); - Assert.Equal(new DnsEndPoint("remotehost", 7777), eps[2].EndPoint); + Assert.Equal(new DnsEndPoint("srv-c", 7777), eps[2].EndPoint); Assert.All(initialResult.EndpointSource.Endpoints, ep => { @@ -137,28 +87,17 @@ public async Task ResolveServiceEndpoint_DnsSrv() [Theory] public async Task ResolveServiceEndpoint_DnsSrv_MultipleProviders_PreventMixing(bool dnsFirst) { - var dnsClientMock = new FakeDnsClient + var dnsClientMock = new FakeDnsResolver { - QueryAsyncFunc = (query, queryType, queryClass, cancellationToken) => + ResolveServiceAsyncFunc = (name, cancellationToken) => { - var response = new FakeDnsQueryResponse - { - Answers = new List - { - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 66, 8888, DnsString.Parse("srv-a")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 9999, DnsString.Parse("srv-b")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 7777, DnsString.Parse("srv-c")) - }, - Additionals = new List - { - new ARecord(new ResourceRecordInfo("srv-a", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Parse("10.10.10.10")), - new ARecord(new ResourceRecordInfo("srv-b", ResourceRecordType.AAAA, queryClass, 64, 0), IPAddress.IPv6Loopback), - new CNameRecord(new ResourceRecordInfo("srv-c", ResourceRecordType.AAAA, queryClass, 64, 0), DnsString.Parse("remotehost")), - new TxtRecord(new ResourceRecordInfo("srv-a", ResourceRecordType.TXT, queryClass, 64, 0), ["some txt values"], ["some txt utf8 values"]) - } - }; + ServiceResult[] response = [ + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 66, 8888, "srv-a", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Parse("10.10.10.10"))]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 9999, "srv-b", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.IPv6Loopback)]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 7777, "srv-c", []) + ]; - return Task.FromResult(response); + return ValueTask.FromResult(response); } }; var configSource = new MemoryConfigurationSource @@ -171,7 +110,7 @@ public async Task ResolveServiceEndpoint_DnsSrv_MultipleProviders_PreventMixing( }; var config = new ConfigurationBuilder().Add(configSource); var serviceCollection = new ServiceCollection() - .AddSingleton(dnsClientMock) + .AddSingleton(dnsClientMock) .AddSingleton(config.Build()) .AddServiceDiscoveryCore(); if (dnsFirst) @@ -211,7 +150,7 @@ public async Task ResolveServiceEndpoint_DnsSrv_MultipleProviders_PreventMixing( var eps = initialResult.EndpointSource.Endpoints; Assert.Equal(new IPEndPoint(IPAddress.Parse("10.10.10.10"), 8888), eps[0].EndPoint); Assert.Equal(new IPEndPoint(IPAddress.IPv6Loopback, 9999), eps[1].EndPoint); - Assert.Equal(new DnsEndPoint("remotehost", 7777), eps[2].EndPoint); + Assert.Equal(new DnsEndPoint("srv-c", 7777), eps[2].EndPoint); Assert.All(initialResult.EndpointSource.Endpoints, ep => { diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs index aa20a3b30ca..ed562b364df 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs @@ -3,6 +3,7 @@ using Xunit.Abstractions; using System.Runtime.CompilerServices; +using Microsoft.Extensions.Time.Testing; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; @@ -12,7 +13,7 @@ public abstract class LoopbackDnsTestBase : IDisposable internal LoopbackDnsServer DnsServer { get; } internal DnsResolver Resolver { get; } - protected readonly TestTimeProvider TimeProvider; + protected readonly FakeTimeProvider TimeProvider; [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "SetTimeProvider")] static extern void MockTimeProvider(DnsResolver instance, TimeProvider provider); @@ -21,9 +22,9 @@ public LoopbackDnsTestBase(ITestOutputHelper output) { Output = output; DnsServer = new(); + TimeProvider = new(); Resolver = new([DnsServer.DnsEndPoint]); Resolver.Timeout = TimeSpan.FromSeconds(5); - TimeProvider = new(); MockTimeProvider(Resolver, TimeProvider); } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs index 6447e9e2ba9..ed967ad7912 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -112,4 +112,16 @@ public async Task ResolveIP_InvalidAddressFamily_Throws() { await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.Unknown)); } + + [Theory] + [InlineData(AddressFamily.InterNetwork, "127.0.0.1")] + [InlineData(AddressFamily.InterNetworkV6, "::1")] + public async Task ResolveIP_Localhost_ReturnsLoopback(AddressFamily family, string addressAsString) + { + IPAddress address = IPAddress.Parse(addressAsString); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("localhost", family); + AddressResult result = Assert.Single(results); + + Assert.Equal(address, result.Address); + } } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TestTimeProvider.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TestTimeProvider.cs deleted file mode 100644 index 453d4cb5a4e..00000000000 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TestTimeProvider.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; - -public class TestTimeProvider : TimeProvider -{ - public DateTime Now { get; set; } = DateTime.UtcNow; - public void Advance(TimeSpan time) => Now += time; - - public override DateTimeOffset GetUtcNow() => Now; -} From 6bda2356d5427e0dbbc4c34928d5f313c17d44b5 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 4 Oct 2024 13:49:59 +0200 Subject: [PATCH 05/45] Actually run ResolvConf tests --- .../Resolver/ResolvConf.cs | 47 ++++++------------- .../Resolver/ResolvConfTests.cs | 1 + 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs index 4312b9d9d84..d11c0fbff74 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs @@ -15,54 +15,35 @@ public static ResolverOptions GetOptions() public static ResolverOptions GetOptions(TextReader reader) { - int serverCount = 0; - int domainCount = 0; + List serverList = new(); + List searchDomains = new(); - string[] lines = reader.ReadToEnd().Split('\n', StringSplitOptions.RemoveEmptyEntries); - foreach (string line in lines) + while (reader.ReadLine() is string line) { + string[] tokens = line.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + if (line.StartsWith("nameserver")) { - serverCount++; + if (tokens.Length >= 2 && IPAddress.TryParse(tokens[1], out IPAddress? address)) + { + serverList.Add(new IPEndPoint(address, 53)); + } } else if (line.StartsWith("search")) { - domainCount++; + searchDomains.AddRange(tokens.Skip(1)); } } - if (serverCount == 0) + if (serverList.Count == 0) { throw new SocketException((int)SocketError.AddressNotAvailable); } - IPEndPoint[] serverList = new IPEndPoint[serverCount]; - var options = new ResolverOptions(serverList); - if (domainCount > 0) + var options = new ResolverOptions(serverList.ToArray()) { - options.SearchDomains = new string[domainCount]; - } - - serverCount = 0; - domainCount = 0; - foreach (string line in lines) - { - string[] tokens = line.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); - if (tokens[0].Equals("nameserver")) - { - options.Servers[serverCount] = new IPEndPoint(IPAddress.Parse(tokens[1]), 53); - serverCount++; - } - else if (tokens[0].Equals("search")) - { - options.SearchDomains![domainCount] = tokens[1]; - domainCount++; - } - else if (tokens[0].Equals("domain")) - { - options.DefaultDomain = tokens[1]; - } - } + SearchDomains = searchDomains.Count > 0 ? searchDomains.ToArray() : default + }; return options; } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs index a17e9a07159..281ffbecd24 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs @@ -8,6 +8,7 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; public class ResolvConfTests { + [Fact] public void GetOptions() { var contents = @" From 6fc61e011f907481b1231a711bad4fe47dd6aadc Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Tue, 8 Oct 2024 11:53:32 +0200 Subject: [PATCH 06/45] Add retry functionality --- .../Resolver/DnsResolver.cs | 371 ++++++++++-------- .../Resolver/ResolverOptions.cs | 3 + .../Resolver/SendQueryError.cs | 13 + .../Resolver/CancellationTests.cs | 20 +- .../Resolver/LoopbackDnsTestBase.cs | 20 +- .../Resolver/ResolveAddressesTests.cs | 9 +- .../Resolver/RetryTests.cs | 49 +++ .../Resolver/TcpFailoverTests.cs | 1 - 8 files changed, 308 insertions(+), 178 deletions(-) create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index d3a7cb664c6..7b1b504ebb4 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -14,14 +14,13 @@ internal class DnsResolver : IDnsResolver, IDisposable private const int MaximumNameLength = 253; private const int IPv4Length = 4; private const int IPv6Length = 16; - private const int HeaderSize = 12; private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue); bool _disposed; private readonly ResolverOptions _options; - private TimeSpan _timeout = System.Threading.Timeout.InfiniteTimeSpan; private readonly CancellationTokenSource _pendingRequestsCts = new(); + private int _maxRetries = 3; private TimeProvider _timeProvider = TimeProvider.System; @@ -42,6 +41,12 @@ internal DnsResolver(ResolverOptions options) { throw new ArgumentException("There are no DNS servers configured.", nameof(options)); } + + if (options.Timeout != Timeout.InfiniteTimeSpan) + { + ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(options.Timeout, TimeSpan.Zero); + ArgumentOutOfRangeException.ThrowIfGreaterThan(options.Timeout, s_maxTimeout); + } } internal DnsResolver(IEnumerable servers) : this(new ResolverOptions(servers.ToArray())) @@ -52,19 +57,14 @@ internal DnsResolver(IPEndPoint server) : this(new ResolverOptions(server)) { } - public TimeSpan Timeout + public int MaxRetries { - get => _timeout; + get => _maxRetries; set { ObjectDisposedException.ThrowIf(_disposed, this); - - if (value != System.Threading.Timeout.InfiniteTimeSpan) - { - ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(value, TimeSpan.Zero); - ArgumentOutOfRangeException.ThrowIfGreaterThan(value, s_maxTimeout); - } - _timeout = value; + ArgumentOutOfRangeException.ThrowIfLessThan(value, 0); + _maxRetries = value; } } @@ -73,15 +73,17 @@ public async ValueTask ResolveServiceAsync(string name, Cancell ObjectDisposedException.ThrowIf(_disposed, this); cancellationToken.ThrowIfCancellationRequested(); - DnsResponse record = await SendQueryAsync(name, QueryType.SRV, cancellationToken).ConfigureAwait(false); - if (!ValidateResponse(record)) + SendQueryResult result = await SendQueryWithRetriesAsync(name, QueryType.SRV, cancellationToken).ConfigureAwait(false); + if (result.Error is not SendQueryError.NoError) { return Array.Empty(); } - var results = new List(record.Answers.Count); + DnsResponse response = result.Response; + + var results = new List(response.Answers.Count); - foreach (var answer in record.Answers) + foreach (var answer in response.Answers) { if (answer.Type == QueryType.SRV) { @@ -89,7 +91,7 @@ public async ValueTask ResolveServiceAsync(string name, Cancell Debug.Assert(success, "Failed to read SRV"); List addresses = new List(); - foreach (var additional in record.Additionals) + foreach (var additional in response.Additionals) { // From RFC 2782: // @@ -105,16 +107,15 @@ public async ValueTask ResolveServiceAsync(string name, Cancell // available at this domain. if (additional.Name == target && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA)) { - addresses.Add(new AddressResult(record.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span))); + addresses.Add(new AddressResult(response.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span))); } } - results.Add(new ServiceResult(record.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target!, addresses.ToArray())); + results.Add(new ServiceResult(response.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target!, addresses.ToArray())); } } - var result = results.ToArray(); - return result; + return results.ToArray(); } public async ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) @@ -180,18 +181,19 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; - DnsResponse record = await SendQueryAsync(name, queryType, cancellationToken).ConfigureAwait(false); - if (!ValidateResponse(record)) + SendQueryResult result = await SendQueryWithRetriesAsync(name, queryType, cancellationToken).ConfigureAwait(false); + if (result.Error is not SendQueryError.NoError) { return Array.Empty(); } - var results = new List(record.Answers.Count); + DnsResponse response = result.Response; + var results = new List(response.Answers.Count); // servers send back CNAME records together with associated A/AAAA records string currentAlias = name; - foreach (var answer in record.Answers) + foreach (var answer in response.Answers) { if (answer.Name != currentAlias) { @@ -208,14 +210,144 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add else if (answer.Type == queryType) { Debug.Assert(answer.Data.Length == IPv4Length || answer.Data.Length == IPv6Length); - results.Add(new AddressResult(record.CreatedAt.AddSeconds(answer.Ttl), new IPAddress(answer.Data.Span))); + results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), new IPAddress(answer.Data.Span))); + } + } + + return results.ToArray(); + } + + internal struct SendQueryResult + { + public DnsResponse Response; + public SendQueryError Error; + } + + async ValueTask SendQueryWithRetriesAsync(string name, QueryType queryType, CancellationToken cancellationToken) + { + SendQueryResult result = default; + + for (int index = 0; index < _options.Servers.Length; index++) + { + for (int attempt = 0; attempt < _options.Attempts; attempt++) + { + result = await SendQueryToServerWithTimeoutAsync(_options.Servers[index], name, queryType, index == _options.Servers.Length - 1, cancellationToken).ConfigureAwait(false); + + // TODO: we probably should skip to the next server in case of some errors + if (result.Error is SendQueryError.NoError) + { + break; + } } } - var result = results.ToArray(); + // we have at least one server and we always keep the last received response. return result; } + internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, CancellationToken cancellationToken) + { + (CancellationTokenSource cts, bool disposeTokenSource, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken); + + try + { + return await SendQueryToServerAsync(serverEndPoint, name, queryType, isLastServer, cts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) when ( + !cancellationToken.IsCancellationRequested && // not cancelled by the caller + !pendingRequestsCts.IsCancellationRequested) // not cancelled by the global token (dispose) + // the only remaining token that could cancel this is the linked cts from the timeout. + { + Debug.Assert(cts.Token.IsCancellationRequested); + return new SendQueryResult { Error = SendQueryError.Timeout }; + } + catch (OperationCanceledException ex) when (cancellationToken.IsCancellationRequested && ex.CancellationToken != cancellationToken) + { + // cancellation was initiated by the caller, but exception was triggered by a linked token, + // rethrow the exception with the caller's token. + cancellationToken.ThrowIfCancellationRequested(); + throw new UnreachableException(); + } + finally + { + if (disposeTokenSource) + { + cts.Dispose(); + } + } + } + + async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, CancellationToken cancellationToken) + { + DateTime queryStartedTime = _timeProvider.GetUtcNow().DateTime; + (DnsDataReader responseReader, DnsMessageHeader header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); + + try + { + if (header.IsResultTruncated) + { + responseReader.Dispose(); + // TCP fallback + (responseReader, header) = await SendDnsQueryCoreTcpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); + } + + if (header.QueryCount != 1 || + !responseReader.TryReadQuestion(out var qName, out var qType, out var qClass) || + qName != name || qType != queryType || qClass != QueryClass.Internet) + { + // TODO: do we care? + throw new InvalidOperationException("Invalid response: Query mismatch"); + // return default; + } + + if (header.ResponseCode != QueryResponseCode.NoError) + { + return new SendQueryResult { Error = SendQueryError.ServerError }; + } + + if (header.ResponseCode != QueryResponseCode.NoError && !isLastServer) + { + // we exhausted attempts on this server, try the next one + responseReader.Dispose(); + return default; + } + + int ttl = int.MaxValue; + List answers = ReadRecords(header.AnswerCount, ref ttl, ref responseReader); + List authorities = ReadRecords(header.AuthorityCount, ref ttl, ref responseReader); + List additionals = ReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader); + + DnsResponse response = new(header, queryStartedTime, queryStartedTime.AddSeconds(ttl), answers, authorities, additionals); + responseReader.Dispose(); + + return new SendQueryResult { Response = response, Error = ValidateResponse(response) }; + } + finally + { + responseReader.Dispose(); + } + + static List ReadRecords(int count, ref int ttl, ref DnsDataReader reader) + { + List records = new(count); + + for (int i = 0; i < count; i++) + { + if (!reader.TryReadResourceRecord(out var record)) + { + // TODO how to handle corrupted responses? + throw new InvalidOperationException("Invalid response: corrupted record"); + } + + ttl = Math.Min(ttl, record.Ttl); + // copy the data to a new array since the underlying array is pooled + records.Add(new DnsResourceRecord(record.Name, record.Type, record.Class, record.Ttl, record.Data.ToArray())); + } + + return records; + } + } + internal static bool GetNegativeCacheExpiration(in DnsResponse response, out DateTime expiration) { // @@ -243,13 +375,13 @@ internal static bool GetNegativeCacheExpiration(in DnsResponse response, out Dat return false; } - internal static bool ValidateResponse(in DnsResponse response) + internal static SendQueryError ValidateResponse(in DnsResponse response) { if (response.Header.ResponseCode == QueryResponseCode.NoError) { if (response.Answers.Count > 0) { - return true; + return SendQueryError.NoError; } // // RFC 2308 Section 2.2 - No Data @@ -270,7 +402,7 @@ internal static bool ValidateResponse(in DnsResponse response) { // _cache.TryAdd(name, queryType, expiration, Array.Empty()); } - return false; + return SendQueryError.NoData; } if (response.Header.ResponseCode == QueryResponseCode.NameError) @@ -288,64 +420,10 @@ internal static bool ValidateResponse(in DnsResponse response) // _cache.TryAddNonexistent(name, expiration); } - return false; + return SendQueryError.ServerError; } - return true; - } - - internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, CancellationToken cancellationToken) - { - var buffer = ArrayPool.Shared.Rent(8 * 1024); - try - { - // When sending over TCP, the message is prefixed by 2B length - (ushort transactionId, int length) = EncodeQuestion(buffer.AsMemory(2), name, queryType); - BinaryPrimitives.WriteUInt16BigEndian(buffer, (ushort)length); - - using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); - await socket.ConnectAsync(serverEndPoint, cancellationToken).ConfigureAwait(false); - await socket.SendAsync(buffer.AsMemory(0, length + 2), SocketFlags.None, cancellationToken).ConfigureAwait(false); - - int responseLength = -1; - int bytesRead = 0; - while (responseLength < 0 || bytesRead < length + 2) - { - int read = await socket.ReceiveAsync(buffer.AsMemory(bytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false); - bytesRead += read; - - if (responseLength < 0 && bytesRead >= 2) - { - responseLength = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); - - if (responseLength > buffer.Length) - { - var largerBuffer = ArrayPool.Shared.Rent(responseLength); - Array.Copy(buffer, largerBuffer, bytesRead); - ArrayPool.Shared.Return(buffer); - buffer = largerBuffer; - } - } - } - - DnsDataReader responseReader = new DnsDataReader(buffer.AsMemory(2, responseLength), buffer); - if (!responseReader.TryReadHeader(out DnsMessageHeader header) || - header.TransactionId != transactionId || - !header.IsResponse) - { - throw new InvalidOperationException("Invalid response: Header mismatch"); - } - - buffer = null!; - return (responseReader, header); - } - finally - { - if (buffer != null) - { - ArrayPool.Shared.Return(buffer); - } - } + return SendQueryError.ServerError; } internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreUdpAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, CancellationToken cancellationToken) @@ -368,7 +446,7 @@ internal static bool ValidateResponse(in DnsResponse response) { int readLength = await socket.ReceiveAsync(memory, SocketFlags.None, cancellationToken).ConfigureAwait(false); - if (readLength < HeaderSize) + if (readLength < DnsMessageHeader.HeaderLength) { continue; } @@ -397,100 +475,56 @@ internal static bool ValidateResponse(in DnsResponse response) } } - internal async ValueTask SendQueryAsync(string name, QueryType queryType, CancellationToken cancellationToken) + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, CancellationToken cancellationToken) { - (CancellationTokenSource cts, bool disposeTokenSource, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken); - + var buffer = ArrayPool.Shared.Rent(8 * 1024); try { - return await SendQueryAsyncCore(name, queryType, cts.Token).ConfigureAwait(false); - } - catch (OperationCanceledException oce) when ( - !cancellationToken.IsCancellationRequested && // not cancelled by the caller - !pendingRequestsCts.IsCancellationRequested) // not cancelled by the global token (dispose) - // the only remaining token that could cancel this is the linked cts from the timeout. - { - Debug.Assert(cts.Token.IsCancellationRequested); - throw new TimeoutException("The operation has timed out.", oce); - } - finally - { - if (disposeTokenSource) - { - cts.Dispose(); - } - } + // When sending over TCP, the message is prefixed by 2B length + (ushort transactionId, int length) = EncodeQuestion(buffer.AsMemory(2), name, queryType); + BinaryPrimitives.WriteUInt16BigEndian(buffer, (ushort)length); - async ValueTask SendQueryAsyncCore(string name, QueryType queryType, CancellationToken cancellationToken) - { - DnsDataReader responseReader = default; - DnsMessageHeader header = default; - DateTime queryStartedTime = default; + using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(serverEndPoint, cancellationToken).ConfigureAwait(false); + await socket.SendAsync(buffer.AsMemory(0, length + 2), SocketFlags.None, cancellationToken).ConfigureAwait(false); - for (int index = 0; index < _options.Servers.Length; index++) + int responseLength = -1; + int bytesRead = 0; + while (responseLength < 0 || bytesRead < length + 2) { - IPEndPoint serverEndPoint = _options.Servers[index]; - - queryStartedTime = _timeProvider.GetUtcNow().DateTime; - (responseReader, header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); - - if (header.IsResultTruncated) - { - responseReader.Dispose(); - // TCP fallback - (responseReader, header) = await SendDnsQueryCoreTcpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); - } - - if (header.QueryCount != 1 || - !responseReader.TryReadQuestion(out var qName, out var qType, out var qClass) || - qName != name || qType != queryType || qClass != QueryClass.Internet) - { - // TODO: do we care? - throw new InvalidOperationException("Invalid response: Query mismatch"); - // return default; - } + int read = await socket.ReceiveAsync(buffer.AsMemory(bytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false); + bytesRead += read; - // TODO: on which response codes should we retry? - if (header.ResponseCode == QueryResponseCode.NoError) + if (responseLength < 0 && bytesRead >= 2) { - break; - } + responseLength = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); - if (index < _options.Servers.Length - 1) - { - // keep the reader open for processing - responseReader.Dispose(); + if (responseLength > buffer.Length) + { + var largerBuffer = ArrayPool.Shared.Rent(responseLength); + Array.Copy(buffer, largerBuffer, bytesRead); + ArrayPool.Shared.Return(buffer); + buffer = largerBuffer; + } } } - int ttl = int.MaxValue; - List answers = ReadRecords(header.AnswerCount, ref ttl, ref responseReader); - List authorities = ReadRecords(header.AuthorityCount, ref ttl, ref responseReader); - List additionals = ReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader); - - DnsResponse record = new(header, queryStartedTime, queryStartedTime.AddSeconds(ttl), answers, authorities, additionals); - responseReader.Dispose(); - - return record; - - static List ReadRecords(int count, ref int ttl, ref DnsDataReader reader) + DnsDataReader responseReader = new DnsDataReader(buffer.AsMemory(2, responseLength), buffer); + if (!responseReader.TryReadHeader(out DnsMessageHeader header) || + header.TransactionId != transactionId || + !header.IsResponse) { - List records = new(count); - - for (int i = 0; i < count; i++) - { - if (!reader.TryReadResourceRecord(out var record)) - { - // TODO how to handle corrupted responses? - throw new InvalidOperationException("Invalid response: corrupted record"); - } - - ttl = Math.Min(ttl, record.Ttl); - // copy the data to a new array since the underlying array is pooled - records.Add(new DnsResourceRecord(record.Name, record.Type, record.Class, record.Ttl, record.Data.ToArray())); - } + throw new InvalidOperationException("Invalid response: Header mismatch"); + } - return records; + buffer = null!; + return (responseReader, header); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); } } } @@ -536,14 +570,15 @@ public void Dispose() // it's either due to this source or due to the timeout, and checking whether this source is the culprit is reliable whereas // it's more approximate checking elapsed time. CancellationTokenSource pendingRequestsCts = _pendingRequestsCts; + TimeSpan timeout = _options.Timeout; - bool hasTimeout = _timeout != System.Threading.Timeout.InfiniteTimeSpan; + bool hasTimeout = timeout != System.Threading.Timeout.InfiniteTimeSpan; if (hasTimeout || cancellationToken.CanBeCanceled) { CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, pendingRequestsCts.Token); if (hasTimeout) { - cts.CancelAfter(_timeout); + cts.CancelAfter(timeout); } return (cts, DisposeTokenSource: true, pendingRequestsCts); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs index 8f0d2fd271d..e3eda4d105f 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs @@ -12,6 +12,9 @@ internal class ResolverOptions public string[]? SearchDomains; public bool UseHostsFile; + public int Attempts = 2; + public TimeSpan Timeout = TimeSpan.FromSeconds(3); + public ResolverOptions(IPEndPoint[] servers) { Servers = servers; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs new file mode 100644 index 00000000000..c7f9be88783 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal enum SendQueryError +{ + NoError, + Timeout, + ServerError, + ParseError, + NoData, +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs index 105b64094ce..b0ace03a8db 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs @@ -19,13 +19,25 @@ public async Task PreCanceledToken_Throws() CancellationTokenSource cts = new CancellationTokenSource(); cts.Cancel(); - await Assert.ThrowsAnyAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork, cts.Token)); + var ex = await Assert.ThrowsAnyAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork, cts.Token)); + + Assert.Equal(cts.Token, ex.CancellationToken); } [Fact] - public async Task Timeout_Throws() + public async Task CancellationInProgress_Throws() { - Resolver.Timeout = TimeSpan.FromSeconds(1); - await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork)); + CancellationTokenSource cts = new CancellationTokenSource(); + + var task = Assert.ThrowsAnyAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork, cts.Token)); + + await DnsServer.ProcessUdpRequest(_ => + { + cts.Cancel(); + return Task.CompletedTask; + }); + + OperationCanceledException ex = await task; + Assert.Equal(cts.Token, ex.CancellationToken); } } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs index ed562b364df..1daa77e2bec 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs @@ -12,7 +12,9 @@ public abstract class LoopbackDnsTestBase : IDisposable protected readonly ITestOutputHelper Output; internal LoopbackDnsServer DnsServer { get; } - internal DnsResolver Resolver { get; } + private readonly Lazy _resolverLazy; + internal DnsResolver Resolver => _resolverLazy.Value; + internal ResolverOptions Options { get; } protected readonly FakeTimeProvider TimeProvider; [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "SetTimeProvider")] @@ -23,9 +25,19 @@ public LoopbackDnsTestBase(ITestOutputHelper output) Output = output; DnsServer = new(); TimeProvider = new(); - Resolver = new([DnsServer.DnsEndPoint]); - Resolver.Timeout = TimeSpan.FromSeconds(5); - MockTimeProvider(Resolver, TimeProvider); + Options = new([DnsServer.DnsEndPoint]) + { + Timeout = TimeSpan.FromSeconds(5), + Attempts = 1, + }; + _resolverLazy = new(InitializeResolver); + } + + DnsResolver InitializeResolver() + { + var resolver = new DnsResolver(Options); + MockTimeProvider(resolver, TimeProvider); + return resolver; } public void Dispose() diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs index ed967ad7912..beaaaf26d18 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -12,7 +12,6 @@ public class ResolveAddressesTests : LoopbackDnsTestBase { public ResolveAddressesTests(ITestOutputHelper output) : base(output) { - Resolver.Timeout = TimeSpan.FromSeconds(5); } [Theory] @@ -124,4 +123,12 @@ public async Task ResolveIP_Localhost_ReturnsLoopback(AddressFamily family, stri Assert.Equal(address, result.Address); } + + [Fact] + public async Task Resolve_Timeout_ReturnsEmpty() + { + Options.Timeout = TimeSpan.FromSeconds(1); + AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); + Assert.Empty(result); + } } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs new file mode 100644 index 00000000000..665169d32c0 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Sockets; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class RetryTests : LoopbackDnsTestBase +{ + public RetryTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public async Task Retry_Simple_Success() + { + Options.Attempts = 3; + IPAddress address = IPAddress.Parse("172.213.245.111"); + + _ = Task.Run(async () => + { + for (int attempt = 1; attempt <= 3; attempt++) + { + await DnsServer.ProcessUdpRequest(builder => + { + if (attempt == 3) + { + builder.Answers.AddAddress("www.example.com", 3600, address); + } + else + { + builder.ResponseCode = QueryResponseCode.ServerFailure; + } + return Task.CompletedTask; + }); + } + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + + } +} diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs index d61f8ca8aa3..bd4a1ef4070 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -12,7 +12,6 @@ public class TcpFailoverTests : LoopbackDnsTestBase { public TcpFailoverTests(ITestOutputHelper output) : base(output) { - Resolver.Timeout = TimeSpan.FromHours(5); } [Fact] From 2b6d91e9494e8ae2f75f9d7be3e9f2428e2a296f Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 10 Oct 2024 14:12:28 +0200 Subject: [PATCH 07/45] Add logging --- ...oft.Extensions.ServiceDiscovery.Dns.csproj | 1 - .../Resolver/DnsResolver.Log.cs | 30 +++++++++ .../Resolver/DnsResolver.cs | 67 ++++++++++++------- 3 files changed, 74 insertions(+), 24 deletions(-) create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj index a5b4993f286..c4c9dc1ba88 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj @@ -13,7 +13,6 @@ - diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs new file mode 100644 index 00000000000..ce8f4358b56 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +using System.Net; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal partial class DnsResolver : IDnsResolver, IDisposable +{ + internal static partial class Log + { + [LoggerMessage(1, LogLevel.Information, "Resolving {QueryType} {QueryName} on {Server} attempt {Attempt}", EventName = "Query")] + public static partial void Query(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(2, LogLevel.Debug, "Result truncated for {QueryType} {QueryName} from {Server} attempt {Attempt}. Restarting over TCP", EventName = "ResultTruncated")] + public static partial void ResultTruncated(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(3, LogLevel.Error, "Server {Server} replied with {ResponseCode} when querying {QueryType} {QueryName}", EventName = "ErrorResponseCode")] + public static partial void ErrorResponseCode(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, QueryResponseCode responseCode); + + [LoggerMessage(4, LogLevel.Information, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} timed out.", EventName = "Timeout")] + public static partial void Timeout(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(5, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} returned no data", EventName = "NoData")] + public static partial void NoData(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(6, LogLevel.Error, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed.", EventName = "QueryError")] + public static partial void QueryError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception); + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 7b1b504ebb4..ee0f8778168 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -6,10 +6,12 @@ using System.Diagnostics; using System.Net; using System.Net.Sockets; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; -internal class DnsResolver : IDnsResolver, IDisposable +internal partial class DnsResolver : IDnsResolver, IDisposable { private const int MaximumNameLength = 253; private const int IPv4Length = 4; @@ -20,22 +22,23 @@ internal class DnsResolver : IDnsResolver, IDisposable bool _disposed; private readonly ResolverOptions _options; private readonly CancellationTokenSource _pendingRequestsCts = new(); - private int _maxRetries = 3; - private TimeProvider _timeProvider = TimeProvider.System; + private readonly ILogger _logger; internal void SetTimeProvider(TimeProvider timeProvider) { _timeProvider = timeProvider; } - public DnsResolver(TimeProvider timeProvider) : this(OperatingSystem.IsWindows() ? NetworkInfo.GetOptions() : ResolvConf.GetOptions()) + public DnsResolver(TimeProvider timeProvider, ILogger logger) : this(OperatingSystem.IsWindows() ? NetworkInfo.GetOptions() : ResolvConf.GetOptions()) { _timeProvider = timeProvider; + _logger = logger; } internal DnsResolver(ResolverOptions options) { + _logger = NullLogger.Instance; _options = options; if (options.Servers.Length == 0) { @@ -57,17 +60,6 @@ internal DnsResolver(IPEndPoint server) : this(new ResolverOptions(server)) { } - public int MaxRetries - { - get => _maxRetries; - set - { - ObjectDisposedException.ThrowIf(_disposed, this); - ArgumentOutOfRangeException.ThrowIfLessThan(value, 0); - _maxRetries = value; - } - } - public async ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) { ObjectDisposedException.ThrowIf(_disposed, this); @@ -229,29 +221,51 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp for (int index = 0; index < _options.Servers.Length; index++) { + IPEndPoint serverEndPoint = _options.Servers[index]; + for (int attempt = 0; attempt < _options.Attempts; attempt++) { - result = await SendQueryToServerWithTimeoutAsync(_options.Servers[index], name, queryType, index == _options.Servers.Length - 1, cancellationToken).ConfigureAwait(false); - // TODO: we probably should skip to the next server in case of some errors - if (result.Error is SendQueryError.NoError) + try { - break; + result = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, index == _options.Servers.Length - 1, attempt, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) when (!cancellationToken.IsCancellationRequested) + { + Log.QueryError(_logger, queryType, name, serverEndPoint, attempt, ex); + continue; // retry or skip to the next server + } + + switch (result.Error) + { + case SendQueryError.NoError: + goto exit; + case SendQueryError.Timeout: + // TODO: should we retry on timeout or skip to the next server? + Log.Timeout(_logger, queryType, name, serverEndPoint, attempt); + break; + case SendQueryError.ServerError: + Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, result.Response.Header.ResponseCode); + break; + case SendQueryError.NoData: + Log.NoData(_logger, queryType, name, serverEndPoint, attempt); + break; } } } + exit: // we have at least one server and we always keep the last received response. return result; } - internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, CancellationToken cancellationToken) + internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, int attempt, CancellationToken cancellationToken) { (CancellationTokenSource cts, bool disposeTokenSource, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken); try { - return await SendQueryToServerAsync(serverEndPoint, name, queryType, isLastServer, cts.Token).ConfigureAwait(false); + return await SendQueryToServerAsync(serverEndPoint, name, queryType, isLastServer, attempt, cts.Token).ConfigureAwait(false); } catch (OperationCanceledException) when ( !cancellationToken.IsCancellationRequested && // not cancelled by the caller @@ -277,8 +291,10 @@ internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEn } } - async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, CancellationToken cancellationToken) + async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, int attempt, CancellationToken cancellationToken) { + Log.Query(_logger, queryType, name, serverEndPoint, attempt); + DateTime queryStartedTime = _timeProvider.GetUtcNow().DateTime; (DnsDataReader responseReader, DnsMessageHeader header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); @@ -286,6 +302,7 @@ async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoin { if (header.IsResultTruncated) { + Log.ResultTruncated(_logger, queryType, name, serverEndPoint, 0); responseReader.Dispose(); // TCP fallback (responseReader, header) = await SendDnsQueryCoreTcpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); @@ -302,7 +319,11 @@ async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoin if (header.ResponseCode != QueryResponseCode.NoError) { - return new SendQueryResult { Error = SendQueryError.ServerError }; + return new SendQueryResult + { + Response = new DnsResponse(header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.ServerError + }; } if (header.ResponseCode != QueryResponseCode.NoError && !isLastServer) From 2137954cafb0402952bcf02fa8f2f8c49ac19db4 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 10 Oct 2024 14:29:04 +0200 Subject: [PATCH 08/45] Fix Microsoft.Extensions.ServiceDiscovery.Yarp.Tests --- ...oft.Extensions.ServiceDiscovery.Dns.csproj | 1 + .../YarpServiceDiscoveryTests.cs | 93 +++++-------------- 2 files changed, 23 insertions(+), 71 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj index c4c9dc1ba88..9ea318cefb1 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj @@ -27,6 +27,7 @@ + diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs index 5efb4a98c2c..74e8d5800af 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs @@ -6,8 +6,8 @@ using Xunit; using Yarp.ReverseProxy.Configuration; using System.Net; -using DnsClient; -using DnsClient.Protocol; +using System.Net.Sockets; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Options; @@ -266,32 +266,22 @@ public async Task ServiceDiscoveryDestinationResolverTests_Dns() [Fact] public async Task ServiceDiscoveryDestinationResolverTests_DnsSrv() { - var dnsClientMock = new FakeDnsClient + var dnsClientMock = new FakeDnsResolver { - QueryAsyncFunc = (query, queryType, queryClass, cancellationToken) => + ResolveServiceAsyncFunc = (name, cancellationToken) => { - var response = new FakeDnsQueryResponse - { - Answers = new List - { - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 66, 8888, DnsString.Parse("srv-a")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 9999, DnsString.Parse("srv-b")), - new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 7777, DnsString.Parse("srv-c")) - }, - Additionals = new List - { - new ARecord(new ResourceRecordInfo("srv-a", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Parse("10.10.10.10")), - new ARecord(new ResourceRecordInfo("srv-b", ResourceRecordType.AAAA, queryClass, 64, 0), IPAddress.IPv6Loopback), - new ARecord(new ResourceRecordInfo("srv-c", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Loopback), - } - }; - - return Task.FromResult(response); + ServiceResult[] response = [ + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 66, 8888, "srv-a", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Parse("10.10.10.10"))]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 9999, "srv-b", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.IPv6Loopback)]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 7777, "srv-c", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Loopback)]) + ]; + + return ValueTask.FromResult(response); } }; await using var services = new ServiceCollection() - .AddSingleton(dnsClientMock) + .AddSingleton(dnsClientMock) .AddServiceDiscoveryCore() .AddDnsSrvServiceEndpointProvider(options => options.QuerySuffix = ".ns") .BuildServiceProvider(); @@ -314,56 +304,17 @@ public async Task ServiceDiscoveryDestinationResolverTests_DnsSrv() a => Assert.Equal("https://127.0.0.1:7777/", a)); } - private sealed class FakeDnsClient : IDnsQuery + private sealed class FakeDnsResolver : IDnsResolver { - public Func>? QueryAsyncFunc { get; set; } - - public IDnsQueryResponse Query(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse Query(DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse Query(DnsQuestion question, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException(); - public Task QueryAsync(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) - => QueryAsyncFunc!(query, queryType, queryClass, cancellationToken); - public Task QueryAsync(DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryAsync(DnsQuestion question, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryCache(DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse QueryCache(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryReverse(IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryReverse(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException(); - public Task QueryReverseAsync(IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryReverseAsync(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, DnsQuestion question) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, DnsQuestion question, DnsQueryOptions queryOptions) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServer(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, DnsQuestion question, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerAsync(IReadOnlyCollection servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress) => throw new NotImplementedException(); - public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection servers, IPAddress ipAddress, DnsQueryOptions queryOptions) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task QueryServerReverseAsync(IReadOnlyCollection servers, IPAddress ipAddress, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - } + public Func>? ResolveIPAddressesAsyncFunc { get; set; } + public ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc!.Invoke(name, addressFamily, cancellationToken); - private sealed class FakeDnsQueryResponse : IDnsQueryResponse - { - public IReadOnlyList? Questions { get; set; } - public IReadOnlyList? Additionals { get; set; } - public IEnumerable? AllRecords { get; set; } - public IReadOnlyList? Answers { get; set; } - public IReadOnlyList? Authorities { get; set; } - public string? AuditTrail { get; set; } - public string? ErrorMessage { get; set; } - public bool HasError { get; set; } - public DnsResponseHeader? Header { get; set; } - public int MessageSize { get; set; } - public NameServer? NameServer { get; set; } - public DnsQuerySettings? Settings { get; set; } + public Func>? ResolveIPAddressesAsyncFunc2 { get; set; } + + public ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc2!.Invoke(name, cancellationToken); + + public Func>? ResolveServiceAsyncFunc { get; set; } + + public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) => ResolveServiceAsyncFunc!.Invoke(name, cancellationToken); } } From 2fc7de4198b763d2aa1d5e64b0507d622bf3aacb Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Tue, 26 Nov 2024 14:47:26 +0100 Subject: [PATCH 09/45] Activity WIP --- .../Resolver/DnsResolver.Activity.cs | 96 +++++++++++++++++++ .../Resolver/DnsResolver.cs | 14 ++- 2 files changed, 107 insertions(+), 3 deletions(-) create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs new file mode 100644 index 00000000000..ce0deeb2396 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +// using System.Diagnostics.Metrics; +// using System.Diagnostics.Tracing; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal partial class DnsResolver +{ + internal static class Telemetry + { + // private static readonly Meter s_meter = new Meter("Microsoft.Extensions.ServiceDiscovery.Dns.Resolver"); + // private static readonly Counter s_queryCounter = s_meter.CreateCounter("queries", "Number of DNS queries"); + // private static readonly Histogram s_queryDuration = s_meter.CreateHistogram("query.duration", "ms", "DNS query duration"); + + public static NameResolutionActivity StartNameResolution(string hostName, QueryType queryType) + { + return new NameResolutionActivity(hostName, queryType); + } + + public static void StopNameResolution(in NameResolutionActivity activity, object? answers, SendQueryError error) + { + if (!activity.Stop(answers, error, out TimeSpan _)) + { + return; + } + } + } + + internal readonly struct NameResolutionActivity + { + private const string ActivitySourceName = "Microsoft.Extensions.ServiceDiscovery.Dns.Resolver"; + private const string ActivityName = ActivitySourceName + ".Resolve"; + private static readonly ActivitySource s_activitySource = new ActivitySource(ActivitySourceName); + + private readonly long _startingTimestamp; + private readonly Activity? _activity; // null if activity is not started + + public NameResolutionActivity(string hostName, QueryType queryType) + { + _startingTimestamp = Stopwatch.GetTimestamp(); + _activity = s_activitySource.StartActivity(ActivityName, ActivityKind.Client); + if (_activity is not null) + { + _activity.DisplayName = $"Resolving {hostName}"; + if (_activity.IsAllDataRequested) + { + _activity.SetTag("dns.question.name", hostName); + _activity.SetTag("dns.question.type", queryType.ToString()); + } + } + } + + public bool Stop(object? answers, SendQueryError error, out TimeSpan duration) + { + if (_activity is null) + { + duration = TimeSpan.Zero; + return false; + } + + if (_activity.IsAllDataRequested) + { + if (answers is not null) + { + static string[] ToStringHelper(T[] array) => array.Select(a => a!.ToString()!).ToArray(); + + string[]? answersArray = answers switch + { + ServiceResult[] serviceResults => ToStringHelper(serviceResults), + AddressResult[] addressResults => ToStringHelper(addressResults), + _ => null + }; + + Debug.Assert(answersArray is not null); + _activity.SetTag("dns.answers", answersArray); + } + else + { + _activity.SetTag("error.type", error.ToString()); + } + } + + if (answers is null) + { + _activity.SetStatus(ActivityStatusCode.Error); + } + + _activity.Stop(); + duration = Stopwatch.GetElapsedTime(_startingTimestamp); + return true; + } + } +} diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index ee0f8778168..63a225efd51 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -65,9 +65,12 @@ public async ValueTask ResolveServiceAsync(string name, Cancell ObjectDisposedException.ThrowIf(_disposed, this); cancellationToken.ThrowIfCancellationRequested(); + NameResolutionActivity activity = Telemetry.StartNameResolution(name, QueryType.SRV); SendQueryResult result = await SendQueryWithRetriesAsync(name, QueryType.SRV, cancellationToken).ConfigureAwait(false); + if (result.Error is not SendQueryError.NoError) { + Telemetry.StopNameResolution(activity, null, result.Error); return Array.Empty(); } @@ -107,7 +110,9 @@ public async ValueTask ResolveServiceAsync(string name, Cancell } } - return results.ToArray(); + ServiceResult[] res = results.ToArray(); + Telemetry.StopNameResolution(activity, res, result.Error); + return res; } public async ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) @@ -172,10 +177,11 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add } var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; - + NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType); SendQueryResult result = await SendQueryWithRetriesAsync(name, queryType, cancellationToken).ConfigureAwait(false); if (result.Error is not SendQueryError.NoError) { + Telemetry.StopNameResolution(activity, null, result.Error); return Array.Empty(); } @@ -206,7 +212,9 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add } } - return results.ToArray(); + AddressResult[] res = results.ToArray(); + Telemetry.StopNameResolution(activity, res, result.Error); + return res; } internal struct SendQueryResult From 9b4de29726a87ef0c2f74bca37df7a35004ae2b2 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 28 Nov 2024 12:57:05 +0100 Subject: [PATCH 10/45] Fix reading CNAME with pointers in domain name segments --- .../Resolver/DnsDataReader.cs | 10 +++---- .../Resolver/DnsResolver.cs | 28 ++++++++++++------- .../Resolver/DnsResponse.cs | 19 +++++++++++-- 3 files changed, 40 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs index 227698107df..192fa0a948e 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -10,15 +10,15 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal struct DnsDataReader : IDisposable { + public byte[]? RawData { get; private set; } private ReadOnlyMemory _buffer; - private byte[]? _pooled; private int _position; public DnsDataReader(ReadOnlyMemory buffer, byte[]? returnToPool = null) { _buffer = buffer; _position = 0; - _pooled = returnToPool; + RawData = returnToPool; } public bool TryReadHeader(out DnsMessageHeader header) @@ -109,10 +109,10 @@ public bool TryReadDomainName([NotNullWhen(true)] out string? name) public void Dispose() { - if (_pooled is not null) + if (RawData is not null) { - ArrayPool.Shared.Return(_pooled); - _pooled = null!; + ArrayPool.Shared.Return(RawData); + RawData = null!; } _buffer = default; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 63a225efd51..f99aa266201 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Net; using System.Net.Sockets; +using System.Runtime.InteropServices; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -74,7 +75,7 @@ public async ValueTask ResolveServiceAsync(string name, Cancell return Array.Empty(); } - DnsResponse response = result.Response; + using DnsResponse response = result.Response; var results = new List(response.Answers.Count); @@ -185,7 +186,7 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add return Array.Empty(); } - DnsResponse response = result.Response; + using DnsResponse response = result.Response; var results = new List(response.Answers.Count); // servers send back CNAME records together with associated A/AAAA records @@ -200,8 +201,16 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add if (answer.Type == QueryType.CNAME) { - bool success = DnsPrimitives.TryReadQName(answer.Data.Span, 0, out currentAlias!, out _); - Debug.Assert(success, "Failed to read CNAME"); + // Although RFC does not necessarily allow pointers segments in CNAME domain names, some servers do use them + // so we need to pass the entire buffer to TryReadQName with the proper offset. The data should be always + // backed by the array containing the full response. + + var success = MemoryMarshal.TryGetArray(answer.Data, out ArraySegment segment); + Debug.Assert(success, "Failed to get array segment"); + if (!DnsPrimitives.TryReadQName(segment.Array.AsSpan(0, segment.Offset + segment.Count), segment.Offset, out currentAlias!, out _)) + { + throw new InvalidOperationException("Invalid response: CNAME record"); + } continue; } @@ -329,7 +338,7 @@ async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoin { return new SendQueryResult { - Response = new DnsResponse(header, queryStartedTime, queryStartedTime, null!, null!, null!), + Response = new DnsResponse(Array.Empty(), header, queryStartedTime, queryStartedTime, null!, null!, null!), Error = SendQueryError.ServerError }; } @@ -337,7 +346,6 @@ async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoin if (header.ResponseCode != QueryResponseCode.NoError && !isLastServer) { // we exhausted attempts on this server, try the next one - responseReader.Dispose(); return default; } @@ -346,8 +354,9 @@ async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoin List authorities = ReadRecords(header.AuthorityCount, ref ttl, ref responseReader); List additionals = ReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader); - DnsResponse response = new(header, queryStartedTime, queryStartedTime.AddSeconds(ttl), answers, authorities, additionals); - responseReader.Dispose(); + // we transfer ownership of RawData to the response + DnsResponse response = new(responseReader.RawData!, header, queryStartedTime, queryStartedTime.AddSeconds(ttl), answers, authorities, additionals); + responseReader = default; // avoid disposing (and returning RawData to the pool) return new SendQueryResult { Response = response, Error = ValidateResponse(response) }; } @@ -369,8 +378,7 @@ static List ReadRecords(int count, ref int ttl, ref DnsDataRe } ttl = Math.Min(ttl, record.Ttl); - // copy the data to a new array since the underlying array is pooled - records.Add(new DnsResourceRecord(record.Name, record.Type, record.Class, record.Ttl, record.Data.ToArray())); + records.Add(new DnsResourceRecord(record.Name, record.Type, record.Class, record.Ttl, record.Data)); } return records; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs index 582cb282730..b76ccf1f47f 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs @@ -1,9 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; + namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; -internal struct DnsResponse +internal struct DnsResponse : IDisposable { public DnsMessageHeader Header { get; } public List Answers { get; } @@ -11,9 +13,13 @@ internal struct DnsResponse public List Additionals { get; } public DateTime CreatedAt { get; } public DateTime Expiration { get; } + public ReadOnlyMemory RawData => _rawData ?? ReadOnlyMemory.Empty; + private byte[]? _rawData; - public DnsResponse(DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals) + public DnsResponse(byte[] rawData, DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals) { + _rawData = rawData; + Header = header; CreatedAt = createdAt; Expiration = expiration; @@ -21,4 +27,13 @@ public DnsResponse(DnsMessageHeader header, DateTime createdAt, DateTime expirat Authorities = authorities; Additionals = additionals; } + + public void Dispose() + { + if (_rawData != null) + { + ArrayPool.Shared.Return(_rawData); + _rawData = null; + } + } } From 3ad5a7524db7b08275e682e015b7f79943995950 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 28 Nov 2024 15:05:35 +0100 Subject: [PATCH 11/45] Reenable telemetry --- .../Resolver/DnsResolver.Activity.cs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs index ce0deeb2396..d62423fe136 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs @@ -2,8 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; -// using System.Diagnostics.Metrics; -// using System.Diagnostics.Tracing; +using System.Diagnostics.Metrics; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; @@ -11,9 +10,9 @@ internal partial class DnsResolver { internal static class Telemetry { - // private static readonly Meter s_meter = new Meter("Microsoft.Extensions.ServiceDiscovery.Dns.Resolver"); - // private static readonly Counter s_queryCounter = s_meter.CreateCounter("queries", "Number of DNS queries"); - // private static readonly Histogram s_queryDuration = s_meter.CreateHistogram("query.duration", "ms", "DNS query duration"); + private static readonly Meter s_meter = new Meter("Microsoft.Extensions.ServiceDiscovery.Dns.Resolver"); + private static readonly Counter s_queryCounter = s_meter.CreateCounter("queries", "Number of DNS queries"); + private static readonly Histogram s_queryDuration = s_meter.CreateHistogram("query.duration", "ms", "DNS query duration"); public static NameResolutionActivity StartNameResolution(string hostName, QueryType queryType) { @@ -22,10 +21,13 @@ public static NameResolutionActivity StartNameResolution(string hostName, QueryT public static void StopNameResolution(in NameResolutionActivity activity, object? answers, SendQueryError error) { - if (!activity.Stop(answers, error, out TimeSpan _)) + if (!activity.Stop(answers, error, out TimeSpan duration)) { return; } + + s_queryCounter.Add(1); + s_queryDuration.Record(duration.TotalMilliseconds); } } From 40f8f9bbb515d62a5bc7d11d75a254242ca87929 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 28 Nov 2024 17:07:42 +0100 Subject: [PATCH 12/45] Last changes to telemetry --- .../Resolver/DnsResolver.Log.cs | 4 +- ...r.Activity.cs => DnsResolver.Telemetry.cs} | 45 +++++++++++++------ .../Resolver/DnsResolver.cs | 12 ++--- 3 files changed, 39 insertions(+), 22 deletions(-) rename src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/{DnsResolver.Activity.cs => DnsResolver.Telemetry.cs} (67%) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs index ce8f4358b56..7a4f4223ba5 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs @@ -9,7 +9,7 @@ internal partial class DnsResolver : IDnsResolver, IDisposable { internal static partial class Log { - [LoggerMessage(1, LogLevel.Information, "Resolving {QueryType} {QueryName} on {Server} attempt {Attempt}", EventName = "Query")] + [LoggerMessage(1, LogLevel.Debug, "Resolving {QueryType} {QueryName} on {Server} attempt {Attempt}", EventName = "Query")] public static partial void Query(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); [LoggerMessage(2, LogLevel.Debug, "Result truncated for {QueryType} {QueryName} from {Server} attempt {Attempt}. Restarting over TCP", EventName = "ResultTruncated")] @@ -18,7 +18,7 @@ internal static partial class Log [LoggerMessage(3, LogLevel.Error, "Server {Server} replied with {ResponseCode} when querying {QueryType} {QueryName}", EventName = "ErrorResponseCode")] public static partial void ErrorResponseCode(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, QueryResponseCode responseCode); - [LoggerMessage(4, LogLevel.Information, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} timed out.", EventName = "Timeout")] + [LoggerMessage(4, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} timed out.", EventName = "Timeout")] public static partial void Timeout(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); [LoggerMessage(5, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} returned no data", EventName = "NoData")] diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs similarity index 67% rename from src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs rename to src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs index d62423fe136..4be956cede9 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Activity.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs @@ -11,23 +11,41 @@ internal partial class DnsResolver internal static class Telemetry { private static readonly Meter s_meter = new Meter("Microsoft.Extensions.ServiceDiscovery.Dns.Resolver"); - private static readonly Counter s_queryCounter = s_meter.CreateCounter("queries", "Number of DNS queries"); private static readonly Histogram s_queryDuration = s_meter.CreateHistogram("query.duration", "ms", "DNS query duration"); - public static NameResolutionActivity StartNameResolution(string hostName, QueryType queryType) + private static bool IsEnabled() => s_queryDuration.Enabled; + + public static NameResolutionActivity StartNameResolution(string hostName, QueryType queryType, long startingTimestamp) { - return new NameResolutionActivity(hostName, queryType); + if (IsEnabled()) + { + return new NameResolutionActivity(hostName, queryType, startingTimestamp); + } + + return default; } - public static void StopNameResolution(in NameResolutionActivity activity, object? answers, SendQueryError error) + public static void StopNameResolution(string hostName, QueryType queryType, in NameResolutionActivity activity, object? answers, SendQueryError error, long endingTimestamp) { - if (!activity.Stop(answers, error, out TimeSpan duration)) + activity.Stop(answers, error, endingTimestamp, out TimeSpan duration); + + if (!IsEnabled()) { return; } - s_queryCounter.Add(1); - s_queryDuration.Record(duration.TotalMilliseconds); + var hostNameTag = KeyValuePair.Create("dns.question.name", (object?)hostName); + var queryTypeTag = KeyValuePair.Create("dns.question.type", (object?)queryType); + + if (answers is not null) + { + s_queryDuration.Record(duration.TotalSeconds, hostNameTag, queryTypeTag); + } + else + { + var errorTypeTag = KeyValuePair.Create("error.type", (object?)error.ToString()); + s_queryDuration.Record(duration.TotalSeconds, hostNameTag, queryTypeTag, errorTypeTag); + } } } @@ -40,9 +58,9 @@ internal readonly struct NameResolutionActivity private readonly long _startingTimestamp; private readonly Activity? _activity; // null if activity is not started - public NameResolutionActivity(string hostName, QueryType queryType) + public NameResolutionActivity(string hostName, QueryType queryType, long startingTimestamp) { - _startingTimestamp = Stopwatch.GetTimestamp(); + _startingTimestamp = startingTimestamp; _activity = s_activitySource.StartActivity(ActivityName, ActivityKind.Client); if (_activity is not null) { @@ -55,12 +73,13 @@ public NameResolutionActivity(string hostName, QueryType queryType) } } - public bool Stop(object? answers, SendQueryError error, out TimeSpan duration) + public void Stop(object? answers, SendQueryError error, long endingTimestamp, out TimeSpan duration) { + duration = Stopwatch.GetElapsedTime(_startingTimestamp, endingTimestamp); + if (_activity is null) { - duration = TimeSpan.Zero; - return false; + return; } if (_activity.IsAllDataRequested) @@ -91,8 +110,6 @@ public bool Stop(object? answers, SendQueryError error, out TimeSpan duration) } _activity.Stop(); - duration = Stopwatch.GetElapsedTime(_startingTimestamp); - return true; } } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index f99aa266201..ad5cca3909e 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -66,12 +66,12 @@ public async ValueTask ResolveServiceAsync(string name, Cancell ObjectDisposedException.ThrowIf(_disposed, this); cancellationToken.ThrowIfCancellationRequested(); - NameResolutionActivity activity = Telemetry.StartNameResolution(name, QueryType.SRV); + NameResolutionActivity activity = Telemetry.StartNameResolution(name, QueryType.SRV, _timeProvider.GetTimestamp()); SendQueryResult result = await SendQueryWithRetriesAsync(name, QueryType.SRV, cancellationToken).ConfigureAwait(false); if (result.Error is not SendQueryError.NoError) { - Telemetry.StopNameResolution(activity, null, result.Error); + Telemetry.StopNameResolution(name, QueryType.SRV, activity, null, result.Error, _timeProvider.GetTimestamp()); return Array.Empty(); } @@ -112,7 +112,7 @@ public async ValueTask ResolveServiceAsync(string name, Cancell } ServiceResult[] res = results.ToArray(); - Telemetry.StopNameResolution(activity, res, result.Error); + Telemetry.StopNameResolution(name, QueryType.SRV, activity, res, result.Error, _timeProvider.GetTimestamp()); return res; } @@ -178,11 +178,11 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add } var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; - NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType); + NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType, _timeProvider.GetTimestamp()); SendQueryResult result = await SendQueryWithRetriesAsync(name, queryType, cancellationToken).ConfigureAwait(false); if (result.Error is not SendQueryError.NoError) { - Telemetry.StopNameResolution(activity, null, result.Error); + Telemetry.StopNameResolution(name, queryType, activity, null, result.Error, _timeProvider.GetTimestamp()); return Array.Empty(); } @@ -222,7 +222,7 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add } AddressResult[] res = results.ToArray(); - Telemetry.StopNameResolution(activity, res, result.Error); + Telemetry.StopNameResolution(name, queryType, activity, res, result.Error, _timeProvider.GetTimestamp()); return res; } From e3856057eba15bdc590cca96b2bab766bb82ef31 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Wed, 15 Jan 2025 10:20:36 +0100 Subject: [PATCH 13/45] Use strong RNG for Transaction ID generation --- .../Resolver/DnsMessageHeader.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs index c920375aa9c..675e150e3b2 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers.Binary; +using System.Security.Cryptography; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; @@ -97,7 +98,7 @@ internal bool IsResponse internal void InitQueryHeader() { this = default; - TransactionId = (ushort)Random.Shared.Next(ushort.MaxValue); + TransactionId = (ushort)RandomNumberGenerator.GetInt32(short.MaxValue + 1); IsRecursionDesired = true; QueryCount = 1; } From 3cf0a9d28d9bd05ddd98fde71591bc7ca543827f Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Mon, 3 Feb 2025 13:38:59 +0100 Subject: [PATCH 14/45] Check against too long domain name --- .../Resolver/DnsPrimitives.cs | 10 +++++++ .../Resolver/DnsResolver.cs | 21 +++++++++----- .../Resolver/DnsPrimitivesTests.cs | 28 +++++++++++++++++++ 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs index 96467442445..1c3cf9b0e1c 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs @@ -9,6 +9,8 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal static class DnsPrimitives { + internal const int MaxDomainNameLength = 253; + internal static bool TryWriteQName(Span destination, string name, out int written) { // @@ -103,7 +105,15 @@ private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messag { sb.Append('.'); } + sb.Append(Encoding.ASCII.GetString(messageBuffer.Slice(currentOffset + 1, length))); + + if (sb.Length > MaxDomainNameLength) + { + // domain name is too long + return false; + } + currentOffset += 1 + length; bytesRead += 1 + length; } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index ad5cca3909e..b562e025b76 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Buffers.Binary; using System.Diagnostics; +using System.Globalization; using System.Net; using System.Net.Sockets; using System.Runtime.InteropServices; @@ -14,7 +15,6 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal partial class DnsResolver : IDnsResolver, IDisposable { - private const int MaximumNameLength = 253; private const int IPv4Length = 4; private const int IPv6Length = 16; @@ -66,6 +66,8 @@ public async ValueTask ResolveServiceAsync(string name, Cancell ObjectDisposedException.ThrowIf(_disposed, this); cancellationToken.ThrowIfCancellationRequested(); + name = GetNormalizedHostName(name); + NameResolutionActivity activity = Telemetry.StartNameResolution(name, QueryType.SRV, _timeProvider.GetTimestamp()); SendQueryResult result = await SendQueryWithRetriesAsync(name, QueryType.SRV, cancellationToken).ConfigureAwait(false); @@ -118,7 +120,7 @@ public async ValueTask ResolveServiceAsync(string name, Cancell public async ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) { - if (name == "localhost") + if (string.Equals(name, "localhost", StringComparison.OrdinalIgnoreCase)) { // name localhost exists outside of DNS and can't be resolved by a DNS server int len = (Socket.OSSupportsIPv4 ? 1 : 0) + (Socket.OSSupportsIPv6 ? 1 : 0); @@ -157,7 +159,7 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add throw new ArgumentOutOfRangeException(nameof(addressFamily), addressFamily, "Invalid address family"); } - if (name == "localhost") + if (string.Equals(name, "localhost", StringComparison.OrdinalIgnoreCase)) { // name localhost exists outside of DNS and can't be resolved by a DNS server if (addressFamily == AddressFamily.InterNetwork && Socket.OSSupportsIPv4) @@ -172,10 +174,7 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add return Array.Empty(); } - if (name.Length > MaximumNameLength) - { - throw new ArgumentException("Name is too long", nameof(name)); - } + name = GetNormalizedHostName(name); var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType, _timeProvider.GetTimestamp()); @@ -623,4 +622,12 @@ public void Dispose() return (pendingRequestsCts, DisposeTokenSource: false, pendingRequestsCts); } + + private static readonly IdnMapping s_idnMapping = new IdnMapping(); + + private static string GetNormalizedHostName(string name) + { + // TODO: better exception message + return s_idnMapping.GetAscii(name); + } } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs index b2afc5510f3..d5690b405e9 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs @@ -122,4 +122,32 @@ public void TryReadQName_ReservedBits() Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); } + + [Theory] + [InlineData(253)] + [InlineData(254)] + [InlineData(255)] + public void TryReadQName_NameTooLong(int length) + { + // longest possible label is 63 bytes + 1 byte for length + byte[] labelData = new byte[64]; + Array.Fill(labelData, (byte)'a'); + labelData[0] = 63; + + int remainder = length - 3 * 64; + + byte[] lastLabelData = new byte[remainder + 1]; + Array.Fill(lastLabelData, (byte)'a'); + lastLabelData[0] = (byte)remainder; + + byte[] data = Enumerable.Repeat(labelData, 3).SelectMany(x => x).Concat(lastLabelData).Concat(new byte[1]).ToArray(); + if (length > 253) + { + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + else + { + Assert.True(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + } } From 6016e41982dcb99b906ac2de638d320ee9db76d1 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Mon, 3 Feb 2025 14:19:18 +0100 Subject: [PATCH 15/45] Disallow pointer to pointer. --- .../Resolver/DnsPrimitives.cs | 76 ++++++++++--------- .../Resolver/DnsPrimitivesTests.cs | 18 ++++- 2 files changed, 54 insertions(+), 40 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs index 1c3cf9b0e1c..c6a90a5624d 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs @@ -9,6 +9,7 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal static class DnsPrimitives { + // Maximum length of a domain name in ASCII (excluding trailing dot) internal const int MaxDomainNameLength = 253; internal static bool TryWriteQName(Span destination, string name, out int written) @@ -59,7 +60,7 @@ internal static bool TryWriteQName(Span destination, string name, out int return true; } - private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messageBuffer, int offset, out int bytesRead) + private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messageBuffer, int offset, out int bytesRead, bool canStartWithPointer = true) { // // domain name can be either @@ -72,10 +73,12 @@ private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messag // // It is not specified by the RFC if pointers must be backwards only, // the code below prohibits forward (and self) pointers to avoid - // infinite loops + // infinite loops. It also allows pointers only to point to a + // label, not to another pointer. // bytesRead = 0; + bool allowPointer = canStartWithPointer; if (offset < 0 || offset >= messageBuffer.Length) { @@ -98,52 +101,51 @@ private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messag return true; } - if (currentOffset + 1 + length < messageBuffer.Length) + if (currentOffset + 1 + length >= messageBuffer.Length) { - // read next label/segment - if (sb.Length > 0) - { - sb.Append('.'); - } - - sb.Append(Encoding.ASCII.GetString(messageBuffer.Slice(currentOffset + 1, length))); - - if (sb.Length > MaxDomainNameLength) - { - // domain name is too long - return false; - } - - currentOffset += 1 + length; - bytesRead += 1 + length; + // too many labels or truncated data + break; } - else + + // read next label/segment + if (sb.Length > 0) { - // truncated data - break; + sb.Append('.'); + } + + sb.Append(Encoding.ASCII.GetString(messageBuffer.Slice(currentOffset + 1, length))); + + if (sb.Length > MaxDomainNameLength) + { + // domain name is too long + return false; } + + currentOffset += 1 + length; + bytesRead += 1 + length; + + // we read a label, they can be followed by pointer. + allowPointer = true; } else if ((length & 0xC0) == 0xC0) { // pointer, together with next byte gives the offset of the true label - if (currentOffset + 1 < messageBuffer.Length) - { - bytesRead += 2; - int pointer = ((length & 0x3F) << 8) | messageBuffer[currentOffset + 1]; - - // we prohibit self-references and forward pointers to avoid - // infinite loops, we do this by truncating the - // messageBuffer at the offset where we started reading the - // name. We also ignore the bytesRead from the recursive - // call, as we are only interested on how many bytes we read - // from the initial start of the name. - return TryReadQNameCore(sb, messageBuffer.Slice(0, offset), pointer, out int _); - } - else + if (!allowPointer || currentOffset + 1 >= messageBuffer.Length) { - // truncated data + // pointer to pointer or truncated data break; } + + bytesRead += 2; + int pointer = ((length & 0x3F) << 8) | messageBuffer[currentOffset + 1]; + + // we prohibit self-references and forward pointers to avoid + // infinite loops, we do this by truncating the + // messageBuffer at the offset where we started reading the + // name. We also ignore the bytesRead from the recursive + // call, as we are only interested on how many bytes we read + // from the initial start of the name. + return TryReadQNameCore(sb, messageBuffer.Slice(0, offset), pointer, out int _, false); } else { diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs index d5690b405e9..32d4ee02ce9 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs @@ -97,7 +97,7 @@ public void TryReadQName_PointerTruncated_Fails() public void TryReadQName_ForwardPointer_Fails() { // www->[ptr to example.com], [7B padding], example.com. - Span data = "\x03www\x00\0x000cpaddingexample\x0003com\x00"u8.ToArray(); + Span data = "\x03www\x00\x000dpadding\x0007example\x0003com\x00"u8.ToArray(); data[4] = 0xc0; Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); @@ -107,17 +107,29 @@ public void TryReadQName_ForwardPointer_Fails() public void TryReadQName_PointerToSelf_Fails() { // www->[ptr to www->...] - Span data = "\x0003www\x00c0"u8.ToArray(); + Span data = "\x0003www\0\0"u8.ToArray(); data[4] = 0xc0; Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); } + [Fact] + public void TryReadQName_PointerToPointer_Fails() + { + // com, example[->com], example2[->[->com]] + Span data = "\x0003com\0\x0007example\0\0\x0008example2\0\0"u8.ToArray(); + data[13] = 0xc0; + data[14] = 0x00; // -> com + data[24] = 0xc0; + data[25] = 13; // -> -> com + + Assert.False(DnsPrimitives.TryReadQName(data, 15, out _, out _)); + } + [Fact] public void TryReadQName_ReservedBits() { Span data = "\x0003www\x00c0"u8.ToArray(); - data[4] = 0xc0; data[0] = 0x40; Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); From f8c189df9a38cbf5359fbf503d8db9211ac08007 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Mon, 3 Feb 2025 17:18:50 +0100 Subject: [PATCH 16/45] Code review feedback (easy fixes) --- .../DnsServiceEndpointProvider.cs | 2 +- .../DnsServiceEndpointProviderBase.Log.cs | 4 ++-- .../DnsSrvServiceEndpointProvider.cs | 3 --- .../Resolver/DnsDataReader.cs | 2 +- .../Resolver/DnsPrimitives.cs | 6 ++++++ .../Resolver/DnsResolver.cs | 11 +++++------ .../Resolver/NetworkInfo.cs | 2 +- .../Resolver/QueryFlags.cs | 2 -- .../Resolver/QueryResponseCode.cs | 10 +++++----- .../Resolver/ResolvConf.cs | 5 ++++- .../Resolver/ResolverOptions.cs | 7 ++++++- .../Resolver/ResultTypes.cs | 16 ---------------- 12 files changed, 31 insertions(+), 39 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs index c969a9e91e9..7a2d1b632e0 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs @@ -38,7 +38,7 @@ protected override async Task ResolveAsyncCore() foreach (var address in addresses) { ttl = MinTtl(now, address.ExpiresAt, ttl); - endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, 0))); + endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, port: 0))); } if (endpoints.Count == 0) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs index 9dbfafe4ef6..29aaaf8e930 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs @@ -9,10 +9,10 @@ partial class DnsServiceEndpointProviderBase { internal static partial class Log { - [LoggerMessage(1, LogLevel.Information, "Resolving endpoints for service '{ServiceName}' using DNS SRV lookup for name '{RecordName}'.", EventName = "SrvQuery")] + [LoggerMessage(1, LogLevel.Trace, "Resolving endpoints for service '{ServiceName}' using DNS SRV lookup for name '{RecordName}'.", EventName = "SrvQuery")] public static partial void SrvQuery(ILogger logger, string serviceName, string recordName); - [LoggerMessage(2, LogLevel.Information, "Resolving endpoints for service '{ServiceName}' using host lookup for name '{RecordName}'.", EventName = "AddressQuery")] + [LoggerMessage(2, LogLevel.Trace, "Resolving endpoints for service '{ServiceName}' using host lookup for name '{RecordName}'.", EventName = "AddressQuery")] public static partial void AddressQuery(ILogger logger, string serviceName, string recordName); [LoggerMessage(3, LogLevel.Debug, "Skipping endpoint resolution for service '{ServiceName}': '{Reason}'.", EventName = "SkippedResolution")] diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs index dc47bb8b5cd..6d5ade5059e 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs @@ -34,13 +34,10 @@ protected override async Task ResolveAsyncCore() var endpoints = new List(); var ttl = DefaultRefreshPeriod; Log.SrvQuery(logger, ServiceName, srvQuery); - logger.LogInformation("Resolving endpoints for service '{ServiceName}' using DNS SRV lookup for name '{RecordName}'.", ServiceName, srvQuery); var now = _timeProvider.GetUtcNow().DateTime; var result = await resolver.ResolveServiceAsync(srvQuery, cancellationToken: ShutdownToken).ConfigureAwait(false); - logger.LogInformation("Resolved {Number} entries", result.Length); - foreach (var record in result) { ttl = MinTtl(now, record.ExpiresAt, ttl); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs index 192fa0a948e..bbe42a74997 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -112,7 +112,7 @@ public void Dispose() if (RawData is not null) { ArrayPool.Shared.Return(RawData); - RawData = null!; + RawData = null; } _buffer = default; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs index c6a90a5624d..0455fbaf41c 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs @@ -41,6 +41,8 @@ internal static bool TryWriteQName(Span destination, string name, out int int index = nameBuffer.Slice(1).IndexOf((byte)'.'); int labelLen = index == -1 ? nameBuffer.Length - 1 : index; + // https://www.rfc-editor.org/rfc/rfc1035#section-2.3.4 + // labels 63 octets or less if (labelLen > 63) { throw new ArgumentException("Label is too long"); @@ -177,6 +179,7 @@ internal static bool TryReadQName(ReadOnlySpan messageBuffer, int offset, internal static bool TryReadService(ReadOnlySpan buffer, out ushort priority, out ushort weight, out ushort port, [NotNullWhen(true)] out string? target, out int bytesRead) { + // https://www.rfc-editor.org/rfc/rfc2782 if (!BinaryPrimitives.TryReadUInt16BigEndian(buffer, out priority) || !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Slice(2), out weight) || !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Slice(4), out port) || @@ -196,6 +199,7 @@ internal static bool TryReadService(ReadOnlySpan buffer, out ushort priori internal static bool TryWriteService(Span buffer, ushort priority, ushort weight, ushort port, string target, out int bytesWritten) { + // https://www.rfc-editor.org/rfc/rfc2782 if (!BinaryPrimitives.TryWriteUInt16BigEndian(buffer, priority) || !BinaryPrimitives.TryWriteUInt16BigEndian(buffer.Slice(2), weight) || !BinaryPrimitives.TryWriteUInt16BigEndian(buffer.Slice(4), port) || @@ -211,6 +215,7 @@ internal static bool TryWriteService(Span buffer, ushort priority, ushort internal static bool TryWriteSoa(Span buffer, string primaryNameServer, string responsibleMailAddress, uint serial, uint refresh, uint retry, uint expire, uint minimum, out int bytesWritten) { + // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 if (!TryWriteQName(buffer, primaryNameServer, out int w1) || !TryWriteQName(buffer.Slice(w1), responsibleMailAddress, out int w2) || !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2), serial) || @@ -229,6 +234,7 @@ internal static bool TryWriteSoa(Span buffer, string primaryNameServer, st internal static bool TryReadSoa(ReadOnlySpan buffer, [NotNullWhen(true)] out string? primaryNameServer, [NotNullWhen(true)] out string? responsibleMailAddress, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out int bytesRead) { + // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 if (!TryReadQName(buffer, 0, out primaryNameServer, out int w1) || !TryReadQName(buffer.Slice(w1), 0, out responsibleMailAddress, out int w2) || !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2), out serial) || diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index b562e025b76..522fbbe2509 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -20,7 +20,7 @@ internal partial class DnsResolver : IDnsResolver, IDisposable private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue); - bool _disposed; + private bool _disposed; private readonly ResolverOptions _options; private readonly CancellationTokenSource _pendingRequestsCts = new(); private TimeProvider _timeProvider = TimeProvider.System; @@ -31,7 +31,7 @@ internal void SetTimeProvider(TimeProvider timeProvider) _timeProvider = timeProvider; } - public DnsResolver(TimeProvider timeProvider, ILogger logger) : this(OperatingSystem.IsWindows() ? NetworkInfo.GetOptions() : ResolvConf.GetOptions()) + public DnsResolver(TimeProvider timeProvider, ILogger logger) : this(OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() ? ResolvConf.GetOptions() : NetworkInfo.GetOptions()) { _timeProvider = timeProvider; _logger = logger; @@ -255,7 +255,7 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp switch (result.Error) { case SendQueryError.NoError: - goto exit; + return result; case SendQueryError.Timeout: // TODO: should we retry on timeout or skip to the next server? Log.Timeout(_logger, queryType, name, serverEndPoint, attempt); @@ -270,8 +270,7 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp } } - exit: - // we have at least one server and we always keep the last received response. + // return the last error received return result; } @@ -307,7 +306,7 @@ internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEn } } - async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, int attempt, CancellationToken cancellationToken) + private async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, int attempt, CancellationToken cancellationToken) { Log.Query(_logger, queryType, name, serverEndPoint, attempt); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs index fb9f331559b..7da0bb9e9f7 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs @@ -22,7 +22,7 @@ public static ResolverOptions GetOptions() { foreach (IPAddress server in properties.DnsAddresses) { - IPEndPoint ep = new IPEndPoint(server, 53); + IPEndPoint ep = new IPEndPoint(server, 53); // 53 is standard DNS port if (!servers.Contains(ep)) { servers.Add(ep); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs index 94fa019f54f..d983a83b35a 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs @@ -6,8 +6,6 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; [Flags] internal enum QueryFlags : ushort { - IsCheckingDisabled = 0x0010, - IsAuthenticData = 0x0020, RecursionAvailable = 0x0080, RecursionDesired = 0x0100, ResultTruncated = 0x0200, diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs index 20b8790f54c..a0b0f6e40fe 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs @@ -23,11 +23,6 @@ enum QueryResponseCode : byte /// ServerFailure = 2, - /// - /// The name server does not support the requested kind of query. - /// - NotImplemented = 4, - /// /// Meaningful only for responses from an authoritative name server, this /// code signifies that the domain name referenced in the query does not @@ -35,6 +30,11 @@ enum QueryResponseCode : byte /// NameError = 3, + /// + /// The name server does not support the requested kind of query. + /// + NotImplemented = 4, + /// /// The name server refuses to perform the specified operation for policy reasons. /// diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs index d11c0fbff74..5132634c3de 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs @@ -3,11 +3,14 @@ using System.Net.Sockets; using System.Net; +using System.Runtime.Versioning; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal static class ResolvConf { + [SupportedOSPlatform("linux")] + [SupportedOSPlatform("osx")] public static ResolverOptions GetOptions() { return GetOptions(new StreamReader("/etc/resolv.conf")); @@ -26,7 +29,7 @@ public static ResolverOptions GetOptions(TextReader reader) { if (tokens.Length >= 2 && IPAddress.TryParse(tokens[1], out IPAddress? address)) { - serverList.Add(new IPEndPoint(address, 53)); + serverList.Add(new IPEndPoint(address, 53)); // 53 is standard DNS port } } else if (line.StartsWith("search")) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs index e3eda4d105f..504055e9548 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs @@ -5,7 +5,7 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; -internal class ResolverOptions +internal sealed class ResolverOptions { public IPEndPoint[] Servers; public string DefaultDomain = string.Empty; @@ -17,6 +17,11 @@ internal class ResolverOptions public ResolverOptions(IPEndPoint[] servers) { + if (servers.Length == 0) + { + throw new ArgumentException("At least one server is required.", nameof(servers)); + } + Servers = servers; } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs index 5b2f1d7229c..aed799ac8d6 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs @@ -2,25 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net; -using System.Text; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal record struct AddressResult(DateTime ExpiresAt, IPAddress Address); internal record struct ServiceResult(DateTime ExpiresAt, int Priority, int Weight, int Port, string Target, AddressResult[] Addresses); - -internal record struct TxtResult(int Ttl, byte[] Data) -{ - internal IEnumerable GetText() => GetText(Encoding.ASCII); - - internal IEnumerable GetText(Encoding encoding) - { - for (int i = 0; i < Data.Length;) - { - int length = Data[i]; - yield return encoding.GetString(Data, i + 1, length); - i += length + 1; - } - } -} From d9298890144b896ab541cdf72a853270d068b74d Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 6 Feb 2025 14:09:53 +0100 Subject: [PATCH 17/45] More code review feedback --- .../Resolver/DnsDataReader.cs | 5 ++- .../Resolver/DnsMessageHeader.cs | 9 +++- .../Resolver/DnsPrimitives.cs | 45 ++++--------------- .../Resolver/DnsResolver.cs | 1 + .../Resolver/LoopbackDnsServer.cs | 24 ++++++++-- 5 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs index bbe42a74997..ac2b4e99c48 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -3,6 +3,7 @@ using System.Buffers; using System.Buffers.Binary; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Runtime.InteropServices; @@ -23,7 +24,9 @@ public DnsDataReader(ReadOnlyMemory buffer, byte[]? returnToPool = null) public bool TryReadHeader(out DnsMessageHeader header) { - if (_buffer.Length - _position < DnsMessageHeader.HeaderLength) + Debug.Assert(_position == 0); + + if (_buffer.Length < DnsMessageHeader.HeaderLength) { header = default; return false; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs index 675e150e3b2..ebf887de7ca 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs @@ -2,21 +2,28 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers.Binary; +using System.Runtime.InteropServices; using System.Security.Cryptography; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; // RFC 1035 4.1.1. Header section format +[StructLayout(LayoutKind.Explicit, Size = HeaderLength)] internal struct DnsMessageHeader { internal const int HeaderLength = 12; + [FieldOffset(0)] private ushort _transactionId; + [FieldOffset(2)] private ushort _flags; - + [FieldOffset(4)] private ushort _queryCount; + [FieldOffset(6)] private ushort _answerCount; + [FieldOffset(8)] private ushort _authorityCount; + [FieldOffset(10)] private ushort _additionalRecordCount; internal ushort QueryCount diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs index 0455fbaf41c..dad97eae497 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers.Binary; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Text; @@ -25,7 +26,11 @@ internal static bool TryWriteQName(Span destination, string name, out int // padding is used. // - if (!Encoding.ASCII.TryGetBytes(name, destination.IsEmpty ? destination : destination.Slice(1), out int length) || destination.Length < length + 2) + // The is assumed to be already validated and puny-encoded if needed + Debug.Assert(name.Length <= MaxDomainNameLength); + Debug.Assert(Ascii.IsValid(name)); + + if (destination.IsEmpty || !Encoding.ASCII.TryGetBytes(name, destination.Slice(1), out int length) || destination.Length < length + 2) { // buffer too small written = 0; @@ -38,13 +43,14 @@ internal static bool TryWriteQName(Span destination, string name, out int while (true) { // figure out the next label and prepend the length - int index = nameBuffer.Slice(1).IndexOf((byte)'.'); + int index = nameBuffer.Slice(1).IndexOf((byte)'.'); int labelLen = index == -1 ? nameBuffer.Length - 1 : index; // https://www.rfc-editor.org/rfc/rfc1035#section-2.3.4 // labels 63 octets or less if (labelLen > 63) { + // this should never happen, as we validate the name before calling this method throw new ArgumentException("Label is too long"); } @@ -197,41 +203,6 @@ internal static bool TryReadService(ReadOnlySpan buffer, out ushort priori return true; } - internal static bool TryWriteService(Span buffer, ushort priority, ushort weight, ushort port, string target, out int bytesWritten) - { - // https://www.rfc-editor.org/rfc/rfc2782 - if (!BinaryPrimitives.TryWriteUInt16BigEndian(buffer, priority) || - !BinaryPrimitives.TryWriteUInt16BigEndian(buffer.Slice(2), weight) || - !BinaryPrimitives.TryWriteUInt16BigEndian(buffer.Slice(4), port) || - !TryWriteQName(buffer.Slice(6), target, out bytesWritten)) - { - bytesWritten = 0; - return false; - } - - bytesWritten += 6; - return true; - } - - internal static bool TryWriteSoa(Span buffer, string primaryNameServer, string responsibleMailAddress, uint serial, uint refresh, uint retry, uint expire, uint minimum, out int bytesWritten) - { - // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 - if (!TryWriteQName(buffer, primaryNameServer, out int w1) || - !TryWriteQName(buffer.Slice(w1), responsibleMailAddress, out int w2) || - !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2), serial) || - !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2 + 4), refresh) || - !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2 + 8), retry) || - !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2 + 12), expire) || - !BinaryPrimitives.TryWriteUInt32BigEndian(buffer.Slice(w1 + w2 + 16), minimum)) - { - bytesWritten = 0; - return false; - } - - bytesWritten = w1 + w2 + 20; - return true; - } - internal static bool TryReadSoa(ReadOnlySpan buffer, [NotNullWhen(true)] out string? primaryNameServer, [NotNullWhen(true)] out string? responsibleMailAddress, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out int bytesRead) { // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 522fbbe2509..255e67e746a 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -552,6 +552,7 @@ internal static SendQueryError ValidateResponse(in DnsResponse response) throw new InvalidOperationException("Invalid response: Header mismatch"); } + // transfer ownership of buffer to the caller buffer = null!; return (responseReader, header); } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs index fc0507beb35..233154e6e32 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -61,7 +61,8 @@ private static async Task ProcessRequestCore(ReadOnlyMemory message, })) { throw new InvalidOperationException("Failed to write header"); - }; + } + ; foreach (var (questionName, questionType, questionClass) in responseBuilder.Questions) { @@ -198,11 +199,18 @@ public static List AddCname(this List reco public static List AddService(this List records, string name, int ttl, ushort priority, ushort weight, ushort port, string target) { byte[] buff = new byte[256]; - if (!DnsPrimitives.TryWriteService(buff, priority, weight, port, target, out int length)) + + // https://www.rfc-editor.org/rfc/rfc2782 + if (!BinaryPrimitives.TryWriteUInt16BigEndian(buff, priority) || + !BinaryPrimitives.TryWriteUInt16BigEndian(buff.AsSpan(2), weight) || + !BinaryPrimitives.TryWriteUInt16BigEndian(buff.AsSpan(4), port) || + !DnsPrimitives.TryWriteQName(buff.AsSpan(6), target, out int length)) { throw new InvalidOperationException("Failed to encode SRV record"); } + length += 6; + records.Add(new DnsResourceRecord(name, QueryType.SRV, QueryClass.Internet, ttl, buff.AsMemory(0, length))); return records; } @@ -210,11 +218,21 @@ public static List AddService(this List re public static List AddStartOfAuthority(this List records, string name, int ttl, string mname, string rname, uint serial, uint refresh, uint retry, uint expire, uint minimum) { byte[] buff = new byte[256]; - if (!DnsPrimitives.TryWriteSoa(buff, mname, rname, serial, refresh, retry, expire, minimum, out int length)) + + // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 + if (!DnsPrimitives.TryWriteQName(buff, mname, out int w1) || + !DnsPrimitives.TryWriteQName(buff.AsSpan(w1), rname, out int w2) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2), serial) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 4), refresh) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 8), retry) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 12), expire) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 16), minimum)) { throw new InvalidOperationException("Failed to encode SOA record"); } + int length = w1 + w2 + 20; + records.Add(new DnsResourceRecord(name, QueryType.SOA, QueryClass.Internet, ttl, buff.AsMemory(0, length))); return records; } From 00e41d764e91014e8a859bcca74ee082b78f4fe3 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 6 Feb 2025 14:20:21 +0100 Subject: [PATCH 18/45] Fix increment --- .../Resolver/DnsResolver.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 255e67e746a..5ec825c38d6 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -18,6 +18,7 @@ internal partial class DnsResolver : IDnsResolver, IDisposable private const int IPv4Length = 4; private const int IPv6Length = 16; + // CancellationTokenSource.CancelAfter has a maximum timeout of Int32.MaxValue milliseconds. private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue); private bool _disposed; @@ -130,10 +131,11 @@ public async ValueTask ResolveIPAddressesAsync(string name, Can if (Socket.OSSupportsIPv6) // prefer IPv6 { res[index] = new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback); + index++; } if (Socket.OSSupportsIPv4) { - res[index++] = new AddressResult(DateTime.MaxValue, IPAddress.Loopback); + res[index] = new AddressResult(DateTime.MaxValue, IPAddress.Loopback); } } From 7267fb794bb8ad92bd6450fb91946e9000318ddf Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 7 Feb 2025 09:52:59 +0100 Subject: [PATCH 19/45] Detect CNAME loops --- .../Resolver/DnsResolver.cs | 80 +++++++++++++------ .../Resolver/ResolveAddressesTests.cs | 38 ++++++++- 2 files changed, 94 insertions(+), 24 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 5ec825c38d6..f87052045bd 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -137,6 +137,8 @@ public async ValueTask ResolveIPAddressesAsync(string name, Can { res[index] = new AddressResult(DateTime.MaxValue, IPAddress.Loopback); } + + return res; } var ipv4AddressesTask = ResolveIPAddressesAsync(name, AddressFamily.InterNetwork, cancellationToken); @@ -188,39 +190,67 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add } using DnsResponse response = result.Response; + + // Given that result.Error is NoError, there should be at least one answer. + Debug.Assert(response.Answers.Count > 0); var results = new List(response.Answers.Count); - // servers send back CNAME records together with associated A/AAAA records + // Servers send back CNAME records together with associated A/AAAA records. Servers + // send only those CNAME records relevant to the query, and if there is a CNAME record, + // there should not be other records associated with the name. Therefore, we simply follow + // the list of CNAME aliases until we get to the primary name and return the A/AAAA records + // associated. + // + // more info: https://datatracker.ietf.org/doc/html/rfc1034#section-3.6.2 + // + // Most of the servers send the CNAME records in order so that we can sequentially scan the + // answers, but nothing prevents the records from being in arbitrary order. Therefore, when + // we encounter a CNAME record, we continue down the list and allow looping back to the beginning + // in case the CNAME chain is not in order. + // string currentAlias = name; + int i = 0; + int endIndex = 0; - foreach (var answer in response.Answers) + do { - if (answer.Name != currentAlias) - { - continue; - } + DnsResourceRecord answer = response.Answers[i]; - if (answer.Type == QueryType.CNAME) + if (answer.Name == currentAlias) { - // Although RFC does not necessarily allow pointers segments in CNAME domain names, some servers do use them - // so we need to pass the entire buffer to TryReadQName with the proper offset. The data should be always - // backed by the array containing the full response. + if (answer.Type == QueryType.CNAME) + { + // Although RFC does not necessarily allow pointer segments in CNAME domain names, some servers do use them + // so we need to pass the entire buffer to TryReadQName with the proper offset. The data should be always + // backed by the array containing the full response. + + var success = MemoryMarshal.TryGetArray(answer.Data, out ArraySegment segment); + Debug.Assert(success, "Failed to get array segment"); + if (!DnsPrimitives.TryReadQName(segment.Array.AsSpan(0, segment.Offset + segment.Count), segment.Offset, out currentAlias!, out _)) + { + // TODO: how to handle corrupted responses? + throw new InvalidOperationException("Failed to parse CNAME record"); + } - var success = MemoryMarshal.TryGetArray(answer.Data, out ArraySegment segment); - Debug.Assert(success, "Failed to get array segment"); - if (!DnsPrimitives.TryReadQName(segment.Array.AsSpan(0, segment.Offset + segment.Count), segment.Offset, out currentAlias!, out _)) + // We need to start over. start with following answers and allow looping back + endIndex = i; + + if (string.Equals(currentAlias, name, StringComparison.OrdinalIgnoreCase)) + { + // CNAME records looped back to original question dns name (=> malformed response). Stop processing. + break; + } + } + else if (answer.Type == queryType) { - throw new InvalidOperationException("Invalid response: CNAME record"); + Debug.Assert(answer.Data.Length == IPv4Length || answer.Data.Length == IPv6Length); + results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), new IPAddress(answer.Data.Span))); } - continue; } - else if (answer.Type == queryType) - { - Debug.Assert(answer.Data.Length == IPv4Length || answer.Data.Length == IPv6Length); - results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), new IPAddress(answer.Data.Span))); - } + i = (i + 1) % response.Answers.Count; } + while (i != endIndex); AddressResult[] res = results.ToArray(); Telemetry.StopNameResolution(name, queryType, activity, res, result.Error, _timeProvider.GetTimestamp()); @@ -273,6 +303,7 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp } // return the last error received + // TODO: will this always have nondefault value? return result; } @@ -329,9 +360,12 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve !responseReader.TryReadQuestion(out var qName, out var qType, out var qClass) || qName != name || qType != queryType || qClass != QueryClass.Internet) { - // TODO: do we care? - throw new InvalidOperationException("Invalid response: Query mismatch"); - // return default; + // DNS Question mismatch + return new SendQueryResult + { + Response = new DnsResponse(Array.Empty(), header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.ServerError + }; } if (header.ResponseCode != QueryResponseCode.NoError) diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs index beaaaf26d18..b3c55998060 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -69,7 +69,7 @@ public async Task ResolveIPv4_Simple_Success() } [Fact] - public async Task ResolveIPv4_Aliases_Success() + public async Task ResolveIPv4_Aliases_InOrder_Success() { IPAddress address = IPAddress.Parse("172.213.245.111"); _ = DnsServer.ProcessUdpRequest(builder => @@ -87,6 +87,42 @@ public async Task ResolveIPv4_Aliases_Success() Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); } + [Fact] + public async Task ResolveIPv4_Aliases_OutOfOrder_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddAddress("www.example3.com", 3600, address); + builder.Answers.AddCname("www.example.com", 3600, "www.example2.com"); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task ResolveIPv4_Aliases_Loop_ReturnsEmpty() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example1.com", 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddCname("www.example3.com", 3600, "www.example1.com"); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example1.com", AddressFamily.InterNetwork); + + Assert.Empty(results); + } + [Fact] public async Task ResolveIPv4_Aliases_NotFound_Success() { From c1b76d37d10f1091cadd6ea8535a34092e4a91da Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 7 Feb 2025 13:06:47 +0100 Subject: [PATCH 20/45] Handle empty Tcp fallback responses and lack of TCP failover support on server. --- .../Resolver/DnsResolver.Log.cs | 8 +- .../Resolver/DnsResolver.cs | 77 +++++++++++++------ .../Resolver/SendQueryError.cs | 31 +++++++- .../Resolver/LoopbackDnsServer.cs | 6 +- .../Resolver/TcpFailoverTests.cs | 44 +++++++++++ 5 files changed, 138 insertions(+), 28 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs index 7a4f4223ba5..b3762db9dd9 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs @@ -24,7 +24,13 @@ internal static partial class Log [LoggerMessage(5, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} returned no data", EventName = "NoData")] public static partial void NoData(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); - [LoggerMessage(6, LogLevel.Error, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed.", EventName = "QueryError")] + [LoggerMessage(6, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed to return a valid DNS response.", EventName = "MalformedResponse")] + public static partial void MalformedResponse(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(7, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed due to a network error.", EventName = "NetworkError")] + public static partial void NetworkError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception); + + [LoggerMessage(8, LogLevel.Error, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed.", EventName = "QueryError")] public static partial void QueryError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception); } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index f87052045bd..e3f2c0cd62b 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -265,46 +265,62 @@ internal struct SendQueryResult async ValueTask SendQueryWithRetriesAsync(string name, QueryType queryType, CancellationToken cancellationToken) { - SendQueryResult result = default; + SendQueryResult? result = default; for (int index = 0; index < _options.Servers.Length; index++) { IPEndPoint serverEndPoint = _options.Servers[index]; - for (int attempt = 0; attempt < _options.Attempts; attempt++) + for (int attempt = 1; attempt <= _options.Attempts; attempt++) { try { result = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, index == _options.Servers.Length - 1, attempt, cancellationToken).ConfigureAwait(false); } + catch (SocketException ex) + { + Log.NetworkError(_logger, queryType, name, serverEndPoint, attempt, ex); + result = new SendQueryResult { Error = SendQueryError.NetworkError }; + continue; // retry or skip to the next server + } catch (Exception ex) when (!cancellationToken.IsCancellationRequested) { Log.QueryError(_logger, queryType, name, serverEndPoint, attempt, ex); continue; // retry or skip to the next server } - switch (result.Error) + Debug.Assert(result.HasValue); + + switch (result.Value.Error) { case SendQueryError.NoError: - return result; + return result.Value; case SendQueryError.Timeout: // TODO: should we retry on timeout or skip to the next server? Log.Timeout(_logger, queryType, name, serverEndPoint, attempt); break; case SendQueryError.ServerError: - Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, result.Response.Header.ResponseCode); + Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, result.Value.Response.Header.ResponseCode); break; case SendQueryError.NoData: Log.NoData(_logger, queryType, name, serverEndPoint, attempt); break; + case SendQueryError.MalformedResponse: + Log.MalformedResponse(_logger, queryType, name, serverEndPoint, attempt); + break; } } } - // return the last error received - // TODO: will this always have nondefault value? - return result; + // we should have an error result by now, except when we threw an exception due to internal bug + // (handled here), or cancellation (handled by the caller). + // if (!result.HasValue) + // { + // result = new SendQueryResult { Error = SendQueryError.InternalError }; + // } + + return result!.Value; } internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, int attempt, CancellationToken cancellationToken) @@ -343,6 +359,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve { Log.Query(_logger, queryType, name, serverEndPoint, attempt); + SendQueryError sendError = SendQueryError.NoError; DateTime queryStartedTime = _timeProvider.GetUtcNow().DateTime; (DnsDataReader responseReader, DnsMessageHeader header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); @@ -353,7 +370,13 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve Log.ResultTruncated(_logger, queryType, name, serverEndPoint, 0); responseReader.Dispose(); // TCP fallback - (responseReader, header) = await SendDnsQueryCoreTcpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); + (responseReader, header, sendError) = await SendDnsQueryCoreTcpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); + } + + if (sendError != SendQueryError.NoError) + { + // we failed to get back any response + return new SendQueryResult { Error = sendError }; } if (header.QueryCount != 1 || @@ -364,11 +387,15 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve return new SendQueryResult { Response = new DnsResponse(Array.Empty(), header, queryStartedTime, queryStartedTime, null!, null!, null!), - Error = SendQueryError.ServerError + Error = SendQueryError.MalformedResponse }; } - if (header.ResponseCode != QueryResponseCode.NoError) + // we are interested in returned RRs only in case of NOERROR response code, + // if this is not a successful response and we have attempts remaining, + // we skip parsing the response and retry. + // TODO: test server failover behavior + if (header.ResponseCode != QueryResponseCode.NoError && (!isLastServer || attempt != _options.Attempts)) { return new SendQueryResult { @@ -377,12 +404,6 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve }; } - if (header.ResponseCode != QueryResponseCode.NoError && !isLastServer) - { - // we exhausted attempts on this server, try the next one - return default; - } - int ttl = int.MaxValue; List answers = ReadRecords(header.AnswerCount, ref ttl, ref responseReader); List authorities = ReadRecords(header.AuthorityCount, ref ttl, ref responseReader); @@ -506,9 +527,7 @@ internal static SendQueryError ValidateResponse(in DnsResponse response) (ushort transactionId, int length) = EncodeQuestion(memory, name, queryType); using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Dgram, ProtocolType.Udp); - await socket.ConnectAsync(serverEndPoint, cancellationToken).ConfigureAwait(false); - - await socket.SendAsync(memory.Slice(0, length), SocketFlags.None, cancellationToken).ConfigureAwait(false); + await socket.SendToAsync(memory.Slice(0, length), SocketFlags.None, serverEndPoint, cancellationToken).ConfigureAwait(false); DnsDataReader responseReader; DnsMessageHeader header; @@ -527,8 +546,7 @@ internal static SendQueryError ValidateResponse(in DnsResponse response) header.TransactionId != transactionId || !header.IsResponse) { - // the message is not a response for our query. - // don't dispose reader, we will reuse the buffer + // header mismatch, this is not a response to our query continue; } @@ -546,7 +564,7 @@ internal static SendQueryError ValidateResponse(in DnsResponse response) } } - internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, CancellationToken cancellationToken) + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header, SendQueryError error)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, CancellationToken cancellationToken) { var buffer = ArrayPool.Shared.Rent(8 * 1024); try @@ -566,12 +584,20 @@ internal static SendQueryError ValidateResponse(in DnsResponse response) int read = await socket.ReceiveAsync(buffer.AsMemory(bytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false); bytesRead += read; + if (read == 0) + { + // connection closed before receiving complete response message + return (default, default, SendQueryError.MalformedResponse); + } + if (responseLength < 0 && bytesRead >= 2) { responseLength = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); if (responseLength > buffer.Length) { + // even though this is user-controlled pre-allocation, it is limited to + // 64 kB, so it should be fine. var largerBuffer = ArrayPool.Shared.Rent(responseLength); Array.Copy(buffer, largerBuffer, bytesRead); ArrayPool.Shared.Return(buffer); @@ -585,12 +611,13 @@ internal static SendQueryError ValidateResponse(in DnsResponse response) header.TransactionId != transactionId || !header.IsResponse) { - throw new InvalidOperationException("Invalid response: Header mismatch"); + // header mismatch on TCP fallback + return (default, default, SendQueryError.MalformedResponse); } // transfer ownership of buffer to the caller buffer = null!; - return (responseReader, header); + return (responseReader, header, SendQueryError.NoError); } finally { diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs index c7f9be88783..6b8e2bddf35 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs @@ -5,9 +5,38 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal enum SendQueryError { + /// + /// DNS query was successful and returned response message with answers. + /// NoError, + + /// + /// Server failed to respond to the query withing specified timeout. + /// Timeout, + + /// + /// Server returned a response with an error code. + /// ServerError, - ParseError, + + /// + /// Server returned a malformed response. + /// + MalformedResponse, + + /// + /// Server returned a response indicating no data are available. + /// NoData, + + /// + /// Network-level error occurred during the query. + /// + NetworkError, + + /// + /// Internal error on part of the implementation. + /// + InternalError, } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs index 233154e6e32..41e2be134a5 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -31,6 +31,11 @@ public void Dispose() _tcpSocket.Dispose(); } + public void DisableTcpFallback() + { + _tcpSocket.Close(); + } + private static async Task ProcessRequestCore(ReadOnlyMemory message, Func action, Memory responseBuffer) { DnsDataReader reader = new DnsDataReader(message); @@ -62,7 +67,6 @@ private static async Task ProcessRequestCore(ReadOnlyMemory message, { throw new InvalidOperationException("Failed to write header"); } - ; foreach (var (questionName, questionType, questionClass) in responseBuilder.Questions) { diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs index bd4a1ef4070..67ced0b2ac0 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -36,4 +36,48 @@ public async Task TcpFailover_Simple_Success() Assert.Equal(address, res.Address); Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); } + + [Fact] + public async Task TcpFailover_ServerClosesWithoutData_EmptyResult() + { + Options.Attempts = 1; + Options.Timeout = TimeSpan.FromSeconds(60); + + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + Task serverTask = DnsServer.ProcessTcpRequest(builder => + { + throw new InvalidOperationException("This forces closing the socket without writing any data"); + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + Assert.Empty(results); + + await Assert.ThrowsAsync(() => serverTask); + } + + [Fact] + public async Task TcpFailover_TcpNotAvailable_EmptyResult() + { + Options.Attempts = 1; + Options.Timeout = TimeSpan.FromMilliseconds(100000); + + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + // turn off TCP support the server + DnsServer.DisableTcpFallback(); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + Assert.Empty(results); + } } From 86490cc932ec52581754e7502ab64bb2c5ace226 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 7 Feb 2025 13:08:43 +0100 Subject: [PATCH 21/45] Dispose result.Response when overwritten by another result. --- .../Resolver/DnsResolver.cs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index e3f2c0cd62b..f83db1a853a 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -276,7 +276,14 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp try { - result = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, index == _options.Servers.Length - 1, attempt, cancellationToken).ConfigureAwait(false); + SendQueryResult newResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, index == _options.Servers.Length - 1, attempt, cancellationToken).ConfigureAwait(false); + + if (result.HasValue) + { + result.Value.Response.Dispose(); + } + + result = newResult; } catch (SocketException ex) { @@ -315,10 +322,10 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp // we should have an error result by now, except when we threw an exception due to internal bug // (handled here), or cancellation (handled by the caller). - // if (!result.HasValue) - // { - // result = new SendQueryResult { Error = SendQueryError.InternalError }; - // } + if (!result.HasValue) + { + result = new SendQueryResult { Error = SendQueryError.InternalError }; + } return result!.Value; } From 92b29052f3a13fafcf876b65e5426e808c4521f7 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Tue, 11 Feb 2025 15:28:16 +0100 Subject: [PATCH 22/45] Move DnsMessageHeader parsing to DnsPrimitives --- .../Resolver/DnsDataReader.cs | 5 +- .../Resolver/DnsDataWriter.cs | 8 +- .../Resolver/DnsMessageHeader.cs | 100 ++---------------- .../Resolver/DnsPrimitives.cs | 48 +++++++++ .../Resolver/DnsResolver.cs | 15 ++- .../Resolver/ResolverOptions.cs | 1 - .../Resolver/LoopbackDnsServer.cs | 3 +- 7 files changed, 75 insertions(+), 105 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs index ac2b4e99c48..c9e705a0787 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -26,14 +26,13 @@ public bool TryReadHeader(out DnsMessageHeader header) { Debug.Assert(_position == 0); - if (_buffer.Length < DnsMessageHeader.HeaderLength) + if (!DnsPrimitives.TryReadMessageHeader(_buffer.Span, out header, out int bytesRead)) { header = default; return false; } - _position += DnsMessageHeader.HeaderLength; - header = MemoryMarshal.AsRef(_buffer.Span); + _position += bytesRead; return true; } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs index 4abbf277328..4b116763ceb 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers.Binary; -using System.Runtime.InteropServices; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; -internal class DnsDataWriter +internal sealed class DnsDataWriter { private readonly Memory _buffer; private int _position; @@ -21,13 +20,12 @@ internal DnsDataWriter(Memory buffer) internal bool TryWriteHeader(in DnsMessageHeader header) { - if (_buffer.Length - _position < DnsMessageHeader.HeaderLength) + if (!DnsPrimitives.TryWriteMessageHeader(_buffer.Span.Slice(_position), header, out int written)) { return false; } - MemoryMarshal.Write(_buffer.Span, in header); - _position += DnsMessageHeader.HeaderLength; + _position += written; return true; } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs index ebf887de7ca..b3d5effb0d6 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs @@ -1,114 +1,36 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers.Binary; -using System.Runtime.InteropServices; -using System.Security.Cryptography; - namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; // RFC 1035 4.1.1. Header section format -[StructLayout(LayoutKind.Explicit, Size = HeaderLength)] internal struct DnsMessageHeader { internal const int HeaderLength = 12; + public ushort TransactionId { get; set; } - [FieldOffset(0)] - private ushort _transactionId; - [FieldOffset(2)] - private ushort _flags; - [FieldOffset(4)] - private ushort _queryCount; - [FieldOffset(6)] - private ushort _answerCount; - [FieldOffset(8)] - private ushort _authorityCount; - [FieldOffset(10)] - private ushort _additionalRecordCount; - - internal ushort QueryCount - { - get => ReverseByteOrder(_queryCount); - set => _queryCount = ReverseByteOrder(value); - } - - internal ushort AnswerCount - { - get => ReverseByteOrder(_answerCount); - set => _answerCount = ReverseByteOrder(value); - } + internal QueryFlags QueryFlags { get; set; } - internal ushort AuthorityCount - { - get => ReverseByteOrder(_authorityCount); - set => _authorityCount = ReverseByteOrder(value); - } + public ushort QueryCount { get; set; } - internal ushort AdditionalRecordCount - { - get => ReverseByteOrder(_additionalRecordCount); - set => _additionalRecordCount = ReverseByteOrder(value); - } + public ushort AnswerCount { get; set; } - internal ushort TransactionId - { - get => ReverseByteOrder(_transactionId); - set => _transactionId = ReverseByteOrder(value); - } + public ushort AuthorityCount { get; set; } - internal QueryFlags QueryFlags - { - get => (QueryFlags)ReverseByteOrder(_flags); - set => _flags = ReverseByteOrder((ushort)value); - } + public ushort AdditionalRecordCount { get; set; } - internal bool IsRecursionDesired + public QueryResponseCode ResponseCode { - get => (QueryFlags & QueryFlags.RecursionDesired) != 0; - set - { - if (value) - { - QueryFlags |= QueryFlags.RecursionDesired; - } - else - { - QueryFlags &= ~QueryFlags.RecursionDesired; - } - } + get => (QueryResponseCode)((int)QueryFlags & 0x000F); } - internal QueryResponseCode ResponseCode + public bool IsResultTruncated { - get => (QueryResponseCode)((_flags & 0x0F00) >> 8); - set => _flags = (ushort)((_flags & 0xF0FF) | ((ushort)value << 8)); + get => (QueryFlags & QueryFlags.ResultTruncated) != 0; } - internal bool IsResultTruncated => (QueryFlags & QueryFlags.ResultTruncated) != 0; - - internal bool IsResponse + public bool IsResponse { get => (QueryFlags & QueryFlags.HasResponse) != 0; - set - { - if (value) - { - QueryFlags |= QueryFlags.HasResponse; - } - else - { - QueryFlags &= ~QueryFlags.HasResponse; - } - } } - - internal void InitQueryHeader() - { - this = default; - TransactionId = (ushort)RandomNumberGenerator.GetInt32(short.MaxValue + 1); - IsRecursionDesired = true; - QueryCount = 1; - } - - private static ushort ReverseByteOrder(ushort value) => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value; } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs index dad97eae497..47db4f7ac6f 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs @@ -13,6 +13,54 @@ internal static class DnsPrimitives // Maximum length of a domain name in ASCII (excluding trailing dot) internal const int MaxDomainNameLength = 253; + internal static bool TryReadMessageHeader(ReadOnlySpan buffer, out DnsMessageHeader header, out int bytesRead) + { + // RFC 1035 4.1.1. Header section format + if (buffer.Length < DnsMessageHeader.HeaderLength) + { + header = default; + bytesRead = 0; + return false; + } + + header = new DnsMessageHeader + { + TransactionId = BinaryPrimitives.ReadUInt16BigEndian(buffer), + QueryFlags = (QueryFlags)BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(2)), + QueryCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(4)), + AnswerCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(6)), + AuthorityCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(8)), + AdditionalRecordCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(10)) + }; + + bytesRead = DnsMessageHeader.HeaderLength; + return true; + } + + internal static bool TryWriteMessageHeader(Span buffer, DnsMessageHeader header, out int bytesWritten) + { + // RFC 1035 4.1.1. Header section format + if (buffer.Length < DnsMessageHeader.HeaderLength) + { + bytesWritten = 0; + return false; + } + + BinaryPrimitives.WriteUInt16BigEndian(buffer, header.TransactionId); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(2), (ushort)header.QueryFlags); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(4), header.QueryCount); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(6), header.AnswerCount); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(8), header.AuthorityCount); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(10), header.AdditionalRecordCount); + + bytesWritten = DnsMessageHeader.HeaderLength; + return true; + } + + // https://www.rfc-editor.org/rfc/rfc1035#section-2.3.4 + // labels 63 octets or less + // name 255 octets or less + internal static bool TryWriteQName(Span destination, string name, out int written) { // diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index f83db1a853a..7ef080bb572 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -8,12 +8,13 @@ using System.Net; using System.Net.Sockets; using System.Runtime.InteropServices; +using System.Security.Cryptography; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; -internal partial class DnsResolver : IDnsResolver, IDisposable +internal sealed partial class DnsResolver : IDnsResolver, IDisposable { private const int IPv4Length = 4; private const int IPv6Length = 16; @@ -273,7 +274,6 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp for (int attempt = 1; attempt <= _options.Attempts; attempt++) { - try { SendQueryResult newResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, index == _options.Servers.Length - 1, attempt, cancellationToken).ConfigureAwait(false); @@ -637,13 +637,18 @@ internal static SendQueryError ValidateResponse(in DnsResponse response) private static (ushort id, int length) EncodeQuestion(Memory buffer, string name, QueryType queryType) { - DnsMessageHeader header = default; - header.InitQueryHeader(); + DnsMessageHeader header = new DnsMessageHeader + { + TransactionId = (ushort)RandomNumberGenerator.GetInt32(ushort.MaxValue + 1), + QueryFlags = QueryFlags.RecursionDesired, + QueryCount = 1 + }; + DnsDataWriter writer = new DnsDataWriter(buffer); if (!writer.TryWriteHeader(header) || !writer.TryWriteQuestion(name, queryType, QueryClass.Internet)) { - // should never happen since we validated the name length + // should never happen since we validated the name length before throw new InvalidOperationException("Buffer too small"); } return (header.TransactionId, writer.Position); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs index 504055e9548..82e874a1b25 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs @@ -10,7 +10,6 @@ internal sealed class ResolverOptions public IPEndPoint[] Servers; public string DefaultDomain = string.Empty; public string[]? SearchDomains; - public bool UseHostsFile; public int Attempts = 2; public TimeSpan Timeout = TimeSpan.FromSeconds(3); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs index 41e2be134a5..9ca1267261f 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -57,8 +57,7 @@ private static async Task ProcessRequestCore(ReadOnlyMemory message, if (!writer.TryWriteHeader(new DnsMessageHeader { TransactionId = responseBuilder.TransactionId, - QueryFlags = responseBuilder.Flags, - ResponseCode = responseBuilder.ResponseCode, + QueryFlags = responseBuilder.Flags | (QueryFlags)responseBuilder.ResponseCode, QueryCount = (ushort)responseBuilder.Questions.Count, AnswerCount = (ushort)responseBuilder.Answers.Count, AuthorityCount = (ushort)responseBuilder.Authorities.Count, From fdf190483954e34a8002644b692d3d9d807970f7 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Wed, 12 Feb 2025 16:42:37 +0100 Subject: [PATCH 23/45] Rework and add tests to retries and failover --- .../Resolver/DnsDataReader.cs | 1 - .../Resolver/DnsResolver.Log.cs | 11 +- .../Resolver/DnsResolver.cs | 108 +++++--- .../Resolver/QueryResponseCode.cs | 2 +- .../Resolver/SendQueryError.cs | 7 +- .../Resolver/RetryTests.cs | 232 ++++++++++++++++-- 6 files changed, 300 insertions(+), 61 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs index c9e705a0787..831f03b082c 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -5,7 +5,6 @@ using System.Buffers.Binary; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Runtime.InteropServices; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs index b3762db9dd9..adab9161737 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs @@ -21,16 +21,19 @@ internal static partial class Log [LoggerMessage(4, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} timed out.", EventName = "Timeout")] public static partial void Timeout(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); - [LoggerMessage(5, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} returned no data", EventName = "NoData")] + [LoggerMessage(5, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt}: no data matching given query type.", EventName = "NoData")] public static partial void NoData(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); - [LoggerMessage(6, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed to return a valid DNS response.", EventName = "MalformedResponse")] + [LoggerMessage(6, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt}: server indicates given name does not exist.", EventName = "NameError")] + public static partial void NameError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(7, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed to return a valid DNS response.", EventName = "MalformedResponse")] public static partial void MalformedResponse(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); - [LoggerMessage(7, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed due to a network error.", EventName = "NetworkError")] + [LoggerMessage(8, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed due to a network error.", EventName = "NetworkError")] public static partial void NetworkError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception); - [LoggerMessage(8, LogLevel.Error, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed.", EventName = "QueryError")] + [LoggerMessage(9, LogLevel.Error, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed.", EventName = "QueryError")] public static partial void QueryError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception); } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 7ef080bb572..93928b1e9c5 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -276,7 +276,7 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp { try { - SendQueryResult newResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, index == _options.Servers.Length - 1, attempt, cancellationToken).ConfigureAwait(false); + SendQueryResult newResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, attempt, cancellationToken).ConfigureAwait(false); if (result.HasValue) { @@ -289,39 +289,73 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp { Log.NetworkError(_logger, queryType, name, serverEndPoint, attempt, ex); result = new SendQueryResult { Error = SendQueryError.NetworkError }; - continue; // retry or skip to the next server } catch (Exception ex) when (!cancellationToken.IsCancellationRequested) { Log.QueryError(_logger, queryType, name, serverEndPoint, attempt, ex); - continue; // retry or skip to the next server + result = new SendQueryResult { Error = SendQueryError.InternalError }; } - Debug.Assert(result.HasValue); - switch (result.Value.Error) { + // + // Definitive answers, no point retrying + // case SendQueryError.NoError: return result.Value; + + case SendQueryError.NameError: + // authoritative answer that the name does not exist, no point in retrying + Log.NameError(_logger, queryType, name, serverEndPoint, attempt); + return result.Value; + + case SendQueryError.NoData: + // no data available for the name from authoritative server + Log.NoData(_logger, queryType, name, serverEndPoint, attempt); + return result.Value; + + // + // Transient errors, retry on the same server + // case SendQueryError.Timeout: - // TODO: should we retry on timeout or skip to the next server? Log.Timeout(_logger, queryType, name, serverEndPoint, attempt); - break; + continue; + + case SendQueryError.NetworkError: + // TODO: retry with exponential backoff? + continue; + + case SendQueryError.ServerError when result.Value.Response!.Header.ResponseCode == QueryResponseCode.ServerFailure: + // ServerFailure may indicate transient failure with upstream DNS servers, retry on the same server + Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, result.Value.Response.Header.ResponseCode); + continue; + + // + // Persistent errors, skip to the next server + // case SendQueryError.ServerError: + // this should cover all response codes except NoError, NameError which are definite and handled above, and + // ServerFailure which is a transient error and handled above. Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, result.Value.Response.Header.ResponseCode); break; - case SendQueryError.NoData: - Log.NoData(_logger, queryType, name, serverEndPoint, attempt); - break; + case SendQueryError.MalformedResponse: Log.MalformedResponse(_logger, queryType, name, serverEndPoint, attempt); break; + + case SendQueryError.InternalError: + // exception logged above. + break; } + + // actual break that causes skipping to the next server + break; } } - // we should have an error result by now, except when we threw an exception due to internal bug - // (handled here), or cancellation (handled by the caller). + // we should have a result by now + Debug.Assert(result.HasValue); + if (!result.HasValue) { result = new SendQueryResult { Error = SendQueryError.InternalError }; @@ -330,13 +364,13 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp return result!.Value; } - internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, int attempt, CancellationToken cancellationToken) + internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, int attempt, CancellationToken cancellationToken) { (CancellationTokenSource cts, bool disposeTokenSource, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken); try { - return await SendQueryToServerAsync(serverEndPoint, name, queryType, isLastServer, attempt, cts.Token).ConfigureAwait(false); + return await SendQueryToServerAsync(serverEndPoint, name, queryType, attempt, cts.Token).ConfigureAwait(false); } catch (OperationCanceledException) when ( !cancellationToken.IsCancellationRequested && // not cancelled by the caller @@ -362,7 +396,7 @@ internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEn } } - private async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, bool isLastServer, int attempt, CancellationToken cancellationToken) + private async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, int attempt, CancellationToken cancellationToken) { Log.Query(_logger, queryType, name, serverEndPoint, attempt); @@ -398,29 +432,21 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve }; } - // we are interested in returned RRs only in case of NOERROR response code, - // if this is not a successful response and we have attempts remaining, - // we skip parsing the response and retry. - // TODO: test server failover behavior - if (header.ResponseCode != QueryResponseCode.NoError && (!isLastServer || attempt != _options.Attempts)) - { - return new SendQueryResult - { - Response = new DnsResponse(Array.Empty(), header, queryStartedTime, queryStartedTime, null!, null!, null!), - Error = SendQueryError.ServerError - }; - } - int ttl = int.MaxValue; List answers = ReadRecords(header.AnswerCount, ref ttl, ref responseReader); List authorities = ReadRecords(header.AuthorityCount, ref ttl, ref responseReader); List additionals = ReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader); + DateTime expirationTime = + (answers.Count + authorities.Count + additionals.Count) > 0 ? queryStartedTime.AddSeconds(ttl) : queryStartedTime; + + SendQueryError validationError = ValidateResponse(header.ResponseCode, queryStartedTime, answers, authorities, ref expirationTime); + // we transfer ownership of RawData to the response - DnsResponse response = new(responseReader.RawData!, header, queryStartedTime, queryStartedTime.AddSeconds(ttl), answers, authorities, additionals); + DnsResponse response = new(responseReader.RawData!, header, queryStartedTime, expirationTime, answers, authorities, additionals); responseReader = default; // avoid disposing (and returning RawData to the pool) - return new SendQueryResult { Response = response, Error = ValidateResponse(response) }; + return new SendQueryResult { Response = response, Error = validationError }; } finally { @@ -447,7 +473,7 @@ static List ReadRecords(int count, ref int ttl, ref DnsDataRe } } - internal static bool GetNegativeCacheExpiration(in DnsResponse response, out DateTime expiration) + internal static bool GetNegativeCacheExpiration(DateTime createdAt, List authorities, out DateTime expiration) { // // RFC 2308 Section 5 - Caching Negative Answers @@ -463,10 +489,10 @@ internal static bool GetNegativeCacheExpiration(in DnsResponse response, out Dat // be used again. // - DnsResourceRecord? soa = response.Authorities.FirstOrDefault(r => r.Type == QueryType.SOA); + DnsResourceRecord? soa = authorities.FirstOrDefault(r => r.Type == QueryType.SOA); if (soa != null && DnsPrimitives.TryReadSoa(soa.Value.Data.Span, out string? mname, out string? rname, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out _)) { - expiration = response.CreatedAt.AddSeconds(Math.Min(minimum, soa.Value.Ttl)); + expiration = createdAt.AddSeconds(Math.Min(minimum, soa.Value.Ttl)); return true; } @@ -474,11 +500,11 @@ internal static bool GetNegativeCacheExpiration(in DnsResponse response, out Dat return false; } - internal static SendQueryError ValidateResponse(in DnsResponse response) + internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, DateTime createdAt, List answers, List authorities, ref DateTime expiration) { - if (response.Header.ResponseCode == QueryResponseCode.NoError) + if (responseCode == QueryResponseCode.NoError) { - if (response.Answers.Count > 0) + if (answers.Count > 0) { return SendQueryError.NoError; } @@ -497,14 +523,15 @@ internal static SendQueryError ValidateResponse(in DnsResponse response) // another query for the same that resulted in // the cached negative response. // - if (!response.Authorities.Any(r => r.Type == QueryType.NS) && GetNegativeCacheExpiration(response, out DateTime expiration)) + if (!authorities.Any(r => r.Type == QueryType.NS) && GetNegativeCacheExpiration(createdAt, authorities, out DateTime newExpiration)) { + expiration = newExpiration; // _cache.TryAdd(name, queryType, expiration, Array.Empty()); } return SendQueryError.NoData; } - if (response.Header.ResponseCode == QueryResponseCode.NameError) + if (responseCode == QueryResponseCode.NameError) { // // RFC 2308 Section 5 - Caching Negative Answers @@ -514,12 +541,13 @@ internal static SendQueryError ValidateResponse(in DnsResponse response) // another query for the same that resulted in the // cached negative response. // - if (GetNegativeCacheExpiration(response, out DateTime expiration)) + if (GetNegativeCacheExpiration(createdAt, authorities, out DateTime newExpiration)) { + expiration = newExpiration; // _cache.TryAddNonexistent(name, expiration); } - return SendQueryError.ServerError; + return SendQueryError.NameError; } return SendQueryError.ServerError; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs index a0b0f6e40fe..dd51c712112 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs @@ -6,7 +6,7 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; /// /// The response code (RCODE) in a DNS query response. /// -enum QueryResponseCode : byte +internal enum QueryResponseCode : byte { /// /// No error condition diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs index 6b8e2bddf35..3ba5632e207 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs @@ -26,10 +26,15 @@ internal enum SendQueryError MalformedResponse, /// - /// Server returned a response indicating no data are available. + /// Server returned a response indicating that the name exists, but no data are available. /// NoData, + /// + /// Server returned a response indicating the name does not exist. + /// + NameError, + /// /// Network-level error occurred during the query. /// diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs index 665169d32c0..5ee7faa8632 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs @@ -12,38 +12,242 @@ public class RetryTests : LoopbackDnsTestBase { public RetryTests(ITestOutputHelper output) : base(output) { + Options.Attempts = 3; + } + + private static void SetupUdpProcessFunction(LoopbackDnsServer server, Func func) + { + _ = Task.Run(async () => + { + try + { + while (true) + { + await server.ProcessUdpRequest(func); + } + } + catch (SocketException) + { + // Test teardown closed the socket, ignore + } + }); + } + + private void SetupUdpProcessFunction(Func func) + { + SetupUdpProcessFunction(DnsServer, func); } [Fact] public async Task Retry_Simple_Success() { - Options.Attempts = 3; IPAddress address = IPAddress.Parse("172.213.245.111"); - _ = Task.Run(async () => + int attempt = 0; + + SetupUdpProcessFunction(builder => { - for (int attempt = 1; attempt <= 3; attempt++) + attempt++; + if (attempt == Options.Attempts) { - await DnsServer.ProcessUdpRequest(builder => + builder.Answers.AddAddress("www.example.com", 3600, address); + } + else + { + builder.ResponseCode = QueryResponseCode.ServerFailure; + } + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Theory] + [InlineData(/* QueryResponseCode.NotImplemented */ 4)] + [InlineData(/* QueryResponseCode.Refused */ 5)] + public async Task PersistentErrorsResponseCode_FailoverToNextServer(int responseCode) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper("www.example.com", + builder => + { + primaryAttempt++; + builder.ResponseCode = (QueryResponseCode)responseCode; + return Task.CompletedTask; + }, + builder => + { + secondaryAttempt++; + builder.Answers.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + Assert.Equal(1, primaryAttempt); + Assert.Equal(1, secondaryAttempt); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + public enum DefinitveAnswerType + { + NoError, + NoData, + NameError, + } + + [Theory] + [InlineData(DefinitveAnswerType.NoError, false)] + [InlineData(DefinitveAnswerType.NoData, false)] + [InlineData(DefinitveAnswerType.NoData, true)] + [InlineData(DefinitveAnswerType.NameError, false)] + [InlineData(DefinitveAnswerType.NameError, true)] + public async Task DefinitiveAnswers_NoRetryOrFailover(DefinitveAnswerType type, bool additionalData) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper("www.example.com", + builder => + { + primaryAttempt++; + switch (type) { - if (attempt == 3) - { + case DefinitveAnswerType.NoError: + builder.ResponseCode = QueryResponseCode.NoError; builder.Answers.AddAddress("www.example.com", 3600, address); - } - else + break; + + case DefinitveAnswerType.NoData: + builder.ResponseCode = QueryResponseCode.NoError; + break; + + case DefinitveAnswerType.NameError: + builder.ResponseCode = QueryResponseCode.NameError; + break; + } + + if (additionalData) + { + builder.Authorities.AddStartOfAuthority("www.example.com", 300, "ns1.example.com", "hostmaster.example.com", 2023101001, 1, 3600, 300, 86400); + } + + return Task.CompletedTask; + }, + builder => + { + secondaryAttempt++; + builder.ResponseCode = QueryResponseCode.Refused; + return Task.CompletedTask; + }); + + Assert.Equal(1, primaryAttempt); + Assert.Equal(0, secondaryAttempt); + + if (type == DefinitveAnswerType.NoError) + { + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + else + { + Assert.Empty(results); + } + } + + public enum TransientErrorType + { + Timeout, + ServerFailure, + // TODO: simulate NetworkErrors + } + + [Theory] + [InlineData(TransientErrorType.Timeout)] + [InlineData(TransientErrorType.ServerFailure)] + public async Task TransientError_RetryOnSameServer(TransientErrorType type) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper("www.example.com", + async builder => + { + primaryAttempt++; + if (primaryAttempt == 1) + { + switch (type) { - builder.ResponseCode = QueryResponseCode.ServerFailure; + case TransientErrorType.Timeout: + await Task.Delay(Options.Timeout.Multiply(1.5)); + builder.Answers.AddAddress("www.example.com", 3600, address); + break; + + case TransientErrorType.ServerFailure: + builder.ResponseCode = QueryResponseCode.ServerFailure; + break; } - return Task.CompletedTask; - }); - } - }); + } + else + { + builder.Answers.AddAddress("www.example.com", 3600, address); + } + }, + builder => + { + secondaryAttempt++; + builder.ResponseCode = QueryResponseCode.Refused; + return Task.CompletedTask; + }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + Assert.Equal(2, primaryAttempt); + Assert.Equal(0, secondaryAttempt); AddressResult res = Assert.Single(results); Assert.Equal(address, res.Address); Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + private async Task RunWithFallbackServerHelper(string name, Func primaryHandler, Func fallbackHandler) + { + SetupUdpProcessFunction(primaryHandler); + using LoopbackDnsServer fallbackServer = new LoopbackDnsServer(); + SetupUdpProcessFunction(fallbackServer, fallbackHandler); + + Options.Servers = [DnsServer.DnsEndPoint, fallbackServer.DnsEndPoint]; + + return await Resolver.ResolveIPAddressesAsync(name, AddressFamily.InterNetwork); + } + + [Fact] + public async Task NameError_NoRetry() + { + int counter = 0; + SetupUdpProcessFunction(builder => + { + counter++; + // authoritative answer that the name does not exist + builder.ResponseCode = QueryResponseCode.NameError; + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + Assert.Empty(results); + Assert.Equal(1, counter); } } From 61814e7c0a91254e81dd4525305a02acc42e83b6 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 13 Feb 2025 14:11:44 +0100 Subject: [PATCH 24/45] Test failover after exhausting attempts on one server --- .../Resolver/RetryTests.cs | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs index 5ee7faa8632..b7c12bfc763 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs @@ -167,6 +167,36 @@ public async Task DefinitiveAnswers_NoRetryOrFailover(DefinitveAnswerType type, } } + [Fact] + public async Task ExhaustedRetries_FailoverToNextServer() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper("www.example.com", + builder => + { + primaryAttempt++; + builder.ResponseCode = QueryResponseCode.ServerFailure; + return Task.CompletedTask; + }, + builder => + { + secondaryAttempt++; + builder.Answers.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + Assert.Equal(Options.Attempts, primaryAttempt); + Assert.Equal(1, secondaryAttempt); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + public enum TransientErrorType { Timeout, From 82a8108ba1628c2adc1964539bc9f8b5f88d1765 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 13 Feb 2025 15:29:44 +0100 Subject: [PATCH 25/45] Streamline options, remove unsupported options from ResolverOptions --- .../Resolver/DnsResolver.cs | 7 ++----- .../Resolver/NetworkInfo.cs | 2 +- .../Resolver/ResolvConf.cs | 21 +++++++------------ .../Resolver/ResolverOptions.cs | 11 ++++------ 4 files changed, 15 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 93928b1e9c5..b7e1af4d935 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -43,10 +43,7 @@ internal DnsResolver(ResolverOptions options) { _logger = NullLogger.Instance; _options = options; - if (options.Servers.Length == 0) - { - throw new ArgumentException("There are no DNS servers configured.", nameof(options)); - } + Debug.Assert(_options.Servers.Count > 0); if (options.Timeout != Timeout.InfiniteTimeSpan) { @@ -268,7 +265,7 @@ async ValueTask SendQueryWithRetriesAsync(string name, QueryTyp { SendQueryResult? result = default; - for (int index = 0; index < _options.Servers.Length; index++) + for (int index = 0; index < _options.Servers.Count; index++) { IPEndPoint serverEndPoint = _options.Servers[index]; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs index 7da0bb9e9f7..c2ef13f922e 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs @@ -31,6 +31,6 @@ public static ResolverOptions GetOptions() } } - return new ResolverOptions(servers!.ToArray()); + return new ResolverOptions(servers); } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs index 5132634c3de..fbfdc5ae027 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Net.Sockets; using System.Net; using System.Runtime.Versioning; @@ -19,7 +18,6 @@ public static ResolverOptions GetOptions() public static ResolverOptions GetOptions(TextReader reader) { List serverList = new(); - List searchDomains = new(); while (reader.ReadLine() is string line) { @@ -30,24 +28,21 @@ public static ResolverOptions GetOptions(TextReader reader) if (tokens.Length >= 2 && IPAddress.TryParse(tokens[1], out IPAddress? address)) { serverList.Add(new IPEndPoint(address, 53)); // 53 is standard DNS port + + if (serverList.Count == 3) + { + break; // resolv.conf manpage allow max 3 nameservers anyway + } } } - else if (line.StartsWith("search")) - { - searchDomains.AddRange(tokens.Skip(1)); - } } if (serverList.Count == 0) { - throw new SocketException((int)SocketError.AddressNotAvailable); + // If no nameservers are configured, fall back to the default behavior of using the system resolver configuration. + return NetworkInfo.GetOptions(); } - var options = new ResolverOptions(serverList.ToArray()) - { - SearchDomains = searchDomains.Count > 0 ? searchDomains.ToArray() : default - }; - - return options; + return new ResolverOptions(serverList); } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs index 82e874a1b25..673091453ad 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs @@ -7,18 +7,15 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal sealed class ResolverOptions { - public IPEndPoint[] Servers; - public string DefaultDomain = string.Empty; - public string[]? SearchDomains; - + public IReadOnlyList Servers; public int Attempts = 2; public TimeSpan Timeout = TimeSpan.FromSeconds(3); - public ResolverOptions(IPEndPoint[] servers) + public ResolverOptions(IReadOnlyList servers) { - if (servers.Length == 0) + if (servers.Count == 0) { - throw new ArgumentException("At least one server is required.", nameof(servers)); + throw new ArgumentException("At least one DNS server is required.", nameof(servers)); } Servers = servers; From fa627fdf35568f438561a4573d40d06c2b7e895e Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 13 Feb 2025 17:14:09 +0100 Subject: [PATCH 26/45] Better handling of malformed responses --- .../Resolver/DnsResolver.cs | 401 +++++++++--------- .../Resolver/RetryTests.cs | 29 +- 2 files changed, 235 insertions(+), 195 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index b7e1af4d935..8875dc7b9d0 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -60,61 +60,55 @@ internal DnsResolver(IPEndPoint server) : this(new ResolverOptions(server)) { } - public async ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) + public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) { ObjectDisposedException.ThrowIf(_disposed, this); cancellationToken.ThrowIfCancellationRequested(); name = GetNormalizedHostName(name); - NameResolutionActivity activity = Telemetry.StartNameResolution(name, QueryType.SRV, _timeProvider.GetTimestamp()); - SendQueryResult result = await SendQueryWithRetriesAsync(name, QueryType.SRV, cancellationToken).ConfigureAwait(false); + return SendQueryWithTelemetry(name, QueryType.SRV, ProcessResponse, cancellationToken); - if (result.Error is not SendQueryError.NoError) + static (SendQueryError, ServiceResult[]) ProcessResponse(string name, QueryType queryType, DnsResponse response) { - Telemetry.StopNameResolution(name, QueryType.SRV, activity, null, result.Error, _timeProvider.GetTimestamp()); - return Array.Empty(); - } - - using DnsResponse response = result.Response; - - var results = new List(response.Answers.Count); + var results = new List(response.Answers.Count); - foreach (var answer in response.Answers) - { - if (answer.Type == QueryType.SRV) + foreach (var answer in response.Answers) { - bool success = DnsPrimitives.TryReadService(answer.Data.Span, out ushort priority, out ushort weight, out ushort port, out string? target, out _); - Debug.Assert(success, "Failed to read SRV"); - - List addresses = new List(); - foreach (var additional in response.Additionals) + if (answer.Type == QueryType.SRV) { - // From RFC 2782: - // - // Target - // The domain name of the target host. There MUST be one or more - // address records for this name, the name MUST NOT be an alias (in - // the sense of RFC 1034 or RFC 2181). Implementors are urged, but - // not required, to return the address record(s) in the Additional - // Data section. Unless and until permitted by future standards - // action, name compression is not to be used for this field. - // - // A Target of "." means that the service is decidedly not - // available at this domain. - if (additional.Name == target && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA)) + if (!DnsPrimitives.TryReadService(answer.Data.Span, out ushort priority, out ushort weight, out ushort port, out string? target, out int bytesRead) || bytesRead != answer.Data.Length) { - addresses.Add(new AddressResult(response.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span))); + return (SendQueryError.MalformedResponse, []); + } + + List addresses = new List(); + foreach (var additional in response.Additionals) + { + // From RFC 2782: + // + // Target + // The domain name of the target host. There MUST be one or more + // address records for this name, the name MUST NOT be an alias (in + // the sense of RFC 1034 or RFC 2181). Implementors are urged, but + // not required, to return the address record(s) in the Additional + // Data section. Unless and until permitted by future standards + // action, name compression is not to be used for this field. + // + // A Target of "." means that the service is decidedly not + // available at this domain. + if (additional.Name == target && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA)) + { + addresses.Add(new AddressResult(response.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span))); + } } - } - results.Add(new ServiceResult(response.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target!, addresses.ToArray())); + results.Add(new ServiceResult(response.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target!, addresses.ToArray())); + } } - } - ServiceResult[] res = results.ToArray(); - Telemetry.StopNameResolution(name, QueryType.SRV, activity, res, result.Error, _timeProvider.GetTimestamp()); - return res; + return (SendQueryError.NoError, results.ToArray()); + } } public async ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) @@ -151,7 +145,7 @@ public async ValueTask ResolveIPAddressesAsync(string name, Can return results; } - public async ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) + public ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) { ObjectDisposedException.ThrowIf(_disposed, this); cancellationToken.ThrowIfCancellationRequested(); @@ -166,93 +160,95 @@ public async ValueTask ResolveIPAddressesAsync(string name, Add // name localhost exists outside of DNS and can't be resolved by a DNS server if (addressFamily == AddressFamily.InterNetwork && Socket.OSSupportsIPv4) { - return [new AddressResult(DateTime.MaxValue, IPAddress.Loopback)]; + return ValueTask.FromResult([new AddressResult(DateTime.MaxValue, IPAddress.Loopback)]); } else if (addressFamily == AddressFamily.InterNetworkV6 && Socket.OSSupportsIPv6) { - return [new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback)]; + return ValueTask.FromResult([new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback)]); } - return Array.Empty(); + return ValueTask.FromResult([]); } name = GetNormalizedHostName(name); - var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; - NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType, _timeProvider.GetTimestamp()); - SendQueryResult result = await SendQueryWithRetriesAsync(name, queryType, cancellationToken).ConfigureAwait(false); - if (result.Error is not SendQueryError.NoError) - { - Telemetry.StopNameResolution(name, queryType, activity, null, result.Error, _timeProvider.GetTimestamp()); - return Array.Empty(); - } - - using DnsResponse response = result.Response; - - // Given that result.Error is NoError, there should be at least one answer. - Debug.Assert(response.Answers.Count > 0); - var results = new List(response.Answers.Count); + return SendQueryWithTelemetry(name, queryType, ProcessResponse, cancellationToken); - // Servers send back CNAME records together with associated A/AAAA records. Servers - // send only those CNAME records relevant to the query, and if there is a CNAME record, - // there should not be other records associated with the name. Therefore, we simply follow - // the list of CNAME aliases until we get to the primary name and return the A/AAAA records - // associated. - // - // more info: https://datatracker.ietf.org/doc/html/rfc1034#section-3.6.2 - // - // Most of the servers send the CNAME records in order so that we can sequentially scan the - // answers, but nothing prevents the records from being in arbitrary order. Therefore, when - // we encounter a CNAME record, we continue down the list and allow looping back to the beginning - // in case the CNAME chain is not in order. - // - string currentAlias = name; - int i = 0; - int endIndex = 0; - - do + static (SendQueryError error, AddressResult[] result) ProcessResponse(string name, QueryType queryType, DnsResponse response) { - DnsResourceRecord answer = response.Answers[i]; + List results = new List(response.Answers.Count); - if (answer.Name == currentAlias) + // Servers send back CNAME records together with associated A/AAAA records. Servers + // send only those CNAME records relevant to the query, and if there is a CNAME record, + // there should not be other records associated with the name. Therefore, we simply follow + // the list of CNAME aliases until we get to the primary name and return the A/AAAA records + // associated. + // + // more info: https://datatracker.ietf.org/doc/html/rfc1034#section-3.6.2 + // + // Most of the servers send the CNAME records in order so that we can sequentially scan the + // answers, but nothing prevents the records from being in arbitrary order. Therefore, when + // we encounter a CNAME record, we continue down the list and allow looping back to the beginning + // in case the CNAME chain is not in order. + // + string currentAlias = name; + int i = 0; + int endIndex = 0; + + do { - if (answer.Type == QueryType.CNAME) - { - // Although RFC does not necessarily allow pointer segments in CNAME domain names, some servers do use them - // so we need to pass the entire buffer to TryReadQName with the proper offset. The data should be always - // backed by the array containing the full response. + DnsResourceRecord answer = response.Answers[i]; - var success = MemoryMarshal.TryGetArray(answer.Data, out ArraySegment segment); - Debug.Assert(success, "Failed to get array segment"); - if (!DnsPrimitives.TryReadQName(segment.Array.AsSpan(0, segment.Offset + segment.Count), segment.Offset, out currentAlias!, out _)) + if (answer.Name == currentAlias) + { + if (answer.Type == QueryType.CNAME) { - // TODO: how to handle corrupted responses? - throw new InvalidOperationException("Failed to parse CNAME record"); + // Although RFC does not necessarily allow pointer segments in CNAME domain names, some servers do use them + // so we need to pass the entire buffer to TryReadQName with the proper offset. The data should be always + // backed by the array containing the full response. + + var success = MemoryMarshal.TryGetArray(answer.Data, out ArraySegment segment); + Debug.Assert(success, "Failed to get array segment"); + if (!DnsPrimitives.TryReadQName(segment.Array.AsSpan(0, segment.Offset + segment.Count), segment.Offset, out currentAlias!, out int bytesRead) || bytesRead != answer.Data.Length) + { + return (SendQueryError.MalformedResponse, []); + } + + // We need to start over. start with following answers and allow looping back + endIndex = i; + + if (string.Equals(currentAlias, name, StringComparison.OrdinalIgnoreCase)) + { + // CNAME records looped back to original question dns name (=> malformed response). Stop processing. + return (SendQueryError.MalformedResponse, []); + } } - - // We need to start over. start with following answers and allow looping back - endIndex = i; - - if (string.Equals(currentAlias, name, StringComparison.OrdinalIgnoreCase)) + else if (answer.Type == queryType) { - // CNAME records looped back to original question dns name (=> malformed response). Stop processing. - break; + if (answer.Data.Length != IPv4Length && answer.Data.Length != IPv6Length) + { + return (SendQueryError.MalformedResponse, []); + } + + results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), new IPAddress(answer.Data.Span))); } } - else if (answer.Type == queryType) - { - Debug.Assert(answer.Data.Length == IPv4Length || answer.Data.Length == IPv6Length); - results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), new IPAddress(answer.Data.Span))); - } + + i = (i + 1) % response.Answers.Count; } + while (i != endIndex); - i = (i + 1) % response.Answers.Count; + return (SendQueryError.NoError, results.ToArray()); } - while (i != endIndex); + } + + private async ValueTask SendQueryWithTelemetry(string name, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) + { + NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType, _timeProvider.GetTimestamp()); + (SendQueryError error, TResult[] result) = await SendQueryWithRetriesAsync(name, queryType, processResponseFunc, cancellationToken).ConfigureAwait(false); + Telemetry.StopNameResolution(name, queryType, activity, null, error, _timeProvider.GetTimestamp()); - AddressResult[] res = results.ToArray(); - Telemetry.StopNameResolution(name, queryType, activity, res, result.Error, _timeProvider.GetTimestamp()); - return res; + return result; } internal struct SendQueryResult @@ -261,104 +257,108 @@ internal struct SendQueryResult public SendQueryError Error; } - async ValueTask SendQueryWithRetriesAsync(string name, QueryType queryType, CancellationToken cancellationToken) + async ValueTask<(SendQueryError error, TResult[] result)> SendQueryWithRetriesAsync(string name, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) { - SendQueryResult? result = default; - + SendQueryError lastError = SendQueryError.InternalError; // will be overwritten by the first attempt for (int index = 0; index < _options.Servers.Count; index++) { IPEndPoint serverEndPoint = _options.Servers[index]; for (int attempt = 1; attempt <= _options.Attempts; attempt++) { + DnsResponse response = default; try { - SendQueryResult newResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, attempt, cancellationToken).ConfigureAwait(false); + TResult[] results = Array.Empty(); - if (result.HasValue) + try + { + SendQueryResult queryResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, attempt, cancellationToken).ConfigureAwait(false); + lastError = queryResult.Error; + response = queryResult.Response; + + if (lastError == SendQueryError.NoError) + { + // Given that result.Error is NoError, there should be at least one answer. + Debug.Assert(response.Answers.Count > 0); + (lastError, results) = processResponseFunc(name, queryType, queryResult.Response); + } + } + catch (SocketException ex) { - result.Value.Response.Dispose(); + Log.NetworkError(_logger, queryType, name, serverEndPoint, attempt, ex); + lastError = SendQueryError.NetworkError; + } + catch (Exception ex) when (!cancellationToken.IsCancellationRequested) + { + Log.QueryError(_logger, queryType, name, serverEndPoint, attempt, ex); + lastError = SendQueryError.InternalError; } - result = newResult; - } - catch (SocketException ex) - { - Log.NetworkError(_logger, queryType, name, serverEndPoint, attempt, ex); - result = new SendQueryResult { Error = SendQueryError.NetworkError }; - } - catch (Exception ex) when (!cancellationToken.IsCancellationRequested) - { - Log.QueryError(_logger, queryType, name, serverEndPoint, attempt, ex); - result = new SendQueryResult { Error = SendQueryError.InternalError }; - } + switch (lastError) + { + // + // Definitive answers, no point retrying + // + case SendQueryError.NoError: + return (lastError, results); + + case SendQueryError.NameError: + // authoritative answer that the name does not exist, no point in retrying + Log.NameError(_logger, queryType, name, serverEndPoint, attempt); + return (lastError, results); + + case SendQueryError.NoData: + // no data available for the name from authoritative server + Log.NoData(_logger, queryType, name, serverEndPoint, attempt); + return (lastError, results); + + // + // Transient errors, retry on the same server + // + case SendQueryError.Timeout: + Log.Timeout(_logger, queryType, name, serverEndPoint, attempt); + continue; + + case SendQueryError.NetworkError: + // TODO: retry with exponential backoff? + continue; + + case SendQueryError.ServerError when response.Header.ResponseCode == QueryResponseCode.ServerFailure: + // ServerFailure may indicate transient failure with upstream DNS servers, retry on the same server + Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, response.Header.ResponseCode); + continue; + + // + // Persistent errors, skip to the next server + // + case SendQueryError.ServerError: + // this should cover all response codes except NoError, NameError which are definite and handled above, and + // ServerFailure which is a transient error and handled above. + Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, response.Header.ResponseCode); + break; + + case SendQueryError.MalformedResponse: + Log.MalformedResponse(_logger, queryType, name, serverEndPoint, attempt); + break; + + case SendQueryError.InternalError: + // exception logged above. + break; + } - switch (result.Value.Error) + // actual break that causes skipping to the next server + break; + } + finally { - // - // Definitive answers, no point retrying - // - case SendQueryError.NoError: - return result.Value; - - case SendQueryError.NameError: - // authoritative answer that the name does not exist, no point in retrying - Log.NameError(_logger, queryType, name, serverEndPoint, attempt); - return result.Value; - - case SendQueryError.NoData: - // no data available for the name from authoritative server - Log.NoData(_logger, queryType, name, serverEndPoint, attempt); - return result.Value; - - // - // Transient errors, retry on the same server - // - case SendQueryError.Timeout: - Log.Timeout(_logger, queryType, name, serverEndPoint, attempt); - continue; - - case SendQueryError.NetworkError: - // TODO: retry with exponential backoff? - continue; - - case SendQueryError.ServerError when result.Value.Response!.Header.ResponseCode == QueryResponseCode.ServerFailure: - // ServerFailure may indicate transient failure with upstream DNS servers, retry on the same server - Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, result.Value.Response.Header.ResponseCode); - continue; - - // - // Persistent errors, skip to the next server - // - case SendQueryError.ServerError: - // this should cover all response codes except NoError, NameError which are definite and handled above, and - // ServerFailure which is a transient error and handled above. - Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, result.Value.Response.Header.ResponseCode); - break; - - case SendQueryError.MalformedResponse: - Log.MalformedResponse(_logger, queryType, name, serverEndPoint, attempt); - break; - - case SendQueryError.InternalError: - // exception logged above. - break; + response.Dispose(); } - - // actual break that causes skipping to the next server - break; } } - // we should have a result by now - Debug.Assert(result.HasValue); - - if (!result.HasValue) - { - result = new SendQueryResult { Error = SendQueryError.InternalError }; - } - - return result!.Value; + // if we get here, we exhausted all servers and all attempts + return (lastError, []); } internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, int attempt, CancellationToken cancellationToken) @@ -417,6 +417,17 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve return new SendQueryResult { Error = sendError }; } + if ((uint)header.ResponseCode > (uint)QueryResponseCode.Refused) + { + // Response code is outside of valid range + return new SendQueryResult + { + Response = new DnsResponse(Array.Empty(), header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.MalformedResponse + }; + } + + // Recheck that the server echoes back the DNS question if (header.QueryCount != 1 || !responseReader.TryReadQuestion(out var qName, out var qType, out var qClass) || qName != name || qType != queryType || qClass != QueryClass.Internet) @@ -429,10 +440,19 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve }; } + // Structurally separate the resource records, this will validate only the + // "outside structure" of the resource record, it will not validate the content. int ttl = int.MaxValue; - List answers = ReadRecords(header.AnswerCount, ref ttl, ref responseReader); - List authorities = ReadRecords(header.AuthorityCount, ref ttl, ref responseReader); - List additionals = ReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader); + if (!TryReadRecords(header.AnswerCount, ref ttl, ref responseReader, out List? answers) || + !TryReadRecords(header.AuthorityCount, ref ttl, ref responseReader, out List? authorities) || + !TryReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader, out List? additionals)) + { + return new SendQueryResult + { + Response = new DnsResponse(Array.Empty(), header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.MalformedResponse + }; + } DateTime expirationTime = (answers.Count + authorities.Count + additionals.Count) > 0 ? queryStartedTime.AddSeconds(ttl) : queryStartedTime; @@ -450,23 +470,22 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve responseReader.Dispose(); } - static List ReadRecords(int count, ref int ttl, ref DnsDataReader reader) + static bool TryReadRecords(int count, ref int ttl, ref DnsDataReader reader, out List records) { - List records = new(count); + records = new(count); for (int i = 0; i < count; i++) { if (!reader.TryReadResourceRecord(out var record)) { - // TODO how to handle corrupted responses? - throw new InvalidOperationException("Invalid response: corrupted record"); + return false; } ttl = Math.Min(ttl, record.Ttl); records.Add(new DnsResourceRecord(record.Name, record.Type, record.Class, record.Ttl, record.Data)); } - return records; + return true; } } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs index b7c12bfc763..17a38b3a302 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs @@ -66,10 +66,18 @@ public async Task Retry_Simple_Success() Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); } + public enum PersistentErrorType + { + NotImplemented, + Refused, + MalformedResponse + } + [Theory] - [InlineData(/* QueryResponseCode.NotImplemented */ 4)] - [InlineData(/* QueryResponseCode.Refused */ 5)] - public async Task PersistentErrorsResponseCode_FailoverToNextServer(int responseCode) + [InlineData(PersistentErrorType.NotImplemented)] + [InlineData(PersistentErrorType.Refused)] + [InlineData(PersistentErrorType.MalformedResponse)] + public async Task PersistentErrorsResponseCode_FailoverToNextServer(PersistentErrorType type) { IPAddress address = IPAddress.Parse("172.213.245.111"); @@ -80,7 +88,20 @@ public async Task PersistentErrorsResponseCode_FailoverToNextServer(int response builder => { primaryAttempt++; - builder.ResponseCode = (QueryResponseCode)responseCode; + switch (type) + { + case PersistentErrorType.NotImplemented: + builder.ResponseCode = QueryResponseCode.NotImplemented; + break; + + case PersistentErrorType.Refused: + builder.ResponseCode = QueryResponseCode.Refused; + break; + + case PersistentErrorType.MalformedResponse: + builder.ResponseCode = (QueryResponseCode)0xFF; + break; + } return Task.CompletedTask; }, builder => From 4ec79ca32d8d13420472c6f6d99211b798f6f271 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Tue, 25 Mar 2025 16:39:38 +0100 Subject: [PATCH 27/45] Code review feedback --- .../Resolver/DnsMessageHeader.cs | 2 +- .../Resolver/DnsResolver.cs | 13 ++++++++----- .../Resolver/DnsResponse.cs | 2 +- .../Resolver/QueryFlags.cs | 1 + .../Resolver/TcpFailoverTests.cs | 2 -- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs index b3d5effb0d6..b22273a04f2 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs @@ -21,7 +21,7 @@ internal struct DnsMessageHeader public QueryResponseCode ResponseCode { - get => (QueryResponseCode)((int)QueryFlags & 0x000F); + get => (QueryResponseCode)(QueryFlags & QueryFlags.ResponseCodeMask); } public bool IsResultTruncated diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 8875dc7b9d0..95532cd09a3 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -422,7 +422,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve // Response code is outside of valid range return new SendQueryResult { - Response = new DnsResponse(Array.Empty(), header, queryStartedTime, queryStartedTime, null!, null!, null!), + Response = new DnsResponse(null, header, queryStartedTime, queryStartedTime, null!, null!, null!), Error = SendQueryError.MalformedResponse }; } @@ -435,7 +435,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve // DNS Question mismatch return new SendQueryResult { - Response = new DnsResponse(Array.Empty(), header, queryStartedTime, queryStartedTime, null!, null!, null!), + Response = new DnsResponse(null, header, queryStartedTime, queryStartedTime, null!, null!, null!), Error = SendQueryError.MalformedResponse }; } @@ -449,7 +449,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve { return new SendQueryResult { - Response = new DnsResponse(Array.Empty(), header, queryStartedTime, queryStartedTime, null!, null!, null!), + Response = new DnsResponse(null, header, queryStartedTime, queryStartedTime, null!, null!, null!), Error = SendQueryError.MalformedResponse }; } @@ -460,7 +460,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve SendQueryError validationError = ValidateResponse(header.ResponseCode, queryStartedTime, answers, authorities, ref expirationTime); // we transfer ownership of RawData to the response - DnsResponse response = new(responseReader.RawData!, header, queryStartedTime, expirationTime, answers, authorities, additionals); + DnsResponse response = new DnsResponse(responseReader.RawData!, header, queryStartedTime, expirationTime, answers, authorities, additionals); responseReader = default; // avoid disposing (and returning RawData to the pool) return new SendQueryResult { Response = response, Error = validationError }; @@ -472,7 +472,10 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve static bool TryReadRecords(int count, ref int ttl, ref DnsDataReader reader, out List records) { - records = new(count); + // Since `count` is attacker controlled, limit the initial capacity + // to 32 items to avoid excessive memory allocation. More than 32 + // records are unusual so we don't need to optimize for them. + records = new(Math.Min(count, 32)); for (int i = 0; i < count; i++) { diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs index b76ccf1f47f..f980eb7091f 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs @@ -16,7 +16,7 @@ internal struct DnsResponse : IDisposable public ReadOnlyMemory RawData => _rawData ?? ReadOnlyMemory.Empty; private byte[]? _rawData; - public DnsResponse(byte[] rawData, DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals) + public DnsResponse(byte[]? rawData, DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals) { _rawData = rawData; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs index d983a83b35a..02474b6cda1 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs @@ -11,4 +11,5 @@ internal enum QueryFlags : ushort ResultTruncated = 0x0200, HasAuthorityAnswer = 0x0400, HasResponse = 0x8000, + ResponseCodeMask = 0x000F, } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs index 67ced0b2ac0..ac10d0d6d8c 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -43,7 +43,6 @@ public async Task TcpFailover_ServerClosesWithoutData_EmptyResult() Options.Attempts = 1; Options.Timeout = TimeSpan.FromSeconds(60); - IPAddress address = IPAddress.Parse("172.213.245.111"); _ = DnsServer.ProcessUdpRequest(builder => { builder.Flags |= QueryFlags.ResultTruncated; @@ -67,7 +66,6 @@ public async Task TcpFailover_TcpNotAvailable_EmptyResult() Options.Attempts = 1; Options.Timeout = TimeSpan.FromMilliseconds(100000); - IPAddress address = IPAddress.Parse("172.213.245.111"); _ = DnsServer.ProcessUdpRequest(builder => { builder.Flags |= QueryFlags.ResultTruncated; From 123375d54e2886e03aeb6212c75c46adf9e9cc5a Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 27 Mar 2025 12:16:10 +0100 Subject: [PATCH 28/45] More feedback --- .../Resolver/DnsResolver.cs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 95532cd09a3..e670e60a98b 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -588,14 +588,17 @@ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, while (true) { - int readLength = await socket.ReceiveAsync(memory, SocketFlags.None, cancellationToken).ConfigureAwait(false); + // Because this is UDP, the response must be in a single packet, + // if the response does not fit into a single UDP packet, the server will + // set the Truncated flag in the header, and we will need to retry with TCP. + int packetLength = await socket.ReceiveAsync(memory, SocketFlags.None, cancellationToken).ConfigureAwait(false); - if (readLength < DnsMessageHeader.HeaderLength) + if (packetLength < DnsMessageHeader.HeaderLength) { continue; } - responseReader = new DnsDataReader(memory.Slice(0, readLength), buffer); + responseReader = new DnsDataReader(memory.Slice(0, packetLength), buffer); if (!responseReader.TryReadHeader(out header) || header.TransactionId != transactionId || !header.IsResponse) From 63116165e31adfe3fc2d5bdc6b233ff17716c8f7 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 27 Mar 2025 13:53:55 +0100 Subject: [PATCH 29/45] Guarantee linear parsing of CNAME chains --- .../Resolver/DnsResolver.cs | 161 ++++++++++++++---- 1 file changed, 130 insertions(+), 31 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index e670e60a98b..1ce7d1fdc89 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Buffers.Binary; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Net; using System.Net.Sockets; @@ -187,58 +188,156 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa // more info: https://datatracker.ietf.org/doc/html/rfc1034#section-3.6.2 // // Most of the servers send the CNAME records in order so that we can sequentially scan the - // answers, but nothing prevents the records from being in arbitrary order. Therefore, when - // we encounter a CNAME record, we continue down the list and allow looping back to the beginning - // in case the CNAME chain is not in order. - // + // answers, but nothing prevents the records from being in arbitrary order. Attempt the linear + // scan first and fallback to a slower but more robust method if necessary. + + bool success = true; string currentAlias = name; - int i = 0; - int endIndex = 0; - do + foreach (var answer in response.Answers) { - DnsResourceRecord answer = response.Answers[i]; - if (answer.Name == currentAlias) + switch (answer.Type) { - if (answer.Type == QueryType.CNAME) - { - // Although RFC does not necessarily allow pointer segments in CNAME domain names, some servers do use them - // so we need to pass the entire buffer to TryReadQName with the proper offset. The data should be always - // backed by the array containing the full response. - - var success = MemoryMarshal.TryGetArray(answer.Data, out ArraySegment segment); - Debug.Assert(success, "Failed to get array segment"); - if (!DnsPrimitives.TryReadQName(segment.Array.AsSpan(0, segment.Offset + segment.Count), segment.Offset, out currentAlias!, out int bytesRead) || bytesRead != answer.Data.Length) + case QueryType.CNAME: + if (!TryReadTarget(answer, out string? target)) { return (SendQueryError.MalformedResponse, []); } - // We need to start over. start with following answers and allow looping back - endIndex = i; + if (string.Equals(answer.Name, currentAlias, StringComparison.OrdinalIgnoreCase)) + { + currentAlias = target; + continue; + } + + break; - if (string.Equals(currentAlias, name, StringComparison.OrdinalIgnoreCase)) + case var type when type == queryType: + if (!TryReadAddress(answer, queryType, out IPAddress? address)) { - // CNAME records looped back to original question dns name (=> malformed response). Stop processing. return (SendQueryError.MalformedResponse, []); } - } - else if (answer.Type == queryType) - { - if (answer.Data.Length != IPv4Length && answer.Data.Length != IPv6Length) + + if (string.Equals(answer.Name, currentAlias, StringComparison.OrdinalIgnoreCase)) { - return (SendQueryError.MalformedResponse, []); + results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), address)); + continue; } - results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), new IPAddress(answer.Data.Span))); + break; + } + + // unexpected name or record type, fall back to more robust path + results.Clear(); + success = false; + break; + } + + if (success) + { + return (SendQueryError.NoError, results.ToArray()); + } + + // more expensive path for uncommon (but valid) cases where CNAME records are out of order. Use of Dictionary + // allows us to stay within O(n) complexity for the number of answers, but we will use more memory. + Dictionary aliasMap = new Dictionary(StringComparer.OrdinalIgnoreCase); + Dictionary> aRecordMap = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var answer in response.Answers) + { + if (answer.Type == QueryType.CNAME) + { + // map the alias to the target name + if (!TryReadTarget(answer, out string? target)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (!aliasMap.TryAdd(answer.Name, target)) + { + // Duplicate CNAME record + return (SendQueryError.MalformedResponse, []); } } - i = (i + 1) % response.Answers.Count; + if (answer.Type == queryType) + { + if (!TryReadAddress(answer, queryType, out IPAddress? address)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (!aRecordMap.TryGetValue(answer.Name, out List? addressList)) + { + addressList = new List(); + aRecordMap.Add(answer.Name, addressList); + } + + addressList.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), address)); + } } - while (i != endIndex); - return (SendQueryError.NoError, results.ToArray()); + // follow the CNAME chain, limit the maximum number of iterations to avoid infinite loops. + int i = 0; + currentAlias = name; + while (aliasMap.TryGetValue(currentAlias, out string? nextAlias)) + { + if (i >= aliasMap.Count) + { + // circular CNAME chain + return (SendQueryError.MalformedResponse, []); + } + + i++; + + if (aRecordMap.ContainsKey(currentAlias)) + { + // both CNAME record and A/AAAA records exist for the current alias + return (SendQueryError.MalformedResponse, []); + } + + currentAlias = nextAlias; + } + + // Now we have the final target name, check if we have any A/AAAA records for it. + aRecordMap.TryGetValue(currentAlias, out List? finalAddressList); + return (SendQueryError.NoError, finalAddressList?.ToArray() ?? []); + + static bool TryReadTarget(in DnsResourceRecord record, [NotNullWhen(true)] out string? target) + { + Debug.Assert(record.Type == QueryType.CNAME, "Only CNAME records should be processed here."); + + target = null; + + // some servers use domain name compression even inside CNAME records. In order to decode those + // correctly, we need to pass the entire message to TryReadQName. The Data span inside the record + // should be backed by the array containing the entire DNS message. + var gotArray = MemoryMarshal.TryGetArray(record.Data, out ArraySegment segment); + Debug.Assert(gotArray, "Failed to get array segment"); + + bool result = DnsPrimitives.TryReadQName(segment.Array.AsSpan(0, segment.Offset + segment.Count), segment.Offset, out string? targetName, out int bytesRead) && bytesRead == record.Data.Length; + if (result) + { + target = targetName; + } + + return result; + } + + static bool TryReadAddress(in DnsResourceRecord record, QueryType type, [NotNullWhen(true)] out IPAddress? target) + { + Debug.Assert(record.Type is QueryType.A or QueryType.AAAA, "Only CNAME records should be processed here."); + + target = null; + if (record.Type == QueryType.A && record.Data.Length != IPv4Length || + record.Type == QueryType.AAAA && record.Data.Length != IPv6Length) + { + return false; + } + + target = new IPAddress(record.Data.Span); + return true; + } } } From c57a8c5a9e38685526932a71ce126b99654414eb Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 27 Mar 2025 14:25:24 +0100 Subject: [PATCH 30/45] Fix decoding compressed CNAME in TCP fallback --- .../Resolver/DnsDataReader.cs | 35 ++++++++++--------- .../Resolver/DnsResolver.cs | 26 ++++++++------ .../Resolver/DnsResponse.cs | 14 ++++---- .../Resolver/DnsDataReaderTests.cs | 2 +- .../Resolver/LoopbackDnsServer.cs | 6 ++-- 5 files changed, 44 insertions(+), 39 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs index 831f03b082c..0dbd6b90b0e 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -10,22 +10,22 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal struct DnsDataReader : IDisposable { - public byte[]? RawData { get; private set; } - private ReadOnlyMemory _buffer; + public ArraySegment MessageBuffer { get; private set; } + bool _returnToPool; private int _position; - public DnsDataReader(ReadOnlyMemory buffer, byte[]? returnToPool = null) + public DnsDataReader(ArraySegment buffer, bool returnToPool = false) { - _buffer = buffer; + MessageBuffer = buffer; _position = 0; - RawData = returnToPool; + _returnToPool = returnToPool; } public bool TryReadHeader(out DnsMessageHeader header) { Debug.Assert(_position == 0); - if (!DnsPrimitives.TryReadMessageHeader(_buffer.Span, out header, out int bytesRead)) + if (!DnsPrimitives.TryReadMessageHeader(MessageBuffer.AsSpan(), out header, out int bytesRead)) { header = default; return false; @@ -53,26 +53,26 @@ internal bool TryReadQuestion([NotNullWhen(true)] out string? name, out QueryTyp public bool TryReadUInt16(out ushort value) { - if (_buffer.Length - _position < 2) + if (MessageBuffer.Count - _position < 2) { value = 0; return false; } - value = BinaryPrimitives.ReadUInt16BigEndian(_buffer.Span.Slice(_position)); + value = BinaryPrimitives.ReadUInt16BigEndian(MessageBuffer.AsSpan(_position)); _position += 2; return true; } public bool TryReadUInt32(out uint value) { - if (_buffer.Length - _position < 4) + if (MessageBuffer.Count - _position < 4) { value = 0; return false; } - value = BinaryPrimitives.ReadUInt32BigEndian(_buffer.Span.Slice(_position)); + value = BinaryPrimitives.ReadUInt32BigEndian(MessageBuffer.AsSpan(_position)); _position += 4; return true; } @@ -84,13 +84,13 @@ public bool TryReadResourceRecord(out DnsResourceRecord record) !TryReadUInt16(out ushort @class) || !TryReadUInt32(out uint ttl) || !TryReadUInt16(out ushort dataLength) || - _buffer.Length - _position < dataLength) + MessageBuffer.Count - _position < dataLength) { record = default; return false; } - ReadOnlyMemory data = _buffer.Slice(_position, dataLength); + ReadOnlyMemory data = MessageBuffer.AsMemory(_position, dataLength); _position += dataLength; record = new DnsResourceRecord(name, (QueryType)type, (QueryClass)@class, (int)ttl, data); @@ -99,7 +99,7 @@ record = default; public bool TryReadDomainName([NotNullWhen(true)] out string? name) { - if (DnsPrimitives.TryReadQName(_buffer.Span, _position, out name, out int bytesRead)) + if (DnsPrimitives.TryReadQName(MessageBuffer.AsSpan(), _position, out name, out int bytesRead)) { _position += bytesRead; return true; @@ -110,12 +110,13 @@ public bool TryReadDomainName([NotNullWhen(true)] out string? name) public void Dispose() { - if (RawData is not null) + if (!_returnToPool || MessageBuffer.Array == null) { - ArrayPool.Shared.Return(RawData); - RawData = null; + return; // nothing to do if we are not returning to the pool } - _buffer = default; + _returnToPool = false; + ArrayPool.Shared.Return(MessageBuffer.Array); + MessageBuffer = default; } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 1ce7d1fdc89..4ec1ea51301 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -200,7 +200,7 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa switch (answer.Type) { case QueryType.CNAME: - if (!TryReadTarget(answer, out string? target)) + if (!TryReadTarget(answer, response.RawMessageBytes, out string? target)) { return (SendQueryError.MalformedResponse, []); } @@ -248,7 +248,7 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa if (answer.Type == QueryType.CNAME) { // map the alias to the target name - if (!TryReadTarget(answer, out string? target)) + if (!TryReadTarget(answer, response.RawMessageBytes, out string? target)) { return (SendQueryError.MalformedResponse, []); } @@ -303,7 +303,7 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa aRecordMap.TryGetValue(currentAlias, out List? finalAddressList); return (SendQueryError.NoError, finalAddressList?.ToArray() ?? []); - static bool TryReadTarget(in DnsResourceRecord record, [NotNullWhen(true)] out string? target) + static bool TryReadTarget(in DnsResourceRecord record, ArraySegment messageBytes, [NotNullWhen(true)] out string? target) { Debug.Assert(record.Type == QueryType.CNAME, "Only CNAME records should be processed here."); @@ -311,11 +311,15 @@ static bool TryReadTarget(in DnsResourceRecord record, [NotNullWhen(true)] out s // some servers use domain name compression even inside CNAME records. In order to decode those // correctly, we need to pass the entire message to TryReadQName. The Data span inside the record - // should be backed by the array containing the entire DNS message. + // should be backed by the array containing the entire DNS message. We just need to account for the + // 2 byte offset in case of TCP fallback. var gotArray = MemoryMarshal.TryGetArray(record.Data, out ArraySegment segment); Debug.Assert(gotArray, "Failed to get array segment"); + Debug.Assert(segment.Array == messageBytes.Array, "record data backed by different array than the original message"); - bool result = DnsPrimitives.TryReadQName(segment.Array.AsSpan(0, segment.Offset + segment.Count), segment.Offset, out string? targetName, out int bytesRead) && bytesRead == record.Data.Length; + int messageOffset = messageBytes.Offset; + + bool result = DnsPrimitives.TryReadQName(segment.Array.AsSpan(messageOffset, segment.Offset + segment.Count - messageOffset), segment.Offset, out string? targetName, out int bytesRead) && bytesRead == record.Data.Length; if (result) { target = targetName; @@ -521,7 +525,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve // Response code is outside of valid range return new SendQueryResult { - Response = new DnsResponse(null, header, queryStartedTime, queryStartedTime, null!, null!, null!), + Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!), Error = SendQueryError.MalformedResponse }; } @@ -534,7 +538,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve // DNS Question mismatch return new SendQueryResult { - Response = new DnsResponse(null, header, queryStartedTime, queryStartedTime, null!, null!, null!), + Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!), Error = SendQueryError.MalformedResponse }; } @@ -548,7 +552,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve { return new SendQueryResult { - Response = new DnsResponse(null, header, queryStartedTime, queryStartedTime, null!, null!, null!), + Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!), Error = SendQueryError.MalformedResponse }; } @@ -559,7 +563,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve SendQueryError validationError = ValidateResponse(header.ResponseCode, queryStartedTime, answers, authorities, ref expirationTime); // we transfer ownership of RawData to the response - DnsResponse response = new DnsResponse(responseReader.RawData!, header, queryStartedTime, expirationTime, answers, authorities, additionals); + DnsResponse response = new DnsResponse(responseReader.MessageBuffer, header, queryStartedTime, expirationTime, answers, authorities, additionals); responseReader = default; // avoid disposing (and returning RawData to the pool) return new SendQueryResult { Response = response, Error = validationError }; @@ -697,7 +701,7 @@ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, continue; } - responseReader = new DnsDataReader(memory.Slice(0, packetLength), buffer); + responseReader = new DnsDataReader(new ArraySegment(buffer, 0, packetLength), true); if (!responseReader.TryReadHeader(out header) || header.TransactionId != transactionId || !header.IsResponse) @@ -762,7 +766,7 @@ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, } } - DnsDataReader responseReader = new DnsDataReader(buffer.AsMemory(2, responseLength), buffer); + DnsDataReader responseReader = new DnsDataReader(new ArraySegment(buffer, 2, responseLength), true); if (!responseReader.TryReadHeader(out DnsMessageHeader header) || header.TransactionId != transactionId || !header.IsResponse) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs index f980eb7091f..5a7fc8a0b52 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs @@ -13,12 +13,11 @@ internal struct DnsResponse : IDisposable public List Additionals { get; } public DateTime CreatedAt { get; } public DateTime Expiration { get; } - public ReadOnlyMemory RawData => _rawData ?? ReadOnlyMemory.Empty; - private byte[]? _rawData; + public ArraySegment RawMessageBytes { get; private set; } - public DnsResponse(byte[]? rawData, DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals) + public DnsResponse(ArraySegment rawData, DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals) { - _rawData = rawData; + RawMessageBytes = rawData; Header = header; CreatedAt = createdAt; @@ -30,10 +29,11 @@ public DnsResponse(byte[]? rawData, DnsMessageHeader header, DateTime createdAt, public void Dispose() { - if (_rawData != null) + if (RawMessageBytes.Array != null) { - ArrayPool.Shared.Return(_rawData); - _rawData = null; + ArrayPool.Shared.Return(RawMessageBytes.Array); } + + RawMessageBytes = default; // prevent further access to the raw data } } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs index b889270e19e..241d6a8a7a9 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs @@ -57,7 +57,7 @@ public void ReadResourceRecord_Truncated_Fails() for (int i = 0; i < buffer.Length; i++) { - DnsDataReader reader = new DnsDataReader(buffer.AsMemory(0, i)); + DnsDataReader reader = new DnsDataReader(new ArraySegment(buffer, 0, i)); Assert.False(reader.TryReadResourceRecord(out _)); } } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs index 9ca1267261f..4160f2e3cb1 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -36,7 +36,7 @@ public void DisableTcpFallback() _tcpSocket.Close(); } - private static async Task ProcessRequestCore(ReadOnlyMemory message, Func action, Memory responseBuffer) + private static async Task ProcessRequestCore(ArraySegment message, Func action, Memory responseBuffer) { DnsDataReader reader = new DnsDataReader(message); @@ -110,7 +110,7 @@ public async Task ProcessUdpRequest(Func actio EndPoint remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); SocketReceiveFromResult result = await _dnsSocket.ReceiveFromAsync(buffer, remoteEndPoint); - int bytesWritten = await ProcessRequestCore(buffer.AsMemory(0, result.ReceivedBytes), action, buffer.AsMemory(0, 512)); + int bytesWritten = await ProcessRequestCore(new ArraySegment(buffer, 0, result.ReceivedBytes), action, buffer.AsMemory(0, 512)); await _dnsSocket.SendToAsync(buffer.AsMemory(0, bytesWritten), SocketFlags.None, result.RemoteEndPoint); } @@ -143,7 +143,7 @@ public async Task ProcessTcpRequest(Func actio } } - int bytesWritten = await ProcessRequestCore(buffer.AsMemory(2, length), action, buffer.AsMemory(2)); + int bytesWritten = await ProcessRequestCore(new ArraySegment(buffer, 2, length), action, buffer.AsMemory(2)); BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(0, 2), (ushort)bytesWritten); await tcpClient.SendAsync(buffer.AsMemory(0, bytesWritten + 2), SocketFlags.None); } From 83ca95872a323012f15e7aa410c321772dd32b03 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Mon, 12 May 2025 16:15:04 +0200 Subject: [PATCH 31/45] More code coverage --- .../Resolver/DnsResolver.cs | 1 - .../Resolver/LoopbackDnsServer.cs | 137 +++++++++++------- .../Resolver/ResolveAddressesTests.cs | 121 +++++++++++++++- .../Resolver/TcpFailoverTests.cs | 49 +++++++ 4 files changed, 250 insertions(+), 58 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 4ec1ea51301..8b06cbff55b 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -196,7 +196,6 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa foreach (var answer in response.Answers) { - switch (answer.Type) { case QueryType.CNAME: diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs index 4160f2e3cb1..12ab8bf59b3 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -36,7 +36,7 @@ public void DisableTcpFallback() _tcpSocket.Close(); } - private static async Task ProcessRequestCore(ArraySegment message, Func action, Memory responseBuffer) + private static async Task ProcessRequestCore(IPEndPoint remoteEndPoint, ArraySegment message, Func action, Memory responseBuffer) { DnsDataReader reader = new DnsDataReader(message); @@ -51,58 +51,12 @@ private static async Task ProcessRequestCore(ArraySegment message, Fu responseBuilder.Flags = header.QueryFlags | QueryFlags.HasResponse; responseBuilder.ResponseCode = QueryResponseCode.NoError; - await action(responseBuilder); + await action(responseBuilder, remoteEndPoint); - DnsDataWriter writer = new(responseBuffer); - if (!writer.TryWriteHeader(new DnsMessageHeader - { - TransactionId = responseBuilder.TransactionId, - QueryFlags = responseBuilder.Flags | (QueryFlags)responseBuilder.ResponseCode, - QueryCount = (ushort)responseBuilder.Questions.Count, - AnswerCount = (ushort)responseBuilder.Answers.Count, - AuthorityCount = (ushort)responseBuilder.Authorities.Count, - AdditionalRecordCount = (ushort)responseBuilder.Additionals.Count - })) - { - throw new InvalidOperationException("Failed to write header"); - } - - foreach (var (questionName, questionType, questionClass) in responseBuilder.Questions) - { - if (!writer.TryWriteQuestion(questionName, questionType, questionClass)) - { - throw new InvalidOperationException("Failed to write question"); - } - } - - foreach (var answer in responseBuilder.Answers) - { - if (!writer.TryWriteResourceRecord(answer)) - { - throw new InvalidOperationException("Failed to write answer"); - } - } - - foreach (var authority in responseBuilder.Authorities) - { - if (!writer.TryWriteResourceRecord(authority)) - { - throw new InvalidOperationException("Failed to write authority"); - } - } - - foreach (var additional in responseBuilder.Additionals) - { - if (!writer.TryWriteResourceRecord(additional)) - { - throw new InvalidOperationException("Failed to write additional records"); - } - } - - return writer.Position; + return responseBuilder.Write(responseBuffer); } - public async Task ProcessUdpRequest(Func action) + public async Task ProcessUdpRequest(Func action) { byte[] buffer = ArrayPool.Shared.Rent(512); try @@ -110,7 +64,7 @@ public async Task ProcessUdpRequest(Func actio EndPoint remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); SocketReceiveFromResult result = await _dnsSocket.ReceiveFromAsync(buffer, remoteEndPoint); - int bytesWritten = await ProcessRequestCore(new ArraySegment(buffer, 0, result.ReceivedBytes), action, buffer.AsMemory(0, 512)); + int bytesWritten = await ProcessRequestCore((IPEndPoint)result.RemoteEndPoint, new ArraySegment(buffer, 0, result.ReceivedBytes), action, buffer.AsMemory(0, 512)); await _dnsSocket.SendToAsync(buffer.AsMemory(0, bytesWritten), SocketFlags.None, result.RemoteEndPoint); } @@ -120,15 +74,18 @@ public async Task ProcessUdpRequest(Func actio } } - public async Task ProcessTcpRequest(Func action) + public Task ProcessUdpRequest(Func action) + { + return ProcessUdpRequest((builder, _) => action(builder)); + } + + public async Task ProcessTcpRequest(Func action) { using Socket tcpClient = await _tcpSocket.AcceptAsync(); byte[] buffer = ArrayPool.Shared.Rent(8 * 1024); try { - EndPoint remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); - int bytesRead = 0; int length = -1; while (length < 0 || bytesRead < length + 2) @@ -143,7 +100,7 @@ public async Task ProcessTcpRequest(Func actio } } - int bytesWritten = await ProcessRequestCore(new ArraySegment(buffer, 2, length), action, buffer.AsMemory(2)); + int bytesWritten = await ProcessRequestCore((IPEndPoint)tcpClient.RemoteEndPoint!, new ArraySegment(buffer, 2, length), action, buffer.AsMemory(2)); BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(0, 2), (ushort)bytesWritten); await tcpClient.SendAsync(buffer.AsMemory(0, bytesWritten + 2), SocketFlags.None); } @@ -152,6 +109,11 @@ public async Task ProcessTcpRequest(Func actio ArrayPool.Shared.Return(buffer); } } + + public Task ProcessTcpRequest(Func action) + { + return ProcessTcpRequest((builder, _) => action(builder)); + } } internal sealed class LoopbackDnsResponseBuilder @@ -176,6 +138,71 @@ public LoopbackDnsResponseBuilder(string name, QueryType type, QueryClass @class public List Answers { get; } = new List(); public List Authorities { get; } = new List(); public List Additionals { get; } = new List(); + + public int Write(Memory responseBuffer) + { + DnsDataWriter writer = new(responseBuffer); + if (!writer.TryWriteHeader(new DnsMessageHeader + { + TransactionId = TransactionId, + QueryFlags = Flags | (QueryFlags)ResponseCode, + QueryCount = (ushort)Questions.Count, + AnswerCount = (ushort)Answers.Count, + AuthorityCount = (ushort)Authorities.Count, + AdditionalRecordCount = (ushort)Additionals.Count + })) + { + throw new InvalidOperationException("Failed to write header"); + } + + foreach (var (questionName, questionType, questionClass) in Questions) + { + if (!writer.TryWriteQuestion(questionName, questionType, questionClass)) + { + throw new InvalidOperationException("Failed to write question"); + } + } + + foreach (var answer in Answers) + { + if (!writer.TryWriteResourceRecord(answer)) + { + throw new InvalidOperationException("Failed to write answer"); + } + } + + foreach (var authority in Authorities) + { + if (!writer.TryWriteResourceRecord(authority)) + { + throw new InvalidOperationException("Failed to write authority"); + } + } + + foreach (var additional in Additionals) + { + if (!writer.TryWriteResourceRecord(additional)) + { + throw new InvalidOperationException("Failed to write additional records"); + } + } + + return writer.Position; + } + + public byte[] GetMessageBytes() + { + byte[] buffer = ArrayPool.Shared.Rent(512); + try + { + int bytesWritten = Write(buffer.AsMemory(0, 512)); + return buffer.AsSpan(0, bytesWritten).ToArray(); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } } internal static class LoopbackDnsServerExtensions diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs index b3c55998060..25c5baab910 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -109,7 +109,6 @@ public async Task ResolveIPv4_Aliases_OutOfOrder_Success() [Fact] public async Task ResolveIPv4_Aliases_Loop_ReturnsEmpty() { - IPAddress address = IPAddress.Parse("172.213.245.111"); _ = DnsServer.ProcessUdpRequest(builder => { builder.Answers.AddCname("www.example1.com", 3600, "www.example2.com"); @@ -123,6 +122,58 @@ public async Task ResolveIPv4_Aliases_Loop_ReturnsEmpty() Assert.Empty(results); } + [Fact] + public async Task ResolveIPv4_Aliases_Loop_Reverse_ReturnsEmpty() + { + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example3.com", 3600, "www.example1.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddCname("www.example1.com", 3600, "www.example2.com"); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example1.com", AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_Alias_And_Address() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example1.com", 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddAddress("www.example2.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example1.com", AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_DuplicateAlias() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example1.com", 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example4.com"); + builder.Answers.AddAddress("www.example2.com", 3600, address); + builder.Answers.AddAddress("www.example4.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example1.com", AddressFamily.InterNetwork); + + Assert.Empty(results); + } + [Fact] public async Task ResolveIPv4_Aliases_NotFound_Success() { @@ -167,4 +218,70 @@ public async Task Resolve_Timeout_ReturnsEmpty() AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); Assert.Empty(result); } -} + + [Theory] + [InlineData("not-example.com", (int)QueryType.A, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.AAAA, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.A, 0)] + public async Task Resolve_QuestionMismatch_ReturnsEmpty(string name, int type, int @class) + { + Options.Timeout = TimeSpan.FromSeconds(1); + + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Questions[0] = (name, (QueryType)type, (QueryClass)@class); + builder.Answers.AddAddress("www.example4.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); + Assert.Empty(result); + } + + [Fact] + public async Task Resolve_HeaderMismatch_Ignores() + { + string name = "example.com"; + Options.Timeout = TimeSpan.FromSeconds(5); + + SemaphoreSlim responseSemaphore = new SemaphoreSlim(0, 1); + SemaphoreSlim requestSemaphore = new SemaphoreSlim(0, 1); + + IPEndPoint clientAddress = null!; + + IPAddress address = IPAddress.Parse("172.213.245.111"); + ushort transactionId = 0x1234; + _ = DnsServer.ProcessUdpRequest((builder, clientAddr) => + { + clientAddress = clientAddr; + transactionId = (ushort)(builder.TransactionId + 1); + + builder.Answers.AddAddress(name, 3600, address); + requestSemaphore.Release(); + return responseSemaphore.WaitAsync(); + }); + + ValueTask task = Resolver.ResolveIPAddressesAsync(name, AddressFamily.InterNetwork); + + await requestSemaphore.WaitAsync().WaitAsync(Options.Timeout); + + using Socket socket = new Socket(clientAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + LoopbackDnsResponseBuilder responseBuilder = new LoopbackDnsResponseBuilder(name, QueryType.A, QueryClass.Internet) + { + TransactionId = transactionId, + ResponseCode = QueryResponseCode.NoError + }; + + responseBuilder.Questions.Add((name, QueryType.A, QueryClass.Internet)); + responseBuilder.Answers.AddAddress(name, 3600, IPAddress.Loopback); + socket.SendTo(responseBuilder.GetMessageBytes(), clientAddress); + + responseSemaphore.Release(); + + AddressResult[] results = await task; + AddressResult result = Assert.Single(results); + + Assert.Equal(address, result.Address); + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs index ac10d0d6d8c..28e1898a19e 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -78,4 +78,53 @@ public async Task TcpFailover_TcpNotAvailable_EmptyResult() AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); Assert.Empty(results); } + + [Fact] + public async Task TcpFailover_HeaderMismatch_ReturnsEmpty() + { + Options.Timeout = TimeSpan.FromSeconds(1); + IPAddress address = IPAddress.Parse("172.213.245.111"); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + _ = DnsServer.ProcessTcpRequest(builder => + { + builder.TransactionId++; + builder.Answers.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); + Assert.Empty(result); + } + + [Theory] + [InlineData("not-example.com", (int)QueryType.A, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.AAAA, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.A, 0)] + public async Task TcpFailover_QuestionMismatch_ReturnsEmpty(string name, int type, int @class) + { + Options.Timeout = TimeSpan.FromSeconds(1); + IPAddress address = IPAddress.Parse("172.213.245.111"); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + _ = DnsServer.ProcessTcpRequest(builder => + { + builder.Questions[0] = (name, (QueryType)type, (QueryClass)@class); + builder.Answers.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); + Assert.Empty(result); + } } From 03e6a997d4c923bebffd226cfd90e16077a72619 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 15 May 2025 16:57:04 +0200 Subject: [PATCH 32/45] Correctly handle IDN names --- .../Resolver/DnsDataReader.cs | 22 ++- .../Resolver/DnsDataWriter.cs | 57 +++++--- .../Resolver/DnsPrimitives.cs | 133 +++++++++++------- .../Resolver/DnsResolver.cs | 125 +++++++++------- .../Resolver/DnsResourceRecord.cs | 4 +- .../Resolver/EncodedDomainName.cs | 70 +++++++++ .../Resolver/DnsDataReaderTests.cs | 2 +- .../Resolver/DnsDataWriterTests.cs | 16 ++- .../Resolver/DnsPrimitivesTests.cs | 74 +++++++--- .../Resolver/LoopbackDnsServer.cs | 69 ++++++++- .../Resolver/ResolveAddressesTests.cs | 11 +- 11 files changed, 426 insertions(+), 157 deletions(-) create mode 100644 src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs index 0dbd6b90b0e..bc0d662b36c 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -4,7 +4,6 @@ using System.Buffers; using System.Buffers.Binary; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; @@ -35,7 +34,7 @@ public bool TryReadHeader(out DnsMessageHeader header) return true; } - internal bool TryReadQuestion([NotNullWhen(true)] out string? name, out QueryType type, out QueryClass @class) + internal bool TryReadQuestion(out EncodedDomainName name, out QueryType type, out QueryClass @class) { if (!TryReadDomainName(out name) || !TryReadUInt16(out ushort typeAsInt) || @@ -79,7 +78,7 @@ public bool TryReadUInt32(out uint value) public bool TryReadResourceRecord(out DnsResourceRecord record) { - if (!TryReadDomainName(out string? name) || + if (!TryReadDomainName(out EncodedDomainName name) || !TryReadUInt16(out ushort type) || !TryReadUInt16(out ushort @class) || !TryReadUInt32(out uint ttl) || @@ -97,9 +96,9 @@ record = default; return true; } - public bool TryReadDomainName([NotNullWhen(true)] out string? name) + public bool TryReadDomainName(out EncodedDomainName name) { - if (DnsPrimitives.TryReadQName(MessageBuffer.AsSpan(), _position, out name, out int bytesRead)) + if (DnsPrimitives.TryReadQName(MessageBuffer, _position, out name, out int bytesRead)) { _position += bytesRead; return true; @@ -108,6 +107,19 @@ public bool TryReadDomainName([NotNullWhen(true)] out string? name) return false; } + public bool TryReadSpan(int length, out ReadOnlySpan name) + { + if (MessageBuffer.Count - _position < length) + { + name = default; + return false; + } + + name = MessageBuffer.AsSpan(_position, length); + _position += length; + return true; + } + public void Dispose() { if (!_returnToPool || MessageBuffer.Array == null) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs index 4b116763ceb..a0a11f0b808 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers.Binary; +using System.Diagnostics; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; @@ -29,7 +30,7 @@ internal bool TryWriteHeader(in DnsMessageHeader header) return true; } - internal bool TryWriteQuestion(string name, QueryType type, QueryClass @class) + internal bool TryWriteQuestion(EncodedDomainName name, QueryType type, QueryClass @class) { if (!TryWriteDomainName(name) || !TryWriteUInt16((ushort)type) || @@ -41,28 +42,22 @@ internal bool TryWriteQuestion(string name, QueryType type, QueryClass @class) return true; } - internal bool TryWriteResourceRecord(in DnsResourceRecord record) + private bool TryWriteDomainName(EncodedDomainName name) { - if (!TryWriteDomainName(record.Name) || - !TryWriteUInt16((ushort)record.Type) || - !TryWriteUInt16((ushort)record.Class) || - !TryWriteUInt32((uint)record.Ttl)) + foreach (var label in name.Labels) { - return false; + // this should be already validated by the caller + Debug.Assert(label.Length <= 63, "Label length must not exceed 63 bytes."); + + if (!TryWriteByte((byte)label.Length) || + !TryWriteRawData(label.Span)) + { + return false; + } } - if (record.Data.Length + 2 > _buffer.Length - _position) - { - return false; - } - - BinaryPrimitives.WriteUInt16BigEndian(_buffer.Span.Slice(_position), (ushort)record.Data.Length); - _position += 2; - - record.Data.Span.CopyTo(_buffer.Span.Slice(_position)); - _position += record.Data.Length; - - return true; + // root label + return TryWriteByte(0); } internal bool TryWriteDomainName(string name) @@ -76,6 +71,18 @@ internal bool TryWriteDomainName(string name) return false; } + internal bool TryWriteByte(byte value) + { + if (_buffer.Length - _position < 1) + { + return false; + } + + _buffer.Span[_position] = value; + _position += 1; + return true; + } + internal bool TryWriteUInt16(ushort value) { if (_buffer.Length - _position < 2) @@ -99,4 +106,16 @@ internal bool TryWriteUInt32(uint value) _position += 4; return true; } + + internal bool TryWriteRawData(ReadOnlySpan value) + { + if (_buffer.Length - _position < value.Length) + { + return false; + } + + value.CopyTo(_buffer.Span.Slice(_position)); + _position += value.Length; + return true; + } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs index 47db4f7ac6f..e549abe2576 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs @@ -1,9 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Buffers.Binary; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.Text; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; @@ -61,8 +61,12 @@ internal static bool TryWriteMessageHeader(Span buffer, DnsMessageHeader h // labels 63 octets or less // name 255 octets or less + private static readonly SearchValues s_domainNameValidChars = SearchValues.Create("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_."); + private static readonly IdnMapping s_idnMapping = new IdnMapping(); internal static bool TryWriteQName(Span destination, string name, out int written) { + written = 0; + // // RFC 1035 4.1.2. // @@ -73,36 +77,68 @@ internal static bool TryWriteQName(Span destination, string name, out int // that this field may be an odd number of octets; no // padding is used. // + if (!Ascii.IsValid(name)) + { + // IDN name, apply punycode + try + { + // IdnMapping performs some validation internally (such as label + // and domain name lengths), but is more relaxed than RFC + // 1035 (e.g. allows ~ chars), so even if this conversion does + // not throw, we still need to perform additional validation + name = s_idnMapping.GetAscii(name); + } + catch + { + return false; + } + } - // The is assumed to be already validated and puny-encoded if needed - Debug.Assert(name.Length <= MaxDomainNameLength); - Debug.Assert(Ascii.IsValid(name)); - - if (destination.IsEmpty || !Encoding.ASCII.TryGetBytes(name, destination.Slice(1), out int length) || destination.Length < length + 2) + if (name.Length > MaxDomainNameLength || + name.AsSpan().ContainsAnyExcept(s_domainNameValidChars) || + destination.IsEmpty || + !Encoding.ASCII.TryGetBytes(name, destination.Slice(1), out int length) || + destination.Length < length + 2) { // buffer too small - written = 0; return false; } - destination[1 + length] = 0; // last label (root) - Span nameBuffer = destination.Slice(0, 1 + length); + Span label; while (true) { // figure out the next label and prepend the length int index = nameBuffer.Slice(1).IndexOf((byte)'.'); - int labelLen = index == -1 ? nameBuffer.Length - 1 : index; + label = index == -1 ? nameBuffer.Slice(1) : nameBuffer.Slice(1, index); - // https://www.rfc-editor.org/rfc/rfc1035#section-2.3.4 - // labels 63 octets or less - if (labelLen > 63) + if (label.Length == 0) { - // this should never happen, as we validate the name before calling this method - throw new ArgumentException("Label is too long"); + // empty label (explicit root) is only allowed at the end + if (index != -1) + { + written = 0; + return false; + } + } + // Label restrictions: + // - maximum 63 octets long + // - must start with a letter or digit (digit is allowed by RFC 1123) + // - may start with an underscore (underscore may be present only + // at the start of the label to support SRV records) + // - must end with a letter or digit + else if (label.Length > 63 || + !char.IsAsciiLetterOrDigit((char)label[0]) && label[0] != '_' || + label.Slice(1).Contains((byte)'_') || + !char.IsAsciiLetterOrDigit((char)label[^1])) + { + written = 0; + return false; } - nameBuffer[0] = (byte)labelLen; + nameBuffer[0] = (byte)label.Length; + written += label.Length + 1; + if (index == -1) { // this was the last label @@ -112,11 +148,17 @@ internal static bool TryWriteQName(Span destination, string name, out int nameBuffer = nameBuffer.Slice(index + 1); } - written = length + 2; + // Add root label if wasn't explicitly specified + if (label.Length != 0) + { + destination[written] = 0; + written++; + } + return true; } - private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messageBuffer, int offset, out int bytesRead, bool canStartWithPointer = true) + private static bool TryReadQNameCore(List> labels, int totalLength, ReadOnlyMemory messageBuffer, int offset, out int bytesRead, bool canStartWithPointer = true) { // // domain name can be either @@ -145,7 +187,7 @@ private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messag while (true) { - byte length = messageBuffer[currentOffset]; + byte length = messageBuffer.Span[currentOffset]; if ((length & 0xC0) == 0x00) { @@ -164,14 +206,11 @@ private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messag } // read next label/segment - if (sb.Length > 0) - { - sb.Append('.'); - } - - sb.Append(Encoding.ASCII.GetString(messageBuffer.Slice(currentOffset + 1, length))); + labels.Add(messageBuffer.Slice(currentOffset + 1, length)); + totalLength += 1 + length; - if (sb.Length > MaxDomainNameLength) + // subtract one for the length prefix of the first label + if (totalLength - 1 > MaxDomainNameLength) { // domain name is too long return false; @@ -193,7 +232,7 @@ private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messag } bytesRead += 2; - int pointer = ((length & 0x3F) << 8) | messageBuffer[currentOffset + 1]; + int pointer = ((length & 0x3F) << 8) | messageBuffer.Span[currentOffset + 1]; // we prohibit self-references and forward pointers to avoid // infinite loops, we do this by truncating the @@ -201,7 +240,7 @@ private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messag // name. We also ignore the bytesRead from the recursive // call, as we are only interested on how many bytes we read // from the initial start of the name. - return TryReadQNameCore(sb, messageBuffer.Slice(0, offset), pointer, out int _, false); + return TryReadQNameCore(labels, totalLength, messageBuffer.Slice(0, offset), pointer, out int _, false); } else { @@ -214,32 +253,32 @@ private static bool TryReadQNameCore(StringBuilder sb, ReadOnlySpan messag } - internal static bool TryReadQName(ReadOnlySpan messageBuffer, int offset, [NotNullWhen(true)] out string? name, out int bytesRead) + internal static bool TryReadQName(ReadOnlyMemory messageBuffer, int offset, out EncodedDomainName name, out int bytesRead) { - StringBuilder sb = new StringBuilder(); + List> labels = new List>(); - if (TryReadQNameCore(sb, messageBuffer, offset, out bytesRead)) + if (TryReadQNameCore(labels, 0, messageBuffer, offset, out bytesRead)) { - name = sb.ToString(); + name = new EncodedDomainName(labels); return true; } else { bytesRead = 0; - name = null; + name = default; return false; } } - internal static bool TryReadService(ReadOnlySpan buffer, out ushort priority, out ushort weight, out ushort port, [NotNullWhen(true)] out string? target, out int bytesRead) + internal static bool TryReadService(ReadOnlyMemory buffer, out ushort priority, out ushort weight, out ushort port, out EncodedDomainName target, out int bytesRead) { // https://www.rfc-editor.org/rfc/rfc2782 - if (!BinaryPrimitives.TryReadUInt16BigEndian(buffer, out priority) || - !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Slice(2), out weight) || - !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Slice(4), out port) || + if (!BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span, out priority) || + !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span.Slice(2), out weight) || + !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span.Slice(4), out port) || !TryReadQName(buffer.Slice(6), 0, out target, out bytesRead)) { - target = null; + target = default; priority = 0; weight = 0; port = 0; @@ -251,19 +290,19 @@ internal static bool TryReadService(ReadOnlySpan buffer, out ushort priori return true; } - internal static bool TryReadSoa(ReadOnlySpan buffer, [NotNullWhen(true)] out string? primaryNameServer, [NotNullWhen(true)] out string? responsibleMailAddress, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out int bytesRead) + internal static bool TryReadSoa(ReadOnlyMemory buffer, out EncodedDomainName primaryNameServer, out EncodedDomainName responsibleMailAddress, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out int bytesRead) { // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 if (!TryReadQName(buffer, 0, out primaryNameServer, out int w1) || !TryReadQName(buffer.Slice(w1), 0, out responsibleMailAddress, out int w2) || - !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2), out serial) || - !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2 + 4), out refresh) || - !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2 + 8), out retry) || - !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2 + 12), out expire) || - !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Slice(w1 + w2 + 16), out minimum)) + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2), out serial) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 4), out refresh) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 8), out retry) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 12), out expire) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 16), out minimum)) { - primaryNameServer = null!; - responsibleMailAddress = null!; + primaryNameServer = default; + responsibleMailAddress = default; serial = 0; refresh = 0; retry = 0; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 8b06cbff55b..59fa10fe26b 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -5,7 +5,6 @@ using System.Buffers.Binary; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Globalization; using System.Net; using System.Net.Sockets; using System.Runtime.InteropServices; @@ -66,11 +65,18 @@ public ValueTask ResolveServiceAsync(string name, CancellationT ObjectDisposedException.ThrowIf(_disposed, this); cancellationToken.ThrowIfCancellationRequested(); - name = GetNormalizedHostName(name); - - return SendQueryWithTelemetry(name, QueryType.SRV, ProcessResponse, cancellationToken); + byte[] buffer = ArrayPool.Shared.Rent(256); + try + { + EncodedDomainName dnsSafeName = GetNormalizedHostName(name, buffer); + return SendQueryWithTelemetry(name, dnsSafeName, QueryType.SRV, ProcessResponse, cancellationToken); + } + finally + { + ArrayPool.Shared.Return(buffer); + } - static (SendQueryError, ServiceResult[]) ProcessResponse(string name, QueryType queryType, DnsResponse response) + static (SendQueryError, ServiceResult[]) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response) { var results = new List(response.Answers.Count); @@ -78,7 +84,7 @@ public ValueTask ResolveServiceAsync(string name, CancellationT { if (answer.Type == QueryType.SRV) { - if (!DnsPrimitives.TryReadService(answer.Data.Span, out ushort priority, out ushort weight, out ushort port, out string? target, out int bytesRead) || bytesRead != answer.Data.Length) + if (!DnsPrimitives.TryReadService(answer.Data, out ushort priority, out ushort weight, out ushort port, out EncodedDomainName target, out int bytesRead) || bytesRead != answer.Data.Length) { return (SendQueryError.MalformedResponse, []); } @@ -98,13 +104,13 @@ public ValueTask ResolveServiceAsync(string name, CancellationT // // A Target of "." means that the service is decidedly not // available at this domain. - if (additional.Name == target && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA)) + if (additional.Name.Equals(target) && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA)) { addresses.Add(new AddressResult(response.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span))); } } - results.Add(new ServiceResult(response.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target!, addresses.ToArray())); + results.Add(new ServiceResult(response.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target.ToString(), addresses.ToArray())); } } @@ -171,11 +177,19 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa return ValueTask.FromResult([]); } - name = GetNormalizedHostName(name); - var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; - return SendQueryWithTelemetry(name, queryType, ProcessResponse, cancellationToken); + byte[] buffer = ArrayPool.Shared.Rent(256); + try + { + EncodedDomainName dnsSafeName = GetNormalizedHostName(name, buffer); + var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; + return SendQueryWithTelemetry(name, dnsSafeName, queryType, ProcessResponse, cancellationToken); + } + finally + { + ArrayPool.Shared.Return(buffer); + } - static (SendQueryError error, AddressResult[] result) ProcessResponse(string name, QueryType queryType, DnsResponse response) + static (SendQueryError error, AddressResult[] result) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response) { List results = new List(response.Answers.Count); @@ -192,19 +206,19 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa // scan first and fallback to a slower but more robust method if necessary. bool success = true; - string currentAlias = name; + EncodedDomainName currentAlias = dnsSafeName; foreach (var answer in response.Answers) { switch (answer.Type) { case QueryType.CNAME: - if (!TryReadTarget(answer, response.RawMessageBytes, out string? target)) + if (!TryReadTarget(answer, response.RawMessageBytes, out EncodedDomainName target)) { return (SendQueryError.MalformedResponse, []); } - if (string.Equals(answer.Name, currentAlias, StringComparison.OrdinalIgnoreCase)) + if (answer.Name.Equals(currentAlias)) { currentAlias = target; continue; @@ -218,7 +232,7 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa return (SendQueryError.MalformedResponse, []); } - if (string.Equals(answer.Name, currentAlias, StringComparison.OrdinalIgnoreCase)) + if (answer.Name.Equals(currentAlias)) { results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), address)); continue; @@ -240,14 +254,14 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa // more expensive path for uncommon (but valid) cases where CNAME records are out of order. Use of Dictionary // allows us to stay within O(n) complexity for the number of answers, but we will use more memory. - Dictionary aliasMap = new Dictionary(StringComparer.OrdinalIgnoreCase); - Dictionary> aRecordMap = new Dictionary>(StringComparer.OrdinalIgnoreCase); + Dictionary aliasMap = new(); + Dictionary> aRecordMap = new(); foreach (var answer in response.Answers) { if (answer.Type == QueryType.CNAME) { // map the alias to the target name - if (!TryReadTarget(answer, response.RawMessageBytes, out string? target)) + if (!TryReadTarget(answer, response.RawMessageBytes, out EncodedDomainName target)) { return (SendQueryError.MalformedResponse, []); } @@ -278,8 +292,8 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa // follow the CNAME chain, limit the maximum number of iterations to avoid infinite loops. int i = 0; - currentAlias = name; - while (aliasMap.TryGetValue(currentAlias, out string? nextAlias)) + currentAlias = dnsSafeName; + while (aliasMap.TryGetValue(currentAlias, out EncodedDomainName nextAlias)) { if (i >= aliasMap.Count) { @@ -302,11 +316,11 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa aRecordMap.TryGetValue(currentAlias, out List? finalAddressList); return (SendQueryError.NoError, finalAddressList?.ToArray() ?? []); - static bool TryReadTarget(in DnsResourceRecord record, ArraySegment messageBytes, [NotNullWhen(true)] out string? target) + static bool TryReadTarget(in DnsResourceRecord record, ArraySegment messageBytes, out EncodedDomainName target) { Debug.Assert(record.Type == QueryType.CNAME, "Only CNAME records should be processed here."); - target = null; + target = default; // some servers use domain name compression even inside CNAME records. In order to decode those // correctly, we need to pass the entire message to TryReadQName. The Data span inside the record @@ -318,7 +332,7 @@ static bool TryReadTarget(in DnsResourceRecord record, ArraySegment messag int messageOffset = messageBytes.Offset; - bool result = DnsPrimitives.TryReadQName(segment.Array.AsSpan(messageOffset, segment.Offset + segment.Count - messageOffset), segment.Offset, out string? targetName, out int bytesRead) && bytesRead == record.Data.Length; + bool result = DnsPrimitives.TryReadQName(segment.Array.AsMemory(messageOffset, segment.Offset + segment.Count - messageOffset), segment.Offset, out EncodedDomainName targetName, out int bytesRead) && bytesRead == record.Data.Length; if (result) { target = targetName; @@ -344,10 +358,10 @@ static bool TryReadAddress(in DnsResourceRecord record, QueryType type, [NotNull } } - private async ValueTask SendQueryWithTelemetry(string name, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) + private async ValueTask SendQueryWithTelemetry(string name, EncodedDomainName dnsSafeName, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) { NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType, _timeProvider.GetTimestamp()); - (SendQueryError error, TResult[] result) = await SendQueryWithRetriesAsync(name, queryType, processResponseFunc, cancellationToken).ConfigureAwait(false); + (SendQueryError error, TResult[] result) = await SendQueryWithRetriesAsync(name, dnsSafeName, queryType, processResponseFunc, cancellationToken).ConfigureAwait(false); Telemetry.StopNameResolution(name, queryType, activity, null, error, _timeProvider.GetTimestamp()); return result; @@ -359,7 +373,7 @@ internal struct SendQueryResult public SendQueryError Error; } - async ValueTask<(SendQueryError error, TResult[] result)> SendQueryWithRetriesAsync(string name, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) + async ValueTask<(SendQueryError error, TResult[] result)> SendQueryWithRetriesAsync(string name, EncodedDomainName dnsSafeName, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) { SendQueryError lastError = SendQueryError.InternalError; // will be overwritten by the first attempt for (int index = 0; index < _options.Servers.Count; index++) @@ -375,7 +389,7 @@ internal struct SendQueryResult try { - SendQueryResult queryResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, queryType, attempt, cancellationToken).ConfigureAwait(false); + SendQueryResult queryResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, dnsSafeName, queryType, attempt, cancellationToken).ConfigureAwait(false); lastError = queryResult.Error; response = queryResult.Response; @@ -383,7 +397,7 @@ internal struct SendQueryResult { // Given that result.Error is NoError, there should be at least one answer. Debug.Assert(response.Answers.Count > 0); - (lastError, results) = processResponseFunc(name, queryType, queryResult.Response); + (lastError, results) = processResponseFunc(dnsSafeName, queryType, queryResult.Response); } } catch (SocketException ex) @@ -463,13 +477,13 @@ internal struct SendQueryResult return (lastError, []); } - internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, int attempt, CancellationToken cancellationToken) + internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, EncodedDomainName dnsSafeName, QueryType queryType, int attempt, CancellationToken cancellationToken) { (CancellationTokenSource cts, bool disposeTokenSource, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken); try { - return await SendQueryToServerAsync(serverEndPoint, name, queryType, attempt, cts.Token).ConfigureAwait(false); + return await SendQueryToServerAsync(serverEndPoint, name, dnsSafeName, queryType, attempt, cts.Token).ConfigureAwait(false); } catch (OperationCanceledException) when ( !cancellationToken.IsCancellationRequested && // not cancelled by the caller @@ -495,13 +509,13 @@ internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEn } } - private async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, int attempt, CancellationToken cancellationToken) + private async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, EncodedDomainName dnsSafeName, QueryType queryType, int attempt, CancellationToken cancellationToken) { Log.Query(_logger, queryType, name, serverEndPoint, attempt); SendQueryError sendError = SendQueryError.NoError; DateTime queryStartedTime = _timeProvider.GetUtcNow().DateTime; - (DnsDataReader responseReader, DnsMessageHeader header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); + (DnsDataReader responseReader, DnsMessageHeader header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); try { @@ -510,7 +524,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve Log.ResultTruncated(_logger, queryType, name, serverEndPoint, 0); responseReader.Dispose(); // TCP fallback - (responseReader, header, sendError) = await SendDnsQueryCoreTcpAsync(serverEndPoint, name, queryType, cancellationToken).ConfigureAwait(false); + (responseReader, header, sendError) = await SendDnsQueryCoreTcpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); } if (sendError != SendQueryError.NoError) @@ -532,7 +546,7 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve // Recheck that the server echoes back the DNS question if (header.QueryCount != 1 || !responseReader.TryReadQuestion(out var qName, out var qType, out var qClass) || - qName != name || qType != queryType || qClass != QueryClass.Internet) + !dnsSafeName.Equals(qName) || qType != queryType || qClass != QueryClass.Internet) { // DNS Question mismatch return new SendQueryResult @@ -611,7 +625,7 @@ internal static bool GetNegativeCacheExpiration(DateTime createdAt, List r.Type == QueryType.SOA); - if (soa != null && DnsPrimitives.TryReadSoa(soa.Value.Data.Span, out string? mname, out string? rname, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out _)) + if (soa != null && DnsPrimitives.TryReadSoa(soa.Value.Data, out _, out _, out _, out _, out _, out _, out uint minimum, out _)) { expiration = createdAt.AddSeconds(Math.Min(minimum, soa.Value.Ttl)); return true; @@ -674,13 +688,13 @@ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, return SendQueryError.ServerError; } - internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreUdpAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, CancellationToken cancellationToken) + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreUdpAsync(IPEndPoint serverEndPoint, EncodedDomainName dnsSafeName, QueryType queryType, CancellationToken cancellationToken) { var buffer = ArrayPool.Shared.Rent(512); try { Memory memory = buffer; - (ushort transactionId, int length) = EncodeQuestion(memory, name, queryType); + (ushort transactionId, int length) = EncodeQuestion(memory, dnsSafeName, queryType); using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Dgram, ProtocolType.Udp); await socket.SendToAsync(memory.Slice(0, length), SocketFlags.None, serverEndPoint, cancellationToken).ConfigureAwait(false); @@ -723,13 +737,13 @@ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, } } - internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header, SendQueryError error)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, string name, QueryType queryType, CancellationToken cancellationToken) + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header, SendQueryError error)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, EncodedDomainName dnsSafeName, QueryType queryType, CancellationToken cancellationToken) { var buffer = ArrayPool.Shared.Rent(8 * 1024); try { // When sending over TCP, the message is prefixed by 2B length - (ushort transactionId, int length) = EncodeQuestion(buffer.AsMemory(2), name, queryType); + (ushort transactionId, int length) = EncodeQuestion(buffer.AsMemory(2), dnsSafeName, queryType); BinaryPrimitives.WriteUInt16BigEndian(buffer, (ushort)length); using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); @@ -787,7 +801,7 @@ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, } } - private static (ushort id, int length) EncodeQuestion(Memory buffer, string name, QueryType queryType) + private static (ushort id, int length) EncodeQuestion(Memory buffer, EncodedDomainName dnsSafeName, QueryType queryType) { DnsMessageHeader header = new DnsMessageHeader { @@ -798,7 +812,7 @@ private static (ushort id, int length) EncodeQuestion(Memory buffer, strin DnsDataWriter writer = new DnsDataWriter(buffer); if (!writer.TryWriteHeader(header) || - !writer.TryWriteQuestion(name, queryType, QueryClass.Internet)) + !writer.TryWriteQuestion(dnsSafeName, queryType, QueryClass.Internet)) { // should never happen since we validated the name length before throw new InvalidOperationException("Buffer too small"); @@ -850,11 +864,28 @@ public void Dispose() return (pendingRequestsCts, DisposeTokenSource: false, pendingRequestsCts); } - private static readonly IdnMapping s_idnMapping = new IdnMapping(); - - private static string GetNormalizedHostName(string name) + private static EncodedDomainName GetNormalizedHostName(string name, Memory buffer) { - // TODO: better exception message - return s_idnMapping.GetAscii(name); + if (!DnsPrimitives.TryWriteQName(buffer.Span, name, out _)) + { + throw new ArgumentException($"'{name}' is not a valid DNS name.", nameof(name)); + } + + List> labels = new(); + while (true) + { + int len = buffer.Span[0]; + + if (len == 0) + { + // root label, we are finished + break; + } + + labels.Add(buffer.Slice(1, len)); + buffer = buffer.Slice(len + 1); + } + + return new EncodedDomainName(labels); } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs index 929fa893fd5..914ff9aac17 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs @@ -5,13 +5,13 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; internal struct DnsResourceRecord { - public string Name { get; } + public EncodedDomainName Name { get; } public QueryType Type { get; } public QueryClass Class { get; } public int Ttl { get; } public ReadOnlyMemory Data { get; } - public DnsResourceRecord(string name, QueryType type, QueryClass @class, int ttl, ReadOnlyMemory data) + public DnsResourceRecord(EncodedDomainName name, QueryType type, QueryClass @class, int ttl, ReadOnlyMemory data) { Name = name; Type = type; diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs new file mode 100644 index 00000000000..153e11dec20 --- /dev/null +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct EncodedDomainName : IEquatable +{ + public IReadOnlyList> Labels { get; } + + public EncodedDomainName(List> labels) + { + Labels = labels; + } + + public override string ToString() + { + StringBuilder sb = new StringBuilder(); + + foreach (var label in Labels) + { + if (sb.Length > 0) + { + sb.Append('.'); + } + sb.Append(Encoding.ASCII.GetString(label.Span)); + } + + return sb.ToString(); + } + + public bool Equals(EncodedDomainName other) + { + if (Labels.Count != other.Labels.Count) + { + return false; + } + + for (int i = 0; i < Labels.Count; i++) + { + if (!Ascii.EqualsIgnoreCase(Labels[i].Span, other.Labels[i].Span)) + { + return false; + } + } + + return true; + } + + public override bool Equals(object? obj) + { + return obj is EncodedDomainName other && Equals(other); + } + + public override int GetHashCode() + { + HashCode hash = new HashCode(); + + foreach (var label in Labels) + { + foreach (byte b in label.Span) + { + hash.Add((byte)char.ToLower((char)b)); + } + } + + return hash.ToHashCode(); + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs index 241d6a8a7a9..aad32fe785f 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs @@ -29,7 +29,7 @@ public void ReadResourceRecord_Success() DnsDataReader reader = new DnsDataReader(buffer); Assert.True(reader.TryReadResourceRecord(out DnsResourceRecord record)); - Assert.Equal("www.example.com", record.Name); + Assert.Equal("www.example.com", record.Name.ToString()); Assert.Equal(QueryType.A, record.Type); Assert.Equal(QueryClass.Internet, record.Class); Assert.Equal(3600, record.Ttl); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs index 5adbff0c8ac..b2039ce5a4c 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs @@ -26,7 +26,7 @@ public void WriteResourceRecord_Success() 0x00, 0x00, 0x00, 0x00 ]; - DnsResourceRecord record = new DnsResourceRecord("www.example.com", QueryType.A, QueryClass.Internet, 3600, new byte[4]); + DnsResourceRecord record = new DnsResourceRecord(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet, 3600, new byte[4]); byte[] buffer = new byte[512]; DnsDataWriter writer = new DnsDataWriter(buffer); @@ -53,7 +53,7 @@ public void WriteResourceRecord_Truncated_Fails() 0x00, 0x00, 0x00, 0x00 ]; - DnsResourceRecord record = new DnsResourceRecord("www.example.com", QueryType.A, QueryClass.Internet, 3600, new byte[4]); + DnsResourceRecord record = new DnsResourceRecord(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet, 3600, new byte[4]); byte[] buffer = new byte[512]; for (int i = 0; i < expected.Length; i++) @@ -78,7 +78,7 @@ public void WriteQuestion_Success() byte[] buffer = new byte[512]; DnsDataWriter writer = new DnsDataWriter(buffer); - Assert.True(writer.TryWriteQuestion("www.example.com", QueryType.A, QueryClass.Internet)); + Assert.True(writer.TryWriteQuestion(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet)); Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); } @@ -99,7 +99,7 @@ public void WriteQuestion_Truncated_Fails() for (int i = 0; i < expected.Length; i++) { DnsDataWriter writer = new DnsDataWriter(buffer.AsMemory(0, i)); - Assert.False(writer.TryWriteQuestion("www.example.com", QueryType.A, QueryClass.Internet)); + Assert.False(writer.TryWriteQuestion(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet)); } } @@ -137,4 +137,12 @@ public void WriteHeader_Success() Assert.True(writer.TryWriteHeader(header)); Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); } + + private static EncodedDomainName EncodeDomainName(string name) + { + byte[] nameBuffer = new byte[512]; + Assert.True(DnsPrimitives.TryWriteQName(nameBuffer, name, out int nameLength)); + Assert.True(DnsPrimitives.TryReadQName(nameBuffer.AsMemory(0, nameLength), 0, out EncodedDomainName encodedDomainName, out _)); + return encodedDomainName; + } } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs index 32d4ee02ce9..6733a553bad 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs @@ -29,11 +29,11 @@ public void TryWriteQName_Success(string name, byte[] expected) } [Fact] - public void TryWriteQName_LabelTooLong_Throws() + public void TryWriteQName_LabelTooLong_False() { byte[] buffer = new byte[512]; - Assert.Throws(() => DnsPrimitives.TryWriteQName(buffer, new string('a', 70), out int written)); + Assert.False(DnsPrimitives.TryWriteQName(buffer, new string('a', 70), out _)); } [Fact] @@ -48,19 +48,49 @@ public void TryWriteQName_BufferTooShort_Fails() } } + [Theory] + [InlineData("www.-0.com")] + [InlineData("www.-a.com")] + [InlineData("www.a-.com")] + [InlineData("www.a_a.com")] + [InlineData("www.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.com")] // 64 occurrences of 'a' (too long) + [InlineData("www.a~a.com")] // 64 occurrences of 'a' (too long) + [InlineData("www..com")] + [InlineData("www..")] + public void TryWriteQName_InvalidName_ReturnsFalse(string name) + { + byte[] buffer = new byte[512]; + Assert.False(DnsPrimitives.TryWriteQName(buffer, name, out _)); + } + + [Fact] + public void TryWriteQName_ExplicitRoot_Success() + { + string name1 = "www.example.com"; + string name2 = "www.example.com."; + + byte[] buffer1 = new byte[512]; + byte[] buffer2 = new byte[512]; + + Assert.True(DnsPrimitives.TryWriteQName(buffer1, name1, out int written1)); + Assert.True(DnsPrimitives.TryWriteQName(buffer2, name2, out int written2)); + Assert.Equal(written1, written2); + Assert.Equal(buffer1.AsSpan().Slice(0, written1).ToArray(), buffer2.AsSpan().Slice(0, written2).ToArray()); + } + [Theory] [MemberData(nameof(QNameData))] public void TryReadQName_Success(string expected, byte[] serialized) { - Assert.True(DnsPrimitives.TryReadQName(serialized, 0, out string? actual, out int bytesRead)); - Assert.Equal(expected, actual); + Assert.True(DnsPrimitives.TryReadQName(serialized, 0, out EncodedDomainName actual, out int bytesRead)); + Assert.Equal(expected, actual.ToString()); Assert.Equal(serialized.Length, bytesRead); } [Fact] public void TryReadQName_TruncatedData_Fails() { - ReadOnlySpan data = "\x0003www\x0007example\x0003com\x0000"u8; + ReadOnlyMemory data = "\x0003www\x0007example\x0003com\x0000"u8.ToArray(); for (int i = 0; i < data.Length; i++) { @@ -72,11 +102,11 @@ public void TryReadQName_TruncatedData_Fails() public void TryReadQName_Pointer_Success() { // [7B padding], example.com. www->[ptr to example.com.] - Span data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); - data[^2] = 0xc0; + Memory data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); + data.Span[^2] = 0xc0; - Assert.True(DnsPrimitives.TryReadQName(data, data.Length - 6, out string? actual, out int bytesRead)); - Assert.Equal("www.example.com", actual); + Assert.True(DnsPrimitives.TryReadQName(data, data.Length - 6, out EncodedDomainName actual, out int bytesRead)); + Assert.Equal("www.example.com", actual.ToString()); Assert.Equal(6, bytesRead); } @@ -84,8 +114,8 @@ public void TryReadQName_Pointer_Success() public void TryReadQName_PointerTruncated_Fails() { // [7B padding], example.com. www->[ptr to example.com.] - Span data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); - data[^2] = 0xc0; + Memory data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); + data.Span[^2] = 0xc0; for (int i = 0; i < data.Length; i++) { @@ -97,8 +127,8 @@ public void TryReadQName_PointerTruncated_Fails() public void TryReadQName_ForwardPointer_Fails() { // www->[ptr to example.com], [7B padding], example.com. - Span data = "\x03www\x00\x000dpadding\x0007example\x0003com\x00"u8.ToArray(); - data[4] = 0xc0; + Memory data = "\x03www\x00\x000dpadding\x0007example\x0003com\x00"u8.ToArray(); + data.Span[4] = 0xc0; Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); } @@ -107,8 +137,8 @@ public void TryReadQName_ForwardPointer_Fails() public void TryReadQName_PointerToSelf_Fails() { // www->[ptr to www->...] - Span data = "\x0003www\0\0"u8.ToArray(); - data[4] = 0xc0; + Memory data = "\x0003www\0\0"u8.ToArray(); + data.Span[4] = 0xc0; Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); } @@ -117,11 +147,11 @@ public void TryReadQName_PointerToSelf_Fails() public void TryReadQName_PointerToPointer_Fails() { // com, example[->com], example2[->[->com]] - Span data = "\x0003com\0\x0007example\0\0\x0008example2\0\0"u8.ToArray(); - data[13] = 0xc0; - data[14] = 0x00; // -> com - data[24] = 0xc0; - data[25] = 13; // -> -> com + Memory data = "\x0003com\0\x0007example\0\0\x0008example2\0\0"u8.ToArray(); + data.Span[13] = 0xc0; + data.Span[14] = 0x00; // -> com + data.Span[24] = 0xc0; + data.Span[25] = 13; // -> -> com Assert.False(DnsPrimitives.TryReadQName(data, 15, out _, out _)); } @@ -129,8 +159,8 @@ public void TryReadQName_PointerToPointer_Fails() [Fact] public void TryReadQName_ReservedBits() { - Span data = "\x0003www\x00c0"u8.ToArray(); - data[0] = 0x40; + Memory data = "\x0003www\x00c0"u8.ToArray(); + data.Span[0] = 0x40; Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs index 12ab8bf59b3..78eac5c845c 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -3,8 +3,10 @@ using System.Buffers; using System.Buffers.Binary; +using System.Globalization; using System.Net; using System.Net.Sockets; +using System.Text; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; @@ -46,7 +48,7 @@ private static async Task ProcessRequestCore(IPEndPoint remoteEndPoint, Arr return 0; } - LoopbackDnsResponseBuilder responseBuilder = new(name, type, @class); + LoopbackDnsResponseBuilder responseBuilder = new(name.ToString(), type, @class); responseBuilder.TransactionId = header.TransactionId; responseBuilder.Flags = header.QueryFlags | QueryFlags.HasResponse; responseBuilder.ResponseCode = QueryResponseCode.NoError; @@ -155,13 +157,20 @@ public int Write(Memory responseBuffer) throw new InvalidOperationException("Failed to write header"); } + byte[] buffer = ArrayPool.Shared.Rent(512); foreach (var (questionName, questionType, questionClass) in Questions) { - if (!writer.TryWriteQuestion(questionName, questionType, questionClass)) + if (!DnsPrimitives.TryWriteQName(buffer, questionName, out int length) || + !DnsPrimitives.TryReadQName(buffer.AsMemory(0, length), 0, out EncodedDomainName encodedName, out _)) + { + throw new InvalidOperationException("Failed to encode domain name"); + } + if (!writer.TryWriteQuestion(encodedName, questionType, questionClass)) { throw new InvalidOperationException("Failed to write question"); } } + ArrayPool.Shared.Return(buffer); foreach (var answer in Answers) { @@ -207,10 +216,20 @@ public byte[] GetMessageBytes() internal static class LoopbackDnsServerExtensions { + private static readonly IdnMapping s_idnMapping = new IdnMapping(); + + private static EncodedDomainName EncodeDomainName(string name) + { + var encodedLabels = name.Split('.', StringSplitOptions.RemoveEmptyEntries).Select(label => (ReadOnlyMemory)Encoding.UTF8.GetBytes(s_idnMapping.GetAscii(label))) + .ToList(); + + return new EncodedDomainName(encodedLabels); + } + public static List AddAddress(this List records, string name, int ttl, IPAddress address) { QueryType type = address.AddressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; - records.Add(new DnsResourceRecord(name, type, QueryClass.Internet, ttl, address.GetAddressBytes())); + records.Add(new DnsResourceRecord(EncodeDomainName(name), type, QueryClass.Internet, ttl, address.GetAddressBytes())); return records; } @@ -222,7 +241,7 @@ public static List AddCname(this List reco throw new InvalidOperationException("Failed to encode domain name"); } - records.Add(new DnsResourceRecord(name, QueryType.CNAME, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + records.Add(new DnsResourceRecord(EncodeDomainName(name), QueryType.CNAME, QueryClass.Internet, ttl, buff.AsMemory(0, length))); return records; } @@ -241,7 +260,7 @@ public static List AddService(this List re length += 6; - records.Add(new DnsResourceRecord(name, QueryType.SRV, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + records.Add(new DnsResourceRecord(EncodeDomainName(name), QueryType.SRV, QueryClass.Internet, ttl, buff.AsMemory(0, length))); return records; } @@ -263,7 +282,45 @@ public static List AddStartOfAuthority(this List 63) + { + throw new InvalidOperationException("Label length exceeds maximum of 63 bytes"); + } + + if (!writer.TryWriteByte((byte)label.Length) || + !writer.TryWriteRawData(label.Span)) + { + return false; + } + } + + // root label + return writer.TryWriteByte(0); + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs index 25c5baab910..9e32fcb4cff 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -51,17 +51,20 @@ public async Task ResolveIPv4_NoSuchName_Success(bool includeSoa) Assert.Empty(results); } - [Fact] - public async Task ResolveIPv4_Simple_Success() + [Theory] + [InlineData("www.example.com")] + [InlineData("www.example.com.")] + [InlineData("www.ř.com")] + public async Task ResolveIPv4_Simple_Success(string name) { IPAddress address = IPAddress.Parse("172.213.245.111"); _ = DnsServer.ProcessUdpRequest(builder => { - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(name, 3600, address); return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(name, AddressFamily.InterNetwork); AddressResult res = Assert.Single(results); Assert.Equal(address, res.Address); From 46418fe72c697461b8985b081537a3e1e7c7d841 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 22 May 2025 16:41:58 +0200 Subject: [PATCH 33/45] Minor fixes --- .../Resolver/DnsResolver.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 59fa10fe26b..ca7a1eed6f7 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -332,7 +332,7 @@ static bool TryReadTarget(in DnsResourceRecord record, ArraySegment messag int messageOffset = messageBytes.Offset; - bool result = DnsPrimitives.TryReadQName(segment.Array.AsMemory(messageOffset, segment.Offset + segment.Count - messageOffset), segment.Offset, out EncodedDomainName targetName, out int bytesRead) && bytesRead == record.Data.Length; + bool result = DnsPrimitives.TryReadQName(segment.Array.AsMemory(messageOffset, segment.Offset + segment.Count - messageOffset), segment.Offset - messageOffset, out EncodedDomainName targetName, out int bytesRead) && bytesRead == record.Data.Length; if (result) { target = targetName; @@ -752,7 +752,7 @@ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, int responseLength = -1; int bytesRead = 0; - while (responseLength < 0 || bytesRead < length + 2) + while (responseLength < 0 || bytesRead < responseLength + 2) { int read = await socket.ReceiveAsync(buffer.AsMemory(bytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false); bytesRead += read; @@ -767,11 +767,11 @@ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, { responseLength = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); - if (responseLength > buffer.Length) + if (responseLength + 2 > buffer.Length) { // even though this is user-controlled pre-allocation, it is limited to // 64 kB, so it should be fine. - var largerBuffer = ArrayPool.Shared.Rent(responseLength); + var largerBuffer = ArrayPool.Shared.Rent(responseLength + 2); Array.Copy(buffer, largerBuffer, bytesRead); ArrayPool.Shared.Return(buffer); buffer = largerBuffer; From 4075e4fc1ecd54cce26f41ffba166d4a3accfcd9 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 23 May 2025 09:20:09 +0200 Subject: [PATCH 34/45] Add Fuzzing tests for DNS resolver --- Directory.Packages.props | 2 + eng/Versions.props | 2 + ...oft.Extensions.ServiceDiscovery.Dns.csproj | 1 + .../Resolver/DnsResolver.cs | 62 ++++++++-- .../Resolver/ResolverOptions.cs | 3 + .../.gitignore | 2 + .../Fuzzers/DnsResponseFuzzer.cs | 44 ++++++++ .../Fuzzers/EncodedDomainNameFuzzer.cs | 38 +++++++ .../Fuzzers/WriteDomainNameRoundTripFuzzer.cs | 48 ++++++++ .../GlobalUsings.cs | 5 + .../IFuzzer.cs | 10 ++ ....ServiceDiscovery.Dns.Tests.Fuzzing.csproj | 18 +++ .../Program.cs | 64 +++++++++++ .../DnsResponseFuzzer/ip-www.example.com | Bin 0 -> 141 bytes .../corpus-seed/DnsResponseFuzzer/name-error | Bin 0 -> 31 bytes .../DnsResponseFuzzer/name-error-2 | Bin 0 -> 91 bytes .../corpus-seed/DnsResponseFuzzer/no-data | Bin 0 -> 91 bytes .../DnsResponseFuzzer/server-error | Bin 0 -> 31 bytes .../WriteDomainNameRoundTripFuzzer/example | 1 + .../WriteDomainNameRoundTripFuzzer/nonascii | 1 + .../WriteDomainNameRoundTripFuzzer/toolong | 1 + .../run.ps1 | 106 ++++++++++++++++++ 22 files changed, 400 insertions(+), 8 deletions(-) create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/ip-www.example.com create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error-2 create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/no-data create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/server-error create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/example create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/nonascii create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/toolong create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/run.ps1 diff --git a/Directory.Packages.props b/Directory.Packages.props index 840338d04ef..af7fec39b9e 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -164,6 +164,8 @@ + + diff --git a/eng/Versions.props b/eng/Versions.props index 6e798fb4b75..c4544506319 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -67,6 +67,8 @@ 8.0.1 8.0.5 + + 2.2.0 diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj index 9ea318cefb1..d50c4be70df 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj @@ -27,6 +27,7 @@ + diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index ca7a1eed6f7..5eb916909a8 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -33,7 +33,11 @@ internal void SetTimeProvider(TimeProvider timeProvider) _timeProvider = timeProvider; } - public DnsResolver(TimeProvider timeProvider, ILogger logger) : this(OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() ? ResolvConf.GetOptions() : NetworkInfo.GetOptions()) + public DnsResolver(TimeProvider timeProvider, ILogger logger) : this(timeProvider, logger, OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() ? ResolvConf.GetOptions() : NetworkInfo.GetOptions()) + { + } + + public DnsResolver(TimeProvider timeProvider, ILogger logger, ResolverOptions options) : this(options) { _timeProvider = timeProvider; _logger = logger; @@ -407,8 +411,9 @@ internal struct SendQueryResult } catch (Exception ex) when (!cancellationToken.IsCancellationRequested) { + // internal error, propagate Log.QueryError(_logger, queryType, name, serverEndPoint, attempt, ex); - lastError = SendQueryError.InternalError; + throw; } switch (lastError) @@ -515,16 +520,27 @@ private async ValueTask SendQueryToServerAsync(IPEndPoint serve SendQueryError sendError = SendQueryError.NoError; DateTime queryStartedTime = _timeProvider.GetUtcNow().DateTime; - (DnsDataReader responseReader, DnsMessageHeader header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); + DnsDataReader responseReader = default; + DnsMessageHeader header; try { - if (header.IsResultTruncated) + // use transport override if provided + if (_options._transportOverride != null) { - Log.ResultTruncated(_logger, queryType, name, serverEndPoint, 0); - responseReader.Dispose(); - // TCP fallback - (responseReader, header, sendError) = await SendDnsQueryCoreTcpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); + (responseReader, header, sendError) = SendDnsQueryCustomTransport(_options._transportOverride, dnsSafeName, queryType); + } + else + { + (responseReader, header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); + + if (header.IsResultTruncated) + { + Log.ResultTruncated(_logger, queryType, name, serverEndPoint, 0); + responseReader.Dispose(); + // TCP fallback + (responseReader, header, sendError) = await SendDnsQueryCoreTcpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); + } } if (sendError != SendQueryError.NoError) @@ -688,6 +704,36 @@ internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, return SendQueryError.ServerError; } + internal static (DnsDataReader reader, DnsMessageHeader header, SendQueryError sendError) SendDnsQueryCustomTransport(Func, int, int> callback, EncodedDomainName dnsSafeName, QueryType queryType) + { + byte[] buffer = ArrayPool.Shared.Rent(2048); + try + { + (ushort transactionId, int length) = EncodeQuestion(buffer, dnsSafeName, queryType); + length = callback(buffer, length); + + DnsDataReader responseReader = new DnsDataReader(new ArraySegment(buffer, 0, length), true); + + if (!responseReader.TryReadHeader(out DnsMessageHeader header) || + header.TransactionId != transactionId || + !header.IsResponse) + { + return (default, default, SendQueryError.MalformedResponse); + } + + // transfer ownership of buffer to the caller + buffer = null!; + return (responseReader, header, SendQueryError.NoError); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreUdpAsync(IPEndPoint serverEndPoint, EncodedDomainName dnsSafeName, QueryType queryType, CancellationToken cancellationToken) { var buffer = ArrayPool.Shared.Rent(512); diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs index 673091453ad..51d03f64bfd 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs @@ -11,6 +11,9 @@ internal sealed class ResolverOptions public int Attempts = 2; public TimeSpan Timeout = TimeSpan.FromSeconds(3); + // override for testing purposes + internal Func, int, int>? _transportOverride; + public ResolverOptions(IReadOnlyList servers) { if (servers.Count == 0) diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore new file mode 100644 index 00000000000..0151cc4e360 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore @@ -0,0 +1,2 @@ +# corpuses generated by the fuzzing engine +corpuses/** \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs new file mode 100644 index 00000000000..1b180d74b9d --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +internal sealed class DnsResponseFuzzer : IFuzzer +{ + DnsResolver? _resolver; + byte[]? _buffer; + int _length; + + public void FuzzTarget(ReadOnlySpan data) + { + // lazy init + if (_resolver == null) + { + _buffer = new byte[4096]; + _resolver = new DnsResolver(new ResolverOptions(new IPEndPoint(IPAddress.Loopback, 53)) + { + Timeout = TimeSpan.FromSeconds(5), + Attempts = 1, + _transportOverride = (buffer, length) => + { + // the first two bytes are the random transaction ID, so we keep that + // and use the fuzzing payload for the rest of the DNS response + _buffer.AsSpan(0, Math.Min(_length, buffer.Length - 2)).CopyTo(buffer.Span.Slice(2)); + return _length + 2; + } + }); + } + + data.CopyTo(_buffer!); + _length = data.Length; + + // the _transportOverride makes the execution synchronous + ValueTask task = _resolver!.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork, CancellationToken.None); + Debug.Assert(task.IsCompleted, "Task should be completed synchronously"); + task.GetAwaiter().GetResult(); + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs new file mode 100644 index 00000000000..0d37bad24b3 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +internal sealed class EncodedDomainNameFuzzer : IFuzzer +{ + public void FuzzTarget(ReadOnlySpan data) + { + // first byte is the offset of the domain name, rest is the actual + // (simulated) DNS message payload + + if (data.Length < 1) + { + return; + } + + byte[] buffer = ArrayPool.Shared.Rent(data.Length); + try + { + int offset = data[0]; + data.Slice(1).CopyTo(buffer); + + if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, data.Length - 1), offset, out EncodedDomainName name, out _)) + { + return; + } + + // the domain name should be readable + _ = name.ToString(); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs new file mode 100644 index 00000000000..f657245a842 --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +internal sealed class WriteDomainNameRoundTripFuzzer : IFuzzer +{ + private static readonly System.Globalization.IdnMapping s_idnMapping = new(); + public void FuzzTarget(ReadOnlySpan data) + { + // first byte is the offset of the domain name, rest is the actual + // (simulated) DNS message payload + + byte[] buffer = ArrayPool.Shared.Rent(data.Length * 2); + + try + { + string domainName = Encoding.UTF8.GetString(data); + if (!DnsPrimitives.TryWriteQName(buffer, domainName, out int written)) + { + return; + } + + if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, written), 0, out EncodedDomainName name, out int read)) + { + return; + } + + if (read != written) + { + throw new InvalidOperationException($"Read {read} bytes, but wrote {written} bytes"); + } + + string readName = name.ToString(); + + if (!string.Equals(s_idnMapping.GetAscii(domainName).TrimEnd('.'), readName, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException($"Domain name mismatch: {readName} != {domainName}"); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs new file mode 100644 index 00000000000..2ff9d86b2ce --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs @@ -0,0 +1,5 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +global using System.Buffers; +global using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs new file mode 100644 index 00000000000..4b4c8c99b4b --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +public interface IFuzzer +{ + string Name => GetType().Name; + void FuzzTarget(ReadOnlySpan data); +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj new file mode 100644 index 00000000000..6572c27a1fa --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj @@ -0,0 +1,18 @@ + + + + $(DefaultTargetFramework) + enable + enable + Exe + + + + + + + + + + + diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs new file mode 100644 index 00000000000..22b1580d1ac --- /dev/null +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using SharpFuzz; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +public static class Program +{ + public static void Main(string[] args) + { + IFuzzer[] fuzzers = typeof(Program).Assembly.GetTypes() + .Where(t => t.IsClass && !t.IsAbstract) + .Where(t => t.GetInterfaces().Contains(typeof(IFuzzer))) + .Select(t => (IFuzzer)Activator.CreateInstance(t)!) + .OrderBy(f => f.Name, StringComparer.OrdinalIgnoreCase) + .ToArray(); + + void PrintUsage() + { + Console.Error.WriteLine($""" + Usage: + DotnetFuzzing list + DotnetFuzzing [input file/directory] + // DotnetFuzzing prepare-onefuzz + + Available fuzzers: + {string.Join(Environment.NewLine, fuzzers.Select(f => $" {f.Name}"))} + """); + } + + if (args.Length == 0) + { + PrintUsage(); + return; + } + + string arg = args[0]; + IFuzzer? fuzzer = fuzzers.FirstOrDefault(f => string.Equals(f.Name, arg, StringComparison.OrdinalIgnoreCase)); + if (fuzzer == null) + { + Console.Error.WriteLine($"Unknown fuzzer: {arg}"); + PrintUsage(); + return; + } + + string? inputFiles = args.Length > 1 ? args[1] : null; + if (string.IsNullOrEmpty(inputFiles)) + { + // no input files, let the fuzzer generate + Fuzzer.LibFuzzer.Run(fuzzer.FuzzTarget); + return; + } + + string[] files = Directory.Exists(inputFiles) + ? Directory.GetFiles(inputFiles) + : [inputFiles]; + + foreach (string inputFile in files) + { + fuzzer.FuzzTarget(File.ReadAllBytes(inputFile)); + } + } +} \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/ip-www.example.com b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/ip-www.example.com new file mode 100644 index 0000000000000000000000000000000000000000..bb40cd100c651f4d67ef95ae5d0b10edf886eb34 GIT binary patch literal 141 zcmZo{U|?imVE_W=^73-_)QZI1f}B+5X=X_(b6#o*!vS50YC{(W5!OUQ6C)#*l;Y$fw#4kj+{DZSUI(HJSlAD Date: Fri, 23 May 2025 10:46:12 +0200 Subject: [PATCH 35/45] Improve fuzzing of EncodedDomainName --- .../Fuzzers/EncodedDomainNameFuzzer.cs | 25 +++++++----------- .../ip-www.example.com | Bin 0 -> 143 bytes 2 files changed, 10 insertions(+), 15 deletions(-) create mode 100644 tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/EncodedDomainNameFuzzer/ip-www.example.com diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs index 0d37bad24b3..72f84b3c959 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs @@ -7,27 +7,22 @@ internal sealed class EncodedDomainNameFuzzer : IFuzzer { public void FuzzTarget(ReadOnlySpan data) { - // first byte is the offset of the domain name, rest is the actual - // (simulated) DNS message payload - - if (data.Length < 1) - { - return; - } - byte[] buffer = ArrayPool.Shared.Rent(data.Length); try { - int offset = data[0]; - data.Slice(1).CopyTo(buffer); + data.CopyTo(buffer); - if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, data.Length - 1), offset, out EncodedDomainName name, out _)) + // attempt to read at any offset to really stress the parser + for (int i = 0; i < data.Length; i++) { - return; - } + if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, data.Length), i, out EncodedDomainName name, out _)) + { + continue; + } - // the domain name should be readable - _ = name.ToString(); + // the domain name should be readable + _ = name.ToString(); + } } finally { diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/EncodedDomainNameFuzzer/ip-www.example.com b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/EncodedDomainNameFuzzer/ip-www.example.com new file mode 100644 index 0000000000000000000000000000000000000000..c227840c1688d3774bf148f50a7c377637489193 GIT binary patch literal 143 zcmZQzXl!6$WME+c0_O7aa`x1U#N2|MROaOTTn3;7;{hH9Rv^v5@QOhRSpgeRfo_=z zXKG4%YH?|1Nh)(*Y6-&uU5IK!7X}g5L_-rJBbJon Date: Tue, 10 Jun 2025 13:14:30 +0200 Subject: [PATCH 36/45] Remove commented out code --- playground/TestShop/TestShop.ServiceDefaults/Extensions.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/playground/TestShop/TestShop.ServiceDefaults/Extensions.cs b/playground/TestShop/TestShop.ServiceDefaults/Extensions.cs index 39a58575f75..5ec75c84763 100644 --- a/playground/TestShop/TestShop.ServiceDefaults/Extensions.cs +++ b/playground/TestShop/TestShop.ServiceDefaults/Extensions.cs @@ -24,8 +24,6 @@ public static TBuilder AddServiceDefaults(this TBuilder builder) where builder.AddDefaultHealthChecks(); builder.Services.AddServiceDiscovery(); - // builder.Services.AddServiceDiscoveryCore(); - // builder.Services.AddDnsSrvServiceEndpointProvider(); builder.Services.ConfigureHttpClientDefaults(http => { From 3c387e9a511a56d0535f67b8fcf68041f50f76a9 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Tue, 10 Jun 2025 13:50:40 +0200 Subject: [PATCH 37/45] Fix build --- .../Resolver/CancellationTests.cs | 1 - .../Resolver/LoopbackDnsTestBase.cs | 2 +- .../Resolver/ResolveAddressesTests.cs | 1 - .../Resolver/ResolveServiceTests.cs | 1 - .../Resolver/RetryTests.cs | 1 - .../Resolver/TcpFailoverTests.cs | 1 - 6 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs index b0ace03a8db..8c646ac18ee 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using Xunit; -using Xunit.Abstractions; using System.Net.Sockets; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs index 1daa77e2bec..e1785b52800 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs @@ -1,7 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using Xunit.Abstractions; +using Xunit; using System.Runtime.CompilerServices; using Microsoft.Extensions.Time.Testing; diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs index 9e32fcb4cff..e66c0d3390d 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using Xunit; -using Xunit.Abstractions; using System.Net; using System.Net.Sockets; diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs index fe599ca1789..e1cd1df2959 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using Xunit; -using Xunit.Abstractions; using System.Net; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs index 17a38b3a302..aacb50e3314 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs @@ -4,7 +4,6 @@ using System.Net; using System.Net.Sockets; using Xunit; -using Xunit.Abstractions; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs index 28e1898a19e..51de64d44eb 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using Xunit; -using Xunit.Abstractions; using System.Net; using System.Net.Sockets; From c6e3d630b6bdb06b12a54961afd194b7b1a14976 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Tue, 10 Jun 2025 18:46:22 +0200 Subject: [PATCH 38/45] Fix Yarp service discovery tests --- .../YarpServiceDiscoveryTests.cs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs index 49bbf6eabf4..c2751823c65 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs @@ -10,6 +10,7 @@ using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.Extensions.ServiceDiscovery.Yarp.Tests; @@ -231,7 +232,10 @@ public async Task ServiceDiscoveryDestinationResolverTests_Configuration_Disallo [Fact] public async Task ServiceDiscoveryDestinationResolverTests_Dns() { + DnsResolver resolver = new DnsResolver(TimeProvider.System, NullLogger.Instance); + await using var services = new ServiceCollection() + .AddSingleton(resolver) .AddServiceDiscoveryCore() .AddDnsServiceEndpointProvider() .BuildServiceProvider(); From e66229ab47a3a1cc7883fc2e8d8363e7460c7c94 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 12 Jun 2025 16:22:37 +0200 Subject: [PATCH 39/45] Lazy TCP socket allocation in LoopbackTests --- .../Resolver/LoopbackDnsServer.cs | 22 +++++++++---------- .../Resolver/TcpFailoverTests.cs | 3 --- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs index 78eac5c845c..8ac1d3f0bd8 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -12,8 +12,8 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; internal sealed class LoopbackDnsServer : IDisposable { - readonly Socket _dnsSocket; - readonly Socket _tcpSocket; + private readonly Socket _dnsSocket; + private Socket? _tcpSocket; public IPEndPoint DnsEndPoint => (IPEndPoint)_dnsSocket.LocalEndPoint!; @@ -21,21 +21,12 @@ public LoopbackDnsServer() { _dnsSocket = new(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); _dnsSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - - _tcpSocket = new(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - _tcpSocket.Bind(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)_dnsSocket.LocalEndPoint!).Port)); - _tcpSocket.Listen(); } public void Dispose() { _dnsSocket.Dispose(); - _tcpSocket.Dispose(); - } - - public void DisableTcpFallback() - { - _tcpSocket.Close(); + _tcpSocket?.Dispose(); } private static async Task ProcessRequestCore(IPEndPoint remoteEndPoint, ArraySegment message, Func action, Memory responseBuffer) @@ -83,6 +74,13 @@ public Task ProcessUdpRequest(Func action) public async Task ProcessTcpRequest(Func action) { + if (_tcpSocket is null) + { + _tcpSocket = new(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _tcpSocket.Bind(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)_dnsSocket.LocalEndPoint!).Port)); + _tcpSocket.Listen(); + } + using Socket tcpClient = await _tcpSocket.AcceptAsync(); byte[] buffer = ArrayPool.Shared.Rent(8 * 1024); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs index 51de64d44eb..328f6c60f69 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -71,9 +71,6 @@ public async Task TcpFailover_TcpNotAvailable_EmptyResult() return Task.CompletedTask; }); - // turn off TCP support the server - DnsServer.DisableTcpFallback(); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); Assert.Empty(results); } From da018f7f6220d9a35b2171518f15671a6d9fafbb Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 12 Jun 2025 16:39:10 +0200 Subject: [PATCH 40/45] LoopbackTestBaseLogging --- .../Resolver/DnsResolver.cs | 18 ++++++------------ .../Resolver/LoopbackDnsTestBase.cs | 13 +++++++------ 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 5eb916909a8..8c4a63324fd 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -25,27 +25,17 @@ internal sealed partial class DnsResolver : IDnsResolver, IDisposable private bool _disposed; private readonly ResolverOptions _options; private readonly CancellationTokenSource _pendingRequestsCts = new(); - private TimeProvider _timeProvider = TimeProvider.System; + private readonly TimeProvider _timeProvider; private readonly ILogger _logger; - internal void SetTimeProvider(TimeProvider timeProvider) - { - _timeProvider = timeProvider; - } - public DnsResolver(TimeProvider timeProvider, ILogger logger) : this(timeProvider, logger, OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() ? ResolvConf.GetOptions() : NetworkInfo.GetOptions()) { } - public DnsResolver(TimeProvider timeProvider, ILogger logger, ResolverOptions options) : this(options) + internal DnsResolver(TimeProvider timeProvider, ILogger logger, ResolverOptions options) { _timeProvider = timeProvider; _logger = logger; - } - - internal DnsResolver(ResolverOptions options) - { - _logger = NullLogger.Instance; _options = options; Debug.Assert(_options.Servers.Count > 0); @@ -56,6 +46,10 @@ internal DnsResolver(ResolverOptions options) } } + internal DnsResolver(ResolverOptions options) : this(TimeProvider.System, NullLogger.Instance, options) + { + } + internal DnsResolver(IEnumerable servers) : this(new ResolverOptions(servers.ToArray())) { } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs index e1785b52800..f76621db93c 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs @@ -2,8 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using Xunit; -using System.Runtime.CompilerServices; using Microsoft.Extensions.Time.Testing; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.DependencyInjection; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; @@ -17,9 +18,6 @@ public abstract class LoopbackDnsTestBase : IDisposable internal ResolverOptions Options { get; } protected readonly FakeTimeProvider TimeProvider; - [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "SetTimeProvider")] - static extern void MockTimeProvider(DnsResolver instance, TimeProvider provider); - public LoopbackDnsTestBase(ITestOutputHelper output) { Output = output; @@ -35,8 +33,11 @@ public LoopbackDnsTestBase(ITestOutputHelper output) DnsResolver InitializeResolver() { - var resolver = new DnsResolver(Options); - MockTimeProvider(resolver, TimeProvider); + ServiceCollection services = new(); + services.AddXunitLogging(Output); + + // construct DnsResolver manually via internal constructor which accepts ResolverOptions + var resolver = new DnsResolver(TimeProvider, services.BuildServiceProvider().GetRequiredService>(), Options); return resolver; } From 269741294dbb86e55216b436bad97948befac87a Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Thu, 12 Jun 2025 16:47:04 +0200 Subject: [PATCH 41/45] fixup! LoopbackTestBaseLogging --- .../Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj index 24faf1d8abe..f6202d17d37 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj @@ -12,6 +12,10 @@ + + + + From 214cc9be3e88db9196b34ea3457c60aea6faaa0a Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Tue, 17 Jun 2025 09:56:54 +0200 Subject: [PATCH 42/45] Downgrade SharpFuzz --- eng/Versions.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/Versions.props b/eng/Versions.props index a8797944827..54a47c40973 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -48,7 +48,7 @@ 9.0.2 - 2.2.0 + 2.1.1 From 0d891cb52ae6b59d7ff9cf95e6a7c3b63980903b Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Wed, 18 Jun 2025 11:12:46 +0200 Subject: [PATCH 43/45] Fix buffer use after returning to pool --- .../Resolver/DnsDataReader.cs | 5 +- .../Resolver/DnsResolver.cs | 74 +++++++++---------- .../Resolver/EncodedDomainName.cs | 18 ++++- .../Resolver/LoopbackDnsServer.cs | 7 ++ .../Resolver/ResolveAddressesTests.cs | 64 ++++++++++------ .../Resolver/RetryTests.cs | 50 +++++++------ .../Resolver/TcpFailoverTests.cs | 22 ++++-- 7 files changed, 144 insertions(+), 96 deletions(-) diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs index bc0d662b36c..094df3040d1 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -122,13 +122,12 @@ public bool TryReadSpan(int length, out ReadOnlySpan name) public void Dispose() { - if (!_returnToPool || MessageBuffer.Array == null) + if (_returnToPool && MessageBuffer.Array != null) { - return; // nothing to do if we are not returning to the pool + ArrayPool.Shared.Return(MessageBuffer.Array); } _returnToPool = false; - ArrayPool.Shared.Return(MessageBuffer.Array); MessageBuffer = default; } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs index 8c4a63324fd..5722356a1c3 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -63,16 +63,9 @@ public ValueTask ResolveServiceAsync(string name, CancellationT ObjectDisposedException.ThrowIf(_disposed, this); cancellationToken.ThrowIfCancellationRequested(); - byte[] buffer = ArrayPool.Shared.Rent(256); - try - { - EncodedDomainName dnsSafeName = GetNormalizedHostName(name, buffer); - return SendQueryWithTelemetry(name, dnsSafeName, QueryType.SRV, ProcessResponse, cancellationToken); - } - finally - { - ArrayPool.Shared.Return(buffer); - } + // dnsSafeName is Disposed by SendQueryWithTelemetry + EncodedDomainName dnsSafeName = GetNormalizedHostName(name); + return SendQueryWithTelemetry(name, dnsSafeName, QueryType.SRV, ProcessResponse, cancellationToken); static (SendQueryError, ServiceResult[]) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response) { @@ -175,17 +168,10 @@ public ValueTask ResolveIPAddressesAsync(string name, AddressFa return ValueTask.FromResult([]); } - byte[] buffer = ArrayPool.Shared.Rent(256); - try - { - EncodedDomainName dnsSafeName = GetNormalizedHostName(name, buffer); - var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; - return SendQueryWithTelemetry(name, dnsSafeName, queryType, ProcessResponse, cancellationToken); - } - finally - { - ArrayPool.Shared.Return(buffer); - } + // dnsSafeName is Disposed by SendQueryWithTelemetry + EncodedDomainName dnsSafeName = GetNormalizedHostName(name); + var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; + return SendQueryWithTelemetry(name, dnsSafeName, queryType, ProcessResponse, cancellationToken); static (SendQueryError error, AddressResult[] result) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response) { @@ -361,6 +347,7 @@ private async ValueTask SendQueryWithTelemetry(string name, NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType, _timeProvider.GetTimestamp()); (SendQueryError error, TResult[] result) = await SendQueryWithRetriesAsync(name, dnsSafeName, queryType, processResponseFunc, cancellationToken).ConfigureAwait(false); Telemetry.StopNameResolution(name, queryType, activity, null, error, _timeProvider.GetTimestamp()); + dnsSafeName.Dispose(); return result; } @@ -904,28 +891,41 @@ public void Dispose() return (pendingRequestsCts, DisposeTokenSource: false, pendingRequestsCts); } - private static EncodedDomainName GetNormalizedHostName(string name, Memory buffer) + private static EncodedDomainName GetNormalizedHostName(string name) { - if (!DnsPrimitives.TryWriteQName(buffer.Span, name, out _)) - { - throw new ArgumentException($"'{name}' is not a valid DNS name.", nameof(name)); - } - - List> labels = new(); - while (true) + byte[] buffer = ArrayPool.Shared.Rent(256); + try { - int len = buffer.Span[0]; + if (!DnsPrimitives.TryWriteQName(buffer, name, out _)) + { + throw new ArgumentException($"'{name}' is not a valid DNS name.", nameof(name)); + } - if (len == 0) + List> labels = new(); + Memory memory = buffer.AsMemory(); + while (true) { - // root label, we are finished - break; + int len = memory.Span[0]; + + if (len == 0) + { + // root label, we are finished + break; + } + + labels.Add(memory.Slice(1, len)); + memory = memory.Slice(len + 1); } - labels.Add(buffer.Slice(1, len)); - buffer = buffer.Slice(len + 1); + buffer = null!; // ownership transferred to the EncodedDomainName + return new EncodedDomainName(labels, buffer); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } } - - return new EncodedDomainName(labels); } } diff --git a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs index 153e11dec20..4c258cac3ac 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs @@ -1,19 +1,21 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Text; namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; -internal struct EncodedDomainName : IEquatable +internal struct EncodedDomainName : IEquatable, IDisposable { public IReadOnlyList> Labels { get; } + private byte[]? _pooledBuffer; - public EncodedDomainName(List> labels) + public EncodedDomainName(List> labels, byte[]? pooledBuffer = null) { Labels = labels; + _pooledBuffer = pooledBuffer; } - public override string ToString() { StringBuilder sb = new StringBuilder(); @@ -67,4 +69,14 @@ public override int GetHashCode() return hash.ToHashCode(); } + + public void Dispose() + { + if (_pooledBuffer != null) + { + ArrayPool.Shared.Return(_pooledBuffer); + } + + _pooledBuffer = null; + } } \ No newline at end of file diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs index 8ac1d3f0bd8..4789e21c575 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -118,12 +118,19 @@ public Task ProcessTcpRequest(Func action) internal sealed class LoopbackDnsResponseBuilder { + private static readonly SearchValues s_domainNameValidChars = SearchValues.Create("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_."); + public LoopbackDnsResponseBuilder(string name, QueryType type, QueryClass @class) { Name = name; Type = type; Class = @class; Questions.Add((name, type, @class)); + + if (name.AsSpan().ContainsAnyExcept(s_domainNameValidChars)) + { + throw new ArgumentException($"Invalid characters in domain name '{name}'"); + } } public ushort TransactionId { get; set; } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs index e66c0d3390d..b87e1362f3d 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -18,6 +18,8 @@ public ResolveAddressesTests(ITestOutputHelper output) : base(output) [InlineData(true)] public async Task ResolveIPv4_NoData_Success(bool includeSoa) { + string hostName = "nodata.test"; + _ = DnsServer.ProcessUdpRequest(builder => { if (includeSoa) @@ -27,7 +29,7 @@ public async Task ResolveIPv4_NoData_Success(bool includeSoa) return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(results); } @@ -36,6 +38,8 @@ public async Task ResolveIPv4_NoData_Success(bool includeSoa) [InlineData(true)] public async Task ResolveIPv4_NoSuchName_Success(bool includeSoa) { + string hostName = "nosuchname.test"; + _ = DnsServer.ProcessUdpRequest(builder => { builder.ResponseCode = QueryResponseCode.NameError; @@ -46,13 +50,13 @@ public async Task ResolveIPv4_NoSuchName_Success(bool includeSoa) return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(results); } [Theory] - [InlineData("www.example.com")] - [InlineData("www.example.com.")] + [InlineData("www.resolveipv4.com")] + [InlineData("www.resolveipv4.com.")] [InlineData("www.ř.com")] public async Task ResolveIPv4_Simple_Success(string name) { @@ -74,15 +78,17 @@ public async Task ResolveIPv4_Simple_Success(string name) public async Task ResolveIPv4_Aliases_InOrder_Success() { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-in-order.test"; + _ = DnsServer.ProcessUdpRequest(builder => { - builder.Answers.AddCname("www.example.com", 3600, "www.example2.com"); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); builder.Answers.AddAddress("www.example3.com", 3600, address); return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); AddressResult res = Assert.Single(results); Assert.Equal(address, res.Address); @@ -93,15 +99,17 @@ public async Task ResolveIPv4_Aliases_InOrder_Success() public async Task ResolveIPv4_Aliases_OutOfOrder_Success() { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-out-of-order.test"; + _ = DnsServer.ProcessUdpRequest(builder => { builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); builder.Answers.AddAddress("www.example3.com", 3600, address); - builder.Answers.AddCname("www.example.com", 3600, "www.example2.com"); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); AddressResult res = Assert.Single(results); Assert.Equal(address, res.Address); @@ -111,15 +119,17 @@ public async Task ResolveIPv4_Aliases_OutOfOrder_Success() [Fact] public async Task ResolveIPv4_Aliases_Loop_ReturnsEmpty() { + string hostName = "alias-loop2.test"; + _ = DnsServer.ProcessUdpRequest(builder => { - builder.Answers.AddCname("www.example1.com", 3600, "www.example2.com"); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); - builder.Answers.AddCname("www.example3.com", 3600, "www.example1.com"); + builder.Answers.AddCname("www.example3.com", 3600, hostName); return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example1.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(results); } @@ -127,15 +137,17 @@ public async Task ResolveIPv4_Aliases_Loop_ReturnsEmpty() [Fact] public async Task ResolveIPv4_Aliases_Loop_Reverse_ReturnsEmpty() { + string hostName = "alias-loop2.test"; + _ = DnsServer.ProcessUdpRequest(builder => { - builder.Answers.AddCname("www.example3.com", 3600, "www.example1.com"); + builder.Answers.AddCname("www.example3.com", 3600, hostName); builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); - builder.Answers.AddCname("www.example1.com", 3600, "www.example2.com"); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example1.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(results); } @@ -144,15 +156,17 @@ public async Task ResolveIPv4_Aliases_Loop_Reverse_ReturnsEmpty() public async Task ResolveIPv4_Alias_And_Address() { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-address.test"; + _ = DnsServer.ProcessUdpRequest(builder => { - builder.Answers.AddCname("www.example1.com", 3600, "www.example2.com"); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); builder.Answers.AddAddress("www.example2.com", 3600, address); return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example1.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(results); } @@ -161,9 +175,11 @@ public async Task ResolveIPv4_Alias_And_Address() public async Task ResolveIPv4_DuplicateAlias() { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "duplicate-alias.test"; + _ = DnsServer.ProcessUdpRequest(builder => { - builder.Answers.AddCname("www.example1.com", 3600, "www.example2.com"); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); builder.Answers.AddCname("www.example2.com", 3600, "www.example4.com"); builder.Answers.AddAddress("www.example2.com", 3600, address); @@ -171,7 +187,7 @@ public async Task ResolveIPv4_DuplicateAlias() return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example1.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(results); } @@ -180,9 +196,11 @@ public async Task ResolveIPv4_DuplicateAlias() public async Task ResolveIPv4_Aliases_NotFound_Success() { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-no-found.test"; + _ = DnsServer.ProcessUdpRequest(builder => { - builder.Answers.AddCname("www.example.com", 3600, "www.example2.com"); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); // extra address in the answer not connected to the above @@ -190,7 +208,7 @@ public async Task ResolveIPv4_Aliases_NotFound_Success() return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(results); } @@ -198,7 +216,7 @@ public async Task ResolveIPv4_Aliases_NotFound_Success() [Fact] public async Task ResolveIP_InvalidAddressFamily_Throws() { - await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.Unknown)); + await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("invalid-af.test", AddressFamily.Unknown)); } [Theory] @@ -217,7 +235,7 @@ public async Task ResolveIP_Localhost_ReturnsLoopback(AddressFamily family, stri public async Task Resolve_Timeout_ReturnsEmpty() { Options.Timeout = TimeSpan.FromSeconds(1); - AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); + AddressResult[] result = await Resolver.ResolveIPAddressesAsync("timeout-empty.test", AddressFamily.InterNetwork); Assert.Empty(result); } @@ -244,7 +262,7 @@ public async Task Resolve_QuestionMismatch_ReturnsEmpty(string name, int type, i [Fact] public async Task Resolve_HeaderMismatch_Ignores() { - string name = "example.com"; + string name = "header-mismatch.test"; Options.Timeout = TimeSpan.FromSeconds(5); SemaphoreSlim responseSemaphore = new SemaphoreSlim(0, 1); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs index aacb50e3314..800905d1ac5 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs @@ -14,9 +14,9 @@ public RetryTests(ITestOutputHelper output) : base(output) Options.Attempts = 3; } - private static void SetupUdpProcessFunction(LoopbackDnsServer server, Func func) + private Task SetupUdpProcessFunction(LoopbackDnsServer server, Func func) { - _ = Task.Run(async () => + return Task.Run(async () => { try { @@ -25,31 +25,33 @@ private static void SetupUdpProcessFunction(LoopbackDnsServer server, Func func) + private Task SetupUdpProcessFunction(Func func) { - SetupUdpProcessFunction(DnsServer, func); + return SetupUdpProcessFunction(DnsServer, func); } [Fact] public async Task Retry_Simple_Success() { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "retry-simple-success.com"; int attempt = 0; - SetupUdpProcessFunction(builder => + Task t = SetupUdpProcessFunction(builder => { attempt++; if (attempt == Options.Attempts) { - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(hostName, 3600, address); } else { @@ -58,7 +60,7 @@ public async Task Retry_Simple_Success() return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); AddressResult res = Assert.Single(results); Assert.Equal(address, res.Address); @@ -79,11 +81,12 @@ public enum PersistentErrorType public async Task PersistentErrorsResponseCode_FailoverToNextServer(PersistentErrorType type) { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "www.persistent.com"; int primaryAttempt = 0; int secondaryAttempt = 0; - AddressResult[] results = await RunWithFallbackServerHelper("www.example.com", + AddressResult[] results = await RunWithFallbackServerHelper(hostName, builder => { primaryAttempt++; @@ -106,7 +109,7 @@ public async Task PersistentErrorsResponseCode_FailoverToNextServer(PersistentEr builder => { secondaryAttempt++; - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(hostName, 3600, address); return Task.CompletedTask; }); @@ -134,11 +137,12 @@ public enum DefinitveAnswerType public async Task DefinitiveAnswers_NoRetryOrFailover(DefinitveAnswerType type, bool additionalData) { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "www.retry.com"; int primaryAttempt = 0; int secondaryAttempt = 0; - AddressResult[] results = await RunWithFallbackServerHelper("www.example.com", + AddressResult[] results = await RunWithFallbackServerHelper(hostName, builder => { primaryAttempt++; @@ -146,7 +150,7 @@ public async Task DefinitiveAnswers_NoRetryOrFailover(DefinitveAnswerType type, { case DefinitveAnswerType.NoError: builder.ResponseCode = QueryResponseCode.NoError; - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(hostName, 3600, address); break; case DefinitveAnswerType.NoData: @@ -160,7 +164,7 @@ public async Task DefinitiveAnswers_NoRetryOrFailover(DefinitveAnswerType type, if (additionalData) { - builder.Authorities.AddStartOfAuthority("www.example.com", 300, "ns1.example.com", "hostmaster.example.com", 2023101001, 1, 3600, 300, 86400); + builder.Authorities.AddStartOfAuthority(hostName, 300, "ns1.example.com", "hostmaster.example.com", 2023101001, 1, 3600, 300, 86400); } return Task.CompletedTask; @@ -191,11 +195,12 @@ public async Task DefinitiveAnswers_NoRetryOrFailover(DefinitveAnswerType type, public async Task ExhaustedRetries_FailoverToNextServer() { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "ExhaustedRetriesFailoverToNextServer"; int primaryAttempt = 0; int secondaryAttempt = 0; - AddressResult[] results = await RunWithFallbackServerHelper("www.example.com", + AddressResult[] results = await RunWithFallbackServerHelper(hostName, builder => { primaryAttempt++; @@ -205,7 +210,7 @@ public async Task ExhaustedRetries_FailoverToNextServer() builder => { secondaryAttempt++; - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(hostName, 3600, address); return Task.CompletedTask; }); @@ -230,11 +235,12 @@ public enum TransientErrorType public async Task TransientError_RetryOnSameServer(TransientErrorType type) { IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "www.transient.com"; int primaryAttempt = 0; int secondaryAttempt = 0; - AddressResult[] results = await RunWithFallbackServerHelper("www.example.com", + AddressResult[] results = await RunWithFallbackServerHelper(hostName, async builder => { primaryAttempt++; @@ -244,7 +250,7 @@ public async Task TransientError_RetryOnSameServer(TransientErrorType type) { case TransientErrorType.Timeout: await Task.Delay(Options.Timeout.Multiply(1.5)); - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(hostName, 3600, address); break; case TransientErrorType.ServerFailure: @@ -254,7 +260,7 @@ public async Task TransientError_RetryOnSameServer(TransientErrorType type) } else { - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(hostName, 3600, address); } }, builder => @@ -274,9 +280,9 @@ public async Task TransientError_RetryOnSameServer(TransientErrorType type) private async Task RunWithFallbackServerHelper(string name, Func primaryHandler, Func fallbackHandler) { - SetupUdpProcessFunction(primaryHandler); + Task t = SetupUdpProcessFunction(primaryHandler); using LoopbackDnsServer fallbackServer = new LoopbackDnsServer(); - SetupUdpProcessFunction(fallbackServer, fallbackHandler); + Task t2 = SetupUdpProcessFunction(fallbackServer, fallbackHandler); Options.Servers = [DnsServer.DnsEndPoint, fallbackServer.DnsEndPoint]; @@ -287,7 +293,7 @@ private async Task RunWithFallbackServerHelper(string name, Fun public async Task NameError_NoRetry() { int counter = 0; - SetupUdpProcessFunction(builder => + Task t = SetupUdpProcessFunction(builder => { counter++; // authoritative answer that the name does not exist @@ -295,7 +301,7 @@ public async Task NameError_NoRetry() return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("nameerror-noretry", AddressFamily.InterNetwork); Assert.Empty(results); Assert.Equal(1, counter); diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs index 328f6c60f69..40841e3d11a 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -16,7 +16,9 @@ public TcpFailoverTests(ITestOutputHelper output) : base(output) [Fact] public async Task TcpFailover_Simple_Success() { + string hostName = "tcp-simple.test"; IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => { builder.Flags |= QueryFlags.ResultTruncated; @@ -25,11 +27,11 @@ public async Task TcpFailover_Simple_Success() _ = DnsServer.ProcessTcpRequest(builder => { - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(hostName, 3600, address); return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); AddressResult res = Assert.Single(results); Assert.Equal(address, res.Address); @@ -39,6 +41,7 @@ public async Task TcpFailover_Simple_Success() [Fact] public async Task TcpFailover_ServerClosesWithoutData_EmptyResult() { + string hostName = "tcp-server-closes.test"; Options.Attempts = 1; Options.Timeout = TimeSpan.FromSeconds(60); @@ -53,7 +56,7 @@ public async Task TcpFailover_ServerClosesWithoutData_EmptyResult() throw new InvalidOperationException("This forces closing the socket without writing any data"); }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); Assert.Empty(results); await Assert.ThrowsAsync(() => serverTask); @@ -62,6 +65,7 @@ public async Task TcpFailover_ServerClosesWithoutData_EmptyResult() [Fact] public async Task TcpFailover_TcpNotAvailable_EmptyResult() { + string hostName = "tcp-not-available.test"; Options.Attempts = 1; Options.Timeout = TimeSpan.FromMilliseconds(100000); @@ -71,13 +75,14 @@ public async Task TcpFailover_TcpNotAvailable_EmptyResult() return Task.CompletedTask; }); - AddressResult[] results = await Resolver.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(results); } [Fact] public async Task TcpFailover_HeaderMismatch_ReturnsEmpty() { + string hostName = "tcp-header-mismatch.test"; Options.Timeout = TimeSpan.FromSeconds(1); IPAddress address = IPAddress.Parse("172.213.245.111"); @@ -90,11 +95,11 @@ public async Task TcpFailover_HeaderMismatch_ReturnsEmpty() _ = DnsServer.ProcessTcpRequest(builder => { builder.TransactionId++; - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(hostName, 3600, address); return Task.CompletedTask; }); - AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); + AddressResult[] result = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(result); } @@ -104,6 +109,7 @@ public async Task TcpFailover_HeaderMismatch_ReturnsEmpty() [InlineData("example.com", (int)QueryType.A, 0)] public async Task TcpFailover_QuestionMismatch_ReturnsEmpty(string name, int type, int @class) { + string hostName = "tcp-question-mismatch.test"; Options.Timeout = TimeSpan.FromSeconds(1); IPAddress address = IPAddress.Parse("172.213.245.111"); @@ -116,11 +122,11 @@ public async Task TcpFailover_QuestionMismatch_ReturnsEmpty(string name, int typ _ = DnsServer.ProcessTcpRequest(builder => { builder.Questions[0] = (name, (QueryType)type, (QueryClass)@class); - builder.Answers.AddAddress("www.example.com", 3600, address); + builder.Answers.AddAddress(hostName, 3600, address); return Task.CompletedTask; }); - AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); + AddressResult[] result = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); Assert.Empty(result); } } From 35a5809148fe232a26eb103b57457d1df73ae9a2 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Wed, 18 Jun 2025 11:57:32 +0200 Subject: [PATCH 44/45] Fix build --- eng/Versions.props | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/eng/Versions.props b/eng/Versions.props index eccd3dd6c73..2c9e9571bc8 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -50,14 +50,10 @@ 9.4.0 9.4.0 -<<<<<<< managed-resolver - 9.0.2 - - 2.1.1 -======= 9.0.6 10.0.0-preview.5.25277.114 ->>>>>>> main + + 2.1.1 From df785d1c6048783ad37ed55ac3169458c1d70a5d Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Wed, 25 Jun 2025 09:19:43 +0200 Subject: [PATCH 45/45] Remove DnsClient package ref --- Directory.Packages.props | 1 - 1 file changed, 1 deletion(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index 01a62f4c552..490e518213e 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -77,7 +77,6 @@ -