Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/core/Azure.Core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Breaking Changes

### Bugs Fixed
- Fixed an issue that could result in `BearerTokenAuthenticationPolicy` fails to refresh a token, resulting in a `OperationCanceledException`.

### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,15 @@ protected void AuthenticateAndAuthorizeRequest(HttpMessage message, TokenRequest
message.Request.Headers.SetValue(HttpHeader.Names.Authorization, headerValue);
}

private class AccessTokenCache
internal class AccessTokenCache
{
private readonly object _syncObj = new object();
private readonly TokenCredential _credential;
private readonly TimeSpan _tokenRefreshOffset;
private readonly TimeSpan _tokenRefreshRetryDelay;

// must be updated under lock (_syncObj)
private TokenRequestState? _state;
internal TokenRequestState? _state;

public AccessTokenCache(TokenCredential credential, TimeSpan tokenRefreshOffset, TimeSpan tokenRefreshRetryDelay)
{
Expand All @@ -204,7 +204,7 @@ public async ValueTask<string> GetAuthHeaderValueAsync(HttpMessage message, Toke
{
if (localState.BackgroundTokenUpdateTcs != null)
{
headerValueInfo = await localState.GetCurrentHeaderValue(async).ConfigureAwait(false);
headerValueInfo = await localState.GetCurrentHeaderValue(async, false, message.CancellationToken).ConfigureAwait(false);
_ = Task.Run(() => GetHeaderValueFromCredentialInBackgroundAsync(localState.BackgroundTokenUpdateTcs, headerValueInfo, context, async));
return headerValueInfo.HeaderValue;
}
Expand Down Expand Up @@ -355,7 +355,7 @@ private async ValueTask SetResultOnTcsFromCredentialAsync(TokenRequestContext co
targetTcs.SetResult(new AuthHeaderValueInfo("Bearer " + token.Token, token.ExpiresOn, token.RefreshOn.HasValue ? token.RefreshOn.Value : token.ExpiresOn - _tokenRefreshOffset));
}

private readonly struct AuthHeaderValueInfo
internal readonly struct AuthHeaderValueInfo
{
public string HeaderValue { get; }
public DateTimeOffset ExpiresOn { get; }
Expand All @@ -369,7 +369,7 @@ public AuthHeaderValueInfo(string headerValue, DateTimeOffset expiresOn, DateTim
}
}

private class TokenRequestState
internal class TokenRequestState
{
public TokenRequestContext CurrentContext { get; }
public TaskCompletionSource<AuthHeaderValueInfo> CurrentTokenTcs { get; }
Expand Down Expand Up @@ -409,7 +409,7 @@ public TokenRequestState WithBackgroundUpdateTcsAsCurrent() =>
new TokenRequestState(CurrentContext, BackgroundTokenUpdateTcs!, default);

public TokenRequestState WithNewCurrentTokenTcs() =>
new TokenRequestState(CurrentContext, new TaskCompletionSource<AuthHeaderValueInfo>(TaskCreationOptions.RunContinuationsAsynchronously), BackgroundTokenUpdateTcs);
new TokenRequestState(CurrentContext, new TaskCompletionSource<AuthHeaderValueInfo>(TaskCreationOptions.RunContinuationsAsynchronously), default);

public TokenRequestState WithNewBackroundUpdateTokenTcs() =>
new TokenRequestState(CurrentContext, CurrentTokenTcs, new TaskCompletionSource<AuthHeaderValueInfo>(TaskCreationOptions.RunContinuationsAsynchronously));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,58 @@ public async Task BearerTokenAuthenticationPolicy_SwitchedTenants()
Assert.AreEqual(3, callCount);
}

[Test]
public async Task TokenCacheCurrentTcsTOkenIsExpiredAndBackgroundTcsInitialized()
{
var currentTcs = new TaskCompletionSource<BearerTokenAuthenticationPolicy.AccessTokenCache.AuthHeaderValueInfo>();
var backgroundTcs = new TaskCompletionSource<BearerTokenAuthenticationPolicy.AccessTokenCache.AuthHeaderValueInfo>();

currentTcs.SetResult(new BearerTokenAuthenticationPolicy.AccessTokenCache.AuthHeaderValueInfo("token", DateTimeOffset.UtcNow.AddMinutes(-5), DateTimeOffset.UtcNow.AddMinutes(-5)));

TokenRequestContext ctx = new TokenRequestContext(new[] { "scope" });
var cache = new BearerTokenAuthenticationPolicy.AccessTokenCache(
new TokenCredentialStub((r, c) => new AccessToken(string.Empty, DateTimeOffset.MaxValue), IsAsync),
TimeSpan.FromMinutes(5), TimeSpan.FromSeconds(30))
{
_state = new BearerTokenAuthenticationPolicy.AccessTokenCache.TokenRequestState(
ctx,
currentTcs,
backgroundTcs
)
};
var msg = new HttpMessage(new MockRequest(), ResponseClassifier.Shared);
var cts = new CancellationTokenSource();
cts.CancelAfter(5000);
msg.CancellationToken = cts.Token;
await cache.GetAuthHeaderValueAsync(msg, ctx, IsAsync);
}

[Test]
public async Task TokenCacheCurrentTcsIsCancelledAndBackgroundTcsInitialized()
{
var currentTcs = new TaskCompletionSource<BearerTokenAuthenticationPolicy.AccessTokenCache.AuthHeaderValueInfo>();
var backgroundTcs = new TaskCompletionSource<BearerTokenAuthenticationPolicy.AccessTokenCache.AuthHeaderValueInfo>();

currentTcs.SetCanceled();

TokenRequestContext ctx = new TokenRequestContext(new[] { "scope" });
var cache = new BearerTokenAuthenticationPolicy.AccessTokenCache(
new TokenCredentialStub((r, c) => new AccessToken(string.Empty, DateTimeOffset.MaxValue), IsAsync),
TimeSpan.FromMinutes(5), TimeSpan.FromSeconds(30))
{
_state = new BearerTokenAuthenticationPolicy.AccessTokenCache.TokenRequestState(
ctx,
currentTcs,
backgroundTcs
)
};
var msg = new HttpMessage(new MockRequest(), ResponseClassifier.Shared);
var cts = new CancellationTokenSource();
cts.CancelAfter(5000);
msg.CancellationToken = cts.Token;
await cache.GetAuthHeaderValueAsync(msg, ctx, IsAsync);
}

private class ChallengeBasedAuthenticationTestPolicy : BearerTokenAuthenticationPolicy
{
public string TenantId { get; private set; }
Expand Down