diff --git a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs new file mode 100644 index 0000000000..41a05e137f --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs @@ -0,0 +1,56 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Authorization +{ + using System; + using global::Azure.Core; + + internal sealed class CosmosScopeProvider : IScopeProvider + { + private const string AadInvalidScopeErrorMessage = "AADSTS500011"; + private const string AadDefaultScope = "https://cosmos.azure.com/.default"; + private const string ScopeFormat = "https://{0}/.default"; + + private readonly string accountScope; + private readonly string overrideScope; + private string currentScope; + + public CosmosScopeProvider(Uri accountEndpoint) + { + this.overrideScope = ConfigurationManager.AADScopeOverrideValue(defaultValue: null); + this.accountScope = string.Format(ScopeFormat, accountEndpoint.Host); + this.currentScope = this.overrideScope ?? this.accountScope; + } + + public TokenRequestContext GetTokenRequestContext() + { + return new TokenRequestContext(new[] { this.currentScope }); + } + + public bool TryFallback(Exception exception) + { + // If override scope is set, never fallback + if (!string.IsNullOrEmpty(this.overrideScope)) + { + return false; + } + + // If already using fallback scope, do not fallback again + if (this.currentScope == CosmosScopeProvider.AadDefaultScope) + { + return false; + } + +#pragma warning disable CDX1003 // DontUseExceptionToString + if (exception.ToString().Contains(CosmosScopeProvider.AadInvalidScopeErrorMessage) == true) + { + this.currentScope = CosmosScopeProvider.AadDefaultScope; + return true; + } +#pragma warning restore CDX1003 // DontUseExceptionToString + + return false; + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs new file mode 100644 index 0000000000..545270f9eb --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs @@ -0,0 +1,14 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Authorization +{ + using System; + using global::Azure.Core; + + internal interface IScopeProvider + { + TokenRequestContext GetTokenRequestContext(); + bool TryFallback(Exception ex); + } +} diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index e844dbdec8..efb10277d8 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -10,7 +10,8 @@ namespace Microsoft.Azure.Cosmos using System.Threading; using System.Threading.Tasks; using global::Azure; - using global::Azure.Core; + using global::Azure.Core; + using Microsoft.Azure.Cosmos.Authorization; using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.Resource.CosmosExceptions; using Microsoft.Azure.Cosmos.Tracing; @@ -36,9 +37,7 @@ internal sealed class TokenCredentialCache : IDisposable // If the background refresh fails with less than a minute then just allow the request to hit the exception. public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1); - private const string ScopeFormat = "https://{0}/.default"; - - private readonly TokenRequestContext tokenRequestContext; + private readonly IScopeProvider scopeProvider; private readonly TokenCredential tokenCredential; private readonly CancellationTokenSource cancellationTokenSource; private readonly CancellationToken cancellationToken; @@ -51,7 +50,7 @@ internal sealed class TokenCredentialCache : IDisposable private Task? currentRefreshOperation = null; private AccessToken? cachedAccessToken = null; private bool isBackgroundTaskRunning = false; - private bool isDisposed = false; + private bool isDisposed = false; internal TokenCredentialCache( TokenCredential tokenCredential, @@ -65,14 +64,7 @@ internal TokenCredentialCache( throw new ArgumentNullException(nameof(accountEndpoint)); } - string? scopeOverride = ConfigurationManager.AADScopeOverrideValue(defaultValue: null); - - this.tokenRequestContext = new TokenRequestContext(new string[] - { - !string.IsNullOrEmpty(scopeOverride) - ? scopeOverride - : string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host) - }); + this.scopeProvider = new Microsoft.Azure.Cosmos.Authorization.CosmosScopeProvider(accountEndpoint); if (backgroundTokenCredentialRefreshInterval.HasValue) { @@ -129,7 +121,7 @@ public void Dispose() } this.cancellationTokenSource.Cancel(); - this.cancellationTokenSource.Dispose(); + this.cancellationTokenSource.Dispose(); this.isDisposed = true; } @@ -171,11 +163,13 @@ private async Task GetNewTokenAsync( private async ValueTask RefreshCachedTokenWithRetryHelperAsync( ITrace trace) - { + { + Exception? lastException = null; + const int totalRetryCount = 2; + TokenRequestContext tokenRequestContext = default; + try { - Exception? lastException = null; - const int totalRetryCount = 2; for (int retry = 0; retry < totalRetryCount; retry++) { if (this.cancellationToken.IsCancellationRequested) @@ -190,11 +184,13 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), component: TraceComponent.Authorization, level: Tracing.TraceLevel.Info)) - { + { try - { + { + tokenRequestContext = this.scopeProvider.GetTokenRequestContext(); + this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( - requestContext: this.tokenRequestContext, + requestContext: tokenRequestContext, cancellationToken: this.cancellationToken); if (!this.cachedAccessToken.HasValue) @@ -219,32 +215,15 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( return this.cachedAccessToken.Value; } - catch (RequestFailedException requestFailedException) - { - lastException = requestFailedException; - getTokenTrace.AddDatum( - $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedException.Message); - - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - // Don't retry on auth failures - if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || - requestFailedException.Status == (int)HttpStatusCode.Forbidden) - { - this.cachedAccessToken = default; - throw; - } - } catch (OperationCanceledException operationCancelled) { lastException = operationCancelled; getTokenTrace.AddDatum( $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - operationCancelled.Message); - - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + operationCancelled.Message); + + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); throw CosmosExceptionFactory.CreateRequestTimeoutException( message: ClientResources.FailedToGetAadToken, @@ -255,15 +234,29 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( innerException: lastException, trace: getTokenTrace); } - catch (Exception exception) - { - lastException = exception; - getTokenTrace.AddDatum( - $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - exception.Message); - - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + catch (Exception exception) + { + lastException = exception; + getTokenTrace.AddDatum( + $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + exception.Message); + + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + + // Don't retry on auth failures + if (exception is RequestFailedException requestFailedException && + (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || + requestFailedException.Status == (int)HttpStatusCode.Forbidden)) + { + this.cachedAccessToken = default; + throw; + } + bool didFallback = this.scopeProvider.TryFallback(exception); + + if (didFallback) + { + DefaultTrace.TraceInformation($"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}. Fallback attempted: {didFallback}"); + } } } } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs index 40e9a27e90..1e53f82a44 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs @@ -13,7 +13,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests using System.Web; using Documents.Client; using global::Azure; - using global::Azure.Core; + using global::Azure.Core; using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.IdentityModel.Tokens; using static Microsoft.Azure.Cosmos.SDK.EmulatorTests.TransportClientHelper; @@ -263,6 +263,185 @@ void GetAadTokenCallBack( Assert.IsTrue(ce.ToString().Contains(errorMessage)); } } - } + } + + [TestMethod] + public async Task Aad_OverrideScope_NoFallback_OnFailure_E2E() + { + // Arrange + (string endpoint, string authKey) = TestCommon.GetAccountInfo(); + string databaseId = "db-" + Guid.NewGuid(); + using (CosmosClient setupClient = TestCommon.CreateCosmosClient()) + { + await setupClient.CreateDatabaseAsync(databaseId); + } + + string overrideScope = "https://override/.default"; + string accountScope = $"https://{new Uri(endpoint).Host}/.default"; + int overrideScopeCount = 0; + int accountScopeCount = 0; + + string previous = Environment.GetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE"); + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", overrideScope); + + void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token) + { + string scope = context.Scopes[0]; + if (scope == overrideScope) + { + overrideScopeCount++; + throw new RequestFailedException(408, "Simulated override scope failure"); + } + if (scope == accountScope) + { + accountScopeCount++; + } + } + + LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential( + expectedScopes: new[] { overrideScope, accountScope }, + masterKey: authKey, + getTokenCallback: GetAadTokenCallBack); + + CosmosClientOptions clientOptions = new CosmosClientOptions + { + ConnectionMode = ConnectionMode.Gateway, + TokenCredentialBackgroundRefreshInterval = TimeSpan.FromSeconds(60) + }; + + try + { + using CosmosClient aadClient = new CosmosClient(endpoint, credential, clientOptions); + + try + { + // Act + ResponseMessage r = await aadClient.GetDatabase(databaseId).ReadStreamAsync(); + Assert.Fail("Expected failure when override scope token acquisition fails."); + } + catch (RequestFailedException ex) when (ex.Status == (int)HttpStatusCode.RequestTimeout || ex.Status == 408) + { + // Assert + Assert.IsTrue(overrideScopeCount > 0, "Override scope should have been attempted."); + Assert.AreEqual(0, accountScopeCount, "No fallback to account scope must occur when override is configured."); + } + } + finally + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", previous); + using CosmosClient cleanup = TestCommon.CreateCosmosClient(); + await cleanup.GetDatabase(databaseId).DeleteAsync(); + } + } + + [TestMethod] + public async Task Aad_AccountScope_Fallbacks_ToCosmosScope() + { + (string endpoint, string authKey) = TestCommon.GetAccountInfo(); + + string previous = Environment.GetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE"); + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", null); + + string accountScope = $"https://{new Uri(endpoint).Host}/.default"; + string aadScope = "https://cosmos.azure.com/.default"; + + int accountScopeCount = 0; + int cosmosScopeCount = 0; + + void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token) + { + string scope = context.Scopes[0]; + + if (string.Equals(scope, accountScope, StringComparison.OrdinalIgnoreCase)) + { + accountScopeCount++; + throw new Exception( + message: "AADSTS500011", + innerException: new Exception("AADSTS500011")); + } + + if (string.Equals(scope, aadScope, StringComparison.OrdinalIgnoreCase)) + { + cosmosScopeCount++; + } + } + + LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential( + expectedScopes: new[] { accountScope, aadScope }, + masterKey: authKey, + getTokenCallback: GetAadTokenCallBack); + + CosmosClientOptions clientOptions = new CosmosClientOptions + { + ConnectionMode = ConnectionMode.Gateway, + TokenCredentialBackgroundRefreshInterval = TimeSpan.FromSeconds(60) + }; + + try + { + using CosmosClient aadClient = new CosmosClient(endpoint, credential, clientOptions); + TokenCredentialCache tokenCredentialCache = + ((AuthorizationTokenProviderTokenCredential)aadClient.AuthorizationTokenProvider).tokenCredentialCache; + + string token = await tokenCredentialCache.GetTokenAsync(Tracing.Trace.GetRootTrace("account-fallback-to-cosmos-test")); + Assert.IsFalse(string.IsNullOrEmpty(token), "Fallback should succeed and produce a token."); + + Assert.IsTrue(accountScopeCount >= 1, "Account scope must be attempted first."); + Assert.IsTrue(cosmosScopeCount >= 1, "The client must fall back to cosmos.azure.com scope."); + } + finally + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", previous); + } + } + + [TestMethod] + public async Task Aad_AccountScope_Success_NoFallback() + { + // Arrange + (string endpoint, string authKey) = TestCommon.GetAccountInfo(); + + string accountScope = $"https://{new Uri(endpoint).Host}/.default"; + string aadScope = "https://cosmos.azure.com/.default"; + + int accountScopeCount = 0; + int cosmosScopeCount = 0; + + void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token) + { + string scope = context.Scopes[0]; + + if (string.Equals(scope, accountScope, StringComparison.OrdinalIgnoreCase)) + { + accountScopeCount++; + } + + if (string.Equals(scope, aadScope, StringComparison.OrdinalIgnoreCase)) + { + cosmosScopeCount++; + } + } + + LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential( + expectedScopes: new[] { accountScope }, + masterKey: authKey, + getTokenCallback: GetAadTokenCallBack); + + CosmosClientOptions clientOptions = new CosmosClientOptions + { + ConnectionMode = ConnectionMode.Gateway, + TokenCredentialBackgroundRefreshInterval = TimeSpan.FromSeconds(60) + }; + + using CosmosClient aadClient = new CosmosClient(endpoint, credential, clientOptions); + TokenCredentialCache tokenCredentialCache = + ((AuthorizationTokenProviderTokenCredential)aadClient.AuthorizationTokenProvider).tokenCredentialCache; + + string token = await tokenCredentialCache.GetTokenAsync(Tracing.Trace.GetRootTrace("account-scope-success-no-fallback")); + Assert.IsFalse(string.IsNullOrEmpty(token), "Token should be acquired successfully with account scope."); + + Assert.AreEqual(1, accountScopeCount, "Account scope must be used exactly once."); + Assert.AreEqual(0, cosmosScopeCount, "Cosmos scope must not be used (no fallback)."); + } } } \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs index c40a5e331f..0d8da95a0b 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs @@ -18,7 +18,7 @@ public class LocalEmulatorTokenCredential : TokenCredential private readonly DateTime? DefaultDateTime = null; private readonly Action GetTokenCallback; private readonly string masterKey; - private readonly string expectedScope; + private readonly string[] expectedScopes; internal LocalEmulatorTokenCredential( string expectedScope, @@ -29,7 +29,19 @@ internal LocalEmulatorTokenCredential( this.masterKey = masterKey; this.GetTokenCallback = getTokenCallback; this.DefaultDateTime = defaultDateTime; - this.expectedScope = expectedScope; + this.expectedScopes = new string[] { expectedScope }; + } + + internal LocalEmulatorTokenCredential( + string[] expectedScopes, + string masterKey = null, + Action getTokenCallback = null, + DateTime? defaultDateTime = null) + { + this.masterKey = masterKey; + this.GetTokenCallback = getTokenCallback; + this.DefaultDateTime = defaultDateTime; + this.expectedScopes = expectedScopes; } public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) @@ -44,7 +56,7 @@ public override ValueTask GetTokenAsync(TokenRequestContext request private AccessToken GetAccessToken(TokenRequestContext requestContext, CancellationToken cancellationToken) { - Assert.AreEqual(this.expectedScope, requestContext.Scopes.First()); + Assert.IsTrue(this.expectedScopes.Contains(requestContext.Scopes.First())); this.GetTokenCallback?.Invoke( requestContext, diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Authorization/CosmosScopeProviderTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Authorization/CosmosScopeProviderTests.cs new file mode 100644 index 0000000000..b4c99fbf8b --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Authorization/CosmosScopeProviderTests.cs @@ -0,0 +1,78 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Tests.Authorization +{ + using System; + using global::Azure.Core; + using Microsoft.Azure.Cosmos.Authorization; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class CosmosScopeProviderTests + { + private static readonly Uri TestAccountEndpoint = new Uri("https://testaccount.documents.azure.com:443/"); + + [DataTestMethod] + [DataRow("https://override/.default", "https://override/.default", DisplayName = "OverrideScope_Used")] + [DataRow(null, "https://testaccount.documents.azure.com/.default", DisplayName = "AccountScope_Used_WhenNoOverride")] + public void GetTokenRequestContext_UsesExpectedScope(string overrideScope, string expectedScope) + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", overrideScope); + + try + { + CosmosScopeProvider provider = new CosmosScopeProvider(TestAccountEndpoint); + TokenRequestContext context = provider.GetTokenRequestContext(); + Assert.AreEqual(expectedScope, context.Scopes[0]); + } + finally + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", null); + } + } + + [DataTestMethod] + [DataRow("https://override/.default", false, "AADSTS500011", "https://override/.default", DisplayName = "OverrideScope_NeverFallback")] + [DataRow(null, true, "AADSTS500011", "https://cosmos.azure.com/.default", DisplayName = "AccountScope_FallbacksToAadDefault")] + [DataRow(null, false, "SomeOtherError", "https://testaccount.documents.azure.com/.default", DisplayName = "AccountScope_NoFallbackOnOtherError")] + public void Test_TryFallback_Behavior( + string overrideScope, + bool expectFallback, + string exceptionMessage, + string expectedScope) + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", overrideScope); + + try + { + CosmosScopeProvider provider = new CosmosScopeProvider(TestAccountEndpoint); + + bool didFallback = provider.TryFallback(new Exception(exceptionMessage)); + + Assert.AreEqual(expectFallback, didFallback, "Fallback result mismatch."); + Assert.AreEqual(expectedScope, provider.GetTokenRequestContext().Scopes[0]); + } + finally + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", null); + } + } + + [TestMethod] + public void TryFallback_DoesNotFallback_WhenAlreadyUsingAadDefault() + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", null); + CosmosScopeProvider provider = new CosmosScopeProvider(TestAccountEndpoint); + + provider.TryFallback(new Exception("AADSTS500011")); + Assert.AreEqual("https://cosmos.azure.com/.default", provider.GetTokenRequestContext().Scopes[0]); + + // Act + bool didFallbackAgain = provider.TryFallback(new Exception("AADSTS500011")); + + Assert.IsFalse(didFallbackAgain, "Should not fallback again when already using AadDefault scope."); + Assert.AreEqual("https://cosmos.azure.com/.default", provider.GetTokenRequestContext().Scopes[0]); + } + } +} \ No newline at end of file