diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Afd/AfdConfigurationClientManager.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Afd/AfdConfigurationClientManager.cs new file mode 100644 index 000000000..87c8c24bb --- /dev/null +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Afd/AfdConfigurationClientManager.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +using Azure.Data.AppConfiguration; +using Microsoft.Extensions.Azure; +using System; +using System.Collections.Generic; +namespace Microsoft.Extensions.Configuration.AzureAppConfiguration.Afd +{ + internal class AfdConfigurationClientManager : IConfigurationClientManager + { + private readonly ConfigurationClientWrapper _clientWrapper; + + public AfdConfigurationClientManager( + IAzureClientFactory clientFactory, + Uri endpoint) + { + if (clientFactory == null) + { + throw new ArgumentNullException(nameof(clientFactory)); + } + + if (endpoint == null) + { + throw new ArgumentNullException(nameof(endpoint)); + } + + _clientWrapper = new ConfigurationClientWrapper(endpoint, clientFactory.CreateClient(endpoint.AbsoluteUri)); + } + + public IEnumerable GetClients() + { + return new List { _clientWrapper.Client }; + } + + public void RefreshClients() + { + return; + } + + public bool UpdateSyncToken(Uri endpoint, string syncToken) + { + if (endpoint == null) + { + throw new ArgumentNullException(nameof(endpoint)); + } + + if (string.IsNullOrWhiteSpace(syncToken)) + { + throw new ArgumentNullException(nameof(syncToken)); + } + + if (new EndpointComparer().Equals(_clientWrapper.Endpoint, endpoint)) + { + _clientWrapper.Client.UpdateSyncToken(syncToken); + return true; + } + + return false; + } + + public Uri GetEndpointForClient(ConfigurationClient client) + { + if (client == null) + { + throw new ArgumentNullException(nameof(client)); + } + + return _clientWrapper.Client == client ? _clientWrapper.Endpoint : null; + } + } +} diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Afd/AfdPolicy.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Afd/AfdPolicy.cs new file mode 100644 index 000000000..b311cf5ef --- /dev/null +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Afd/AfdPolicy.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +using Azure.Core; +using Azure.Core.Pipeline; +using System; +using System.Collections.Generic; +using System.Collections.Specialized; +using System.Linq; +using System.Web; + +namespace Microsoft.Extensions.Configuration.AzureAppConfiguration.Afd +{ + /// + /// HTTP pipeline policy that removes Authorization headers from requests and orders query parameters by lowercase. + /// + internal class AfdPolicy : HttpPipelinePolicy + { + /// + /// Initializes a new instance of the class. + /// + public AfdPolicy() + { + } + + /// + /// Processes the HTTP message, removes the Authorization header, and orders query parameters by lowercase. + /// + /// The HTTP message. + /// The pipeline. + public override void Process(HttpMessage message, ReadOnlyMemory pipeline) + { + message.Request.Headers.Remove("Authorization"); + + message.Request.Uri.Reset(OrderQueryParameters(message.Request.Uri.ToUri())); + + ProcessNext(message, pipeline); + } + + /// + /// Processes the HTTP message asynchronously, removes the Authorization header, and orders query parameters by lowercase. + /// + /// The HTTP message. + /// The pipeline. + /// A task representing the asynchronous operation. + public override async System.Threading.Tasks.ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline) + { + message.Request.Headers.Remove("Authorization"); + + message.Request.Uri.Reset(OrderQueryParameters(message.Request.Uri.ToUri())); + + await ProcessNextAsync(message, pipeline).ConfigureAwait(false); + } + + private static Uri OrderQueryParameters(Uri uri) + { + var uriBuilder = new UriBuilder(uri); + + NameValueCollection query = HttpUtility.ParseQueryString(uriBuilder.Query); + + if (query.Count > 0) + { + IEnumerable orderedParams = query.AllKeys + .Where(key => key != null) + .OrderBy(key => key.ToLowerInvariant()) + .Select(key => + { + string value = query[key]; + + if (value == null) + { + return Uri.EscapeDataString(key); + } + + return $"{Uri.EscapeDataString(key)}={Uri.EscapeDataString(value)}"; + }); + + uriBuilder.Query = string.Join("&", orderedParams); + } + + return uriBuilder.Uri; + } + } +} diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Afd/EmptyTokenCredential.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Afd/EmptyTokenCredential.cs new file mode 100644 index 000000000..1239676f7 --- /dev/null +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Afd/EmptyTokenCredential.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +using Azure.Core; +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.Configuration.AzureAppConfiguration.Afd +{ + /// + /// A token credential that provides an empty token. + /// + internal class EmptyTokenCredential : TokenCredential + { + /// + /// Gets an empty token. + /// + /// The context of the token request. + /// A cancellation token to cancel the operation. + /// An empty access token. + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + return new AccessToken(string.Empty, DateTimeOffset.MaxValue); + } + + /// + /// Asynchronously gets an empty token. + /// + /// The context of the token request. + /// A cancellation token to cancel the operation. + /// A task that represents the asynchronous operation. The task result contains an empty access token. + public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + return new ValueTask(new AccessToken(string.Empty, DateTimeOffset.MaxValue)); + } + } +} diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationOptions.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationOptions.cs index bb48372ab..37e330720 100644 --- a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationOptions.cs +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationOptions.cs @@ -5,6 +5,7 @@ using Azure.Data.AppConfiguration; using Microsoft.Extensions.Azure; using Microsoft.Extensions.Configuration.AzureAppConfiguration.AzureKeyVault; +using Microsoft.Extensions.Configuration.AzureAppConfiguration.Afd; using Microsoft.Extensions.Configuration.AzureAppConfiguration.Extensions; using Microsoft.Extensions.Configuration.AzureAppConfiguration.FeatureManagement; using Microsoft.Extensions.Configuration.AzureAppConfiguration.Models; @@ -127,7 +128,7 @@ internal IEnumerable Adapters /// /// Options used to configure the client used to communicate with Azure App Configuration. /// - internal ConfigurationClientOptions ClientOptions { get; } = GetDefaultClientOptions(); + internal ConfigurationClientOptions ClientOptions { get; private set; } = GetDefaultClientOptions(); /// /// Flag to indicate whether Key Vault options have been configured. @@ -154,6 +155,11 @@ internal IEnumerable Adapters /// internal IAzureClientFactory ClientFactory { get; private set; } + /// + /// Gets a value indicating whether AFD is enabled. + /// + internal bool IsAfdEnabled { get; private set; } + /// /// Initializes a new instance of the class. /// @@ -181,6 +187,7 @@ public AzureAppConfigurationOptions() public AzureAppConfigurationOptions SetClientFactory(IAzureClientFactory factory) { ClientFactory = factory ?? throw new ArgumentNullException(nameof(factory)); + return this; } @@ -357,6 +364,11 @@ public AzureAppConfigurationOptions Connect(string connectionString) /// public AzureAppConfigurationOptions Connect(IEnumerable connectionStrings) { + if (IsAfdEnabled) + { + throw new InvalidOperationException("Cannot connect to both Azure App Configuration and AFD at the same time."); + } + if (connectionStrings == null || !connectionStrings.Any()) { throw new ArgumentNullException(nameof(connectionStrings)); @@ -373,6 +385,32 @@ public AzureAppConfigurationOptions Connect(IEnumerable connectionString return this; } + /// + /// Connect the provider to Azure Front Door endpoint. + /// + /// The endpoint of the Azure Front Door (AFD) instance to connect to. + public AzureAppConfigurationOptions ConnectAzureFrontDoor(Uri endpoint) + { + if ((Credential != null && !(Credential is EmptyTokenCredential)) || (ConnectionStrings?.Any() ?? false)) + { + throw new InvalidOperationException("Cannot connect to both Azure App Configuration and Azure Front Door at the same time."); + } + + if (endpoint == null) + { + throw new ArgumentNullException(nameof(endpoint)); + } + + Credential ??= new EmptyTokenCredential(); + + Endpoints = new List() { endpoint }; + ConnectionStrings = null; + + IsAfdEnabled = true; + + return this; + } + /// /// Connect the provider to Azure App Configuration using endpoint and token credentials. /// @@ -400,6 +438,11 @@ public AzureAppConfigurationOptions Connect(Uri endpoint, TokenCredential creden /// Token credential to use to connect. public AzureAppConfigurationOptions Connect(IEnumerable endpoints, TokenCredential credential) { + if (IsAfdEnabled) + { + throw new InvalidOperationException("Cannot connect to both Azure App Configuration and AFD at the same time."); + } + if (endpoints == null || !endpoints.Any()) { throw new ArgumentNullException(nameof(endpoints)); diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationProvider.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationProvider.cs index 6b100f8da..ee045fd42 100644 --- a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationProvider.cs +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationProvider.cs @@ -314,23 +314,54 @@ await ExecuteWithFailOverPolicyAsync(clients, async (client) => // Get key value collection changes if RegisterAll was called if (isRefreshDue) { - refreshAll = await HaveCollectionsChanged( - _options.Selectors.Where(selector => !selector.IsFeatureFlagSelector), - _kvEtags, - client, - cancellationToken).ConfigureAwait(false); + foreach (KeyValueSelector selector in _options.Selectors.Where(selector => !selector.IsFeatureFlagSelector)) + { + if (_kvEtags.TryGetValue(selector, out IEnumerable matchConditions)) + { + await TracingUtils.CallWithRequestTracing(_requestTracingEnabled, RequestType.Watch, _requestTracingOptions, + async () => refreshAll = await client.HaveCollectionsChanged( + selector, + matchConditions, + _options.ConfigurationSettingPageIterator, + makeConditionalRequest: !_options.IsAfdEnabled, + cancellationToken).ConfigureAwait(false)).ConfigureAwait(false); + } + + if (refreshAll) + { + break; + } + } } } else { - refreshAll = await RefreshIndividualKvWatchers( - client, - keyValueChanges, - refreshableIndividualKvWatchers, - endpoint, - logDebugBuilder, - logInfoBuilder, - cancellationToken).ConfigureAwait(false); + foreach (KeyValueWatcher kvWatcher in refreshableIndividualKvWatchers) + { + KeyValueChange change = await CheckForChange(client, kvWatcher, cancellationToken).ConfigureAwait(false); + + // + // Skip if no change detected + if (change.ChangeType == KeyValueChangeType.None) + { + logDebugBuilder.AppendLine(LogHelper.BuildKeyValueReadMessage(change.ChangeType, change.Key, change.Label, endpoint.ToString())); + + continue; + } + + logDebugBuilder.AppendLine(LogHelper.BuildKeyValueReadMessage(change.ChangeType, change.Key, change.Label, endpoint.ToString())); + + logInfoBuilder.AppendLine(LogHelper.BuildKeyValueSettingUpdatedMessage(change.Key)); + + keyValueChanges.Add(change); + + if (kvWatcher.RefreshAll) + { + refreshAll = true; + + break; + } + } } if (refreshAll) @@ -347,18 +378,31 @@ await ExecuteWithFailOverPolicyAsync(clients, async (client) => return; } - // Get feature flag changes - ffCollectionUpdated = await HaveCollectionsChanged( - refreshableFfWatchers.Select(watcher => new KeyValueSelector + var ffSelectors = refreshableFfWatchers.Select(watcher => new KeyValueSelector + { + KeyFilter = watcher.Key, + LabelFilter = watcher.Label, + IsFeatureFlagSelector = true + }); + + foreach (KeyValueSelector selector in ffSelectors) + { + if (_ffEtags.TryGetValue(selector, out IEnumerable matchConditions)) { - KeyFilter = watcher.Key, - LabelFilter = watcher.Label, - TagFilters = watcher.Tags, - IsFeatureFlagSelector = true - }), - _ffEtags, - client, - cancellationToken).ConfigureAwait(false); + await TracingUtils.CallWithRequestTracing(_requestTracingEnabled, RequestType.Watch, _requestTracingOptions, + async () => ffCollectionUpdated = await client.HaveCollectionsChanged( + selector, + matchConditions, + _options.ConfigurationSettingPageIterator, + makeConditionalRequest: !_options.IsAfdEnabled, + cancellationToken).ConfigureAwait(false)).ConfigureAwait(false); + } + + if (ffCollectionUpdated) + { + break; + } + } if (ffCollectionUpdated) { @@ -974,76 +1018,47 @@ private async Task> LoadKey return watchedIndividualKvs; } - private async Task RefreshIndividualKvWatchers( - ConfigurationClient client, - List keyValueChanges, - IEnumerable refreshableIndividualKvWatchers, - Uri endpoint, - StringBuilder logDebugBuilder, - StringBuilder logInfoBuilder, - CancellationToken cancellationToken) + private async Task CheckForChange(ConfigurationClient client, KeyValueWatcher kvWatcher, CancellationToken cancellationToken) { - foreach (KeyValueWatcher kvWatcher in refreshableIndividualKvWatchers) - { - string watchedKey = kvWatcher.Key; - string watchedLabel = kvWatcher.Label; + Debug.Assert(client != null); + Debug.Assert(kvWatcher != null); - KeyValueIdentifier watchedKeyLabel = new KeyValueIdentifier(watchedKey, watchedLabel); + KeyValueChange change = default; - KeyValueChange change = default; + // + // Find if there is a change associated with watcher + if (_watchedIndividualKvs.TryGetValue(new KeyValueIdentifier(kvWatcher.Key, kvWatcher.Label), out ConfigurationSetting watchedKv)) + { + await TracingUtils.CallWithRequestTracing(_requestTracingEnabled, RequestType.Watch, _requestTracingOptions, + async () => change = await client.GetKeyValueChange(watchedKv, makeConditionalRequest: !_options.IsAfdEnabled, cancellationToken).ConfigureAwait(false)).ConfigureAwait(false); + } + else + { + // Load the key-value in case the previous load attempts had failed - // - // Find if there is a change associated with watcher - if (_watchedIndividualKvs.TryGetValue(watchedKeyLabel, out ConfigurationSetting watchedKv)) + try { - await TracingUtils.CallWithRequestTracing(_requestTracingEnabled, RequestType.Watch, _requestTracingOptions, - async () => change = await client.GetKeyValueChange(watchedKv, cancellationToken).ConfigureAwait(false)).ConfigureAwait(false); + await CallWithRequestTracing( + async () => watchedKv = await client.GetConfigurationSettingAsync(kvWatcher.Key, kvWatcher.Label, cancellationToken).ConfigureAwait(false)).ConfigureAwait(false); } - else + catch (RequestFailedException e) when (e.Status == (int)HttpStatusCode.NotFound) { - // Load the key-value in case the previous load attempts had failed - - try - { - await CallWithRequestTracing( - async () => watchedKv = await client.GetConfigurationSettingAsync(watchedKey, watchedLabel, cancellationToken).ConfigureAwait(false)).ConfigureAwait(false); - } - catch (RequestFailedException e) when (e.Status == (int)HttpStatusCode.NotFound) - { - watchedKv = null; - } - - if (watchedKv != null) - { - change = new KeyValueChange() - { - Key = watchedKv.Key, - Label = watchedKv.Label.NormalizeNull(), - Current = watchedKv, - ChangeType = KeyValueChangeType.Modified - }; - } + watchedKv = null; } - // Check if a change has been detected in the key-value registered for refresh - if (change.ChangeType != KeyValueChangeType.None) + if (watchedKv != null) { - logDebugBuilder.AppendLine(LogHelper.BuildKeyValueReadMessage(change.ChangeType, change.Key, change.Label, endpoint.ToString())); - logInfoBuilder.AppendLine(LogHelper.BuildKeyValueSettingUpdatedMessage(change.Key)); - keyValueChanges.Add(change); - - if (kvWatcher.RefreshAll) + change = new KeyValueChange() { - return true; - } - } - else - { - logDebugBuilder.AppendLine(LogHelper.BuildKeyValueReadMessage(change.ChangeType, change.Key, change.Label, endpoint.ToString())); + Key = watchedKv.Key, + Label = watchedKv.Label.NormalizeNull(), + Current = watchedKv, + ChangeType = KeyValueChangeType.Modified + }; } } - return false; + return change; } private void SetData(IDictionary data) @@ -1099,7 +1114,8 @@ private void SetRequestTracingOptions() IsKeyVaultConfigured = _options.IsKeyVaultConfigured, IsKeyVaultRefreshConfigured = _options.IsKeyVaultRefreshConfigured, FeatureFlagTracing = _options.FeatureFlagTracing, - IsLoadBalancingEnabled = _options.LoadBalancingEnabled + IsLoadBalancingEnabled = _options.LoadBalancingEnabled, + IsAfdEnabled = _options.IsAfdEnabled }; } @@ -1362,35 +1378,6 @@ private void UpdateClientBackoffStatus(Uri endpoint, bool successful) _configClientBackoffs[endpoint] = clientBackoffStatus; } - private async Task HaveCollectionsChanged( - IEnumerable selectors, - Dictionary> pageEtags, - ConfigurationClient client, - CancellationToken cancellationToken) - { - bool haveCollectionsChanged = false; - - foreach (KeyValueSelector selector in selectors) - { - if (pageEtags.TryGetValue(selector, out IEnumerable matchConditions)) - { - await TracingUtils.CallWithRequestTracing(_requestTracingEnabled, RequestType.Watch, _requestTracingOptions, - async () => haveCollectionsChanged = await client.HaveCollectionsChanged( - selector, - matchConditions, - _options.ConfigurationSettingPageIterator, - cancellationToken).ConfigureAwait(false)).ConfigureAwait(false); - } - - if (haveCollectionsChanged) - { - return true; - } - } - - return haveCollectionsChanged; - } - private async Task ProcessKeyValueChangesAsync( IEnumerable keyValueChanges, Dictionary mappedData, diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationSource.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationSource.cs index 83d20e2fb..52faab570 100644 --- a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationSource.cs +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureAppConfigurationSource.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // +using Azure.Core; using Azure.Data.AppConfiguration; using Microsoft.Extensions.Azure; +using Microsoft.Extensions.Configuration.AzureAppConfiguration.Afd; using System; using System.Collections.Generic; using System.Linq; @@ -34,13 +36,29 @@ public IConfigurationProvider Build(IConfigurationBuilder builder) { AzureAppConfigurationOptions options = _optionsProvider(); + IAzureClientFactory clientFactory = options.ClientFactory; + + if (options.IsAfdEnabled) + { + if (options.LoadBalancingEnabled) + { + throw new InvalidOperationException("Load balancing is not supported when connecting to AFD."); + } + + if (clientFactory != null) + { + throw new InvalidOperationException($"Custom client factory is not supported when connecting to AFD."); + } + + options.ClientOptions.AddPolicy(new AfdPolicy(), HttpPipelinePosition.PerRetry); + } + if (options.ClientManager != null) { return new AzureAppConfigurationProvider(options.ClientManager, options, _optional); } IEnumerable endpoints; - IAzureClientFactory clientFactory = options.ClientFactory; if (options.ConnectionStrings != null) { @@ -56,10 +74,17 @@ public IConfigurationProvider Build(IConfigurationBuilder builder) } else { - throw new ArgumentException($"Please call {nameof(AzureAppConfigurationOptions)}.{nameof(AzureAppConfigurationOptions.Connect)} to specify how to connect to Azure App Configuration."); + throw new ArgumentException($"Please call {nameof(AzureAppConfigurationOptions)}.{nameof(AzureAppConfigurationOptions.Connect)} or {nameof(AzureAppConfigurationOptions)}.{nameof(AzureAppConfigurationOptions.ConnectAzureFrontDoor)} to specify how to connect to Azure App Configuration."); } - provider = new AzureAppConfigurationProvider(new ConfigurationClientManager(clientFactory, endpoints, options.ReplicaDiscoveryEnabled, options.LoadBalancingEnabled), options, _optional); + if (options.IsAfdEnabled) + { + provider = new AzureAppConfigurationProvider(new AfdConfigurationClientManager(clientFactory, endpoints.First()), options, _optional); + } + else + { + provider = new AzureAppConfigurationProvider(new ConfigurationClientManager(clientFactory, endpoints, options.ReplicaDiscoveryEnabled, options.LoadBalancingEnabled), options, _optional); + } } catch (InvalidOperationException ex) // InvalidOperationException is thrown when any problems are found while configuring AzureAppConfigurationOptions or when SDK fails to create a configurationClient. { diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Constants/RequestTracingConstants.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Constants/RequestTracingConstants.cs index 612e1bccf..c23774d8d 100644 --- a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Constants/RequestTracingConstants.cs +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Constants/RequestTracingConstants.cs @@ -37,6 +37,7 @@ internal class RequestTracingConstants public const string SignalRUsedTag = "SignalR"; public const string FailoverRequestTag = "Failover"; public const string PushRefreshTag = "PushRefresh"; + public const string AfdTag = "AFD"; public const string FeatureFlagFilterTypeKey = "Filter"; public const string CustomFilter = "CSTM"; diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Extensions/ConfigurationClientExtensions.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Extensions/ConfigurationClientExtensions.cs index c4edfb0ee..b4ae0bc31 100644 --- a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Extensions/ConfigurationClientExtensions.cs +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/Extensions/ConfigurationClientExtensions.cs @@ -14,7 +14,7 @@ namespace Microsoft.Extensions.Configuration.AzureAppConfiguration.Extensions { internal static class ConfigurationClientExtensions { - public static async Task GetKeyValueChange(this ConfigurationClient client, ConfigurationSetting setting, CancellationToken cancellationToken) + public static async Task GetKeyValueChange(this ConfigurationClient client, ConfigurationSetting setting, bool makeConditionalRequest, CancellationToken cancellationToken) { if (setting == null) { @@ -28,7 +28,7 @@ public static async Task GetKeyValueChange(this ConfigurationCli try { - Response response = await client.GetConfigurationSettingAsync(setting, onlyIfChanged: true, cancellationToken).ConfigureAwait(false); + Response response = await client.GetConfigurationSettingAsync(setting, onlyIfChanged: makeConditionalRequest, cancellationToken).ConfigureAwait(false); if (response.GetRawResponse().Status == (int)HttpStatusCode.OK && !response.Value.ETag.Equals(setting.ETag)) { @@ -64,7 +64,7 @@ public static async Task GetKeyValueChange(this ConfigurationCli }; } - public static async Task HaveCollectionsChanged(this ConfigurationClient client, KeyValueSelector keyValueSelector, IEnumerable matchConditions, IConfigurationSettingPageIterator pageIterator, CancellationToken cancellationToken) + public static async Task HaveCollectionsChanged(this ConfigurationClient client, KeyValueSelector keyValueSelector, IEnumerable matchConditions, IConfigurationSettingPageIterator pageIterator, bool makeConditionalRequest, CancellationToken cancellationToken) { if (matchConditions == null) { @@ -91,7 +91,9 @@ public static async Task HaveCollectionsChanged(this ConfigurationClient c using IEnumerator existingMatchConditionsEnumerator = matchConditions.GetEnumerator(); - await foreach (Page page in pageable.AsPages(pageIterator, matchConditions).ConfigureAwait(false)) + IAsyncEnumerable> pages = makeConditionalRequest ? pageable.AsPages(pageIterator, matchConditions) : pageable.AsPages(pageIterator); + + await foreach (Page page in pages.ConfigureAwait(false)) { using Response response = page.GetRawResponse(); diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/KeyValueChange.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/KeyValueChange.cs index 2286016d7..a430e7a02 100644 --- a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/KeyValueChange.cs +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/KeyValueChange.cs @@ -2,6 +2,9 @@ // Licensed under the MIT license. // using Azure.Data.AppConfiguration; +using Microsoft.Extensions.Configuration.AzureAppConfiguration.Extensions; +using System.Security.Cryptography; +using System.Text; namespace Microsoft.Extensions.Configuration.AzureAppConfiguration { diff --git a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/RequestTracingOptions.cs b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/RequestTracingOptions.cs index 21582db1c..96c9b843a 100644 --- a/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/RequestTracingOptions.cs +++ b/src/Microsoft.Extensions.Configuration.AzureAppConfiguration/RequestTracingOptions.cs @@ -70,6 +70,11 @@ internal class RequestTracingOptions /// public bool IsPushRefreshUsed { get; set; } = false; + /// + /// Flag to indicate wether the request is sent to a AFD. + /// + public bool IsAfdEnabled { get; set; } = false; + /// /// Flag to indicate whether any key-value uses the json content type and contains /// a parameter indicating an AI profile. @@ -120,7 +125,8 @@ public bool UsesAnyTracingFeature() return IsLoadBalancingEnabled || IsSignalRUsed || UsesAIConfiguration || - UsesAIChatCompletionConfiguration; + UsesAIChatCompletionConfiguration || + IsAfdEnabled; } /// @@ -171,6 +177,16 @@ public string CreateFeaturesString() sb.Append(RequestTracingConstants.AIChatCompletionConfigurationTag); } + if (IsAfdEnabled) + { + if (sb.Length > 0) + { + sb.Append(RequestTracingConstants.Delimiter); + } + + sb.Append(RequestTracingConstants.AfdTag); + } + return sb.ToString(); } } diff --git a/tests/Tests.AzureAppConfiguration/Azure.Core.Testing/MockResponse.cs b/tests/Tests.AzureAppConfiguration/Azure.Core.Testing/MockResponse.cs index c60c2a255..5322bb1df 100644 --- a/tests/Tests.AzureAppConfiguration/Azure.Core.Testing/MockResponse.cs +++ b/tests/Tests.AzureAppConfiguration/Azure.Core.Testing/MockResponse.cs @@ -13,14 +13,14 @@ public class MockResponse : Response { private readonly Dictionary> _headers = new Dictionary>(StringComparer.OrdinalIgnoreCase); - public MockResponse(int status, string reasonPhrase = null) + public MockResponse(int status, string etag = null, string reasonPhrase = null) { Status = status; ReasonPhrase = reasonPhrase; if (status == 200) { - AddHeader(new HttpHeader(HttpHeader.Names.ETag, "\"" + Guid.NewGuid().ToString() + "\"")); + AddHeader(new HttpHeader(HttpHeader.Names.ETag, etag ?? "\"" + Guid.NewGuid().ToString() + "\"")); } } diff --git a/tests/Tests.AzureAppConfiguration/Unit/AfdTests.cs b/tests/Tests.AzureAppConfiguration/Unit/AfdTests.cs new file mode 100644 index 000000000..02ec5b1f2 --- /dev/null +++ b/tests/Tests.AzureAppConfiguration/Unit/AfdTests.cs @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +using Azure; +using Azure.Core.Testing; +using Azure.Data.AppConfiguration; +using Microsoft.Extensions.Azure; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Configuration.AzureAppConfiguration; +using Moq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Tests.AzureAppConfiguration +{ + public class AfdTests + { + List _kvCollection = new List + { + ConfigurationModelFactory.ConfigurationSetting( + key: "TestKey1", + label: "label", + value: "TestValue1", + eTag: new ETag("0a76e3d7-7ec1-4e37-883c-9ea6d0d89e63"), + contentType: "text"), + + ConfigurationModelFactory.ConfigurationSetting( + key: "TestKey2", + label: "label", + value: "TestValue2", + eTag: new ETag("31c38369-831f-4bf1-b9ad-79db56c8b989"), + contentType: "text"), + + ConfigurationModelFactory.ConfigurationSetting( + key: "TestKey3", + label: "label", + value: "TestValue3", + eTag: new ETag("bb203f2b-c113-44fc-995d-b933c2143339"), + contentType: "text"), + + ConfigurationModelFactory.ConfigurationSetting( + key: "TestKeyWithMultipleLabels", + label: "label1", + value: "TestValueForLabel1", + eTag: new ETag("bb203f2b-c113-44fc-995d-b933c2143339"), + contentType: "text"), + + ConfigurationModelFactory.ConfigurationSetting( + key: "TestKeyWithMultipleLabels", + label: "label2", + value: "TestValueForLabel2", + eTag: new ETag("bb203f2b-c113-44fc-995d-b933c2143339"), + contentType: "text") + }; + + [Fact] + public void AfdTests_DoesNotSupportCustomClientFactory() + { + var mockClientFactory = new Mock>(); + + var configBuilder = new ConfigurationBuilder() + .AddAzureAppConfiguration(options => + { + options.SetClientFactory(mockClientFactory.Object) + .ConnectAzureFrontDoor(TestHelpers.MockAfdEndpoint); + }); + + Exception exception = Assert.Throws(() => configBuilder.Build()); + Assert.IsType(exception.InnerException); + } + + [Fact] + public void AfdTests_DoesNotSupportLoadBalancing() + { + var configBuilder = new ConfigurationBuilder() + .AddAzureAppConfiguration(options => + { + options.ConnectAzureFrontDoor(TestHelpers.MockAfdEndpoint) + .LoadBalancingEnabled = true; + }); + + Exception exception = Assert.Throws(() => configBuilder.Build()); + Assert.IsType(exception.InnerException); + } + + [Fact] + public async Task AfdTests_RefreshWithRegisterAll() + { + var keyValueCollection = new List(_kvCollection); + var mockResponse = new Mock(); + var mockClient = new Mock(MockBehavior.Strict); + var mockAsyncPageable = new MockAsyncPageable(keyValueCollection); + + mockClient.Setup(c => c.GetConfigurationSettingsAsync(It.IsAny(), It.IsAny())) + .Returns(mockAsyncPageable); + + IConfigurationRefresher refresher = null; + AzureAppConfigurationOptions capturedOptions = null; + + var config = new ConfigurationBuilder() + .AddAzureAppConfiguration(options => + { + options.ConnectAzureFrontDoor(TestHelpers.MockAfdEndpoint) + .Select("TestKey*") + .ConfigureRefresh(refreshOptions => + { + refreshOptions.RegisterAll() + .SetRefreshInterval(TimeSpan.FromSeconds(1)); + }); + + options.ClientManager = TestHelpers.CreateMockedConfigurationClientManager(mockClient.Object); + + refresher = options.GetRefresher(); + capturedOptions = options; + }) + .Build(); + + Assert.Equal("TestValue1", config["TestKey1"]); + Assert.Equal("TestValue2", config["TestKey2"]); + Assert.Equal("TestValue3", config["TestKey3"]); + + // Verify AFD is enabled + Assert.True(capturedOptions.IsAfdEnabled); + + keyValueCollection[0] = TestHelpers.ChangeValue(keyValueCollection[0], "newValue"); + + mockAsyncPageable.UpdateCollection(keyValueCollection); + + // Wait for the cache to expire + await Task.Delay(1500); + + // Trigger refresh - this should set a token in the AFD token accessor + await refresher.RefreshAsync(); + + // Verify the configuration was updated + Assert.Equal("newValue", config["TestKey1"]); + } + + [Fact] + public async Task AfdTests_RefreshWithRegister() + { + var keyValueCollection = new List(_kvCollection); + var mockResponse = new Mock(); + var mockClient = new Mock(MockBehavior.Strict); + + Response GetSettingFromService(string k, string l, CancellationToken ct) + { + return Response.FromValue(keyValueCollection.FirstOrDefault(s => s.Key == k && s.Label == l), mockResponse.Object); + } + + Response GetIfChanged(ConfigurationSetting setting, bool _, CancellationToken cancellationToken) + { + var currentSetting = keyValueCollection.FirstOrDefault(s => s.Key == setting.Key && s.Label == setting.Label); + + if (currentSetting == null) + { + throw new RequestFailedException(new MockResponse(404)); + } + + return Response.FromValue(currentSetting, new MockResponse(200)); + } + + mockClient.Setup(c => c.GetConfigurationSettingsAsync(It.IsAny(), It.IsAny())) + .Returns(() => + { + var copy = new List(); + foreach (var setting in keyValueCollection) + { + copy.Add(TestHelpers.CloneSetting(setting)); + } + + return new MockAsyncPageable(copy); + }); + + mockClient.Setup(c => c.GetConfigurationSettingAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync((Func>)GetSettingFromService); + + mockClient.Setup(c => c.GetConfigurationSettingAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync((Func>)GetIfChanged); + + IConfigurationRefresher refresher = null; + AzureAppConfigurationOptions capturedOptions = null; + + var config = new ConfigurationBuilder() + .AddAzureAppConfiguration(options => + { + options.ConnectAzureFrontDoor(TestHelpers.MockAfdEndpoint) + .Select("TestKey*") + .ConfigureRefresh(refreshOptions => + { + refreshOptions.Register("TestKey1", "label", refreshAll: true) + .SetRefreshInterval(TimeSpan.FromSeconds(1)); + }); + + options.ClientManager = TestHelpers.CreateMockedConfigurationClientManager(mockClient.Object); + + refresher = options.GetRefresher(); + capturedOptions = options; + }) + .Build(); + + Assert.Equal("TestValue1", config["TestKey1"]); + Assert.Equal("TestValue2", config["TestKey2"]); + Assert.Equal("TestValue3", config["TestKey3"]); + + // Verify AFD is enabled + Assert.True(capturedOptions.IsAfdEnabled); + + keyValueCollection[0] = TestHelpers.ChangeValue(keyValueCollection[0], "newValue"); + + // Wait for the cache to expire + await Task.Delay(1500); + + // Trigger refresh - this should set a token in the AFD token accessor + await refresher.RefreshAsync(); + + // Verify the configuration was updated + Assert.Equal("newValue", config["TestKey1"]); + } + } +} diff --git a/tests/Tests.AzureAppConfiguration/Unit/TestHelper.cs b/tests/Tests.AzureAppConfiguration/Unit/TestHelper.cs index 9fd3f388e..53ff0162f 100644 --- a/tests/Tests.AzureAppConfiguration/Unit/TestHelper.cs +++ b/tests/Tests.AzureAppConfiguration/Unit/TestHelper.cs @@ -72,6 +72,8 @@ static public string CreateMockEndpointString(string endpoint = "https://azure.a return $"Endpoint={endpoint};Id=b1d9b31;Secret={returnValue}"; } + static public Uri MockAfdEndpoint => new Uri("https://afd.azurefd.net"); + static public void SerializeSetting(ref Utf8JsonWriter json, ConfigurationSetting setting) { json.WriteStartObject(); @@ -164,6 +166,7 @@ class MockAsyncPageable : AsyncPageable { private readonly List _collection = new List(); private int _status; + private string _etag; private readonly TimeSpan? _delay; public MockAsyncPageable(List collection, TimeSpan? delay = null) @@ -178,6 +181,7 @@ public MockAsyncPageable(List collection, TimeSpan? delay } _status = 200; + _etag = "\"" + Guid.NewGuid().ToString() + "\""; _delay = delay; } @@ -206,6 +210,8 @@ public void UpdateCollection(List newCollection) _collection.Add(newSetting); } + + _etag = "\"" + Guid.NewGuid().ToString() + "\""; } } @@ -216,7 +222,7 @@ public override async IAsyncEnumerable> AsPages(strin await Task.Delay(_delay.Value); } - yield return Page.FromValues(_collection, null, new MockResponse(_status)); + yield return Page.FromValues(_collection, null, new MockResponse(_status, _etag)); } }