Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

### Other Changes

- [[#4756]] (https://github.com/Azure/azure-sdk-for-cpp/issues/4756) `BearerTokenAuthenticationPolicy` now uses shared mutex lock for read operations.

## 1.11.0-beta.2 (2023-11-02)

### Features Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <memory>
#include <mutex>
#include <set>
#include <shared_mutex>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -549,7 +550,7 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
Credentials::TokenRequestContext m_tokenRequestContext;

mutable Credentials::AccessToken m_accessToken;
mutable std::mutex m_accessTokenMutex;
mutable std::shared_timed_mutex m_accessTokenMutex;
mutable Credentials::TokenRequestContext m_accessTokenContext;

public:
Expand Down Expand Up @@ -581,6 +582,9 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
BearerTokenAuthenticationPolicy(BearerTokenAuthenticationPolicy const& other)
: BearerTokenAuthenticationPolicy(other.m_credential, other.m_tokenRequestContext)
{
std::shared_lock<std::shared_timed_mutex> readLock(other.m_accessTokenMutex);
m_accessToken = other.m_accessToken;
m_accessTokenContext = other.m_accessTokenContext;
}

void operator=(BearerTokenAuthenticationPolicy const&) = delete;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,49 @@ bool BearerTokenAuthenticationPolicy::AuthorizeRequestOnChallenge(
return false;
}

namespace {
bool TokenNeedsRefresh(
Azure::Core::Credentials::AccessToken const& cachedToken,
Azure::Core::Credentials::TokenRequestContext const& cachedTokenRequestContext,
Azure::DateTime const& currentTime,
Azure::Core::Credentials::TokenRequestContext const& newTokenRequestContext)
{
return newTokenRequestContext.TenantId != cachedTokenRequestContext.TenantId
|| newTokenRequestContext.Scopes != cachedTokenRequestContext.Scopes
|| currentTime > (cachedToken.ExpiresOn - newTokenRequestContext.MinimumExpiration);
}
Comment on lines +65 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic could benefit from a comment. For example, I understand that we need a token refresh if the tenant id and scope don't match, but why are we using current token's and cached token's expirations together like this: cachedToken.ExpiresOn - newTokenRequestContext.MinimumExpiration?


void ApplyBearerToken(
Azure::Core::Http::Request& request,
Azure::Core::Credentials::AccessToken const& token)
{
request.SetHeader("authorization", "Bearer " + token.Token);
}
} // namespace

void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest(
Request& request,
Credentials::TokenRequestContext const& tokenRequestContext,
Context const& context) const
{
std::lock_guard<std::mutex> lock(m_accessTokenMutex);
DateTime const currentTime = std::chrono::system_clock::now();

{
std::shared_lock<std::shared_timed_mutex> readLock(m_accessTokenMutex);
if (!TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext))
{
ApplyBearerToken(request, m_accessToken);
return;
}
}

if (tokenRequestContext.TenantId != m_accessTokenContext.TenantId
|| tokenRequestContext.Scopes != m_accessTokenContext.Scopes
|| std::chrono::system_clock::now()
> (m_accessToken.ExpiresOn - tokenRequestContext.MinimumExpiration))
std::unique_lock<std::shared_timed_mutex> writeLock(m_accessTokenMutex);
// Check if token needs refresh for the second time in case another thread has just updated it.
if (TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext))
{
m_accessToken = m_credential->GetToken(tokenRequestContext, context);
m_accessTokenContext = tokenRequestContext;
}

request.SetHeader("authorization", "Bearer " + m_accessToken.Token);
ApplyBearerToken(request, m_accessToken);
}