diff --git a/Microsoft.Azure.Cosmos/FaultInjection/src/implementation/FaultInjectionServerErrorResultInternal.cs b/Microsoft.Azure.Cosmos/FaultInjection/src/implementation/FaultInjectionServerErrorResultInternal.cs
index c354cf5fde..8b235ecd11 100644
--- a/Microsoft.Azure.Cosmos/FaultInjection/src/implementation/FaultInjectionServerErrorResultInternal.cs
+++ b/Microsoft.Azure.Cosmos/FaultInjection/src/implementation/FaultInjectionServerErrorResultInternal.cs
@@ -283,7 +283,10 @@ public StoreResponse GetInjectedServerError(ChannelCallArguments args, string ru
case FaultInjectionServerErrorType.AadTokenRevoked:
INameValueCollection aadTokenRevokedHeaders = args.RequestHeaders;
aadTokenRevokedHeaders.Set(WFConstants.BackendHeaders.LocalLSN, lsn);
- aadTokenRevokedHeaders.Set(WFConstants.BackendHeaders.SubStatus, "5013");
+ aadTokenRevokedHeaders.Set(WFConstants.BackendHeaders.SubStatus, ((int)SubStatusCodes.AadTokenRevoked).ToString());
+ aadTokenRevokedHeaders.Set(
+ HttpConstants.HttpHeaders.WwwAuthenticate,
+ this.GenerateWwwAuthenticateForRevocation());
storeResponse = new StoreResponse()
{
Status = 401,
@@ -592,14 +595,17 @@ public HttpResponseMessage GetInjectedServerError(DocumentServiceRequest dsr, st
new MemoryStream(
isProxyCall
? FaultInjectionResponseEncoding.GetBytes(
- GetProxyResponseMessageString((int)StatusCodes.Unauthorized, 5013, "AadTokenRevoked", ruleId))
+ GetProxyResponseMessageString((int)StatusCodes.Unauthorized, (int)SubStatusCodes.AadTokenRevoked, "AadTokenRevoked", ruleId))
: FaultInjectionResponseEncoding.GetBytes($"Fault Injection Server Error: AadTokenRevoked, rule: {ruleId}"))),
};
this.SetHttpHeaders(httpResponse, headers, isProxyCall);
httpResponse.Headers.Add(
WFConstants.BackendHeaders.SubStatus,
- "5013");
+ ((int)SubStatusCodes.AadTokenRevoked).ToString());
httpResponse.Headers.Add(WFConstants.BackendHeaders.LocalLSN, lsn);
+ httpResponse.Headers.TryAddWithoutValidation(
+ HttpConstants.HttpHeaders.WwwAuthenticate,
+ this.GenerateWwwAuthenticateForRevocation());
return httpResponse;
default:
throw new ArgumentException($"Server error type {this.serverErrorType} is not supported");
@@ -641,6 +647,15 @@ private static string GetProxyResponseMessageString(
return $"{{\"code\": \"{statusCode}:{subStatusCode}\",\"message\":\"Fault Injection Server Error: {message}, rule: {faultInjectionRuleId}\"}}";
}
+ private string GenerateWwwAuthenticateForRevocation()
+ {
+ long currentTimestamp = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
+ string claimsJson = "{\"access_token\":{\"nbf\":{\"essential\":false,\"value\":\"" + currentTimestamp.ToString() + "\"}}}";
+ string base64Claims = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(claimsJson));
+
+ return "Bearer realm=\"\", authorization_uri=\"\", error=\"insufficient_claims\", claims=\"" + base64Claims + "\"";
+ }
+
internal class FaultInjectionHttpContent : HttpContent
{
private readonly Stream content;
diff --git a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderTokenCredential.cs b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderTokenCredential.cs
index cccacb240a..ae05438bca 100644
--- a/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderTokenCredential.cs
+++ b/Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderTokenCredential.cs
@@ -6,6 +6,7 @@ namespace Microsoft.Azure.Cosmos
{
using System;
using System.Globalization;
+ using System.Net;
using System.Threading.Tasks;
using global::Azure.Core;
using Microsoft.Azure.Cosmos.Core.Trace;
@@ -17,7 +18,7 @@ internal sealed class AuthorizationTokenProviderTokenCredential : AuthorizationT
{
internal readonly TokenCredentialCache tokenCredentialCache;
private bool isDisposed = false;
-
+
internal readonly TokenCredential tokenCredential;
public AuthorizationTokenProviderTokenCredential(
@@ -99,5 +100,110 @@ public override void Dispose()
this.tokenCredentialCache.Dispose();
}
}
+
+ ///
+ /// Attempts to handle AAD token revocation by checking for claims challenge.
+ /// Extracts claims from WWW-Authenticate header value and resets cache for retry with fresh token.
+ ///
+ /// HTTP status code from the response
+ /// The WWW-Authenticate response header value
+ /// True if claims challenge detected and request should be retried; false otherwise
+ internal bool TryHandleTokenRevocation(
+ HttpStatusCode statusCode,
+ string wwwAuthenticateHeaderValue)
+ {
+ if (statusCode != HttpStatusCode.Unauthorized)
+ {
+ return false;
+ }
+
+ if (string.IsNullOrEmpty(wwwAuthenticateHeaderValue))
+ {
+ return false;
+ }
+
+ // Check for claims challenge indicators
+ bool hasClaimsChallenge = wwwAuthenticateHeaderValue.IndexOf("insufficient_claims", StringComparison.OrdinalIgnoreCase) >= 0
+ || wwwAuthenticateHeaderValue.IndexOf("claims=", StringComparison.OrdinalIgnoreCase) >= 0;
+
+ if (!hasClaimsChallenge)
+ {
+ return false;
+ }
+
+ string claimsChallenge = AuthorizationTokenProviderTokenCredential.ExtractClaimsFromWwwAuthenticate(wwwAuthenticateHeaderValue);
+
+ // Reset cache with claims challenge for next token request
+ this.tokenCredentialCache.ResetCachedToken(claimsChallenge);
+
+ DefaultTrace.TraceInformation(
+ "AAD token revocation detected (claims challenge present). Token cache reset. " +
+ "Request will be retried with fresh token including claims. HasClaims={0}",
+ claimsChallenge != null);
+
+ return true;
+ }
+
+ ///
+ /// Extracts the claims challenge from the WWW-Authenticate header value.
+ ///
+ /// WWW-Authenticate header value
+ /// Base64-encoded claims string, or null if not present
+ private static string ExtractClaimsFromWwwAuthenticate(string wwwAuthenticateHeader)
+ {
+ if (string.IsNullOrEmpty(wwwAuthenticateHeader))
+ {
+ return null;
+ }
+
+ const string claimsPrefix = "claims=\"";
+ int claimsIndex = wwwAuthenticateHeader.IndexOf(claimsPrefix, StringComparison.OrdinalIgnoreCase);
+ if (claimsIndex < 0)
+ {
+ return null;
+ }
+
+ int startIndex = claimsIndex + claimsPrefix.Length;
+ int endIndex = wwwAuthenticateHeader.IndexOf("\"", startIndex, StringComparison.Ordinal);
+ if (endIndex < 0)
+ {
+ return null;
+ }
+
+ return wwwAuthenticateHeader.Substring(startIndex, endIndex - startIndex);
+ }
+ ///
+ /// Checks if a DocumentClientException is a 401/5013 token revocation that can be handled
+ /// by extracting claims from WWW-Authenticate and resetting the token cache.
+ /// Used by code paths outside the handler pipeline (GatewayAccountReader, GatewayAddressCache).
+ /// Returns true if the caller should retry the request.
+ ///
+ internal static bool TryHandleRevocationException(
+ AuthorizationTokenProvider authorizationTokenProvider,
+ DocumentClientException exception)
+ {
+ if (exception.StatusCode != HttpStatusCode.Unauthorized)
+ {
+ return false;
+ }
+
+ if (!(authorizationTokenProvider is AuthorizationTokenProviderTokenCredential tokenProvider))
+ {
+ return false;
+ }
+
+ string wwwAuthenticate = exception.Headers?.Get(HttpConstants.HttpHeaders.WwwAuthenticate);
+
+ // Proceed if either substatus is AadTokenRevoked (emergency) or WWW-Authenticate is present (CAE)
+ if (exception.GetSubStatus() != SubStatusCodes.AadTokenRevoked
+ && string.IsNullOrEmpty(wwwAuthenticate))
+ {
+ return false;
+ }
+
+ return tokenProvider.TryHandleTokenRevocation(
+ HttpStatusCode.Unauthorized,
+ wwwAuthenticate);
+ }
}
}
diff --git a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs
index 41a05e137f..516186bec2 100644
--- a/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs
+++ b/Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs
@@ -25,7 +25,7 @@ public CosmosScopeProvider(Uri accountEndpoint)
public TokenRequestContext GetTokenRequestContext()
{
- return new TokenRequestContext(new[] { this.currentScope });
+ return new TokenRequestContext(new[] { this.currentScope }, isCaeEnabled: true);
}
public bool TryFallback(Exception exception)
diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs
index 25f60af3ec..fbba690b8c 100644
--- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs
+++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs
@@ -62,6 +62,7 @@ public AuthState(AccessToken token, string authorizationHeader)
private TimeSpan? systemBackgroundTokenCredentialRefreshInterval;
private Task? currentRefreshOperation = null;
private volatile AuthState? authState = null;
+ private volatile string? cachedClaimsChallenge;
private bool isBackgroundTaskRunning = false;
private bool isDisposed = false;
@@ -142,6 +143,93 @@ public void Dispose()
this.isDisposed = true;
}
+ internal void ResetCachedToken(string? claimsChallenge = null)
+ {
+ if (this.isDisposed)
+ {
+ return;
+ }
+
+ lock (this.backgroundRefreshLock)
+ {
+ this.authState = null;
+ this.currentRefreshOperation = null;
+ this.isBackgroundTaskRunning = false;
+ this.cachedClaimsChallenge = claimsChallenge;
+ }
+
+ DefaultTrace.TraceInformation(
+ $"TokenCredentialCache: Token cache reset due to AAD revocation signal. HasClaims={claimsChallenge != null}");
+ }
+
+ internal static string MergeClaimsWithClientCapabilities(string? claimsChallenge)
+ {
+ const string clientCapabilitiesJson = "{\"access_token\":{\"xms_cc\":{\"values\":[\"cp1\"]}}}";
+
+ if (string.IsNullOrEmpty(claimsChallenge))
+ {
+ return clientCapabilitiesJson;
+ }
+
+ try
+ {
+ byte[] claimsBytes = Convert.FromBase64String(claimsChallenge);
+ string claimsJson = System.Text.Encoding.UTF8.GetString(claimsBytes);
+
+ int accessTokenIndex = claimsJson.IndexOf("\"access_token\"", StringComparison.Ordinal);
+ if (accessTokenIndex < 0)
+ {
+ DefaultTrace.TraceWarning("TokenCredentialCache: CAE claims challenge missing 'access_token' key, using client capabilities only");
+ return clientCapabilitiesJson;
+ }
+
+ int openBraceIndex = claimsJson.IndexOf('{', accessTokenIndex);
+ if (openBraceIndex < 0)
+ {
+ DefaultTrace.TraceWarning("TokenCredentialCache: Malformed CAE claims challenge, using client capabilities only");
+ return clientCapabilitiesJson;
+ }
+
+ int braceCount = 1;
+ int currentIndex = openBraceIndex + 1;
+ int closeBraceIndex = -1;
+
+ while (currentIndex < claimsJson.Length && braceCount > 0)
+ {
+ if (claimsJson[currentIndex] == '{')
+ {
+ braceCount++;
+ }
+ else if (claimsJson[currentIndex] == '}')
+ {
+ braceCount--;
+ if (braceCount == 0)
+ {
+ closeBraceIndex = currentIndex;
+ break;
+ }
+ }
+
+ currentIndex++;
+ }
+
+ if (closeBraceIndex < 0)
+ {
+ DefaultTrace.TraceWarning("TokenCredentialCache: Unable to locate the end of the 'access_token' object in the CAE claims challenge. Using client capabilities only");
+ return clientCapabilitiesJson;
+ }
+
+ return claimsJson.Substring(0, closeBraceIndex) +
+ ",\"xms_cc\":{\"values\":[\"cp1\"]}" +
+ claimsJson.Substring(closeBraceIndex);
+ }
+ catch (Exception ex)
+ {
+ DefaultTrace.TraceWarning($"TokenCredentialCache: Failed to merge claims challenge: {ex.Message}. Using client capabilities only.");
+ return clientCapabilitiesJson;
+ }
+ }
+
private async Task GetNewTokenAsync(
ITrace trace)
{
@@ -206,6 +294,25 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync(
{
tokenRequestContext = this.scopeProvider.GetTokenRequestContext();
+ string mergedClaims = TokenCredentialCache.MergeClaimsWithClientCapabilities(this.cachedClaimsChallenge);
+ if (string.IsNullOrEmpty(this.cachedClaimsChallenge))
+ {
+ DefaultTrace.TraceInformation(
+ $"Requesting AAD token with CAE client capabilities (cp1). Retry={retry}");
+ }
+ else
+ {
+ DefaultTrace.TraceInformation(
+ $"Requesting AAD token for revocation with claims challenge and client capabilities (cp1). Retry={retry}");
+ }
+
+ tokenRequestContext = new TokenRequestContext(
+ scopes: tokenRequestContext.Scopes,
+ parentRequestId: tokenRequestContext.ParentRequestId,
+ claims: mergedClaims,
+ tenantId: tokenRequestContext.TenantId,
+ isCaeEnabled: tokenRequestContext.IsCaeEnabled);
+
AccessToken accessToken = await this.tokenCredential.GetTokenAsync(
requestContext: tokenRequestContext,
cancellationToken: this.cancellationToken);
@@ -225,6 +332,8 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync(
this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds);
}
+ this.cachedClaimsChallenge = null;
+
AuthState newState = new AuthState(
accessToken,
this.tokenToAuthorizationHeader(accessToken.Token));
@@ -257,7 +366,12 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception.Message);
- DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
+ DefaultTrace.TraceError(
+ $"TokenCredential.GetTokenAuthorizationHeaderAsync() failed. " +
+ $"scope = {string.Join(";", tokenRequestContext.Scopes)}, " +
+ $"hasClaimsChallenge = {this.cachedClaimsChallenge != null}, " +
+ $"retry = {retry}, " +
+ $"Exception = {lastException.Message}");
// Don't retry on auth failures
if (exception is RequestFailedException requestFailedException &&
@@ -265,6 +379,7 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync(
requestFailedException.Status == (int)HttpStatusCode.Forbidden))
{
this.authState = null;
+ this.cachedClaimsChallenge = null;
throw;
}
bool didFallback = this.scopeProvider.TryFallback(exception);
@@ -282,6 +397,8 @@ private async ValueTask RefreshCachedTokenWithRetryHelperAsync(
throw new ArgumentException("Last exception is null.");
}
+ this.cachedClaimsChallenge = null;
+
// The retries have been exhausted. Throw the last exception.
throw lastException;
}
diff --git a/Microsoft.Azure.Cosmos/src/ClientRetryPolicy.cs b/Microsoft.Azure.Cosmos/src/ClientRetryPolicy.cs
index 713c114b7e..a86fe63307 100644
--- a/Microsoft.Azure.Cosmos/src/ClientRetryPolicy.cs
+++ b/Microsoft.Azure.Cosmos/src/ClientRetryPolicy.cs
@@ -13,7 +13,8 @@ namespace Microsoft.Azure.Cosmos
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Core.Trace;
using Microsoft.Azure.Cosmos.Routing;
- using Microsoft.Azure.Documents;
+ using Microsoft.Azure.Documents;
+ using Microsoft.Azure.Documents.Collections;
///
/// Client policy is combination of endpoint change retry + throttling retry.
@@ -23,7 +24,8 @@ internal sealed class ClientRetryPolicy : IDocumentClientRetryPolicy
private const int RetryIntervalInMS = 1000; // Once we detect failover wait for 1 second before retrying request.
private const int MaxRetryCount = 120;
private const int MaxServiceUnavailableRetryCount = 1;
- private const int MaxSessionTokenRetryCount = 2;
+ private const int MaxSessionTokenRetryCount = 2;
+ private const int MaxCaeRevocationRetryCount = 1;
// ----- DTX (Distributed Transaction) inner-loop retry constants -----
// The outer loop (DistributedTransactionCommitter) handles body-bearing isRetriable failures.
@@ -39,10 +41,12 @@ internal sealed class ClientRetryPolicy : IDocumentClientRetryPolicy
private readonly GlobalPartitionEndpointManager partitionKeyRangeLocationCache;
private readonly bool enableEndpointDiscovery;
private readonly bool isThinClientEnabled;
+ private readonly AuthorizationTokenProvider authorizationTokenProvider;
private int failoverRetryCount;
private int sessionTokenRetryCount;
- private int serviceUnavailableRetryCount;
+ private int serviceUnavailableRetryCount;
+ private int caeRevocationRetryCount;
private int distributedTransactionRetryCount;
private int distributedTransactionInfraFailureRetryCount;
private bool isReadRequest;
@@ -62,7 +66,8 @@ public ClientRetryPolicy(
GlobalPartitionEndpointManager partitionKeyRangeLocationCache,
RetryOptions retryOptions,
bool enableEndpointDiscovery,
- bool isThinClientEnabled)
+ bool isThinClientEnabled,
+ AuthorizationTokenProvider authorizationTokenProvider = null)
{
this.throttlingRetry = new ResourceThrottleRetryPolicy(
retryOptions.MaxRetryAttemptsOnThrottledRequests,
@@ -73,10 +78,12 @@ public ClientRetryPolicy(
this.failoverRetryCount = 0;
this.enableEndpointDiscovery = enableEndpointDiscovery;
this.sessionTokenRetryCount = 0;
- this.serviceUnavailableRetryCount = 0;
+ this.serviceUnavailableRetryCount = 0;
+ this.caeRevocationRetryCount = 0;
this.canUseMultipleWriteLocations = false;
this.isMultiMasterWriteRequest = false;
- this.isThinClientEnabled = isThinClientEnabled;
+ this.isThinClientEnabled = isThinClientEnabled;
+ this.authorizationTokenProvider = authorizationTokenProvider;
}
///
@@ -131,7 +138,8 @@ public async Task ShouldRetryAsync(
ShouldRetryResult shouldRetryResult = await this.ShouldRetryInternalAsync(
clientException?.StatusCode,
- clientException?.GetSubStatus(),
+ clientException?.GetSubStatus(),
+ clientException?.Headers,
clientException?.RetryAfter);
if (shouldRetryResult != null)
{
@@ -146,7 +154,8 @@ public async Task ShouldRetryAsync(
{
ShouldRetryResult shouldRetryResult = await this.ShouldRetryInternalAsync(
cosmosException.StatusCode,
- cosmosException.Headers.SubStatusCode,
+ cosmosException.Headers.SubStatusCode,
+ cosmosException.Headers,
cosmosException.RetryAfter);
if (shouldRetryResult != null)
{
@@ -190,7 +199,8 @@ public async Task ShouldRetryAsync(
ShouldRetryResult shouldRetryResult = await this.ShouldRetryInternalAsync(
cosmosResponseMessage?.StatusCode,
- cosmosResponseMessage?.Headers.SubStatusCode,
+ cosmosResponseMessage?.Headers.SubStatusCode,
+ cosmosResponseMessage?.Headers,
cosmosResponseMessage?.Headers.RetryAfter,
hasResponseBody);
if (shouldRetryResult != null)
@@ -281,9 +291,40 @@ public void OnBeforeSendRequest(DocumentServiceRequest request)
private async Task ShouldRetryInternalAsync(
HttpStatusCode? statusCode,
- SubStatusCodes? subStatusCode,
- TimeSpan? retryAfter = null,
- bool hasResponseBody = false)
+ SubStatusCodes? subStatusCode,
+ INameValueCollection responseHeaders,
+ TimeSpan? retryAfter = null,
+ bool hasResponseBody = false)
+ {
+ return await this.ShouldRetryInternalAsync(
+ statusCode,
+ subStatusCode,
+ responseHeaders?.Get(HttpConstants.HttpHeaders.WwwAuthenticate),
+ retryAfter,
+ hasResponseBody);
+ }
+
+ private async Task ShouldRetryInternalAsync(
+ HttpStatusCode? statusCode,
+ SubStatusCodes? subStatusCode,
+ Headers responseHeaders,
+ TimeSpan? retryAfter = null,
+ bool hasResponseBody = false)
+ {
+ return await this.ShouldRetryInternalAsync(
+ statusCode,
+ subStatusCode,
+ responseHeaders?[HttpConstants.HttpHeaders.WwwAuthenticate],
+ retryAfter,
+ hasResponseBody);
+ }
+
+ private async Task ShouldRetryInternalAsync(
+ HttpStatusCode? statusCode,
+ SubStatusCodes? subStatusCode,
+ string wwwAuthenticateHeaderValue,
+ TimeSpan? retryAfter = null,
+ bool hasResponseBody = false)
{
if (!statusCode.HasValue
&& (!subStatusCode.HasValue
@@ -391,6 +432,13 @@ private async Task ShouldRetryInternalAsync(
{
return this.ShouldRetryOnUnavailableEndpointStatusCodes();
}
+
+ if (statusCode == HttpStatusCode.Unauthorized
+ && (subStatusCode == SubStatusCodes.AadTokenRevoked
+ || !string.IsNullOrEmpty(wwwAuthenticateHeaderValue)))
+ {
+ return this.HandleUnauthorizedResponse(wwwAuthenticateHeaderValue);
+ }
if (this.isDtxRequest)
{
@@ -398,7 +446,37 @@ private async Task ShouldRetryInternalAsync(
}
return null;
- }
+ }
+
+ private ShouldRetryResult HandleUnauthorizedResponse(string wwwAuthenticateHeaderValue)
+ {
+ if (!(this.authorizationTokenProvider is AuthorizationTokenProviderTokenCredential tokenProvider)
+ || this.documentServiceRequest == null)
+ {
+ return null;
+ }
+
+ if (this.caeRevocationRetryCount >= ClientRetryPolicy.MaxCaeRevocationRetryCount)
+ {
+ DefaultTrace.TraceWarning(
+ "ClientRetryPolicy: Token revocation max retry count ({0}) exceeded. Not retrying.",
+ ClientRetryPolicy.MaxCaeRevocationRetryCount);
+ return ShouldRetryResult.NoRetry();
+ }
+
+ if (!tokenProvider.TryHandleTokenRevocation(
+ HttpStatusCode.Unauthorized,
+ wwwAuthenticateHeaderValue))
+ {
+ return null;
+ }
+
+ this.caeRevocationRetryCount++;
+ DefaultTrace.TraceInformation(
+ "ClientRetryPolicy: AAD token revocation handled. Retrying with fresh token. RetryCount={0}",
+ this.caeRevocationRetryCount);
+ return ShouldRetryResult.RetryAfter(TimeSpan.Zero);
+ }
private async Task ShouldRetryOnEndpointFailureAsync(
bool isReadRequest,
diff --git a/Microsoft.Azure.Cosmos/src/DocumentClient.cs b/Microsoft.Azure.Cosmos/src/DocumentClient.cs
index c70b0b8668..b3ec8fbeac 100644
--- a/Microsoft.Azure.Cosmos/src/DocumentClient.cs
+++ b/Microsoft.Azure.Cosmos/src/DocumentClient.cs
@@ -1080,7 +1080,8 @@ private async Task GetInitializationTaskAsync(IStoreClientFactory storeCli
globalEndpointManager: this.GlobalEndpointManager,
connectionPolicy: this.ConnectionPolicy,
partitionKeyRangeLocationCache: this.PartitionKeyRangeLocation,
- isThinClientEnabled: this.isThinClientEnabled);
+ isThinClientEnabled: this.isThinClientEnabled,
+ this.cosmosAuthorization);
this.ResetSessionTokenRetryPolicy = this.retryPolicy;
@@ -6813,7 +6814,8 @@ private void InitializeDirectConnectivity(IStoreClientFactory storeClientFactory
this.ConnectionPolicy,
this.httpClient,
this.storeClientFactory.GetConnectionStateListener(),
- this.enableAsyncCacheExceptionNoSharing);
+ this.enableAsyncCacheExceptionNoSharing,
+ authorizationTokenProvider: this.cosmosAuthorization);
this.CreateStoreModel(subscribeRntbdStatus: true);
}
diff --git a/Microsoft.Azure.Cosmos/src/GatewayAccountReader.cs b/Microsoft.Azure.Cosmos/src/GatewayAccountReader.cs
index 459ac9c5f5..d1a77c7167 100644
--- a/Microsoft.Azure.Cosmos/src/GatewayAccountReader.cs
+++ b/Microsoft.Azure.Cosmos/src/GatewayAccountReader.cs
@@ -8,6 +8,7 @@ namespace Microsoft.Azure.Cosmos
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
+ using Microsoft.Azure.Cosmos.Core.Trace;
using Microsoft.Azure.Cosmos.Resource.CosmosExceptions;
using Microsoft.Azure.Cosmos.Routing;
using Microsoft.Azure.Cosmos.Tracing;
@@ -55,31 +56,8 @@ await this.cosmosAuthorization.AddAuthorizationHeaderAsync(
try
{
- using (DocumentServiceRequest request = DocumentServiceRequest.Create(
- operationType: OperationType.Read,
- resourceType: ResourceType.DatabaseAccount,
- authorizationTokenType: AuthorizationTokenType.PrimaryMasterKey))
- {
- if (this.isThinClientEnabled)
- {
- headers.Add(
- ThinClientConstants.EnableThinClientEndpointDiscoveryHeaderName,
- this.isThinClientEnabled.ToString());
- }
-
- using (HttpResponseMessage responseMessage = await this.httpClient.GetAsync(
- uri: serviceEndpoint,
- additionalHeaders: headers,
- resourceType: ResourceType.DatabaseAccount,
- timeoutPolicy: HttpTimeoutPolicyControlPlaneRead.Instance,
- clientSideRequestStatistics: stats,
- cancellationToken: default,
- documentServiceRequest: request))
- using (DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(responseMessage))
- {
- return CosmosResource.FromStream(documentServiceResponse);
- }
- }
+ return await this.ExecuteAccountReadWithRevocationRetryAsync(
+ serviceEndpoint, headers, stats);
}
catch (ObjectDisposedException) when (this.cancellationToken.IsCancellationRequested)
{
@@ -100,6 +78,66 @@ await this.cosmosAuthorization.AddAuthorizationHeaderAsync(
}
}
+ private async Task ExecuteAccountReadWithRevocationRetryAsync(
+ Uri serviceEndpoint,
+ INameValueCollection headers,
+ IClientSideRequestStatistics stats)
+ {
+ try
+ {
+ return await this.ExecuteAccountReadAsync(serviceEndpoint, headers, stats);
+ }
+ catch (DocumentClientException dce)
+ when (AuthorizationTokenProviderTokenCredential.TryHandleRevocationException(
+ this.cosmosAuthorization, dce))
+ {
+ DefaultTrace.TraceInformation(
+ "GatewayAccountReader: AAD token revocation detected on account read. Retrying with fresh token.");
+
+ // Re-add authorization header with the fresh token (cache was reset by TryHandleRevocationException)
+ headers.Remove(HttpConstants.HttpHeaders.Authorization);
+ await this.cosmosAuthorization.AddAuthorizationHeaderAsync(
+ headersCollection: headers,
+ serviceEndpoint,
+ HttpConstants.HttpMethods.Get,
+ AuthorizationTokenType.PrimaryMasterKey);
+
+ return await this.ExecuteAccountReadAsync(serviceEndpoint, headers, stats);
+ }
+ }
+
+ private async Task ExecuteAccountReadAsync(
+ Uri serviceEndpoint,
+ INameValueCollection headers,
+ IClientSideRequestStatistics stats)
+ {
+ using (DocumentServiceRequest request = DocumentServiceRequest.Create(
+ operationType: OperationType.Read,
+ resourceType: ResourceType.DatabaseAccount,
+ authorizationTokenType: AuthorizationTokenType.PrimaryMasterKey))
+ {
+ if (this.isThinClientEnabled)
+ {
+ headers.Set(
+ ThinClientConstants.EnableThinClientEndpointDiscoveryHeaderName,
+ this.isThinClientEnabled.ToString());
+ }
+
+ using (HttpResponseMessage responseMessage = await this.httpClient.GetAsync(
+ uri: serviceEndpoint,
+ additionalHeaders: headers,
+ resourceType: ResourceType.DatabaseAccount,
+ timeoutPolicy: HttpTimeoutPolicyControlPlaneRead.Instance,
+ clientSideRequestStatistics: stats,
+ cancellationToken: default,
+ documentServiceRequest: request))
+ using (DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(responseMessage))
+ {
+ return CosmosResource.FromStream(documentServiceResponse);
+ }
+ }
+ }
+
public async Task InitializeReaderAsync()
{
AccountProperties databaseAccount = await GlobalEndpointManager.GetDatabaseAccountFromAnyLocationsAsync(
diff --git a/Microsoft.Azure.Cosmos/src/RetryPolicy.cs b/Microsoft.Azure.Cosmos/src/RetryPolicy.cs
index f66841dd17..e523e30d61 100644
--- a/Microsoft.Azure.Cosmos/src/RetryPolicy.cs
+++ b/Microsoft.Azure.Cosmos/src/RetryPolicy.cs
@@ -13,25 +13,28 @@ internal sealed class RetryPolicy : IRetryPolicyFactory
private readonly GlobalPartitionEndpointManager partitionKeyRangeLocationCache;
private readonly GlobalEndpointManager globalEndpointManager;
private readonly bool enableEndpointDiscovery;
- private readonly bool isPartitionLevelFailoverEnabled;
+ private readonly bool isPartitionLevelFailoverEnabled;
private readonly bool isThinClientEnabled;
private readonly RetryOptions retryOptions;
+ private readonly AuthorizationTokenProvider authorizationTokenProvider;
///
/// Initialize the instance of the RetryPolicy class
///
public RetryPolicy(
- GlobalEndpointManager globalEndpointManager,
+ GlobalEndpointManager globalEndpointManager,
ConnectionPolicy connectionPolicy,
- GlobalPartitionEndpointManager partitionKeyRangeLocationCache,
- bool isThinClientEnabled)
+ GlobalPartitionEndpointManager partitionKeyRangeLocationCache,
+ bool isThinClientEnabled,
+ AuthorizationTokenProvider authorizationTokenProvider = null)
{
this.enableEndpointDiscovery = connectionPolicy.EnableEndpointDiscovery;
- this.isPartitionLevelFailoverEnabled = connectionPolicy.EnablePartitionLevelFailover;
+ this.isPartitionLevelFailoverEnabled = connectionPolicy.EnablePartitionLevelFailover;
this.globalEndpointManager = globalEndpointManager;
this.retryOptions = connectionPolicy.RetryOptions;
- this.partitionKeyRangeLocationCache = partitionKeyRangeLocationCache;
+ this.partitionKeyRangeLocationCache = partitionKeyRangeLocationCache;
this.isThinClientEnabled = isThinClientEnabled;
+ this.authorizationTokenProvider = authorizationTokenProvider;
}
///
@@ -43,8 +46,9 @@ public IDocumentClientRetryPolicy GetRequestPolicy()
this.globalEndpointManager,
this.partitionKeyRangeLocationCache,
this.retryOptions,
- this.enableEndpointDiscovery,
- this.isThinClientEnabled);
+ this.enableEndpointDiscovery,
+ this.isThinClientEnabled,
+ this.authorizationTokenProvider);
return clientRetryPolicy;
}
diff --git a/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs b/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs
index 1df09fa6b5..8db3c88717 100644
--- a/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs
+++ b/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs
@@ -48,6 +48,7 @@ internal class GatewayAddressCache : IAddressCache, IDisposable
private readonly Protocol protocol;
private readonly string protocolFilter;
private readonly ICosmosAuthorizationTokenProvider tokenProvider;
+ private readonly AuthorizationTokenProvider authorizationTokenProvider;
private readonly bool enableTcpConnectionEndpointRediscovery;
private readonly SemaphoreSlim semaphore;
@@ -72,11 +73,13 @@ public GatewayAddressCache(
long suboptimalPartitionForceRefreshIntervalInSeconds = 600,
bool enableTcpConnectionEndpointRediscovery = false,
bool replicaAddressValidationEnabled = false,
- bool enableAsyncCacheExceptionNoSharing = true)
+ bool enableAsyncCacheExceptionNoSharing = true,
+ AuthorizationTokenProvider authorizationTokenProvider = null)
{
this.addressEndpoint = new Uri(serviceEndpoint + "/" + Paths.AddressPathSegment);
this.protocol = protocol;
this.tokenProvider = tokenProvider;
+ this.authorizationTokenProvider = authorizationTokenProvider;
this.serviceEndpoint = serviceEndpoint;
this.serviceConfigReader = serviceConfigReader;
this.serverPartitionAddressCache = new AsyncCacheNonBlocking(enableAsyncCacheExceptionNoSharing);
@@ -780,17 +783,51 @@ private async Task GetMasterAddressesViaGatewayAsync(
}
}
- using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
- uri: targetEndpoint,
- additionalHeaders: headers,
- resourceType: resourceType,
- timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout,
- clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
- cancellationToken: default))
+ try
+ {
+ using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
+ uri: targetEndpoint,
+ additionalHeaders: headers,
+ resourceType: resourceType,
+ timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout,
+ clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
+ cancellationToken: default))
+ {
+ DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
+ GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
+ return documentServiceResponse;
+ }
+ }
+ catch (DocumentClientException dce)
+ when (AuthorizationTokenProviderTokenCredential.TryHandleRevocationException(
+ this.authorizationTokenProvider, dce))
{
- DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
- GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
- return documentServiceResponse;
+ DefaultTrace.TraceInformation(
+ "GatewayAddressCache: AAD token revocation detected on master address resolution. Retrying.");
+
+ headers.Set(HttpConstants.HttpHeaders.XDate, Rfc1123DateTimeCache.UtcNow());
+ string retryToken = await this.tokenProvider.GetUserAuthorizationTokenAsync(
+ resourceAddress,
+ resourceTypeToSign,
+ HttpConstants.HttpMethods.Get,
+ headers,
+ AuthorizationTokenType.PrimaryMasterKey,
+ trace);
+
+ headers.Set(HttpConstants.HttpHeaders.Authorization, retryToken);
+
+ using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
+ uri: targetEndpoint,
+ additionalHeaders: headers,
+ resourceType: resourceType,
+ timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout,
+ clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
+ cancellationToken: default))
+ {
+ DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
+ GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
+ return documentServiceResponse;
+ }
}
}
}
@@ -886,17 +923,51 @@ private async Task GetServerAddressesViaGatewayAsync(
}
}
- using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
- uri: targetEndpoint,
- additionalHeaders: headers,
- resourceType: ResourceType.Document,
- timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout,
- clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
- cancellationToken: default))
+ try
{
- DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
- GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
- return documentServiceResponse;
+ using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
+ uri: targetEndpoint,
+ additionalHeaders: headers,
+ resourceType: ResourceType.Document,
+ timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout,
+ clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
+ cancellationToken: default))
+ {
+ DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
+ GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
+ return documentServiceResponse;
+ }
+ }
+ catch (DocumentClientException dce)
+ when (AuthorizationTokenProviderTokenCredential.TryHandleRevocationException(
+ this.authorizationTokenProvider, dce))
+ {
+ DefaultTrace.TraceInformation(
+ "GatewayAddressCache: AAD token revocation detected on server address resolution. Retrying.");
+
+ headers.Set(HttpConstants.HttpHeaders.XDate, Rfc1123DateTimeCache.UtcNow());
+ string retryToken = await this.tokenProvider.GetUserAuthorizationTokenAsync(
+ collectionRid,
+ resourceTypeToSign,
+ HttpConstants.HttpMethods.Get,
+ headers,
+ AuthorizationTokenType.PrimaryMasterKey,
+ trace);
+
+ headers.Set(HttpConstants.HttpHeaders.Authorization, retryToken);
+
+ using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
+ uri: targetEndpoint,
+ additionalHeaders: headers,
+ resourceType: ResourceType.Document,
+ timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout,
+ clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
+ cancellationToken: default))
+ {
+ DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
+ GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
+ return documentServiceResponse;
+ }
}
}
}
diff --git a/Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs b/Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs
index 88e142e4d3..d3d62e6006 100644
--- a/Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs
+++ b/Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs
@@ -41,6 +41,7 @@ internal sealed class GlobalAddressResolver : IAddressResolverExtension, IDispos
private readonly bool isReplicaAddressValidationEnabled;
private readonly bool enableAsyncCacheExceptionNoSharing;
private readonly IConnectionStateListener connectionStateListener;
+ private readonly AuthorizationTokenProvider authorizationTokenProvider;
private IOpenConnectionsHandler openConnectionsHandler;
public GlobalAddressResolver(
@@ -54,7 +55,8 @@ public GlobalAddressResolver(
ConnectionPolicy connectionPolicy,
CosmosHttpClient httpClient,
IConnectionStateListener connectionStateListener,
- bool enableAsyncCacheExceptionNoSharing = true)
+ bool enableAsyncCacheExceptionNoSharing = true,
+ AuthorizationTokenProvider authorizationTokenProvider = null)
{
this.endpointManager = endpointManager;
this.partitionKeyRangeLocationCache = partitionKeyRangeLocationCache;
@@ -65,6 +67,7 @@ public GlobalAddressResolver(
this.serviceConfigReader = serviceConfigReader;
this.httpClient = httpClient;
this.connectionStateListener = connectionStateListener;
+ this.authorizationTokenProvider = authorizationTokenProvider;
int maxBackupReadEndpoints =
!connectionPolicy.EnableReadRequestsFallback.HasValue || connectionPolicy.EnableReadRequestsFallback.Value
@@ -349,7 +352,8 @@ private EndpointCache GetOrAddEndpoint(Uri endpoint)
this.connectionStateListener,
enableTcpConnectionEndpointRediscovery: this.enableTcpConnectionEndpointRediscovery,
replicaAddressValidationEnabled: this.isReplicaAddressValidationEnabled,
- enableAsyncCacheExceptionNoSharing: this.enableAsyncCacheExceptionNoSharing);
+ enableAsyncCacheExceptionNoSharing: this.enableAsyncCacheExceptionNoSharing,
+ authorizationTokenProvider: this.authorizationTokenProvider);
string location = this.endpointManager.GetLocation(endpoint);
AddressResolver addressResolver = new AddressResolver(null, new NullRequestSigner(), location);
diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs
index b2c586b7d8..aa7595fce3 100644
--- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs
+++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs
@@ -5,8 +5,10 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
{
using System;
using System.Collections.Generic;
- using System.Globalization;
- using System.Net;
+ using System.Globalization;
+ using System.Linq;
+ using System.Net;
+ using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -14,8 +16,8 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using Documents.Client;
using global::Azure;
using global::Azure.Core;
- using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.IdentityModel.Tokens;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
using static Microsoft.Azure.Cosmos.SDK.EmulatorTests.TransportClientHelper;
[TestClass]
@@ -443,5 +445,198 @@ void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token)
Assert.AreEqual(1, accountScopeCount, "Account scope must be used exactly once.");
Assert.AreEqual(0, cosmosScopeCount, "Cosmos scope must not be used (no fallback).");
}
+
+ ///
+ /// Generates a WWW-Authenticate header value matching the server's AadTokenRevocationHelper format.
+ /// Format: Bearer realm="", authorization_uri="", error="insufficient_claims", claims=""
+ /// where claims is base64 of: {"access_token":{"nbf":{"essential":false,"value":""}}}
+ ///
+ private static string GenerateWwwAuthenticateHeaderValue()
+ {
+ long currentTimestamp = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
+ string claimsChallengeJson = "{\"access_token\":{\"nbf\":{\"essential\":false,\"value\":\"" + currentTimestamp.ToString() + "\"}}}";
+ string base64Claims = Convert.ToBase64String(Encoding.UTF8.GetBytes(claimsChallengeJson));
+ return "Bearer " + string.Join(", ",
+ "realm=\"\"",
+ "authorization_uri=\"\"",
+ "error=\"insufficient_claims\"",
+ "claims=\"" + base64Claims + "\"");
+ }
+
+ [TestMethod]
+ public async Task AadTokenRevocation_WithMockedServerResponse_ShouldTriggerTokenRefresh()
+ {
+ string databaseId = Guid.NewGuid().ToString();
+ string containerId = Guid.NewGuid().ToString();
+
+ using CosmosClient setupClient = TestCommon.CreateCosmosClient();
+ Database database = null;
+
+ try
+ {
+ database = await setupClient.CreateDatabaseIfNotExistsAsync(databaseId);
+ await database.CreateContainerIfNotExistsAsync(containerId, "/id");
+
+ (string endpoint, string authKey) = TestCommon.GetAccountInfo();
+
+ List tokenRequests = new List();
+ bool hasReturnedUnauthorized = false;
+
+ void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token)
+ {
+ tokenRequests.Add(context);
+ }
+
+ LocalEmulatorTokenCredential tokenCredential = new LocalEmulatorTokenCredential(
+ expectedScope: "https://127.0.0.1/.default",
+ masterKey: authKey,
+ getTokenCallback: GetAadTokenCallBack);
+
+ HttpClientHandlerHelper httpHandler = new HttpClientHandlerHelper
+ {
+ RequestCallBack = (request, cancellationToken) =>
+ {
+ bool isDocumentCreate = request.Method == HttpMethod.Post
+ && request.RequestUri.PathAndQuery.Contains("/docs");
+
+ if (isDocumentCreate && !hasReturnedUnauthorized)
+ {
+ hasReturnedUnauthorized = true;
+
+ // Return fake 401/5013 with WWW-Authenticate WITHOUT forwarding to the server
+ HttpResponseMessage unauthorizedResponse = new HttpResponseMessage(HttpStatusCode.Unauthorized)
+ {
+ RequestMessage = request,
+ Content = new StringContent("{\"code\":\"Unauthorized\",\"message\":\"Provided AAD token has been revoked.\"}")
+ };
+ unauthorizedResponse.Headers.Add("x-ms-substatus", ((int)Documents.SubStatusCodes.AadTokenRevoked).ToString());
+ unauthorizedResponse.Headers.Add(
+ "WWW-Authenticate",
+ CosmosAadTests.GenerateWwwAuthenticateHeaderValue());
+
+ return Task.FromResult(unauthorizedResponse);
+ }
+
+ // All other requests pass through to the real server
+ return null;
+ }
+ };
+
+ CosmosClientOptions clientOptions = new CosmosClientOptions()
+ {
+ ConnectionMode = ConnectionMode.Gateway,
+ HttpClientFactory = () => new HttpClient(httpHandler),
+ };
+
+ using (CosmosClient aadClient = new CosmosClient(endpoint, tokenCredential, clientOptions))
+ {
+ Container aadContainer = aadClient.GetContainer(databaseId, containerId);
+ tokenRequests.Clear();
+
+ ToDoActivity item = ToDoActivity.CreateRandomToDoActivity();
+
+ // First attempt: SDK uses cached token → handler returns fake 401/5013 (request never reaches server)
+ // SDK detects revocation → extracts claims → resets cache → gets fresh token → retries
+ // Second attempt: real request reaches server → document created → 201
+ ItemResponse response = await aadContainer.CreateItemAsync(item, new PartitionKey(item.id));
+ Assert.AreEqual(HttpStatusCode.Created, response.StatusCode, "Retry with fresh token should succeed.");
+
+ // Validate that 401 was simulated
+ Assert.IsTrue(hasReturnedUnauthorized, "Test should have returned 401 Unauthorized");
+
+ // Validate that the SDK requested a fresh token with claims challenge
+ Assert.IsTrue(tokenRequests.Count >= 1, "SDK should have requested a fresh token after revocation.");
+ Assert.IsTrue(
+ tokenRequests.Any(r => !string.IsNullOrEmpty(r.Claims) && r.Claims.Contains("nbf")),
+ "Retry token request should contain nbf claims from the server's claims challenge.");
+ }
+ }
+ finally
+ {
+ if (database != null)
+ {
+ await database.DeleteStreamAsync();
+ }
+ }
+ }
+
+ [TestMethod]
+ public async Task AadTokenRevocation_ExceedsMaxRetry_ShouldFail()
+ {
+ string databaseId = Guid.NewGuid().ToString();
+ string containerId = Guid.NewGuid().ToString();
+
+ using CosmosClient setupClient = TestCommon.CreateCosmosClient();
+ Database database = await setupClient.CreateDatabaseAsync(databaseId);
+
+ try
+ {
+ await database.CreateContainerAsync(containerId, "/id");
+ (string endpoint, string authKey) = TestCommon.GetAccountInfo();
+
+ int caeResponseCount = 0;
+
+ LocalEmulatorTokenCredential tokenCredential = new LocalEmulatorTokenCredential(
+ expectedScope: "https://127.0.0.1/.default",
+ masterKey: authKey);
+
+ HttpClientHandlerHelper httpHandler = new HttpClientHandlerHelper
+ {
+ RequestCallBack = (request, cancellationToken) =>
+ {
+ bool isDocumentCreate = request.Method == HttpMethod.Post
+ && request.RequestUri.PathAndQuery.Contains("/docs");
+
+ if (isDocumentCreate)
+ {
+ caeResponseCount++;
+
+ // Always return 401/5013 with claims challenge — never pass through
+ HttpResponseMessage caeResponse = new HttpResponseMessage(HttpStatusCode.Unauthorized)
+ {
+ RequestMessage = request,
+ Content = new StringContent("{\"code\":\"Unauthorized\",\"message\":\"Provided AAD token has been revoked.\"}")
+ };
+ caeResponse.Headers.Add("x-ms-substatus", ((int)Documents.SubStatusCodes.AadTokenRevoked).ToString());
+ caeResponse.Headers.Add(
+ "WWW-Authenticate",
+ CosmosAadTests.GenerateWwwAuthenticateHeaderValue());
+
+ return Task.FromResult(caeResponse);
+ }
+
+ return null;
+ }
+ };
+
+ CosmosClientOptions clientOptions = new CosmosClientOptions()
+ {
+ ConnectionMode = ConnectionMode.Gateway,
+ HttpClientFactory = () => new HttpClient(httpHandler),
+ };
+
+ using CosmosClient aadClient = new CosmosClient(endpoint, tokenCredential, clientOptions);
+
+ Container aadContainer = aadClient.GetContainer(databaseId, containerId);
+
+ ToDoActivity item = ToDoActivity.CreateRandomToDoActivity();
+
+ try
+ {
+ await aadContainer.CreateItemAsync(item, new PartitionKey(item.id));
+ Assert.Fail("Expected CosmosException after max CAE retries exceeded");
+ }
+ catch (CosmosException ex) when (ex.StatusCode == HttpStatusCode.Unauthorized)
+ {
+ // Expected - should fail after max retry (1 retry = 2 total attempts)
+ Assert.IsTrue(caeResponseCount <= 2,
+ $"Should stop after max retry. CAE responses: {caeResponseCount}");
+ }
+ }
+ finally
+ {
+ await database?.DeleteStreamAsync();
+ }
+ }
}
}
\ No newline at end of file
diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTokenRevocationTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTokenRevocationTests.cs
new file mode 100644
index 0000000000..9bad61acad
--- /dev/null
+++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTokenRevocationTests.cs
@@ -0,0 +1,399 @@
+//------------------------------------------------------------
+// Copyright (c) Microsoft Corporation. All rights reserved.
+//------------------------------------------------------------
+
+namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Linq;
+ using System.Net;
+ using System.Net.Http;
+ using System.Text;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Microsoft.Azure.Documents.Client;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ ///
+ /// Gateway Mode:
+ /// Operation Revocation
+ /// ------------------------------------------
+ /// Data-plane (doc CRUD, queries) YES
+ /// Container read/create/delete YES
+ /// Database read/create/delete YES
+ /// Account read (client init) YES
+ /// Collection metadata cache YES
+ /// Partition key ranges YES
+ ///
+ /// Direct Mode:
+ /// Operation Revocation
+ /// ------------------------------------------
+ /// Data-plane (doc CRUD, queries) NO
+ /// Container read/create/delete YES
+ /// Database read/create/delete YES
+ /// Account read (client init) YES
+ /// Address resolution YES
+ /// Collection metadata cache YES
+ /// Partition key ranges YES
+ ///
+ /// ThinClient Mode:
+ /// Operation Revocation
+ /// -------------------------------- ----------
+ /// Data-plane (doc CRUD, queries) NO
+ /// Container read/create/delete YES
+ /// Database read/create/delete YES
+ /// Account read (client init) YES
+ /// Address resolution YES
+ /// Collection metadata cache YES
+ /// Partition key ranges YES
+ ///
+ [TestClass]
+ public class CosmosAadTokenRevocationTests
+ {
+ private CosmosClient cosmosClient;
+ private Cosmos.Database database;
+ private Container container;
+
+ private static readonly string DatabaseId = string.Concat("RevocationTestDb_", Guid.NewGuid().ToString("N").AsSpan(0, 8));
+ private static readonly string ContainerId = "RevocationTestContainer";
+
+ [TestInitialize]
+ public async Task TestInitialize()
+ {
+ this.cosmosClient = TestCommon.CreateCosmosClient();
+ this.database = await this.cosmosClient.CreateDatabaseIfNotExistsAsync(DatabaseId);
+ this.container = await this.database.CreateContainerIfNotExistsAsync(ContainerId, "/id");
+ }
+
+ [TestCleanup]
+ public async Task TestCleanup()
+ {
+ if (this.database != null)
+ {
+ await this.database.DeleteStreamAsync();
+ }
+
+ this.cosmosClient?.Dispose();
+ }
+
+ ///
+ /// Gateway data-plane: CreateItemAsync goes through GatewayStoreModel → GatewayStoreClient (HTTP).
+ /// Handler fakes 401 on first POST /docs → SDK extracts claims → retries with fresh token → 201.
+ ///
+ [TestMethod]
+ public async Task Revocation_Gateway_DataPlane_ShouldRetryWithFreshToken()
+ {
+ await this.RunRevocationRetryTest(
+ connectionMode: ConnectionMode.Gateway,
+ targetPathContains: "/docs",
+ targetMethod: HttpMethod.Post,
+ executeOperation: async (c) =>
+ {
+ ToDoActivity item = ToDoActivity.CreateRandomToDoActivity();
+ ItemResponse r = await c.CreateItemAsync(item, new Cosmos.PartitionKey(item.id));
+ return r.StatusCode;
+ },
+ expectedStatusCode: HttpStatusCode.Created);
+ }
+
+ ///
+ /// Gateway mode: ReadContainerAsync (GET /colls/) is a container metadata read.
+ /// Always routes through GatewayStoreModel, even in direct mode.
+ /// Handler fakes 401 on first GET /colls/ → SDK extracts claims → retries → 200.
+ ///
+ [TestMethod]
+ public async Task Revocation_Gateway_ContainerRead_ShouldRetryWithFreshToken()
+ {
+ await this.RunRevocationRetryTest(
+ connectionMode: ConnectionMode.Gateway,
+ targetPathContains: "/colls/",
+ targetMethod: HttpMethod.Get,
+ executeOperation: async (c) =>
+ {
+ ContainerResponse r = await c.ReadContainerAsync();
+ return r.StatusCode;
+ },
+ expectedStatusCode: HttpStatusCode.OK);
+ }
+
+ ///
+ /// Account read (GET /) during client init — outside the handler pipeline.
+ /// GatewayAccountReader has its own catch-when revocation retry.
+ /// If init succeeds, revocation retry worked.
+ ///
+ [TestMethod]
+ public async Task Revocation_AccountRead_ShouldRetryWithFreshToken()
+ {
+ (string endpoint, string authKey) = TestCommon.GetAccountInfo();
+ int tokenCallCount = 0;
+ List claimsList = new List();
+
+ LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential(
+ expectedScope: "https://127.0.0.1/.default",
+ masterKey: authKey,
+ getTokenCallback: (ctx, ct) => { tokenCallCount++; claimsList.Add(ctx.Claims); });
+
+ RevocationSimulatingHandler handler = new RevocationSimulatingHandler(
+ new HttpClientHandler(),
+ targetPathEquals: "/",
+ targetMethod: HttpMethod.Get);
+
+ using CosmosClient aadClient = new CosmosClient(endpoint, credential,
+ new CosmosClientOptions
+ {
+ ConnectionMode = ConnectionMode.Gateway,
+ HttpClientFactory = () => new HttpClient(handler),
+ });
+
+ // Init succeeded — account read revocation retry worked
+ Container c = aadClient.GetContainer(DatabaseId, ContainerId);
+ ToDoActivity item = ToDoActivity.CreateRandomToDoActivity();
+ ItemResponse response = await c.CreateItemAsync(item, new Cosmos.PartitionKey(item.id));
+
+ Assert.AreEqual(HttpStatusCode.Created, response.StatusCode);
+ Assert.IsTrue(handler.SimulatedRevocationCount >= 1, "Should have simulated 401 on GET /.");
+ Assert.IsTrue(claimsList.Any(c2 => !string.IsNullOrEmpty(c2) && c2.Contains("nbf")),
+ "Retry token must include nbf claims.");
+ }
+
+ ///
+ /// Direct mode: container metadata read (GET /colls/) goes through gateway even in direct mode.
+ /// Handler fakes 401 → SDK retries with fresh token → 200.
+ ///
+ [TestMethod]
+ public async Task Revocation_Direct_ContainerRead_ShouldRetryWithFreshToken()
+ {
+ await this.RunRevocationRetryTest(
+ connectionMode: ConnectionMode.Direct,
+ targetPathContains: "/colls/",
+ targetMethod: HttpMethod.Get,
+ executeOperation: async (c) =>
+ {
+ ContainerResponse r = await c.ReadContainerAsync();
+ return r.StatusCode;
+ },
+ expectedStatusCode: HttpStatusCode.OK);
+ }
+
+ ///
+ /// Direct mode data-plane: CreateItemAsync goes via RNTBD directly to replicas.
+ /// The HTTP handler never sees the request, so no 401 is simulated.
+ /// This test confirms that direct data-plane is NOT subject to gateway revocation.
+ ///
+ [TestMethod]
+ public async Task Revocation_Direct_DataPlane_NotSubjectToGatewayRevocation()
+ {
+ (string endpoint, string authKey) = TestCommon.GetAccountInfo();
+
+ LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential(
+ expectedScope: "https://127.0.0.1/.default",
+ masterKey: authKey);
+
+ RevocationSimulatingHandler handler = new RevocationSimulatingHandler(
+ new HttpClientHandler(),
+ targetPathContains: "/docs",
+ targetMethod: HttpMethod.Post);
+
+ using CosmosClient aadClient = new CosmosClient(endpoint, credential,
+ new CosmosClientOptions
+ {
+ ConnectionMode = ConnectionMode.Direct,
+ ConnectionProtocol = Protocol.Tcp,
+ HttpClientFactory = () => new HttpClient(handler),
+ });
+
+ Container c = aadClient.GetContainer(DatabaseId, ContainerId);
+ ToDoActivity item = ToDoActivity.CreateRandomToDoActivity();
+ ItemResponse response = await c.CreateItemAsync(item, new Cosmos.PartitionKey(item.id));
+
+ // Request succeeds because RNTBD bypasses the HTTP handler entirely
+ Assert.AreEqual(HttpStatusCode.Created, response.StatusCode);
+ Assert.AreEqual(0, handler.SimulatedRevocationCount,
+ "Direct data-plane should NOT hit the HTTP handler — RNTBD bypasses the gateway.");
+ }
+
+ ///
+ /// Retry exhaustion: handler always returns 401 on every doc POST.
+ /// SDK retries once with fresh token (max retry = 1), then gives up.
+ /// Verifies no infinite retry loops.
+ ///
+ [TestMethod]
+ [DataRow(ConnectionMode.Gateway)]
+ public async Task Revocation_RetryExhausted_ShouldFailAfterOneRetry(ConnectionMode connectionMode)
+ {
+ (string endpoint, string authKey) = TestCommon.GetAccountInfo();
+ int tokenCallCount = 0;
+ List claimsList = new List();
+
+ LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential(
+ expectedScope: "https://127.0.0.1/.default",
+ masterKey: authKey,
+ getTokenCallback: (ctx, ct) => { tokenCallCount++; claimsList.Add(ctx.Claims); });
+
+ AlwaysRevokingHandler handler = new AlwaysRevokingHandler(new HttpClientHandler());
+
+ using CosmosClient aadClient = new CosmosClient(endpoint, credential,
+ new CosmosClientOptions
+ {
+ ConnectionMode = connectionMode,
+ ConnectionProtocol = connectionMode == ConnectionMode.Direct ? Protocol.Tcp : Protocol.Https,
+ HttpClientFactory = () => new HttpClient(handler),
+ });
+
+ Container c = aadClient.GetContainer(DatabaseId, ContainerId);
+ int tokenCallsAfterInit = tokenCallCount;
+
+ try
+ {
+ ToDoActivity item = ToDoActivity.CreateRandomToDoActivity();
+ await c.CreateItemAsync(item, new Cosmos.PartitionKey(item.id));
+ Assert.Fail("Should have thrown after retry exhaustion.");
+ }
+ catch (CosmosException ex)
+ {
+ Assert.AreEqual(HttpStatusCode.Unauthorized, ex.StatusCode);
+ Assert.AreEqual(2, handler.SimulatedRevocationCount,
+ "Exactly 2 simulated 401s: original + one retry, no more.");
+ Assert.IsTrue(claimsList.Skip(tokenCallsAfterInit)
+ .Any(c2 => !string.IsNullOrEmpty(c2) && c2.Contains("nbf")),
+ "Should have requested fresh token with claims before giving up.");
+ }
+ }
+
+ private async Task RunRevocationRetryTest(
+ ConnectionMode connectionMode,
+ string targetPathContains,
+ HttpMethod targetMethod,
+ Func> executeOperation,
+ HttpStatusCode expectedStatusCode)
+ {
+ (string endpoint, string authKey) = TestCommon.GetAccountInfo();
+ int tokenCallCount = 0;
+ List claimsList = new List();
+
+ LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential(
+ expectedScope: "https://127.0.0.1/.default",
+ masterKey: authKey,
+ getTokenCallback: (ctx, ct) => { tokenCallCount++; claimsList.Add(ctx.Claims); });
+
+ RevocationSimulatingHandler handler = new RevocationSimulatingHandler(
+ new HttpClientHandler(),
+ targetPathContains: targetPathContains,
+ targetMethod: targetMethod);
+
+ using CosmosClient aadClient = new CosmosClient(endpoint, credential,
+ new CosmosClientOptions
+ {
+ ConnectionMode = connectionMode,
+ ConnectionProtocol = connectionMode == ConnectionMode.Direct ? Protocol.Tcp : Protocol.Https,
+ HttpClientFactory = () => new HttpClient(handler),
+ });
+
+ Container c = aadClient.GetContainer(DatabaseId, ContainerId);
+ int tokenCallsAfterInit = tokenCallCount;
+
+ HttpStatusCode actualStatusCode = await executeOperation(c);
+ int retryTokenCalls = tokenCallCount - tokenCallsAfterInit;
+
+ Console.WriteLine($" StatusCode: {actualStatusCode}");
+ Console.WriteLine($" Simulated 401s: {handler.SimulatedRevocationCount}");
+ Console.WriteLine($" Token calls during test: {retryTokenCalls}");
+ for (int i = 0; i < handler.RequestLog.Count; i++)
+ {
+ (string method, string path, bool simulated) = handler.RequestLog[i];
+ Console.WriteLine($" Request[{i}]: {method} {path} {(simulated ? "→ SIMULATED 401" : "→ passthrough")}");
+ }
+
+ Assert.AreEqual(expectedStatusCode, actualStatusCode, "Request should succeed after revocation retry.");
+ Assert.IsTrue(handler.SimulatedRevocationCount >= 1,
+ "Handler should have intercepted and returned a fake 401/5013 with WWW-Authenticate.");
+ Assert.IsTrue(retryTokenCalls >= 1,
+ "SDK should have reset token cache and called credential again for a fresh token.");
+ Assert.IsTrue(claimsList.Skip(tokenCallsAfterInit)
+ .Any(c2 => !string.IsNullOrEmpty(c2) && c2.Contains("nbf")),
+ "Fresh token request should contain merged claims: nbf (from server challenge) + xms_cc (SDK cp1).");
+ }
+
+ private class RevocationSimulatingHandler : DelegatingHandler
+ {
+ private readonly string targetPathContains;
+ private readonly string targetPathEquals;
+ private readonly HttpMethod targetMethod;
+ private bool hasSimulated401;
+
+ public int SimulatedRevocationCount { get; private set; }
+ public List<(string method, string path, bool simulated)> RequestLog { get; }
+ = new List<(string, string, bool)>();
+
+ public RevocationSimulatingHandler(
+ HttpMessageHandler innerHandler,
+ HttpMethod targetMethod,
+ string targetPathContains = null,
+ string targetPathEquals = null)
+ : base(innerHandler)
+ {
+ this.targetMethod = targetMethod;
+ this.targetPathContains = targetPathContains;
+ this.targetPathEquals = targetPathEquals;
+ }
+
+ protected override async Task SendAsync(
+ HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ string path = request.RequestUri?.AbsolutePath ?? "";
+ bool match = request.Method == this.targetMethod
+ && (this.targetPathEquals != null
+ ? (path == this.targetPathEquals || path == this.targetPathEquals + "/")
+ : this.targetPathContains != null && path.Contains(this.targetPathContains));
+
+ if (match && !this.hasSimulated401)
+ {
+ this.hasSimulated401 = true;
+ this.SimulatedRevocationCount++;
+ this.RequestLog.Add((request.Method.ToString(), path, true));
+ return CreateFake401Response();
+ }
+
+ this.RequestLog.Add((request.Method.ToString(), path, false));
+ return await base.SendAsync(request, cancellationToken);
+ }
+ }
+
+ private class AlwaysRevokingHandler : DelegatingHandler
+ {
+ public int SimulatedRevocationCount { get; private set; }
+
+ public AlwaysRevokingHandler(HttpMessageHandler innerHandler)
+ : base(innerHandler) { }
+
+ protected override async Task SendAsync(
+ HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ string path = request.RequestUri?.AbsolutePath ?? "";
+ if (request.Method == HttpMethod.Post && path.Contains("/docs"))
+ {
+ this.SimulatedRevocationCount++;
+ return CreateFake401Response();
+ }
+
+ return await base.SendAsync(request, cancellationToken);
+ }
+ }
+
+ private static HttpResponseMessage CreateFake401Response()
+ {
+ long ts = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
+ string claimsJson = "{\"access_token\":{\"nbf\":{\"essential\":false,\"value\":\"" + ts + "\"}}}";
+ string base64Claims = Convert.ToBase64String(Encoding.UTF8.GetBytes(claimsJson));
+ string wwwAuth = "Bearer realm=\"\", authorization_uri=\"\", error=\"insufficient_claims\", claims=\"" + base64Claims + "\"";
+
+ HttpResponseMessage response = new HttpResponseMessage(HttpStatusCode.Unauthorized);
+ response.Headers.TryAddWithoutValidation("x-ms-substatus", ((int)Documents.SubStatusCodes.AadTokenRevoked).ToString());
+ response.Headers.TryAddWithoutValidation("x-ms-activity-id", Guid.NewGuid().ToString());
+ response.Content = new StringContent("{\"code\":\"Unauthorized\",\"message\":\"Provided AAD token has been revoked.\"}");
+ response.Headers.TryAddWithoutValidation("WWW-Authenticate", wwwAuth);
+ return response;
+ }
+ }
+}
diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ClientRetryPolicyTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ClientRetryPolicyTests.cs
index 5c2b155120..4b96facb21 100644
--- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ClientRetryPolicyTests.cs
+++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ClientRetryPolicyTests.cs
@@ -14,13 +14,16 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
+ using AccessToken = global::Azure.Core.AccessToken;
+ using TokenCredential = global::Azure.Core.TokenCredential;
+ using TokenRequestContext = global::Azure.Core.TokenRequestContext;
using Microsoft.Azure.Documents.Collections;
using Microsoft.Azure.Documents.Client;
using Microsoft.Azure.Cosmos.Common;
using System.Net.Http;
using System.Reflection;
- using System.Collections.Concurrent;
-
+ using System.Collections.Concurrent;
+
///
/// Tests for
///
@@ -29,7 +32,7 @@ public sealed class ClientRetryPolicyTests
{
private static Uri Location1Endpoint = new Uri("https://location1.documents.azure.com");
private static Uri Location2Endpoint = new Uri("https://location2.documents.azure.com");
-
+
private const string HubRegionHeader = "x-ms-cosmos-hub-region-processing-only";
private ReadOnlyCollection preferredLocations;
private AccountProperties databaseAccount;
@@ -89,37 +92,37 @@ public void MultimasterMetadataWriteRetryTest()
retryPolicy.OnBeforeSendRequest(request);
Assert.AreEqual(request.RequestContext.LocationEndpointToRoute, ClientRetryPolicyTests.Location1Endpoint);
}
-
+
///
- /// Test to validate that when 429.3092 is thrown from the service, write requests on
+ /// Test to validate that when 429.3092 is thrown from the service, write requests on
/// a multi master account should be converted to 503 and retried to the next region.
///
- [TestMethod]
- [DataRow(true, DisplayName = "Validate retry policy with multi master write account.")]
+ [TestMethod]
+ [DataRow(true, DisplayName = "Validate retry policy with multi master write account.")]
[DataRow(false, DisplayName = "Validate retry policy with single master write account.")]
- public async Task ShouldRetryAsync_WhenRequestThrottledWithResourceNotAvailable_ShouldThrow503OnMultiMasterWriteAndRetryOnNextRegion(
+ public async Task ShouldRetryAsync_WhenRequestThrottledWithResourceNotAvailable_ShouldThrow503OnMultiMasterWriteAndRetryOnNextRegion(
bool isMultiMasterAccount)
- {
- // Arrange.
+ {
+ // Arrange.
const bool enableEndpointDiscovery = true;
using GlobalEndpointManager endpointManager = this.Initialize(
useMultipleWriteLocations: isMultiMasterAccount,
enableEndpointDiscovery: enableEndpointDiscovery,
isPreferredLocationsListEmpty: false,
- multimasterMetadataWriteRetryTest: true);
-
+ multimasterMetadataWriteRetryTest: true);
+
await endpointManager.RefreshLocationAsync();
- ClientRetryPolicy retryPolicy = new (
- endpointManager,
- this.partitionKeyRangeLocationCache,
- new RetryOptions(),
- enableEndpointDiscovery,
+ ClientRetryPolicy retryPolicy = new (
+ endpointManager,
+ this.partitionKeyRangeLocationCache,
+ new RetryOptions(),
+ enableEndpointDiscovery,
false);
// Creates a sample write request.
- DocumentServiceRequest request = this.CreateRequest(
- isReadRequest: false,
+ DocumentServiceRequest request = this.CreateRequest(
+ isReadRequest: false,
isMasterResourceType: false);
// On first attempt should get (default/non hub) location.
@@ -128,7 +131,7 @@ public async Task ShouldRetryAsync_WhenRequestThrottledWithResourceNotAvailable_
// Creation of 429.3092 Error.
HttpStatusCode throttleException = HttpStatusCode.TooManyRequests;
- SubStatusCodes resourceNotAvailable = SubStatusCodes.SystemResourceUnavailable;
+ SubStatusCodes resourceNotAvailable = SubStatusCodes.SystemResourceUnavailable;
Exception innerException = new ();
Mock nameValueCollection = new ();
@@ -141,39 +144,39 @@ public async Task ShouldRetryAsync_WhenRequestThrottledWithResourceNotAvailable_
responseHeaders: nameValueCollection.Object);
// Act.
- Task shouldRetry = retryPolicy.ShouldRetryAsync(
- documentClientException,
- new CancellationToken());
-
- // Assert.
- Assert.IsTrue(shouldRetry.Result.ShouldRetry);
- retryPolicy.OnBeforeSendRequest(request);
-
- if (isMultiMasterAccount)
- {
- Assert.AreEqual(
- expected: ClientRetryPolicyTests.Location2Endpoint,
- actual: request.RequestContext.LocationEndpointToRoute,
- message: "The request should be routed to the next region, since the accound is a multi master write account and the request" +
- "failed with 429.309 which got converted into 503 internally. This should trigger another retry attempt to the next region.");
+ Task shouldRetry = retryPolicy.ShouldRetryAsync(
+ documentClientException,
+ new CancellationToken());
+
+ // Assert.
+ Assert.IsTrue(shouldRetry.Result.ShouldRetry);
+ retryPolicy.OnBeforeSendRequest(request);
+
+ if (isMultiMasterAccount)
+ {
+ Assert.AreEqual(
+ expected: ClientRetryPolicyTests.Location2Endpoint,
+ actual: request.RequestContext.LocationEndpointToRoute,
+ message: "The request should be routed to the next region, since the accound is a multi master write account and the request" +
+ "failed with 429.309 which got converted into 503 internally. This should trigger another retry attempt to the next region.");
}
- else
- {
- Assert.AreEqual(
- expected: ClientRetryPolicyTests.Location1Endpoint,
- actual: request.RequestContext.LocationEndpointToRoute,
- message: "Since this is asingle master account, the write request should not be retried on the next region.");
+ else
+ {
+ Assert.AreEqual(
+ expected: ClientRetryPolicyTests.Location1Endpoint,
+ actual: request.RequestContext.LocationEndpointToRoute,
+ message: "Since this is asingle master account, the write request should not be retried on the next region.");
}
- }
+ }
///
/// Tests to see if different 503 substatus and other similar status codes are handeled correctly
///
/// The substatus code being Tested.
[DataRow((int)StatusCodes.ServiceUnavailable, (int)SubStatusCodes.Unknown, "ServiceUnavailable")]
- [DataRow((int)StatusCodes.ServiceUnavailable, (int)SubStatusCodes.TransportGenerated503, "ServiceUnavailable")]
- [DataRow((int)StatusCodes.InternalServerError, (int)SubStatusCodes.Unknown, "InternalServerError")]
- [DataRow((int)StatusCodes.Gone, (int)SubStatusCodes.LeaseNotFound, "LeaseNotFound")]
+ [DataRow((int)StatusCodes.ServiceUnavailable, (int)SubStatusCodes.TransportGenerated503, "ServiceUnavailable")]
+ [DataRow((int)StatusCodes.InternalServerError, (int)SubStatusCodes.Unknown, "InternalServerError")]
+ [DataRow((int)StatusCodes.Gone, (int)SubStatusCodes.LeaseNotFound, "LeaseNotFound")]
[DataRow((int)StatusCodes.Forbidden, (int)SubStatusCodes.DatabaseAccountNotFound, "DatabaseAccountNotFound")]
[DataTestMethod]
public void Http503LikeSubStatusHandelingTests(int statusCode, int SubStatusCode, string message)
@@ -187,8 +190,8 @@ public void Http503LikeSubStatusHandelingTests(int statusCode, int SubStatusCode
isPreferredLocationsListEmpty: true);
//Create Retry Policy
- ClientRetryPolicy retryPolicy = new ClientRetryPolicy(endpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery, false);
-
+ ClientRetryPolicy retryPolicy = new ClientRetryPolicy(endpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery, false);
+
CancellationToken cancellationToken = new CancellationToken();
Exception serviceUnavailableException = new Exception();
Mock nameValueCollection = new Mock();
@@ -207,153 +210,153 @@ public void Http503LikeSubStatusHandelingTests(int statusCode, int SubStatusCode
Task retryStatus = retryPolicy.ShouldRetryAsync(documentClientException, cancellationToken);
Assert.IsFalse(retryStatus.Result.ShouldRetry);
- }
-
+ }
+
///
- /// Tests to validate that when HttpRequestException is thrown while connecting to a gateway endpoint for a single master write account with PPAF enabled,
+ /// Tests to validate that when HttpRequestException is thrown while connecting to a gateway endpoint for a single master write account with PPAF enabled,
/// a partition level failover is added and the request is retried to the next region.
///
- [TestMethod]
+ [TestMethod]
[DataRow(true, DisplayName = "Case when partition level failover is enabled.")]
[DataRow(false, DisplayName = "Case when partition level failover is disabled.")]
- public void HttpRequestExceptionHandelingTests(
+ public void HttpRequestExceptionHandelingTests(
bool enablePartitionLevelFailover)
- {
- const bool enableEndpointDiscovery = true;
- const string suffix = "-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF";
-
- //Creates a sample write request
- DocumentServiceRequest request = this.CreateRequest(false, false);
+ {
+ const bool enableEndpointDiscovery = true;
+ const string suffix = "-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF";
+
+ //Creates a sample write request
+ DocumentServiceRequest request = this.CreateRequest(false, false);
request.RequestContext.ResolvedPartitionKeyRange = new PartitionKeyRange() { Id = "0" , MinInclusive = "3F" + suffix, MaxExclusive = "5F" + suffix };
//Create GlobalEndpointManager
using GlobalEndpointManager endpointManager = this.Initialize(
useMultipleWriteLocations: false,
enableEndpointDiscovery: enableEndpointDiscovery,
- isPreferredLocationsListEmpty: false,
+ isPreferredLocationsListEmpty: false,
enablePartitionLevelFailover: enablePartitionLevelFailover);
-
- // Capture the read locations.
- ReadOnlyCollection readLocations = endpointManager.ReadEndpoints;
+
+ // Capture the read locations.
+ ReadOnlyCollection readLocations = endpointManager.ReadEndpoints;
//Create Retry Policy
- ClientRetryPolicy retryPolicy = new (
- globalEndpointManager: endpointManager,
- partitionKeyRangeLocationCache: this.partitionKeyRangeLocationCache,
- retryOptions: new RetryOptions(),
- enableEndpointDiscovery: enableEndpointDiscovery,
- isThinClientEnabled: false);
-
+ ClientRetryPolicy retryPolicy = new (
+ globalEndpointManager: endpointManager,
+ partitionKeyRangeLocationCache: this.partitionKeyRangeLocationCache,
+ retryOptions: new RetryOptions(),
+ enableEndpointDiscovery: enableEndpointDiscovery,
+ isThinClientEnabled: false);
+
CancellationToken cancellationToken = new ();
- HttpRequestException httpRequestException = new (message: "Connecting to endpoint has failed.");
-
- GlobalPartitionEndpointManagerCore.PartitionKeyRangeFailoverInfo partitionKeyRangeFailoverInfo = ClientRetryPolicyTests.GetPartitionKeyRangeFailoverInfoUsingReflection(
- this.partitionKeyRangeLocationCache,
- request.RequestContext.ResolvedPartitionKeyRange,
- isReadOnlyOrMultiMasterWriteRequest: false);
-
- // Validate that the partition key range failover info is not present before the http request exception was captured in the retry policy.
- Assert.IsNull(partitionKeyRangeFailoverInfo);
-
+ HttpRequestException httpRequestException = new (message: "Connecting to endpoint has failed.");
+
+ GlobalPartitionEndpointManagerCore.PartitionKeyRangeFailoverInfo partitionKeyRangeFailoverInfo = ClientRetryPolicyTests.GetPartitionKeyRangeFailoverInfoUsingReflection(
+ this.partitionKeyRangeLocationCache,
+ request.RequestContext.ResolvedPartitionKeyRange,
+ isReadOnlyOrMultiMasterWriteRequest: false);
+
+ // Validate that the partition key range failover info is not present before the http request exception was captured in the retry policy.
+ Assert.IsNull(partitionKeyRangeFailoverInfo);
+
retryPolicy.OnBeforeSendRequest(request);
- Task retryStatus = retryPolicy.ShouldRetryAsync(httpRequestException, cancellationToken);
-
- Assert.IsTrue(retryStatus.Result.ShouldRetry);
-
- partitionKeyRangeFailoverInfo = ClientRetryPolicyTests.GetPartitionKeyRangeFailoverInfoUsingReflection(
- this.partitionKeyRangeLocationCache,
- request.RequestContext.ResolvedPartitionKeyRange,
- isReadOnlyOrMultiMasterWriteRequest: false);
-
- if (enablePartitionLevelFailover)
- {
- // Validate that the partition key range failover info to the next account region is present after the http request exception was captured in the retry policy.
- Assert.AreEqual(partitionKeyRangeFailoverInfo.Current, readLocations[1]);
- }
- else
- {
- Assert.IsNull(partitionKeyRangeFailoverInfo);
- }
- }
-
+ Task retryStatus = retryPolicy.ShouldRetryAsync(httpRequestException, cancellationToken);
+
+ Assert.IsTrue(retryStatus.Result.ShouldRetry);
+
+ partitionKeyRangeFailoverInfo = ClientRetryPolicyTests.GetPartitionKeyRangeFailoverInfoUsingReflection(
+ this.partitionKeyRangeLocationCache,
+ request.RequestContext.ResolvedPartitionKeyRange,
+ isReadOnlyOrMultiMasterWriteRequest: false);
+
+ if (enablePartitionLevelFailover)
+ {
+ // Validate that the partition key range failover info to the next account region is present after the http request exception was captured in the retry policy.
+ Assert.AreEqual(partitionKeyRangeFailoverInfo.Current, readLocations[1]);
+ }
+ else
+ {
+ Assert.IsNull(partitionKeyRangeFailoverInfo);
+ }
+ }
+
///
- /// Test to validate that when an OperationCanceledException is thrown during the retry attempt, for a single master write account with PPAF enabled,
+ /// Test to validate that when an OperationCanceledException is thrown during the retry attempt, for a single master write account with PPAF enabled,
/// a partition level failover is applied and the subsequent requests will be retried on the next region for the faulty partition.
///
- [TestMethod]
+ [TestMethod]
[DataRow(true, true, DisplayName = "Read Request - Case when partition level failover is enabled.")]
[DataRow(false, true, DisplayName = "Write Request - Case when partition level failover is enabled.")]
[DataRow(true, false, DisplayName = "Read Request - Case when partition level failover is disabled.")]
[DataRow(false, false, DisplayName = "Write Request - Case when partition level failover is disabled.")]
- public void CosmosOperationCancelledExceptionHandelingTests(
- bool isReadOnlyRequest,
+ public void CosmosOperationCancelledExceptionHandelingTests(
+ bool isReadOnlyRequest,
bool enablePartitionLevelFailover)
- {
- int requestThreshold = isReadOnlyRequest ? 10 : 5;
- const bool enableEndpointDiscovery = true;
- const string suffix = "-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF";
-
- //Creates a sample write request
- DocumentServiceRequest request = this.CreateRequest(isReadOnlyRequest, false);
+ {
+ int requestThreshold = isReadOnlyRequest ? 10 : 5;
+ const bool enableEndpointDiscovery = true;
+ const string suffix = "-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF-FF";
+
+ //Creates a sample write request
+ DocumentServiceRequest request = this.CreateRequest(isReadOnlyRequest, false);
request.RequestContext.ResolvedPartitionKeyRange = new PartitionKeyRange() { Id = "0", MinInclusive = "3F" + suffix, MaxExclusive = "5F" + suffix };
//Create GlobalEndpointManager
using GlobalEndpointManager endpointManager = this.Initialize(
useMultipleWriteLocations: false,
enableEndpointDiscovery: enableEndpointDiscovery,
- isPreferredLocationsListEmpty: false,
+ isPreferredLocationsListEmpty: false,
enablePartitionLevelFailover: enablePartitionLevelFailover);
-
- // Capture the read locations.
- ReadOnlyCollection readLocations = endpointManager.ReadEndpoints;
+
+ // Capture the read locations.
+ ReadOnlyCollection readLocations = endpointManager.ReadEndpoints;
//Create Retry Policy
- ClientRetryPolicy retryPolicy = new(
- globalEndpointManager: endpointManager,
- partitionKeyRangeLocationCache: this.partitionKeyRangeLocationCache,
- retryOptions: new RetryOptions(),
- enableEndpointDiscovery: enableEndpointDiscovery,
- isThinClientEnabled: false);
-
+ ClientRetryPolicy retryPolicy = new(
+ globalEndpointManager: endpointManager,
+ partitionKeyRangeLocationCache: this.partitionKeyRangeLocationCache,
+ retryOptions: new RetryOptions(),
+ enableEndpointDiscovery: enableEndpointDiscovery,
+ isThinClientEnabled: false);
+
CancellationToken cancellationToken = new();
- OperationCanceledException operationCancelledException = new(message: "Operation was cancelled due to cancellation token expiry.");
-
- GlobalPartitionEndpointManagerCore.PartitionKeyRangeFailoverInfo partitionKeyRangeFailoverInfo = ClientRetryPolicyTests.GetPartitionKeyRangeFailoverInfoUsingReflection(
- this.partitionKeyRangeLocationCache,
- request.RequestContext.ResolvedPartitionKeyRange,
- isReadOnlyOrMultiMasterWriteRequest: isReadOnlyRequest);
-
- // Validate that the partition key range failover info is not present before the http request exception was captured in the retry policy.
- Assert.IsNull(partitionKeyRangeFailoverInfo);
-
- Task retryStatus;
-
- // With cancellation token expiry, the retry policy should not failover the offending partition
- // until the write threshold is met.
- for (int i=0; i< requestThreshold; i++)
- {
- retryPolicy.OnBeforeSendRequest(request);
- retryStatus = retryPolicy.ShouldRetryAsync(operationCancelledException, cancellationToken);
- }
-
- retryStatus = retryPolicy.ShouldRetryAsync(operationCancelledException, cancellationToken);
- Assert.IsFalse(retryStatus.Result.ShouldRetry);
-
- partitionKeyRangeFailoverInfo = ClientRetryPolicyTests.GetPartitionKeyRangeFailoverInfoUsingReflection(
- this.partitionKeyRangeLocationCache,
- request.RequestContext.ResolvedPartitionKeyRange,
- isReadOnlyOrMultiMasterWriteRequest: isReadOnlyRequest);
-
- if (enablePartitionLevelFailover)
- {
- // Validate that the partition key range failover info to the next account region is present after the http request exception was captured in the retry policy.
- Assert.IsNotNull(partitionKeyRangeFailoverInfo);
- Assert.AreEqual(partitionKeyRangeFailoverInfo.Current, readLocations[1]);
- }
- else
- {
- Assert.IsNull(partitionKeyRangeFailoverInfo);
- }
+ OperationCanceledException operationCancelledException = new(message: "Operation was cancelled due to cancellation token expiry.");
+
+ GlobalPartitionEndpointManagerCore.PartitionKeyRangeFailoverInfo partitionKeyRangeFailoverInfo = ClientRetryPolicyTests.GetPartitionKeyRangeFailoverInfoUsingReflection(
+ this.partitionKeyRangeLocationCache,
+ request.RequestContext.ResolvedPartitionKeyRange,
+ isReadOnlyOrMultiMasterWriteRequest: isReadOnlyRequest);
+
+ // Validate that the partition key range failover info is not present before the http request exception was captured in the retry policy.
+ Assert.IsNull(partitionKeyRangeFailoverInfo);
+
+ Task retryStatus;
+
+ // With cancellation token expiry, the retry policy should not failover the offending partition
+ // until the write threshold is met.
+ for (int i=0; i< requestThreshold; i++)
+ {
+ retryPolicy.OnBeforeSendRequest(request);
+ retryStatus = retryPolicy.ShouldRetryAsync(operationCancelledException, cancellationToken);
+ }
+
+ retryStatus = retryPolicy.ShouldRetryAsync(operationCancelledException, cancellationToken);
+ Assert.IsFalse(retryStatus.Result.ShouldRetry);
+
+ partitionKeyRangeFailoverInfo = ClientRetryPolicyTests.GetPartitionKeyRangeFailoverInfoUsingReflection(
+ this.partitionKeyRangeLocationCache,
+ request.RequestContext.ResolvedPartitionKeyRange,
+ isReadOnlyOrMultiMasterWriteRequest: isReadOnlyRequest);
+
+ if (enablePartitionLevelFailover)
+ {
+ // Validate that the partition key range failover info to the next account region is present after the http request exception was captured in the retry policy.
+ Assert.IsNotNull(partitionKeyRangeFailoverInfo);
+ Assert.AreEqual(partitionKeyRangeFailoverInfo.Current, readLocations[1]);
+ }
+ else
+ {
+ Assert.IsNull(partitionKeyRangeFailoverInfo);
+ }
}
[TestMethod]
@@ -402,8 +405,8 @@ public async Task ClientRetryPolicy_NoRetry_MultiMaster_Read_NoPreferredLocation
public async Task ClientRetryPolicy_NoRetry_MultiMaster_Write_NoPreferredLocationsAsync()
{
await this.ValidateConnectTimeoutTriggersClientRetryPolicyAsync(isReadRequest: false, useMultipleWriteLocations: true, usesPreferredLocations: false, true);
- }
-
+ }
+
///
/// Test to validate that hub region header is added on 404/1002 for single master accounts only,
/// starting from the second retry (after first retry also fails). For multi-master accounts,
@@ -851,6 +854,111 @@ public async Task ClientRetryPolicy_NullSharedContext_LocalFlagStillWorks()
Assert.AreEqual(bool.TrueString, headerValues[0]);
}
+ [TestMethod]
+ public async Task ClientRetryPolicy_TokenRevocationWithClaims_ShouldRetryOnceWithTokenCredential()
+ {
+ const bool enableEndpointDiscovery = true;
+ using GlobalEndpointManager endpointManager = this.Initialize(
+ useMultipleWriteLocations: false,
+ enableEndpointDiscovery: enableEndpointDiscovery,
+ isPreferredLocationsListEmpty: false);
+
+ Mock mockTokenCredential = new Mock();
+ mockTokenCredential
+ .Setup(x => x.GetTokenAsync(It.IsAny(), It.IsAny()))
+ .ReturnsAsync(new AccessToken("test-token", DateTimeOffset.UtcNow.AddHours(1)));
+
+ using AuthorizationTokenProviderTokenCredential tokenProvider = new AuthorizationTokenProviderTokenCredential(
+ mockTokenCredential.Object,
+ new Uri("https://test-account.documents.azure.com"),
+ backgroundTokenCredentialRefreshInterval: TimeSpan.FromMinutes(5),
+ tokenToAuthorizationHeader: AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature);
+
+ ClientRetryPolicy retryPolicy = new ClientRetryPolicy(
+ endpointManager,
+ this.partitionKeyRangeLocationCache,
+ new RetryOptions(),
+ enableEndpointDiscovery,
+ isThinClientEnabled: false,
+ authorizationTokenProvider: tokenProvider);
+
+ DocumentServiceRequest request = this.CreateRequest(isReadRequest: false, isMasterResourceType: false);
+ retryPolicy.OnBeforeSendRequest(request);
+
+ StoreResponseNameValueCollection responseHeaders = new StoreResponseNameValueCollection();
+ responseHeaders.Set(
+ HttpConstants.HttpHeaders.WwwAuthenticate,
+ "Bearer error=\"insufficient_claims\", claims=\"eyJhY2Nlc3NfdG9rZW4iOnt9fQ==\"");
+
+ DocumentClientException revocationException = new DocumentClientException(
+ message: "AAD token revocation",
+ innerException: null,
+ statusCode: HttpStatusCode.Unauthorized,
+ substatusCode: SubStatusCodes.AadTokenRevoked,
+ requestUri: request.RequestContext.LocationEndpointToRoute,
+ responseHeaders: responseHeaders);
+
+ ShouldRetryResult firstResult = await retryPolicy.ShouldRetryAsync(revocationException, CancellationToken.None);
+ Assert.IsTrue(firstResult.ShouldRetry, "Token revocation with claims should retry on first attempt.");
+ Assert.AreEqual(TimeSpan.Zero, firstResult.BackoffTime, "Retry should be immediate for token revocation.");
+
+ ShouldRetryResult secondResult = await retryPolicy.ShouldRetryAsync(revocationException, CancellationToken.None);
+ Assert.IsFalse(secondResult.ShouldRetry, "Token revocation should not retry after the revocation retry budget is exhausted.");
+ }
+
+ [DataTestMethod]
+ [DataRow(null, DisplayName = "No WWW-Authenticate header")]
+ [DataRow("Bearer realm=\"test\"", DisplayName = "WWW-Authenticate without claims")]
+ public async Task ClientRetryPolicy_401WithoutCaeIndicators_DoesNotRetry(string wwwAuthenticateValue)
+ {
+ const bool enableEndpointDiscovery = true;
+ using GlobalEndpointManager endpointManager = this.Initialize(
+ useMultipleWriteLocations: false,
+ enableEndpointDiscovery: enableEndpointDiscovery,
+ isPreferredLocationsListEmpty: false);
+
+ Mock mockTokenCredential = new Mock();
+ mockTokenCredential
+ .Setup(x => x.GetTokenAsync(It.IsAny(), It.IsAny()))
+ .ReturnsAsync(new AccessToken("test-token", DateTimeOffset.UtcNow.AddHours(1)));
+
+ using AuthorizationTokenProviderTokenCredential tokenProvider = new AuthorizationTokenProviderTokenCredential(
+ mockTokenCredential.Object,
+ new Uri("https://test-account.documents.azure.com"),
+ backgroundTokenCredentialRefreshInterval: TimeSpan.FromMinutes(5),
+ tokenToAuthorizationHeader: AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature);
+
+ ClientRetryPolicy retryPolicy = new ClientRetryPolicy(
+ endpointManager,
+ this.partitionKeyRangeLocationCache,
+ new RetryOptions(),
+ enableEndpointDiscovery,
+ isThinClientEnabled: false,
+ authorizationTokenProvider: tokenProvider);
+
+ DocumentServiceRequest request = this.CreateRequest(isReadRequest: false, isMasterResourceType: false);
+ retryPolicy.OnBeforeSendRequest(request);
+
+ StoreResponseNameValueCollection responseHeaders = new StoreResponseNameValueCollection();
+ if (wwwAuthenticateValue != null)
+ {
+ responseHeaders.Set(HttpConstants.HttpHeaders.WwwAuthenticate, wwwAuthenticateValue);
+ }
+
+ DocumentClientException unauthorizedException = new DocumentClientException(
+ message: "Unauthorized",
+ innerException: null,
+ statusCode: HttpStatusCode.Unauthorized,
+ substatusCode: SubStatusCodes.Unknown,
+ requestUri: request.RequestContext.LocationEndpointToRoute,
+ responseHeaders: responseHeaders);
+
+ ShouldRetryResult result = await retryPolicy.ShouldRetryAsync(unauthorizedException, CancellationToken.None);
+
+ Assert.IsNotNull(result, "Should get a result from the retry pipeline.");
+ Assert.IsFalse(result.ShouldRetry, "401 without CAE indicators should not trigger a retry.");
+ }
+
private async Task ValidateConnectTimeoutTriggersClientRetryPolicyAsync(
bool isReadRequest,
bool useMultipleWriteLocations,
@@ -895,7 +1003,7 @@ private async Task ValidateConnectTimeoutTriggersClientRetryPolicyAsync(
replicatedResourceClient.GoneAndRetryWithRetryTimeoutInSecondsOverride = 1;
this.partitionKeyRangeLocationCache = GlobalPartitionEndpointManagerNoOp.Instance;
-
+
ClientRetryPolicy retryPolicy = new ClientRetryPolicy(mockDocumentClientContext.GlobalEndpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery: true, false);
INameValueCollection headers = new DictionaryNameValueCollection();
@@ -1194,28 +1302,28 @@ private static DocumentServiceRequest CreateDtxRequest()
AuthorizationTokenType.PrimaryMasterKey);
}
- private static GlobalPartitionEndpointManagerCore.PartitionKeyRangeFailoverInfo GetPartitionKeyRangeFailoverInfoUsingReflection(
- GlobalPartitionEndpointManager globalPartitionEndpointManager,
- PartitionKeyRange pkRange,
- bool isReadOnlyOrMultiMasterWriteRequest)
- {
- string fieldName = isReadOnlyOrMultiMasterWriteRequest ? "PartitionKeyRangeToLocationForReadAndWrite" : "PartitionKeyRangeToLocationForWrite";
- FieldInfo fieldInfo = globalPartitionEndpointManager
- .GetType()
- .GetField(
- name: fieldName,
- bindingAttr: BindingFlags.Instance | BindingFlags.NonPublic);
-
- if (fieldInfo != null)
- {
- Lazy> partitionKeyRangeToLocation = (Lazy>)fieldInfo.GetValue(globalPartitionEndpointManager);
- partitionKeyRangeToLocation.Value.TryGetValue(pkRange, out GlobalPartitionEndpointManagerCore.PartitionKeyRangeFailoverInfo partitionKeyRangeFailoverInfo);
-
- return partitionKeyRangeFailoverInfo;
- }
-
- return null;
- }
+ private static GlobalPartitionEndpointManagerCore.PartitionKeyRangeFailoverInfo GetPartitionKeyRangeFailoverInfoUsingReflection(
+ GlobalPartitionEndpointManager globalPartitionEndpointManager,
+ PartitionKeyRange pkRange,
+ bool isReadOnlyOrMultiMasterWriteRequest)
+ {
+ string fieldName = isReadOnlyOrMultiMasterWriteRequest ? "PartitionKeyRangeToLocationForReadAndWrite" : "PartitionKeyRangeToLocationForWrite";
+ FieldInfo fieldInfo = globalPartitionEndpointManager
+ .GetType()
+ .GetField(
+ name: fieldName,
+ bindingAttr: BindingFlags.Instance | BindingFlags.NonPublic);
+
+ if (fieldInfo != null)
+ {
+ Lazy> partitionKeyRangeToLocation = (Lazy>)fieldInfo.GetValue(globalPartitionEndpointManager);
+ partitionKeyRangeToLocation.Value.TryGetValue(pkRange, out GlobalPartitionEndpointManagerCore.PartitionKeyRangeFailoverInfo partitionKeyRangeFailoverInfo);
+
+ return partitionKeyRangeFailoverInfo;
+ }
+
+ return null;
+ }
private static AccountProperties CreateDatabaseAccount(
bool useMultipleWriteLocations,
@@ -1309,9 +1417,9 @@ private GlobalEndpointManager Initialize(
if (enablePartitionLevelFailover)
{
- this.partitionKeyRangeLocationCache = new GlobalPartitionEndpointManagerCore(
- globalEndpointManager: endpointManager,
- isPartitionLevelFailoverEnabled: enablePartitionLevelFailover,
+ this.partitionKeyRangeLocationCache = new GlobalPartitionEndpointManagerCore(
+ globalEndpointManager: endpointManager,
+ isPartitionLevelFailoverEnabled: enablePartitionLevelFailover,
isPartitionLevelCircuitBreakerEnabled: enablePartitionLevelFailover || enablePartitionLevelCircuitBreaker);
}
else
diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs
index 06b84086bf..36ef902fec 100644
--- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs
+++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs
@@ -600,6 +600,145 @@ public async Task TestTokenCredentialMultiThreadAsync()
}
}
+ [TestMethod]
+ [DataRow("Bearer error=\"insufficient_claims\", claims=\"eyJhY2Nlc3NfdG9rZW4iOnt9fQ==\"", true, DisplayName = "With insufficient_claims")]
+ [DataRow("Bearer claims=\"eyJhY2Nlc3NfdG9rZW4iOnt9fQ==\"", true, DisplayName = "With claims only")]
+ [DataRow("Bearer realm=\"test\"", false, DisplayName = "Without claims challenge")]
+ [DataRow("", false, DisplayName = "Empty header")]
+ public void TryHandleTokenRevocation_VariousHeaders(string wwwAuthenticateValue, bool expectedResult)
+ {
+ // Arrange
+ Mock mockTokenCredential = new Mock();
+ mockTokenCredential
+ .Setup(x => x.GetTokenAsync(It.IsAny(), It.IsAny()))
+ .ReturnsAsync(new AccessToken("test-token", DateTimeOffset.MaxValue));
+
+ using AuthorizationTokenProviderTokenCredential tokenProvider = new AuthorizationTokenProviderTokenCredential(
+ mockTokenCredential.Object,
+ CosmosAuthorizationTests.AccountEndpoint,
+ backgroundTokenCredentialRefreshInterval: TimeSpan.FromMinutes(5),
+ tokenToAuthorizationHeader: AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature);
+
+ // Act
+ bool result = tokenProvider.TryHandleTokenRevocation(HttpStatusCode.Unauthorized, wwwAuthenticateValue);
+
+ // Assert
+ Assert.AreEqual(expectedResult, result);
+ }
+
+ [TestMethod]
+ [DataRow(HttpStatusCode.Forbidden)]
+ [DataRow(HttpStatusCode.BadRequest)]
+ [DataRow(HttpStatusCode.NotFound)]
+ public void TryHandleTokenRevocation_NonUnauthorizedStatus_ReturnsFalse(HttpStatusCode statusCode)
+ {
+ // Arrange
+ Mock mockTokenCredential = new Mock();
+ mockTokenCredential
+ .Setup(x => x.GetTokenAsync(It.IsAny(), It.IsAny()))
+ .ReturnsAsync(new AccessToken("test-token", DateTimeOffset.MaxValue));
+
+ using AuthorizationTokenProviderTokenCredential tokenProvider = new AuthorizationTokenProviderTokenCredential(
+ mockTokenCredential.Object,
+ CosmosAuthorizationTests.AccountEndpoint,
+ backgroundTokenCredentialRefreshInterval: TimeSpan.FromMinutes(5),
+ tokenToAuthorizationHeader: AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature);
+
+ string wwwAuthenticateValue = "Bearer error=\"insufficient_claims\", claims=\"eyJhY2Nlc3NfdG9rZW4iOnt9fQ==\"";
+
+ // Act
+ bool result = tokenProvider.TryHandleTokenRevocation(statusCode, wwwAuthenticateValue);
+ // Assert
+ Assert.IsFalse(result);
+ }
+
+ [TestMethod]
+ [DataRow(null, "{\"access_token\":{\"xms_cc\":{\"values\":[\"cp1\"]}}}", DisplayName = "Null claims")]
+ [DataRow("", "{\"access_token\":{\"xms_cc\":{\"values\":[\"cp1\"]}}}", DisplayName = "Empty claims")]
+ [DataRow("not-valid-base64!!!", "{\"access_token\":{\"xms_cc\":{\"values\":[\"cp1\"]}}}", DisplayName = "Invalid base64")]
+ public void MergeClaimsWithClientCapabilities_InvalidInput_ReturnsOnlyCp1(string claimsChallenge, string expected)
+ {
+ // Act
+ string result = TokenCredentialCache.MergeClaimsWithClientCapabilities(claimsChallenge);
+
+ // Assert
+ Assert.AreEqual(expected, result);
+ }
+
+ [TestMethod]
+ public void MergeClaimsWithClientCapabilities_ValidClaims_MergesWithCp1()
+ {
+ // Arrange - Base64 encoded: {"access_token":{"acrs":{"essential":true,"value":"c1"}}}
+ string claimsChallenge = "eyJhY2Nlc3NfdG9rZW4iOnsiYWNycyI6eyJlc3NlbnRpYWwiOnRydWUsInZhbHVlIjoiYzEifX19";
+
+ // Act
+ string result = TokenCredentialCache.MergeClaimsWithClientCapabilities(claimsChallenge);
+
+ // Assert
+ Assert.IsTrue(result.Contains("\"xms_cc\":{\"values\":[\"cp1\"]}"), "Should contain cp1");
+ Assert.IsTrue(result.Contains("\"acrs\""), "Should contain original claims");
+ }
+
+ [TestMethod]
+ public async Task TokenCredentialCache_ResetWithClaims_RefreshesTokenWithClaims()
+ {
+ // Arrange
+ int callCount = 0;
+ List claimsReceived = new List();
+
+ TestTokenCredential testTokenCredential = new TestTokenCredential(() =>
+ {
+ callCount++;
+ return new ValueTask(new AccessToken($"Token{callCount}", DateTimeOffset.MaxValue));
+ });
+
+ using TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential);
+
+ // Get initial token
+ string t1 = await tokenCredentialCache.GetTokenAuthorizationHeaderAsync(NoOpTrace.Singleton);
+ Assert.AreEqual(AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature("Token1"), t1);
+ Assert.AreEqual(1, callCount);
+
+ // Simulate CAE revocation with claims
+ string claimsChallenge = Convert.ToBase64String(
+ System.Text.Encoding.UTF8.GetBytes("{\"access_token\":{\"acrs\":{\"essential\":true,\"value\":\"c1\"}}}"));
+ tokenCredentialCache.ResetCachedToken(claimsChallenge);
+
+ // Get token again
+ string t2 = await tokenCredentialCache.GetTokenAuthorizationHeaderAsync(NoOpTrace.Singleton);
+
+ // Assert
+ Assert.AreEqual(AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature("Token2"), t2);
+ Assert.AreEqual(2, callCount);
+ }
+
+ [TestMethod]
+ public async Task TokenCredentialCache_ResetWithNullClaims_RefreshesToken()
+ {
+ // Arrange
+ int callCount = 0;
+ TestTokenCredential testTokenCredential = new TestTokenCredential(() =>
+ {
+ callCount++;
+ return new ValueTask(new AccessToken($"Token{callCount}", DateTimeOffset.MaxValue));
+ });
+
+ using TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential);
+
+ // Get initial token
+ await tokenCredentialCache.GetTokenAuthorizationHeaderAsync(NoOpTrace.Singleton);
+ Assert.AreEqual(1, callCount);
+
+ // Reset with null claims
+ tokenCredentialCache.ResetCachedToken(claimsChallenge: null);
+
+ // Get token again
+ await tokenCredentialCache.GetTokenAuthorizationHeaderAsync(NoOpTrace.Singleton);
+
+ // Assert
+ Assert.AreEqual(2, callCount);
+ }
+
private TokenCredentialCache CreateTokenCredentialCache(
TokenCredential tokenCredential)
{
diff --git a/changelog.md b/changelog.md
index 584d48ec5d..879b25d95c 100644
--- a/changelog.md
+++ b/changelog.md
@@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- [5838](https://github.com/Azure/azure-cosmos-dotnet-v3/pull/5838) EmbeddingGenerator: Adds ICosmosEmbeddingGenerator client-wide configuration (preview)
+- [#5549](https://github.com/Azure/azure-cosmos-dotnet-v3/pull/5549) Adds AAD token revocation (CAE / Emergency) transparent retry handling
+
#### Breaking Changes
#### Bugs Fixed