Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 138 additions & 43 deletions Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@ internal sealed class TokenCredentialCache : IDisposable
// If the background refresh fails with less than a minute then just allow the request to hit the exception.
public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1);

private const string ScopeFormat = "https://{0}/.default";
private const string ScopeFormat = "https://{0}/.default";
private const string AadInvalidScopeErrorMessage = "AADSTS500011";
private const string AadDefaultScope = "https://cosmos.azure.com/.default";

private readonly TokenRequestContext tokenRequestContext;
private readonly TokenCredential tokenCredential;
private readonly CancellationTokenSource cancellationTokenSource;
private readonly CancellationToken cancellationToken;
private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval;
private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval;
private readonly string accountScope;
private readonly bool isOverrideScopeProvided;

private readonly SemaphoreSlim isTokenRefreshingLock = new SemaphoreSlim(1);
private readonly object backgroundRefreshLock = new object();
Expand All @@ -67,11 +71,12 @@ internal TokenCredentialCache(

string? scopeOverride = ConfigurationManager.AADScopeOverrideValue(defaultValue: null);

this.accountScope = string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host);
this.isOverrideScopeProvided = !string.IsNullOrEmpty(scopeOverride);

this.tokenRequestContext = new TokenRequestContext(new string[]
{
!string.IsNullOrEmpty(scopeOverride)
? scopeOverride
: string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host)
this.isOverrideScopeProvided ? scopeOverride! : this.accountScope
});

if (backgroundTokenCredentialRefreshInterval.HasValue)
Expand Down Expand Up @@ -167,6 +172,22 @@ private async Task<AccessToken> GetNewTokenAsync(
}

return await currentTask;
}

private void ApplyTokenAndSetRefreshInterval(AccessToken token)
{
this.cachedAccessToken = token;

if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue)
{
double refreshIntervalInSeconds =
(token.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage;

refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds);
refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds);

this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds);
}
}

private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
Expand All @@ -190,34 +211,27 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
name: nameof(this.RefreshCachedTokenWithRetryHelperAsync),
component: TraceComponent.Authorization,
level: Tracing.TraceLevel.Info))
{
try
{
this.cachedAccessToken = await this.tokenCredential.GetTokenAsync(
requestContext: this.tokenRequestContext,
cancellationToken: this.cancellationToken);

if (!this.cachedAccessToken.HasValue)
{
throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token.");
}

if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow)
{
throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}");
}

if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue)
{
double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage;

// Ensure the background refresh interval is a valid range.
refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds);
refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds);
this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds);
}

return this.cachedAccessToken.Value;
{
bool shouldAttemptAadFallback = false;

try
{
AccessToken? tokenNullable = await this.tokenCredential.GetTokenAsync(
requestContext: this.tokenRequestContext,
cancellationToken: this.cancellationToken);

if (!tokenNullable.HasValue)
{
throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token.");
}

if (tokenNullable.Value.ExpiresOn < DateTimeOffset.UtcNow)
{
throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{tokenNullable.Value.ExpiresOn:O}");
}

this.ApplyTokenAndSetRefreshInterval(tokenNullable.Value);
return tokenNullable.Value;
}
catch (RequestFailedException requestFailedException)
{
Expand All @@ -228,13 +242,17 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(

DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

// Don't retry on auth failures
if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden)
{
this.cachedAccessToken = default;
throw;
}
if (this.isOverrideScopeProvided)
{
if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden)
{
this.cachedAccessToken = default;
throw;
}

continue;
}
}
catch (OperationCanceledException operationCancelled)
{
Expand Down Expand Up @@ -263,9 +281,86 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
exception.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
}
}
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

if (this.isOverrideScopeProvided)
{
continue;
}

if (exception.InnerException?.Message.Contains(AadInvalidScopeErrorMessage) == true)
{
shouldAttemptAadFallback = true;
}
}

if (!this.isOverrideScopeProvided && shouldAttemptAadFallback)
{
TokenRequestContext fallbackContext = new TokenRequestContext(new[] { AadDefaultScope });

try
{
AccessToken? tokenNullable = await this.tokenCredential.GetTokenAsync(
requestContext: fallbackContext,
cancellationToken: this.cancellationToken);

if (!tokenNullable.HasValue)
{
throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token.");
}

if (tokenNullable.Value.ExpiresOn < DateTimeOffset.UtcNow)
{
throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{tokenNullable.Value.ExpiresOn:O}");
}

this.ApplyTokenAndSetRefreshInterval(tokenNullable.Value);
return tokenNullable.Value;
}
catch (RequestFailedException requestFailedExceptionFallback)
{
lastException = requestFailedExceptionFallback;
getTokenTrace.AddDatum(
$"RequestFailedException (fallback) at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
requestFailedExceptionFallback.Message);

DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException on fallback. scope = {string.Join(";", fallbackContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

if (requestFailedExceptionFallback.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedExceptionFallback.Status == (int)HttpStatusCode.Forbidden)
{
this.cachedAccessToken = default;
throw;
}
}
catch (OperationCanceledException operationCancelledFallback)
{
lastException = operationCancelledFallback;
getTokenTrace.AddDatum(
$"OperationCanceledException (fallback) at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
operationCancelledFallback.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed on fallback. scope = {string.Join(";", fallbackContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

throw CosmosExceptionFactory.CreateRequestTimeoutException(
message: ClientResources.FailedToGetAadToken,
headers: new Headers() { SubStatusCode = SubStatusCodes.FailedToGetAadToken, },
innerException: lastException,
trace: getTokenTrace);
}
catch (Exception exceptionFallback)
{
lastException = exceptionFallback;
getTokenTrace.AddDatum(
$"Exception (fallback) at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exceptionFallback.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed on fallback. scope = {string.Join(";", fallbackContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
}
}
}
}

if (lastException == null)
Expand Down
Loading
Loading