Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------
namespace Microsoft.Azure.Cosmos.Authorization
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using global::Azure.Core;

internal sealed class CosmosScopeProvider : IScopeProvider
{
private const string AadInvalidScopeErrorMessage = "AADSTS500011";
private const string AadDefaultScope = "https://cosmos.azure.com/.default";
private const string ScopeFormat = "https://{0}/.default";

private readonly string accountScope;
private readonly string overrideScope;
private string currentScope;

public CosmosScopeProvider(Uri accountEndpoint)
{
this.overrideScope = ConfigurationManager.AADScopeOverrideValue(defaultValue: null);
this.accountScope = string.Format(ScopeFormat, accountEndpoint.Host);
this.currentScope = this.overrideScope ?? this.accountScope;
}

public TokenRequestContext GetTokenRequestContext()
{
return new TokenRequestContext(new[] { this.currentScope });
}

public bool TryFallback(Exception exception)
{
// If override scope is set, never fallback
if (!string.IsNullOrEmpty(this.overrideScope))
{
return false;
}

// If already using fallback scope, do not fallback again
if (this.currentScope == CosmosScopeProvider.AadDefaultScope)
{
return false;
}

#pragma warning disable CDX1003 // DontUseExceptionToString
if (exception.ToString().Contains(CosmosScopeProvider.AadInvalidScopeErrorMessage) == true)
{
this.currentScope = CosmosScopeProvider.AadDefaultScope;
return true;
}
#pragma warning restore CDX1003 // DontUseExceptionToString

return false;
}

public void Dispose()
{
}
}
}
16 changes: 16 additions & 0 deletions Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------
namespace Microsoft.Azure.Cosmos.Authorization
{
using System;
using System.Collections.Generic;
using System.Text;
using global::Azure.Core;

internal interface IScopeProvider : IDisposable
{
TokenRequestContext GetTokenRequestContext();
bool TryFallback(Exception ex);
}
}
60 changes: 34 additions & 26 deletions Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace Microsoft.Azure.Cosmos
using System.Threading;
using System.Threading.Tasks;
using global::Azure;
using global::Azure.Core;
using global::Azure.Core;
using Microsoft.Azure.Cosmos.Authorization;
using Microsoft.Azure.Cosmos.Core.Trace;
using Microsoft.Azure.Cosmos.Resource.CosmosExceptions;
using Microsoft.Azure.Cosmos.Tracing;
Expand All @@ -36,9 +37,7 @@ internal sealed class TokenCredentialCache : IDisposable
// If the background refresh fails with less than a minute then just allow the request to hit the exception.
public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1);

private const string ScopeFormat = "https://{0}/.default";

private readonly TokenRequestContext tokenRequestContext;
private readonly IScopeProvider scopeProvider;
private readonly TokenCredential tokenCredential;
private readonly CancellationTokenSource cancellationTokenSource;
private readonly CancellationToken cancellationToken;
Expand All @@ -51,7 +50,7 @@ internal sealed class TokenCredentialCache : IDisposable
private Task<AccessToken>? currentRefreshOperation = null;
private AccessToken? cachedAccessToken = null;
private bool isBackgroundTaskRunning = false;
private bool isDisposed = false;
private bool isDisposed = false;

internal TokenCredentialCache(
TokenCredential tokenCredential,
Expand All @@ -65,14 +64,7 @@ internal TokenCredentialCache(
throw new ArgumentNullException(nameof(accountEndpoint));
}

string? scopeOverride = ConfigurationManager.AADScopeOverrideValue(defaultValue: null);

this.tokenRequestContext = new TokenRequestContext(new string[]
{
!string.IsNullOrEmpty(scopeOverride)
? scopeOverride
: string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host)
});
this.scopeProvider = new Microsoft.Azure.Cosmos.Authorization.CosmosScopeProvider(accountEndpoint);

if (backgroundTokenCredentialRefreshInterval.HasValue)
{
Expand Down Expand Up @@ -171,11 +163,13 @@ private async Task<AccessToken> GetNewTokenAsync(

private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
ITrace trace)
{
{
try
{
Exception? lastException = null;
const int totalRetryCount = 2;
const int totalRetryCount = 2;
TokenRequestContext tokenRequestContext = default;

for (int retry = 0; retry < totalRetryCount; retry++)
{
if (this.cancellationToken.IsCancellationRequested)
Expand All @@ -190,11 +184,13 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
name: nameof(this.RefreshCachedTokenWithRetryHelperAsync),
component: TraceComponent.Authorization,
level: Tracing.TraceLevel.Info))
{
{
try
{
{
tokenRequestContext = this.scopeProvider.GetTokenRequestContext();

this.cachedAccessToken = await this.tokenCredential.GetTokenAsync(
requestContext: this.tokenRequestContext,
requestContext: tokenRequestContext,
cancellationToken: this.cancellationToken);

if (!this.cachedAccessToken.HasValue)
Expand All @@ -220,31 +216,37 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
return this.cachedAccessToken.Value;
}
catch (RequestFailedException requestFailedException)
{
{
lastException = requestFailedException;
getTokenTrace.AddDatum(
$"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
requestFailedException.Message);

DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes ?? Array.Empty<string>())}, retry = {retry}, Exception = {lastException.Message}");

// Don't retry on auth failures
if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden)
{
this.cachedAccessToken = default;
throw;
}

// Fallback logic
if (this.scopeProvider.TryFallback(requestFailedException))
{
continue;
}
}
catch (OperationCanceledException operationCancelled)
{
lastException = operationCancelled;
getTokenTrace.AddDatum(
$"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
operationCancelled.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
operationCancelled.Message);
DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes ?? Array.Empty<string>())}, retry = {retry}, Exception = {lastException.Message}");

throw CosmosExceptionFactory.CreateRequestTimeoutException(
message: ClientResources.FailedToGetAadToken,
Expand All @@ -262,8 +264,14 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes ?? Array.Empty<string>())}, retry = {retry}, Exception = {lastException.Message}");

// Fallback logic
if (this.scopeProvider.TryFallback(exception))
{
continue;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using System.Web;
using Documents.Client;
using global::Azure;
using global::Azure.Core;
using global::Azure.Core;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.IdentityModel.Tokens;
using static Microsoft.Azure.Cosmos.SDK.EmulatorTests.TransportClientHelper;
Expand Down Expand Up @@ -263,6 +263,136 @@ void GetAadTokenCallBack(
Assert.IsTrue(ce.ToString().Contains(errorMessage));
}
}
}
}

[TestMethod]
public async Task Aad_OverrideScope_NoFallback_OnFailure_E2E()
{
// Arrange
(string endpoint, string authKey) = TestCommon.GetAccountInfo();
string databaseId = "db-" + Guid.NewGuid();
using (CosmosClient setupClient = TestCommon.CreateCosmosClient())
{
await setupClient.CreateDatabaseAsync(databaseId);
}

string overrideScope = "https://override/.default";
string accountScope = $"https://{new Uri(endpoint).Host}/.default";
int overrideScopeCount = 0;
int accountScopeCount = 0;

string previous = Environment.GetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE");
Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", overrideScope);

void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token)
{
string scope = context.Scopes[0];
if (scope == overrideScope)
{
overrideScopeCount++;
throw new RequestFailedException(408, "Simulated override scope failure");
}
if (scope == accountScope)
{
accountScopeCount++;
}
}

LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential(
expectedScopes: new[] { overrideScope, accountScope },
masterKey: authKey,
getTokenCallback: GetAadTokenCallBack);

CosmosClientOptions clientOptions = new CosmosClientOptions
{
ConnectionMode = ConnectionMode.Gateway,
TokenCredentialBackgroundRefreshInterval = TimeSpan.FromSeconds(60)
};

try
{
using CosmosClient aadClient = new CosmosClient(endpoint, credential, clientOptions);

try
{
// Act
ResponseMessage r = await aadClient.GetDatabase(databaseId).ReadStreamAsync();
Assert.Fail("Expected failure when override scope token acquisition fails.");
}
catch (RequestFailedException ex) when (ex.Status == (int)HttpStatusCode.RequestTimeout || ex.Status == 408)
{
// Assert
Assert.IsTrue(overrideScopeCount > 0, "Override scope should have been attempted.");
Assert.AreEqual(0, accountScopeCount, "No fallback to account scope must occur when override is configured.");
}
}
finally
{
Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", previous);
using CosmosClient cleanup = TestCommon.CreateCosmosClient();
await cleanup.GetDatabase(databaseId).DeleteAsync();
}
}

[TestMethod]
public async Task Aad_AccountScope_Fallbacks_ToCosmosScope()
{
(string endpoint, string authKey) = TestCommon.GetAccountInfo();

string previous = Environment.GetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE");
Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", null);

string accountScope = $"https://{new Uri(endpoint).Host}/.default";
string aadScope = "https://cosmos.azure.com/.default";

int accountScopeCount = 0;
int cosmosScopeCount = 0;

void GetAadTokenCallBack(TokenRequestContext context, CancellationToken token)
{
string scope = context.Scopes[0];

if (string.Equals(scope, accountScope, StringComparison.OrdinalIgnoreCase))
{
accountScopeCount++;
throw new Exception(
message: "AADSTS500011",
innerException: new Exception("AADSTS500011"));
}

if (string.Equals(scope, aadScope, StringComparison.OrdinalIgnoreCase))
{
cosmosScopeCount++;
}
}

LocalEmulatorTokenCredential credential = new LocalEmulatorTokenCredential(
expectedScopes: new[] { accountScope, aadScope },
masterKey: authKey,
getTokenCallback: GetAadTokenCallBack);

CosmosClientOptions clientOptions = new CosmosClientOptions
{
ConnectionMode = ConnectionMode.Gateway,
TokenCredentialBackgroundRefreshInterval = TimeSpan.FromSeconds(60)
};

try
{
using CosmosClient aadClient = new CosmosClient(endpoint, credential, clientOptions);
TokenCredentialCache tokenCredentialCache =
((AuthorizationTokenProviderTokenCredential)aadClient.AuthorizationTokenProvider).tokenCredentialCache;

string token = await tokenCredentialCache.GetTokenAsync(Tracing.Trace.GetRootTrace("account-fallback-to-cosmos-test"));
Assert.IsFalse(string.IsNullOrEmpty(token), "Fallback should succeed and produce a token.");

Assert.IsTrue(accountScopeCount >= 1, "Account scope must be attempted first.");
Assert.IsTrue(cosmosScopeCount >= 1, "The client must fall back to cosmos.azure.com scope.");
}
finally
{
Environment.SetEnvironmentVariable("AZURE_COSMOS_AAD_SCOPE_OVERRIDE", previous);
}
}
}
}
Loading
Loading