From 5b957a2e5a785ab994306c221a40a3555f9205f2 Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Tue, 19 Aug 2025 15:31:54 -0700 Subject: [PATCH 01/13] Add fallback mechanism for scope override. --- .../src/Authorization/TokenCredentialCache.cs | 337 +++++++++++------- .../CosmosAadTests.cs | 134 ++++++- .../Utils/LocalEmulatorTokenCredential.cs | 22 +- 3 files changed, 365 insertions(+), 128 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index e844dbdec8..3e0dda90cd 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -37,12 +37,16 @@ internal sealed class TokenCredentialCache : IDisposable public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1); private const string ScopeFormat = "https://{0}/.default"; + private const string AadInvalidScopeErrorMessage = "AADSTS500011"; + private const string AadDefaultScope = "https://cosmos.azure.com/.default"; private readonly TokenRequestContext tokenRequestContext; private readonly TokenCredential tokenCredential; private readonly CancellationTokenSource cancellationTokenSource; private readonly CancellationToken cancellationToken; - private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval; + private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval; + private readonly string accountScope; + private readonly bool isOverrideScopeProvided; private readonly SemaphoreSlim isTokenRefreshingLock = new SemaphoreSlim(1); private readonly object backgroundRefreshLock = new object(); @@ -67,11 +71,12 @@ internal TokenCredentialCache( string? scopeOverride = ConfigurationManager.AADScopeOverrideValue(defaultValue: null); + this.accountScope = string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host); + this.isOverrideScopeProvided = !string.IsNullOrEmpty(scopeOverride); + this.tokenRequestContext = new TokenRequestContext(new string[] { - !string.IsNullOrEmpty(scopeOverride) - ? scopeOverride - : string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host) + this.isOverrideScopeProvided ? scopeOverride! : this.accountScope }); if (backgroundTokenCredentialRefreshInterval.HasValue) @@ -167,127 +172,211 @@ private async Task GetNewTokenAsync( } return await currentTask; - } - - private async ValueTask RefreshCachedTokenWithRetryHelperAsync( - ITrace trace) - { - try - { - Exception? lastException = null; - const int totalRetryCount = 2; - for (int retry = 0; retry < totalRetryCount; retry++) - { - if (this.cancellationToken.IsCancellationRequested) - { - DefaultTrace.TraceInformation( - "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); - - break; - } - - using (ITrace getTokenTrace = trace.StartChild( - name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), - component: TraceComponent.Authorization, - level: Tracing.TraceLevel.Info)) - { - try - { - this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( - requestContext: this.tokenRequestContext, - cancellationToken: this.cancellationToken); - - if (!this.cachedAccessToken.HasValue) - { - throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); - } - - if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) - { - throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}"); - } - - if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) - { - double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; - - // Ensure the background refresh interval is a valid range. - refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); - refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); - this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); - } - - return this.cachedAccessToken.Value; - } - catch (RequestFailedException requestFailedException) - { - lastException = requestFailedException; - getTokenTrace.AddDatum( - $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedException.Message); - - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - // Don't retry on auth failures - if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || - requestFailedException.Status == (int)HttpStatusCode.Forbidden) - { - this.cachedAccessToken = default; - throw; - } - } - catch (OperationCanceledException operationCancelled) - { - lastException = operationCancelled; - getTokenTrace.AddDatum( - $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - operationCancelled.Message); - - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - throw CosmosExceptionFactory.CreateRequestTimeoutException( - message: ClientResources.FailedToGetAadToken, - headers: new Headers() - { - SubStatusCode = SubStatusCodes.FailedToGetAadToken, - }, - innerException: lastException, - trace: getTokenTrace); - } - catch (Exception exception) - { - lastException = exception; - getTokenTrace.AddDatum( - $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - exception.Message); - - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - } - } - } - - if (lastException == null) - { - throw new ArgumentException("Last exception is null."); - } - + } + + private void ApplyTokenAndSetRefreshInterval(AccessToken token) + { + this.cachedAccessToken = token; + + if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) + { + double refreshIntervalInSeconds = + (token.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; + + refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); + refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); + + this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); + } + } + + private async ValueTask RefreshCachedTokenWithRetryHelperAsync(ITrace trace) + { + try + { + Exception? lastException = null; + const int totalRetryCount = 2; + for (int retry = 0; retry < totalRetryCount; retry++) + { + if (this.cancellationToken.IsCancellationRequested) + { + DefaultTrace.TraceInformation("Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); + break; + } + + using (ITrace getTokenTrace = trace.StartChild( + name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), + component: TraceComponent.Authorization, + level: Tracing.TraceLevel.Info)) + { + bool shouldAttemptAadFallback = false; + + try + { + AccessToken? tokenNullable = await this.tokenCredential.GetTokenAsync( + requestContext: this.tokenRequestContext, + cancellationToken: this.cancellationToken); + + if (!tokenNullable.HasValue) + { + throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); + } + + if (tokenNullable.Value.ExpiresOn < DateTimeOffset.UtcNow) + { + throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{tokenNullable.Value.ExpiresOn:O}"); + } + + this.ApplyTokenAndSetRefreshInterval(tokenNullable.Value); + return tokenNullable.Value; + } + catch (RequestFailedException requestFailedException) + { + lastException = requestFailedException; + getTokenTrace.AddDatum( + $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + requestFailedException.Message); + + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + + if (this.isOverrideScopeProvided) + { + if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || + requestFailedException.Status == (int)HttpStatusCode.Forbidden) + { + this.cachedAccessToken = default; + throw; + } + + continue; + } + } + catch (OperationCanceledException operationCancelled) + { + lastException = operationCancelled; + getTokenTrace.AddDatum( + $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + operationCancelled.Message); + + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + + throw CosmosExceptionFactory.CreateRequestTimeoutException( + message: ClientResources.FailedToGetAadToken, + headers: new Headers() { SubStatusCode = SubStatusCodes.FailedToGetAadToken, }, + innerException: lastException, + trace: getTokenTrace); + } + catch (Exception exception) + { + lastException = exception; + getTokenTrace.AddDatum( + $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + exception.Message); + + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + + if (this.isOverrideScopeProvided) + { + continue; + } + + if (exception.InnerException?.Message.Contains(AadInvalidScopeErrorMessage) == true) + { + shouldAttemptAadFallback = true; + } + } + + if (!this.isOverrideScopeProvided && shouldAttemptAadFallback) + { + TokenRequestContext fallbackContext = new TokenRequestContext(new[] { AadDefaultScope }); + + try + { + AccessToken? tokenNullable = await this.tokenCredential.GetTokenAsync( + requestContext: fallbackContext, + cancellationToken: this.cancellationToken); + + if (!tokenNullable.HasValue) + { + throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); + } + + if (tokenNullable.Value.ExpiresOn < DateTimeOffset.UtcNow) + { + throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{tokenNullable.Value.ExpiresOn:O}"); + } + + this.ApplyTokenAndSetRefreshInterval(tokenNullable.Value); + return tokenNullable.Value; + } + catch (RequestFailedException requestFailedExceptionFallback) + { + lastException = requestFailedExceptionFallback; + getTokenTrace.AddDatum( + $"RequestFailedException (fallback) at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + requestFailedExceptionFallback.Message); + + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException on fallback. scope = {string.Join(";", fallbackContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + + if (requestFailedExceptionFallback.Status == (int)HttpStatusCode.Unauthorized || + requestFailedExceptionFallback.Status == (int)HttpStatusCode.Forbidden) + { + this.cachedAccessToken = default; + throw; + } + } + catch (OperationCanceledException operationCancelledFallback) + { + lastException = operationCancelledFallback; + getTokenTrace.AddDatum( + $"OperationCanceledException (fallback) at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + operationCancelledFallback.Message); + + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed on fallback. scope = {string.Join(";", fallbackContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + + throw CosmosExceptionFactory.CreateRequestTimeoutException( + message: ClientResources.FailedToGetAadToken, + headers: new Headers() { SubStatusCode = SubStatusCodes.FailedToGetAadToken, }, + innerException: lastException, + trace: getTokenTrace); + } + catch (Exception exceptionFallback) + { + lastException = exceptionFallback; + getTokenTrace.AddDatum( + $"Exception (fallback) at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + exceptionFallback.Message); + + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed on fallback. scope = {string.Join(";", fallbackContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + } + } + } + } + + if (lastException == null) + { + throw new ArgumentException("Last exception is null."); + } + // The retries have been exhausted. Throw the last exception. - throw lastException; - } - finally - { - try - { - await this.isTokenRefreshingLock.WaitAsync(); - this.currentRefreshOperation = null; - } - finally - { - this.isTokenRefreshingLock.Release(); - } - } + throw lastException; + } + finally + { + try + { + await this.isTokenRefreshingLock.WaitAsync(); + this.currentRefreshOperation = null; + } + finally + { + this.isTokenRefreshingLock.Release(); + } + } } #pragma warning disable VSTHRD100 // Avoid async void methods diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs index 40e9a27e90..9c1c56a357 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs @@ -13,7 +13,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests using System.Web; using Documents.Client; using global::Azure; - using global::Azure.Core; + using global::Azure.Core; using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.IdentityModel.Tokens; using static Microsoft.Azure.Cosmos.SDK.EmulatorTests.TransportClientHelper; @@ -263,6 +263,136 @@ void GetAadTokenCallBack( Assert.IsTrue(ce.ToString().Contains(errorMessage)); } } - } + } + + [TestMethod] + public async Task Aad_OverrideScope_NoFallback_OnFailure_E2E() + { + // Arrange + (string endpoint, string authKey) = TestCommon.GetAccountInfo(); + string databaseId = "db-" + Guid.NewGuid(); + using (CosmosClient setupClient = TestCommon.CreateCosmosClient()) + { + await setupClient.CreateDatabaseAsync(databaseId); + } + + string overrideScope = "https://override/.default"; + string accountScope = $"https://{new Uri(endpoint).Host}/.default"; + int overrideScopeCount = 0; + int accountScopeCount = 0; + + string previous = Environment.GetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE"); + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", overrideScope); + + void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token) + { + string scope = context.Scopes[0]; + if (scope == overrideScope) + { + overrideScopeCount++; + throw new RequestFailedException(408, "Simulated override scope failure"); + } + if (scope == accountScope) + { + accountScopeCount++; + } + } + + LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential( + expectedScopes: new[] { overrideScope, accountScope }, + masterKey: authKey, + getTokenCallback: GetAadTokenCallBack); + + CosmosClientOptions clientOptions = new CosmosClientOptions + { + ConnectionMode = ConnectionMode.Gateway, + TokenCredentialBackgroundRefreshInterval = TimeSpan.FromSeconds(60) + }; + + try + { + using CosmosClient aadClient = new CosmosClient(endpoint, credential, clientOptions); + + try + { + // Act + ResponseMessage r = await aadClient.GetDatabase(databaseId).ReadStreamAsync(); + Assert.Fail("Expected failure when override scope token acquisition fails."); + } + catch (RequestFailedException ex) when (ex.Status == (int)HttpStatusCode.RequestTimeout || ex.Status == 408) + { + // Assert + Assert.IsTrue(overrideScopeCount > 0, "Override scope should have been attempted."); + Assert.AreEqual(0, accountScopeCount, "No fallback to account scope must occur when override is configured."); + } + } + finally + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", previous); + using CosmosClient cleanup = TestCommon.CreateCosmosClient(); + await cleanup.GetDatabase(databaseId).DeleteAsync(); + } + } + + [TestMethod] + public async Task Aad_AccountScope_Fallbacks_ToCosmos_OnSniCertRevoked_Unit() + { + (string endpoint, string authKey) = TestCommon.GetAccountInfo(); + + string previous = Environment.GetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE"); + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", null); + + string accountScope = $"https://{new Uri(endpoint).Host}/.default"; + string aadScope = "https://cosmos.azure.com/.default"; + + int accountScopeCount = 0; + int cosmosScopeCount = 0; + + void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token) + { + string scope = context.Scopes[0]; + + if (string.Equals(scope, accountScope, StringComparison.OrdinalIgnoreCase)) + { + accountScopeCount++; + throw new Exception( + message: "AADSTS500011", + innerException: new Exception("AADSTS500011")); + } + + if (string.Equals(scope, aadScope, StringComparison.OrdinalIgnoreCase)) + { + cosmosScopeCount++; + } + } + + LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential( + expectedScopes: new[] { accountScope, aadScope }, + masterKey: authKey, + getTokenCallback: GetAadTokenCallBack); + + CosmosClientOptions clientOptions = new CosmosClientOptions + { + ConnectionMode = ConnectionMode.Gateway, + TokenCredentialBackgroundRefreshInterval = TimeSpan.FromSeconds(60) + }; + + try + { + using CosmosClient aadClient = new CosmosClient(endpoint, credential, clientOptions); + TokenCredentialCache tokenCredentialCache = + ((AuthorizationTokenProviderTokenCredential)aadClient.AuthorizationTokenProvider).tokenCredentialCache; + + string token = await tokenCredentialCache.GetTokenAsync(Tracing.Trace.GetRootTrace("account-fallback-to-cosmos-test")); + Assert.IsFalse(string.IsNullOrEmpty(token), "Fallback should succeed and produce a token."); + + Assert.IsTrue(accountScopeCount >= 1, "Account scope must be attempted first."); + Assert.IsTrue(cosmosScopeCount >= 1, "The client must fall back to cosmos.azure.com scope."); + } + finally + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", previous); + } + } } } \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs index c40a5e331f..5fda33b046 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs @@ -19,6 +19,7 @@ public class LocalEmulatorTokenCredential : TokenCredential private readonly Action GetTokenCallback; private readonly string masterKey; private readonly string expectedScope; + private readonly string[] expectedScopes; internal LocalEmulatorTokenCredential( string expectedScope, @@ -29,7 +30,21 @@ internal LocalEmulatorTokenCredential( this.masterKey = masterKey; this.GetTokenCallback = getTokenCallback; this.DefaultDateTime = defaultDateTime; - this.expectedScope = expectedScope; + this.expectedScope = expectedScope; + this.expectedScopes = null; + } + + internal LocalEmulatorTokenCredential( + string[] expectedScopes, + string masterKey = null, + Action getTokenCallback = null, + DateTime? defaultDateTime = null) + { + this.masterKey = masterKey; + this.GetTokenCallback = getTokenCallback; + this.DefaultDateTime = defaultDateTime; + this.expectedScope = null; + this.expectedScopes = expectedScopes; } public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) @@ -44,7 +59,10 @@ public override ValueTask GetTokenAsync(TokenRequestContext request private AccessToken GetAccessToken(TokenRequestContext requestContext, CancellationToken cancellationToken) { - Assert.AreEqual(this.expectedScope, requestContext.Scopes.First()); + if (this.expectedScope != null) + { + Assert.AreEqual(this.expectedScope, requestContext.Scopes.First()); + } this.GetTokenCallback?.Invoke( requestContext, From a7d3c647e00882956d831457e3475ca79b08082d Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Tue, 19 Aug 2025 15:52:41 -0700 Subject: [PATCH 02/13] Fix formatting. --- .../src/Authorization/TokenCredentialCache.cs | 166 +++++++++--------- 1 file changed, 86 insertions(+), 80 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index 3e0dda90cd..4bf8d36bf6 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -36,9 +36,9 @@ internal sealed class TokenCredentialCache : IDisposable // If the background refresh fails with less than a minute then just allow the request to hit the exception. public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1); - private const string ScopeFormat = "https://{0}/.default"; + private const string ScopeFormat = "https://{0}/.default"; private const string AadInvalidScopeErrorMessage = "AADSTS500011"; - private const string AadDefaultScope = "https://cosmos.azure.com/.default"; + private const string AadDefaultScope = "https://cosmos.azure.com/.default"; private readonly TokenRequestContext tokenRequestContext; private readonly TokenCredential tokenCredential; @@ -46,7 +46,7 @@ internal sealed class TokenCredentialCache : IDisposable private readonly CancellationToken cancellationToken; private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval; private readonly string accountScope; - private readonly bool isOverrideScopeProvided; + private readonly bool isOverrideScopeProvided; private readonly SemaphoreSlim isTokenRefreshingLock = new SemaphoreSlim(1); private readonly object backgroundRefreshLock = new object(); @@ -188,26 +188,29 @@ private void ApplyTokenAndSetRefreshInterval(AccessToken token) this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); } - } - - private async ValueTask RefreshCachedTokenWithRetryHelperAsync(ITrace trace) - { - try - { - Exception? lastException = null; - const int totalRetryCount = 2; - for (int retry = 0; retry < totalRetryCount; retry++) - { - if (this.cancellationToken.IsCancellationRequested) - { - DefaultTrace.TraceInformation("Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); - break; - } - - using (ITrace getTokenTrace = trace.StartChild( - name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), - component: TraceComponent.Authorization, - level: Tracing.TraceLevel.Info)) + } + + private async ValueTask RefreshCachedTokenWithRetryHelperAsync( + ITrace trace) + { + try + { + Exception? lastException = null; + const int totalRetryCount = 2; + for (int retry = 0; retry < totalRetryCount; retry++) + { + if (this.cancellationToken.IsCancellationRequested) + { + DefaultTrace.TraceInformation( + "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); + + break; + } + + using (ITrace getTokenTrace = trace.StartChild( + name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), + component: TraceComponent.Authorization, + level: Tracing.TraceLevel.Info)) { bool shouldAttemptAadFallback = false; @@ -228,17 +231,17 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync(ITra } this.ApplyTokenAndSetRefreshInterval(tokenNullable.Value); - return tokenNullable.Value; - } - catch (RequestFailedException requestFailedException) - { - lastException = requestFailedException; - getTokenTrace.AddDatum( - $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedException.Message); - - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - + return tokenNullable.Value; + } + catch (RequestFailedException requestFailedException) + { + lastException = requestFailedException; + getTokenTrace.AddDatum( + $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + requestFailedException.Message); + + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + if (this.isOverrideScopeProvided) { if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || @@ -250,31 +253,34 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync(ITra continue; } - } - catch (OperationCanceledException operationCancelled) - { - lastException = operationCancelled; - getTokenTrace.AddDatum( - $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - operationCancelled.Message); - - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - throw CosmosExceptionFactory.CreateRequestTimeoutException( - message: ClientResources.FailedToGetAadToken, - headers: new Headers() { SubStatusCode = SubStatusCodes.FailedToGetAadToken, }, - innerException: lastException, - trace: getTokenTrace); - } - catch (Exception exception) - { - lastException = exception; - getTokenTrace.AddDatum( - $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - exception.Message); - - DefaultTrace.TraceError( + } + catch (OperationCanceledException operationCancelled) + { + lastException = operationCancelled; + getTokenTrace.AddDatum( + $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + operationCancelled.Message); + + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + + throw CosmosExceptionFactory.CreateRequestTimeoutException( + message: ClientResources.FailedToGetAadToken, + headers: new Headers() + { + SubStatusCode = SubStatusCodes.FailedToGetAadToken, + }, + innerException: lastException, + trace: getTokenTrace); + } + catch (Exception exception) + { + lastException = exception; + getTokenTrace.AddDatum( + $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + exception.Message); + + DefaultTrace.TraceError( $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); if (this.isOverrideScopeProvided) @@ -355,28 +361,28 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync(ITra } } } - } - - if (lastException == null) - { - throw new ArgumentException("Last exception is null."); - } - + } + + if (lastException == null) + { + throw new ArgumentException("Last exception is null."); + } + // The retries have been exhausted. Throw the last exception. - throw lastException; - } - finally - { - try - { - await this.isTokenRefreshingLock.WaitAsync(); - this.currentRefreshOperation = null; - } - finally - { - this.isTokenRefreshingLock.Release(); - } - } + throw lastException; + } + finally + { + try + { + await this.isTokenRefreshingLock.WaitAsync(); + this.currentRefreshOperation = null; + } + finally + { + this.isTokenRefreshingLock.Release(); + } + } } #pragma warning disable VSTHRD100 // Avoid async void methods From 889fd4d95e1c1d1154eb8c39968e0d56d533b7e5 Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Wed, 20 Aug 2025 22:45:38 -0700 Subject: [PATCH 03/13] Update the fallback logic --- .../src/Authorization/CosmosScopeProvider.cs | 63 ++++ .../src/Authorization/IScopeProvider.cs | 16 + .../src/Authorization/TokenCredentialCache.cs | 350 +++++++----------- 3 files changed, 209 insertions(+), 220 deletions(-) create mode 100644 Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs create mode 100644 Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs diff --git a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs new file mode 100644 index 0000000000..c5357b67f7 --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs @@ -0,0 +1,63 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Authorization +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + using global::Azure.Core; + + internal sealed class CosmosScopeProvider : IScopeProvider + { + private const string AadInvalidScopeErrorMessage = "AADSTS500011"; + private const string AadDefaultScope = "https://cosmos.azure.com/.default"; + private const string ScopeFormat = "https://{0}/.default"; + + private readonly string accountScope; + private readonly string overrideScope; + private string currentScope; + private bool fallbackAttempted = false; + + public CosmosScopeProvider(Uri accountEndpoint) + { + this.overrideScope = ConfigurationManager.AADScopeOverrideValue(defaultValue: null); + this.accountScope = string.Format(ScopeFormat, accountEndpoint.Host); + this.currentScope = this.overrideScope ?? this.accountScope; + } + + public TokenRequestContext GetTokenRequestContext() + { + return new TokenRequestContext(new[] { this.currentScope }); + } + + public bool TryFallback(Exception ex) + { + // If override scope is set, never fallback + if (!string.IsNullOrEmpty(this.overrideScope)) + { + return false; + } + + // If already attempted fallback, do not fallback again + if (this.fallbackAttempted) + { + return false; + } + + if (ex.InnerException?.Message.Contains(AadInvalidScopeErrorMessage) == true) + { + this.currentScope = AadDefaultScope; + this.fallbackAttempted = true; + return true; + } + + return false; + } + + public void Dispose() + { + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs new file mode 100644 index 0000000000..0964b6aec5 --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs @@ -0,0 +1,16 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Authorization +{ + using System; + using System.Collections.Generic; + using System.Text; + using global::Azure.Core; + + internal interface IScopeProvider : IDisposable + { + TokenRequestContext GetTokenRequestContext(); + bool TryFallback(Exception ex); + } +} diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index 4bf8d36bf6..26dd40cbff 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -10,7 +10,8 @@ namespace Microsoft.Azure.Cosmos using System.Threading; using System.Threading.Tasks; using global::Azure; - using global::Azure.Core; + using global::Azure.Core; + using Microsoft.Azure.Cosmos.Authorization; using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.Resource.CosmosExceptions; using Microsoft.Azure.Cosmos.Tracing; @@ -36,67 +37,53 @@ internal sealed class TokenCredentialCache : IDisposable // If the background refresh fails with less than a minute then just allow the request to hit the exception. public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1); - private const string ScopeFormat = "https://{0}/.default"; - private const string AadInvalidScopeErrorMessage = "AADSTS500011"; - private const string AadDefaultScope = "https://cosmos.azure.com/.default"; - - private readonly TokenRequestContext tokenRequestContext; private readonly TokenCredential tokenCredential; private readonly CancellationTokenSource cancellationTokenSource; private readonly CancellationToken cancellationToken; private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval; - private readonly string accountScope; - private readonly bool isOverrideScopeProvided; private readonly SemaphoreSlim isTokenRefreshingLock = new SemaphoreSlim(1); - private readonly object backgroundRefreshLock = new object(); + private readonly object backgroundRefreshLock = new object(); + private readonly IScopeProvider scopeProvider; private TimeSpan? systemBackgroundTokenCredentialRefreshInterval; private Task? currentRefreshOperation = null; private AccessToken? cachedAccessToken = null; private bool isBackgroundTaskRunning = false; private bool isDisposed = false; - - internal TokenCredentialCache( - TokenCredential tokenCredential, - Uri accountEndpoint, - TimeSpan? backgroundTokenCredentialRefreshInterval) - { - this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential)); - - if (accountEndpoint == null) - { - throw new ArgumentNullException(nameof(accountEndpoint)); - } - string? scopeOverride = ConfigurationManager.AADScopeOverrideValue(defaultValue: null); + internal TokenCredentialCache( + TokenCredential tokenCredential, + Uri accountEndpoint, + TimeSpan? backgroundTokenCredentialRefreshInterval) + { + this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential)); + + if (accountEndpoint == null) + { + throw new ArgumentNullException(nameof(accountEndpoint)); + } - this.accountScope = string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host); - this.isOverrideScopeProvided = !string.IsNullOrEmpty(scopeOverride); + this.scopeProvider = new Microsoft.Azure.Cosmos.Authorization.CosmosScopeProvider(accountEndpoint); - this.tokenRequestContext = new TokenRequestContext(new string[] + if (backgroundTokenCredentialRefreshInterval.HasValue) { - this.isOverrideScopeProvided ? scopeOverride! : this.accountScope - }); + if (backgroundTokenCredentialRefreshInterval.Value <= TimeSpan.Zero) + { + throw new ArgumentException($"{nameof(backgroundTokenCredentialRefreshInterval)} must be a positive value greater than 0. Value '{backgroundTokenCredentialRefreshInterval.Value.TotalMilliseconds}'."); + } - if (backgroundTokenCredentialRefreshInterval.HasValue) - { - if (backgroundTokenCredentialRefreshInterval.Value <= TimeSpan.Zero) - { - throw new ArgumentException($"{nameof(backgroundTokenCredentialRefreshInterval)} must be a positive value greater than 0. Value '{backgroundTokenCredentialRefreshInterval.Value.TotalMilliseconds}'."); - } - - // TimeSpan.MaxValue disables the background refresh - if (backgroundTokenCredentialRefreshInterval.Value > TokenCredentialCache.MaxBackgroundRefreshInterval && - backgroundTokenCredentialRefreshInterval.Value != TimeSpan.MaxValue) - { - throw new ArgumentException($"{nameof(backgroundTokenCredentialRefreshInterval)} must be less than or equal to {TokenCredentialCache.MaxBackgroundRefreshInterval}. Value '{backgroundTokenCredentialRefreshInterval.Value}'."); - } - } - - this.userDefinedBackgroundTokenCredentialRefreshInterval = backgroundTokenCredentialRefreshInterval; - this.cancellationTokenSource = new CancellationTokenSource(); - this.cancellationToken = this.cancellationTokenSource.Token; + // TimeSpan.MaxValue disables the background refresh + if (backgroundTokenCredentialRefreshInterval.Value > TokenCredentialCache.MaxBackgroundRefreshInterval && + backgroundTokenCredentialRefreshInterval.Value != TimeSpan.MaxValue) + { + throw new ArgumentException($"{nameof(backgroundTokenCredentialRefreshInterval)} must be less than or equal to {TokenCredentialCache.MaxBackgroundRefreshInterval}. Value '{backgroundTokenCredentialRefreshInterval.Value}'."); + } + } + + this.userDefinedBackgroundTokenCredentialRefreshInterval = backgroundTokenCredentialRefreshInterval; + this.cancellationTokenSource = new CancellationTokenSource(); + this.cancellationToken = this.cancellationTokenSource.Token; } public TimeSpan? BackgroundTokenCredentialRefreshInterval => @@ -172,217 +159,140 @@ private async Task GetNewTokenAsync( } return await currentTask; - } - - private void ApplyTokenAndSetRefreshInterval(AccessToken token) + } + + private async ValueTask RefreshCachedTokenWithRetryHelperAsync(ITrace trace) { - this.cachedAccessToken = token; - - if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) + try { - double refreshIntervalInSeconds = - (token.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; + Exception? lastException = null; + const int totalRetryCount = 2; + for (int retry = 0; retry < totalRetryCount; retry++) + { + if (this.cancellationToken.IsCancellationRequested) + { + DefaultTrace.TraceInformation( + "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); - refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); - refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); + break; + } - this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); - } - } - - private async ValueTask RefreshCachedTokenWithRetryHelperAsync( - ITrace trace) - { - try - { - Exception? lastException = null; - const int totalRetryCount = 2; - for (int retry = 0; retry < totalRetryCount; retry++) - { - if (this.cancellationToken.IsCancellationRequested) - { - DefaultTrace.TraceInformation( - "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); - - break; - } - - using (ITrace getTokenTrace = trace.StartChild( - name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), - component: TraceComponent.Authorization, - level: Tracing.TraceLevel.Info)) + using (ITrace getTokenTrace = trace.StartChild( + name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), + component: TraceComponent.Authorization, + level: Tracing.TraceLevel.Info)) { - bool shouldAttemptAadFallback = false; - try { - AccessToken? tokenNullable = await this.tokenCredential.GetTokenAsync( - requestContext: this.tokenRequestContext, + TokenRequestContext tokenRequestContext = this.scopeProvider.GetTokenRequestContext(); + + this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( + requestContext: tokenRequestContext, cancellationToken: this.cancellationToken); - if (!tokenNullable.HasValue) + if (!this.cachedAccessToken.HasValue) { throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); } - if (tokenNullable.Value.ExpiresOn < DateTimeOffset.UtcNow) + if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) { - throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{tokenNullable.Value.ExpiresOn:O}"); + throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}"); } - this.ApplyTokenAndSetRefreshInterval(tokenNullable.Value); - return tokenNullable.Value; - } - catch (RequestFailedException requestFailedException) - { - lastException = requestFailedException; - getTokenTrace.AddDatum( - $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedException.Message); - - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - if (this.isOverrideScopeProvided) + if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) { - if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || - requestFailedException.Status == (int)HttpStatusCode.Forbidden) - { - this.cachedAccessToken = default; - throw; - } + double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; - continue; + // Ensure the background refresh interval is a valid range. + refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); + refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); + this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); } - } - catch (OperationCanceledException operationCancelled) - { - lastException = operationCancelled; - getTokenTrace.AddDatum( - $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - operationCancelled.Message); - - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - throw CosmosExceptionFactory.CreateRequestTimeoutException( - message: ClientResources.FailedToGetAadToken, - headers: new Headers() - { - SubStatusCode = SubStatusCodes.FailedToGetAadToken, - }, - innerException: lastException, - trace: getTokenTrace); - } - catch (Exception exception) - { - lastException = exception; - getTokenTrace.AddDatum( - $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - exception.Message); - - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - if (this.isOverrideScopeProvided) + return this.cachedAccessToken.Value; + } + catch (RequestFailedException requestFailedException) + { + lastException = requestFailedException; + getTokenTrace.AddDatum( + $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + requestFailedException.Message); + + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + + // Don't retry on auth failures + if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || + requestFailedException.Status == (int)HttpStatusCode.Forbidden) { - continue; + this.cachedAccessToken = default; + throw; } - if (exception.InnerException?.Message.Contains(AadInvalidScopeErrorMessage) == true) + // Fallback logic + if (this.scopeProvider.TryFallback(requestFailedException)) { - shouldAttemptAadFallback = true; + continue; } } - - if (!this.isOverrideScopeProvided && shouldAttemptAadFallback) + catch (OperationCanceledException operationCancelled) { - TokenRequestContext fallbackContext = new TokenRequestContext(new[] { AadDefaultScope }); + lastException = operationCancelled; + getTokenTrace.AddDatum( + $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + operationCancelled.Message); - try - { - AccessToken? tokenNullable = await this.tokenCredential.GetTokenAsync( - requestContext: fallbackContext, - cancellationToken: this.cancellationToken); - - if (!tokenNullable.HasValue) - { - throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); - } - - if (tokenNullable.Value.ExpiresOn < DateTimeOffset.UtcNow) - { - throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{tokenNullable.Value.ExpiresOn:O}"); - } - - this.ApplyTokenAndSetRefreshInterval(tokenNullable.Value); - return tokenNullable.Value; - } - catch (RequestFailedException requestFailedExceptionFallback) - { - lastException = requestFailedExceptionFallback; - getTokenTrace.AddDatum( - $"RequestFailedException (fallback) at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedExceptionFallback.Message); - - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException on fallback. scope = {string.Join(";", fallbackContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - if (requestFailedExceptionFallback.Status == (int)HttpStatusCode.Unauthorized || - requestFailedExceptionFallback.Status == (int)HttpStatusCode.Forbidden) + throw CosmosExceptionFactory.CreateRequestTimeoutException( + message: ClientResources.FailedToGetAadToken, + headers: new Headers() { - this.cachedAccessToken = default; - throw; - } - } - catch (OperationCanceledException operationCancelledFallback) - { - lastException = operationCancelledFallback; - getTokenTrace.AddDatum( - $"OperationCanceledException (fallback) at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - operationCancelledFallback.Message); + SubStatusCode = SubStatusCodes.FailedToGetAadToken, + }, + innerException: lastException, + trace: getTokenTrace); + } + catch (Exception exception) + { + lastException = exception; + getTokenTrace.AddDatum( + $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + exception.Message); - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed on fallback. scope = {string.Join(";", fallbackContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - throw CosmosExceptionFactory.CreateRequestTimeoutException( - message: ClientResources.FailedToGetAadToken, - headers: new Headers() { SubStatusCode = SubStatusCodes.FailedToGetAadToken, }, - innerException: lastException, - trace: getTokenTrace); - } - catch (Exception exceptionFallback) + // Fallback logic + if (this.scopeProvider.TryFallback(exception)) { - lastException = exceptionFallback; - getTokenTrace.AddDatum( - $"Exception (fallback) at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - exceptionFallback.Message); - - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed on fallback. scope = {string.Join(";", fallbackContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + continue; } } } - } - - if (lastException == null) - { - throw new ArgumentException("Last exception is null."); - } - - // The retries have been exhausted. Throw the last exception. - throw lastException; - } - finally - { - try - { - await this.isTokenRefreshingLock.WaitAsync(); - this.currentRefreshOperation = null; - } - finally - { - this.isTokenRefreshingLock.Release(); - } - } + } + + if (lastException == null) + { + throw new ArgumentException("Last exception is null."); + } + + // The retries have been exhausted. Throw the last exception. + throw lastException; + } + finally + { + try + { + await this.isTokenRefreshingLock.WaitAsync(); + this.currentRefreshOperation = null; + } + finally + { + this.isTokenRefreshingLock.Release(); + } + } } #pragma warning disable VSTHRD100 // Avoid async void methods From 3c8d3d30f36d131afc4fe747daae6937e8f3948b Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Wed, 20 Aug 2025 22:52:43 -0700 Subject: [PATCH 04/13] Fux formatting. --- .../src/Authorization/TokenCredentialCache.cs | 291 +++++++++--------- 1 file changed, 146 insertions(+), 145 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index 26dd40cbff..b3f63c7fd8 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -11,7 +11,7 @@ namespace Microsoft.Azure.Cosmos using System.Threading.Tasks; using global::Azure; using global::Azure.Core; - using Microsoft.Azure.Cosmos.Authorization; + using Microsoft.Azure.Cosmos.Authorization; using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.Resource.CosmosExceptions; using Microsoft.Azure.Cosmos.Tracing; @@ -37,53 +37,53 @@ internal sealed class TokenCredentialCache : IDisposable // If the background refresh fails with less than a minute then just allow the request to hit the exception. public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1); + private readonly IScopeProvider scopeProvider; private readonly TokenCredential tokenCredential; private readonly CancellationTokenSource cancellationTokenSource; private readonly CancellationToken cancellationToken; - private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval; + private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval; private readonly SemaphoreSlim isTokenRefreshingLock = new SemaphoreSlim(1); - private readonly object backgroundRefreshLock = new object(); - private readonly IScopeProvider scopeProvider; + private readonly object backgroundRefreshLock = new object(); private TimeSpan? systemBackgroundTokenCredentialRefreshInterval; private Task? currentRefreshOperation = null; private AccessToken? cachedAccessToken = null; private bool isBackgroundTaskRunning = false; private bool isDisposed = false; - - internal TokenCredentialCache( - TokenCredential tokenCredential, - Uri accountEndpoint, - TimeSpan? backgroundTokenCredentialRefreshInterval) - { - this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential)); - - if (accountEndpoint == null) - { - throw new ArgumentNullException(nameof(accountEndpoint)); + + internal TokenCredentialCache( + TokenCredential tokenCredential, + Uri accountEndpoint, + TimeSpan? backgroundTokenCredentialRefreshInterval) + { + this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential)); + + if (accountEndpoint == null) + { + throw new ArgumentNullException(nameof(accountEndpoint)); } this.scopeProvider = new Microsoft.Azure.Cosmos.Authorization.CosmosScopeProvider(accountEndpoint); - if (backgroundTokenCredentialRefreshInterval.HasValue) - { - if (backgroundTokenCredentialRefreshInterval.Value <= TimeSpan.Zero) - { - throw new ArgumentException($"{nameof(backgroundTokenCredentialRefreshInterval)} must be a positive value greater than 0. Value '{backgroundTokenCredentialRefreshInterval.Value.TotalMilliseconds}'."); - } - - // TimeSpan.MaxValue disables the background refresh - if (backgroundTokenCredentialRefreshInterval.Value > TokenCredentialCache.MaxBackgroundRefreshInterval && - backgroundTokenCredentialRefreshInterval.Value != TimeSpan.MaxValue) - { - throw new ArgumentException($"{nameof(backgroundTokenCredentialRefreshInterval)} must be less than or equal to {TokenCredentialCache.MaxBackgroundRefreshInterval}. Value '{backgroundTokenCredentialRefreshInterval.Value}'."); - } - } - - this.userDefinedBackgroundTokenCredentialRefreshInterval = backgroundTokenCredentialRefreshInterval; - this.cancellationTokenSource = new CancellationTokenSource(); - this.cancellationToken = this.cancellationTokenSource.Token; + if (backgroundTokenCredentialRefreshInterval.HasValue) + { + if (backgroundTokenCredentialRefreshInterval.Value <= TimeSpan.Zero) + { + throw new ArgumentException($"{nameof(backgroundTokenCredentialRefreshInterval)} must be a positive value greater than 0. Value '{backgroundTokenCredentialRefreshInterval.Value.TotalMilliseconds}'."); + } + + // TimeSpan.MaxValue disables the background refresh + if (backgroundTokenCredentialRefreshInterval.Value > TokenCredentialCache.MaxBackgroundRefreshInterval && + backgroundTokenCredentialRefreshInterval.Value != TimeSpan.MaxValue) + { + throw new ArgumentException($"{nameof(backgroundTokenCredentialRefreshInterval)} must be less than or equal to {TokenCredentialCache.MaxBackgroundRefreshInterval}. Value '{backgroundTokenCredentialRefreshInterval.Value}'."); + } + } + + this.userDefinedBackgroundTokenCredentialRefreshInterval = backgroundTokenCredentialRefreshInterval; + this.cancellationTokenSource = new CancellationTokenSource(); + this.cancellationToken = this.cancellationTokenSource.Token; } public TimeSpan? BackgroundTokenCredentialRefreshInterval => @@ -161,106 +161,107 @@ private async Task GetNewTokenAsync( return await currentTask; } - private async ValueTask RefreshCachedTokenWithRetryHelperAsync(ITrace trace) - { - try - { - Exception? lastException = null; - const int totalRetryCount = 2; - for (int retry = 0; retry < totalRetryCount; retry++) - { - if (this.cancellationToken.IsCancellationRequested) - { - DefaultTrace.TraceInformation( - "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); - - break; - } - - using (ITrace getTokenTrace = trace.StartChild( - name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), - component: TraceComponent.Authorization, - level: Tracing.TraceLevel.Info)) - { - try + private async ValueTask RefreshCachedTokenWithRetryHelperAsync( + ITrace trace) + { + try + { + Exception? lastException = null; + const int totalRetryCount = 2; + for (int retry = 0; retry < totalRetryCount; retry++) + { + if (this.cancellationToken.IsCancellationRequested) + { + DefaultTrace.TraceInformation( + "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); + + break; + } + + using (ITrace getTokenTrace = trace.StartChild( + name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), + component: TraceComponent.Authorization, + level: Tracing.TraceLevel.Info)) + { + try { TokenRequestContext tokenRequestContext = this.scopeProvider.GetTokenRequestContext(); - - this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( - requestContext: tokenRequestContext, - cancellationToken: this.cancellationToken); - - if (!this.cachedAccessToken.HasValue) - { - throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); - } - - if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) - { - throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}"); - } - - if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) - { - double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; - - // Ensure the background refresh interval is a valid range. - refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); - refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); - this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); - } - - return this.cachedAccessToken.Value; - } - catch (RequestFailedException requestFailedException) - { - lastException = requestFailedException; - getTokenTrace.AddDatum( - $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedException.Message); - + + this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( + requestContext: tokenRequestContext, + cancellationToken: this.cancellationToken); + + if (!this.cachedAccessToken.HasValue) + { + throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); + } + + if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) + { + throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}"); + } + + if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) + { + double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; + + // Ensure the background refresh interval is a valid range. + refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); + refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); + this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); + } + + return this.cachedAccessToken.Value; + } + catch (RequestFailedException requestFailedException) + { + lastException = requestFailedException; + getTokenTrace.AddDatum( + $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + requestFailedException.Message); + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - // Don't retry on auth failures - if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || - requestFailedException.Status == (int)HttpStatusCode.Forbidden) - { - this.cachedAccessToken = default; - throw; + + // Don't retry on auth failures + if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || + requestFailedException.Status == (int)HttpStatusCode.Forbidden) + { + this.cachedAccessToken = default; + throw; } // Fallback logic if (this.scopeProvider.TryFallback(requestFailedException)) { continue; - } - } - catch (OperationCanceledException operationCancelled) - { - lastException = operationCancelled; - getTokenTrace.AddDatum( - $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + } + } + catch (OperationCanceledException operationCancelled) + { + lastException = operationCancelled; + getTokenTrace.AddDatum( + $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", operationCancelled.Message); DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - throw CosmosExceptionFactory.CreateRequestTimeoutException( - message: ClientResources.FailedToGetAadToken, - headers: new Headers() - { - SubStatusCode = SubStatusCodes.FailedToGetAadToken, - }, - innerException: lastException, - trace: getTokenTrace); - } - catch (Exception exception) - { - lastException = exception; - getTokenTrace.AddDatum( - $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - exception.Message); - + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + + throw CosmosExceptionFactory.CreateRequestTimeoutException( + message: ClientResources.FailedToGetAadToken, + headers: new Headers() + { + SubStatusCode = SubStatusCodes.FailedToGetAadToken, + }, + innerException: lastException, + trace: getTokenTrace); + } + catch (Exception exception) + { + lastException = exception; + getTokenTrace.AddDatum( + $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + exception.Message); + DefaultTrace.TraceError( $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); @@ -269,30 +270,30 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync(ITra { continue; } - } - } - } - - if (lastException == null) - { - throw new ArgumentException("Last exception is null."); - } - - // The retries have been exhausted. Throw the last exception. - throw lastException; - } - finally - { - try - { - await this.isTokenRefreshingLock.WaitAsync(); - this.currentRefreshOperation = null; - } - finally - { - this.isTokenRefreshingLock.Release(); - } - } + } + } + } + + if (lastException == null) + { + throw new ArgumentException("Last exception is null."); + } + + // The retries have been exhausted. Throw the last exception. + throw lastException; + } + finally + { + try + { + await this.isTokenRefreshingLock.WaitAsync(); + this.currentRefreshOperation = null; + } + finally + { + this.isTokenRefreshingLock.Release(); + } + } } #pragma warning disable VSTHRD100 // Avoid async void methods From 90029c07abfe8e7585df06c86eb18ef499152564 Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Thu, 21 Aug 2025 11:51:01 -0700 Subject: [PATCH 05/13] Code cleanup. --- .../src/Authorization/CosmosScopeProvider.cs | 6 ++---- .../Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs index c5357b67f7..680e255a25 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs @@ -18,7 +18,6 @@ internal sealed class CosmosScopeProvider : IScopeProvider private readonly string accountScope; private readonly string overrideScope; private string currentScope; - private bool fallbackAttempted = false; public CosmosScopeProvider(Uri accountEndpoint) { @@ -40,8 +39,8 @@ public bool TryFallback(Exception ex) return false; } - // If already attempted fallback, do not fallback again - if (this.fallbackAttempted) + // If already using fallback scope, do not fallback again + if (this.currentScope == AadDefaultScope) { return false; } @@ -49,7 +48,6 @@ public bool TryFallback(Exception ex) if (ex.InnerException?.Message.Contains(AadInvalidScopeErrorMessage) == true) { this.currentScope = AadDefaultScope; - this.fallbackAttempted = true; return true; } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs index 9c1c56a357..b9bb314807 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs @@ -335,7 +335,7 @@ void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token) } [TestMethod] - public async Task Aad_AccountScope_Fallbacks_ToCosmos_OnSniCertRevoked_Unit() + public async Task Aad_AccountScope_Fallbacks_ToCosmosScope() { (string endpoint, string authKey) = TestCommon.GetAccountInfo(); From 0e709adddf27f77b6d8f6639e979fb243cb0b5b2 Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Thu, 21 Aug 2025 18:48:53 -0700 Subject: [PATCH 06/13] Code cleanup. --- .../src/Authorization/CosmosScopeProvider.cs | 10 ++++++---- .../src/Authorization/TokenCredentialCache.cs | 20 ++++++++++--------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs index 680e255a25..7950b64fba 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs @@ -31,7 +31,7 @@ public TokenRequestContext GetTokenRequestContext() return new TokenRequestContext(new[] { this.currentScope }); } - public bool TryFallback(Exception ex) + public bool TryFallback(Exception exception) { // If override scope is set, never fallback if (!string.IsNullOrEmpty(this.overrideScope)) @@ -40,16 +40,18 @@ public bool TryFallback(Exception ex) } // If already using fallback scope, do not fallback again - if (this.currentScope == AadDefaultScope) + if (this.currentScope == CosmosScopeProvider.AadDefaultScope) { return false; } - if (ex.InnerException?.Message.Contains(AadInvalidScopeErrorMessage) == true) +#pragma warning disable CDX1003 // DontUseExceptionToString + if (exception.ToString().Contains(CosmosScopeProvider.AadInvalidScopeErrorMessage) == true) { - this.currentScope = AadDefaultScope; + this.currentScope = CosmosScopeProvider.AadDefaultScope; return true; } +#pragma warning restore CDX1003 // DontUseExceptionToString return false; } diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index b3f63c7fd8..b267d9c742 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -50,7 +50,7 @@ internal sealed class TokenCredentialCache : IDisposable private Task? currentRefreshOperation = null; private AccessToken? cachedAccessToken = null; private bool isBackgroundTaskRunning = false; - private bool isDisposed = false; + private bool isDisposed = false; internal TokenCredentialCache( TokenCredential tokenCredential, @@ -163,11 +163,13 @@ private async Task GetNewTokenAsync( private async ValueTask RefreshCachedTokenWithRetryHelperAsync( ITrace trace) - { + { try { Exception? lastException = null; - const int totalRetryCount = 2; + const int totalRetryCount = 2; + TokenRequestContext tokenRequestContext = default; + for (int retry = 0; retry < totalRetryCount; retry++) { if (this.cancellationToken.IsCancellationRequested) @@ -182,10 +184,10 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), component: TraceComponent.Authorization, level: Tracing.TraceLevel.Info)) - { + { try { - TokenRequestContext tokenRequestContext = this.scopeProvider.GetTokenRequestContext(); + tokenRequestContext = this.scopeProvider.GetTokenRequestContext(); this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( requestContext: tokenRequestContext, @@ -214,13 +216,13 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( return this.cachedAccessToken.Value; } catch (RequestFailedException requestFailedException) - { + { lastException = requestFailedException; getTokenTrace.AddDatum( $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", requestFailedException.Message); - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes ?? Array.Empty())}, retry = {retry}, Exception = {lastException.Message}"); // Don't retry on auth failures if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || @@ -244,7 +246,7 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( operationCancelled.Message); DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes ?? Array.Empty())}, retry = {retry}, Exception = {lastException.Message}"); throw CosmosExceptionFactory.CreateRequestTimeoutException( message: ClientResources.FailedToGetAadToken, @@ -263,7 +265,7 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( exception.Message); DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.scopeProvider.GetTokenRequestContext().Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes ?? Array.Empty())}, retry = {retry}, Exception = {lastException.Message}"); // Fallback logic if (this.scopeProvider.TryFallback(exception)) From c48e1fb5bd29ebcb3928725de2aabaf500a916ec Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Fri, 22 Aug 2025 14:40:07 -0700 Subject: [PATCH 07/13] Added updated tests for cosmos scope provider. --- .../src/Authorization/TokenCredentialCache.cs | 9 ++- .../CosmosAadTests.cs | 49 ++++++++++++ .../Authorization/CosmosScopeProviderTests.cs | 78 +++++++++++++++++++ 3 files changed, 132 insertions(+), 4 deletions(-) create mode 100644 Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Authorization/CosmosScopeProviderTests.cs diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index b267d9c742..2e5ee5f58a 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -121,7 +121,8 @@ public void Dispose() } this.cancellationTokenSource.Cancel(); - this.cancellationTokenSource.Dispose(); + this.cancellationTokenSource.Dispose(); + this.scopeProvider.Dispose(); this.isDisposed = true; } @@ -222,7 +223,7 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", requestFailedException.Message); - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes ?? Array.Empty())}, retry = {retry}, Exception = {lastException.Message}"); + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); // Don't retry on auth failures if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || @@ -246,7 +247,7 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( operationCancelled.Message); DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes ?? Array.Empty())}, retry = {retry}, Exception = {lastException.Message}"); + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); throw CosmosExceptionFactory.CreateRequestTimeoutException( message: ClientResources.FailedToGetAadToken, @@ -265,7 +266,7 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( exception.Message); DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes ?? Array.Empty())}, retry = {retry}, Exception = {lastException.Message}"); + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); // Fallback logic if (this.scopeProvider.TryFallback(exception)) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs index b9bb314807..1e53f82a44 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs @@ -394,5 +394,54 @@ void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token) Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", previous); } } + + [TestMethod] + public async Task Aad_AccountScope_Success_NoFallback() + { + // Arrange + (string endpoint, string authKey) = TestCommon.GetAccountInfo(); + + string accountScope = $"https://{new Uri(endpoint).Host}/.default"; + string aadScope = "https://cosmos.azure.com/.default"; + + int accountScopeCount = 0; + int cosmosScopeCount = 0; + + void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token) + { + string scope = context.Scopes[0]; + + if (string.Equals(scope, accountScope, StringComparison.OrdinalIgnoreCase)) + { + accountScopeCount++; + } + + if (string.Equals(scope, aadScope, StringComparison.OrdinalIgnoreCase)) + { + cosmosScopeCount++; + } + } + + LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential( + expectedScopes: new[] { accountScope }, + masterKey: authKey, + getTokenCallback: GetAadTokenCallBack); + + CosmosClientOptions clientOptions = new CosmosClientOptions + { + ConnectionMode = ConnectionMode.Gateway, + TokenCredentialBackgroundRefreshInterval = TimeSpan.FromSeconds(60) + }; + + using CosmosClient aadClient = new CosmosClient(endpoint, credential, clientOptions); + TokenCredentialCache tokenCredentialCache = + ((AuthorizationTokenProviderTokenCredential)aadClient.AuthorizationTokenProvider).tokenCredentialCache; + + string token = await tokenCredentialCache.GetTokenAsync(Tracing.Trace.GetRootTrace("account-scope-success-no-fallback")); + Assert.IsFalse(string.IsNullOrEmpty(token), "Token should be acquired successfully with account scope."); + + Assert.AreEqual(1, accountScopeCount, "Account scope must be used exactly once."); + Assert.AreEqual(0, cosmosScopeCount, "Cosmos scope must not be used (no fallback)."); + } } } \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Authorization/CosmosScopeProviderTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Authorization/CosmosScopeProviderTests.cs new file mode 100644 index 0000000000..b4c99fbf8b --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Authorization/CosmosScopeProviderTests.cs @@ -0,0 +1,78 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Tests.Authorization +{ + using System; + using global::Azure.Core; + using Microsoft.Azure.Cosmos.Authorization; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class CosmosScopeProviderTests + { + private static readonly Uri TestAccountEndpoint = new Uri("https://testaccount.documents.azure.com:443/"); + + [DataTestMethod] + [DataRow("https://override/.default", "https://override/.default", DisplayName = "OverrideScope_Used")] + [DataRow(null, "https://testaccount.documents.azure.com/.default", DisplayName = "AccountScope_Used_WhenNoOverride")] + public void GetTokenRequestContext_UsesExpectedScope(string overrideScope, string expectedScope) + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", overrideScope); + + try + { + CosmosScopeProvider provider = new CosmosScopeProvider(TestAccountEndpoint); + TokenRequestContext context = provider.GetTokenRequestContext(); + Assert.AreEqual(expectedScope, context.Scopes[0]); + } + finally + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", null); + } + } + + [DataTestMethod] + [DataRow("https://override/.default", false, "AADSTS500011", "https://override/.default", DisplayName = "OverrideScope_NeverFallback")] + [DataRow(null, true, "AADSTS500011", "https://cosmos.azure.com/.default", DisplayName = "AccountScope_FallbacksToAadDefault")] + [DataRow(null, false, "SomeOtherError", "https://testaccount.documents.azure.com/.default", DisplayName = "AccountScope_NoFallbackOnOtherError")] + public void Test_TryFallback_Behavior( + string overrideScope, + bool expectFallback, + string exceptionMessage, + string expectedScope) + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", overrideScope); + + try + { + CosmosScopeProvider provider = new CosmosScopeProvider(TestAccountEndpoint); + + bool didFallback = provider.TryFallback(new Exception(exceptionMessage)); + + Assert.AreEqual(expectFallback, didFallback, "Fallback result mismatch."); + Assert.AreEqual(expectedScope, provider.GetTokenRequestContext().Scopes[0]); + } + finally + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", null); + } + } + + [TestMethod] + public void TryFallback_DoesNotFallback_WhenAlreadyUsingAadDefault() + { + Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", null); + CosmosScopeProvider provider = new CosmosScopeProvider(TestAccountEndpoint); + + provider.TryFallback(new Exception("AADSTS500011")); + Assert.AreEqual("https://cosmos.azure.com/.default", provider.GetTokenRequestContext().Scopes[0]); + + // Act + bool didFallbackAgain = provider.TryFallback(new Exception("AADSTS500011")); + + Assert.IsFalse(didFallbackAgain, "Should not fallback again when already using AadDefault scope."); + Assert.AreEqual("https://cosmos.azure.com/.default", provider.GetTokenRequestContext().Scopes[0]); + } + } +} \ No newline at end of file From 74b031120912e72a4e4d0cfa087184e85132dfc2 Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Fri, 22 Aug 2025 18:04:14 -0700 Subject: [PATCH 08/13] Code cleanup. --- .../Utils/LocalEmulatorTokenCredential.cs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs index 5fda33b046..0d8da95a0b 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/LocalEmulatorTokenCredential.cs @@ -18,7 +18,6 @@ public class LocalEmulatorTokenCredential : TokenCredential private readonly DateTime? DefaultDateTime = null; private readonly Action GetTokenCallback; private readonly string masterKey; - private readonly string expectedScope; private readonly string[] expectedScopes; internal LocalEmulatorTokenCredential( @@ -30,8 +29,7 @@ internal LocalEmulatorTokenCredential( this.masterKey = masterKey; this.GetTokenCallback = getTokenCallback; this.DefaultDateTime = defaultDateTime; - this.expectedScope = expectedScope; - this.expectedScopes = null; + this.expectedScopes = new string[] { expectedScope }; } internal LocalEmulatorTokenCredential( @@ -43,7 +41,6 @@ internal LocalEmulatorTokenCredential( this.masterKey = masterKey; this.GetTokenCallback = getTokenCallback; this.DefaultDateTime = defaultDateTime; - this.expectedScope = null; this.expectedScopes = expectedScopes; } @@ -59,10 +56,7 @@ public override ValueTask GetTokenAsync(TokenRequestContext request private AccessToken GetAccessToken(TokenRequestContext requestContext, CancellationToken cancellationToken) { - if (this.expectedScope != null) - { - Assert.AreEqual(this.expectedScope, requestContext.Scopes.First()); - } + Assert.IsTrue(this.expectedScopes.Contains(requestContext.Scopes.First())); this.GetTokenCallback?.Invoke( requestContext, From 4f5a0175f2dc9b8fc49f2494e459dbe7b8a7643c Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Mon, 25 Aug 2025 10:05:24 -0700 Subject: [PATCH 09/13] Code cleanup --- .../src/Authorization/TokenCredentialCache.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index 2e5ee5f58a..4a6be6e7b4 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -165,12 +165,12 @@ private async Task GetNewTokenAsync( private async ValueTask RefreshCachedTokenWithRetryHelperAsync( ITrace trace) { + Exception? lastException = null; + const int totalRetryCount = 2; + TokenRequestContext tokenRequestContext = default; + try { - Exception? lastException = null; - const int totalRetryCount = 2; - TokenRequestContext tokenRequestContext = default; - for (int retry = 0; retry < totalRetryCount; retry++) { if (this.cancellationToken.IsCancellationRequested) From 481e76bb64c596a95dacce4b68cd93cb6f3fa74d Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Mon, 25 Aug 2025 11:33:41 -0700 Subject: [PATCH 10/13] nit updates based on review. --- Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs | 2 +- .../src/Authorization/TokenCredentialCache.cs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs index 0964b6aec5..987c6565ba 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs @@ -8,7 +8,7 @@ namespace Microsoft.Azure.Cosmos.Authorization using System.Text; using global::Azure.Core; - internal interface IScopeProvider : IDisposable + internal interface IScopeProvider { TokenRequestContext GetTokenRequestContext(); bool TryFallback(Exception ex); diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index 4a6be6e7b4..f7d40fcaf3 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -122,7 +122,6 @@ public void Dispose() this.cancellationTokenSource.Cancel(); this.cancellationTokenSource.Dispose(); - this.scopeProvider.Dispose(); this.isDisposed = true; } @@ -236,6 +235,8 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( // Fallback logic if (this.scopeProvider.TryFallback(requestFailedException)) { + DefaultTrace.TraceWarning( + $"TokenCredentialCache: Fallback to default scope triggered due to exception: {requestFailedException.Message}"); continue; } } From afcf4bd25512880774298e6b737f9fc4303ff067 Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Mon, 25 Aug 2025 13:00:36 -0700 Subject: [PATCH 11/13] Update logging for fallback. --- .../src/Authorization/TokenCredentialCache.cs | 67 +++++++++++-------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index f7d40fcaf3..ce1c8750fd 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -215,30 +215,35 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( return this.cachedAccessToken.Value; } - catch (RequestFailedException requestFailedException) + catch (RequestFailedException requestFailedException) { - lastException = requestFailedException; - getTokenTrace.AddDatum( - $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedException.Message); - + lastException = requestFailedException; + getTokenTrace.AddDatum( + $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + requestFailedException.Message); + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - // Don't retry on auth failures - if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || - requestFailedException.Status == (int)HttpStatusCode.Forbidden) - { - this.cachedAccessToken = default; - throw; + + // Don't retry on auth failures + if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || + requestFailedException.Status == (int)HttpStatusCode.Forbidden) + { + this.cachedAccessToken = default; + throw; } - // Fallback logic - if (this.scopeProvider.TryFallback(requestFailedException)) + bool didFallback = this.scopeProvider.TryFallback(requestFailedException); + string logMessage = $"TokenCredential.GetToken() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {requestFailedException.Message}. Fallback attempted: {didFallback}"; + + if (didFallback) { - DefaultTrace.TraceWarning( - $"TokenCredentialCache: Fallback to default scope triggered due to exception: {requestFailedException.Message}"); + DefaultTrace.TraceInformation(logMessage); continue; - } + } + else + { + DefaultTrace.TraceWarning(logMessage); + } } catch (OperationCanceledException operationCancelled) { @@ -259,21 +264,25 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( innerException: lastException, trace: getTokenTrace); } - catch (Exception exception) - { - lastException = exception; - getTokenTrace.AddDatum( - $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - exception.Message); - - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); + catch (Exception exception) + { + lastException = exception; + getTokenTrace.AddDatum( + $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + exception.Message); + + bool didFallback = this.scopeProvider.TryFallback(exception); + string logMessage = $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}. Fallback attempted: {didFallback}"; - // Fallback logic - if (this.scopeProvider.TryFallback(exception)) + if (didFallback) { + DefaultTrace.TraceInformation(logMessage); continue; } + else + { + DefaultTrace.TraceWarning(logMessage); + } } } } From 19c526f92c1cd2f0819e2e4445c8d6ea2949f253 Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Mon, 25 Aug 2025 16:37:18 -0700 Subject: [PATCH 12/13] Update logging for AAD fallback. --- .../src/Authorization/CosmosScopeProvider.cs | 7 ------- .../src/Authorization/IScopeProvider.cs | 2 -- .../src/Authorization/TokenCredentialCache.cs | 16 +--------------- 3 files changed, 1 insertion(+), 24 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs index 7950b64fba..41a05e137f 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs @@ -4,9 +4,6 @@ namespace Microsoft.Azure.Cosmos.Authorization { using System; - using System.Collections.Generic; - using System.Linq; - using System.Text; using global::Azure.Core; internal sealed class CosmosScopeProvider : IScopeProvider @@ -55,9 +52,5 @@ public bool TryFallback(Exception exception) return false; } - - public void Dispose() - { - } } } diff --git a/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs index 987c6565ba..545270f9eb 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs @@ -4,8 +4,6 @@ namespace Microsoft.Azure.Cosmos.Authorization { using System; - using System.Collections.Generic; - using System.Text; using global::Azure.Core; internal interface IScopeProvider diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index ce1c8750fd..4c9d89bcc0 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -214,7 +214,7 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( } return this.cachedAccessToken.Value; - } + } catch (RequestFailedException requestFailedException) { lastException = requestFailedException; @@ -231,19 +231,6 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( this.cachedAccessToken = default; throw; } - - bool didFallback = this.scopeProvider.TryFallback(requestFailedException); - string logMessage = $"TokenCredential.GetToken() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {requestFailedException.Message}. Fallback attempted: {didFallback}"; - - if (didFallback) - { - DefaultTrace.TraceInformation(logMessage); - continue; - } - else - { - DefaultTrace.TraceWarning(logMessage); - } } catch (OperationCanceledException operationCancelled) { @@ -277,7 +264,6 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( if (didFallback) { DefaultTrace.TraceInformation(logMessage); - continue; } else { From 554887061ff0165efeab8519be8340d51bb1cf2d Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Mon, 25 Aug 2025 17:39:21 -0700 Subject: [PATCH 13/13] Update logging for AAD fallback. --- .../src/Authorization/TokenCredentialCache.cs | 34 ++++++------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index 4c9d89bcc0..efb10277d8 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -214,23 +214,6 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( } return this.cachedAccessToken.Value; - } - catch (RequestFailedException requestFailedException) - { - lastException = requestFailedException; - getTokenTrace.AddDatum( - $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedException.Message); - - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - - // Don't retry on auth failures - if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || - requestFailedException.Status == (int)HttpStatusCode.Forbidden) - { - this.cachedAccessToken = default; - throw; - } } catch (OperationCanceledException operationCancelled) { @@ -258,16 +241,21 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync( $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", exception.Message); - bool didFallback = this.scopeProvider.TryFallback(exception); - string logMessage = $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}. Fallback attempted: {didFallback}"; + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); - if (didFallback) + // Don't retry on auth failures + if (exception is RequestFailedException requestFailedException && + (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || + requestFailedException.Status == (int)HttpStatusCode.Forbidden)) { - DefaultTrace.TraceInformation(logMessage); + this.cachedAccessToken = default; + throw; } - else + bool didFallback = this.scopeProvider.TryFallback(exception); + + if (didFallback) { - DefaultTrace.TraceWarning(logMessage); + DefaultTrace.TraceInformation($"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}. Fallback attempted: {didFallback}"); } } }