diff --git a/sdk/core/Azure.Core/CHANGELOG.md b/sdk/core/Azure.Core/CHANGELOG.md index 271a6966de58..61f9e6dd31ce 100644 --- a/sdk/core/Azure.Core/CHANGELOG.md +++ b/sdk/core/Azure.Core/CHANGELOG.md @@ -7,6 +7,7 @@ ### Breaking Changes ### Bugs Fixed +- Fixed an issue that could result in `BearerTokenAuthenticationPolicy` fails to refresh a token, resulting in a `OperationCanceledException`. ### Other Changes diff --git a/sdk/core/Azure.Core/src/Pipeline/BearerTokenAuthenticationPolicy.cs b/sdk/core/Azure.Core/src/Pipeline/BearerTokenAuthenticationPolicy.cs index ea49deb64b9e..0b8c7de0d5a1 100644 --- a/sdk/core/Azure.Core/src/Pipeline/BearerTokenAuthenticationPolicy.cs +++ b/sdk/core/Azure.Core/src/Pipeline/BearerTokenAuthenticationPolicy.cs @@ -173,7 +173,7 @@ protected void AuthenticateAndAuthorizeRequest(HttpMessage message, TokenRequest message.Request.Headers.SetValue(HttpHeader.Names.Authorization, headerValue); } - private class AccessTokenCache + internal class AccessTokenCache { private readonly object _syncObj = new object(); private readonly TokenCredential _credential; @@ -181,7 +181,7 @@ private class AccessTokenCache private readonly TimeSpan _tokenRefreshRetryDelay; // must be updated under lock (_syncObj) - private TokenRequestState? _state; + internal TokenRequestState? _state; public AccessTokenCache(TokenCredential credential, TimeSpan tokenRefreshOffset, TimeSpan tokenRefreshRetryDelay) { @@ -204,7 +204,7 @@ public async ValueTask GetAuthHeaderValueAsync(HttpMessage message, Toke { if (localState.BackgroundTokenUpdateTcs != null) { - headerValueInfo = await localState.GetCurrentHeaderValue(async).ConfigureAwait(false); + headerValueInfo = await localState.GetCurrentHeaderValue(async, false, message.CancellationToken).ConfigureAwait(false); _ = Task.Run(() => GetHeaderValueFromCredentialInBackgroundAsync(localState.BackgroundTokenUpdateTcs, headerValueInfo, context, async)); return headerValueInfo.HeaderValue; } @@ -355,7 +355,7 @@ private async ValueTask SetResultOnTcsFromCredentialAsync(TokenRequestContext co targetTcs.SetResult(new AuthHeaderValueInfo("Bearer " + token.Token, token.ExpiresOn, token.RefreshOn.HasValue ? token.RefreshOn.Value : token.ExpiresOn - _tokenRefreshOffset)); } - private readonly struct AuthHeaderValueInfo + internal readonly struct AuthHeaderValueInfo { public string HeaderValue { get; } public DateTimeOffset ExpiresOn { get; } @@ -369,7 +369,7 @@ public AuthHeaderValueInfo(string headerValue, DateTimeOffset expiresOn, DateTim } } - private class TokenRequestState + internal class TokenRequestState { public TokenRequestContext CurrentContext { get; } public TaskCompletionSource CurrentTokenTcs { get; } @@ -409,7 +409,7 @@ public TokenRequestState WithBackgroundUpdateTcsAsCurrent() => new TokenRequestState(CurrentContext, BackgroundTokenUpdateTcs!, default); public TokenRequestState WithNewCurrentTokenTcs() => - new TokenRequestState(CurrentContext, new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously), BackgroundTokenUpdateTcs); + new TokenRequestState(CurrentContext, new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously), default); public TokenRequestState WithNewBackroundUpdateTokenTcs() => new TokenRequestState(CurrentContext, CurrentTokenTcs, new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously)); diff --git a/sdk/core/Azure.Core/tests/BearerTokenAuthenticationPolicyTests.cs b/sdk/core/Azure.Core/tests/BearerTokenAuthenticationPolicyTests.cs index 92a1b2d907f8..9d079b6c8723 100644 --- a/sdk/core/Azure.Core/tests/BearerTokenAuthenticationPolicyTests.cs +++ b/sdk/core/Azure.Core/tests/BearerTokenAuthenticationPolicyTests.cs @@ -888,6 +888,58 @@ public async Task BearerTokenAuthenticationPolicy_SwitchedTenants() Assert.AreEqual(3, callCount); } + [Test] + public async Task TokenCacheCurrentTcsTOkenIsExpiredAndBackgroundTcsInitialized() + { + var currentTcs = new TaskCompletionSource(); + var backgroundTcs = new TaskCompletionSource(); + + currentTcs.SetResult(new BearerTokenAuthenticationPolicy.AccessTokenCache.AuthHeaderValueInfo("token", DateTimeOffset.UtcNow.AddMinutes(-5), DateTimeOffset.UtcNow.AddMinutes(-5))); + + TokenRequestContext ctx = new TokenRequestContext(new[] { "scope" }); + var cache = new BearerTokenAuthenticationPolicy.AccessTokenCache( + new TokenCredentialStub((r, c) => new AccessToken(string.Empty, DateTimeOffset.MaxValue), IsAsync), + TimeSpan.FromMinutes(5), TimeSpan.FromSeconds(30)) + { + _state = new BearerTokenAuthenticationPolicy.AccessTokenCache.TokenRequestState( + ctx, + currentTcs, + backgroundTcs + ) + }; + var msg = new HttpMessage(new MockRequest(), ResponseClassifier.Shared); + var cts = new CancellationTokenSource(); + cts.CancelAfter(5000); + msg.CancellationToken = cts.Token; + await cache.GetAuthHeaderValueAsync(msg, ctx, IsAsync); + } + + [Test] + public async Task TokenCacheCurrentTcsIsCancelledAndBackgroundTcsInitialized() + { + var currentTcs = new TaskCompletionSource(); + var backgroundTcs = new TaskCompletionSource(); + + currentTcs.SetCanceled(); + + TokenRequestContext ctx = new TokenRequestContext(new[] { "scope" }); + var cache = new BearerTokenAuthenticationPolicy.AccessTokenCache( + new TokenCredentialStub((r, c) => new AccessToken(string.Empty, DateTimeOffset.MaxValue), IsAsync), + TimeSpan.FromMinutes(5), TimeSpan.FromSeconds(30)) + { + _state = new BearerTokenAuthenticationPolicy.AccessTokenCache.TokenRequestState( + ctx, + currentTcs, + backgroundTcs + ) + }; + var msg = new HttpMessage(new MockRequest(), ResponseClassifier.Shared); + var cts = new CancellationTokenSource(); + cts.CancelAfter(5000); + msg.CancellationToken = cts.Token; + await cache.GetAuthHeaderValueAsync(msg, ctx, IsAsync); + } + private class ChallengeBasedAuthenticationTestPolicy : BearerTokenAuthenticationPolicy { public string TenantId { get; private set; }