diff --git a/eng/Version.Details.xml b/eng/Version.Details.xml index e4f5ea11353..3b5a0c53b49 100644 --- a/eng/Version.Details.xml +++ b/eng/Version.Details.xml @@ -4,6 +4,10 @@ https://dev.azure.com/dnceng/internal/_git/dotnet-runtime 893c2ebbd49952ca49e93298148af2d95a61a0a4 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 893c2ebbd49952ca49e93298148af2d95a61a0a4 + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime 893c2ebbd49952ca49e93298148af2d95a61a0a4 @@ -80,6 +84,10 @@ https://dev.azure.com/dnceng/internal/_git/dotnet-runtime 893c2ebbd49952ca49e93298148af2d95a61a0a4 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 893c2ebbd49952ca49e93298148af2d95a61a0a4 + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime 893c2ebbd49952ca49e93298148af2d95a61a0a4 @@ -180,6 +188,10 @@ https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore ff66c263be7ed395794bdaf616322977b8ec897c + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + ff66c263be7ed395794bdaf616322977b8ec897c + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore ff66c263be7ed395794bdaf616322977b8ec897c diff --git a/eng/Versions.props b/eng/Versions.props index 91a09df773f..22d61962791 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -34,6 +34,7 @@ 9.0.9 + 9.0.9 9.0.9 9.0.9 9.0.9 @@ -53,6 +54,7 @@ 9.0.9 9.0.9 9.0.9 + 9.0.9 9.0.9 9.0.9 9.0.9 @@ -78,6 +80,7 @@ 9.0.9 9.0.9 9.0.9 + 9.0.9 9.0.9 9.0.9 @@ -107,6 +110,7 @@ 8.0.1 8.0.0 8.0.2 + 8.0.0 8.0.20 8.0.20 8.0.0 @@ -132,6 +136,7 @@ 8.0.20 8.0.20 8.0.20 + 8.0.20 8.0.20 8.0.20 diff --git a/eng/packages/General-LTS.props b/eng/packages/General-LTS.props index 884d874c5e1..e5e06d632de 100644 --- a/eng/packages/General-LTS.props +++ b/eng/packages/General-LTS.props @@ -4,6 +4,7 @@ of the framework, we should use the following LTS versions instead --> + @@ -17,6 +18,7 @@ + @@ -28,6 +30,7 @@ + diff --git a/eng/packages/General-net9.props b/eng/packages/General-net9.props index 341f69458a8..e3ff1198cec 100644 --- a/eng/packages/General-net9.props +++ b/eng/packages/General-net9.props @@ -4,6 +4,7 @@ of the framework, the following versions should be used. --> + @@ -17,6 +18,7 @@ + @@ -28,6 +30,7 @@ + diff --git a/eng/packages/General.props b/eng/packages/General.props index 5be4031ad4d..7a5bd0d46a0 100644 --- a/eng/packages/General.props +++ b/eng/packages/General.props @@ -5,6 +5,7 @@ + @@ -21,6 +22,7 @@ + @@ -33,6 +35,7 @@ + diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IHostNameFeature.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IHostNameFeature.cs new file mode 100644 index 00000000000..c7489472374 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IHostNameFeature.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Exposes the host name of the end point. +/// +public interface IHostNameFeature +{ + /// + /// Gets the host name of the end point. + /// + public string HostName { get; } +} + diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IServiceEndpointBuilder.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IServiceEndpointBuilder.cs new file mode 100644 index 00000000000..e051b2bf746 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IServiceEndpointBuilder.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Builder to create a instances. +/// +public interface IServiceEndpointBuilder +{ + /// + /// Gets the endpoints. + /// + IList Endpoints { get; } + + /// + /// Gets the feature collection. + /// + IFeatureCollection Features { get; } + + /// + /// Adds a change token to the resulting . + /// + /// The change token. + void AddChangeToken(IChangeToken changeToken); +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IServiceEndpointProvider.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IServiceEndpointProvider.cs new file mode 100644 index 00000000000..4a192180b66 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IServiceEndpointProvider.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Provides details about a service's endpoints. +/// +public interface IServiceEndpointProvider : IAsyncDisposable +{ + /// + /// Resolves the endpoints for the service. + /// + /// The endpoint collection, which resolved endpoints will be added to. + /// The token to monitor for cancellation requests. + /// The resolution status. + ValueTask PopulateAsync(IServiceEndpointBuilder endpoints, CancellationToken cancellationToken); +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IServiceEndpointProviderFactory.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IServiceEndpointProviderFactory.cs new file mode 100644 index 00000000000..009cbf05d76 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/IServiceEndpointProviderFactory.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Creates instances. +/// +public interface IServiceEndpointProviderFactory +{ + /// + /// Tries to create an instance for the specified . + /// + /// The service to create the provider for. + /// The provider. + /// if the provider was created, otherwise. + bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider); +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/Internal/ServiceEndpointImpl.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/Internal/ServiceEndpointImpl.cs new file mode 100644 index 00000000000..151a9309338 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/Internal/ServiceEndpointImpl.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.Extensions.ServiceDiscovery.Internal; + +internal sealed class ServiceEndpointImpl(EndPoint endPoint, IFeatureCollection? features = null) : ServiceEndpoint +{ + public override EndPoint EndPoint { get; } = endPoint; + + public override IFeatureCollection Features { get; } = features ?? new FeatureCollection(); + + public override string? ToString() => EndPoint switch + { + IPEndPoint ip when ip.Port == 0 && ip.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6 => $"[{ip.Address}]", + IPEndPoint ip when ip.Port == 0 => $"{ip.Address}", + DnsEndPoint dns when dns.Port == 0 => $"{dns.Host}", + DnsEndPoint dns => $"{dns.Host}:{dns.Port}", + _ => EndPoint.ToString()! + }; +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/Microsoft.Extensions.ServiceDiscovery.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/Microsoft.Extensions.ServiceDiscovery.Abstractions.csproj new file mode 100644 index 00000000000..b3a9f892419 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/Microsoft.Extensions.ServiceDiscovery.Abstractions.csproj @@ -0,0 +1,30 @@ + + + + $(TargetFrameworks);netstandard2.0 + true + Provides abstractions for service discovery. Interfaces defined in this package are implemented in Microsoft.Extensions.ServiceDiscovery and other service discovery packages. + $(DefaultDotnetIconFullPath) + Microsoft.Extensions.ServiceDiscovery + + $(NoWarn);S1144;CA1002;S2365;SA1642;IDE0040;CA1307;EA0009;LA0003 + enable + + + + + + + + + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/README.md b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/README.md new file mode 100644 index 00000000000..0d97211313e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/README.md @@ -0,0 +1,7 @@ +# Microsoft.Extensions.ServiceDiscovery.Abstractions + +The `Microsoft.Extensions.ServiceDiscovery.Abstractions` library provides abstractions used by the `Microsoft.Extensions.ServiceDiscovery` library and other libraries which implement service discovery extensions, such as service endpoint providers. For more information, see [Service discovery in .NET](https://learn.microsoft.com/dotnet/core/extensions/service-discovery). + +## Feedback & contributing + +https://github.com/dotnet/aspire diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/ServiceEndpoint.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/ServiceEndpoint.cs new file mode 100644 index 00000000000..33e0eff4d69 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/ServiceEndpoint.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.ServiceDiscovery.Internal; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Represents an endpoint for a service. +/// +public abstract class ServiceEndpoint +{ + /// + /// Gets the endpoint. + /// + public abstract EndPoint EndPoint { get; } + + /// + /// Gets the collection of endpoint features. + /// + public abstract IFeatureCollection Features { get; } + + /// + /// Creates a new . + /// + /// The endpoint being represented. + /// Features of the endpoint. + /// A newly initialized . + public static ServiceEndpoint Create(EndPoint endPoint, IFeatureCollection? features = null) + { + ArgumentNullException.ThrowIfNull(endPoint); + + return new ServiceEndpointImpl(endPoint, features); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/ServiceEndpointQuery.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/ServiceEndpointQuery.cs new file mode 100644 index 00000000000..36fca0893cc --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/ServiceEndpointQuery.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Describes a query for endpoints of a service. +/// +public sealed class ServiceEndpointQuery +{ + private readonly string _originalString; + + /// + /// Initializes a new instance. + /// + /// The string which the query was constructed from. + /// The ordered list of included URI schemes. + /// The service name. + /// The optional endpoint name. + private ServiceEndpointQuery(string originalString, string[] includedSchemes, string serviceName, string? endpointName) + { + _originalString = originalString; + IncludedSchemes = includedSchemes; + ServiceName = serviceName; + EndpointName = endpointName; + } + + /// + /// Tries to parse the provided input as a service endpoint query. + /// + /// The value to parse. + /// The resulting query. + /// if the value was successfully parsed; otherwise . + public static bool TryParse(string input, [NotNullWhen(true)] out ServiceEndpointQuery? query) + { + ArgumentException.ThrowIfNullOrEmpty(input); + + bool hasScheme; + if (!input.Contains("://", StringComparison.Ordinal) + && Uri.TryCreate($"fakescheme://{input}", default, out var uri)) + { + hasScheme = false; + } + else if (Uri.TryCreate(input, default, out uri)) + { + hasScheme = true; + } + else + { + query = null; + return false; + } + + var uriHost = uri.Host; + var segmentSeparatorIndex = uriHost.IndexOf('.'); + string host; + string? endpointName = null; + if (uriHost.StartsWith('_') && segmentSeparatorIndex > 1 && uriHost[^1] != '.') + { + endpointName = uriHost[1..segmentSeparatorIndex]; + + // Skip the endpoint name, including its prefix ('_') and suffix ('.'). + host = uriHost[(segmentSeparatorIndex + 1)..]; + } + else + { + host = uriHost; + } + + // Allow multiple schemes to be separated by a '+', eg. "https+http://host:port". + var schemes = hasScheme ? uri.Scheme.Split('+') : []; + query = new(input, schemes, host, endpointName); + return true; + } + + /// + /// Gets the ordered list of included URI schemes. + /// + public IReadOnlyList IncludedSchemes { get; } + + /// + /// Gets the endpoint name, or if no endpoint name is specified. + /// + public string? EndpointName { get; } + + /// + /// Gets the service name. + /// + public string ServiceName { get; } + + /// + public override string? ToString() => _originalString; +} + diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/ServiceEndpointSource.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/ServiceEndpointSource.cs new file mode 100644 index 00000000000..28d987a2f34 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Abstractions/ServiceEndpointSource.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Represents a collection of service endpoints. +/// +[DebuggerDisplay("{ToString(),nq}")] +[DebuggerTypeProxy(typeof(ServiceEndpointCollectionDebuggerView))] +public sealed class ServiceEndpointSource +{ + private readonly List? _endpoints; + + /// + /// Initializes a new instance. + /// + /// The endpoints. + /// The change token. + /// The feature collection. + public ServiceEndpointSource(List? endpoints, IChangeToken changeToken, IFeatureCollection features) + { + ArgumentNullException.ThrowIfNull(changeToken); + ArgumentNullException.ThrowIfNull(features); + + _endpoints = endpoints; + Features = features; + ChangeToken = changeToken; + } + + /// + /// Gets the endpoints. + /// + public IReadOnlyList Endpoints => _endpoints ?? (IReadOnlyList)[]; + + /// + /// Gets the change token which indicates when this collection should be refreshed. + /// + public IChangeToken ChangeToken { get; } + + /// + /// Gets the feature collection. + /// + public IFeatureCollection Features { get; } + + /// + public override string ToString() + { + if (_endpoints is not { } eps) + { + return "[]"; + } + + return $"[{string.Join(", ", eps)}]"; + } + + private sealed class ServiceEndpointCollectionDebuggerView(ServiceEndpointSource value) + { + public IChangeToken ChangeToken => value.ChangeToken; + + public IFeatureCollection Features => value.Features; + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public ServiceEndpoint[] Endpoints => value.Endpoints.ToArray(); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs new file mode 100644 index 00000000000..7a2d1b632e0 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProvider.cs @@ -0,0 +1,69 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns; + +internal sealed partial class DnsServiceEndpointProvider( + ServiceEndpointQuery query, + string hostName, + IOptionsMonitor options, + ILogger logger, + IDnsResolver resolver, + TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature +{ + protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor; + protected override TimeSpan MinRetryPeriod => options.CurrentValue.MinRetryPeriod; + protected override TimeSpan MaxRetryPeriod => options.CurrentValue.MaxRetryPeriod; + protected override TimeSpan DefaultRefreshPeriod => options.CurrentValue.DefaultRefreshPeriod; + + string IHostNameFeature.HostName => hostName; + + /// + public override string ToString() => "DNS"; + + protected override async Task ResolveAsyncCore() + { + var endpoints = new List(); + var ttl = DefaultRefreshPeriod; + Log.AddressQuery(logger, ServiceName, hostName); + + var now = _timeProvider.GetUtcNow().DateTime; + var addresses = await resolver.ResolveIPAddressesAsync(hostName, ShutdownToken).ConfigureAwait(false); + + foreach (var address in addresses) + { + ttl = MinTtl(now, address.ExpiresAt, ttl); + endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, port: 0))); + } + + if (endpoints.Count == 0) + { + throw new InvalidOperationException($"No DNS records were found for service '{ServiceName}' (DNS name: '{hostName}')."); + } + + SetResult(endpoints, ttl); + + static TimeSpan MinTtl(DateTime now, DateTime expiresAt, TimeSpan existing) + { + var candidate = expiresAt - now; + return candidate < existing ? candidate : existing; + } + + ServiceEndpoint CreateEndpoint(EndPoint endPoint) + { + var serviceEndpoint = ServiceEndpoint.Create(endPoint); + serviceEndpoint.Features.Set(this); + if (options.CurrentValue.ShouldApplyHostNameMetadata(serviceEndpoint)) + { + serviceEndpoint.Features.Set(this); + } + + return serviceEndpoint; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs new file mode 100644 index 00000000000..29aaaf8e930 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.Log.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns; + +partial class DnsServiceEndpointProviderBase +{ + internal static partial class Log + { + [LoggerMessage(1, LogLevel.Trace, "Resolving endpoints for service '{ServiceName}' using DNS SRV lookup for name '{RecordName}'.", EventName = "SrvQuery")] + public static partial void SrvQuery(ILogger logger, string serviceName, string recordName); + + [LoggerMessage(2, LogLevel.Trace, "Resolving endpoints for service '{ServiceName}' using host lookup for name '{RecordName}'.", EventName = "AddressQuery")] + public static partial void AddressQuery(ILogger logger, string serviceName, string recordName); + + [LoggerMessage(3, LogLevel.Debug, "Skipping endpoint resolution for service '{ServiceName}': '{Reason}'.", EventName = "SkippedResolution")] + public static partial void SkippedResolution(ILogger logger, string serviceName, string reason); + + [LoggerMessage(4, LogLevel.Debug, "Service name '{ServiceName}' is not a valid URI or DNS name.", EventName = "ServiceNameIsNotUriOrDnsName")] + public static partial void ServiceNameIsNotUriOrDnsName(ILogger logger, string serviceName); + + [LoggerMessage(5, LogLevel.Debug, "DNS SRV query cannot be constructed for service name '{ServiceName}' because no DNS namespace was configured or detected.", EventName = "NoDnsSuffixFound")] + public static partial void NoDnsSuffixFound(ILogger logger, string serviceName); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs new file mode 100644 index 00000000000..311c06f631a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderBase.cs @@ -0,0 +1,164 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns; + +/// +/// A service end point provider that uses DNS to resolve the service end points. +/// +internal abstract partial class DnsServiceEndpointProviderBase : IServiceEndpointProvider +{ + private readonly object _lock = new(); + private readonly ILogger _logger; + private readonly CancellationTokenSource _disposeCancellation = new(); + protected readonly TimeProvider _timeProvider; + private long _lastRefreshTimeStamp; + private Task _resolveTask = Task.CompletedTask; + private bool _hasEndpoints; + private CancellationChangeToken _lastChangeToken; + private CancellationTokenSource _lastCollectionCancellation; + private List? _lastEndpointCollection; + private TimeSpan _nextRefreshPeriod; + + /// + /// Initializes a new instance. + /// + /// The service name. + /// The logger. + /// The time provider. + protected DnsServiceEndpointProviderBase( + ServiceEndpointQuery query, + ILogger logger, + TimeProvider timeProvider) + { + ServiceName = query.ToString()!; + _logger = logger; + _lastEndpointCollection = null; + _timeProvider = timeProvider; + _lastRefreshTimeStamp = _timeProvider.GetTimestamp(); + var cancellation = _lastCollectionCancellation = new CancellationTokenSource(); + _lastChangeToken = new CancellationChangeToken(cancellation.Token); + } + + private TimeSpan ElapsedSinceRefresh => _timeProvider.GetElapsedTime(_lastRefreshTimeStamp); + + protected string ServiceName { get; } + + protected abstract double RetryBackOffFactor { get; } + + protected abstract TimeSpan MinRetryPeriod { get; } + + protected abstract TimeSpan MaxRetryPeriod { get; } + + protected abstract TimeSpan DefaultRefreshPeriod { get; } + + protected CancellationToken ShutdownToken => _disposeCancellation.Token; + + /// + public async ValueTask PopulateAsync(IServiceEndpointBuilder endpoints, CancellationToken cancellationToken) + { + // Only add endpoints to the collection if a previous provider (eg, a configuration override) did not add them. + if (endpoints.Endpoints.Count != 0) + { + Log.SkippedResolution(_logger, ServiceName, "Collection has existing endpoints"); + return; + } + + if (ShouldRefresh()) + { + Task resolveTask; + lock (_lock) + { + if (_resolveTask.IsCompleted && ShouldRefresh()) + { + _resolveTask = ResolveAsyncCore(); + } + + resolveTask = _resolveTask; + } + + await resolveTask.WaitAsync(cancellationToken).ConfigureAwait(false); + } + + lock (_lock) + { + if (_lastEndpointCollection is { Count: > 0 } eps) + { + foreach (var ep in eps) + { + endpoints.Endpoints.Add(ep); + } + } + + endpoints.AddChangeToken(_lastChangeToken); + return; + } + } + + private bool ShouldRefresh() => _lastEndpointCollection is null || _lastChangeToken is { HasChanged: true } || ElapsedSinceRefresh >= _nextRefreshPeriod; + + protected abstract Task ResolveAsyncCore(); + + protected void SetResult(List endpoints, TimeSpan validityPeriod) + { + lock (_lock) + { + if (endpoints is { Count: > 0 }) + { + _lastRefreshTimeStamp = _timeProvider.GetTimestamp(); + _nextRefreshPeriod = DefaultRefreshPeriod; + _hasEndpoints = true; + } + else + { + _nextRefreshPeriod = GetRefreshPeriod(); + validityPeriod = TimeSpan.Zero; + _hasEndpoints = false; + } + + if (validityPeriod <= TimeSpan.Zero) + { + validityPeriod = _nextRefreshPeriod; + } + else if (validityPeriod > _nextRefreshPeriod) + { + validityPeriod = _nextRefreshPeriod; + } + + _lastCollectionCancellation.Cancel(); + var cancellation = _lastCollectionCancellation = new CancellationTokenSource(validityPeriod, _timeProvider); + _lastChangeToken = new CancellationChangeToken(cancellation.Token); + _lastEndpointCollection = endpoints; + } + + TimeSpan GetRefreshPeriod() + { + if (_hasEndpoints) + { + return MinRetryPeriod; + } + + var nextTicks = (long)(_nextRefreshPeriod.Ticks * RetryBackOffFactor); + if (nextTicks <= 0 || nextTicks > MaxRetryPeriod.Ticks) + { + return MaxRetryPeriod; + } + + return TimeSpan.FromTicks(nextTicks); + } + } + + /// + public async ValueTask DisposeAsync() + { + _disposeCancellation.Cancel(); + + if (_resolveTask is { } task) + { + await task.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs new file mode 100644 index 00000000000..1da21411e64 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderFactory.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns; + +internal sealed partial class DnsServiceEndpointProviderFactory( + IOptionsMonitor options, + ILogger logger, + IDnsResolver resolver, + TimeProvider timeProvider) : IServiceEndpointProviderFactory +{ + /// + public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider) + { + provider = new DnsServiceEndpointProvider(query, hostName: query.ServiceName, options, logger, resolver, timeProvider); + return true; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderOptions.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderOptions.cs new file mode 100644 index 00000000000..b163afc76ff --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsServiceEndpointProviderOptions.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns; + +/// +/// Options for configuring . +/// +public class DnsServiceEndpointProviderOptions +{ + /// + /// Gets or sets the default refresh period for endpoints resolved from DNS. + /// + public TimeSpan DefaultRefreshPeriod { get; set; } = TimeSpan.FromMinutes(1); + + /// + /// Gets or sets the initial period between retries. + /// + public TimeSpan MinRetryPeriod { get; set; } = TimeSpan.FromSeconds(1); + + /// + /// Gets or sets the maximum period between retries. + /// + public TimeSpan MaxRetryPeriod { get; set; } = TimeSpan.FromSeconds(30); + + /// + /// Gets or sets the retry period growth factor. + /// + public double RetryBackOffFactor { get; set; } = 2; + + /// + /// Gets or sets a delegate used to determine whether to apply host name metadata to each resolved endpoint. Defaults to false. + /// + public Func ShouldApplyHostNameMetadata { get; set; } = _ => false; +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs new file mode 100644 index 00000000000..6d5ade5059e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProvider.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns; + +internal sealed partial class DnsSrvServiceEndpointProvider( + ServiceEndpointQuery query, + string srvQuery, + string hostName, + IOptionsMonitor options, + ILogger logger, + IDnsResolver resolver, + TimeProvider timeProvider) : DnsServiceEndpointProviderBase(query, logger, timeProvider), IHostNameFeature +{ + protected override double RetryBackOffFactor => options.CurrentValue.RetryBackOffFactor; + + protected override TimeSpan MinRetryPeriod => options.CurrentValue.MinRetryPeriod; + + protected override TimeSpan MaxRetryPeriod => options.CurrentValue.MaxRetryPeriod; + + protected override TimeSpan DefaultRefreshPeriod => options.CurrentValue.DefaultRefreshPeriod; + + public override string ToString() => "DNS SRV"; + + string IHostNameFeature.HostName => hostName; + + protected override async Task ResolveAsyncCore() + { + var endpoints = new List(); + var ttl = DefaultRefreshPeriod; + Log.SrvQuery(logger, ServiceName, srvQuery); + + var now = _timeProvider.GetUtcNow().DateTime; + var result = await resolver.ResolveServiceAsync(srvQuery, cancellationToken: ShutdownToken).ConfigureAwait(false); + + foreach (var record in result) + { + ttl = MinTtl(now, record.ExpiresAt, ttl); + + if (record.Addresses.Length > 0) + { + foreach (var address in record.Addresses) + { + ttl = MinTtl(now, address.ExpiresAt, ttl); + endpoints.Add(CreateEndpoint(new IPEndPoint(address.Address, record.Port))); + } + } + else + { + endpoints.Add(CreateEndpoint(new DnsEndPoint(record.Target.TrimEnd('.'), record.Port))); + } + } + + SetResult(endpoints, ttl); + + static TimeSpan MinTtl(DateTime now, DateTime expiresAt, TimeSpan existing) + { + var candidate = expiresAt - now; + return candidate < existing ? candidate : existing; + } + + ServiceEndpoint CreateEndpoint(EndPoint endPoint) + { + var serviceEndpoint = ServiceEndpoint.Create(endPoint); + serviceEndpoint.Features.Set(this); + if (options.CurrentValue.ShouldApplyHostNameMetadata(serviceEndpoint)) + { + serviceEndpoint.Features.Set(this); + } + + return serviceEndpoint; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs new file mode 100644 index 00000000000..57820560a63 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderFactory.cs @@ -0,0 +1,133 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns; + +internal sealed partial class DnsSrvServiceEndpointProviderFactory( + IOptionsMonitor options, + ILogger logger, + IDnsResolver resolver, + TimeProvider timeProvider) : IServiceEndpointProviderFactory +{ + private static readonly string s_serviceAccountPath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount"); + private static readonly string s_serviceAccountNamespacePath = Path.Combine($"{Path.DirectorySeparatorChar}var", "run", "secrets", "kubernetes.io", "serviceaccount", "namespace"); + private static readonly string s_resolveConfPath = Path.Combine($"{Path.DirectorySeparatorChar}etc", "resolv.conf"); + private readonly string? _querySuffix = options.CurrentValue.QuerySuffix?.TrimStart('.') ?? GetKubernetesHostDomain(); + + /// + public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider) + { + // If a default namespace is not specified, then this provider will attempt to infer the namespace from the service name, but only when running inside Kubernetes. + // Kubernetes DNS spec: https://github.com/kubernetes/dns/blob/master/docs/specification.md + // SRV records are available for headless services with named ports. + // They take the form $"_{portName}._{protocol}.{serviceName}.{namespace}.{suffix}" + // The suffix (after the service name) can be parsed from /etc/resolv.conf + // Otherwise, the namespace can be read from /var/run/secrets/kubernetes.io/serviceaccount/namespace and combined with an assumed suffix of "svc.cluster.local". + // The protocol is assumed to be "tcp". + // The portName is the name of the port in the service definition. If the serviceName parses as a URI, we use the scheme as the port name, otherwise "default". + if (string.IsNullOrWhiteSpace(_querySuffix)) + { + DnsServiceEndpointProviderBase.Log.NoDnsSuffixFound(logger, query.ToString()!); + provider = default; + return false; + } + + var portName = query.EndpointName ?? "default"; + var srvQuery = $"_{portName}._tcp.{query.ServiceName}.{_querySuffix}"; + provider = new DnsSrvServiceEndpointProvider(query, srvQuery, hostName: query.ServiceName, options, logger, resolver, timeProvider); + return true; + } + + private static string? GetKubernetesHostDomain() + { + // Check that we are running in Kubernetes first. + if (!IsInKubernetesCluster()) + { + return null; + } + + if (!OperatingSystem.IsLinux()) + { + return null; + } + + var qualifiedNamespace = ReadQualifiedNamespaceFromResolvConf(); + if (!string.IsNullOrWhiteSpace(qualifiedNamespace)) + { + return qualifiedNamespace; + } + + var serviceAccountNamespace = ReadNamespaceFromKubernetesServiceAccount(); + if (!string.IsNullOrWhiteSpace(serviceAccountNamespace)) + { + // The zone is assumed to be "cluster.local" + return $"{serviceAccountNamespace}.svc.cluster.local"; + } + + return null; + } + + private static string? ReadNamespaceFromKubernetesServiceAccount() + { + // Read the namespace from the Kubernetes pod's service account. + if (File.Exists(s_serviceAccountNamespacePath)) + { + return File.ReadAllText(s_serviceAccountNamespacePath).Trim(); + } + + return null; + } + + private static string? ReadQualifiedNamespaceFromResolvConf() + { + if (!File.Exists(s_resolveConfPath)) + { + return default; + } + + // See https://manpages.debian.org/bookworm/manpages/resolv.conf.5.en.html#search for the format of /etc/resolv.conf's search option. + // In our case, we are interested in determining the domain name. + var lines = File.ReadAllLines(s_resolveConfPath); + foreach (var line in lines) + { + if (!line.StartsWith("search ", StringComparison.Ordinal)) + { + continue; + } + + var components = line.Split(' ', StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries); + if (components.Length > 1) + { + return components[1]; + } + } + + return default; + } + + private static bool IsInKubernetesCluster() + { + // This logic is based on the Kubernetes C# client logic found here: + // https://github.com/kubernetes-client/csharp/blob/52c3c00d4c55b28bdb491a219f4967823a83df2d/src/KubernetesClient/KubernetesClientConfiguration.InCluster.cs#L21 + var host = Environment.GetEnvironmentVariable("KUBERNETES_SERVICE_HOST"); + var port = Environment.GetEnvironmentVariable("KUBERNETES_SERVICE_PORT"); + if (string.IsNullOrEmpty(host) || string.IsNullOrEmpty(port)) + { + return false; + } + + var tokenPath = Path.Combine(s_serviceAccountPath, "token"); + if (!File.Exists(tokenPath)) + { + return false; + } + + var certPath = Path.Combine(s_serviceAccountPath, "ca.crt"); + return File.Exists(certPath); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderOptions.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderOptions.cs new file mode 100644 index 00000000000..c908c56d770 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/DnsSrvServiceEndpointProviderOptions.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns; + +/// +/// Options for configuring . +/// +public class DnsSrvServiceEndpointProviderOptions +{ + /// + /// Gets or sets the default refresh period for endpoints resolved from DNS. + /// + public TimeSpan DefaultRefreshPeriod { get; set; } = TimeSpan.FromMinutes(1); + + /// + /// Gets or sets the initial period between retries. + /// + public TimeSpan MinRetryPeriod { get; set; } = TimeSpan.FromSeconds(1); + + /// + /// Gets or sets the maximum period between retries. + /// + public TimeSpan MaxRetryPeriod { get; set; } = TimeSpan.FromSeconds(30); + + /// + /// Gets or sets the retry period growth factor. + /// + public double RetryBackOffFactor { get; set; } = 2; + + /// + /// Gets or sets the default DNS query suffix for services resolved via this provider. + /// + /// + /// If not specified, the provider will attempt to infer the namespace. + /// + public string? QuerySuffix { get; set; } + + /// + /// Gets or sets a delegate used to determine whether to apply host name metadata to each resolved endpoint. Defaults to false. + /// + public Func ShouldApplyHostNameMetadata { get; set; } = _ => false; +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/FallbackDnsResolver.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/FallbackDnsResolver.cs new file mode 100644 index 00000000000..1cdcab2f05d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/FallbackDnsResolver.cs @@ -0,0 +1,102 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using DnsClient; +using DnsClient.Protocol; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns; + +internal sealed class FallbackDnsResolver : IDnsResolver +{ + private readonly LookupClient _lookupClient; + private readonly IOptionsMonitor _options; + private readonly TimeProvider _timeProvider; + + public FallbackDnsResolver(LookupClient lookupClient, IOptionsMonitor options, TimeProvider timeProvider) + { + _lookupClient = lookupClient; + _options = options; + _timeProvider = timeProvider; + } + + private TimeSpan DefaultRefreshPeriod => _options.CurrentValue.DefaultRefreshPeriod; + + public async ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) + { + DateTime expiresAt = _timeProvider.GetUtcNow().DateTime.Add(DefaultRefreshPeriod); + var addresses = await System.Net.Dns.GetHostAddressesAsync(name, cancellationToken).ConfigureAwait(false); + + var results = new AddressResult[addresses.Length]; + + for (int i = 0; i < addresses.Length; i++) + { + results[i] = new AddressResult + { + Address = addresses[i], + ExpiresAt = expiresAt + }; + } + + return results; + } + + public async ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) + { + DateTime now = _timeProvider.GetUtcNow().DateTime; + var queryResult = await _lookupClient.QueryAsync(name, DnsClient.QueryType.SRV, cancellationToken: cancellationToken).ConfigureAwait(false); + if (queryResult.HasError) + { + throw CreateException(name, queryResult.ErrorMessage); + } + + var lookupMapping = new Dictionary>(); + foreach (var record in queryResult.Additionals.OfType()) + { + if (!lookupMapping.TryGetValue(record.DomainName, out var addresses)) + { + addresses = new List(); + lookupMapping[record.DomainName] = addresses; + } + + addresses.Add(new AddressResult + { + Address = record.Address, + ExpiresAt = now.Add(TimeSpan.FromSeconds(record.TimeToLive)) + }); + } + + var srvRecords = queryResult.Answers.OfType().ToList(); + + var results = new ServiceResult[srvRecords.Count]; + for (int i = 0; i < srvRecords.Count; i++) + { + var record = srvRecords[i]; + + results[i] = new ServiceResult + { + ExpiresAt = now.Add(TimeSpan.FromSeconds(record.TimeToLive)), + Priority = record.Priority, + Weight = record.Weight, + Port = record.Port, + Target = record.Target, + Addresses = lookupMapping.TryGetValue(record.Target, out var addresses) + ? addresses.ToArray() + : Array.Empty() + }; + } + + return results; + } + + private static InvalidOperationException CreateException(string dnsName, string errorMessage) + { + var msg = errorMessage switch + { + { Length: > 0 } => $"No DNS SRV records were found for DNS name '{dnsName}': {errorMessage}.", + _ => $"No DNS SRV records were found for DNS name '{dnsName}'", + }; + return new InvalidOperationException(msg); + } +} \ No newline at end of file diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj new file mode 100644 index 00000000000..890f8daab3e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Microsoft.Extensions.ServiceDiscovery.Dns.csproj @@ -0,0 +1,34 @@ + + + + $(NetCoreTargetFrameworks) + true + Provides extensions to HttpClient to resolve well-known hostnames to concrete endpoints based on DNS records. Useful for service resolution in orchestrators such as Kubernetes. + $(DefaultDotnetIconFullPath) + + $(NoWarn);IDE0018;IDE0025;IDE0032;IDE0040;IDE0058;IDE0250;IDE0251;IDE1006;CA1304;CA1307;CA1309;CA1310;CA1849;CA2000;CA2213;CA2217;S125;S1135;S1226;S2344;S3626;S4022;SA1108;SA1120;SA1128;SA1129;SA1204;SA1205;SA1214;SA1400;SA1405;SA1408;SA1515;SA1600;SA1629;SA1642;SA1649;EA0001;EA0009;EA0014;LA0001;LA0003;LA0008;VSTHRD200 + enable + false + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/README.md b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/README.md new file mode 100644 index 00000000000..8be4560870b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/README.md @@ -0,0 +1,65 @@ +# Microsoft.Extensions.ServiceDiscovery.Dns + +This library provides support for resolving service endpoints using DNS (Domain Name System). It provides two service endpoint providers: + +- _DNS_, which resolves endpoints using DNS `A/AAAA` record queries. This means that it can resolve names to IP addresses, but cannot resolve port numbers endpoints. As such, port numbers are assumed to be the default for the protocol (for example, 80 for HTTP and 433 for HTTPS). The benefit of using the DNS provider is that for cases where these default ports are appropriate, clients can spread their requests across hosts. For more information, see _Load-balancing with endpoint selectors_. + +- _DNS SRV_, which resolves service names using DNS SRV record queries. This allows it to resolve both IP addresses and port numbers. This is useful for environments which support DNS SRV queries, such as Kubernetes (when configured accordingly). + +## Resolving service endpoints with DNS + +The _DNS_ service endpoint provider resolves endpoints using DNS `A/AAAA` record queries. This means that it can resolve names to IP addresses, but cannot resolve port numbers endpoints. As such, port numbers are assumed to be the default for the protocol (for example, 80 for HTTP and 433 for HTTPS). The benefit of using the DNS service endpoint provider is that for cases where these default ports are appropriate, clients can spread their requests across hosts. For more information, see _Load-balancing with endpoint selectors_. + +To configure the DNS service endpoint provider in your application, add the DNS service endpoint provider to your host builder's service collection using the `AddDnsServiceEndpointProvider` method. service discovery as follows: + +```csharp +builder.Services.AddServiceDiscoveryCore(); +builder.Services.AddDnsServiceEndpointProvider(); +``` + +## Resolving service endpoints in Kubernetes with DNS SRV + +When deploying to Kubernetes, the DNS SRV service endpoint provider can be used to resolve endpoints. For example, the following resource definition will result in a DNS SRV record being created for an endpoint named "default" and an endpoint named "dashboard", both on the service named "basket". + +```yml +apiVersion: v1 +kind: Service +metadata: + name: basket +spec: + selector: + name: basket-service + clusterIP: None + ports: + - name: default + port: 8080 + - name: dashboard + port: 8888 +``` + +To configure a service to resolve the "dashboard" endpoint on the "basket" service, add the DNS SRV service endpoint provider to the host builder as follows: + +```csharp +builder.Services.AddServiceDiscoveryCore(); +builder.Services.AddDnsSrvServiceEndpointProvider(); +``` + +The special port name "default" is used to specify the default endpoint, resolved using the URI `http://basket`. + +As in the previous example, add service discovery to an `HttpClient` for the basket service: + +```csharp +builder.Services.AddHttpClient( + static client => client.BaseAddress = new("http://basket")); +``` + +Similarly, the "dashboard" endpoint can be targeted as follows: + +```csharp +builder.Services.AddHttpClient( + static client => client.BaseAddress = new("http://_dashboard.basket")); +``` + +## Feedback & contributing + +https://github.com/dotnet/aspire diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs new file mode 100644 index 00000000000..094df3040d1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataReader.cs @@ -0,0 +1,133 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct DnsDataReader : IDisposable +{ + public ArraySegment MessageBuffer { get; private set; } + bool _returnToPool; + private int _position; + + public DnsDataReader(ArraySegment buffer, bool returnToPool = false) + { + MessageBuffer = buffer; + _position = 0; + _returnToPool = returnToPool; + } + + public bool TryReadHeader(out DnsMessageHeader header) + { + Debug.Assert(_position == 0); + + if (!DnsPrimitives.TryReadMessageHeader(MessageBuffer.AsSpan(), out header, out int bytesRead)) + { + header = default; + return false; + } + + _position += bytesRead; + return true; + } + + internal bool TryReadQuestion(out EncodedDomainName name, out QueryType type, out QueryClass @class) + { + if (!TryReadDomainName(out name) || + !TryReadUInt16(out ushort typeAsInt) || + !TryReadUInt16(out ushort classAsInt)) + { + type = 0; + @class = 0; + return false; + } + + type = (QueryType)typeAsInt; + @class = (QueryClass)classAsInt; + return true; + } + + public bool TryReadUInt16(out ushort value) + { + if (MessageBuffer.Count - _position < 2) + { + value = 0; + return false; + } + + value = BinaryPrimitives.ReadUInt16BigEndian(MessageBuffer.AsSpan(_position)); + _position += 2; + return true; + } + + public bool TryReadUInt32(out uint value) + { + if (MessageBuffer.Count - _position < 4) + { + value = 0; + return false; + } + + value = BinaryPrimitives.ReadUInt32BigEndian(MessageBuffer.AsSpan(_position)); + _position += 4; + return true; + } + + public bool TryReadResourceRecord(out DnsResourceRecord record) + { + if (!TryReadDomainName(out EncodedDomainName name) || + !TryReadUInt16(out ushort type) || + !TryReadUInt16(out ushort @class) || + !TryReadUInt32(out uint ttl) || + !TryReadUInt16(out ushort dataLength) || + MessageBuffer.Count - _position < dataLength) + { + record = default; + return false; + } + + ReadOnlyMemory data = MessageBuffer.AsMemory(_position, dataLength); + _position += dataLength; + + record = new DnsResourceRecord(name, (QueryType)type, (QueryClass)@class, (int)ttl, data); + return true; + } + + public bool TryReadDomainName(out EncodedDomainName name) + { + if (DnsPrimitives.TryReadQName(MessageBuffer, _position, out name, out int bytesRead)) + { + _position += bytesRead; + return true; + } + + return false; + } + + public bool TryReadSpan(int length, out ReadOnlySpan name) + { + if (MessageBuffer.Count - _position < length) + { + name = default; + return false; + } + + name = MessageBuffer.AsSpan(_position, length); + _position += length; + return true; + } + + public void Dispose() + { + if (_returnToPool && MessageBuffer.Array != null) + { + ArrayPool.Shared.Return(MessageBuffer.Array); + } + + _returnToPool = false; + MessageBuffer = default; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs new file mode 100644 index 00000000000..a0a11f0b808 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsDataWriter.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers.Binary; +using System.Diagnostics; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal sealed class DnsDataWriter +{ + private readonly Memory _buffer; + private int _position; + + internal DnsDataWriter(Memory buffer) + { + _buffer = buffer; + _position = 0; + } + + public int Position => _position; + + internal bool TryWriteHeader(in DnsMessageHeader header) + { + if (!DnsPrimitives.TryWriteMessageHeader(_buffer.Span.Slice(_position), header, out int written)) + { + return false; + } + + _position += written; + return true; + } + + internal bool TryWriteQuestion(EncodedDomainName name, QueryType type, QueryClass @class) + { + if (!TryWriteDomainName(name) || + !TryWriteUInt16((ushort)type) || + !TryWriteUInt16((ushort)@class)) + { + return false; + } + + return true; + } + + private bool TryWriteDomainName(EncodedDomainName name) + { + foreach (var label in name.Labels) + { + // this should be already validated by the caller + Debug.Assert(label.Length <= 63, "Label length must not exceed 63 bytes."); + + if (!TryWriteByte((byte)label.Length) || + !TryWriteRawData(label.Span)) + { + return false; + } + } + + // root label + return TryWriteByte(0); + } + + internal bool TryWriteDomainName(string name) + { + if (DnsPrimitives.TryWriteQName(_buffer.Span.Slice(_position), name, out int written)) + { + _position += written; + return true; + } + + return false; + } + + internal bool TryWriteByte(byte value) + { + if (_buffer.Length - _position < 1) + { + return false; + } + + _buffer.Span[_position] = value; + _position += 1; + return true; + } + + internal bool TryWriteUInt16(ushort value) + { + if (_buffer.Length - _position < 2) + { + return false; + } + + BinaryPrimitives.WriteUInt16BigEndian(_buffer.Span.Slice(_position), value); + _position += 2; + return true; + } + + internal bool TryWriteUInt32(uint value) + { + if (_buffer.Length - _position < 4) + { + return false; + } + + BinaryPrimitives.WriteUInt32BigEndian(_buffer.Span.Slice(_position), value); + _position += 4; + return true; + } + + internal bool TryWriteRawData(ReadOnlySpan value) + { + if (_buffer.Length - _position < value.Length) + { + return false; + } + + value.CopyTo(_buffer.Span.Slice(_position)); + _position += value.Length; + return true; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs new file mode 100644 index 00000000000..b22273a04f2 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsMessageHeader.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +// RFC 1035 4.1.1. Header section format +internal struct DnsMessageHeader +{ + internal const int HeaderLength = 12; + public ushort TransactionId { get; set; } + + internal QueryFlags QueryFlags { get; set; } + + public ushort QueryCount { get; set; } + + public ushort AnswerCount { get; set; } + + public ushort AuthorityCount { get; set; } + + public ushort AdditionalRecordCount { get; set; } + + public QueryResponseCode ResponseCode + { + get => (QueryResponseCode)(QueryFlags & QueryFlags.ResponseCodeMask); + } + + public bool IsResultTruncated + { + get => (QueryFlags & QueryFlags.ResultTruncated) != 0; + } + + public bool IsResponse + { + get => (QueryFlags & QueryFlags.HasResponse) != 0; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs new file mode 100644 index 00000000000..e549abe2576 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsPrimitives.cs @@ -0,0 +1,318 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Globalization; +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal static class DnsPrimitives +{ + // Maximum length of a domain name in ASCII (excluding trailing dot) + internal const int MaxDomainNameLength = 253; + + internal static bool TryReadMessageHeader(ReadOnlySpan buffer, out DnsMessageHeader header, out int bytesRead) + { + // RFC 1035 4.1.1. Header section format + if (buffer.Length < DnsMessageHeader.HeaderLength) + { + header = default; + bytesRead = 0; + return false; + } + + header = new DnsMessageHeader + { + TransactionId = BinaryPrimitives.ReadUInt16BigEndian(buffer), + QueryFlags = (QueryFlags)BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(2)), + QueryCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(4)), + AnswerCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(6)), + AuthorityCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(8)), + AdditionalRecordCount = BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(10)) + }; + + bytesRead = DnsMessageHeader.HeaderLength; + return true; + } + + internal static bool TryWriteMessageHeader(Span buffer, DnsMessageHeader header, out int bytesWritten) + { + // RFC 1035 4.1.1. Header section format + if (buffer.Length < DnsMessageHeader.HeaderLength) + { + bytesWritten = 0; + return false; + } + + BinaryPrimitives.WriteUInt16BigEndian(buffer, header.TransactionId); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(2), (ushort)header.QueryFlags); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(4), header.QueryCount); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(6), header.AnswerCount); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(8), header.AuthorityCount); + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(10), header.AdditionalRecordCount); + + bytesWritten = DnsMessageHeader.HeaderLength; + return true; + } + + // https://www.rfc-editor.org/rfc/rfc1035#section-2.3.4 + // labels 63 octets or less + // name 255 octets or less + + private static readonly SearchValues s_domainNameValidChars = SearchValues.Create("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_."); + private static readonly IdnMapping s_idnMapping = new IdnMapping(); + internal static bool TryWriteQName(Span destination, string name, out int written) + { + written = 0; + + // + // RFC 1035 4.1.2. + // + // a domain name represented as a sequence of labels, where + // each label consists of a length octet followed by that + // number of octets. The domain name terminates with the + // zero length octet for the null label of the root. Note + // that this field may be an odd number of octets; no + // padding is used. + // + if (!Ascii.IsValid(name)) + { + // IDN name, apply punycode + try + { + // IdnMapping performs some validation internally (such as label + // and domain name lengths), but is more relaxed than RFC + // 1035 (e.g. allows ~ chars), so even if this conversion does + // not throw, we still need to perform additional validation + name = s_idnMapping.GetAscii(name); + } + catch + { + return false; + } + } + + if (name.Length > MaxDomainNameLength || + name.AsSpan().ContainsAnyExcept(s_domainNameValidChars) || + destination.IsEmpty || + !Encoding.ASCII.TryGetBytes(name, destination.Slice(1), out int length) || + destination.Length < length + 2) + { + // buffer too small + return false; + } + + Span nameBuffer = destination.Slice(0, 1 + length); + Span label; + while (true) + { + // figure out the next label and prepend the length + int index = nameBuffer.Slice(1).IndexOf((byte)'.'); + label = index == -1 ? nameBuffer.Slice(1) : nameBuffer.Slice(1, index); + + if (label.Length == 0) + { + // empty label (explicit root) is only allowed at the end + if (index != -1) + { + written = 0; + return false; + } + } + // Label restrictions: + // - maximum 63 octets long + // - must start with a letter or digit (digit is allowed by RFC 1123) + // - may start with an underscore (underscore may be present only + // at the start of the label to support SRV records) + // - must end with a letter or digit + else if (label.Length > 63 || + !char.IsAsciiLetterOrDigit((char)label[0]) && label[0] != '_' || + label.Slice(1).Contains((byte)'_') || + !char.IsAsciiLetterOrDigit((char)label[^1])) + { + written = 0; + return false; + } + + nameBuffer[0] = (byte)label.Length; + written += label.Length + 1; + + if (index == -1) + { + // this was the last label + break; + } + + nameBuffer = nameBuffer.Slice(index + 1); + } + + // Add root label if wasn't explicitly specified + if (label.Length != 0) + { + destination[written] = 0; + written++; + } + + return true; + } + + private static bool TryReadQNameCore(List> labels, int totalLength, ReadOnlyMemory messageBuffer, int offset, out int bytesRead, bool canStartWithPointer = true) + { + // + // domain name can be either + // - a sequence of labels, where each label consists of a length octet + // followed by that number of octets, terminated by a zero length octet + // (root label) + // - a pointer, where the first two bits are set to 1, and the remaining + // 14 bits are an offset (from the start of the message) to the true + // label + // + // It is not specified by the RFC if pointers must be backwards only, + // the code below prohibits forward (and self) pointers to avoid + // infinite loops. It also allows pointers only to point to a + // label, not to another pointer. + // + + bytesRead = 0; + bool allowPointer = canStartWithPointer; + + if (offset < 0 || offset >= messageBuffer.Length) + { + return false; + } + + int currentOffset = offset; + + while (true) + { + byte length = messageBuffer.Span[currentOffset]; + + if ((length & 0xC0) == 0x00) + { + // length followed by the label + if (length == 0) + { + // end of name + bytesRead = currentOffset - offset + 1; + return true; + } + + if (currentOffset + 1 + length >= messageBuffer.Length) + { + // too many labels or truncated data + break; + } + + // read next label/segment + labels.Add(messageBuffer.Slice(currentOffset + 1, length)); + totalLength += 1 + length; + + // subtract one for the length prefix of the first label + if (totalLength - 1 > MaxDomainNameLength) + { + // domain name is too long + return false; + } + + currentOffset += 1 + length; + bytesRead += 1 + length; + + // we read a label, they can be followed by pointer. + allowPointer = true; + } + else if ((length & 0xC0) == 0xC0) + { + // pointer, together with next byte gives the offset of the true label + if (!allowPointer || currentOffset + 1 >= messageBuffer.Length) + { + // pointer to pointer or truncated data + break; + } + + bytesRead += 2; + int pointer = ((length & 0x3F) << 8) | messageBuffer.Span[currentOffset + 1]; + + // we prohibit self-references and forward pointers to avoid + // infinite loops, we do this by truncating the + // messageBuffer at the offset where we started reading the + // name. We also ignore the bytesRead from the recursive + // call, as we are only interested on how many bytes we read + // from the initial start of the name. + return TryReadQNameCore(labels, totalLength, messageBuffer.Slice(0, offset), pointer, out int _, false); + } + else + { + // top two bits are reserved, this means invalid data + break; + } + } + + return false; + + } + + internal static bool TryReadQName(ReadOnlyMemory messageBuffer, int offset, out EncodedDomainName name, out int bytesRead) + { + List> labels = new List>(); + + if (TryReadQNameCore(labels, 0, messageBuffer, offset, out bytesRead)) + { + name = new EncodedDomainName(labels); + return true; + } + else + { + bytesRead = 0; + name = default; + return false; + } + } + + internal static bool TryReadService(ReadOnlyMemory buffer, out ushort priority, out ushort weight, out ushort port, out EncodedDomainName target, out int bytesRead) + { + // https://www.rfc-editor.org/rfc/rfc2782 + if (!BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span, out priority) || + !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span.Slice(2), out weight) || + !BinaryPrimitives.TryReadUInt16BigEndian(buffer.Span.Slice(4), out port) || + !TryReadQName(buffer.Slice(6), 0, out target, out bytesRead)) + { + target = default; + priority = 0; + weight = 0; + port = 0; + bytesRead = 0; + return false; + } + + bytesRead += 6; + return true; + } + + internal static bool TryReadSoa(ReadOnlyMemory buffer, out EncodedDomainName primaryNameServer, out EncodedDomainName responsibleMailAddress, out uint serial, out uint refresh, out uint retry, out uint expire, out uint minimum, out int bytesRead) + { + // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 + if (!TryReadQName(buffer, 0, out primaryNameServer, out int w1) || + !TryReadQName(buffer.Slice(w1), 0, out responsibleMailAddress, out int w2) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2), out serial) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 4), out refresh) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 8), out retry) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 12), out expire) || + !BinaryPrimitives.TryReadUInt32BigEndian(buffer.Span.Slice(w1 + w2 + 16), out minimum)) + { + primaryNameServer = default; + responsibleMailAddress = default; + serial = 0; + refresh = 0; + retry = 0; + expire = 0; + minimum = 0; + bytesRead = 0; + return false; + } + + bytesRead = w1 + w2 + 20; + return true; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs new file mode 100644 index 00000000000..adab9161737 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Log.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +using System.Net; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal partial class DnsResolver : IDnsResolver, IDisposable +{ + internal static partial class Log + { + [LoggerMessage(1, LogLevel.Debug, "Resolving {QueryType} {QueryName} on {Server} attempt {Attempt}", EventName = "Query")] + public static partial void Query(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(2, LogLevel.Debug, "Result truncated for {QueryType} {QueryName} from {Server} attempt {Attempt}. Restarting over TCP", EventName = "ResultTruncated")] + public static partial void ResultTruncated(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(3, LogLevel.Error, "Server {Server} replied with {ResponseCode} when querying {QueryType} {QueryName}", EventName = "ErrorResponseCode")] + public static partial void ErrorResponseCode(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, QueryResponseCode responseCode); + + [LoggerMessage(4, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} timed out.", EventName = "Timeout")] + public static partial void Timeout(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(5, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt}: no data matching given query type.", EventName = "NoData")] + public static partial void NoData(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(6, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt}: server indicates given name does not exist.", EventName = "NameError")] + public static partial void NameError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(7, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed to return a valid DNS response.", EventName = "MalformedResponse")] + public static partial void MalformedResponse(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt); + + [LoggerMessage(8, LogLevel.Warning, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed due to a network error.", EventName = "NetworkError")] + public static partial void NetworkError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception); + + [LoggerMessage(9, LogLevel.Error, "Query {QueryType} {QueryName} on {Server} attempt {Attempt} failed.", EventName = "QueryError")] + public static partial void QueryError(ILogger logger, QueryType queryType, string queryName, IPEndPoint server, int attempt, Exception exception); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs new file mode 100644 index 00000000000..4be956cede9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.Telemetry.cs @@ -0,0 +1,115 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.Metrics; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal partial class DnsResolver +{ + internal static class Telemetry + { + private static readonly Meter s_meter = new Meter("Microsoft.Extensions.ServiceDiscovery.Dns.Resolver"); + private static readonly Histogram s_queryDuration = s_meter.CreateHistogram("query.duration", "ms", "DNS query duration"); + + private static bool IsEnabled() => s_queryDuration.Enabled; + + public static NameResolutionActivity StartNameResolution(string hostName, QueryType queryType, long startingTimestamp) + { + if (IsEnabled()) + { + return new NameResolutionActivity(hostName, queryType, startingTimestamp); + } + + return default; + } + + public static void StopNameResolution(string hostName, QueryType queryType, in NameResolutionActivity activity, object? answers, SendQueryError error, long endingTimestamp) + { + activity.Stop(answers, error, endingTimestamp, out TimeSpan duration); + + if (!IsEnabled()) + { + return; + } + + var hostNameTag = KeyValuePair.Create("dns.question.name", (object?)hostName); + var queryTypeTag = KeyValuePair.Create("dns.question.type", (object?)queryType); + + if (answers is not null) + { + s_queryDuration.Record(duration.TotalSeconds, hostNameTag, queryTypeTag); + } + else + { + var errorTypeTag = KeyValuePair.Create("error.type", (object?)error.ToString()); + s_queryDuration.Record(duration.TotalSeconds, hostNameTag, queryTypeTag, errorTypeTag); + } + } + } + + internal readonly struct NameResolutionActivity + { + private const string ActivitySourceName = "Microsoft.Extensions.ServiceDiscovery.Dns.Resolver"; + private const string ActivityName = ActivitySourceName + ".Resolve"; + private static readonly ActivitySource s_activitySource = new ActivitySource(ActivitySourceName); + + private readonly long _startingTimestamp; + private readonly Activity? _activity; // null if activity is not started + + public NameResolutionActivity(string hostName, QueryType queryType, long startingTimestamp) + { + _startingTimestamp = startingTimestamp; + _activity = s_activitySource.StartActivity(ActivityName, ActivityKind.Client); + if (_activity is not null) + { + _activity.DisplayName = $"Resolving {hostName}"; + if (_activity.IsAllDataRequested) + { + _activity.SetTag("dns.question.name", hostName); + _activity.SetTag("dns.question.type", queryType.ToString()); + } + } + } + + public void Stop(object? answers, SendQueryError error, long endingTimestamp, out TimeSpan duration) + { + duration = Stopwatch.GetElapsedTime(_startingTimestamp, endingTimestamp); + + if (_activity is null) + { + return; + } + + if (_activity.IsAllDataRequested) + { + if (answers is not null) + { + static string[] ToStringHelper(T[] array) => array.Select(a => a!.ToString()!).ToArray(); + + string[]? answersArray = answers switch + { + ServiceResult[] serviceResults => ToStringHelper(serviceResults), + AddressResult[] addressResults => ToStringHelper(addressResults), + _ => null + }; + + Debug.Assert(answersArray is not null); + _activity.SetTag("dns.answers", answersArray); + } + else + { + _activity.SetTag("error.type", error.ToString()); + } + } + + if (answers is null) + { + _activity.SetStatus(ActivityStatusCode.Error); + } + + _activity.Stop(); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs new file mode 100644 index 00000000000..bc290c6b907 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResolver.cs @@ -0,0 +1,931 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal sealed partial class DnsResolver : IDnsResolver, IDisposable +{ + private const int IPv4Length = 4; + private const int IPv6Length = 16; + + // CancellationTokenSource.CancelAfter has a maximum timeout of Int32.MaxValue milliseconds. + private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue); + + private bool _disposed; + private readonly ResolverOptions _options; + private readonly CancellationTokenSource _pendingRequestsCts = new(); + private readonly TimeProvider _timeProvider; + private readonly ILogger _logger; + + public DnsResolver(TimeProvider timeProvider, ILogger logger) : this(timeProvider, logger, OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() ? ResolvConf.GetOptions() : NetworkInfo.GetOptions()) + { + } + + internal DnsResolver(TimeProvider timeProvider, ILogger logger, ResolverOptions options) + { + _timeProvider = timeProvider; + _logger = logger; + _options = options; + Debug.Assert(_options.Servers.Count > 0); + + if (options.Timeout != Timeout.InfiniteTimeSpan) + { + ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(options.Timeout, TimeSpan.Zero); + ArgumentOutOfRangeException.ThrowIfGreaterThan(options.Timeout, s_maxTimeout); + } + } + + internal DnsResolver(ResolverOptions options) : this(TimeProvider.System, NullLogger.Instance, options) + { + } + + internal DnsResolver(IEnumerable servers) : this(new ResolverOptions(servers.ToArray())) + { + } + + internal DnsResolver(IPEndPoint server) : this(new ResolverOptions(server)) + { + } + + public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed, this); + cancellationToken.ThrowIfCancellationRequested(); + + // dnsSafeName is Disposed by SendQueryWithTelemetry + EncodedDomainName dnsSafeName = GetNormalizedHostName(name); + return SendQueryWithTelemetry(name, dnsSafeName, QueryType.SRV, ProcessResponse, cancellationToken); + + static (SendQueryError, ServiceResult[]) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response) + { + var results = new List(response.Answers.Count); + + foreach (var answer in response.Answers) + { + if (answer.Type == QueryType.SRV) + { + if (!DnsPrimitives.TryReadService(answer.Data, out ushort priority, out ushort weight, out ushort port, out EncodedDomainName target, out int bytesRead) || bytesRead != answer.Data.Length) + { + return (SendQueryError.MalformedResponse, []); + } + + List addresses = new List(); + foreach (var additional in response.Additionals) + { + // From RFC 2782: + // + // Target + // The domain name of the target host. There MUST be one or more + // address records for this name, the name MUST NOT be an alias (in + // the sense of RFC 1034 or RFC 2181). Implementors are urged, but + // not required, to return the address record(s) in the Additional + // Data section. Unless and until permitted by future standards + // action, name compression is not to be used for this field. + // + // A Target of "." means that the service is decidedly not + // available at this domain. + if (additional.Name.Equals(target) && (additional.Type == QueryType.A || additional.Type == QueryType.AAAA)) + { + addresses.Add(new AddressResult(response.CreatedAt.AddSeconds(additional.Ttl), new IPAddress(additional.Data.Span))); + } + } + + results.Add(new ServiceResult(response.CreatedAt.AddSeconds(answer.Ttl), priority, weight, port, target.ToString(), addresses.ToArray())); + } + } + + return (SendQueryError.NoError, results.ToArray()); + } + } + + public async ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) + { + if (string.Equals(name, "localhost", StringComparison.OrdinalIgnoreCase)) + { + // name localhost exists outside of DNS and can't be resolved by a DNS server + int len = (Socket.OSSupportsIPv4 ? 1 : 0) + (Socket.OSSupportsIPv6 ? 1 : 0); + AddressResult[] res = new AddressResult[len]; + + int index = 0; + if (Socket.OSSupportsIPv6) // prefer IPv6 + { + res[index] = new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback); + index++; + } + if (Socket.OSSupportsIPv4) + { + res[index] = new AddressResult(DateTime.MaxValue, IPAddress.Loopback); + } + + return res; + } + + var ipv4AddressesTask = ResolveIPAddressesAsync(name, AddressFamily.InterNetwork, cancellationToken); + var ipv6AddressesTask = ResolveIPAddressesAsync(name, AddressFamily.InterNetworkV6, cancellationToken); + + AddressResult[] ipv4Addresses = await ipv4AddressesTask.ConfigureAwait(false); + AddressResult[] ipv6Addresses = await ipv6AddressesTask.ConfigureAwait(false); + + AddressResult[] results = new AddressResult[ipv4Addresses.Length + ipv6Addresses.Length]; + ipv6Addresses.CopyTo(results, 0); + ipv4Addresses.CopyTo(results, ipv6Addresses.Length); + return results; + } + + internal ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed, this); + cancellationToken.ThrowIfCancellationRequested(); + + if (addressFamily != AddressFamily.InterNetwork && addressFamily != AddressFamily.InterNetworkV6) + { + throw new ArgumentOutOfRangeException(nameof(addressFamily), addressFamily, "Invalid address family"); + } + + if (string.Equals(name, "localhost", StringComparison.OrdinalIgnoreCase)) + { + // name localhost exists outside of DNS and can't be resolved by a DNS server + if (addressFamily == AddressFamily.InterNetwork && Socket.OSSupportsIPv4) + { + return ValueTask.FromResult([new AddressResult(DateTime.MaxValue, IPAddress.Loopback)]); + } + else if (addressFamily == AddressFamily.InterNetworkV6 && Socket.OSSupportsIPv6) + { + return ValueTask.FromResult([new AddressResult(DateTime.MaxValue, IPAddress.IPv6Loopback)]); + } + + return ValueTask.FromResult([]); + } + + // dnsSafeName is Disposed by SendQueryWithTelemetry + EncodedDomainName dnsSafeName = GetNormalizedHostName(name); + var queryType = addressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; + return SendQueryWithTelemetry(name, dnsSafeName, queryType, ProcessResponse, cancellationToken); + + static (SendQueryError error, AddressResult[] result) ProcessResponse(EncodedDomainName dnsSafeName, QueryType queryType, DnsResponse response) + { + List results = new List(response.Answers.Count); + + // Servers send back CNAME records together with associated A/AAAA records. Servers + // send only those CNAME records relevant to the query, and if there is a CNAME record, + // there should not be other records associated with the name. Therefore, we simply follow + // the list of CNAME aliases until we get to the primary name and return the A/AAAA records + // associated. + // + // more info: https://datatracker.ietf.org/doc/html/rfc1034#section-3.6.2 + // + // Most of the servers send the CNAME records in order so that we can sequentially scan the + // answers, but nothing prevents the records from being in arbitrary order. Attempt the linear + // scan first and fallback to a slower but more robust method if necessary. + + bool success = true; + EncodedDomainName currentAlias = dnsSafeName; + + foreach (var answer in response.Answers) + { + switch (answer.Type) + { + case QueryType.CNAME: + if (!TryReadTarget(answer, response.RawMessageBytes, out EncodedDomainName target)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (answer.Name.Equals(currentAlias)) + { + currentAlias = target; + continue; + } + + break; + + case var type when type == queryType: + if (!TryReadAddress(answer, queryType, out IPAddress? address)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (answer.Name.Equals(currentAlias)) + { + results.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), address)); + continue; + } + + break; + } + + // unexpected name or record type, fall back to more robust path + results.Clear(); + success = false; + break; + } + + if (success) + { + return (SendQueryError.NoError, results.ToArray()); + } + + // more expensive path for uncommon (but valid) cases where CNAME records are out of order. Use of Dictionary + // allows us to stay within O(n) complexity for the number of answers, but we will use more memory. + Dictionary aliasMap = new(); + Dictionary> aRecordMap = new(); + foreach (var answer in response.Answers) + { + if (answer.Type == QueryType.CNAME) + { + // map the alias to the target name + if (!TryReadTarget(answer, response.RawMessageBytes, out EncodedDomainName target)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (!aliasMap.TryAdd(answer.Name, target)) + { + // Duplicate CNAME record + return (SendQueryError.MalformedResponse, []); + } + } + + if (answer.Type == queryType) + { + if (!TryReadAddress(answer, queryType, out IPAddress? address)) + { + return (SendQueryError.MalformedResponse, []); + } + + if (!aRecordMap.TryGetValue(answer.Name, out List? addressList)) + { + addressList = new List(); + aRecordMap.Add(answer.Name, addressList); + } + + addressList.Add(new AddressResult(response.CreatedAt.AddSeconds(answer.Ttl), address)); + } + } + + // follow the CNAME chain, limit the maximum number of iterations to avoid infinite loops. + int i = 0; + currentAlias = dnsSafeName; + while (aliasMap.TryGetValue(currentAlias, out EncodedDomainName nextAlias)) + { + if (i >= aliasMap.Count) + { + // circular CNAME chain + return (SendQueryError.MalformedResponse, []); + } + + i++; + + if (aRecordMap.ContainsKey(currentAlias)) + { + // both CNAME record and A/AAAA records exist for the current alias + return (SendQueryError.MalformedResponse, []); + } + + currentAlias = nextAlias; + } + + // Now we have the final target name, check if we have any A/AAAA records for it. + aRecordMap.TryGetValue(currentAlias, out List? finalAddressList); + return (SendQueryError.NoError, finalAddressList?.ToArray() ?? []); + + static bool TryReadTarget(in DnsResourceRecord record, ArraySegment messageBytes, out EncodedDomainName target) + { + Debug.Assert(record.Type == QueryType.CNAME, "Only CNAME records should be processed here."); + + target = default; + + // some servers use domain name compression even inside CNAME records. In order to decode those + // correctly, we need to pass the entire message to TryReadQName. The Data span inside the record + // should be backed by the array containing the entire DNS message. We just need to account for the + // 2 byte offset in case of TCP fallback. + var gotArray = MemoryMarshal.TryGetArray(record.Data, out ArraySegment segment); + Debug.Assert(gotArray, "Failed to get array segment"); + Debug.Assert(segment.Array == messageBytes.Array, "record data backed by different array than the original message"); + + int messageOffset = messageBytes.Offset; + + bool result = DnsPrimitives.TryReadQName(segment.Array.AsMemory(messageOffset, segment.Offset + segment.Count - messageOffset), segment.Offset - messageOffset, out EncodedDomainName targetName, out int bytesRead) && bytesRead == record.Data.Length; + if (result) + { + target = targetName; + } + + return result; + } + + static bool TryReadAddress(in DnsResourceRecord record, QueryType type, [NotNullWhen(true)] out IPAddress? target) + { + Debug.Assert(record.Type is QueryType.A or QueryType.AAAA, "Only CNAME records should be processed here."); + + target = null; + if (record.Type == QueryType.A && record.Data.Length != IPv4Length || + record.Type == QueryType.AAAA && record.Data.Length != IPv6Length) + { + return false; + } + + target = new IPAddress(record.Data.Span); + return true; + } + } + } + + private async ValueTask SendQueryWithTelemetry(string name, EncodedDomainName dnsSafeName, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) + { + NameResolutionActivity activity = Telemetry.StartNameResolution(name, queryType, _timeProvider.GetTimestamp()); + (SendQueryError error, TResult[] result) = await SendQueryWithRetriesAsync(name, dnsSafeName, queryType, processResponseFunc, cancellationToken).ConfigureAwait(false); + Telemetry.StopNameResolution(name, queryType, activity, null, error, _timeProvider.GetTimestamp()); + dnsSafeName.Dispose(); + + return result; + } + + internal struct SendQueryResult + { + public DnsResponse Response; + public SendQueryError Error; + } + + async ValueTask<(SendQueryError error, TResult[] result)> SendQueryWithRetriesAsync(string name, EncodedDomainName dnsSafeName, QueryType queryType, Func processResponseFunc, CancellationToken cancellationToken) + { + SendQueryError lastError = SendQueryError.InternalError; // will be overwritten by the first attempt + for (int index = 0; index < _options.Servers.Count; index++) + { + IPEndPoint serverEndPoint = _options.Servers[index]; + + for (int attempt = 1; attempt <= _options.Attempts; attempt++) + { + DnsResponse response = default; + try + { + TResult[] results = Array.Empty(); + + try + { + SendQueryResult queryResult = await SendQueryToServerWithTimeoutAsync(serverEndPoint, name, dnsSafeName, queryType, attempt, cancellationToken).ConfigureAwait(false); + lastError = queryResult.Error; + response = queryResult.Response; + + if (lastError == SendQueryError.NoError) + { + // Given that result.Error is NoError, there should be at least one answer. + Debug.Assert(response.Answers.Count > 0); + (lastError, results) = processResponseFunc(dnsSafeName, queryType, queryResult.Response); + } + } + catch (SocketException ex) + { + Log.NetworkError(_logger, queryType, name, serverEndPoint, attempt, ex); + lastError = SendQueryError.NetworkError; + } + catch (Exception ex) when (!cancellationToken.IsCancellationRequested) + { + // internal error, propagate + Log.QueryError(_logger, queryType, name, serverEndPoint, attempt, ex); + throw; + } + + switch (lastError) + { + // + // Definitive answers, no point retrying + // + case SendQueryError.NoError: + return (lastError, results); + + case SendQueryError.NameError: + // authoritative answer that the name does not exist, no point in retrying + Log.NameError(_logger, queryType, name, serverEndPoint, attempt); + return (lastError, results); + + case SendQueryError.NoData: + // no data available for the name from authoritative server + Log.NoData(_logger, queryType, name, serverEndPoint, attempt); + return (lastError, results); + + // + // Transient errors, retry on the same server + // + case SendQueryError.Timeout: + Log.Timeout(_logger, queryType, name, serverEndPoint, attempt); + continue; + + case SendQueryError.NetworkError: + // TODO: retry with exponential backoff? + continue; + + case SendQueryError.ServerError when response.Header.ResponseCode == QueryResponseCode.ServerFailure: + // ServerFailure may indicate transient failure with upstream DNS servers, retry on the same server + Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, response.Header.ResponseCode); + continue; + + // + // Persistent errors, skip to the next server + // + case SendQueryError.ServerError: + // this should cover all response codes except NoError, NameError which are definite and handled above, and + // ServerFailure which is a transient error and handled above. + Log.ErrorResponseCode(_logger, queryType, name, serverEndPoint, response.Header.ResponseCode); + break; + + case SendQueryError.MalformedResponse: + Log.MalformedResponse(_logger, queryType, name, serverEndPoint, attempt); + break; + + case SendQueryError.InternalError: + // exception logged above. + break; + } + + // actual break that causes skipping to the next server + break; + } + finally + { + response.Dispose(); + } + } + } + + // if we get here, we exhausted all servers and all attempts + return (lastError, []); + } + + internal async ValueTask SendQueryToServerWithTimeoutAsync(IPEndPoint serverEndPoint, string name, EncodedDomainName dnsSafeName, QueryType queryType, int attempt, CancellationToken cancellationToken) + { + (CancellationTokenSource cts, bool disposeTokenSource, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken); + + try + { + return await SendQueryToServerAsync(serverEndPoint, name, dnsSafeName, queryType, attempt, cts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) when ( + !cancellationToken.IsCancellationRequested && // not cancelled by the caller + !pendingRequestsCts.IsCancellationRequested) // not cancelled by the global token (dispose) + // the only remaining token that could cancel this is the linked cts from the timeout. + { + Debug.Assert(cts.Token.IsCancellationRequested); + return new SendQueryResult { Error = SendQueryError.Timeout }; + } + catch (OperationCanceledException ex) when (cancellationToken.IsCancellationRequested && ex.CancellationToken != cancellationToken) + { + // cancellation was initiated by the caller, but exception was triggered by a linked token, + // rethrow the exception with the caller's token. + cancellationToken.ThrowIfCancellationRequested(); + throw new UnreachableException(); + } + finally + { + if (disposeTokenSource) + { + cts.Dispose(); + } + } + } + + private async ValueTask SendQueryToServerAsync(IPEndPoint serverEndPoint, string name, EncodedDomainName dnsSafeName, QueryType queryType, int attempt, CancellationToken cancellationToken) + { + Log.Query(_logger, queryType, name, serverEndPoint, attempt); + + SendQueryError sendError = SendQueryError.NoError; + DateTime queryStartedTime = _timeProvider.GetUtcNow().DateTime; + DnsDataReader responseReader = default; + DnsMessageHeader header; + + try + { + // use transport override if provided + if (_options._transportOverride != null) + { + (responseReader, header, sendError) = SendDnsQueryCustomTransport(_options._transportOverride, dnsSafeName, queryType); + } + else + { + (responseReader, header) = await SendDnsQueryCoreUdpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); + + if (header.IsResultTruncated) + { + Log.ResultTruncated(_logger, queryType, name, serverEndPoint, 0); + responseReader.Dispose(); + // TCP fallback + (responseReader, header, sendError) = await SendDnsQueryCoreTcpAsync(serverEndPoint, dnsSafeName, queryType, cancellationToken).ConfigureAwait(false); + } + } + + if (sendError != SendQueryError.NoError) + { + // we failed to get back any response + return new SendQueryResult { Error = sendError }; + } + + if ((uint)header.ResponseCode > (uint)QueryResponseCode.Refused) + { + // Response code is outside of valid range + return new SendQueryResult + { + Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.MalformedResponse + }; + } + + // Recheck that the server echoes back the DNS question + if (header.QueryCount != 1 || + !responseReader.TryReadQuestion(out var qName, out var qType, out var qClass) || + !dnsSafeName.Equals(qName) || qType != queryType || qClass != QueryClass.Internet) + { + // DNS Question mismatch + return new SendQueryResult + { + Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.MalformedResponse + }; + } + + // Structurally separate the resource records, this will validate only the + // "outside structure" of the resource record, it will not validate the content. + int ttl = int.MaxValue; + if (!TryReadRecords(header.AnswerCount, ref ttl, ref responseReader, out List? answers) || + !TryReadRecords(header.AuthorityCount, ref ttl, ref responseReader, out List? authorities) || + !TryReadRecords(header.AdditionalRecordCount, ref ttl, ref responseReader, out List? additionals)) + { + return new SendQueryResult + { + Response = new DnsResponse(ArraySegment.Empty, header, queryStartedTime, queryStartedTime, null!, null!, null!), + Error = SendQueryError.MalformedResponse + }; + } + + DateTime expirationTime = + (answers.Count + authorities.Count + additionals.Count) > 0 ? queryStartedTime.AddSeconds(ttl) : queryStartedTime; + + SendQueryError validationError = ValidateResponse(header.ResponseCode, queryStartedTime, answers, authorities, ref expirationTime); + + // we transfer ownership of RawData to the response + DnsResponse response = new DnsResponse(responseReader.MessageBuffer, header, queryStartedTime, expirationTime, answers, authorities, additionals); + responseReader = default; // avoid disposing (and returning RawData to the pool) + + return new SendQueryResult { Response = response, Error = validationError }; + } + finally + { + responseReader.Dispose(); + } + + static bool TryReadRecords(int count, ref int ttl, ref DnsDataReader reader, out List records) + { + // Since `count` is attacker controlled, limit the initial capacity + // to 32 items to avoid excessive memory allocation. More than 32 + // records are unusual so we don't need to optimize for them. + records = new(Math.Min(count, 32)); + + for (int i = 0; i < count; i++) + { + if (!reader.TryReadResourceRecord(out var record)) + { + return false; + } + + ttl = Math.Min(ttl, record.Ttl); + records.Add(new DnsResourceRecord(record.Name, record.Type, record.Class, record.Ttl, record.Data)); + } + + return true; + } + } + + internal static bool GetNegativeCacheExpiration(DateTime createdAt, List authorities, out DateTime expiration) + { + // + // RFC 2308 Section 5 - Caching Negative Answers + // + // Like normal answers negative answers have a time to live (TTL). As + // there is no record in the answer section to which this TTL can be + // applied, the TTL must be carried by another method. This is done by + // including the SOA record from the zone in the authority section of + // the reply. When the authoritative server creates this record its TTL + // is taken from the minimum of the SOA.MINIMUM field and SOA's TTL. + // This TTL decrements in a similar manner to a normal cached answer and + // upon reaching zero (0) indicates the cached negative answer MUST NOT + // be used again. + // + + DnsResourceRecord? soa = authorities.FirstOrDefault(r => r.Type == QueryType.SOA); + if (soa != null && DnsPrimitives.TryReadSoa(soa.Value.Data, out _, out _, out _, out _, out _, out _, out uint minimum, out _)) + { + expiration = createdAt.AddSeconds(Math.Min(minimum, soa.Value.Ttl)); + return true; + } + + expiration = default; + return false; + } + + internal static SendQueryError ValidateResponse(QueryResponseCode responseCode, DateTime createdAt, List answers, List authorities, ref DateTime expiration) + { + if (responseCode == QueryResponseCode.NoError) + { + if (answers.Count > 0) + { + return SendQueryError.NoError; + } + // + // RFC 2308 Section 2.2 - No Data + // + // NODATA is indicated by an answer with the RCODE set to NOERROR and no + // relevant answers in the answer section. The authority section will + // contain an SOA record, or there will be no NS records there. + // + // + // RFC 2308 Section 5 - Caching Negative Answers + // + // A negative answer that resulted from a no data error (NODATA) should + // be cached such that it can be retrieved and returned in response to + // another query for the same that resulted in + // the cached negative response. + // + if (!authorities.Any(r => r.Type == QueryType.NS) && GetNegativeCacheExpiration(createdAt, authorities, out DateTime newExpiration)) + { + expiration = newExpiration; + // _cache.TryAdd(name, queryType, expiration, Array.Empty()); + } + return SendQueryError.NoData; + } + + if (responseCode == QueryResponseCode.NameError) + { + // + // RFC 2308 Section 5 - Caching Negative Answers + // + // A negative answer that resulted from a name error (NXDOMAIN) should + // be cached such that it can be retrieved and returned in response to + // another query for the same that resulted in the + // cached negative response. + // + if (GetNegativeCacheExpiration(createdAt, authorities, out DateTime newExpiration)) + { + expiration = newExpiration; + // _cache.TryAddNonexistent(name, expiration); + } + + return SendQueryError.NameError; + } + + return SendQueryError.ServerError; + } + + internal static (DnsDataReader reader, DnsMessageHeader header, SendQueryError sendError) SendDnsQueryCustomTransport(Func, int, int> callback, EncodedDomainName dnsSafeName, QueryType queryType) + { + byte[] buffer = ArrayPool.Shared.Rent(2048); + try + { + (ushort transactionId, int length) = EncodeQuestion(buffer, dnsSafeName, queryType); + length = callback(buffer, length); + + DnsDataReader responseReader = new DnsDataReader(new ArraySegment(buffer, 0, length), true); + + if (!responseReader.TryReadHeader(out DnsMessageHeader header) || + header.TransactionId != transactionId || + !header.IsResponse) + { + return (default, default, SendQueryError.MalformedResponse); + } + + // transfer ownership of buffer to the caller + buffer = null!; + return (responseReader, header, SendQueryError.NoError); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header)> SendDnsQueryCoreUdpAsync(IPEndPoint serverEndPoint, EncodedDomainName dnsSafeName, QueryType queryType, CancellationToken cancellationToken) + { + var buffer = ArrayPool.Shared.Rent(512); + try + { + Memory memory = buffer; + (ushort transactionId, int length) = EncodeQuestion(memory, dnsSafeName, queryType); + + using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + await socket.SendToAsync(memory.Slice(0, length), SocketFlags.None, serverEndPoint, cancellationToken).ConfigureAwait(false); + + DnsDataReader responseReader; + DnsMessageHeader header; + + while (true) + { + // Because this is UDP, the response must be in a single packet, + // if the response does not fit into a single UDP packet, the server will + // set the Truncated flag in the header, and we will need to retry with TCP. + int packetLength = await socket.ReceiveAsync(memory, SocketFlags.None, cancellationToken).ConfigureAwait(false); + + if (packetLength < DnsMessageHeader.HeaderLength) + { + continue; + } + + responseReader = new DnsDataReader(new ArraySegment(buffer, 0, packetLength), true); + if (!responseReader.TryReadHeader(out header) || + header.TransactionId != transactionId || + !header.IsResponse) + { + // header mismatch, this is not a response to our query + continue; + } + + // ownership of the buffer is transferred to the reader, caller will dispose. + buffer = null!; + return (responseReader, header); + } + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + + internal static async ValueTask<(DnsDataReader reader, DnsMessageHeader header, SendQueryError error)> SendDnsQueryCoreTcpAsync(IPEndPoint serverEndPoint, EncodedDomainName dnsSafeName, QueryType queryType, CancellationToken cancellationToken) + { + var buffer = ArrayPool.Shared.Rent(8 * 1024); + try + { + // When sending over TCP, the message is prefixed by 2B length + (ushort transactionId, int length) = EncodeQuestion(buffer.AsMemory(2), dnsSafeName, queryType); + BinaryPrimitives.WriteUInt16BigEndian(buffer, (ushort)length); + + using var socket = new Socket(serverEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(serverEndPoint, cancellationToken).ConfigureAwait(false); + await socket.SendAsync(buffer.AsMemory(0, length + 2), SocketFlags.None, cancellationToken).ConfigureAwait(false); + + int responseLength = -1; + int bytesRead = 0; + while (responseLength < 0 || bytesRead < responseLength + 2) + { + int read = await socket.ReceiveAsync(buffer.AsMemory(bytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false); + bytesRead += read; + + if (read == 0) + { + // connection closed before receiving complete response message + return (default, default, SendQueryError.MalformedResponse); + } + + if (responseLength < 0 && bytesRead >= 2) + { + responseLength = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); + + if (responseLength + 2 > buffer.Length) + { + // even though this is user-controlled pre-allocation, it is limited to + // 64 kB, so it should be fine. + var largerBuffer = ArrayPool.Shared.Rent(responseLength + 2); + Array.Copy(buffer, largerBuffer, bytesRead); + ArrayPool.Shared.Return(buffer); + buffer = largerBuffer; + } + } + } + + DnsDataReader responseReader = new DnsDataReader(new ArraySegment(buffer, 2, responseLength), true); + if (!responseReader.TryReadHeader(out DnsMessageHeader header) || + header.TransactionId != transactionId || + !header.IsResponse) + { + // header mismatch on TCP fallback + return (default, default, SendQueryError.MalformedResponse); + } + + // transfer ownership of buffer to the caller + buffer = null!; + return (responseReader, header, SendQueryError.NoError); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + + private static (ushort id, int length) EncodeQuestion(Memory buffer, EncodedDomainName dnsSafeName, QueryType queryType) + { + DnsMessageHeader header = new DnsMessageHeader + { + TransactionId = (ushort)RandomNumberGenerator.GetInt32(ushort.MaxValue + 1), + QueryFlags = QueryFlags.RecursionDesired, + QueryCount = 1 + }; + + DnsDataWriter writer = new DnsDataWriter(buffer); + if (!writer.TryWriteHeader(header) || + !writer.TryWriteQuestion(dnsSafeName, queryType, QueryClass.Internet)) + { + // should never happen since we validated the name length before + throw new InvalidOperationException("Buffer too small"); + } + return (header.TransactionId, writer.Position); + } + + public void Dispose() + { + if (!_disposed) + { + _disposed = true; + + // Cancel all pending requests (if any). Note that we don't call CancelPendingRequests() but cancel + // the CTS directly. The reason is that CancelPendingRequests() would cancel the current CTS and create + // a new CTS. We don't want a new CTS in this case. + _pendingRequestsCts.Cancel(); + _pendingRequestsCts.Dispose(); + } + } + + private (CancellationTokenSource TokenSource, bool DisposeTokenSource, CancellationTokenSource PendingRequestsCts) PrepareCancellationTokenSource(CancellationToken cancellationToken) + { + // We need a CancellationTokenSource to use with the request. We always have the global + // _pendingRequestsCts to use, plus we may have a token provided by the caller, and we may + // have a timeout. If we have a timeout or a caller-provided token, we need to create a new + // CTS (we can't, for example, timeout the pending requests CTS, as that could cancel other + // unrelated operations). Otherwise, we can use the pending requests CTS directly. + + // Snapshot the current pending requests cancellation source. It can change concurrently due to cancellation being requested + // and it being replaced, and we need a stable view of it: if cancellation occurs and the caller's token hasn't been canceled, + // it's either due to this source or due to the timeout, and checking whether this source is the culprit is reliable whereas + // it's more approximate checking elapsed time. + CancellationTokenSource pendingRequestsCts = _pendingRequestsCts; + TimeSpan timeout = _options.Timeout; + + bool hasTimeout = timeout != System.Threading.Timeout.InfiniteTimeSpan; + if (hasTimeout || cancellationToken.CanBeCanceled) + { + CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, pendingRequestsCts.Token); + if (hasTimeout) + { + cts.CancelAfter(timeout); + } + + return (cts, DisposeTokenSource: true, pendingRequestsCts); + } + + return (pendingRequestsCts, DisposeTokenSource: false, pendingRequestsCts); + } + + private static EncodedDomainName GetNormalizedHostName(string name) + { + byte[] buffer = ArrayPool.Shared.Rent(256); + try + { + if (!DnsPrimitives.TryWriteQName(buffer, name, out _)) + { + throw new ArgumentException($"'{name}' is not a valid DNS name.", nameof(name)); + } + + List> labels = new(); + Memory memory = buffer.AsMemory(); + while (true) + { + int len = memory.Span[0]; + + if (len == 0) + { + // root label, we are finished + break; + } + + labels.Add(memory.Slice(1, len)); + memory = memory.Slice(len + 1); + } + + buffer = null!; // ownership transferred to the EncodedDomainName + return new EncodedDomainName(labels, buffer); + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs new file mode 100644 index 00000000000..914ff9aac17 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResourceRecord.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct DnsResourceRecord +{ + public EncodedDomainName Name { get; } + public QueryType Type { get; } + public QueryClass Class { get; } + public int Ttl { get; } + public ReadOnlyMemory Data { get; } + + public DnsResourceRecord(EncodedDomainName name, QueryType type, QueryClass @class, int ttl, ReadOnlyMemory data) + { + Name = name; + Type = type; + Class = @class; + Ttl = ttl; + Data = data; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs new file mode 100644 index 00000000000..5a7fc8a0b52 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/DnsResponse.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct DnsResponse : IDisposable +{ + public DnsMessageHeader Header { get; } + public List Answers { get; } + public List Authorities { get; } + public List Additionals { get; } + public DateTime CreatedAt { get; } + public DateTime Expiration { get; } + public ArraySegment RawMessageBytes { get; private set; } + + public DnsResponse(ArraySegment rawData, DnsMessageHeader header, DateTime createdAt, DateTime expiration, List answers, List authorities, List additionals) + { + RawMessageBytes = rawData; + + Header = header; + CreatedAt = createdAt; + Expiration = expiration; + Answers = answers; + Authorities = authorities; + Additionals = additionals; + } + + public void Dispose() + { + if (RawMessageBytes.Array != null) + { + ArrayPool.Shared.Return(RawMessageBytes.Array); + } + + RawMessageBytes = default; // prevent further access to the raw data + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs new file mode 100644 index 00000000000..4c258cac3ac --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/EncodedDomainName.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal struct EncodedDomainName : IEquatable, IDisposable +{ + public IReadOnlyList> Labels { get; } + private byte[]? _pooledBuffer; + + public EncodedDomainName(List> labels, byte[]? pooledBuffer = null) + { + Labels = labels; + _pooledBuffer = pooledBuffer; + } + public override string ToString() + { + StringBuilder sb = new StringBuilder(); + + foreach (var label in Labels) + { + if (sb.Length > 0) + { + sb.Append('.'); + } + sb.Append(Encoding.ASCII.GetString(label.Span)); + } + + return sb.ToString(); + } + + public bool Equals(EncodedDomainName other) + { + if (Labels.Count != other.Labels.Count) + { + return false; + } + + for (int i = 0; i < Labels.Count; i++) + { + if (!Ascii.EqualsIgnoreCase(Labels[i].Span, other.Labels[i].Span)) + { + return false; + } + } + + return true; + } + + public override bool Equals(object? obj) + { + return obj is EncodedDomainName other && Equals(other); + } + + public override int GetHashCode() + { + HashCode hash = new HashCode(); + + foreach (var label in Labels) + { + foreach (byte b in label.Span) + { + hash.Add((byte)char.ToLower((char)b)); + } + } + + return hash.ToHashCode(); + } + + public void Dispose() + { + if (_pooledBuffer != null) + { + ArrayPool.Shared.Return(_pooledBuffer); + } + + _pooledBuffer = null; + } +} \ No newline at end of file diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs new file mode 100644 index 00000000000..080fe3be8de --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/IDnsResolver.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal interface IDnsResolver +{ + ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default); + ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default); +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs new file mode 100644 index 00000000000..c2ef13f922e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/NetworkInfo.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.NetworkInformation; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal static class NetworkInfo +{ + // basic option to get DNS serves via NetworkInfo. We may get it directly later via proper APIs. + public static ResolverOptions GetOptions() + { + List servers = new List(); + + foreach (NetworkInterface nic in NetworkInterface.GetAllNetworkInterfaces()) + { + IPInterfaceProperties properties = nic.GetIPProperties(); + // avoid loopback, VPN etc. Should be re-visited. + + if (nic.NetworkInterfaceType == NetworkInterfaceType.Ethernet && nic.OperationalStatus == OperationalStatus.Up) + { + foreach (IPAddress server in properties.DnsAddresses) + { + IPEndPoint ep = new IPEndPoint(server, 53); // 53 is standard DNS port + if (!servers.Contains(ep)) + { + servers.Add(ep); + } + } + } + } + + return new ResolverOptions(servers); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs new file mode 100644 index 00000000000..732ca0216da --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryClass.cs @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal enum QueryClass +{ + Internet = 1 +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs new file mode 100644 index 00000000000..02474b6cda1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryFlags.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +[Flags] +internal enum QueryFlags : ushort +{ + RecursionAvailable = 0x0080, + RecursionDesired = 0x0100, + ResultTruncated = 0x0200, + HasAuthorityAnswer = 0x0400, + HasResponse = 0x8000, + ResponseCodeMask = 0x000F, +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs new file mode 100644 index 00000000000..dd51c712112 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryResponseCode.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +/// +/// The response code (RCODE) in a DNS query response. +/// +internal enum QueryResponseCode : byte +{ + /// + /// No error condition + /// + NoError = 0, + + /// + /// The name server was unable to interpret the query. + /// + FormatError = 1, + + /// + /// The name server was unable to process this query due to a problem with the name server. + /// + ServerFailure = 2, + + /// + /// Meaningful only for responses from an authoritative name server, this + /// code signifies that the domain name referenced in the query does not + /// exist. + /// + NameError = 3, + + /// + /// The name server does not support the requested kind of query. + /// + NotImplemented = 4, + + /// + /// The name server refuses to perform the specified operation for policy reasons. + /// + Refused = 5, +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs new file mode 100644 index 00000000000..2ccc898a5b7 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/QueryType.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +/// +/// DNS Query Types. +/// +internal enum QueryType +{ + /// + /// A host address. + /// + A = 1, + + /// + /// An authoritative name server. + /// + NS = 2, + + /// + /// The canonical name for an alias. + /// + CNAME = 5, + + /// + /// Marks the start of a zone of authority. + /// + SOA = 6, + + /// + /// Mail exchange. + /// + MX = 15, + + /// + /// Text strings. + /// + TXT = 16, + + /// + /// IPv6 host address. (RFC 3596) + /// + AAAA = 28, + + /// + /// Location information. (RFC 2782) + /// + SRV = 33, + + /// + /// Wildcard match. + /// + All = 255 +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs new file mode 100644 index 00000000000..fbfdc5ae027 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolvConf.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Runtime.Versioning; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal static class ResolvConf +{ + [SupportedOSPlatform("linux")] + [SupportedOSPlatform("osx")] + public static ResolverOptions GetOptions() + { + return GetOptions(new StreamReader("/etc/resolv.conf")); + } + + public static ResolverOptions GetOptions(TextReader reader) + { + List serverList = new(); + + while (reader.ReadLine() is string line) + { + string[] tokens = line.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + + if (line.StartsWith("nameserver")) + { + if (tokens.Length >= 2 && IPAddress.TryParse(tokens[1], out IPAddress? address)) + { + serverList.Add(new IPEndPoint(address, 53)); // 53 is standard DNS port + + if (serverList.Count == 3) + { + break; // resolv.conf manpage allow max 3 nameservers anyway + } + } + } + } + + if (serverList.Count == 0) + { + // If no nameservers are configured, fall back to the default behavior of using the system resolver configuration. + return NetworkInfo.GetOptions(); + } + + return new ResolverOptions(serverList); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs new file mode 100644 index 00000000000..51d03f64bfd --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResolverOptions.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal sealed class ResolverOptions +{ + public IReadOnlyList Servers; + public int Attempts = 2; + public TimeSpan Timeout = TimeSpan.FromSeconds(3); + + // override for testing purposes + internal Func, int, int>? _transportOverride; + + public ResolverOptions(IReadOnlyList servers) + { + if (servers.Count == 0) + { + throw new ArgumentException("At least one DNS server is required.", nameof(servers)); + } + + Servers = servers; + } + + public ResolverOptions(IPEndPoint server) + { + Servers = new IPEndPoint[] { server }; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs new file mode 100644 index 00000000000..aed799ac8d6 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/ResultTypes.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal record struct AddressResult(DateTime ExpiresAt, IPAddress Address); + +internal record struct ServiceResult(DateTime ExpiresAt, int Priority, int Weight, int Port, string Target, AddressResult[] Addresses); diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs new file mode 100644 index 00000000000..3ba5632e207 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/Resolver/SendQueryError.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +internal enum SendQueryError +{ + /// + /// DNS query was successful and returned response message with answers. + /// + NoError, + + /// + /// Server failed to respond to the query withing specified timeout. + /// + Timeout, + + /// + /// Server returned a response with an error code. + /// + ServerError, + + /// + /// Server returned a malformed response. + /// + MalformedResponse, + + /// + /// Server returned a response indicating that the name exists, but no data are available. + /// + NoData, + + /// + /// Server returned a response indicating the name does not exist. + /// + NameError, + + /// + /// Network-level error occurred during the query. + /// + NetworkError, + + /// + /// Internal error on part of the implementation. + /// + InternalError, +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs new file mode 100644 index 00000000000..42f220445b1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns/ServiceDiscoveryDnsServiceCollectionExtensions.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.ServiceDiscovery; +using Microsoft.Extensions.ServiceDiscovery.Dns; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; + +namespace Microsoft.Extensions.Hosting; + +/// +/// Extensions for to add service discovery. +/// +public static class ServiceDiscoveryDnsServiceCollectionExtensions +{ + /// + /// Adds DNS SRV service discovery to the . + /// + /// The service collection. + /// The provided . + /// + /// DNS SRV queries are able to provide port numbers for endpoints and can support multiple named endpoints per service. + /// However, not all environment support DNS SRV queries, and in some environments, additional configuration may be required. + /// + public static IServiceCollection AddDnsSrvServiceEndpointProvider(this IServiceCollection services) + { + ArgumentNullException.ThrowIfNull(services); + + return services.AddDnsSrvServiceEndpointProvider(_ => { }); + } + + /// + /// Adds DNS SRV service discovery to the . + /// + /// The service collection. + /// The DNS SRV service discovery configuration options. + /// The provided . + /// + /// DNS SRV queries are able to provide port numbers for endpoints and can support multiple named endpoints per service. + /// However, not all environment support DNS SRV queries, and in some environments, additional configuration may be required. + /// + public static IServiceCollection AddDnsSrvServiceEndpointProvider(this IServiceCollection services, Action configureOptions) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configureOptions); + + services.AddServiceDiscoveryCore(); + + if (!GetDnsClientFallbackFlag()) + { + services.TryAddSingleton(); + } + else + { + services.TryAddSingleton(); + services.TryAddSingleton(); + } + + services.AddSingleton(); + var options = services.AddOptions(); + options.Configure(o => configureOptions?.Invoke(o)); + return services; + + static bool GetDnsClientFallbackFlag() + { + if (AppContext.TryGetSwitch("Microsoft.Extensions.ServiceDiscovery.Dns.UseDnsClientFallback", out var value)) + { + return value; + } + + var envVar = Environment.GetEnvironmentVariable("MICROSOFT_EXTENSIONS_SERVICE_DISCOVERY_DNS_USE_DNSCLIENT_FALLBACK"); + if (envVar is not null && (envVar.Equals("true", StringComparison.OrdinalIgnoreCase) || envVar.Equals("1"))) + { + return true; + } + + return false; + } + } + + /// + /// Adds DNS service discovery to the . + /// + /// The service collection. + /// The provided . + /// + /// DNS A/AAAA queries are widely available but are not able to provide port numbers for endpoints and cannot support multiple named endpoints per service. + /// + public static IServiceCollection AddDnsServiceEndpointProvider(this IServiceCollection services) + { + ArgumentNullException.ThrowIfNull(services); + + return services.AddDnsServiceEndpointProvider(_ => { }); + } + + /// + /// Adds DNS service discovery to the . + /// + /// The service collection. + /// The DNS SRV service discovery configuration options. + /// The provided . + /// + /// DNS A/AAAA queries are widely available but are not able to provide port numbers for endpoints and cannot support multiple named endpoints per service. + /// + public static IServiceCollection AddDnsServiceEndpointProvider(this IServiceCollection services, Action configureOptions) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configureOptions); + + services.AddServiceDiscoveryCore(); + services.AddSingleton(); + var options = services.AddOptions(); + options.Configure(o => configureOptions?.Invoke(o)); + return services; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/Microsoft.Extensions.ServiceDiscovery.Yarp.csproj b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/Microsoft.Extensions.ServiceDiscovery.Yarp.csproj new file mode 100644 index 00000000000..e990866bd16 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/Microsoft.Extensions.ServiceDiscovery.Yarp.csproj @@ -0,0 +1,28 @@ + + + + $(NetCoreTargetFrameworks) + enable + enable + true + Provides extensions for service discovery for the YARP reverse proxy. + $(DefaultDotnetIconFullPath) + + $(NoWarn);IDE0018;IDE0025;IDE0032;IDE0040;IDE0058;IDE0250;IDE0251;IDE1006;CA1304;CA1307;CA1309;CA1310;CA1849;CA2000;CA2213;CA2217;S125;S1135;S1226;S2344;S2692;S3626;S4022;SA1108;SA1120;SA1128;SA1129;SA1204;SA1205;SA1214;SA1400;SA1405;SA1408;SA1414;SA1515;SA1600;SA1615;SA1629;SA1642;SA1649;EA0001;EA0009;EA0014;LA0001;LA0003;LA0008;VSTHRD200 + enable + + + + + + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/README.md b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/README.md new file mode 100644 index 00000000000..a7175f0382c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/README.md @@ -0,0 +1,42 @@ +# Microsoft.Extensions.ServiceDiscovery.Yarp + +The `Microsoft.Extensions.ServiceDiscovery.Yarp` library adds support for resolving endpoints for YARP clusters, by implementing a [YARP destination resolver](https://github.com/microsoft/reverse-proxy/blob/main/docs/docfx/articles/destination-resolvers.md). + +## Usage + +### Resolving YARP cluster destinations using Service Discovery + +The `IReverseProxyBuilder.AddServiceDiscoveryDestinationResolver()` extension method configures a [YARP destination resolver](https://github.com/microsoft/reverse-proxy/blob/main/docs/docfx/articles/destination-resolvers.md). To use this method, you must also configure YARP itself as described in the YARP documentation, and you must configure .NET Service Discovery via the _Microsoft.Extensions.ServiceDiscovery_ library. + +### Direct HTTP forwarding using Service Discovery Forwarding HTTP requests using `IHttpForwarder` + +YARP supports _direct forwarding_ of specific requests using the `IHttpForwarder` interface. This, too, can benefit from service discovery using the _Microsoft.Extensions.ServiceDiscovery_ library. To take advantage of service discovery when using YARP Direct Forwarding, use the `IServiceCollection.AddHttpForwarderWithServiceDiscovery` method. + +For example, consider the following .NET Aspire application: + +```csharp +var builder = WebApplication.CreateBuilder(args); + +// Configure service discovery +builder.Services.AddServiceDiscovery(); + +// Add YARP Direct Forwarding with Service Discovery support +builder.Services.AddHttpForwarderWithServiceDiscovery(); + +// ... other configuration ... + +var app = builder.Build(); + +// ... other configuration ... + +// Map a Direct Forwarder which forwards requests to the resolved "catalogservice" endpoints +app.MapForwarder("/catalog/images/{id}", "http://catalogservice", "/api/v1/catalog/items/{id}/image"); + +app.Run(); +``` + +In the above example, the YARP Direct Forwarder will resolve the _catalogservice_ using service discovery, forwarding request sent to the `/catalog/images/{id}` endpoint to the destination path on the resolved endpoints. + +## Feedback & contributing + +https://github.com/dotnet/aspire diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/ServiceDiscoveryDestinationResolver.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/ServiceDiscoveryDestinationResolver.cs new file mode 100644 index 00000000000..2ca456ec911 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/ServiceDiscoveryDestinationResolver.cs @@ -0,0 +1,128 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Yarp.ReverseProxy.Configuration; +using Yarp.ReverseProxy.ServiceDiscovery; + +namespace Microsoft.Extensions.ServiceDiscovery.Yarp; + +/// +/// Implementation of which resolves destinations using service discovery. +/// +/// +/// Initializes a new instance. +/// +/// The endpoint resolver registry. +/// The service discovery options. +internal sealed class ServiceDiscoveryDestinationResolver(ServiceEndpointResolver resolver, IOptions options) : IDestinationResolver +{ + private readonly ServiceDiscoveryOptions _options = options.Value; + + /// + public async ValueTask ResolveDestinationsAsync(IReadOnlyDictionary destinations, CancellationToken cancellationToken) + { + Dictionary results = new(); + var tasks = new List, IChangeToken ChangeToken)>>(destinations.Count); + foreach (var (destinationId, destinationConfig) in destinations) + { + tasks.Add(ResolveHostAsync(destinationId, destinationConfig, cancellationToken)); + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + var changeTokens = new List(); + foreach (var task in tasks) + { + var (configs, changeToken) = await task.ConfigureAwait(false); + if (changeToken is not null) + { + changeTokens.Add(changeToken); + } + + foreach (var (name, config) in configs) + { + results[name] = config; + } + } + + return new ResolvedDestinationCollection(results, new CompositeChangeToken(changeTokens)); + } + + private async Task<(List<(string Name, DestinationConfig Config)>, IChangeToken ChangeToken)> ResolveHostAsync( + string originalName, + DestinationConfig originalConfig, + CancellationToken cancellationToken) + { + var originalUri = new Uri(originalConfig.Address); + var serviceName = originalUri.GetLeftPart(UriPartial.Authority); + + var result = await resolver.GetEndpointsAsync(serviceName, cancellationToken).ConfigureAwait(false); + var results = new List<(string Name, DestinationConfig Config)>(result.Endpoints.Count); + var uriBuilder = new UriBuilder(originalUri); + var healthUri = originalConfig.Health is { Length: > 0 } health ? new Uri(health) : null; + var healthUriBuilder = healthUri is { } ? new UriBuilder(healthUri) : null; + foreach (var endpoint in result.Endpoints) + { + var addressString = endpoint.ToString()!; + Uri uri; + if (!addressString.Contains("://")) + { + var scheme = GetDefaultScheme(originalUri); + uri = new Uri($"{scheme}://{addressString}"); + } + else + { + uri = new Uri(addressString); + } + + uriBuilder.Scheme = uri.Scheme; + uriBuilder.Host = uri.Host; + uriBuilder.Port = uri.Port; + var resolvedAddress = uriBuilder.Uri.ToString(); + var healthAddress = originalConfig.Health; + if (healthUriBuilder is not null) + { + healthUriBuilder.Host = uri.Host; + healthUriBuilder.Port = uri.Port; + healthAddress = healthUriBuilder.Uri.ToString(); + } + + var name = $"{originalName}[{addressString}]"; + string? resolvedHost = null; + + // Use the configured 'Host' value if it is provided. + if (!string.IsNullOrEmpty(originalConfig.Host)) + { + resolvedHost = originalConfig.Host; + } + + var config = originalConfig with { Host = resolvedHost, Address = resolvedAddress, Health = healthAddress }; + results.Add((name, config)); + } + + return (results, result.ChangeToken); + } + + private string GetDefaultScheme(Uri originalUri) + { + if (originalUri.Scheme.IndexOf('+') > 0) + { + // Use the first allowed scheme. + var specifiedSchemes = originalUri.Scheme.Split('+'); + foreach (var scheme in specifiedSchemes) + { + if (_options.AllowAllSchemes || _options.AllowedSchemes.Contains(scheme, StringComparer.OrdinalIgnoreCase)) + { + return scheme; + } + } + + throw new InvalidOperationException($"None of the specified schemes ('{string.Join(", ", specifiedSchemes)}') are allowed by configuration."); + } + else + { + return originalUri.Scheme; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/ServiceDiscoveryForwarderHttpClientFactory.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/ServiceDiscoveryForwarderHttpClientFactory.cs new file mode 100644 index 00000000000..84aafe2a67e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/ServiceDiscoveryForwarderHttpClientFactory.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.ServiceDiscovery.Http; +using Yarp.ReverseProxy.Forwarder; + +namespace Microsoft.Extensions.ServiceDiscovery.Yarp; + +internal sealed class ServiceDiscoveryForwarderHttpClientFactory(IServiceDiscoveryHttpMessageHandlerFactory handlerFactory) + : ForwarderHttpClientFactory +{ + protected override HttpMessageHandler WrapHandler(ForwarderHttpClientContext context, HttpMessageHandler handler) + { + return handlerFactory.CreateHandler(handler); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/ServiceDiscoveryReverseProxyServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/ServiceDiscoveryReverseProxyServiceCollectionExtensions.cs new file mode 100644 index 00000000000..de74dc0fc24 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp/ServiceDiscoveryReverseProxyServiceCollectionExtensions.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.ServiceDiscovery.Yarp; +using Yarp.ReverseProxy.Forwarder; +using Yarp.ReverseProxy.ServiceDiscovery; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extensions for used to register the ReverseProxy's components. +/// +public static class ServiceDiscoveryReverseProxyServiceCollectionExtensions +{ + /// + /// Provides a implementation which uses service discovery to resolve destinations. + /// + public static IReverseProxyBuilder AddServiceDiscoveryDestinationResolver(this IReverseProxyBuilder builder) + { + ArgumentNullException.ThrowIfNull(builder); + + builder.Services.AddServiceDiscoveryCore(); + builder.Services.AddSingleton(); + return builder; + } + + /// + /// Adds the with service discovery support. + /// + public static IServiceCollection AddHttpForwarderWithServiceDiscovery(this IServiceCollection services) + { + ArgumentNullException.ThrowIfNull(services); + + return services.AddHttpForwarder().AddServiceDiscoveryForwarderFactory(); + } + + /// + /// Provides a implementation which uses service discovery to resolve service names. + /// + public static IServiceCollection AddServiceDiscoveryForwarderFactory(this IServiceCollection services) + { + ArgumentNullException.ThrowIfNull(services); + + services.AddServiceDiscoveryCore(); + services.AddSingleton(); + return services; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.Log.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.Log.cs new file mode 100644 index 00000000000..b27c5ea9190 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.Log.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery.Configuration; + +internal sealed partial class ConfigurationServiceEndpointProvider +{ + private static partial class Log + { + [LoggerMessage(1, LogLevel.Debug, "Skipping endpoint resolution for service '{ServiceName}': '{Reason}'.", EventName = "SkippedResolution")] + public static partial void SkippedResolution(ILogger logger, string serviceName, string reason); + + [LoggerMessage(2, LogLevel.Debug, "Using configuration from path '{Path}' to resolve endpoint '{EndpointName}' for service '{ServiceName}'.", EventName = "UsingConfigurationPath")] + public static partial void UsingConfigurationPath(ILogger logger, string path, string endpointName, string serviceName); + + [LoggerMessage(3, LogLevel.Debug, "No valid endpoint configuration was found for service '{ServiceName}' from path '{Path}'.", EventName = "ServiceConfigurationNotFound")] + internal static partial void ServiceConfigurationNotFound(ILogger logger, string serviceName, string path); + + [LoggerMessage(4, LogLevel.Debug, "Endpoints configured for service '{ServiceName}' from path '{Path}': {ConfiguredEndpoints}.", EventName = "ConfiguredEndpoints")] + internal static partial void ConfiguredEndpoints(ILogger logger, string serviceName, string path, string configuredEndpoints); + + internal static void ConfiguredEndpoints(ILogger logger, string serviceName, string path, IList endpoints, int added) + { + if (!logger.IsEnabled(LogLevel.Debug)) + { + return; + } + + StringBuilder endpointValues = new(); + for (var i = endpoints.Count - added; i < endpoints.Count; i++) + { + if (endpointValues.Length > 0) + { + endpointValues.Append(", "); + } + + endpointValues.Append(endpoints[i].ToString()); + } + + var configuredEndpoints = endpointValues.ToString(); + ConfiguredEndpoints(logger, serviceName, path, configuredEndpoints); + } + + [LoggerMessage(5, LogLevel.Debug, "No valid endpoint configuration was found for endpoint '{EndpointName}' on service '{ServiceName}' from path '{Path}'.", EventName = "EndpointConfigurationNotFound")] + internal static partial void EndpointConfigurationNotFound(ILogger logger, string endpointName, string serviceName, string path); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.cs new file mode 100644 index 00000000000..e8c84b69ec8 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.cs @@ -0,0 +1,242 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Net; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Microsoft.Extensions.ServiceDiscovery.Configuration; + +/// +/// A service endpoint provider that uses configuration to resolve resolved. +/// +internal sealed partial class ConfigurationServiceEndpointProvider : IServiceEndpointProvider, IHostNameFeature +{ + private const string DefaultEndpointName = "default"; + private readonly string _serviceName; + private readonly string? _endpointName; + private readonly bool _includeAllSchemes; + private readonly string[] _schemes; + private readonly IConfiguration _configuration; + private readonly ILogger _logger; + private readonly IOptions _options; + + /// + /// Initializes a new instance. + /// + /// The query. + /// The configuration. + /// The logger. + /// Configuration provider options. + /// Service discovery options. + public ConfigurationServiceEndpointProvider( + ServiceEndpointQuery query, + IConfiguration configuration, + ILogger logger, + IOptions options, + IOptions serviceDiscoveryOptions) + { + _serviceName = query.ServiceName; + _endpointName = query.EndpointName; + _includeAllSchemes = serviceDiscoveryOptions.Value.AllowAllSchemes && query.IncludedSchemes.Count == 0; + _schemes = ServiceDiscoveryOptions.ApplyAllowedSchemes(query.IncludedSchemes, serviceDiscoveryOptions.Value.AllowedSchemes, serviceDiscoveryOptions.Value.AllowAllSchemes); + _configuration = configuration; + _logger = logger; + _options = options; + } + + /// + public ValueTask DisposeAsync() => default; + + /// + public ValueTask PopulateAsync(IServiceEndpointBuilder endpoints, CancellationToken cancellationToken) + { + // Only add resolved to the collection if a previous provider (eg, an override) did not add them. + if (endpoints.Endpoints.Count != 0) + { + Log.SkippedResolution(_logger, _serviceName, "Collection has existing endpoints"); + return default; + } + + // Get the corresponding config section. + var section = _configuration.GetSection(_options.Value.SectionName).GetSection(_serviceName); + if (!section.Exists()) + { + endpoints.AddChangeToken(_configuration.GetReloadToken()); + Log.ServiceConfigurationNotFound(_logger, _serviceName, $"{_options.Value.SectionName}:{_serviceName}"); + return default; + } + + endpoints.AddChangeToken(section.GetReloadToken()); + + // Find an appropriate configuration section based on the input. + IConfigurationSection? namedSection = null; + string endpointName; + if (string.IsNullOrWhiteSpace(_endpointName)) + { + // Treat the scheme as the endpoint name and use the first section with a matching endpoint name which exists + endpointName = DefaultEndpointName; + ReadOnlySpan candidateNames = [DefaultEndpointName, .. _schemes]; + foreach (var scheme in candidateNames) + { + var candidate = section.GetSection(scheme); + if (candidate.Exists()) + { + endpointName = scheme; + namedSection = candidate; + break; + } + } + } + else + { + // Use the section corresponding to the endpoint name. + endpointName = _endpointName; + namedSection = section.GetSection(_endpointName); + } + + var configPath = $"{_options.Value.SectionName}:{_serviceName}:{endpointName}"; + if (!namedSection.Exists()) + { + Log.EndpointConfigurationNotFound(_logger, endpointName, _serviceName, configPath); + return default; + } + + List resolved = []; + Log.UsingConfigurationPath(_logger, configPath, endpointName, _serviceName); + + // Account for both the single and multi-value cases. + if (!string.IsNullOrWhiteSpace(namedSection.Value)) + { + // Single value case. + AddEndpoint(resolved, namedSection, endpointName); + } + else + { + // Multiple value case. + foreach (var child in namedSection.GetChildren()) + { + if (!int.TryParse(child.Key, out _)) + { + throw new KeyNotFoundException($"The endpoint configuration section for service '{_serviceName}' endpoint '{endpointName}' has non-numeric keys."); + } + + AddEndpoint(resolved, child, endpointName); + } + } + + int resolvedEndpointCount; + if (_includeAllSchemes) + { + // Include all endpoints. + foreach (var ep in resolved) + { + endpoints.Endpoints.Add(ep); + } + + resolvedEndpointCount = resolved.Count; + } + else + { + // Filter the resolved endpoints to only include those which match the specified, allowed schemes. + resolvedEndpointCount = 0; + var minIndex = _schemes.Length; + foreach (var ep in resolved) + { + if (ep.EndPoint is UriEndPoint uri && uri.Uri.Scheme is { } scheme) + { + var index = Array.IndexOf(_schemes, scheme); + if (index >= 0 && index < minIndex) + { + minIndex = index; + } + } + } + + foreach (var ep in resolved) + { + if (ep.EndPoint is UriEndPoint uri && uri.Uri.Scheme is { } scheme) + { + var index = Array.IndexOf(_schemes, scheme); + if (index >= 0 && index <= minIndex) + { + ++resolvedEndpointCount; + endpoints.Endpoints.Add(ep); + } + } + else + { + ++resolvedEndpointCount; + endpoints.Endpoints.Add(ep); + } + } + } + + if (resolvedEndpointCount == 0) + { + Log.ServiceConfigurationNotFound(_logger, _serviceName, configPath); + } + else + { + Log.ConfiguredEndpoints(_logger, _serviceName, configPath, endpoints.Endpoints, resolvedEndpointCount); + } + + return default; + } + + string IHostNameFeature.HostName => _serviceName; + + private void AddEndpoint(List endpoints, IConfigurationSection section, string endpointName) + { + var value = section.Value; + if (string.IsNullOrWhiteSpace(value) || !TryParseEndPoint(value, out var endPoint)) + { + throw new KeyNotFoundException($"The endpoint configuration section for service '{_serviceName}' endpoint '{endpointName}' has an invalid value with key '{section.Key}'."); + } + + endpoints.Add(CreateEndpoint(endPoint)); + } + + private static bool TryParseEndPoint(string value, [NotNullWhen(true)] out EndPoint? endPoint) + { + if (value.IndexOf("://") < 0 && Uri.TryCreate($"fakescheme://{value}", default, out var uri)) + { + var port = uri.Port > 0 ? uri.Port : 0; + if (IPAddress.TryParse(uri.Host, out var ip)) + { + endPoint = new IPEndPoint(ip, port); + } + else + { + endPoint = new DnsEndPoint(uri.Host, port); + } + } + else if (Uri.TryCreate(value, default, out uri)) + { + endPoint = new UriEndPoint(uri); + } + else + { + endPoint = null; + return false; + } + + return true; + } + + private ServiceEndpoint CreateEndpoint(EndPoint endPoint) + { + var serviceEndpoint = ServiceEndpoint.Create(endPoint); + serviceEndpoint.Features.Set(this); + if (_options.Value.ShouldApplyHostNameMetadata(serviceEndpoint)) + { + serviceEndpoint.Features.Set(this); + } + + return serviceEndpoint; + } + + public override string ToString() => "Configuration"; +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProviderFactory.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProviderFactory.cs new file mode 100644 index 00000000000..a966cd44794 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProviderFactory.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Microsoft.Extensions.ServiceDiscovery.Configuration; + +/// +/// implementation that resolves services using . +/// +internal sealed class ConfigurationServiceEndpointProviderFactory( + IConfiguration configuration, + IOptions options, + IOptions serviceDiscoveryOptions, + ILogger logger) : IServiceEndpointProviderFactory +{ + /// + public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider) + { + provider = new ConfigurationServiceEndpointProvider(query, configuration, logger, options, serviceDiscoveryOptions); + return true; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProviderOptionsValidator.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProviderOptionsValidator.cs new file mode 100644 index 00000000000..f8092c4dd51 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProviderOptionsValidator.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Options; + +namespace Microsoft.Extensions.ServiceDiscovery.Configuration; + +internal sealed class ConfigurationServiceEndpointProviderOptionsValidator : IValidateOptions +{ + public ValidateOptionsResult Validate(string? name, ConfigurationServiceEndpointProviderOptions options) + { + if (string.IsNullOrWhiteSpace(options.SectionName)) + { + return ValidateOptionsResult.Fail($"{nameof(options.SectionName)} must not be null or empty."); + } + + if (options.ShouldApplyHostNameMetadata is null) + { + return ValidateOptionsResult.Fail($"{nameof(options.ShouldApplyHostNameMetadata)} must not be null."); + } + + return ValidateOptionsResult.Success; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ConfigurationServiceEndpointProviderOptions.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ConfigurationServiceEndpointProviderOptions.cs new file mode 100644 index 00000000000..29f28e359f7 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ConfigurationServiceEndpointProviderOptions.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.ServiceDiscovery.Configuration; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Options for . +/// +public sealed class ConfigurationServiceEndpointProviderOptions +{ + /// + /// The name of the configuration section which contains service endpoints. Defaults to "Services". + /// + public string SectionName { get; set; } = "Services"; + + /// + /// Gets or sets a delegate used to determine whether to apply host name metadata to each resolved endpoint. Defaults to a delegate which returns false. + /// + public Func ShouldApplyHostNameMetadata { get; set; } = _ => false; +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/HttpServiceEndpointResolver.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/HttpServiceEndpointResolver.cs new file mode 100644 index 00000000000..e547ab14138 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/HttpServiceEndpointResolver.cs @@ -0,0 +1,263 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Diagnostics; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.ServiceDiscovery.LoadBalancing; + +namespace Microsoft.Extensions.ServiceDiscovery.Http; + +/// +/// Resolves endpoints for HTTP requests. +/// +internal sealed class HttpServiceEndpointResolver(ServiceEndpointWatcherFactory watcherFactory, IServiceProvider serviceProvider, TimeProvider timeProvider) : IAsyncDisposable +{ + private static readonly TimerCallback s_cleanupCallback = s => ((HttpServiceEndpointResolver)s!).CleanupResolvers(); + private static readonly TimeSpan s_cleanupPeriod = TimeSpan.FromSeconds(10); + + private readonly object _lock = new(); + private readonly ServiceEndpointWatcherFactory _watcherFactory = watcherFactory; + private readonly ConcurrentDictionary _resolvers = new(); + private ITimer? _cleanupTimer; + private Task? _cleanupTask; + + /// + /// Resolves and returns a service endpoint for the specified request. + /// + /// The request message. + /// A . + /// The resolved service endpoint. + /// The request had no set or a suitable endpoint could not be found. + public async ValueTask GetEndpointAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + ArgumentNullException.ThrowIfNull(request); + if (request.RequestUri is null) + { + throw new InvalidOperationException("Cannot resolve an endpoint for a request which has no RequestUri"); + } + + EnsureCleanupTimerStarted(); + + var key = request.RequestUri.GetLeftPart(UriPartial.Authority); + while (true) + { + var resolver = _resolvers.GetOrAdd( + key, + static (name, self) => self.CreateResolver(name), + this); + + var (valid, endpoint) = await resolver.TryGetEndpointAsync(request, cancellationToken).ConfigureAwait(false); + if (valid) + { + if (endpoint is null) + { + throw new InvalidOperationException($"Unable to resolve endpoint for service {resolver.ServiceName}"); + } + + return endpoint; + } + else + { + _resolvers.TryRemove(KeyValuePair.Create(resolver.ServiceName, resolver)); + } + } + } + + private void EnsureCleanupTimerStarted() + { + if (_cleanupTimer is not null) + { + return; + } + + lock (_lock) + { + if (_cleanupTimer is not null) + { + return; + } + + // Don't capture the current ExecutionContext and its AsyncLocals onto the timer + var restoreFlow = false; + try + { + if (!ExecutionContext.IsFlowSuppressed()) + { + ExecutionContext.SuppressFlow(); + restoreFlow = true; + } + + _cleanupTimer = timeProvider.CreateTimer(s_cleanupCallback, this, s_cleanupPeriod, s_cleanupPeriod); + } + finally + { + // Restore the current ExecutionContext + if (restoreFlow) + { + ExecutionContext.RestoreFlow(); + } + } + } + } + + /// + public async ValueTask DisposeAsync() + { + lock (_lock) + { + _cleanupTimer?.Dispose(); + _cleanupTimer = null; + } + + foreach (var resolver in _resolvers) + { + await resolver.Value.DisposeAsync().ConfigureAwait(false); + } + + _resolvers.Clear(); + if (_cleanupTask is not null) + { + await _cleanupTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + } + } + + private void CleanupResolvers() + { + lock (_lock) + { + if (_cleanupTask is null or { IsCompleted: true }) + { + _cleanupTask = CleanupResolversAsyncCore(); + } + } + } + + private async Task CleanupResolversAsyncCore() + { + List? cleanupTasks = null; + foreach (var (name, resolver) in _resolvers) + { + if (resolver.CanExpire() && _resolvers.TryRemove(name, out var _)) + { + cleanupTasks ??= new(); + cleanupTasks.Add(resolver.DisposeAsync().AsTask()); + } + } + + if (cleanupTasks is not null) + { + await Task.WhenAll(cleanupTasks).ConfigureAwait(false); + } + } + + private ResolverEntry CreateResolver(string serviceName) + { + var watcher = _watcherFactory.CreateWatcher(serviceName); + var selector = serviceProvider.GetService() ?? new RoundRobinServiceEndpointSelector(); + var result = new ResolverEntry(watcher, selector); + watcher.Start(); + return result; + } + + private sealed class ResolverEntry : IAsyncDisposable + { + private readonly ServiceEndpointWatcher _watcher; + private readonly IServiceEndpointSelector _selector; + private const ulong CountMask = ~(RecentUseFlag | DisposingFlag); + private const ulong RecentUseFlag = 1UL << 62; + private const ulong DisposingFlag = 1UL << 63; + private ulong _status; + private TaskCompletionSource? _onDisposed; + + public ResolverEntry(ServiceEndpointWatcher watcher, IServiceEndpointSelector selector) + { + _watcher = watcher; + _selector = selector; + _watcher.OnEndpointsUpdated += result => + { + if (result.ResolvedSuccessfully) + { + _selector.SetEndpoints(result.EndpointSource); + } + }; + } + + public string ServiceName => _watcher.ServiceName; + + public bool CanExpire() + { + // Read the status, clearing the recent use flag in the process. + var status = Interlocked.And(ref _status, ~RecentUseFlag); + + // The instance can be expired if there are no concurrent callers and the recent use flag was not set. + return (status & (CountMask | RecentUseFlag)) == 0; + } + + public async ValueTask<(bool Valid, ServiceEndpoint? Endpoint)> TryGetEndpointAsync(object? context, CancellationToken cancellationToken) + { + try + { + var status = Interlocked.Increment(ref _status); + if ((status & DisposingFlag) == 0) + { + // If the watcher is valid, resolve. + // We ensure that it will not be disposed while we are resolving. + await _watcher.GetEndpointsAsync(cancellationToken).ConfigureAwait(false); + var result = _selector.GetEndpoint(context); + return (true, result); + } + else + { + return (false, default); + } + } + finally + { + // Set the recent use flag to prevent the instance from being disposed. + Interlocked.Or(ref _status, RecentUseFlag); + + // If we are the last concurrent request to complete and the Disposing flag has been set, + // dispose the resolver now. DisposeAsync was prevented by concurrent requests. + var status = Interlocked.Decrement(ref _status); + if ((status & DisposingFlag) == DisposingFlag && (status & CountMask) == 0) + { + await DisposeAsyncCore().ConfigureAwait(false); + } + } + } + + public async ValueTask DisposeAsync() + { + if (_onDisposed is null) + { + Interlocked.CompareExchange(ref _onDisposed, new(TaskCreationOptions.RunContinuationsAsynchronously), null); + } + + var status = Interlocked.Or(ref _status, DisposingFlag); + if ((status & DisposingFlag) != DisposingFlag && (status & CountMask) == 0) + { + // If we are the one who flipped the Disposing flag and there are no concurrent requests, + // dispose the instance now. Concurrent requests are prevented from starting by the Disposing flag. + await DisposeAsyncCore().ConfigureAwait(false); + } + else + { + await _onDisposed.Task.ConfigureAwait(false); + } + } + + private async Task DisposeAsyncCore() + { + try + { + await _watcher.DisposeAsync().ConfigureAwait(false); + } + finally + { + Debug.Assert(_onDisposed is not null); + _onDisposed.SetResult(); + } + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/IServiceDiscoveryHttpMessageHandlerFactory.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/IServiceDiscoveryHttpMessageHandlerFactory.cs new file mode 100644 index 00000000000..0c5bd02d10d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/IServiceDiscoveryHttpMessageHandlerFactory.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Http; + +/// +/// Factory which creates instances which resolve endpoints using service discovery +/// before delegating to a provided handler. +/// +public interface IServiceDiscoveryHttpMessageHandlerFactory +{ + /// + /// Creates an instance which resolve endpoints using service discovery before + /// delegating to a provided handler. + /// + /// The handler to delegate to. + /// The new . + HttpMessageHandler CreateHandler(HttpMessageHandler handler); +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/ResolvingHttpClientHandler.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/ResolvingHttpClientHandler.cs new file mode 100644 index 00000000000..a0063ae476b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/ResolvingHttpClientHandler.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Options; + +namespace Microsoft.Extensions.ServiceDiscovery.Http; + +/// +/// which resolves endpoints using service discovery. +/// +internal sealed class ResolvingHttpClientHandler(HttpServiceEndpointResolver resolver, IOptions options) : HttpClientHandler +{ + private readonly HttpServiceEndpointResolver _resolver = resolver; + private readonly ServiceDiscoveryOptions _options = options.Value; + + /// + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + var originalUri = request.RequestUri; + + try + { + if (originalUri?.Host is not null) + { + var result = await _resolver.GetEndpointAsync(request, cancellationToken).ConfigureAwait(false); + request.RequestUri = ResolvingHttpDelegatingHandler.GetUriWithEndpoint(originalUri, result, _options); + request.Headers.Host ??= result.Features.Get()?.HostName; + } + + return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); + } + finally + { + request.RequestUri = originalUri; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/ResolvingHttpDelegatingHandler.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/ResolvingHttpDelegatingHandler.cs new file mode 100644 index 00000000000..8f13bb60ab5 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/ResolvingHttpDelegatingHandler.cs @@ -0,0 +1,128 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.Extensions.Options; + +namespace Microsoft.Extensions.ServiceDiscovery.Http; + +/// +/// HTTP message handler which resolves endpoints using service discovery. +/// +internal sealed class ResolvingHttpDelegatingHandler : DelegatingHandler +{ + private readonly HttpServiceEndpointResolver _resolver; + private readonly ServiceDiscoveryOptions _options; + + /// + /// Initializes a new instance. + /// + /// The endpoint resolver. + /// The service discovery options. + public ResolvingHttpDelegatingHandler(HttpServiceEndpointResolver resolver, IOptions options) + { + _resolver = resolver; + _options = options.Value; + } + + /// + /// Initializes a new instance. + /// + /// The endpoint resolver. + /// The service discovery options. + /// The inner handler. + public ResolvingHttpDelegatingHandler(HttpServiceEndpointResolver resolver, IOptions options, HttpMessageHandler innerHandler) : base(innerHandler) + { + _resolver = resolver; + _options = options.Value; + } + + /// + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + var originalUri = request.RequestUri; + if (originalUri?.Host is not null) + { + var result = await _resolver.GetEndpointAsync(request, cancellationToken).ConfigureAwait(false); + request.RequestUri = GetUriWithEndpoint(originalUri, result, _options); + request.Headers.Host ??= result.Features.Get()?.HostName; + } + + try + { + return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); + } + finally + { + request.RequestUri = originalUri; + } + } + + internal static Uri GetUriWithEndpoint(Uri uri, ServiceEndpoint serviceEndpoint, ServiceDiscoveryOptions options) + { + var endPoint = serviceEndpoint.EndPoint; + UriBuilder result; + if (endPoint is UriEndPoint { Uri: { } ep }) + { + result = new UriBuilder(uri) + { + Scheme = ep.Scheme, + Host = ep.Host, + }; + + if (ep.Port > 0) + { + result.Port = ep.Port; + } + + if (ep.AbsolutePath.Length > 1) + { + result.Path = $"{ep.AbsolutePath.TrimEnd('/')}/{uri.AbsolutePath.TrimStart('/')}"; + } + } + else + { + string host; + int port; + switch (endPoint) + { + case IPEndPoint ip: + host = ip.Address.ToString(); + port = ip.Port; + break; + case DnsEndPoint dns: + host = dns.Host; + port = dns.Port; + break; + default: + throw new InvalidOperationException($"Endpoints of type {endPoint.GetType()} are not supported"); + } + + result = new UriBuilder(uri) + { + Host = host, + }; + + // Default to the default port for the scheme. + if (port > 0) + { + result.Port = port; + } + + if (uri.Scheme.IndexOf('+') > 0) + { + var scheme = uri.Scheme.Split('+')[0]; + if (options.AllowAllSchemes || options.AllowedSchemes.Contains(scheme, StringComparer.OrdinalIgnoreCase)) + { + result.Scheme = scheme; + } + else + { + throw new InvalidOperationException($"The scheme '{scheme}' is not allowed."); + } + } + } + + return result.Uri; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/ServiceDiscoveryHttpMessageHandlerFactory.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/ServiceDiscoveryHttpMessageHandlerFactory.cs new file mode 100644 index 00000000000..e5e7f7587bb --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Http/ServiceDiscoveryHttpMessageHandlerFactory.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Options; + +namespace Microsoft.Extensions.ServiceDiscovery.Http; + +internal sealed class ServiceDiscoveryHttpMessageHandlerFactory( + TimeProvider timeProvider, + IServiceProvider serviceProvider, + ServiceEndpointWatcherFactory factory, + IOptions options) : IServiceDiscoveryHttpMessageHandlerFactory +{ + public HttpMessageHandler CreateHandler(HttpMessageHandler handler) + { + var registry = new HttpServiceEndpointResolver(factory, serviceProvider, timeProvider); + return new ResolvingHttpDelegatingHandler(registry, options, handler); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Internal/ServiceDiscoveryOptionsValidator.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Internal/ServiceDiscoveryOptionsValidator.cs new file mode 100644 index 00000000000..fae7bd6f4fc --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Internal/ServiceDiscoveryOptionsValidator.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Options; + +namespace Microsoft.Extensions.ServiceDiscovery.Internal; + +internal sealed class ServiceDiscoveryOptionsValidator : IValidateOptions +{ + public ValidateOptionsResult Validate(string? name, ServiceDiscoveryOptions options) + { + if (options.AllowedSchemes is null) + { + return ValidateOptionsResult.Fail("At least one allowed scheme must be specified."); + } + + return ValidateOptionsResult.Success; + } +} + diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Internal/ServiceEndpointResolverResult.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Internal/ServiceEndpointResolverResult.cs new file mode 100644 index 00000000000..675941bb955 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Internal/ServiceEndpointResolverResult.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.ServiceDiscovery.Internal; + +/// +/// Represents the result of service endpoint resolution. +/// +/// The endpoint collection. +/// The exception which occurred during resolution. +internal sealed class ServiceEndpointResolverResult(ServiceEndpointSource? endpointSource, Exception? exception) +{ + /// + /// Gets the exception which occurred during resolution. + /// + public Exception? Exception { get; } = exception; + + /// + /// Gets a value indicating whether resolution completed successfully. + /// + [MemberNotNullWhen(true, nameof(EndpointSource))] + public bool ResolvedSuccessfully => Exception is null; + + /// + /// Gets the endpoints. + /// + public ServiceEndpointSource? EndpointSource { get; } = endpointSource; +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/LoadBalancing/IServiceEndpointSelector.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/LoadBalancing/IServiceEndpointSelector.cs new file mode 100644 index 00000000000..2d81ff38601 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/LoadBalancing/IServiceEndpointSelector.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.LoadBalancing; + +/// +/// Selects endpoints from a collection of endpoints. +/// +internal interface IServiceEndpointSelector +{ + /// + /// Sets the collection of endpoints which this instance will select from. + /// + /// The collection of endpoints to select from. + void SetEndpoints(ServiceEndpointSource endpoints); + + /// + /// Selects an endpoints from the collection provided by the most recent call to . + /// + /// The context. + /// An endpoint. + ServiceEndpoint GetEndpoint(object? context); +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/LoadBalancing/RoundRobinServiceEndpointSelector.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/LoadBalancing/RoundRobinServiceEndpointSelector.cs new file mode 100644 index 00000000000..e7e51bc6021 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/LoadBalancing/RoundRobinServiceEndpointSelector.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.LoadBalancing; + +/// +/// Selects endpoints by iterating through the list of endpoints in a round-robin fashion. +/// +internal sealed class RoundRobinServiceEndpointSelector : IServiceEndpointSelector +{ + private uint _next; + private IReadOnlyList? _endpoints; + + /// + public void SetEndpoints(ServiceEndpointSource endpoints) + { + _endpoints = endpoints.Endpoints; + } + + /// + public ServiceEndpoint GetEndpoint(object? context) + { + if (_endpoints is not { Count: > 0 } collection) + { + throw new InvalidOperationException("The endpoint collection contains no endpoints"); + } + + return collection[(int)(Interlocked.Increment(ref _next) % collection.Count)]; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Microsoft.Extensions.ServiceDiscovery.csproj b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Microsoft.Extensions.ServiceDiscovery.csproj new file mode 100644 index 00000000000..8ff69baf576 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/Microsoft.Extensions.ServiceDiscovery.csproj @@ -0,0 +1,30 @@ + + + + $(TargetFrameworks);netstandard2.0 + true + Provides extensions to HttpClient that enable service discovery based on configuration. + $(DefaultDotnetIconFullPath) + + $(NoWarn);CS8600;CS8602;CS8604;IDE0040;IDE0055;IDE0058;IDE1006;CA1307;CA1310;CA1849;CA2007;CA2213;SA1204;SA1128;SA1205;SA1405;SA1612;SA1623;SA1625;SA1642;S1144;S1449;S2302;S2692;S3872;S4457;EA0000;EA0009;EA0014;LA0001;LA0003;LA0008;VSTHRD200 + enable + false + + + + + + + + + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/PassThrough/PassThroughServiceEndpointProvider.Log.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/PassThrough/PassThroughServiceEndpointProvider.Log.cs new file mode 100644 index 00000000000..9f6e9ce0ccb --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/PassThrough/PassThroughServiceEndpointProvider.Log.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery.PassThrough; + +internal sealed partial class PassThroughServiceEndpointProvider +{ + private static partial class Log + { + [LoggerMessage(1, LogLevel.Debug, "Using pass-through service endpoint provider for service '{ServiceName}'.", EventName = "UsingPassThrough")] + internal static partial void UsingPassThrough(ILogger logger, string serviceName); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/PassThrough/PassThroughServiceEndpointProvider.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/PassThrough/PassThroughServiceEndpointProvider.cs new file mode 100644 index 00000000000..478d81d42dc --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/PassThrough/PassThroughServiceEndpointProvider.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery.PassThrough; + +/// +/// Service endpoint provider which passes through the provided value. +/// +internal sealed partial class PassThroughServiceEndpointProvider(ILogger logger, string serviceName, EndPoint endPoint) : IServiceEndpointProvider +{ + public ValueTask PopulateAsync(IServiceEndpointBuilder endpoints, CancellationToken cancellationToken) + { + if (endpoints.Endpoints.Count == 0) + { + Log.UsingPassThrough(logger, serviceName); + var ep = ServiceEndpoint.Create(endPoint); + ep.Features.Set(this); + endpoints.Endpoints.Add(ep); + } + + return default; + } + + public ValueTask DisposeAsync() => default; + + public override string ToString() => "Pass-through"; +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/PassThrough/PassThroughServiceEndpointProviderFactory.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/PassThrough/PassThroughServiceEndpointProviderFactory.cs new file mode 100644 index 00000000000..2bf8c0cb481 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/PassThrough/PassThroughServiceEndpointProviderFactory.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Net; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery.PassThrough; + +/// +/// Service endpoint provider factory which creates pass-through providers. +/// +internal sealed class PassThroughServiceEndpointProviderFactory(ILogger logger) : IServiceEndpointProviderFactory +{ + /// + public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider) + { + var serviceName = query.ToString()!; + if (!TryCreateEndPoint(serviceName, out var endPoint)) + { + // Propagate the value through regardless, leaving it to the caller to interpret it. + endPoint = new DnsEndPoint(serviceName, 0); + } + + provider = new PassThroughServiceEndpointProvider(logger, serviceName, endPoint); + return true; + } + + private static bool TryCreateEndPoint(string serviceName, [NotNullWhen(true)] out EndPoint? endPoint) + { + if ((serviceName.Contains("://", StringComparison.Ordinal) || !Uri.TryCreate($"fakescheme://{serviceName}", default, out var uri)) && !Uri.TryCreate(serviceName, default, out uri)) + { + endPoint = null; + return false; + } + + var uriHost = uri.Host; + var segmentSeparatorIndex = uriHost.IndexOf('.'); + string host; + if (uriHost.StartsWith('_') && segmentSeparatorIndex > 1 && uriHost[^1] != '.') + { + // Skip the endpoint name, including its prefix ('_') and suffix ('.'). + host = uriHost[(segmentSeparatorIndex + 1)..]; + } + else + { + host = uriHost; + } + + var port = uri.Port > 0 ? uri.Port : 0; + if (IPAddress.TryParse(host, out var ip)) + { + endPoint = new IPEndPoint(ip, port); + } + else if (!string.IsNullOrEmpty(host)) + { + endPoint = new DnsEndPoint(host, port); + } + else + { + endPoint = null; + return false; + } + + return true; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/README.md b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/README.md new file mode 100644 index 00000000000..b767bb41e83 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/README.md @@ -0,0 +1,276 @@ +# Microsoft.Extensions.ServiceDiscovery + +The `Microsoft.Extensions.ServiceDiscovery` library is designed to simplify the integration of service discovery patterns in .NET applications. Service discovery is a key component of most distributed systems and microservices architectures. This library provides a straightforward way to resolve service names to endpoint addresses. + +In typical systems, service configuration changes over time. Service discovery accounts for by monitoring endpoint configuration using push-based notifications where supported, falling back to polling in other cases. When endpoints are refreshed, callers are notified so that they can observe the refreshed results. + +## How it works + +Service discovery uses configured _providers_ to resolve service endpoints. When service endpoints are resolved, each registered provider is called in the order of registration to contribute to a collection of service endpoints (an instance of `ServiceEndpointSource`). + +Providers implement the `IServiceEndpointProvider` interface. They are created by an instance of `IServiceEndpointProviderProvider`, which are registered with the [.NET dependency injection](https://learn.microsoft.com/dotnet/core/extensions/dependency-injection) system. + +Developers typically add service discovery to their [`HttpClient`](https://learn.microsoft.com/dotnet/fundamentals/networking/http/httpclient) using the [`IHttpClientFactory`](https://learn.microsoft.com/dotnet/core/extensions/httpclient-factory) with the `AddServiceDiscovery` extension method. + +Services can be resolved directly by calling `ServiceEndpointResolver`'s `GetEndpointsAsync` method, which returns a collection of resolved endpoints. + +### Change notifications + +Service configuration can change over time. Service discovery accounts for by monitoring endpoint configuration using push-based notifications where supported, falling back to polling in other cases. When endpoints are refreshed, callers are notified so that they can observe the refreshed results. To subscribe to notifications, callers use the `ChangeToken` property of `ServiceEndpointCollection`. For more information on change tokens, see [Detect changes with change tokens in ASP.NET Core](https://learn.microsoft.com/aspnet/core/fundamentals/change-tokens?view=aspnetcore-7.0). + +### Extensibility using features + +Service endpoints (`ServiceEndpoint` instances) and collections of service endpoints (`ServiceEndpointCollection` instances) expose an extensible [`IFeatureCollection`](https://learn.microsoft.com/dotnet/api/microsoft.aspnetcore.http.features.ifeaturecollection) via their `Features` property. Features are exposed as interfaces accessible on the feature collection. These interfaces can be added, modified, wrapped, replaced or even removed at resolution time by providers. Features which may be available on a `ServiceEndpoint` include: + +* `IHostNameFeature`: exposes the host name of the resolved endpoint, intended for use with [Server Name Identification (SNI)](https://en.wikipedia.org/wiki/Server_Name_Indication) and [Transport Layer Security (TLS)](https://en.wikipedia.org/wiki/Transport_Layer_Security). + +### Resolution order + +The providers included in the `Microsoft.Extensions.ServiceDiscovery` series of packages skip resolution if there are existing endpoints in the collection when they are called. For example, consider a case where the following providers are registered: _Configuration_, _DNS SRV_, _Pass-through_. When resolution occurs, the providers will be called in-order. If the _Configuration_ providers discovers no endpoints, the _DNS SRV_ provider will perform resolution and may add one or more endpoints. If the _DNS SRV_ provider adds an endpoint to the collection, the _Pass-through_ provider will skip its resolution and will return immediately instead. + +## Getting Started + +### Installation + +To install the library, use the following NuGet command: + +```dotnetcli +dotnet add package Microsoft.Extensions.ServiceDiscovery +``` + +### Usage example + +In the _AppHost.cs_ file of your project, call the `AddServiceDiscovery` extension method to add service discovery to the host, configuring default service endpoint providers. + +```csharp +builder.Services.AddServiceDiscovery(); +``` + +Add service discovery to an individual `IHttpClientBuilder` by calling the `AddServiceDiscovery` extension method: + +```csharp +builder.Services.AddHttpClient(c => +{ + c.BaseAddress = new("https://catalog")); +}).AddServiceDiscovery(); +``` + +Alternatively, you can add service discovery to all `HttpClient` instances by default: + +```csharp +builder.Services.ConfigureHttpClientDefaults(http => +{ + // Turn on service discovery by default + http.AddServiceDiscovery(); +}); +``` + +### Resolving service endpoints from configuration + +The `AddServiceDiscovery` extension method adds a configuration-based endpoint provider by default. +This provider reads endpoints from the [.NET Configuration system](https://learn.microsoft.com/dotnet/core/extensions/configuration). +The library supports configuration through `appsettings.json`, environment variables, or any other `IConfiguration` source. + +Here is an example demonstrating how to configure endpoints for the service named _catalog_ via `appsettings.json`: + +```json +{ + "Services": { + "catalog": { + "https": [ + "https://localhost:8443", + "https://10.46.24.90:443" + ] + } + } +} +``` + +The above example adds two endpoints for the service named _catalog_: `https://localhost:8443`, and `"https://10.46.24.90:443"`. +Each time the _catalog_ is resolved, one of these endpoints will be selected. + +If service discovery was added to the host using the `AddServiceDiscoveryCore` extension method on `IServiceCollection`, the configuration-based endpoint provider can be added by calling the `AddConfigurationServiceEndpointProvider` extension method on `IServiceCollection`. + +### Configuration + +The configuration provider is configured using the `ConfigurationServiceEndpointProviderOptions` class, which offers these configuration options: + +* **`SectionName`**: The name of the configuration section that contains service endpoints. It defaults to `"Services"`. + +* **`ShouldApplyHostNameMetadata`**: A delegate used to determine if host name metadata should be applied to resolved endpoints. It defaults to a function that returns `false`. + +To configure these options, you can use the `Configure` extension method on the `IServiceCollection` within your application's startup class or main program file: + +```csharp +var builder = WebApplication.CreateBuilder(args); + +builder.Services.Configure(options => +{ + options.SectionName = "MyServiceEndpoints"; + + // Configure the logic for applying host name metadata + options.ShouldApplyHostNameMetadata = endpoint => + { + // Your custom logic here. For example: + return endpoint.Endpoint is DnsEndPoint dnsEp && dnsEp.Host.StartsWith("internal"); + }; +}); +``` + +This example demonstrates setting a custom section name for your service endpoints and providing a custom logic for applying host name metadata based on a condition. + +## Scheme selection when resolving HTTP(S) endpoints + +It is common to use HTTP while developing and testing a service locally and HTTPS when the service is deployed. Service Discovery supports this by allowing for a priority list of URI schemes to be specified in the input string given to Service Discovery. Service Discovery will attempt to resolve the services for the schemes in order and will stop after an endpoint is found. URI schemes are separated by a `+` character, for example: `"https+http://basket"`. Service Discovery will first try to find HTTPS endpoints for the `"basket"` service and will then fall back to HTTP endpoints. If any HTTPS endpoint is found, Service Discovery will not include HTTP endpoints. +Schemes can be filtered by configuring the `AllowedSchemes` and `AllowAllSchemes` properties on `ServiceDiscoveryOptions`. The `AllowAllSchemes` property is used to indicate that all schemes are allowed. By default, `AllowAllSchemes` is `true` and all schemes are allowed. Schemes can be restricted by setting `AllowAllSchemes` to `false` and adding allowed schemes to the `AllowedSchemes` property. For example, to allow only HTTPS: + +```csharp +services.Configure(options => +{ + options.AllowAllSchemes = false; + options.AllowedSchemes = ["https"]; +}); +``` + +To explicitly allow all schemes, set the `ServiceDiscoveryOptions.AllowAllSchemes` property to `true`: + +```csharp +services.Configure(options => options.AllowAllSchemes = true); +``` + +## Resolving service endpoints using platform-provided service discovery + +Some platforms, such as Azure Container Apps and Kubernetes (if configured), provide functionality for service discovery without the need for a service discovery client library. When an application is deployed to one of these environments, it may be preferable to use the platform's existing functionality instead. The pass-through provider exists to support this scenario while still allowing other provider (such as configuration) to be used in other environments, such as on the developer's machine, without requiring a code change or conditional guards. + +The pass-through provider performs no external resolution and instead resolves endpoints by returning the input service name represented as a `DnsEndPoint`. + +The pass-through provider is configured by-default when adding service discovery via the `AddServiceDiscovery` extension method. + +If service discovery was added to the host using the `AddServiceDiscoveryCore` extension method on `IServiceCollection`, the pass-through provider can be added by calling the `AddPassThroughServiceEndpointProvider` extension method on `IServiceCollection`. + +In the case of Azure Container Apps, the service name should match the app name. For example, if you have a service named "basket", then you should have a corresponding Azure Container App named "basket". + +## Service discovery in .NET Aspire + +.NET Aspire includes functionality for configuring the service discovery at development and testing time. This functionality works by providing configuration in the format expected by the _configuration-based endpoint provider_ described above from the .NET Aspire AppHost project to the individual service projects added to the application model. + +Configuration for service discovery is only added for services which are referenced by a given project. For example, consider the following AppHost program: + +```csharp +var builder = DistributedApplication.CreateBuilder(args); + +var catalog = builder.AddProject("catalog"); +var basket = builder.AddProject("basket"); + +var frontend = builder.AddProject("frontend") + .WithReference(basket) + .WithReference(catalog); +``` + +In the above example, the _frontend_ project references the _catalog_ project and the _basket_ project. The two `WithReference` calls instruct the .NET Aspire application to pass service discovery information for the referenced projects (_catalog_, and _basket_) into the _frontend_ project. + +## Named endpoints + +Some services expose multiple, named endpoints. Named endpoints can be resolved by specifying the endpoint name in the host portion of the HTTP request URI, following the format `scheme://_endpointName.serviceName`. For example, if a service named "basket" exposes an endpoint named "dashboard", then the URI `https+http://_dashboard.basket` can be used to specify this endpoint, for example: + +```csharp +builder.Services.AddHttpClient( + static client => client.BaseAddress = new("https+http://basket")); +builder.Services.AddHttpClient( + static client => client.BaseAddress = new("https+http://_dashboard.basket")); +``` + +In the above example, two `HttpClient`s are added: one for the core basket service and one for the basket service's dashboard. + +### Named endpoints using configuration + +With the configuration-based endpoint provider, named endpoints can be specified in configuration by prefixing the endpoint value with `_endpointName.`, where `endpointName` is the endpoint name. For example, consider this _appsettings.json_ configuration which defined a default endpoint (with no name) and an endpoint named "dashboard": + +```json +{ + "Services": { + "basket": { + "https": "https://10.2.3.4:8080", /* the https endpoint, requested via https://basket */ + "dashboard": "https://10.2.3.4:9999" /* the "dashboard" endpoint, requested via https://_dashboard.basket */ + } + } +} +``` + +### Named endpoints in .NET Aspire + +.NET Aspire uses the configuration-based provider at development and testing time, providing convenient APIs for configuring named endpoints which are then translated into configuration for the target services. For example: + +```csharp +var builder = DistributedApplication.CreateBuilder(args); + +var basket = builder.AddProject("basket") + .WithEndpoint(hostPort: 9999, scheme: "https", name: "admin"); + +var adminDashboard = builder.AddProject("admin-dashboard") + .WithReference(basket.GetEndpoint("admin")); + +var frontend = builder.AddProject("frontend") + .WithReference(basket); +``` + +In the above example, the "basket" service exposes an "admin" endpoint in addition to the default "http" endpoint which it exposes. This endpoint is consumed by the "admin-dashboard" project, while the "frontend" project consumes all endpoints from "basket". Alternatively, the "frontend" project could be made to consume only the default "http" endpoint from "basket" by using the `GetEndpoint(string name)` method, as in the following example: + +```csharp + +// The preceding code is the same as in the above sample + +var frontend = builder.AddProject("frontend") + .WithReference(basket.GetEndpoint("https")); +``` + +### Named endpoints in Kubernetes using DNS SRV + +When deploying to Kubernetes, the DNS SRV service endpoint provider can be used to resolve named endpoints. For example, the following resource definition will result in a DNS SRV record being created for an endpoint named "default" and an endpoint named "dashboard", both on the service named "basket". + +```yml +apiVersion: v1 +kind: Service +metadata: + name: basket +spec: + selector: + name: basket-service + clusterIP: None + ports: + - name: default + port: 8080 + - name: dashboard + port: 8888 +``` + +To configure a service to resolve the "dashboard" endpoint on the "basket" service, add the DNS SRV service endpoint provider to the host builder as follows: + +```csharp +builder.Services.AddServiceDiscoveryCore(); +builder.Services.AddDnsSrvServiceEndpointProvider(); +``` + +The special port name "default" is used to specify the default endpoint, resolved using the URI `https://basket`. + +As in the previous example, add service discovery to an `HttpClient` for the basket service: + +```csharp +builder.Services.AddHttpClient( + static client => client.BaseAddress = new("https://basket")); +``` + +Similarly, the "dashboard" endpoint can be targeted as follows: + +```csharp +builder.Services.AddHttpClient( + static client => client.BaseAddress = new("https://_dashboard.basket")); +``` + +### Named endpoints in Azure Container Apps + +Named endpoints are not currently supported for services deployed to Azure Container Apps. + +## Feedback & contributing + +https://github.com/dotnet/aspire diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryHttpClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryHttpClientBuilderExtensions.cs new file mode 100644 index 00000000000..d2890ae8c8d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryHttpClientBuilderExtensions.cs @@ -0,0 +1,95 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery; +using Microsoft.Extensions.ServiceDiscovery.Http; + +#if NET +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Http; +#endif + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extensions for configuring with service discovery. +/// +public static class ServiceDiscoveryHttpClientBuilderExtensions +{ + /// + /// Adds service discovery to the . + /// + /// The builder. + /// The builder. + public static IHttpClientBuilder AddServiceDiscovery(this IHttpClientBuilder httpClientBuilder) + { + ArgumentNullException.ThrowIfNull(httpClientBuilder); + + var services = httpClientBuilder.Services; + services.AddServiceDiscoveryCore(); + httpClientBuilder.AddHttpMessageHandler(services => + { + var timeProvider = services.GetService() ?? TimeProvider.System; + var watcherFactory = services.GetRequiredService(); + var registry = new HttpServiceEndpointResolver(watcherFactory, services, timeProvider); + var options = services.GetRequiredService>(); + return new ResolvingHttpDelegatingHandler(registry, options); + }); + +#if NET + // Configure the HttpClient to disable gRPC load balancing. + // This is done on all HttpClient instances but only impacts gRPC clients. + AddDisableGrpcLoadBalancingFilter(httpClientBuilder.Services, httpClientBuilder.Name); +#endif + return httpClientBuilder; + } + +#if NET + private static void AddDisableGrpcLoadBalancingFilter(IServiceCollection services, string? name) + { + // A filter is used because it will always run last. This is important because the disable + // property needs to be added to all SocketsHttpHandler instances, including those specified + // with ConfigurePrimaryHttpMessageHandler. + services.TryAddEnumerable(ServiceDescriptor.Singleton()); + services.Configure(o => o.ClientNames.Add(name)); + } + + private sealed class DisableGrpcLoadBalancingFilterOptions + { + // Names of clients. A null value means it is applied globally to all clients. + public HashSet ClientNames { get; } = new HashSet(); + } + + private sealed class DisableGrpcLoadBalancingFilter : IHttpMessageHandlerBuilderFilter + { + private readonly DisableGrpcLoadBalancingFilterOptions _options; + private readonly bool _global; + + public DisableGrpcLoadBalancingFilter(IOptions options) + { + _options = options.Value; + _global = _options.ClientNames.Contains(null); + } + + public Action Configure(Action next) + { + return (builder) => + { + // Run other configuration first, we want to decorate. + next(builder); + if (_global || _options.ClientNames.Contains(builder.Name)) + { + if (builder.PrimaryHandler is SocketsHttpHandler socketsHttpHandler) + { + // gRPC knows about this property and uses it to check whether + // load balancing is disabled when the GrpcChannel is created. + // see https://github.com/grpc/grpc-dotnet/blob/1625f8942791c82d700802fc7278c543025f0fd3/src/Grpc.Net.Client/GrpcChannel.cs#L286 + socketsHttpHandler.Properties["__GrpcLoadBalancingDisabled"] = true; + } + } + }; + } + } +#endif +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryOptions.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryOptions.cs new file mode 100644 index 00000000000..edc652507d9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryOptions.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Options for service endpoint resolution. +/// +public sealed class ServiceDiscoveryOptions +{ + /// + /// Gets or sets a value indicating whether all URI schemes for URIs resolved by the service discovery system are allowed. + /// If this value is , all URI schemes are allowed. + /// If this value is , only the schemes specified in are allowed. + /// + public bool AllowAllSchemes { get; set; } = true; + + /// + /// Gets or sets the period between polling attempts for providers which do not support refresh notifications via . + /// + public TimeSpan RefreshPeriod { get; set; } = TimeSpan.FromSeconds(60); + + /// + /// Gets or sets the collection of allowed URI schemes for URIs resolved by the service discovery system when multiple schemes are specified, for example "https+http://_endpoint.service". + /// + /// + /// When is , this property is ignored. + /// + public IList AllowedSchemes { get; set; } = new List(); + + internal static string[] ApplyAllowedSchemes(IReadOnlyList schemes, IList allowedSchemes, bool allowAllSchemes) + { + if (schemes.Count > 0) + { + if (allowAllSchemes) + { + if (schemes is string[] array && array.Length > 0) + { + return array; + } + + return schemes.ToArray(); + } + + List result = []; + foreach (var s in schemes) + { + foreach (var allowed in allowedSchemes) + { + if (string.Equals(s, allowed, StringComparison.OrdinalIgnoreCase)) + { + result.Add(s); + break; + } + } + } + + return result.ToArray(); + } + + // If no schemes were specified, but a set of allowed schemes were specified, allow those. + return allowedSchemes.ToArray(); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryServiceCollectionExtensions.cs new file mode 100644 index 00000000000..8de759af1f6 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryServiceCollectionExtensions.cs @@ -0,0 +1,119 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery; +using Microsoft.Extensions.ServiceDiscovery.Configuration; +using Microsoft.Extensions.ServiceDiscovery.Http; +using Microsoft.Extensions.ServiceDiscovery.Internal; +using Microsoft.Extensions.ServiceDiscovery.LoadBalancing; +using Microsoft.Extensions.ServiceDiscovery.PassThrough; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension methods for configuring service discovery. +/// +public static class ServiceDiscoveryServiceCollectionExtensions +{ + /// + /// Adds the core service discovery services and configures defaults. + /// + /// The service collection. + /// The service collection. + public static IServiceCollection AddServiceDiscovery(this IServiceCollection services) + => AddServiceDiscoveryCore(services) + .AddConfigurationServiceEndpointProvider() + .AddPassThroughServiceEndpointProvider(); + + /// + /// Adds the core service discovery services and configures defaults. + /// + /// The service collection. + /// The delegate used to configure service discovery options. + /// The service collection. + public static IServiceCollection AddServiceDiscovery(this IServiceCollection services, Action configureOptions) + => AddServiceDiscoveryCore(services, configureOptions: configureOptions) + .AddConfigurationServiceEndpointProvider() + .AddPassThroughServiceEndpointProvider(); + + /// + /// Adds the core service discovery services. + /// + /// The service collection. + /// The service collection. + public static IServiceCollection AddServiceDiscoveryCore(this IServiceCollection services) => AddServiceDiscoveryCore(services, configureOptions: _ => { }); + + /// + /// Adds the core service discovery services. + /// + /// The service collection. + /// The delegate used to configure service discovery options. + /// The service collection. + public static IServiceCollection AddServiceDiscoveryCore(this IServiceCollection services, Action configureOptions) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configureOptions); + + services.AddOptions(); + services.AddLogging(); + services.TryAddTransient, ServiceDiscoveryOptionsValidator>(); + services.TryAddSingleton(_ => TimeProvider.System); + services.TryAddTransient(); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(sp => new ServiceEndpointResolver(sp.GetRequiredService(), sp.GetRequiredService())); + if (configureOptions is not null) + { + services.Configure(configureOptions); + } + + return services; + } + + /// + /// Configures a service discovery endpoint provider which uses to resolve endpoints. + /// + /// The service collection. + /// The service collection. + public static IServiceCollection AddConfigurationServiceEndpointProvider(this IServiceCollection services) + => AddConfigurationServiceEndpointProvider(services, configureOptions: _ => { }); + + /// + /// Configures a service discovery endpoint provider which uses to resolve endpoints. + /// + /// The delegate used to configure the provider. + /// The service collection. + /// The service collection. + public static IServiceCollection AddConfigurationServiceEndpointProvider(this IServiceCollection services, Action configureOptions) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configureOptions); + + services.AddServiceDiscoveryCore(); + services.AddSingleton(); + services.AddTransient, ConfigurationServiceEndpointProviderOptionsValidator>(); + if (configureOptions is not null) + { + services.Configure(configureOptions); + } + + return services; + } + + /// + /// Configures a service discovery endpoint provider which passes through the input without performing resolution. + /// + /// The service collection. + /// The service collection. + public static IServiceCollection AddPassThroughServiceEndpointProvider(this IServiceCollection services) + { + ArgumentNullException.ThrowIfNull(services); + + services.AddServiceDiscoveryCore(); + services.AddSingleton(); + return services; + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointBuilder.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointBuilder.cs new file mode 100644 index 00000000000..947f24b2f81 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointBuilder.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// A mutable collection of service endpoints. +/// +internal sealed class ServiceEndpointBuilder : IServiceEndpointBuilder +{ + private readonly List _endpoints = new(); + private readonly List _changeTokens = new(); + private readonly FeatureCollection _features = new FeatureCollection(); + + /// + /// Adds a change token. + /// + /// The change token. + public void AddChangeToken(IChangeToken changeToken) + { + _changeTokens.Add(changeToken); + } + + /// + /// Gets the feature collection. + /// + public IFeatureCollection Features => _features; + + /// + /// Gets the endpoints. + /// + public IList Endpoints => _endpoints; + + /// + /// Creates a from the provided instance. + /// + /// The service endpoint source. + public ServiceEndpointSource Build() + { + return new ServiceEndpointSource(_endpoints, new CompositeChangeToken(_changeTokens), _features); + } +} + diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointResolver.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointResolver.cs new file mode 100644 index 00000000000..e928980700c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointResolver.cs @@ -0,0 +1,255 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Diagnostics; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Resolves service names to collections of endpoints. +/// +public sealed class ServiceEndpointResolver : IAsyncDisposable +{ + private static readonly TimerCallback s_cleanupCallback = s => ((ServiceEndpointResolver)s!).CleanupResolvers(); + private static readonly TimeSpan s_cleanupPeriod = TimeSpan.FromSeconds(10); + + private readonly object _lock = new(); + private readonly ServiceEndpointWatcherFactory _watcherFactory; + private readonly TimeProvider _timeProvider; + private readonly ConcurrentDictionary _resolvers = new(); + private ITimer? _cleanupTimer; + private Task? _cleanupTask; + private bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + /// The watcher factory. + /// The time provider. + internal ServiceEndpointResolver(ServiceEndpointWatcherFactory watcherFactory, TimeProvider timeProvider) + { + _watcherFactory = watcherFactory; + _timeProvider = timeProvider; + } + + /// + /// Resolves and returns service endpoints for the specified service. + /// + /// The service name. + /// A . + /// The resolved service endpoints. + public async ValueTask GetEndpointsAsync(string serviceName, CancellationToken cancellationToken) + { + ArgumentNullException.ThrowIfNull(serviceName); + ObjectDisposedException.ThrowIf(_disposed, this); + + EnsureCleanupTimerStarted(); + + while (true) + { + ObjectDisposedException.ThrowIf(_disposed, this); + cancellationToken.ThrowIfCancellationRequested(); + var resolver = _resolvers.GetOrAdd( + serviceName, + static (name, self) => self.CreateResolver(name), + this); + + var (valid, result) = await resolver.GetEndpointsAsync(cancellationToken).ConfigureAwait(false); + if (valid) + { + if (result is null) + { + throw new InvalidOperationException($"Unable to resolve endpoints for service {resolver.ServiceName}"); + } + + return result; + } + else + { + _resolvers.TryRemove(KeyValuePair.Create(resolver.ServiceName, resolver)); + } + } + } + + private void EnsureCleanupTimerStarted() + { + if (_cleanupTimer is not null) + { + return; + } + + lock (_lock) + { + if (_cleanupTimer is not null) + { + return; + } + + // Don't capture the current ExecutionContext and its AsyncLocals onto the timer + var restoreFlow = false; + try + { + if (!ExecutionContext.IsFlowSuppressed()) + { + ExecutionContext.SuppressFlow(); + restoreFlow = true; + } + + _cleanupTimer = _timeProvider.CreateTimer(s_cleanupCallback, this, s_cleanupPeriod, s_cleanupPeriod); + } + finally + { + // Restore the current ExecutionContext + if (restoreFlow) + { + ExecutionContext.RestoreFlow(); + } + } + } + } + + /// + public async ValueTask DisposeAsync() + { + lock (_lock) + { + _disposed = true; + _cleanupTimer?.Dispose(); + _cleanupTimer = null; + } + + foreach (var resolver in _resolvers) + { + await resolver.Value.DisposeAsync().ConfigureAwait(false); + } + + _resolvers.Clear(); + if (_cleanupTask is not null) + { + await _cleanupTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + } + } + + private void CleanupResolvers() + { + lock (_lock) + { + if (_cleanupTask is null or { IsCompleted: true }) + { + _cleanupTask = CleanupResolversAsyncCore(); + } + } + } + + private async Task CleanupResolversAsyncCore() + { + List? cleanupTasks = null; + foreach (var (name, resolver) in _resolvers) + { + if (resolver.CanExpire() && _resolvers.TryRemove(name, out var _)) + { + cleanupTasks ??= new(); + cleanupTasks.Add(resolver.DisposeAsync().AsTask()); + } + } + + if (cleanupTasks is not null) + { + await Task.WhenAll(cleanupTasks).ConfigureAwait(false); + } + } + + private ResolverEntry CreateResolver(string serviceName) + { + var resolver = _watcherFactory.CreateWatcher(serviceName); + resolver.Start(); + return new ResolverEntry(resolver); + } + + private sealed class ResolverEntry(ServiceEndpointWatcher watcher) : IAsyncDisposable + { + private readonly ServiceEndpointWatcher _watcher = watcher; + private const ulong CountMask = ~(RecentUseFlag | DisposingFlag); + private const ulong RecentUseFlag = 1UL << 62; + private const ulong DisposingFlag = 1UL << 63; + private ulong _status; + private TaskCompletionSource? _onDisposed; + + public string ServiceName => _watcher.ServiceName; + + public bool CanExpire() + { + // Read the status, clearing the recent use flag in the process. + var status = Interlocked.And(ref _status, ~RecentUseFlag); + + // The instance can be expired if there are no concurrent callers and the recent use flag was not set. + return (status & (CountMask | RecentUseFlag)) == 0; + } + + public async ValueTask<(bool Valid, ServiceEndpointSource? Endpoints)> GetEndpointsAsync(CancellationToken cancellationToken) + { + try + { + var status = Interlocked.Increment(ref _status); + if ((status & DisposingFlag) == 0) + { + // If the watcher is valid, resolve. + // We ensure that it will not be disposed while we are resolving. + var endpoints = await _watcher.GetEndpointsAsync(cancellationToken).ConfigureAwait(false); + return (true, endpoints); + } + else + { + return (false, default); + } + } + finally + { + // Set the recent use flag to prevent the instance from being disposed. + Interlocked.Or(ref _status, RecentUseFlag); + + // If we are the last concurrent request to complete and the Disposing flag has been set, + // dispose the resolver now. DisposeAsync was prevented by concurrent requests. + var status = Interlocked.Decrement(ref _status); + if ((status & DisposingFlag) == DisposingFlag && (status & CountMask) == 0) + { + await DisposeAsyncCore().ConfigureAwait(false); + } + } + } + + public async ValueTask DisposeAsync() + { + if (_onDisposed is null) + { + Interlocked.CompareExchange(ref _onDisposed, new(TaskCreationOptions.RunContinuationsAsynchronously), null); + } + + var status = Interlocked.Or(ref _status, DisposingFlag); + if ((status & DisposingFlag) != DisposingFlag && (status & CountMask) == 0) + { + // If we are the one who flipped the Disposing flag and there are no concurrent requests, + // dispose the instance now. Concurrent requests are prevented from starting by the Disposing flag. + await DisposeAsyncCore().ConfigureAwait(false); + } + else + { + await _onDisposed.Task.ConfigureAwait(false); + } + } + + private async Task DisposeAsyncCore() + { + try + { + await _watcher.DisposeAsync().ConfigureAwait(false); + } + finally + { + Debug.Assert(_onDisposed is not null); + _onDisposed.SetResult(); + } + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcher.Log.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcher.Log.cs new file mode 100644 index 00000000000..8acaa55ee73 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcher.Log.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery; + +partial class ServiceEndpointWatcher +{ + private static partial class Log + { + [LoggerMessage(1, LogLevel.Debug, "Resolving endpoints for service '{ServiceName}'.", EventName = "ResolvingEndpoints")] + public static partial void ResolvingEndpoints(ILogger logger, string serviceName); + + [LoggerMessage(2, LogLevel.Debug, "Endpoint resolution is pending for service '{ServiceName}'.", EventName = "ResolutionPending")] + public static partial void ResolutionPending(ILogger logger, string serviceName); + + [LoggerMessage(3, LogLevel.Debug, "Resolved {Count} endpoints for service '{ServiceName}': {Endpoints}.", EventName = "ResolutionSucceeded")] + public static partial void ResolutionSucceededCore(ILogger logger, int count, string serviceName, string endpoints); + + public static void ResolutionSucceeded(ILogger logger, string serviceName, ServiceEndpointSource endpointSource) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + ResolutionSucceededCore(logger, endpointSource.Endpoints.Count, serviceName, string.Join(", ", endpointSource.Endpoints.Select(GetEndpointString))); + } + + static string GetEndpointString(ServiceEndpoint ep) + { + if (ep.Features.Get() is { } provider) + { + return $"{ep} ({provider})"; + } + + return ep.ToString()!; + } + } + + [LoggerMessage(4, LogLevel.Error, "Error resolving endpoints for service '{ServiceName}'.", EventName = "ResolutionFailed")] + public static partial void ResolutionFailed(ILogger logger, Exception exception, string serviceName); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcher.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcher.cs new file mode 100644 index 00000000000..a94b7b7a3c1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcher.cs @@ -0,0 +1,302 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.ExceptionServices; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.Internal; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Watches for updates to the collection of resolved endpoints for a specified service. +/// +internal sealed partial class ServiceEndpointWatcher( + IServiceEndpointProvider[] providers, + ILogger logger, + string serviceName, + TimeProvider timeProvider, + IOptions options) : IAsyncDisposable +{ + private static readonly TimerCallback s_pollingAction = static state => _ = ((ServiceEndpointWatcher)state!).RefreshAsync(force: true); + + private readonly object _lock = new(); + private readonly ILogger _logger = logger; + private readonly TimeProvider _timeProvider = timeProvider; + private readonly ServiceDiscoveryOptions _options = options.Value; + private readonly IServiceEndpointProvider[] _providers = providers; + private readonly CancellationTokenSource _disposalCancellation = new(); + private ITimer? _pollingTimer; + private ServiceEndpointSource? _cachedEndpoints; + private Task _refreshTask = Task.CompletedTask; + private volatile CacheStatus _cacheState; + private IDisposable? _changeTokenRegistration; + + /// + /// Gets the service name. + /// + public string ServiceName { get; } = serviceName; + + /// + /// Gets or sets the action called when endpoints are updated. + /// + public Action? OnEndpointsUpdated { get; set; } + + /// + /// Starts the endpoint resolver. + /// + public void Start() + { + ThrowIfNoProviders(); + _ = RefreshAsync(force: false); + } + + /// + /// Returns a collection of resolved endpoints for the service. + /// + /// A . + /// A collection of resolved endpoints for the service. + public ValueTask GetEndpointsAsync(CancellationToken cancellationToken = default) + { + ThrowIfNoProviders(); + ObjectDisposedException.ThrowIf(_disposalCancellation.IsCancellationRequested, this); + cancellationToken.ThrowIfCancellationRequested(); + + // If the cache is valid, return the cached value. + if (_cachedEndpoints is { ChangeToken.HasChanged: false } cached) + { + return new ValueTask(cached); + } + + // Otherwise, ensure the cache is being refreshed + // Wait for the cache refresh to complete and return the cached value. + return GetEndpointsInternal(cancellationToken); + + async ValueTask GetEndpointsInternal(CancellationToken cancellationToken) + { + ServiceEndpointSource? result; + var disposalToken = _disposalCancellation.Token; + do + { + disposalToken.ThrowIfCancellationRequested(); + cancellationToken.ThrowIfCancellationRequested(); + await RefreshAsync(force: false).WaitAsync(cancellationToken).ConfigureAwait(false); + result = _cachedEndpoints; + } while (result is null); + + return result; + } + } + + // Ensures that there is a refresh operation running, if needed, and returns the task which represents the completion of the operation + private Task RefreshAsync(bool force) + { + lock (_lock) + { + // If the cache is invalid or needs invalidation, refresh the cache. + if (!_disposalCancellation.IsCancellationRequested && _refreshTask.IsCompleted && (_cacheState == CacheStatus.Invalid || _cachedEndpoints is null or { ChangeToken.HasChanged: true } || force)) + { + // Indicate that the cache is being updated and start a new refresh task. + _cacheState = CacheStatus.Refreshing; + + // Don't capture the current ExecutionContext and its AsyncLocals onto the callback. + var restoreFlow = false; + try + { + if (!ExecutionContext.IsFlowSuppressed()) + { + ExecutionContext.SuppressFlow(); + restoreFlow = true; + } + + _refreshTask = RefreshAsyncInternal(); + } + finally + { + if (restoreFlow) + { + ExecutionContext.RestoreFlow(); + } + } + } + + return _refreshTask; + } + } + + private async Task RefreshAsyncInternal() + { + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); + var cancellationToken = _disposalCancellation.Token; + Exception? error = null; + ServiceEndpointSource? newEndpoints = null; + CacheStatus newCacheState; + try + { + lock (_lock) + { + // Dispose the existing change token registration, if any. + _changeTokenRegistration?.Dispose(); + _changeTokenRegistration = null; + } + + Log.ResolvingEndpoints(_logger, ServiceName); + var builder = new ServiceEndpointBuilder(); + foreach (var provider in _providers) + { + cancellationToken.ThrowIfCancellationRequested(); + await provider.PopulateAsync(builder, cancellationToken).ConfigureAwait(false); + } + + var endpoints = builder.Build(); + newCacheState = CacheStatus.Valid; + + lock (_lock) + { + // Check if we need to poll for updates or if we can register for change notification callbacks. + if (endpoints.ChangeToken.ActiveChangeCallbacks) + { + // Initiate a background refresh when the change token fires. + _changeTokenRegistration = endpoints.ChangeToken.RegisterChangeCallback(static state => _ = ((ServiceEndpointWatcher)state!).RefreshAsync(force: false), this); + + // Dispose the existing timer, if any, since we are reliant on change tokens for updates. + _pollingTimer?.Dispose(); + _pollingTimer = null; + } + else + { + SchedulePollingTimer(); + } + + // The cache is valid + newEndpoints = endpoints; + newCacheState = CacheStatus.Valid; + } + } + catch (Exception exception) + { + error = exception; + newCacheState = CacheStatus.Invalid; + SchedulePollingTimer(); + } + + // If there was an error, the cache must be invalid. + Debug.Assert(error is null || newCacheState is CacheStatus.Invalid); + + // To ensure coherence between the value returned by calls made to GetEndpointsAsync and value passed to the callback, + // we invalidate the cache before invoking the callback. This causes callers to wait on the refresh task + // before receiving the updated value. An alternative approach is to lock access to _cachedEndpoints, but + // that will have more overhead in the common case. + if (newCacheState is CacheStatus.Valid) + { + Interlocked.Exchange(ref _cachedEndpoints, null); + } + + if (OnEndpointsUpdated is { } callback) + { + try + { + callback(new(newEndpoints, error)); + } + catch (Exception exception) + { + _logger.LogError(exception, "Error notifying observers of updated endpoints."); + } + } + + lock (_lock) + { + if (newCacheState is CacheStatus.Valid) + { + Debug.Assert(newEndpoints is not null); + _cachedEndpoints = newEndpoints; + } + + _cacheState = newCacheState; + } + + if (error is not null) + { + Log.ResolutionFailed(_logger, error, ServiceName); + ExceptionDispatchInfo.Throw(error); + } + else if (newEndpoints is not null) + { + Log.ResolutionSucceeded(_logger, ServiceName, newEndpoints); + } + } + + private void SchedulePollingTimer() + { + lock (_lock) + { + if (_disposalCancellation.IsCancellationRequested) + { + _pollingTimer?.Dispose(); + _pollingTimer = null; + return; + } + + if (_pollingTimer is null) + { + _pollingTimer = _timeProvider.CreateTimer(s_pollingAction, this, _options.RefreshPeriod, TimeSpan.Zero); + } + else + { + _pollingTimer.Change(_options.RefreshPeriod, TimeSpan.Zero); + } + } + } + + /// + public async ValueTask DisposeAsync() + { + try + { + _disposalCancellation.Cancel(); + } + catch (Exception exception) + { + _logger.LogError(exception, "Error cancelling disposal cancellation token."); + } + + lock (_lock) + { + _changeTokenRegistration?.Dispose(); + _changeTokenRegistration = null; + + _pollingTimer?.Dispose(); + _pollingTimer = null; + } + + if (_refreshTask is { } task) + { + await task.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + } + + foreach (var provider in _providers) + { + await provider.DisposeAsync().ConfigureAwait(false); + } + } + + private enum CacheStatus + { + Invalid, + Refreshing, + Valid + } + + private void ThrowIfNoProviders() + { + if (_providers.Length == 0) + { + ThrowNoProvidersConfigured(); + } + } + + [DoesNotReturn] + private static void ThrowNoProvidersConfigured() => throw new InvalidOperationException("No service endpoint providers are configured."); +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcherFactory.Log.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcherFactory.Log.cs new file mode 100644 index 00000000000..449ee6920de --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcherFactory.Log.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.ServiceDiscovery; + +partial class ServiceEndpointWatcherFactory +{ + private static partial class Log + { + [LoggerMessage(1, LogLevel.Debug, "Creating endpoint resolver for service '{ServiceName}' with {Count} providers: {Providers}.", EventName = "CreatingResolver")] + public static partial void ServiceEndpointProviderListCore(ILogger logger, string serviceName, int count, string providers); + + public static void CreatingResolver(ILogger logger, string serviceName, List providers) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + ServiceEndpointProviderListCore(logger, serviceName, providers.Count, string.Join(", ", providers.Select(static r => r.ToString()))); + } + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcherFactory.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcherFactory.cs new file mode 100644 index 00000000000..6cc7cb2cbc5 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/ServiceEndpointWatcherFactory.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.ServiceDiscovery.PassThrough; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// Creates service endpoint watchers. +/// +internal sealed partial class ServiceEndpointWatcherFactory( + IEnumerable providerFactories, + ILogger logger, + IOptions options, + TimeProvider timeProvider) +{ + private readonly IServiceEndpointProviderFactory[] _providerFactories = providerFactories + .Where(r => r is not PassThroughServiceEndpointProviderFactory) + .Concat(providerFactories.Where(static r => r is PassThroughServiceEndpointProviderFactory)).ToArray(); + private readonly ILogger _logger = logger; + private readonly TimeProvider _timeProvider = timeProvider; + private readonly IOptions _options = options; + + /// + /// Creates a service endpoint watcher for the provided service name. + /// + public ServiceEndpointWatcher CreateWatcher(string serviceName) + { + ArgumentNullException.ThrowIfNull(serviceName); + + if (!ServiceEndpointQuery.TryParse(serviceName, out var query)) + { + throw new ArgumentException("The provided input was not in a valid format. It must be a valid URI.", nameof(serviceName)); + } + + List? providers = null; + foreach (var factory in _providerFactories) + { + if (factory.TryCreateProvider(query, out var provider)) + { + providers ??= []; + providers.Add(provider); + } + } + + if (providers is not { Count: > 0 }) + { + throw new InvalidOperationException($"No provider which supports the provided service name, '{serviceName}', has been configured."); + } + + Log.CreatingResolver(_logger, serviceName, providers); + return new ServiceEndpointWatcher( + providers: [.. providers], + logger: _logger, + serviceName: serviceName, + timeProvider: _timeProvider, + options: _options); + } +} diff --git a/src/Libraries/Microsoft.Extensions.ServiceDiscovery/UriEndPoint.cs b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/UriEndPoint.cs new file mode 100644 index 00000000000..6b5b07d199e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.ServiceDiscovery/UriEndPoint.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery; + +/// +/// An endpoint represented by a . +/// +/// The . +internal sealed class UriEndPoint(Uri uri) : EndPoint +{ + /// + /// Gets the associated with this endpoint. + /// + public Uri Uri => uri; + + /// + public override bool Equals(object? obj) + { + return obj is UriEndPoint other && Uri.Equals(other.Uri); + } + + /// + public override int GetHashCode() => Uri.GetHashCode(); + + /// + public override string? ToString() => uri.ToString(); +} diff --git a/src/Shared/FxPolyfills/ArgumentException.cs b/src/Shared/FxPolyfills/ArgumentException.cs new file mode 100644 index 00000000000..aafc737ab13 --- /dev/null +++ b/src/Shared/FxPolyfills/ArgumentException.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace System; + +internal static partial class FxPolyfillArgumentException +{ + extension(ArgumentException) + { + public static void ThrowIfNullOrEmpty([NotNull] string? argument, [CallerArgumentExpression(nameof(argument))] string? paramName = null) + { + if (string.IsNullOrEmpty(argument)) + { + ThrowNullOrEmptyException(argument, paramName); + } + } + } + + [DoesNotReturn] + private static void ThrowNullOrEmptyException(string? argument, string? paramName) + { + ArgumentNullException.ThrowIfNull(argument, paramName); + throw new ArgumentException("The value cannot be an empty string.", paramName); + } +} diff --git a/src/Shared/FxPolyfills/ArgumentNullException.cs b/src/Shared/FxPolyfills/ArgumentNullException.cs new file mode 100644 index 00000000000..5585b1e66f8 --- /dev/null +++ b/src/Shared/FxPolyfills/ArgumentNullException.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace System; + +internal static partial class FxPolyfillArgumentNullException +{ + extension(ArgumentNullException) + { + public static void ThrowIfNull([NotNull] object? argument, [CallerArgumentExpression(nameof(argument))] string? paramName = null) + { + if (argument is null) + { + Throw(paramName); + } + } + } + + [DoesNotReturn] + internal static void Throw(string? paramName) => throw new ArgumentNullException(paramName); +} diff --git a/src/Shared/FxPolyfills/ConcurrentDictionary.cs b/src/Shared/FxPolyfills/ConcurrentDictionary.cs new file mode 100644 index 00000000000..92e4a2195df --- /dev/null +++ b/src/Shared/FxPolyfills/ConcurrentDictionary.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Collections.Concurrent; + +internal static partial class FxPolyfillConcurrentDictionary +{ + extension(ConcurrentDictionary dictionary) + { + public TValue GetOrAdd(TKey key, Func valueFactory) + { + if (dictionary.TryGetValue(key, out var existing)) + { + return existing; + } + + return dictionary.GetOrAdd(key, valueFactory(key)); + } + + public TValue GetOrAdd(TKey key, Func valueFactory, TState state) + { + if (dictionary.TryGetValue(key, out var existing)) + { + return existing; + } + + return dictionary.GetOrAdd(key, valueFactory(key, state)); + } + + public void TryRemove(TKey key) + { + dictionary.TryRemove(key, out _); + } + + public void TryRemove(KeyValuePair pair) + { + if (dictionary.TryRemove(pair.Key, out var existing) && !EqualityComparer.Default.Equals(existing, pair.Value)) + { + dictionary.TryAdd(pair.Key, pair.Value); + } + } + } +} diff --git a/src/Shared/FxPolyfills/ExceptionDispatchInfo.cs b/src/Shared/FxPolyfills/ExceptionDispatchInfo.cs new file mode 100644 index 00000000000..81cee7cba9a --- /dev/null +++ b/src/Shared/FxPolyfills/ExceptionDispatchInfo.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; + +namespace System.Runtime.ExceptionServices; + +internal static partial class FxPolyfillExceptionDispatchInfo +{ + extension(ExceptionDispatchInfo) + { + [DoesNotReturn] + public static void Throw(Exception ex) + { + ExceptionDispatchInfo.Capture(ex).Throw(); + } + } +} diff --git a/src/Shared/FxPolyfills/FxPolyfills.targets b/src/Shared/FxPolyfills/FxPolyfills.targets new file mode 100644 index 00000000000..ca38f9ab986 --- /dev/null +++ b/src/Shared/FxPolyfills/FxPolyfills.targets @@ -0,0 +1,25 @@ + + + $(MSBuildThisFileDirectory) + + $(NoWarn);CS8763;CS8777;CS8603;CA1031;IDE0058;S108;S2166;S2302;S2333;S2486;S3400;SA1402;SA1509;SA1515;SA1649;EA0014;LA0001;VSTHRD003 + + + + + + + + + + + + + + + + diff --git a/src/Shared/FxPolyfills/IPEndPoint.cs b/src/Shared/FxPolyfills/IPEndPoint.cs new file mode 100644 index 00000000000..8571b675bb5 --- /dev/null +++ b/src/Shared/FxPolyfills/IPEndPoint.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Globalization; + +namespace System.Net; + +internal static partial class FxPolyfillIPEndPoint +{ + extension(IPEndPoint) + { + public static IPEndPoint Parse(string endpoint) + { + if (TryParse(endpoint.AsSpan(), out var result)) + { + return result; + } + + throw new FormatException("The endpoint format is invalid."); + } + + public static bool TryParse(ReadOnlySpan s, out IPEndPoint? result) + { + const int MaxPort = 0x0000FFFF; + + int addressLength = s.Length; // If there's no port then send the entire string to the address parser + int lastColonPos = s.LastIndexOf(':'); + + // Look to see if this is an IPv6 address with a port. + if (lastColonPos > 0) + { + if (s[lastColonPos - 1] == ']') + { + addressLength = lastColonPos; + } + // Look to see if this is IPv4 with a port (IPv6 will have another colon) + else if (s.Slice(0, lastColonPos).LastIndexOf(':') == -1) + { + addressLength = lastColonPos; + } + } + + if (IPAddress.TryParse(s.Slice(0, addressLength).ToString(), out IPAddress? address)) + { + uint port = 0; + if (addressLength == s.Length || + (uint.TryParse(s.Slice(addressLength + 1).ToString(), NumberStyles.None, CultureInfo.InvariantCulture, out port) && port <= MaxPort)) + + { + result = new IPEndPoint(address, (int)port); + return true; + } + } + + result = null; + return false; + } + } +} diff --git a/src/Shared/FxPolyfills/Interlocked.cs b/src/Shared/FxPolyfills/Interlocked.cs new file mode 100644 index 00000000000..6177e411c35 --- /dev/null +++ b/src/Shared/FxPolyfills/Interlocked.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; + +namespace System.Threading; + +internal static partial class FxPolyfillInterlocked +{ + extension(Interlocked) + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint Decrement(ref uint location) => + (uint)Interlocked.Add(ref Unsafe.As(ref location), -1); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong Decrement(ref ulong location) => + (ulong)Interlocked.Add(ref Unsafe.As(ref location), -1); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint Increment(ref uint location) => + Add(ref location, 1); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong Increment(ref ulong location) => + Add(ref location, 1); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint Add(ref uint location1, uint value) => + (uint)Interlocked.Add(ref Unsafe.As(ref location1), (int)value); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong Add(ref ulong location1, ulong value) => + (ulong)Interlocked.Add(ref Unsafe.As(ref location1), (long)value); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long Or(ref long location1, long value) + { + long current = location1; + while (true) + { + long newValue = current | value; + long oldValue = Interlocked.CompareExchange(ref location1, newValue, current); + if (oldValue == current) + { + return oldValue; + } + current = oldValue; + } + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong And(ref ulong location1, ulong value) => + (ulong)Interlocked.And(ref Unsafe.As(ref location1), (long)value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint And(ref uint location1, uint value) => + (uint)Interlocked.And(ref Unsafe.As(ref location1), (int)value); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int And(ref int location1, int value) + { + int current = location1; + while (true) + { + int newValue = current & value; + int oldValue = Interlocked.CompareExchange(ref location1, newValue, current); + if (oldValue == current) + { + return oldValue; + } + current = oldValue; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long And(ref long location1, long value) + { + long current = location1; + while (true) + { + long newValue = current & value; + long oldValue = Interlocked.CompareExchange(ref location1, newValue, current); + if (oldValue == current) + { + return oldValue; + } + current = oldValue; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong Or(ref ulong location1, ulong value) => + (ulong)Or(ref Unsafe.As(ref location1), (long)value); + } +} diff --git a/src/Shared/FxPolyfills/KeyValuePair.cs b/src/Shared/FxPolyfills/KeyValuePair.cs new file mode 100644 index 00000000000..64c79606bf4 --- /dev/null +++ b/src/Shared/FxPolyfills/KeyValuePair.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Collections.Generic; + +internal static partial class FxPolyfillKeyValuePair +{ + extension(KeyValuePair pair) + { + public void Deconstruct(out TKey key, out TValue value) + { + key = pair.Key; + value = pair.Value; + } + } +} + +internal static class KeyValuePair +{ + public static KeyValuePair Create(TKey key, TValue value) + { + return new KeyValuePair(key, value); + } +} diff --git a/src/Shared/FxPolyfills/ObjectDisposedException.cs b/src/Shared/FxPolyfills/ObjectDisposedException.cs new file mode 100644 index 00000000000..85ec090dd6c --- /dev/null +++ b/src/Shared/FxPolyfills/ObjectDisposedException.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; + +namespace System; + +internal static partial class FxPolyfillObjectDisposedException +{ + extension(ObjectDisposedException) + { + public static void ThrowIf([DoesNotReturnIf(true)] bool condition, object instance) + { + if (condition) + { + throw new ObjectDisposedException(instance?.GetType().FullName); + } + } + } +} diff --git a/src/Shared/FxPolyfills/OperatingSystem.cs b/src/Shared/FxPolyfills/OperatingSystem.cs new file mode 100644 index 00000000000..4c88e7909aa --- /dev/null +++ b/src/Shared/FxPolyfills/OperatingSystem.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System; + +internal static class FrameworkExtensions +{ + extension(OperatingSystem) + { + public static bool IsLinux() => false; + public static bool IsWindows() => true; + public static bool IsMacOS() => false; + } +} diff --git a/src/Shared/FxPolyfills/String.cs b/src/Shared/FxPolyfills/String.cs new file mode 100644 index 00000000000..92df065be7e --- /dev/null +++ b/src/Shared/FxPolyfills/String.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System; + +internal static partial class FxPolyfillString +{ + extension(string s) + { + public bool StartsWith(char c) => s is [{ } first, ..] && first == c; + } +} diff --git a/src/Shared/FxPolyfills/Task.TimeProvider.cs b/src/Shared/FxPolyfills/Task.TimeProvider.cs new file mode 100644 index 00000000000..7e3a9c85bf0 --- /dev/null +++ b/src/Shared/FxPolyfills/Task.TimeProvider.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Threading.Tasks; + +internal static partial class FxPolyfillTask +{ + extension(Task task) + { + public Task WaitAsync(CancellationToken token) + { + return task.WaitAsync(Timeout.InfiniteTimeSpan, TimeProvider.System, token); + } + } +} diff --git a/src/Shared/FxPolyfills/Task.cs b/src/Shared/FxPolyfills/Task.cs new file mode 100644 index 00000000000..0035cde4b6f --- /dev/null +++ b/src/Shared/FxPolyfills/Task.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Threading.Tasks; + +internal enum ConfigureAwaitOptions +{ + None, + ContinueOnCapturedContext, + ForceYielding, + SuppressThrowing, +} + +internal static partial class FxPolyfillTask +{ + extension(Task task) + { + public async Task ConfigureAwait(ConfigureAwaitOptions options) + { + if (options == ConfigureAwaitOptions.None) + { + await task.ConfigureAwait(false); + } + else if (options == ConfigureAwaitOptions.ContinueOnCapturedContext) + { + await task.ConfigureAwait(true); + } + else if (options == ConfigureAwaitOptions.ForceYielding) + { + await Task.Yield(); + await task.ConfigureAwait(false); + } + else if (options == ConfigureAwaitOptions.SuppressThrowing) + { + try + { + await task.ConfigureAwait(false); + } + catch + { + } + } + else + { + throw new InvalidOperationException(); + } + } + } +} + +internal sealed class TaskCompletionSource(TaskCreationOptions options) : TaskCompletionSource(options) +{ + public void SetResult() => SetResult(true); +} diff --git a/src/Shared/Shared.csproj b/src/Shared/Shared.csproj index d25c011a05f..ecffd480a44 100644 --- a/src/Shared/Shared.csproj +++ b/src/Shared/Shared.csproj @@ -26,6 +26,10 @@ 85 + + + + diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore new file mode 100644 index 00000000000..0151cc4e360 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/.gitignore @@ -0,0 +1,2 @@ +# corpuses generated by the fuzzing engine +corpuses/** \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs new file mode 100644 index 00000000000..1b180d74b9d --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/DnsResponseFuzzer.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +internal sealed class DnsResponseFuzzer : IFuzzer +{ + DnsResolver? _resolver; + byte[]? _buffer; + int _length; + + public void FuzzTarget(ReadOnlySpan data) + { + // lazy init + if (_resolver == null) + { + _buffer = new byte[4096]; + _resolver = new DnsResolver(new ResolverOptions(new IPEndPoint(IPAddress.Loopback, 53)) + { + Timeout = TimeSpan.FromSeconds(5), + Attempts = 1, + _transportOverride = (buffer, length) => + { + // the first two bytes are the random transaction ID, so we keep that + // and use the fuzzing payload for the rest of the DNS response + _buffer.AsSpan(0, Math.Min(_length, buffer.Length - 2)).CopyTo(buffer.Span.Slice(2)); + return _length + 2; + } + }); + } + + data.CopyTo(_buffer!); + _length = data.Length; + + // the _transportOverride makes the execution synchronous + ValueTask task = _resolver!.ResolveIPAddressesAsync("www.example.com", AddressFamily.InterNetwork, CancellationToken.None); + Debug.Assert(task.IsCompleted, "Task should be completed synchronously"); + task.GetAwaiter().GetResult(); + } +} \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs new file mode 100644 index 00000000000..72f84b3c959 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/EncodedDomainNameFuzzer.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +internal sealed class EncodedDomainNameFuzzer : IFuzzer +{ + public void FuzzTarget(ReadOnlySpan data) + { + byte[] buffer = ArrayPool.Shared.Rent(data.Length); + try + { + data.CopyTo(buffer); + + // attempt to read at any offset to really stress the parser + for (int i = 0; i < data.Length; i++) + { + if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, data.Length), i, out EncodedDomainName name, out _)) + { + continue; + } + + // the domain name should be readable + _ = name.ToString(); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + + } +} \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs new file mode 100644 index 00000000000..f657245a842 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Fuzzers/WriteDomainNameRoundTripFuzzer.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +internal sealed class WriteDomainNameRoundTripFuzzer : IFuzzer +{ + private static readonly System.Globalization.IdnMapping s_idnMapping = new(); + public void FuzzTarget(ReadOnlySpan data) + { + // first byte is the offset of the domain name, rest is the actual + // (simulated) DNS message payload + + byte[] buffer = ArrayPool.Shared.Rent(data.Length * 2); + + try + { + string domainName = Encoding.UTF8.GetString(data); + if (!DnsPrimitives.TryWriteQName(buffer, domainName, out int written)) + { + return; + } + + if (!DnsPrimitives.TryReadQName(buffer.AsMemory(0, written), 0, out EncodedDomainName name, out int read)) + { + return; + } + + if (read != written) + { + throw new InvalidOperationException($"Read {read} bytes, but wrote {written} bytes"); + } + + string readName = name.ToString(); + + if (!string.Equals(s_idnMapping.GetAscii(domainName).TrimEnd('.'), readName, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException($"Domain name mismatch: {readName} != {domainName}"); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } +} \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs new file mode 100644 index 00000000000..2ff9d86b2ce --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/GlobalUsings.cs @@ -0,0 +1,5 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +global using System.Buffers; +global using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs new file mode 100644 index 00000000000..4b4c8c99b4b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/IFuzzer.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +public interface IFuzzer +{ + string Name => GetType().Name; + void FuzzTarget(ReadOnlySpan data); +} \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj new file mode 100644 index 00000000000..c291dff12c8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj @@ -0,0 +1,20 @@ + + + + $(TestNetCoreTargetFrameworks) + enable + enable + Exe + + $(NoWarn);IDE0040;IDE0061;IDE1006;S5034;SA1400;VSTHRD002 + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs new file mode 100644 index 00000000000..22b1580d1ac --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/Program.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using SharpFuzz; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing; + +public static class Program +{ + public static void Main(string[] args) + { + IFuzzer[] fuzzers = typeof(Program).Assembly.GetTypes() + .Where(t => t.IsClass && !t.IsAbstract) + .Where(t => t.GetInterfaces().Contains(typeof(IFuzzer))) + .Select(t => (IFuzzer)Activator.CreateInstance(t)!) + .OrderBy(f => f.Name, StringComparer.OrdinalIgnoreCase) + .ToArray(); + + void PrintUsage() + { + Console.Error.WriteLine($""" + Usage: + DotnetFuzzing list + DotnetFuzzing [input file/directory] + // DotnetFuzzing prepare-onefuzz + + Available fuzzers: + {string.Join(Environment.NewLine, fuzzers.Select(f => $" {f.Name}"))} + """); + } + + if (args.Length == 0) + { + PrintUsage(); + return; + } + + string arg = args[0]; + IFuzzer? fuzzer = fuzzers.FirstOrDefault(f => string.Equals(f.Name, arg, StringComparison.OrdinalIgnoreCase)); + if (fuzzer == null) + { + Console.Error.WriteLine($"Unknown fuzzer: {arg}"); + PrintUsage(); + return; + } + + string? inputFiles = args.Length > 1 ? args[1] : null; + if (string.IsNullOrEmpty(inputFiles)) + { + // no input files, let the fuzzer generate + Fuzzer.LibFuzzer.Run(fuzzer.FuzzTarget); + return; + } + + string[] files = Directory.Exists(inputFiles) + ? Directory.GetFiles(inputFiles) + : [inputFiles]; + + foreach (string inputFile in files) + { + fuzzer.FuzzTarget(File.ReadAllBytes(inputFile)); + } + } +} \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/ip-www.example.com b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/ip-www.example.com new file mode 100644 index 00000000000..bb40cd100c6 Binary files /dev/null and b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/ip-www.example.com differ diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error new file mode 100644 index 00000000000..92a307a0f72 Binary files /dev/null and b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error differ diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error-2 b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error-2 new file mode 100644 index 00000000000..5b37565190e Binary files /dev/null and b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/name-error-2 differ diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/no-data b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/no-data new file mode 100644 index 00000000000..23265fc7a8f Binary files /dev/null and b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/no-data differ diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/server-error b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/server-error new file mode 100644 index 00000000000..27f1054b9b1 Binary files /dev/null and b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/DnsResponseFuzzer/server-error differ diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/EncodedDomainNameFuzzer/ip-www.example.com b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/EncodedDomainNameFuzzer/ip-www.example.com new file mode 100644 index 00000000000..c227840c168 Binary files /dev/null and b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/EncodedDomainNameFuzzer/ip-www.example.com differ diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/example b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/example new file mode 100644 index 00000000000..8642ba7b94c --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/example @@ -0,0 +1 @@ +www.example.com \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/nonascii b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/nonascii new file mode 100644 index 00000000000..692a5a8aee0 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/nonascii @@ -0,0 +1 @@ +www.řffwefw.com \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/toolong b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/toolong new file mode 100644 index 00000000000..bb6d3a722e8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/corpus-seed/WriteDomainNameRoundTripFuzzer/toolong @@ -0,0 +1 @@ +aa.efaw.ef.wef.ef.wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww.fafeww.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.fwefefefefwefwf.wzzefwefwefwefwfeewfwefwefw.ffff \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/run.ps1 b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/run.ps1 new file mode 100644 index 00000000000..17b3d27055d --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing/run.ps1 @@ -0,0 +1,106 @@ +param( + # Name of the fuzzing target, see Fuzzers/*.cs files + [Parameter(Mandatory = $true, Position = 0)] + [ArgumentCompleter({ + param($commandName, $parameterName, $wordToComplete, $commandAst, $fakeBoundParameters) + $corpusSeedPath = Join-Path $PSScriptRoot "Fuzzers" + if (Test-Path $corpusSeedPath) { + Get-ChildItem -Path $corpusSeedPath -Filter "$wordToComplete*.cs" | ForEach-Object { $_.BaseName } + } + })] + [string] $Target, + + # Number of parallel jobs to run + [int] $Jobs, + + # Maximum length of the input + [int] $MaxLength = 512, + + # Ignore timeouts when running the fuzzer + [switch] $IgnoreTimeouts, + + # Skip the build of the project useful for reruning the fuzzer without recompiling + [switch] $NoBuild, + + # Path to the libfuzzer driver + [string] $LibFuzzer = "libfuzzer-dotnet-windows" +) + +$timeout = 30 +$SharpFuzz = "sharpfuzz" +$dict = $null + +$corpus = Join-Path $PSScriptRoot "corpuses" $Target +$null = New-Item -Path $corpus -ItemType Directory -Force + +$CorpusSeed = Join-Path $PSScriptRoot "corpus-seed" $Target + +if (Test-Path $CorpusSeed -ErrorAction SilentlyContinue) { + Write-Output "Copying corpus seed from $CorpusSeed to $corpus" + Get-ChildItem -Path $CorpusSeed | Copy-Item -Destination $corpus +} + +$project = Join-Path $PSScriptRoot "Microsoft.Extensions.ServiceDiscovery.Dns.Tests.Fuzzing.csproj" + +Set-StrictMode -Version Latest + +$outputDir = "bin" +$projectName = (Get-Item $project).BaseName +$projectDll = "$projectName.dll" +$executable = if ($IsWindows) { Join-Path $outputDir "$projectName.exe" } +else { Join-Path $outputDir "$projectName" } + +if (!$NoBuild) { + + if (Test-Path $outputDir) { + Remove-Item -Recurse -Force $outputDir + } + + dotnet publish $project -c release -o $outputDir + + $exclusions = @( + "dnlib.dll", + "SharpFuzz.dll", + "SharpFuzz.Common.dll", + $projectDll + ) + + $fuzzingTargets = @(Get-Item "$outputDir/Microsoft.Extensions.ServiceDiscovery.Dns.dll") + + if (($fuzzingTargets | Measure-Object).Count -eq 0) { + Write-Error "No fuzzing targets found" + exit 1 + } + + foreach ($fuzzingTarget in $fuzzingTargets) { + Write-Output "Instrumenting $fuzzingTarget" + & $SharpFuzz $fuzzingTarget.FullName + + if ($LastExitCode -ne 0) { + Write-Error "An error occurred while instrumenting $fuzzingTarget" + exit 1 + } + } +} + +$parameters = @( + "-timeout=$timeout" +) + +if ($Jobs) { + $parameters += "-fork=$Jobs" +} + +if ($IgnoreTimeouts) { + $parameters += "-ignore_timeouts=1" +} + +if ($MaxLength) { + $parameters += "-max_len=$MaxLength" +} + +if ($dict) { + $parameters += "-dict=$dict" +} + +& $LibFuzzer @parameters --target_path=$executable --target_arg=$Target $corpus \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs new file mode 100644 index 00000000000..b949e713999 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServiceEndpointResolverTests.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Time.Testing; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests; + +public class DnsServiceEndpointResolverTests +{ + [Fact] + public async Task ResolveServiceEndpoint_Dns_MultiShot() + { + var timeProvider = new FakeTimeProvider(); + var services = new ServiceCollection() + .AddSingleton(timeProvider) + .AddSingleton() + .AddServiceDiscoveryCore() + .AddDnsServiceEndpointProvider(o => o.DefaultRefreshPeriod = TimeSpan.FromSeconds(30)) + .BuildServiceProvider(); + var resolver = services.GetRequiredService(); + var initialResult = await resolver.GetEndpointsAsync("https://localhost", CancellationToken.None); + Assert.NotNull(initialResult); + Assert.True(initialResult.Endpoints.Count > 0); + timeProvider.Advance(TimeSpan.FromSeconds(7)); + var secondResult = await resolver.GetEndpointsAsync("https://localhost", CancellationToken.None); + Assert.NotNull(secondResult); + Assert.True(initialResult.Endpoints.Count > 0); + timeProvider.Advance(TimeSpan.FromSeconds(80)); + var thirdResult = await resolver.GetEndpointsAsync("https://localhost", CancellationToken.None); + Assert.NotNull(thirdResult); + Assert.True(initialResult.Endpoints.Count > 0); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServicePublicApiTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServicePublicApiTests.cs new file mode 100644 index 00000000000..e347deb9822 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsServicePublicApiTests.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests; + +public class DnsServicePublicApiTests +{ + [Fact] + public void AddDnsSrvServiceEndpointProviderShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + + var action = () => services.AddDnsSrvServiceEndpointProvider(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddDnsSrvServiceEndpointProviderWithConfigureOptionsShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + Action configureOptions = (_) => { }; + + var action = () => services.AddDnsSrvServiceEndpointProvider(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddDnsSrvServiceEndpointProviderWithConfigureOptionsShouldThrowWhenConfigureOptionsIsNull() + { + IServiceCollection services = new ServiceCollection(); + Action configureOptions = null!; + + var action = () => services.AddDnsSrvServiceEndpointProvider(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(configureOptions), exception.ParamName); + } + + [Fact] + public void AddDnsServiceEndpointProviderShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + + var action = () => services.AddDnsServiceEndpointProvider(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddDnsServiceEndpointProviderWithConfigureOptionsShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + Action configureOptions = (_) => { }; + + var action = () => services.AddDnsServiceEndpointProvider(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddDnsServiceEndpointProviderWithConfigureOptionsShouldThrowWhenConfigureOptionsIsNull() + { + IServiceCollection services = new ServiceCollection(); + Action configureOptions = null!; + + var action = () => services.AddDnsServiceEndpointProvider(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(configureOptions), exception.ParamName); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs new file mode 100644 index 00000000000..ec21bf9fa9c --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/DnsSrvServiceEndpointResolverTests.cs @@ -0,0 +1,177 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Sockets; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Configuration.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; +using Microsoft.Extensions.ServiceDiscovery.Internal; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests; + +/// +/// Tests for and . +/// These also cover and by extension. +/// +public class DnsSrvServiceEndpointResolverTests +{ + private sealed class FakeDnsResolver : IDnsResolver + { + public Func>? ResolveIPAddressesAsyncFunc { get; set; } + public ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc!.Invoke(name, addressFamily, cancellationToken); + + public Func>? ResolveIPAddressesAsyncFunc2 { get; set; } + + public ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc2!.Invoke(name, cancellationToken); + + public Func>? ResolveServiceAsyncFunc { get; set; } + + public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) => ResolveServiceAsyncFunc!.Invoke(name, cancellationToken); + } + + [Fact] + public async Task ResolveServiceEndpoint_DnsSrv() + { + var dnsClientMock = new FakeDnsResolver + { + ResolveServiceAsyncFunc = (name, cancellationToken) => + { + ServiceResult[] response = [ + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 66, 8888, "srv-a", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Parse("10.10.10.10"))]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 9999, "srv-b", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.IPv6Loopback)]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 7777, "srv-c", []) + ]; + + return ValueTask.FromResult(response); + } + }; + var services = new ServiceCollection() + .AddSingleton(dnsClientMock) + .AddServiceDiscoveryCore() + .AddDnsSrvServiceEndpointProvider(options => options.QuerySuffix = ".ns") + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Equal(3, initialResult.EndpointSource.Endpoints.Count); + var eps = initialResult.EndpointSource.Endpoints; + Assert.Equal(new IPEndPoint(IPAddress.Parse("10.10.10.10"), 8888), eps[0].EndPoint); + Assert.Equal(new IPEndPoint(IPAddress.IPv6Loopback, 9999), eps[1].EndPoint); + Assert.Equal(new DnsEndPoint("srv-c", 7777), eps[2].EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.Null(hostNameFeature); + }); + } + } + + /// + /// Tests that when there are multiple resolvers registered, they are consulted in registration order and each provider only adds endpoints if the providers before it did not. + /// + [InlineData(true)] + [InlineData(false)] + [Theory] + public async Task ResolveServiceEndpoint_DnsSrv_MultipleProviders_PreventMixing(bool dnsFirst) + { + var dnsClientMock = new FakeDnsResolver + { + ResolveServiceAsyncFunc = (name, cancellationToken) => + { + ServiceResult[] response = [ + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 66, 8888, "srv-a", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Parse("10.10.10.10"))]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 9999, "srv-b", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.IPv6Loopback)]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 7777, "srv-c", []) + ]; + + return ValueTask.FromResult(response); + } + }; + var configSource = new MemoryConfigurationSource + { + InitialData = new Dictionary + { + ["services:basket:http:0"] = "localhost:8080", + ["services:basket:http:1"] = "remotehost:9090", + } + }; + var config = new ConfigurationBuilder().Add(configSource); + var serviceCollection = new ServiceCollection() + .AddSingleton(dnsClientMock) + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore(); + if (dnsFirst) + { + serviceCollection + .AddDnsSrvServiceEndpointProvider(options => + { + options.QuerySuffix = ".ns"; + options.ShouldApplyHostNameMetadata = _ => true; + }) + .AddConfigurationServiceEndpointProvider(); + } + else + { + serviceCollection + .AddConfigurationServiceEndpointProvider() + .AddDnsSrvServiceEndpointProvider(options => options.QuerySuffix = ".ns"); + }; + var services = serviceCollection.BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.Null(initialResult.Exception); + Assert.True(initialResult.ResolvedSuccessfully); + + if (dnsFirst) + { + // We expect only the results from the DNS provider. + Assert.Equal(3, initialResult.EndpointSource.Endpoints.Count); + var eps = initialResult.EndpointSource.Endpoints; + Assert.Equal(new IPEndPoint(IPAddress.Parse("10.10.10.10"), 8888), eps[0].EndPoint); + Assert.Equal(new IPEndPoint(IPAddress.IPv6Loopback, 9999), eps[1].EndPoint); + Assert.Equal(new DnsEndPoint("srv-c", 7777), eps[2].EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.NotNull(hostNameFeature); + Assert.Equal("basket", hostNameFeature.HostName); + }); + } + else + { + // We expect only the results from the Configuration provider. + Assert.Equal(2, initialResult.EndpointSource.Endpoints.Count); + Assert.Equal(new DnsEndPoint("localhost", 8080), initialResult.EndpointSource.Endpoints[0].EndPoint); + Assert.Equal(new DnsEndPoint("remotehost", 9090), initialResult.EndpointSource.Endpoints[1].EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.Null(hostNameFeature); + }); + } + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj new file mode 100644 index 00000000000..4911d854007 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj @@ -0,0 +1,26 @@ + + + + $(TestNetCoreTargetFrameworks) + enable + enable + + $(NoWarn);IDE0004;IDE0017;IDE0040;IDE0055;IDE1006;CA1012;CA1031;CA1063;CA1816;CA2000;S103;S107;S1067;S1121;S1128;S1135;S1144;S1186;S2148;S3442;S3459;S4136;SA1106;SA1127;SA1204;SA1208;SA1210;SA1128;SA1316;SA1400;SA1402;SA1407;SA1414;SA1500;SA1513;SA1515;VSTHRD003 + + + + + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs new file mode 100644 index 00000000000..786882afc1d --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/CancellationTests.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Sockets; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class CancellationTests : LoopbackDnsTestBase +{ + public CancellationTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public async Task PreCanceledToken_Throws() + { + CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); + + var ex = await Assert.ThrowsAnyAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork, cts.Token)); + + Assert.Equal(cts.Token, ex.CancellationToken); + } + + [Fact] + public async Task CancellationInProgress_Throws() + { + CancellationTokenSource cts = new CancellationTokenSource(); + + var task = Assert.ThrowsAnyAsync(async () => await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork, cts.Token)); + + await DnsServer.ProcessUdpRequest(_ => + { + cts.Cancel(); + return Task.CompletedTask; + }); + + OperationCanceledException ex = await task; + Assert.Equal(cts.Token, ex.CancellationToken); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs new file mode 100644 index 00000000000..aad32fe785f --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataReaderTests.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class DnsDataReaderTests +{ + [Fact] + public void ReadResourceRecord_Success() + { + // example A record for example.com + byte[] buffer = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + DnsDataReader reader = new DnsDataReader(buffer); + Assert.True(reader.TryReadResourceRecord(out DnsResourceRecord record)); + + Assert.Equal("www.example.com", record.Name.ToString()); + Assert.Equal(QueryType.A, record.Type); + Assert.Equal(QueryClass.Internet, record.Class); + Assert.Equal(3600, record.Ttl); + Assert.Equal(4, record.Data.Length); + } + + [Fact] + public void ReadResourceRecord_Truncated_Fails() + { + // example A record for example.com + byte[] buffer = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + for (int i = 0; i < buffer.Length; i++) + { + DnsDataReader reader = new DnsDataReader(new ArraySegment(buffer, 0, i)); + Assert.False(reader.TryReadResourceRecord(out _)); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs new file mode 100644 index 00000000000..b2039ce5a4c --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsDataWriterTests.cs @@ -0,0 +1,148 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class DnsDataWriterTests +{ + [Fact] + public void WriteResourceRecord_Success() + { + // example A record for example.com + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + DnsResourceRecord record = new DnsResourceRecord(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet, 3600, new byte[4]); + + byte[] buffer = new byte[512]; + DnsDataWriter writer = new DnsDataWriter(buffer); + Assert.True(writer.TryWriteResourceRecord(record)); + Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); + } + + [Fact] + public void WriteResourceRecord_Truncated_Fails() + { + // example A record for example.com + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01, + // TTL (3600) + 0x00, 0x00, 0x0e, 0x10, + // data length (4) + 0x00, 0x04, + // data (placeholder) + 0x00, 0x00, 0x00, 0x00 + ]; + + DnsResourceRecord record = new DnsResourceRecord(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet, 3600, new byte[4]); + + byte[] buffer = new byte[512]; + for (int i = 0; i < expected.Length; i++) + { + DnsDataWriter writer = new DnsDataWriter(buffer.AsMemory(0, i)); + Assert.False(writer.TryWriteResourceRecord(record)); + } + } + + [Fact] + public void WriteQuestion_Success() + { + // example question for example.com (A record) + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01 + ]; + + byte[] buffer = new byte[512]; + DnsDataWriter writer = new DnsDataWriter(buffer); + Assert.True(writer.TryWriteQuestion(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet)); + Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); + } + + [Fact] + public void WriteQuestion_Truncated_Fails() + { + // example question for example.com (A record) + byte[] expected = [ + // name (www.example.com) + 0x03, 0x77, 0x77, 0x77, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + // type (A) + 0x00, 0x01, + // class (IN) + 0x00, 0x01 + ]; + + byte[] buffer = new byte[512]; + for (int i = 0; i < expected.Length; i++) + { + DnsDataWriter writer = new DnsDataWriter(buffer.AsMemory(0, i)); + Assert.False(writer.TryWriteQuestion(EncodeDomainName("www.example.com"), QueryType.A, QueryClass.Internet)); + } + } + + [Fact] + public void WriteHeader_Success() + { + // example header + byte[] expected = [ + // ID (0x1234) + 0x12, 0x34, + // Flags (0x5678) + 0x56, 0x78, + // Question count (1) + 0x00, 0x01, + // Answer count (0) + 0x00, 0x02, + // Authority count (0) + 0x00, 0x03, + // Additional count (0) + 0x00, 0x04 + ]; + + DnsMessageHeader header = new() + { + TransactionId = 0x1234, + QueryFlags = (QueryFlags)0x5678, + QueryCount = 1, + AnswerCount = 2, + AuthorityCount = 3, + AdditionalRecordCount = 4, + }; + + byte[] buffer = new byte[512]; + DnsDataWriter writer = new DnsDataWriter(buffer); + Assert.True(writer.TryWriteHeader(header)); + Assert.Equal(expected, buffer.AsSpan().Slice(0, writer.Position).ToArray()); + } + + private static EncodedDomainName EncodeDomainName(string name) + { + byte[] nameBuffer = new byte[512]; + Assert.True(DnsPrimitives.TryWriteQName(nameBuffer, name, out int nameLength)); + Assert.True(DnsPrimitives.TryReadQName(nameBuffer.AsMemory(0, nameLength), 0, out EncodedDomainName encodedDomainName, out _)); + return encodedDomainName; + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs new file mode 100644 index 00000000000..6733a553bad --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/DnsPrimitivesTests.cs @@ -0,0 +1,195 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class DnsPrimitivesTests +{ + public static TheoryData QNameData => new() + { + { "www.example.com", "\x0003www\x0007example\x0003com\x0000"u8.ToArray() }, + { "example.com", "\x0007example\x0003com\x0000"u8.ToArray() }, + { "com", "\x0003com\x0000"u8.ToArray() }, + { "example", "\x0007example\x0000"u8.ToArray() }, + { "www", "\x0003www\x0000"u8.ToArray() }, + { "a", "\x0001a\x0000"u8.ToArray() }, + }; + + [Theory] + [MemberData(nameof(QNameData))] + public void TryWriteQName_Success(string name, byte[] expected) + { + byte[] buffer = new byte[512]; + + Assert.True(DnsPrimitives.TryWriteQName(buffer, name, out int written)); + Assert.Equal(name.Length + 2, written); + Assert.Equal(expected, buffer.AsSpan().Slice(0, written).ToArray()); + } + + [Fact] + public void TryWriteQName_LabelTooLong_False() + { + byte[] buffer = new byte[512]; + + Assert.False(DnsPrimitives.TryWriteQName(buffer, new string('a', 70), out _)); + } + + [Fact] + public void TryWriteQName_BufferTooShort_Fails() + { + byte[] buffer = new byte[512]; + string name = "www.example.com"; + + for (int i = 0; i < name.Length + 2; i++) + { + Assert.False(DnsPrimitives.TryWriteQName(buffer.AsSpan(0, i), name, out _)); + } + } + + [Theory] + [InlineData("www.-0.com")] + [InlineData("www.-a.com")] + [InlineData("www.a-.com")] + [InlineData("www.a_a.com")] + [InlineData("www.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.com")] // 64 occurrences of 'a' (too long) + [InlineData("www.a~a.com")] // 64 occurrences of 'a' (too long) + [InlineData("www..com")] + [InlineData("www..")] + public void TryWriteQName_InvalidName_ReturnsFalse(string name) + { + byte[] buffer = new byte[512]; + Assert.False(DnsPrimitives.TryWriteQName(buffer, name, out _)); + } + + [Fact] + public void TryWriteQName_ExplicitRoot_Success() + { + string name1 = "www.example.com"; + string name2 = "www.example.com."; + + byte[] buffer1 = new byte[512]; + byte[] buffer2 = new byte[512]; + + Assert.True(DnsPrimitives.TryWriteQName(buffer1, name1, out int written1)); + Assert.True(DnsPrimitives.TryWriteQName(buffer2, name2, out int written2)); + Assert.Equal(written1, written2); + Assert.Equal(buffer1.AsSpan().Slice(0, written1).ToArray(), buffer2.AsSpan().Slice(0, written2).ToArray()); + } + + [Theory] + [MemberData(nameof(QNameData))] + public void TryReadQName_Success(string expected, byte[] serialized) + { + Assert.True(DnsPrimitives.TryReadQName(serialized, 0, out EncodedDomainName actual, out int bytesRead)); + Assert.Equal(expected, actual.ToString()); + Assert.Equal(serialized.Length, bytesRead); + } + + [Fact] + public void TryReadQName_TruncatedData_Fails() + { + ReadOnlyMemory data = "\x0003www\x0007example\x0003com\x0000"u8.ToArray(); + + for (int i = 0; i < data.Length; i++) + { + Assert.False(DnsPrimitives.TryReadQName(data.Slice(0, i), 0, out _, out _)); + } + } + + [Fact] + public void TryReadQName_Pointer_Success() + { + // [7B padding], example.com. www->[ptr to example.com.] + Memory data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); + data.Span[^2] = 0xc0; + + Assert.True(DnsPrimitives.TryReadQName(data, data.Length - 6, out EncodedDomainName actual, out int bytesRead)); + Assert.Equal("www.example.com", actual.ToString()); + Assert.Equal(6, bytesRead); + } + + [Fact] + public void TryReadQName_PointerTruncated_Fails() + { + // [7B padding], example.com. www->[ptr to example.com.] + Memory data = "padding\x0007example\x0003com\x0000\x0003www\x00\x07"u8.ToArray(); + data.Span[^2] = 0xc0; + + for (int i = 0; i < data.Length; i++) + { + Assert.False(DnsPrimitives.TryReadQName(data.Slice(0, i), data.Length - 6, out _, out _)); + } + } + + [Fact] + public void TryReadQName_ForwardPointer_Fails() + { + // www->[ptr to example.com], [7B padding], example.com. + Memory data = "\x03www\x00\x000dpadding\x0007example\x0003com\x00"u8.ToArray(); + data.Span[4] = 0xc0; + + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + + [Fact] + public void TryReadQName_PointerToSelf_Fails() + { + // www->[ptr to www->...] + Memory data = "\x0003www\0\0"u8.ToArray(); + data.Span[4] = 0xc0; + + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + + [Fact] + public void TryReadQName_PointerToPointer_Fails() + { + // com, example[->com], example2[->[->com]] + Memory data = "\x0003com\0\x0007example\0\0\x0008example2\0\0"u8.ToArray(); + data.Span[13] = 0xc0; + data.Span[14] = 0x00; // -> com + data.Span[24] = 0xc0; + data.Span[25] = 13; // -> -> com + + Assert.False(DnsPrimitives.TryReadQName(data, 15, out _, out _)); + } + + [Fact] + public void TryReadQName_ReservedBits() + { + Memory data = "\x0003www\x00c0"u8.ToArray(); + data.Span[0] = 0x40; + + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + + [Theory] + [InlineData(253)] + [InlineData(254)] + [InlineData(255)] + public void TryReadQName_NameTooLong(int length) + { + // longest possible label is 63 bytes + 1 byte for length + byte[] labelData = new byte[64]; + Array.Fill(labelData, (byte)'a'); + labelData[0] = 63; + + int remainder = length - 3 * 64; + + byte[] lastLabelData = new byte[remainder + 1]; + Array.Fill(lastLabelData, (byte)'a'); + lastLabelData[0] = (byte)remainder; + + byte[] data = Enumerable.Repeat(labelData, 3).SelectMany(x => x).Concat(lastLabelData).Concat(new byte[1]).ToArray(); + if (length > 253) + { + Assert.False(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + else + { + Assert.True(DnsPrimitives.TryReadQName(data, 0, out _, out _)); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs new file mode 100644 index 00000000000..4789e21c575 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsServer.cs @@ -0,0 +1,331 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Globalization; +using System.Net; +using System.Net.Sockets; +using System.Text; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +internal sealed class LoopbackDnsServer : IDisposable +{ + private readonly Socket _dnsSocket; + private Socket? _tcpSocket; + + public IPEndPoint DnsEndPoint => (IPEndPoint)_dnsSocket.LocalEndPoint!; + + public LoopbackDnsServer() + { + _dnsSocket = new(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + _dnsSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + } + + public void Dispose() + { + _dnsSocket.Dispose(); + _tcpSocket?.Dispose(); + } + + private static async Task ProcessRequestCore(IPEndPoint remoteEndPoint, ArraySegment message, Func action, Memory responseBuffer) + { + DnsDataReader reader = new DnsDataReader(message); + + if (!reader.TryReadHeader(out DnsMessageHeader header) || + !reader.TryReadQuestion(out var name, out var type, out var @class)) + { + return 0; + } + + LoopbackDnsResponseBuilder responseBuilder = new(name.ToString(), type, @class); + responseBuilder.TransactionId = header.TransactionId; + responseBuilder.Flags = header.QueryFlags | QueryFlags.HasResponse; + responseBuilder.ResponseCode = QueryResponseCode.NoError; + + await action(responseBuilder, remoteEndPoint); + + return responseBuilder.Write(responseBuffer); + } + + public async Task ProcessUdpRequest(Func action) + { + byte[] buffer = ArrayPool.Shared.Rent(512); + try + { + EndPoint remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); + SocketReceiveFromResult result = await _dnsSocket.ReceiveFromAsync(buffer, remoteEndPoint); + + int bytesWritten = await ProcessRequestCore((IPEndPoint)result.RemoteEndPoint, new ArraySegment(buffer, 0, result.ReceivedBytes), action, buffer.AsMemory(0, 512)); + + await _dnsSocket.SendToAsync(buffer.AsMemory(0, bytesWritten), SocketFlags.None, result.RemoteEndPoint); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + public Task ProcessUdpRequest(Func action) + { + return ProcessUdpRequest((builder, _) => action(builder)); + } + + public async Task ProcessTcpRequest(Func action) + { + if (_tcpSocket is null) + { + _tcpSocket = new(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _tcpSocket.Bind(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)_dnsSocket.LocalEndPoint!).Port)); + _tcpSocket.Listen(); + } + + using Socket tcpClient = await _tcpSocket.AcceptAsync(); + + byte[] buffer = ArrayPool.Shared.Rent(8 * 1024); + try + { + int bytesRead = 0; + int length = -1; + while (length < 0 || bytesRead < length + 2) + { + int toRead = length < 0 ? 2 : length + 2 - bytesRead; + int read = await tcpClient.ReceiveAsync(buffer.AsMemory(bytesRead, toRead), SocketFlags.None); + bytesRead += read; + + if (length < 0 && bytesRead >= 2) + { + length = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(0, 2)); + } + } + + int bytesWritten = await ProcessRequestCore((IPEndPoint)tcpClient.RemoteEndPoint!, new ArraySegment(buffer, 2, length), action, buffer.AsMemory(2)); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(0, 2), (ushort)bytesWritten); + await tcpClient.SendAsync(buffer.AsMemory(0, bytesWritten + 2), SocketFlags.None); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + public Task ProcessTcpRequest(Func action) + { + return ProcessTcpRequest((builder, _) => action(builder)); + } +} + +internal sealed class LoopbackDnsResponseBuilder +{ + private static readonly SearchValues s_domainNameValidChars = SearchValues.Create("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_."); + + public LoopbackDnsResponseBuilder(string name, QueryType type, QueryClass @class) + { + Name = name; + Type = type; + Class = @class; + Questions.Add((name, type, @class)); + + if (name.AsSpan().ContainsAnyExcept(s_domainNameValidChars)) + { + throw new ArgumentException($"Invalid characters in domain name '{name}'"); + } + } + + public ushort TransactionId { get; set; } + public QueryFlags Flags { get; set; } + public QueryResponseCode ResponseCode { get; set; } + + public string Name { get; } + public QueryType Type { get; } + public QueryClass Class { get; } + + public List<(string, QueryType, QueryClass)> Questions { get; } = new List<(string, QueryType, QueryClass)>(); + public List Answers { get; } = new List(); + public List Authorities { get; } = new List(); + public List Additionals { get; } = new List(); + + public int Write(Memory responseBuffer) + { + DnsDataWriter writer = new(responseBuffer); + if (!writer.TryWriteHeader(new DnsMessageHeader + { + TransactionId = TransactionId, + QueryFlags = Flags | (QueryFlags)ResponseCode, + QueryCount = (ushort)Questions.Count, + AnswerCount = (ushort)Answers.Count, + AuthorityCount = (ushort)Authorities.Count, + AdditionalRecordCount = (ushort)Additionals.Count + })) + { + throw new InvalidOperationException("Failed to write header"); + } + + byte[] buffer = ArrayPool.Shared.Rent(512); + foreach (var (questionName, questionType, questionClass) in Questions) + { + if (!DnsPrimitives.TryWriteQName(buffer, questionName, out int length) || + !DnsPrimitives.TryReadQName(buffer.AsMemory(0, length), 0, out EncodedDomainName encodedName, out _)) + { + throw new InvalidOperationException("Failed to encode domain name"); + } + if (!writer.TryWriteQuestion(encodedName, questionType, questionClass)) + { + throw new InvalidOperationException("Failed to write question"); + } + } + ArrayPool.Shared.Return(buffer); + + foreach (var answer in Answers) + { + if (!writer.TryWriteResourceRecord(answer)) + { + throw new InvalidOperationException("Failed to write answer"); + } + } + + foreach (var authority in Authorities) + { + if (!writer.TryWriteResourceRecord(authority)) + { + throw new InvalidOperationException("Failed to write authority"); + } + } + + foreach (var additional in Additionals) + { + if (!writer.TryWriteResourceRecord(additional)) + { + throw new InvalidOperationException("Failed to write additional records"); + } + } + + return writer.Position; + } + + public byte[] GetMessageBytes() + { + byte[] buffer = ArrayPool.Shared.Rent(512); + try + { + int bytesWritten = Write(buffer.AsMemory(0, 512)); + return buffer.AsSpan(0, bytesWritten).ToArray(); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } +} + +internal static class LoopbackDnsServerExtensions +{ + private static readonly IdnMapping s_idnMapping = new IdnMapping(); + + private static EncodedDomainName EncodeDomainName(string name) + { + var encodedLabels = name.Split('.', StringSplitOptions.RemoveEmptyEntries).Select(label => (ReadOnlyMemory)Encoding.UTF8.GetBytes(s_idnMapping.GetAscii(label))) + .ToList(); + + return new EncodedDomainName(encodedLabels); + } + + public static List AddAddress(this List records, string name, int ttl, IPAddress address) + { + QueryType type = address.AddressFamily == AddressFamily.InterNetwork ? QueryType.A : QueryType.AAAA; + records.Add(new DnsResourceRecord(EncodeDomainName(name), type, QueryClass.Internet, ttl, address.GetAddressBytes())); + return records; + } + + public static List AddCname(this List records, string name, int ttl, string alias) + { + byte[] buff = new byte[256]; + if (!DnsPrimitives.TryWriteQName(buff, alias, out int length)) + { + throw new InvalidOperationException("Failed to encode domain name"); + } + + records.Add(new DnsResourceRecord(EncodeDomainName(name), QueryType.CNAME, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + return records; + } + + public static List AddService(this List records, string name, int ttl, ushort priority, ushort weight, ushort port, string target) + { + byte[] buff = new byte[256]; + + // https://www.rfc-editor.org/rfc/rfc2782 + if (!BinaryPrimitives.TryWriteUInt16BigEndian(buff, priority) || + !BinaryPrimitives.TryWriteUInt16BigEndian(buff.AsSpan(2), weight) || + !BinaryPrimitives.TryWriteUInt16BigEndian(buff.AsSpan(4), port) || + !DnsPrimitives.TryWriteQName(buff.AsSpan(6), target, out int length)) + { + throw new InvalidOperationException("Failed to encode SRV record"); + } + + length += 6; + + records.Add(new DnsResourceRecord(EncodeDomainName(name), QueryType.SRV, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + return records; + } + + public static List AddStartOfAuthority(this List records, string name, int ttl, string mname, string rname, uint serial, uint refresh, uint retry, uint expire, uint minimum) + { + byte[] buff = new byte[256]; + + // https://www.rfc-editor.org/rfc/rfc1035#section-3.3.13 + if (!DnsPrimitives.TryWriteQName(buff, mname, out int w1) || + !DnsPrimitives.TryWriteQName(buff.AsSpan(w1), rname, out int w2) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2), serial) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 4), refresh) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 8), retry) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 12), expire) || + !BinaryPrimitives.TryWriteUInt32BigEndian(buff.AsSpan(w1 + w2 + 16), minimum)) + { + throw new InvalidOperationException("Failed to encode SOA record"); + } + + int length = w1 + w2 + 20; + + records.Add(new DnsResourceRecord(EncodeDomainName(name), QueryType.SOA, QueryClass.Internet, ttl, buff.AsMemory(0, length))); + return records; + } +} + +internal static class DnsDataWriterExtensions +{ + internal static bool TryWriteResourceRecord(this DnsDataWriter writer, DnsResourceRecord record) + { + if (!TryWriteDomainName(writer, record.Name) || + !writer.TryWriteUInt16((ushort)record.Type) || + !writer.TryWriteUInt16((ushort)record.Class) || + !writer.TryWriteUInt32((uint)record.Ttl) || + !writer.TryWriteUInt16((ushort)record.Data.Length) || + !writer.TryWriteRawData(record.Data.Span)) + { + return false; + } + + return true; + } + + internal static bool TryWriteDomainName(this DnsDataWriter writer, EncodedDomainName name) + { + foreach (var label in name.Labels) + { + if (label.Length > 63) + { + throw new InvalidOperationException("Label length exceeds maximum of 63 bytes"); + } + + if (!writer.TryWriteByte((byte)label.Length) || + !writer.TryWriteRawData(label.Span)) + { + return false; + } + } + + // root label + return writer.TryWriteByte(0); + } +} \ No newline at end of file diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs new file mode 100644 index 00000000000..14abd659029 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/LoopbackDnsTestBase.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Globalization; +using System.Text; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.ServiceDiscovery.Dns.Tests; +using Microsoft.Extensions.Time.Testing; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public abstract class LoopbackDnsTestBase : IDisposable +{ + protected readonly ITestOutputHelper Output; + + internal LoopbackDnsServer DnsServer { get; } + private readonly Lazy _resolverLazy; + internal DnsResolver Resolver => _resolverLazy.Value; + internal ResolverOptions Options { get; } + protected readonly FakeTimeProvider TimeProvider; + + public LoopbackDnsTestBase(ITestOutputHelper output) + { + Output = output; + DnsServer = new(); + TimeProvider = new(); + Options = new([DnsServer.DnsEndPoint]) + { + Timeout = TimeSpan.FromSeconds(5), + Attempts = 1, + }; + _resolverLazy = new(InitializeResolver); + } + + DnsResolver InitializeResolver() + { + ServiceCollection services = new(); + services.AddXunitLogging(Output); + + // construct DnsResolver manually via internal constructor which accepts ResolverOptions + var resolver = new DnsResolver(TimeProvider, services.BuildServiceProvider().GetRequiredService>(), Options); + return resolver; + } + + public void Dispose() + { + DnsServer.Dispose(); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs new file mode 100644 index 00000000000..281ffbecd24 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolvConfTests.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using System.Net; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class ResolvConfTests +{ + [Fact] + public void GetOptions() + { + var contents = @" +nameserver 10.96.0.10 +search default.svc.cluster.local svc.cluster.local cluster.local +options ndots:5 +@"; + + var reader = new StringReader(contents); + ResolverOptions options = ResolvConf.GetOptions(reader); + + IPEndPoint ipAddress = Assert.Single(options.Servers); + Assert.Equal(new IPEndPoint(IPAddress.Parse("10.96.0.10"), 53), ipAddress); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs new file mode 100644 index 00000000000..c2d033ecdae --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveAddressesTests.cs @@ -0,0 +1,307 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Sockets; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class ResolveAddressesTests : LoopbackDnsTestBase +{ + public ResolveAddressesTests(ITestOutputHelper output) : base(output) + { + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ResolveIPv4_NoData_Success(bool includeSoa) + { + string hostName = "nodata.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + if (includeSoa) + { + builder.Authorities.AddStartOfAuthority("ns.com", 240, "ns.com", "admin.ns.com", 1, 900, 180, 6048000, 3600); + } + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(results); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ResolveIPv4_NoSuchName_Success(bool includeSoa) + { + string hostName = "nosuchname.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.ResponseCode = QueryResponseCode.NameError; + if (includeSoa) + { + builder.Authorities.AddStartOfAuthority("ns.com", 240, "ns.com", "admin.ns.com", 1, 900, 180, 6048000, 3600); + } + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(results); + } + + [Theory] + [InlineData("www.resolveipv4.com")] + [InlineData("www.resolveipv4.com.")] + [InlineData("www.ř.com")] + public async Task ResolveIPv4_Simple_Success(string name) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddAddress(name, 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(name, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task ResolveIPv4_Aliases_InOrder_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-in-order.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddAddress("www.example3.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task ResolveIPv4_Aliases_OutOfOrder_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-out-of-order.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddAddress("www.example3.com", 3600, address); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task ResolveIPv4_Aliases_Loop_ReturnsEmpty() + { + string hostName = "alias-loop2.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddCname("www.example3.com", 3600, hostName); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_Aliases_Loop_Reverse_ReturnsEmpty() + { + string hostName = "alias-loop2.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname("www.example3.com", 3600, hostName); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_Alias_And_Address() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-address.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddAddress("www.example2.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_DuplicateAlias() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "duplicate-alias.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example4.com"); + builder.Answers.AddAddress("www.example2.com", 3600, address); + builder.Answers.AddAddress("www.example4.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIPv4_Aliases_NotFound_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "alias-no-found.test"; + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddCname(hostName, 3600, "www.example2.com"); + builder.Answers.AddCname("www.example2.com", 3600, "www.example3.com"); + + // extra address in the answer not connected to the above + builder.Answers.AddAddress("www.example4.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + Assert.Empty(results); + } + + [Fact] + public async Task ResolveIP_InvalidAddressFamily_Throws() + { + await Assert.ThrowsAsync(async () => await Resolver.ResolveIPAddressesAsync("invalid-af.test", AddressFamily.Unknown)); + } + + [Theory] + [InlineData(AddressFamily.InterNetwork, "127.0.0.1")] + [InlineData(AddressFamily.InterNetworkV6, "::1")] + public async Task ResolveIP_Localhost_ReturnsLoopback(AddressFamily family, string addressAsString) + { + IPAddress address = IPAddress.Parse(addressAsString); + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("localhost", family); + AddressResult result = Assert.Single(results); + + Assert.Equal(address, result.Address); + } + + [Fact] + public async Task Resolve_Timeout_ReturnsEmpty() + { + Options.Timeout = TimeSpan.FromSeconds(1); + AddressResult[] result = await Resolver.ResolveIPAddressesAsync("timeout-empty.test", AddressFamily.InterNetwork); + Assert.Empty(result); + } + + [Theory] + [InlineData("not-example.com", (int)QueryType.A, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.AAAA, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.A, 0)] + public async Task Resolve_QuestionMismatch_ReturnsEmpty(string name, int type, int @class) + { + Options.Timeout = TimeSpan.FromSeconds(1); + + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Questions[0] = (name, (QueryType)type, (QueryClass)@class); + builder.Answers.AddAddress("www.example4.com", 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] result = await Resolver.ResolveIPAddressesAsync("example.com", AddressFamily.InterNetwork); + Assert.Empty(result); + } + + [Fact] + public async Task Resolve_HeaderMismatch_Ignores() + { + string name = "header-mismatch.test"; + Options.Timeout = TimeSpan.FromSeconds(5); + + SemaphoreSlim responseSemaphore = new SemaphoreSlim(0, 1); + SemaphoreSlim requestSemaphore = new SemaphoreSlim(0, 1); + + IPEndPoint clientAddress = null!; + + IPAddress address = IPAddress.Parse("172.213.245.111"); + ushort transactionId = 0x1234; + _ = DnsServer.ProcessUdpRequest((builder, clientAddr) => + { + clientAddress = clientAddr; + transactionId = (ushort)(builder.TransactionId + 1); + + builder.Answers.AddAddress(name, 3600, address); + requestSemaphore.Release(); + return responseSemaphore.WaitAsync(); + }); + + ValueTask task = Resolver.ResolveIPAddressesAsync(name, AddressFamily.InterNetwork); + + await requestSemaphore.WaitAsync().WaitAsync(Options.Timeout); + + using Socket socket = new Socket(clientAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + LoopbackDnsResponseBuilder responseBuilder = new LoopbackDnsResponseBuilder(name, QueryType.A, QueryClass.Internet) + { + TransactionId = transactionId, + ResponseCode = QueryResponseCode.NoError + }; + + responseBuilder.Questions.Add((name, QueryType.A, QueryClass.Internet)); + responseBuilder.Answers.AddAddress(name, 3600, IPAddress.Loopback); + socket.SendTo(responseBuilder.GetMessageBytes(), clientAddress); + + responseSemaphore.Release(); + + AddressResult[] results = await task; + AddressResult result = Assert.Single(results); + + Assert.Equal(address, result.Address); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs new file mode 100644 index 00000000000..82ca3175789 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/ResolveServiceTests.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class ResolveServiceTests : LoopbackDnsTestBase +{ + public ResolveServiceTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public async Task ResolveService_Simple_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Answers.AddService("_s0._tcp.example.com", 3600, 1, 2, 8080, "www.example.com"); + builder.Additionals.AddAddress("www.example.com", 3600, address); + return Task.CompletedTask; + }); + + ServiceResult[] results = await Resolver.ResolveServiceAsync("_s0._tcp.example.com"); + + ServiceResult result = Assert.Single(results); + Assert.Equal("www.example.com", result.Target); + Assert.Equal(1, result.Priority); + Assert.Equal(2, result.Weight); + Assert.Equal(8080, result.Port); + + AddressResult addressResult = Assert.Single(result.Addresses); + Assert.Equal(address, addressResult.Address); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs new file mode 100644 index 00000000000..49985846570 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/RetryTests.cs @@ -0,0 +1,309 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Sockets; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class RetryTests : LoopbackDnsTestBase +{ + public RetryTests(ITestOutputHelper output) : base(output) + { + Options.Attempts = 3; + } + + private Task SetupUdpProcessFunction(LoopbackDnsServer server, Func func) + { + return Task.Run(async () => + { + try + { + while (true) + { + await server.ProcessUdpRequest(func); + } + } + catch (Exception ex) + { + Output.WriteLine($"UDP server stopped with exception: {ex}"); + // Test teardown closed the socket, ignore + } + }); + } + + private Task SetupUdpProcessFunction(Func func) + { + return SetupUdpProcessFunction(DnsServer, func); + } + + [Fact] + public async Task Retry_Simple_Success() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "retry-simple-success.com"; + + int attempt = 0; + + Task t = SetupUdpProcessFunction(builder => + { + attempt++; + if (attempt == Options.Attempts) + { + builder.Answers.AddAddress(hostName, 3600, address); + } + else + { + builder.ResponseCode = QueryResponseCode.ServerFailure; + } + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + public enum PersistentErrorType + { + NotImplemented, + Refused, + MalformedResponse + } + + [Theory] + [InlineData(PersistentErrorType.NotImplemented)] + [InlineData(PersistentErrorType.Refused)] + [InlineData(PersistentErrorType.MalformedResponse)] + public async Task PersistentErrorsResponseCode_FailoverToNextServer(PersistentErrorType type) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "www.persistent.com"; + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper(hostName, + builder => + { + primaryAttempt++; + switch (type) + { + case PersistentErrorType.NotImplemented: + builder.ResponseCode = QueryResponseCode.NotImplemented; + break; + + case PersistentErrorType.Refused: + builder.ResponseCode = QueryResponseCode.Refused; + break; + + case PersistentErrorType.MalformedResponse: + builder.ResponseCode = (QueryResponseCode)0xFF; + break; + } + return Task.CompletedTask; + }, + builder => + { + secondaryAttempt++; + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + Assert.Equal(1, primaryAttempt); + Assert.Equal(1, secondaryAttempt); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + public enum DefinitveAnswerType + { + NoError, + NoData, + NameError, + } + + [Theory] + [InlineData(DefinitveAnswerType.NoError, false)] + [InlineData(DefinitveAnswerType.NoData, false)] + [InlineData(DefinitveAnswerType.NoData, true)] + [InlineData(DefinitveAnswerType.NameError, false)] + [InlineData(DefinitveAnswerType.NameError, true)] + public async Task DefinitiveAnswers_NoRetryOrFailover(DefinitveAnswerType type, bool additionalData) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "www.retry.com"; + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper(hostName, + builder => + { + primaryAttempt++; + switch (type) + { + case DefinitveAnswerType.NoError: + builder.ResponseCode = QueryResponseCode.NoError; + builder.Answers.AddAddress(hostName, 3600, address); + break; + + case DefinitveAnswerType.NoData: + builder.ResponseCode = QueryResponseCode.NoError; + break; + + case DefinitveAnswerType.NameError: + builder.ResponseCode = QueryResponseCode.NameError; + break; + } + + if (additionalData) + { + builder.Authorities.AddStartOfAuthority(hostName, 300, "ns1.example.com", "hostmaster.example.com", 2023101001, 1, 3600, 300, 86400); + } + + return Task.CompletedTask; + }, + builder => + { + secondaryAttempt++; + builder.ResponseCode = QueryResponseCode.Refused; + return Task.CompletedTask; + }); + + Assert.Equal(1, primaryAttempt); + Assert.Equal(0, secondaryAttempt); + + if (type == DefinitveAnswerType.NoError) + { + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + else + { + Assert.Empty(results); + } + } + + [Fact] + public async Task ExhaustedRetries_FailoverToNextServer() + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "ExhaustedRetriesFailoverToNextServer"; + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper(hostName, + builder => + { + primaryAttempt++; + builder.ResponseCode = QueryResponseCode.ServerFailure; + return Task.CompletedTask; + }, + builder => + { + secondaryAttempt++; + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + Assert.Equal(Options.Attempts, primaryAttempt); + Assert.Equal(1, secondaryAttempt); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + public enum TransientErrorType + { + Timeout, + ServerFailure, + // TODO: simulate NetworkErrors + } + + [Theory] + [InlineData(TransientErrorType.Timeout)] + [InlineData(TransientErrorType.ServerFailure)] + public async Task TransientError_RetryOnSameServer(TransientErrorType type) + { + IPAddress address = IPAddress.Parse("172.213.245.111"); + string hostName = "www.transient.com"; + + int primaryAttempt = 0; + int secondaryAttempt = 0; + + AddressResult[] results = await RunWithFallbackServerHelper(hostName, + async builder => + { + primaryAttempt++; + if (primaryAttempt == 1) + { + switch (type) + { + case TransientErrorType.Timeout: + await Task.Delay(Options.Timeout.Multiply(1.5)); + builder.Answers.AddAddress(hostName, 3600, address); + break; + + case TransientErrorType.ServerFailure: + builder.ResponseCode = QueryResponseCode.ServerFailure; + break; + } + } + else + { + builder.Answers.AddAddress(hostName, 3600, address); + } + }, + builder => + { + secondaryAttempt++; + builder.ResponseCode = QueryResponseCode.Refused; + return Task.CompletedTask; + }); + + Assert.Equal(2, primaryAttempt); + Assert.Equal(0, secondaryAttempt); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + private async Task RunWithFallbackServerHelper(string name, Func primaryHandler, Func fallbackHandler) + { + Task t = SetupUdpProcessFunction(primaryHandler); + using LoopbackDnsServer fallbackServer = new LoopbackDnsServer(); + Task t2 = SetupUdpProcessFunction(fallbackServer, fallbackHandler); + + Options.Servers = [DnsServer.DnsEndPoint, fallbackServer.DnsEndPoint]; + + return await Resolver.ResolveIPAddressesAsync(name, AddressFamily.InterNetwork); + } + + [Fact] + public async Task NameError_NoRetry() + { + int counter = 0; + Task t = SetupUdpProcessFunction(builder => + { + counter++; + // authoritative answer that the name does not exist + builder.ResponseCode = QueryResponseCode.NameError; + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync("nameerror-noretry", AddressFamily.InterNetwork); + + Assert.Empty(results); + Assert.Equal(1, counter); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs new file mode 100644 index 00000000000..cbdb5e282e9 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/Resolver/TcpFailoverTests.cs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Sockets; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver.Tests; + +public class TcpFailoverTests : LoopbackDnsTestBase +{ + public TcpFailoverTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public async Task TcpFailover_Simple_Success() + { + string hostName = "tcp-simple.test"; + IPAddress address = IPAddress.Parse("172.213.245.111"); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + _ = DnsServer.ProcessTcpRequest(builder => + { + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + + AddressResult res = Assert.Single(results); + Assert.Equal(address, res.Address); + Assert.Equal(TimeProvider.GetUtcNow().DateTime.AddSeconds(3600), res.ExpiresAt); + } + + [Fact] + public async Task TcpFailover_ServerClosesWithoutData_EmptyResult() + { + string hostName = "tcp-server-closes.test"; + Options.Attempts = 1; + Options.Timeout = TimeSpan.FromSeconds(60); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + Task serverTask = DnsServer.ProcessTcpRequest(builder => + { + throw new InvalidOperationException("This forces closing the socket without writing any data"); + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + Assert.Empty(results); + + await Assert.ThrowsAsync(() => serverTask); + } + + [Fact] + public async Task TcpFailover_TcpNotAvailable_EmptyResult() + { + string hostName = "tcp-not-available.test"; + Options.Attempts = 1; + Options.Timeout = TimeSpan.FromMilliseconds(100000); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + AddressResult[] results = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(results); + } + + [Fact] + public async Task TcpFailover_HeaderMismatch_ReturnsEmpty() + { + string hostName = "tcp-header-mismatch.test"; + Options.Timeout = TimeSpan.FromSeconds(1); + IPAddress address = IPAddress.Parse("172.213.245.111"); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + _ = DnsServer.ProcessTcpRequest(builder => + { + builder.TransactionId++; + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] result = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(result); + } + + [Theory] + [InlineData("not-example.com", (int)QueryType.A, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.AAAA, (int)QueryClass.Internet)] + [InlineData("example.com", (int)QueryType.A, 0)] + public async Task TcpFailover_QuestionMismatch_ReturnsEmpty(string name, int type, int @class) + { + string hostName = "tcp-question-mismatch.test"; + Options.Timeout = TimeSpan.FromSeconds(1); + IPAddress address = IPAddress.Parse("172.213.245.111"); + + _ = DnsServer.ProcessUdpRequest(builder => + { + builder.Flags |= QueryFlags.ResultTruncated; + return Task.CompletedTask; + }); + + _ = DnsServer.ProcessTcpRequest(builder => + { + builder.Questions[0] = (name, (QueryType)type, (QueryClass)@class); + builder.Answers.AddAddress(hostName, 3600, address); + return Task.CompletedTask; + }); + + AddressResult[] result = await Resolver.ResolveIPAddressesAsync(hostName, AddressFamily.InterNetwork); + Assert.Empty(result); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/XunitLoggerFactoryExtensions.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/XunitLoggerFactoryExtensions.cs new file mode 100644 index 00000000000..6667688f16e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Dns.Tests/XunitLoggerFactoryExtensions.cs @@ -0,0 +1,145 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests; + +internal static class XunitLoggerFactoryExtensions +{ + public static ILoggingBuilder AddXunit(this ILoggingBuilder builder, ITestOutputHelper output) + { + builder.Services.AddSingleton(new XunitLoggerProvider(output)); + return builder; + } + + public static IServiceCollection AddXunitLogging(this IServiceCollection services, ITestOutputHelper output) => + services.AddLogging(b => b.AddXunit(output)); +} + +internal class XunitLoggerProvider : ILoggerProvider +{ + private readonly ITestOutputHelper _output; + private readonly LogLevel _minLevel; + private readonly DateTimeOffset? _logStart; + + public XunitLoggerProvider(ITestOutputHelper output) + : this(output, LogLevel.Trace) + { + } + + public XunitLoggerProvider(ITestOutputHelper output, LogLevel minLevel) + : this(output, minLevel, null) + { + } + + public XunitLoggerProvider(ITestOutputHelper output, LogLevel minLevel, DateTimeOffset? logStart) + { + _output = output; + _minLevel = minLevel; + _logStart = logStart; + } + + public ILogger CreateLogger(string categoryName) + { + return new XunitLogger(_output, categoryName, _minLevel, _logStart); + } + + public void Dispose() + { + } +} + +internal class XunitLogger : ILogger +{ + private static readonly string[] s_newLineChars = new[] { Environment.NewLine }; + private readonly string _category; + private readonly LogLevel _minLogLevel; + private readonly ITestOutputHelper _output; + private readonly DateTimeOffset? _logStart; + + public XunitLogger(ITestOutputHelper output, string category, LogLevel minLogLevel, DateTimeOffset? logStart) + { + _minLogLevel = minLogLevel; + _category = category; + _output = output; + _logStart = logStart; + } + + public void Log( + LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + if (!IsEnabled(logLevel)) + { + return; + } + + // Buffer the message into a single string in order to avoid shearing the message when running across multiple threads. + var messageBuilder = new StringBuilder(); + + var timestamp = _logStart.HasValue ? + $"{(DateTimeOffset.UtcNow - _logStart.Value).TotalSeconds.ToString("N3", CultureInfo.InvariantCulture)}s" : + DateTimeOffset.UtcNow.ToString("s", CultureInfo.InvariantCulture); + + var firstLinePrefix = $"| [{timestamp}] {_category} {logLevel}: "; + var lines = formatter(state, exception).Split(s_newLineChars, StringSplitOptions.RemoveEmptyEntries); + messageBuilder.AppendLine(firstLinePrefix + lines.FirstOrDefault() ?? string.Empty); + + var additionalLinePrefix = "|" + new string(' ', firstLinePrefix.Length - 1); + foreach (var line in lines.Skip(1)) + { + messageBuilder.AppendLine(additionalLinePrefix + line); + } + + if (exception != null) + { + lines = exception.ToString().Split(s_newLineChars, StringSplitOptions.RemoveEmptyEntries); + additionalLinePrefix = "| "; + foreach (var line in lines) + { + messageBuilder.AppendLine(additionalLinePrefix + line); + } + } + + // Remove the last line-break, because ITestOutputHelper only has WriteLine. + var message = messageBuilder.ToString(); + if (message.EndsWith(Environment.NewLine, StringComparison.Ordinal)) + { + message = message.Substring(0, message.Length - Environment.NewLine.Length); + } + + try + { + _output.WriteLine(message); + } + catch (Exception) + { + // We could fail because we're on a background thread and our captured ITestOutputHelper is + // busted (if the test "completed" before the background thread fired). + // So, ignore this. There isn't really anything we can do but hope the + // caller has additional loggers registered + } + } + + public bool IsEnabled(LogLevel logLevel) + => logLevel >= _minLogLevel; + + public IDisposable BeginScope(TState state) where TState : notnull + => new NullScope(); + + private sealed class NullScope : IDisposable + { + public void Dispose() + { + } + } +} + diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ConfigurationServiceEndpointResolverTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ConfigurationServiceEndpointResolverTests.cs new file mode 100644 index 00000000000..9fc8832fa68 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ConfigurationServiceEndpointResolverTests.cs @@ -0,0 +1,430 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Configuration.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.ServiceDiscovery.Configuration; +using Microsoft.Extensions.ServiceDiscovery.Internal; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Tests; + +/// +/// Tests for . +/// These also cover and by extension. +/// +public class ConfigurationServiceEndpointResolverTests +{ + [Fact] + public async Task ResolveServiceEndpoint_Configuration_SingleResult_NoScheme() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new Dictionary + { + ["services:basket:http"] = "localhost:8080", + }); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider() + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + var ep = Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new DnsEndPoint("localhost", 8080), ep.EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.Null(hostNameFeature); + }); + } + } + + [Fact] + public async Task ResolveServiceEndpoint_Configuration_DisallowedScheme() + { + // Try to resolve an http endpoint when only https is allowed. + var config = new ConfigurationBuilder().AddInMemoryCollection(new Dictionary + { + ["services:basket:foo:0"] = "http://localhost:8080", + ["services:basket:foo:1"] = "https://localhost", + }); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider() + .Configure(o => + { + o.AllowAllSchemes = false; + o.AllowedSchemes = ["https"]; + }) + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + + // Explicitly specifying http. + // We should get no endpoint back because http is not allowed by configuration. + await using ((watcher = watcherFactory.CreateWatcher("http://_foo.basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Empty(initialResult.EndpointSource.Endpoints); + } + + // Specifying no scheme. + // We should get the HTTPS endpoint back, since it is explicitly allowed + await using ((watcher = watcherFactory.CreateWatcher("_foo.basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + var ep = Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new UriEndPoint(new Uri("https://localhost")), ep.EndPoint); + } + + // Specifying either https or http. + // We should only get the https endpoint back. + await using ((watcher = watcherFactory.CreateWatcher("https+http://_foo.basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + var ep = Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new UriEndPoint(new Uri("https://localhost")), ep.EndPoint); + } + + // Specifying either https or http, but in reverse. + // We should only get the https endpoint back. + await using ((watcher = watcherFactory.CreateWatcher("http+https://_foo.basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + var ep = Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new UriEndPoint(new Uri("https://localhost")), ep.EndPoint); + } + } + + [Fact] + public async Task ResolveServiceEndpoint_Configuration_DefaultEndpointName() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new Dictionary + { + ["services:basket:default:0"] = "https://localhost:8080", + ["services:basket:otlp:0"] = "https://localhost:8888", + }); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider(o => + { + o.ShouldApplyHostNameMetadata = _ => true; + }) + .Configure(o => + { + o.AllowAllSchemes = false; + o.AllowedSchemes = ["https"]; + }) + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + + // Explicitly specifying https as the scheme, but the endpoint section in configuration is the default value ("default"). + // We should get the endpoint back because it is an https endpoint (allowed) with the default endpoint name. + await using ((watcher = watcherFactory.CreateWatcher("https://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new UriEndPoint(new Uri("https://localhost:8080")), initialResult.EndpointSource.Endpoints[0].EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.NotNull(hostNameFeature); + Assert.Equal("basket", hostNameFeature.HostName); + }); + } + + // Not specifying the scheme or endpoint name. + // We should get the endpoint back because it is an https endpoint (allowed) with the default endpoint name. + await using ((watcher = watcherFactory.CreateWatcher("basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new UriEndPoint(new Uri("https://localhost:8080")), initialResult.EndpointSource.Endpoints[0].EndPoint); + } + + // Not specifying the scheme, but specifying the default endpoint name. + // We should get the endpoint back because it is an https endpoint (allowed) with the default endpoint name. + await using ((watcher = watcherFactory.CreateWatcher("_default.basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new UriEndPoint(new Uri("https://localhost:8080")), initialResult.EndpointSource.Endpoints[0].EndPoint); + } + } + + /// + /// Checks that when there is no named endpoint, configuration resolves first from the "default" section, then sections named by the scheme names. + /// + [Theory] + [InlineData(true, true, "https://basket", "https://default-host:8080")] + [InlineData(false, true, "https://basket","https://https-host:8080")] + [InlineData(true, false, "https://basket", "https://default-host:8080")] + [InlineData(true, true, "basket", "https://default-host:8080")] + [InlineData(false, true, "basket", null)] + [InlineData(true, false, "basket", "https://default-host:8080")] + [InlineData(true, true, "http+https://basket", "https://default-host:8080")] + [InlineData(false, true, "http+https://basket","https://https-host:8080")] + [InlineData(true, false, "http+https://basket", "https://default-host:8080")] + public async Task ResolveServiceEndpoint_Configuration_DefaultEndpointName_ResolutionOrder( + bool includeDefault, + bool includeSchemeNamed, + string serviceName, + string? expectedResult) + { + var data = new Dictionary(); + if (includeDefault) + { + data["services:basket:default:0"] = "https://default-host:8080"; + } + + if (includeSchemeNamed) + { + data["services:basket:https:0"] = "https://https-host:8080"; + } + + var config = new ConfigurationBuilder().AddInMemoryCollection(data); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider() + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + + // Scheme in query + await using ((watcher = watcherFactory.CreateWatcher(serviceName)).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + if (expectedResult is not null) + { + Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new UriEndPoint(new Uri(expectedResult)), initialResult.EndpointSource.Endpoints[0].EndPoint); + } + else + { + Assert.Empty(initialResult.EndpointSource.Endpoints); + } + } + } + + [Fact] + public async Task ResolveServiceEndpoint_Configuration_MultipleResults() + { + var configSource = new MemoryConfigurationSource + { + InitialData = new Dictionary + { + ["services:basket:default:0"] = "http://localhost:8080", + ["services:basket:default:1"] = "http://remotehost:9090", + } + }; + var config = new ConfigurationBuilder().Add(configSource); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider(options => options.ShouldApplyHostNameMetadata = _ => true) + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Equal(2, initialResult.EndpointSource.Endpoints.Count); + Assert.Equal(new UriEndPoint(new Uri("http://localhost:8080")), initialResult.EndpointSource.Endpoints[0].EndPoint); + Assert.Equal(new UriEndPoint(new Uri("http://remotehost:9090")), initialResult.EndpointSource.Endpoints[1].EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.NotNull(hostNameFeature); + Assert.Equal("basket", hostNameFeature.HostName); + }); + } + + // Request either https or http. Since there are only http endpoints, we should get only http endpoints back. + await using ((watcher = watcherFactory.CreateWatcher("https+http://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Equal(2, initialResult.EndpointSource.Endpoints.Count); + Assert.Equal(new UriEndPoint(new Uri("http://localhost:8080")), initialResult.EndpointSource.Endpoints[0].EndPoint); + Assert.Equal(new UriEndPoint(new Uri("http://remotehost:9090")), initialResult.EndpointSource.Endpoints[1].EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.NotNull(hostNameFeature); + Assert.Equal("basket", hostNameFeature.HostName); + }); + } + } + + [Fact] + public async Task ResolveServiceEndpoint_Configuration_MultipleProtocols() + { + var configSource = new MemoryConfigurationSource + { + InitialData = new Dictionary + { + ["services:basket:http:0"] = "http://localhost:8080", + ["services:basket:https:1"] = "https://remotehost:9090", + ["services:basket:grpc:0"] = "localhost:2222", + ["services:basket:grpc:1"] = "127.0.0.1:3333", + ["services:basket:grpc:2"] = "http://remotehost:4444", + ["services:basket:grpc:3"] = "https://remotehost:5555", + } + }; + var config = new ConfigurationBuilder().Add(configSource); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider() + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://_grpc.basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Equal(3, initialResult.EndpointSource.Endpoints.Count); + Assert.Equal(new DnsEndPoint("localhost", 2222), initialResult.EndpointSource.Endpoints[0].EndPoint); + Assert.Equal(new IPEndPoint(IPAddress.Loopback, 3333), initialResult.EndpointSource.Endpoints[1].EndPoint); + Assert.Equal(new UriEndPoint(new Uri("http://remotehost:4444")), initialResult.EndpointSource.Endpoints[2].EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.Null(hostNameFeature); + }); + } + } + + [Fact] + public async Task ResolveServiceEndpoint_Configuration_MultipleProtocols_WithSpecificationByConsumer() + { + var configSource = new MemoryConfigurationSource + { + InitialData = new Dictionary + { + ["services:basket:default:0"] = "http://localhost:8080", + ["services:basket:default:1"] = "remotehost:9090", + ["services:basket:grpc:0"] = "localhost:2222", + ["services:basket:grpc:1"] = "127.0.0.1:3333", + ["services:basket:grpc:2"] = "http://remotehost:4444", + ["services:basket:grpc:3"] = "https://remotehost:5555", + } + }; + var config = new ConfigurationBuilder().Add(configSource); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider() + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("https+http://_grpc.basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Equal(3, initialResult.EndpointSource.Endpoints.Count); + + // These must be treated as HTTPS by the HttpClient middleware, but that is not the responsibility of the resolver. + Assert.Equal(new DnsEndPoint("localhost", 2222), initialResult.EndpointSource.Endpoints[0].EndPoint); + Assert.Equal(new IPEndPoint(IPAddress.Loopback, 3333), initialResult.EndpointSource.Endpoints[1].EndPoint); + + // We expect the HTTPS endpoint back but not the HTTP one. + Assert.Equal(new UriEndPoint(new Uri("https://remotehost:5555")), initialResult.EndpointSource.Endpoints[2].EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.Null(hostNameFeature); + }); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ExtensionsServicePublicApiTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ExtensionsServicePublicApiTests.cs new file mode 100644 index 00000000000..31781cf6722 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ExtensionsServicePublicApiTests.cs @@ -0,0 +1,218 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.FileProviders; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Primitives; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Tests; + +#pragma warning disable IDE0200 + +public class ExtensionsServicePublicApiTests +{ + [Fact] + public void AddServiceDiscoveryShouldThrowWhenHttpClientBuilderIsNull() + { + IHttpClientBuilder httpClientBuilder = null!; + + var action = () => httpClientBuilder.AddServiceDiscovery(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(httpClientBuilder), exception.ParamName); + } + + [Fact] + public void AddServiceDiscoveryShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + + var action = () => services.AddServiceDiscovery(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddServiceDiscoveryWithConfigureOptionsShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + Action configureOptions = (_) => { }; + + var action = () => services.AddServiceDiscovery(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddServiceDiscoveryWithConfigureOptionsShouldThrowWhenConfigureOptionsIsNull() + { + var services = new ServiceCollection(); + Action configureOptions = null!; + + var action = () => services.AddServiceDiscovery(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(configureOptions), exception.ParamName); + } + + [Fact] + public void AddServiceDiscoveryCoreShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + + var action = () => services.AddServiceDiscoveryCore(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddServiceDiscoveryCoreWithConfigureOptionsShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + Action configureOptions = (_) => { }; + + var action = () => services.AddServiceDiscoveryCore(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddServiceDiscoveryCoreWithConfigureOptionsShouldThrowWhenConfigureOptionsIsNull() + { + var services = new ServiceCollection(); + Action configureOptions = null!; + + var action = () => services.AddServiceDiscoveryCore(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(configureOptions), exception.ParamName); + } + + [Fact] + public void AddConfigurationServiceEndpointProviderShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + + var action = () => services.AddConfigurationServiceEndpointProvider(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddConfigurationServiceEndpointProviderWithConfigureOptionsShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + Action configureOptions = (_) => { }; + + var action = () => services.AddConfigurationServiceEndpointProvider(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddConfigurationServiceEndpointProviderWithConfigureOptionsShouldThrowWhenConfigureOptionsIsNull() + { + var services = new ServiceCollection(); + Action configureOptions = null!; + + var action = () => services.AddConfigurationServiceEndpointProvider(configureOptions); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(configureOptions), exception.ParamName); + } + + [Fact] + public void AddPassThroughServiceEndpointProviderShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + + var action = () => services.AddPassThroughServiceEndpointProvider(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public async Task GetEndpointsAsyncShouldThrowWhenServiceNameIsNull() + { + var serviceEndpointWatcherFactory = new ServiceEndpointWatcherFactory( + new List(), + new Logger(new NullLoggerFactory()), + Options.Options.Create(new ServiceDiscoveryOptions()), + TimeProvider.System); + + var serviceEndpointResolver = new ServiceEndpointResolver(serviceEndpointWatcherFactory, TimeProvider.System); + string serviceName = null!; + + var action = async () => await serviceEndpointResolver.GetEndpointsAsync(serviceName, CancellationToken.None); + + var exception = await Assert.ThrowsAsync(action); + Assert.Equal(nameof(serviceName), exception.ParamName); + } + + [Fact] + public void CreateShouldThrowWhenEndPointIsNull() + { + EndPoint endPoint = null!; + + var action = () => ServiceEndpoint.Create(endPoint); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(endPoint), exception.ParamName); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void TryParseShouldThrowWhenEndPointIsNullOrEmpty(bool isNull) + { + var input = isNull ? null! : string.Empty; + + var action = () => + { + _ = ServiceEndpointQuery.TryParse(input, out _); + }; + + var exception = isNull + ? Assert.Throws(action) + : Assert.Throws(action); + Assert.Equal(nameof(input), exception.ParamName); + } + + [Fact] + public void CtorServiceEndpointSourceShouldThrowWhenChangeTokenIsNull() + { + IChangeToken changeToken = null!; + var features = new FeatureCollection(); + List? endpoints = null; + + var action = () => new ServiceEndpointSource(endpoints, changeToken, features); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(changeToken), exception.ParamName); + } + + [Fact] + public void CtorServiceEndpointSourceShouldThrowWhenFeaturesIsNull() + { + var changeToken = NullChangeToken.Singleton; + IFeatureCollection features = null!; + List? endpoints = null; + + var action = () => new ServiceEndpointSource(endpoints, changeToken, features); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(features), exception.ParamName); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/Microsoft.Extensions.ServiceDiscovery.Tests.csproj b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/Microsoft.Extensions.ServiceDiscovery.Tests.csproj new file mode 100644 index 00000000000..6a39ff1b9af --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/Microsoft.Extensions.ServiceDiscovery.Tests.csproj @@ -0,0 +1,25 @@ + + + + enable + enable + + $(NoWarn);IDE0004;IDE0040;IDE0055;IDE1006;CA2000;S1121;S1128;SA1316;SA1500;SA1513 + + + + + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/PassThroughServiceEndpointResolverTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/PassThroughServiceEndpointResolverTests.cs new file mode 100644 index 00000000000..f8cc2f282e1 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/PassThroughServiceEndpointResolverTests.cs @@ -0,0 +1,130 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Configuration.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.ServiceDiscovery.Internal; +using Microsoft.Extensions.ServiceDiscovery.PassThrough; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Tests; + +/// +/// Tests for . +/// These also cover and by extension. +/// +public class PassThroughServiceEndpointResolverTests +{ + [Fact] + public async Task ResolveServiceEndpoint_PassThrough() + { + var services = new ServiceCollection() + .AddServiceDiscoveryCore() + .AddPassThroughServiceEndpointProvider() + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + var ep = Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new DnsEndPoint("basket", 80), ep.EndPoint); + } + } + + [Fact] + public async Task ResolveServiceEndpoint_Superseded() + { + var configSource = new MemoryConfigurationSource + { + InitialData = new Dictionary + { + ["services:basket:http:0"] = "http://localhost:8080", + } + }; + var config = new ConfigurationBuilder().Add(configSource); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscovery() // Adds the configuration and pass-through providers. + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + + // We expect the basket service to be resolved from Configuration, not the pass-through provider. + Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new UriEndPoint(new Uri("http://localhost:8080")), initialResult.EndpointSource.Endpoints[0].EndPoint); + } + } + + [Fact] + public async Task ResolveServiceEndpoint_Fallback() + { + var configSource = new MemoryConfigurationSource + { + InitialData = new Dictionary + { + ["services:basket:default:0"] = "http://localhost:8080", + } + }; + var config = new ConfigurationBuilder().Add(configSource); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscovery() // Adds the configuration and pass-through providers. + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://catalog")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task; + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + + // We expect the CATALOG service to be resolved from the pass-through provider. + Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new DnsEndPoint("catalog", 80), initialResult.EndpointSource.Endpoints[0].EndPoint); + } + } + + // Ensures that pass-through resolution succeeds in scenarios where no scheme is specified during resolution. + [Fact] + public async Task ResolveServiceEndpoint_Fallback_NoScheme() + { + var configSource = new MemoryConfigurationSource + { + InitialData = new Dictionary + { + ["services:basket:default:0"] = "http://localhost:8080", + } + }; + var config = new ConfigurationBuilder().Add(configSource); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscovery() // Adds the configuration and pass-through providers. + .BuildServiceProvider(); + + var resolver = services.GetRequiredService(); + var result = await resolver.GetEndpointsAsync("catalog", default); + Assert.Equal(new DnsEndPoint("catalog", 0), result.Endpoints[0].EndPoint); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ServiceEndpointResolverTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ServiceEndpointResolverTests.cs new file mode 100644 index 00000000000..c91f07c9300 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ServiceEndpointResolverTests.cs @@ -0,0 +1,292 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Net; +using System.Threading.Channels; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Primitives; +using Microsoft.Extensions.ServiceDiscovery.Http; +using Microsoft.Extensions.ServiceDiscovery.Internal; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Tests; + +/// +/// Tests for and . +/// +public class ServiceEndpointResolverTests +{ + [Fact] + public void ResolveServiceEndpoint_NoProvidersConfigured_Throws() + { + var services = new ServiceCollection() + .AddServiceDiscoveryCore() + .BuildServiceProvider(); + var resolverFactory = services.GetRequiredService(); + var exception = Assert.Throws(() => resolverFactory.CreateWatcher("https://basket")); + Assert.Equal("No provider which supports the provided service name, 'https://basket', has been configured.", exception.Message); + } + + [Fact] + public async Task ServiceEndpointResolver_NoProvidersConfigured_Throws() + { + var services = new ServiceCollection() + .AddServiceDiscoveryCore() + .BuildServiceProvider(); + var watcher = new ServiceEndpointWatcher([], NullLogger.Instance, "foo", TimeProvider.System, Options.Options.Create(new ServiceDiscoveryOptions())); + var exception = Assert.Throws(watcher.Start); + Assert.Equal("No service endpoint providers are configured.", exception.Message); + exception = await Assert.ThrowsAsync(async () => await watcher.GetEndpointsAsync()); + Assert.Equal("No service endpoint providers are configured.", exception.Message); + } + + [Fact] + public void ResolveServiceEndpoint_NullServiceName_Throws() + { + var services = new ServiceCollection() + .AddServiceDiscoveryCore() + .BuildServiceProvider(); + var resolverFactory = services.GetRequiredService(); + Assert.Throws(() => resolverFactory.CreateWatcher(null!)); + } + + [Fact] + public async Task AddServiceDiscovery_NoProviders_Throws() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddHttpClient("foo", c => c.BaseAddress = new("http://foo")).AddServiceDiscovery(); + var services = serviceCollection.BuildServiceProvider(); + var client = services.GetRequiredService().CreateClient("foo"); + var exception = await Assert.ThrowsAsync(async () => await client.GetStringAsync("/")); + Assert.Equal("No provider which supports the provided service name, 'http://foo', has been configured.", exception.Message); + } + + private sealed class FakeEndpointResolverProvider(Func createResolverDelegate) : IServiceEndpointProviderFactory + { +#pragma warning disable CS0436 // Type conflicts with imported type + public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? resolver) +#pragma warning restore CS0436 // Type conflicts with imported type + { + bool result; + (result, resolver) = createResolverDelegate(query); + return result; + } + } + + private sealed class FakeEndpointResolver(Func resolveAsync, Func disposeAsync) : IServiceEndpointProvider + { + public ValueTask PopulateAsync(IServiceEndpointBuilder endpoints, CancellationToken cancellationToken) => resolveAsync(endpoints, cancellationToken); + public ValueTask DisposeAsync() => disposeAsync(); + } + + [Fact] + public async Task ResolveServiceEndpoint() + { + var cts = new[] { new CancellationTokenSource() }; + var innerResolver = new FakeEndpointResolver( + resolveAsync: (collection, ct) => + { + collection.AddChangeToken(new CancellationChangeToken(cts[0].Token)); + collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080))); + + if (cts[0].Token.IsCancellationRequested) + { + cts[0] = new(); + collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.2"), 8888))); + } + return default; + }, + disposeAsync: () => default); + var resolverProvider = new FakeEndpointResolverProvider(name => (true, innerResolver)); + var services = new ServiceCollection() + .AddSingleton(resolverProvider) + .AddServiceDiscoveryCore() + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var initialResult = await watcher.GetEndpointsAsync(CancellationToken.None); + Assert.NotNull(initialResult); + var sep = Assert.Single(initialResult.Endpoints); + var ip = Assert.IsType(sep.EndPoint); + Assert.Equal(IPAddress.Parse("127.1.1.1"), ip.Address); + Assert.Equal(8080, ip.Port); + + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + Assert.False(tcs.Task.IsCompleted); + + cts[0].Cancel(); + var resolverResult = await tcs.Task; + Assert.NotNull(resolverResult); + Assert.True(resolverResult.ResolvedSuccessfully); + Assert.Equal(2, resolverResult.EndpointSource.Endpoints.Count); + var endpoints = resolverResult.EndpointSource.Endpoints.Select(ep => ep.EndPoint).OfType().ToList(); + endpoints.Sort((l, r) => l.Port - r.Port); + Assert.Equal(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080), endpoints[0]); + Assert.Equal(new IPEndPoint(IPAddress.Parse("127.1.1.2"), 8888), endpoints[1]); + } + } + + [Fact] + public async Task ResolveServiceEndpointOneShot() + { + var cts = new[] { new CancellationTokenSource() }; + var innerResolver = new FakeEndpointResolver( + resolveAsync: (collection, ct) => + { + collection.AddChangeToken(new CancellationChangeToken(cts[0].Token)); + collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080))); + + if (cts[0].Token.IsCancellationRequested) + { + cts[0] = new(); + collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.2"), 8888))); + } + return default; + }, + disposeAsync: () => default); + var resolverProvider = new FakeEndpointResolverProvider(name => (true, innerResolver)); + var services = new ServiceCollection() + .AddSingleton(resolverProvider) + .AddServiceDiscoveryCore() + .BuildServiceProvider(); + var resolver = services.GetRequiredService(); + + Assert.NotNull(resolver); + var initialResult = await resolver.GetEndpointsAsync("http://basket", CancellationToken.None); + Assert.NotNull(initialResult); + var sep = Assert.Single(initialResult.Endpoints); + var ip = Assert.IsType(sep.EndPoint); + Assert.Equal(IPAddress.Parse("127.1.1.1"), ip.Address); + Assert.Equal(8080, ip.Port); + + await services.DisposeAsync(); + } + + [Fact] + public async Task ResolveHttpServiceEndpointOneShot() + { + var cts = new[] { new CancellationTokenSource() }; + var innerResolver = new FakeEndpointResolver( + resolveAsync: (collection, ct) => + { + collection.AddChangeToken(new CancellationChangeToken(cts[0].Token)); + collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080))); + + if (cts[0].Token.IsCancellationRequested) + { + cts[0] = new(); + collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.2"), 8888))); + } + return default; + }, + disposeAsync: () => default); + var fakeResolverProvider = new FakeEndpointResolverProvider(name => (true, innerResolver)); + var services = new ServiceCollection() + .AddSingleton(fakeResolverProvider) + .AddServiceDiscoveryCore() + .BuildServiceProvider(); + var resolverProvider = services.GetRequiredService(); + await using var resolver = new HttpServiceEndpointResolver(resolverProvider, services, TimeProvider.System); + + Assert.NotNull(resolver); + var httpRequest = new HttpRequestMessage(HttpMethod.Get, "http://basket"); + var endpoint = await resolver.GetEndpointAsync(httpRequest, CancellationToken.None); + Assert.NotNull(endpoint); + var ip = Assert.IsType(endpoint.EndPoint); + Assert.Equal(IPAddress.Parse("127.1.1.1"), ip.Address); + Assert.Equal(8080, ip.Port); + + await services.DisposeAsync(); + } + + [Fact] + public async Task ResolveServiceEndpoint_ThrowOnReload() + { + var sem = new SemaphoreSlim(0); + var cts = new[] { new CancellationTokenSource() }; + var throwOnNextResolve = new[] { false }; + var innerResolver = new FakeEndpointResolver( + resolveAsync: async (collection, ct) => + { + await sem.WaitAsync(ct).ConfigureAwait(false); + if (cts[0].IsCancellationRequested) + { + // Always be sure to have a fresh token. + cts[0] = new(); + } + + if (throwOnNextResolve[0]) + { + throwOnNextResolve[0] = false; + throw new InvalidOperationException("throwing"); + } + + collection.AddChangeToken(new CancellationChangeToken(cts[0].Token)); + collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080))); + }, + disposeAsync: () => default); + var resolverProvider = new FakeEndpointResolverProvider(name => (true, innerResolver)); + var services = new ServiceCollection() + .AddSingleton(resolverProvider) + .AddServiceDiscoveryCore() + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + + ServiceEndpointWatcher watcher; + await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var initialEndpointsTask = watcher.GetEndpointsAsync(CancellationToken.None); + sem.Release(1); + var initialEndpoints = await initialEndpointsTask; + Assert.NotNull(initialEndpoints); + Assert.Single(initialEndpoints.Endpoints); + + // Tell the resolver to throw on the next resolve call and then trigger a reload. + throwOnNextResolve[0] = true; + cts[0].Cancel(); + + var exception = await Assert.ThrowsAsync(async () => + { + var resolveTask = watcher.GetEndpointsAsync(CancellationToken.None); + sem.Release(1); + await resolveTask.ConfigureAwait(false); + }); + + Assert.Equal("throwing", exception.Message); + + var channel = Channel.CreateUnbounded(); + watcher.OnEndpointsUpdated = result => channel.Writer.TryWrite(result); + + do + { + cts[0].Cancel(); + sem.Release(1); + var resolveTask = watcher.GetEndpointsAsync(CancellationToken.None); + await resolveTask; + var next = await channel.Reader.ReadAsync(CancellationToken.None); + if (next.ResolvedSuccessfully) + { + break; + } + } while (true); + + var task = watcher.GetEndpointsAsync(CancellationToken.None); + sem.Release(1); + var result = await task; + Assert.NotSame(initialEndpoints, result); + var sep = Assert.Single(result.Endpoints); + var ip = Assert.IsType(sep.EndPoint); + Assert.Equal(IPAddress.Parse("127.1.1.1"), ip.Address); + Assert.Equal(8080, ip.Port); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ServiceEndpointTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ServiceEndpointTests.cs new file mode 100644 index 00000000000..2943074c2b3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Tests/ServiceEndpointTests.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Tests; + +public class ServiceEndpointTests +{ + public static TheoryData ZeroPortEndPoints => new() + { + (EndPoint)IPEndPoint.Parse("127.0.0.1:0"), + (EndPoint)new DnsEndPoint("microsoft.com", 0), + (EndPoint)new UriEndPoint(new Uri("https://microsoft.com")) + }; + + public static TheoryData NonZeroPortEndPoints => new() + { + (EndPoint)IPEndPoint.Parse("127.0.0.1:8443"), + (EndPoint)new DnsEndPoint("microsoft.com", 8443), + (EndPoint)new UriEndPoint(new Uri("https://microsoft.com:8443")) + }; + + [Theory] + [MemberData(nameof(ZeroPortEndPoints))] + public void ServiceEndpointToStringOmitsUnspecifiedPort(EndPoint endpoint) + { + var serviceEndpoint = ServiceEndpoint.Create(endpoint); + var epString = serviceEndpoint.ToString(); + Assert.DoesNotContain(":0", epString); + } + + [Theory] + [MemberData(nameof(NonZeroPortEndPoints))] + public void ServiceEndpointToStringContainsSpecifiedPort(EndPoint endpoint) + { + var serviceEndpoint = ServiceEndpoint.Create(endpoint); + var epString = serviceEndpoint.ToString(); + Assert.Contains(":8443", epString); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests.csproj b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests.csproj new file mode 100644 index 00000000000..00211519268 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests.csproj @@ -0,0 +1,23 @@ + + + + $(TestNetCoreTargetFrameworks) + enable + enable + + $(NoWarn);CA2000;S103;S1144;S3459;S4136;SA1208;SA1210;VSTHRD003 + + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryPublicApiTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryPublicApiTests.cs new file mode 100644 index 00000000000..a3b694c6d70 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryPublicApiTests.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.ServiceDiscovery.Yarp.Tests; + +#pragma warning disable IDE0200 + +public class YarpServiceDiscoveryPublicApiTests +{ + [Fact] + public void AddServiceDiscoveryDestinationResolverShouldThrowWhenBuilderIsNull() + { + IReverseProxyBuilder builder = null!; + + var action = () => builder.AddServiceDiscoveryDestinationResolver(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(builder), exception.ParamName); + } + + [Fact] + public void AddHttpForwarderWithServiceDiscoveryShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + + var action = () => services.AddHttpForwarderWithServiceDiscovery(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } + + [Fact] + public void AddServiceDiscoveryForwarderFactoryShouldThrowWhenServicesIsNull() + { + IServiceCollection services = null!; + + var action = () => services.AddServiceDiscoveryForwarderFactory(); + + var exception = Assert.Throws(action); + Assert.Equal(nameof(services), exception.ParamName); + } +} diff --git a/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs new file mode 100644 index 00000000000..c2751823c65 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.ServiceDiscovery.Yarp.Tests/YarpServiceDiscoveryTests.cs @@ -0,0 +1,323 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Xunit; +using Yarp.ReverseProxy.Configuration; +using System.Net; +using System.Net.Sockets; +using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.Extensions.ServiceDiscovery.Yarp.Tests; + +/// +/// Tests for YARP with Service Discovery enabled. +/// +public class YarpServiceDiscoveryTests +{ + private static ServiceDiscoveryDestinationResolver CreateResolver(IServiceProvider serviceProvider) + { + var coreResolver = serviceProvider.GetRequiredService(); + return new ServiceDiscoveryDestinationResolver( + coreResolver, + serviceProvider.GetRequiredService>()); + } + + [Fact] + public async Task ServiceDiscoveryDestinationResolverTests_PassThrough() + { + await using var services = new ServiceCollection() + .AddServiceDiscoveryCore() + .AddPassThroughServiceEndpointProvider() + .BuildServiceProvider(); + var yarpResolver = CreateResolver(services); + + var destinationConfigs = new Dictionary + { + ["dest-a"] = new() + { + Address = "https://my-svc", + }, + }; + + var result = await yarpResolver.ResolveDestinationsAsync(destinationConfigs, CancellationToken.None); + + Assert.Single(result.Destinations); + Assert.Collection(result.Destinations.Select(d => d.Value.Address), + a => Assert.Equal("https://my-svc/", a)); + } + + [Fact] + public async Task ServiceDiscoveryDestinationResolverTests_Configuration() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new Dictionary + { + ["services:basket:default:0"] = "ftp://localhost:2121", + ["services:basket:default:1"] = "https://localhost:8888", + ["services:basket:default:2"] = "http://localhost:1111", + }); + await using var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider() + .BuildServiceProvider(); + var yarpResolver = CreateResolver(services); + + var destinationConfigs = new Dictionary + { + ["dest-a"] = new() + { + Address = "https+http://basket", + }, + }; + + var result = await yarpResolver.ResolveDestinationsAsync(destinationConfigs, CancellationToken.None); + + Assert.Single(result.Destinations); + Assert.Collection(result.Destinations.Select(d => d.Value.Address), + a => Assert.Equal("https://localhost:8888/", a)); + } + + [Fact] + public async Task ServiceDiscoveryDestinationResolverTests_Configuration_NonPreferredScheme() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new Dictionary + { + ["services:basket:default:0"] = "ftp://localhost:2121", + ["services:basket:default:1"] = "http://localhost:1111", + }); + await using var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider() + .BuildServiceProvider(); + var yarpResolver = CreateResolver(services); + + var destinationConfigs = new Dictionary + { + ["dest-a"] = new() + { + Address = "https+http://basket", + }, + }; + + var result = await yarpResolver.ResolveDestinationsAsync(destinationConfigs, CancellationToken.None); + + Assert.Single(result.Destinations); + Assert.Collection(result.Destinations.Select(d => d.Value.Address), + a => Assert.Equal("http://localhost:1111/", a)); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ServiceDiscoveryDestinationResolverTests_Configuration_Host_Value(bool configHasHost) + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new Dictionary + { + ["services:basket:default:0"] = "https://localhost:1111", + ["services:basket:default:1"] = "https://127.0.0.1:2222", + ["services:basket:default:2"] = "https://[::1]:3333", + ["services:basket:default:3"] = "https://baskets-galore.faketld", + }); + await using var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider() + .BuildServiceProvider(); + var yarpResolver = CreateResolver(services); + + var destinationConfigs = new Dictionary + { + ["dest-a"] = new() + { + Address = "https://basket", + Host = configHasHost ? "my-basket-svc.faketld" : null + }, + }; + + var result = await yarpResolver.ResolveDestinationsAsync(destinationConfigs, CancellationToken.None); + + Assert.Equal(4, result.Destinations.Count); + Assert.Collection(result.Destinations.Values, + a => + { + Assert.Equal("https://localhost:1111/", a.Address); + if (configHasHost) + { + Assert.Equal("my-basket-svc.faketld", a.Host); + } + else + { + Assert.Null(a.Host); + } + }, + a => + { + Assert.Equal("https://127.0.0.1:2222/", a.Address); + if (configHasHost) + { + Assert.Equal("my-basket-svc.faketld", a.Host); + } + else + { + Assert.Null(a.Host); + } + }, + a => + { + Assert.Equal("https://[::1]:3333/", a.Address); + if (configHasHost) + { + Assert.Equal("my-basket-svc.faketld", a.Host); + } + else + { + Assert.Null(a.Host); + } + }, + a => + { + Assert.Equal("https://baskets-galore.faketld/", a.Address); + if (configHasHost) + { + Assert.Equal("my-basket-svc.faketld", a.Host); + } + else + { + Assert.Null(a.Host); + } + }); + } + + [Fact] + public async Task ServiceDiscoveryDestinationResolverTests_Configuration_DisallowedScheme() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new Dictionary + { + ["services:basket:default:0"] = "ftp://localhost:2121", + ["services:basket:default:1"] = "http://localhost:1111", + }); + await using var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .Configure(o => + { + // Allow only "https://" + o.AllowAllSchemes = false; + o.AllowedSchemes = ["https"]; + }) + .AddConfigurationServiceEndpointProvider() + .BuildServiceProvider(); + var yarpResolver = CreateResolver(services); + + var destinationConfigs = new Dictionary + { + ["dest-a"] = new() + { + Address = "https+http://basket", + }, + }; + + var result = await yarpResolver.ResolveDestinationsAsync(destinationConfigs, CancellationToken.None); + + // No results: there are no 'https' endpoints in config and 'http' is disallowed. + Assert.Empty(result.Destinations); + } + + [Fact] + public async Task ServiceDiscoveryDestinationResolverTests_Dns() + { + DnsResolver resolver = new DnsResolver(TimeProvider.System, NullLogger.Instance); + + await using var services = new ServiceCollection() + .AddSingleton(resolver) + .AddServiceDiscoveryCore() + .AddDnsServiceEndpointProvider() + .BuildServiceProvider(); + var yarpResolver = CreateResolver(services); + + var destinationConfigs = new Dictionary + { + ["dest-a"] = new() + { + Address = "https://microsoft.com", + }, + ["dest-b"] = new() + { + Address = "http://msn.com", + }, + }; + + var result = await yarpResolver.ResolveDestinationsAsync(destinationConfigs, CancellationToken.None); + Assert.NotNull(result); + Assert.NotEmpty(result.Destinations); + Assert.All(result.Destinations, d => + { + var address = d.Value.Address; + Assert.True(Uri.TryCreate(address, default, out var uri), $"Failed to parse address '{address}' as URI."); + Assert.True(uri.IsDefaultPort, "URI should use the default port when resolved via DNS."); + var expectedScheme = d.Key.StartsWith("dest-a") ? "https" : "http"; + Assert.Equal(expectedScheme, uri.Scheme); + }); + } + + [Fact] + public async Task ServiceDiscoveryDestinationResolverTests_DnsSrv() + { + var dnsClientMock = new FakeDnsResolver + { + ResolveServiceAsyncFunc = (name, cancellationToken) => + { + ServiceResult[] response = [ + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 66, 8888, "srv-a", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Parse("10.10.10.10"))]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 9999, "srv-b", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.IPv6Loopback)]), + new ServiceResult(DateTime.UtcNow.AddSeconds(60), 99, 62, 7777, "srv-c", [new AddressResult(DateTime.UtcNow.AddSeconds(64), IPAddress.Loopback)]) + ]; + + return ValueTask.FromResult(response); + } + }; + + await using var services = new ServiceCollection() + .AddSingleton(dnsClientMock) + .AddServiceDiscoveryCore() + .AddDnsSrvServiceEndpointProvider(options => options.QuerySuffix = ".ns") + .BuildServiceProvider(); + var yarpResolver = CreateResolver(services); + + var destinationConfigs = new Dictionary + { + ["dest-a"] = new() + { + Address = "https://my-svc", + }, + }; + + var result = await yarpResolver.ResolveDestinationsAsync(destinationConfigs, CancellationToken.None); + + Assert.Equal(3, result.Destinations.Count); + Assert.Collection(result.Destinations.Select(d => d.Value.Address), + a => Assert.Equal("https://10.10.10.10:8888/", a), + a => Assert.Equal("https://[::1]:9999/", a), + a => Assert.Equal("https://127.0.0.1:7777/", a)); + } + + private sealed class FakeDnsResolver : IDnsResolver + { + public Func>? ResolveIPAddressesAsyncFunc { get; set; } + public ValueTask ResolveIPAddressesAsync(string name, AddressFamily addressFamily, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc!.Invoke(name, addressFamily, cancellationToken); + + public Func>? ResolveIPAddressesAsyncFunc2 { get; set; } + + public ValueTask ResolveIPAddressesAsync(string name, CancellationToken cancellationToken = default) => ResolveIPAddressesAsyncFunc2!.Invoke(name, cancellationToken); + + public Func>? ResolveServiceAsyncFunc { get; set; } + + public ValueTask ResolveServiceAsync(string name, CancellationToken cancellationToken = default) => ResolveServiceAsyncFunc!.Invoke(name, cancellationToken); + } +}