Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- [[#4474]](https://github.com/Azure/azure-sdk-for-cpp/issues/4474) Enable proactive renewal of Managed Identity tokens.

### Breaking Changes

### Bugs Fixed
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/src/azure_cli_credential.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ AccessToken AzureCliCredential::GetToken(
"accessToken",
"expiresIn",
std::vector<std::string>{"expires_on", "expiresOn"},
"",
false,
GetLocalTimeToUtcDiffSeconds());
}
catch (json::exception const&)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ AccessToken ClientCertificateCredential::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCredentialImpl->GetToken(context, [&]() {
return m_tokenCredentialImpl->GetToken(context, false, [&]() {
auto body = m_requestBody;
if (!scopesStr.empty())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ AccessToken ClientSecretCredential::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCredentialImpl->GetToken(context, [&]() {
return m_tokenCredentialImpl->GetToken(context, false, [&]() {
auto body = m_requestBody;

if (!scopesStr.empty())
Expand Down
7 changes: 4 additions & 3 deletions sdk/identity/azure-identity/src/managed_identity_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Azure::Core::Credentials::AccessToken AppServiceManagedIdentitySource::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
return TokenCredentialImpl::GetToken(context, true, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);

if (!scopesStr.empty())
Expand Down Expand Up @@ -219,7 +219,7 @@ Azure::Core::Credentials::AccessToken CloudShellManagedIdentitySource::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
return TokenCredentialImpl::GetToken(context, true, [&]() {
using Azure::Core::Url;
using Azure::Core::Http::HttpMethod;

Expand Down Expand Up @@ -320,6 +320,7 @@ Azure::Core::Credentials::AccessToken AzureArcManagedIdentitySource::GetToken(
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(
context,
true,
createRequest,
[&](auto const statusCode, auto const& response) -> std::unique_ptr<TokenRequest> {
using Core::Credentials::AuthenticationException;
Expand Down Expand Up @@ -418,7 +419,7 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
return TokenCredentialImpl::GetToken(context, true, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);

if (!scopesStr.empty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ namespace Azure { namespace Identity { namespace _detail {
* @param expiresOnPropertyNames Names of properties in the JSON object that represent token
* expiration as absolute date-time stamp. Can be empty, in which case no attempt to parse the
* corresponding property will be made. Empty string elements will be ignored.
* @param refreshInPropertyName Name of a property in the JSON object that represents when to
* refresh the token in number of seconds from now.
* @param proactiveRenewal A value to indicate whether to refresh tokens, proactively, with half
* lifetime or not.
* @param utcDiffSeconds Optional. If not 0, it represents the difference between the UTC and a
* desired time zone, in seconds. Then, should an RFC3339 timestamp come without a time zone
* information, a corresponding time zone offset will be applied to such timestamp.
Expand All @@ -88,6 +92,8 @@ namespace Azure { namespace Identity { namespace _detail {
std::string const& accessTokenPropertyName,
std::string const& expiresInPropertyName,
std::vector<std::string> const& expiresOnPropertyNames,
std::string const& refreshInPropertyName = "",
bool proactiveRenewal = false,
int utcDiffSeconds = 0);

/**
Expand All @@ -101,6 +107,10 @@ namespace Azure { namespace Identity { namespace _detail {
* @param expiresOnPropertyName Name of a property in the JSON object that represents token
* expiration as absolute date-time stamp. Can be empty, in which case no attempt to parse it is
* made.
* @param refreshInPropertyName Name of a property in the JSON object that represents
* when to refresh the token in number of seconds from now.
* @param proactiveRenewal A value to indicate whether to refresh tokens, proactively, with half
* lifetime or not.
*
* @return A successfully parsed access token.
*
Expand All @@ -110,13 +120,17 @@ namespace Azure { namespace Identity { namespace _detail {
std::string const& jsonString,
std::string const& accessTokenPropertyName,
std::string const& expiresInPropertyName,
std::string const& expiresOnPropertyName)
std::string const& expiresOnPropertyName,
std::string const& refreshInPropertyName = "",
bool proactiveRenewal = false)
{
return ParseToken(
jsonString,
accessTokenPropertyName,
expiresInPropertyName,
std::vector<std::string>{expiresOnPropertyName});
std::vector<std::string>{expiresOnPropertyName},
refreshInPropertyName,
proactiveRenewal);
}

/**
Expand Down Expand Up @@ -169,6 +183,8 @@ namespace Azure { namespace Identity { namespace _detail {
* @brief Gets an authentication token.
*
* @param context A context to control the request lifetime.
* @param proactiveRenewal A value to indicate whether to refresh tokens, proactively, with half
* lifetime or not.
* @param createRequest A function to create a token request.
* @param shouldRetry A function to determine whether a response should be retried with
* another request.
Expand All @@ -177,6 +193,7 @@ namespace Azure { namespace Identity { namespace _detail {
*/
Core::Credentials::AccessToken GetToken(
Core::Context const& context,
bool proactiveRenewal,
std::function<std::unique_ptr<TokenRequest>()> const& createRequest,
std::function<std::unique_ptr<TokenRequest>(
Core::Http::HttpStatusCode statusCode,
Expand Down
93 changes: 83 additions & 10 deletions sdk/identity/azure-identity/src/token_credential_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using Azure::Core::Http::HttpStatusCode;
using Azure::Core::Http::RawResponse;
using Azure::Core::Json::_internal::json;

using namespace std::chrono_literals;

TokenCredentialImpl::TokenCredentialImpl(TokenCredentialOptions const& options)
: m_httpPipeline(options, "identity", PackageVersion::ToString(), {}, {})
{
Expand Down Expand Up @@ -91,6 +93,7 @@ std::string TokenCredentialImpl::FormatScopes(

AccessToken TokenCredentialImpl::GetToken(
Context const& context,
bool proactiveRenewal,
std::function<std::unique_ptr<TokenCredentialImpl::TokenRequest>()> const& createRequest,
std::function<std::unique_ptr<TokenCredentialImpl::TokenRequest>(
HttpStatusCode statusCode,
Expand Down Expand Up @@ -140,7 +143,9 @@ AccessToken TokenCredentialImpl::GetToken(
std::string(responseBodyVector.begin(), responseBodyVector.end()),
"access_token",
"expires_in",
"expires_on");
"expires_on",
"refresh_in",
proactiveRenewal);
}
catch (AuthenticationException const&)
{
Expand Down Expand Up @@ -224,13 +229,39 @@ std::string TimeZoneOffsetAsString(int utcDiffSeconds)
return os.str();
}

// Proactive renewal by cutting the refresh time in half if the token expires in more than
// 2 hours.
std::chrono::seconds GetProactiveRenewalSeconds(std::chrono::seconds seconds)
{
if (seconds >= std::chrono::seconds(2h))
{
return seconds / 2;
}
else
{
return seconds;
}
}

DateTime GetProactiveRenewalDateTime(std::int64_t posixTimestamp)
{
const DateTime now = DateTime::clock::now();

const auto renewInSeconds = std::chrono::duration_cast<std::chrono::seconds>(
PosixTimeConverter::PosixTimeToDateTime(posixTimestamp) - now);

return DateTime(now + GetProactiveRenewalSeconds(renewInSeconds));
}

} // namespace

AccessToken TokenCredentialImpl::ParseToken(
std::string const& jsonString,
std::string const& accessTokenPropertyName,
std::string const& expiresInPropertyName,
std::vector<std::string> const& expiresOnPropertyNames,
std::string const& refreshInPropertyName,
bool proactiveRenewal,
int utcDiffSeconds)
{
json parsedJson;
Expand Down Expand Up @@ -262,6 +293,35 @@ AccessToken TokenCredentialImpl::ParseToken(
accessToken.Token = parsedJson[accessTokenPropertyName].get<std::string>();
accessToken.ExpiresOn = std::chrono::system_clock::now();

// expiresIn = number of seconds until refresh.
// expiresOn = timestamp of refresh expressed as seconds since epoch.

if (!refreshInPropertyName.empty() && parsedJson.contains(refreshInPropertyName))
{
auto const& refreshIn = parsedJson[refreshInPropertyName];
if (refreshIn.is_number_unsigned())
{
try
{
// 'refresh_in' as number (seconds until refresh)
auto const value = refreshIn.get<std::int64_t>();
if (value <= MaxExpirationInSeconds)
{
static_assert(
MaxExpirationInSeconds <= std::numeric_limits<std::int32_t>::max(),
"Can safely cast to int32");

accessToken.ExpiresOn += std::chrono::seconds(static_cast<std::int32_t>(value));
return accessToken;
}
}
catch (std::exception const&)
{
// refreshIn.get<std::int64_t>() has thrown, we may throw later.
}
}
}

if (parsedJson.contains(expiresInPropertyName))
{
auto const& expiresIn = parsedJson[expiresInPropertyName];
Expand All @@ -278,7 +338,9 @@ AccessToken TokenCredentialImpl::ParseToken(
MaxExpirationInSeconds <= std::numeric_limits<std::int32_t>::max(),
"Can safely cast to int32");

accessToken.ExpiresOn += std::chrono::seconds(static_cast<std::int32_t>(value));
auto expiresInSeconds = std::chrono::seconds(static_cast<std::int32_t>(value));
accessToken.ExpiresOn
+= proactiveRenewal ? GetProactiveRenewalSeconds(expiresInSeconds) : expiresInSeconds;
return accessToken;
}
}
Expand All @@ -297,8 +359,10 @@ AccessToken TokenCredentialImpl::ParseToken(
MaxExpirationInSeconds <= std::numeric_limits<std::int32_t>::max(),
"Can safely cast to int32");

accessToken.ExpiresOn += std::chrono::seconds(static_cast<std::int32_t>(
auto expiresInSeconds = std::chrono::seconds(static_cast<std::int32_t>(
ParseNumericExpiration(expiresIn.get<std::string>(), MaxExpirationInSeconds)));
accessToken.ExpiresOn
+= proactiveRenewal ? GetProactiveRenewalSeconds(expiresInSeconds) : expiresInSeconds;

return accessToken;
}
Expand Down Expand Up @@ -342,7 +406,9 @@ AccessToken TokenCredentialImpl::ParseToken(
auto const value = expiresOn.get<std::int64_t>();
if (value <= MaxPosixTimestamp)
{
accessToken.ExpiresOn = PosixTimeConverter::PosixTimeToDateTime(value);
accessToken.ExpiresOn = proactiveRenewal
? GetProactiveRenewalDateTime(value)
: PosixTimeConverter::PosixTimeToDateTime(value);
return accessToken;
}
}
Expand All @@ -359,16 +425,23 @@ AccessToken TokenCredentialImpl::ParseToken(
for (auto const& parse : {
std::function<DateTime(std::string const&)>([&](auto const& s) {
// 'expires_on' as RFC3339 date string (absolute timestamp)
return DateTime::Parse(s + tzOffsetStr, DateTime::DateFormat::Rfc3339);
auto dateTime = DateTime::Parse(s + tzOffsetStr, DateTime::DateFormat::Rfc3339);
return proactiveRenewal ? GetProactiveRenewalDateTime(
PosixTimeConverter::DateTimeToPosixTime(dateTime))
: dateTime;
}),
std::function<DateTime(std::string const&)>([](auto const& s) {
std::function<DateTime(std::string const&)>([&](auto const& s) {
// 'expires_on' as numeric string (posix time representing an absolute timestamp)
return PosixTimeConverter::PosixTimeToDateTime(
ParseNumericExpiration(s, MaxPosixTimestamp));
auto value = ParseNumericExpiration(s, MaxPosixTimestamp);
return proactiveRenewal ? GetProactiveRenewalDateTime(value)
: PosixTimeConverter::PosixTimeToDateTime(value);
}),
std::function<DateTime(std::string const&)>([](auto const& s) {
std::function<DateTime(std::string const&)>([&](auto const& s) {
// 'expires_on' as RFC1123 date string (absolute timestamp)
return DateTime::Parse(s, DateTime::DateFormat::Rfc1123);
auto dateTime = DateTime::Parse(s, DateTime::DateFormat::Rfc1123);
return proactiveRenewal ? GetProactiveRenewalDateTime(
PosixTimeConverter::DateTimeToPosixTime(dateTime))
: dateTime;
}),
})
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ AccessToken WorkloadIdentityCredential::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCredentialImpl->GetToken(context, [&]() {
return m_tokenCredentialImpl->GetToken(context, false, [&]() {
auto body = m_requestBody;
if (!scopesStr.empty())
{
Expand Down
Loading