Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- Enable proactive renewal of Managed Identity tokens.

### Breaking Changes

### Bugs Fixed
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/src/azure_cli_credential.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ AccessToken AzureCliCredential::GetToken(
"accessToken",
"expiresIn",
std::vector<std::string>{"expires_on", "expiresOn"},
"",
false,
GetLocalTimeToUtcDiffSeconds());
}
catch (json::exception const&)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ AccessToken ClientCertificateCredential::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCredentialImpl->GetToken(context, [&]() {
return m_tokenCredentialImpl->GetToken(context, false, [&]() {
auto body = m_requestBody;
if (!scopesStr.empty())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ AccessToken ClientSecretCredential::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCredentialImpl->GetToken(context, [&]() {
return m_tokenCredentialImpl->GetToken(context, false, [&]() {
auto body = m_requestBody;

if (!scopesStr.empty())
Expand Down
7 changes: 4 additions & 3 deletions sdk/identity/azure-identity/src/managed_identity_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Azure::Core::Credentials::AccessToken AppServiceManagedIdentitySource::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
return TokenCredentialImpl::GetToken(context, true, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);

if (!scopesStr.empty())
Expand Down Expand Up @@ -219,7 +219,7 @@ Azure::Core::Credentials::AccessToken CloudShellManagedIdentitySource::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
return TokenCredentialImpl::GetToken(context, true, [&]() {
using Azure::Core::Url;
using Azure::Core::Http::HttpMethod;

Expand Down Expand Up @@ -320,6 +320,7 @@ Azure::Core::Credentials::AccessToken AzureArcManagedIdentitySource::GetToken(
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(
context,
true,
createRequest,
[&](auto const statusCode, auto const& response) -> std::unique_ptr<TokenRequest> {
using Core::Credentials::AuthenticationException;
Expand Down Expand Up @@ -418,7 +419,7 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
return TokenCredentialImpl::GetToken(context, true, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);

if (!scopesStr.empty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ namespace Azure { namespace Identity { namespace _detail {
* @param expiresOnPropertyNames Names of properties in the JSON object that represent token
* expiration as absolute date-time stamp. Can be empty, in which case no attempt to parse the
* corresponding property will be made. Empty string elements will be ignored.
* @param refreshInPropertyName Name of a property in the JSON object that represents when to
* refresh the token in number of seconds from now.
* @param proactiveRenewal A value to indicate whether to refresh tokens, proactively, with half
* lifetime or not.
* @param utcDiffSeconds Optional. If not 0, it represents the difference between the UTC and a
* desired time zone, in seconds. Then, should an RFC3339 timestamp come without a time zone
* information, a corresponding time zone offset will be applied to such timestamp.
Expand All @@ -88,6 +92,8 @@ namespace Azure { namespace Identity { namespace _detail {
std::string const& accessTokenPropertyName,
std::string const& expiresInPropertyName,
std::vector<std::string> const& expiresOnPropertyNames,
std::string const& refreshInPropertyName = "",
bool proactiveRenewal = false,
int utcDiffSeconds = 0);

/**
Expand All @@ -101,6 +107,10 @@ namespace Azure { namespace Identity { namespace _detail {
* @param expiresOnPropertyName Name of a property in the JSON object that represents token
* expiration as absolute date-time stamp. Can be empty, in which case no attempt to parse it is
* made.
* @param refreshInPropertyName Name of a property in the JSON object that represents
* when to refresh the token in number of seconds from now.
* @param proactiveRenewal A value to indicate whether to refresh tokens, proactively, with half
* lifetime or not.
*
* @return A successfully parsed access token.
*
Expand All @@ -110,13 +120,17 @@ namespace Azure { namespace Identity { namespace _detail {
std::string const& jsonString,
std::string const& accessTokenPropertyName,
std::string const& expiresInPropertyName,
std::string const& expiresOnPropertyName)
std::string const& expiresOnPropertyName,
std::string const& refreshInPropertyName = "",
bool proactiveRenewal = false)
{
return ParseToken(
jsonString,
accessTokenPropertyName,
expiresInPropertyName,
std::vector<std::string>{expiresOnPropertyName});
std::vector<std::string>{expiresOnPropertyName},
refreshInPropertyName,
proactiveRenewal);
}

/**
Expand Down Expand Up @@ -169,6 +183,8 @@ namespace Azure { namespace Identity { namespace _detail {
* @brief Gets an authentication token.
*
* @param context A context to control the request lifetime.
* @param proactiveRenewal A value to indicate whether to refresh tokens, proactively, with half
* lifetime or not.
* @param createRequest A function to create a token request.
* @param shouldRetry A function to determine whether a response should be retried with
* another request.
Expand All @@ -177,6 +193,7 @@ namespace Azure { namespace Identity { namespace _detail {
*/
Core::Credentials::AccessToken GetToken(
Core::Context const& context,
bool proactiveRenewal,
std::function<std::unique_ptr<TokenRequest>()> const& createRequest,
std::function<std::unique_ptr<TokenRequest>(
Core::Http::HttpStatusCode statusCode,
Expand Down
94 changes: 83 additions & 11 deletions sdk/identity/azure-identity/src/token_credential_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using Azure::Core::Http::HttpStatusCode;
using Azure::Core::Http::RawResponse;
using Azure::Core::Json::_internal::json;

using namespace std::chrono_literals;

TokenCredentialImpl::TokenCredentialImpl(TokenCredentialOptions const& options)
: m_httpPipeline(options, "identity", PackageVersion::ToString(), {}, {})
{
Expand Down Expand Up @@ -91,6 +93,7 @@ std::string TokenCredentialImpl::FormatScopes(

AccessToken TokenCredentialImpl::GetToken(
Context const& context,
bool proactiveRenewal,
std::function<std::unique_ptr<TokenCredentialImpl::TokenRequest>()> const& createRequest,
std::function<std::unique_ptr<TokenCredentialImpl::TokenRequest>(
HttpStatusCode statusCode,
Expand Down Expand Up @@ -140,7 +143,9 @@ AccessToken TokenCredentialImpl::GetToken(
std::string(responseBodyVector.begin(), responseBodyVector.end()),
"access_token",
"expires_in",
"expires_on");
"expires_on",
"refresh_in",
proactiveRenewal);
}
catch (AuthenticationException const&)
{
Expand Down Expand Up @@ -226,11 +231,36 @@ std::string TimeZoneOffsetAsString(int utcDiffSeconds)

} // namespace

// Proactive renewal by cutting the refresh time in half if the token expires in more than
// 2 hours.
static int64_t ProactiveRenewal(int64_t value)
{
if (value >= std::chrono::seconds(2h).count())
{
return value / 2;
}
else
{
return value;
}
}

static DateTime ProactiveRenewalDateTime(int64_t value)
{
std::int64_t curentSecondsSinceEpoch = std::chrono::seconds(std::time(NULL)).count();
std::int64_t duration = value - curentSecondsSinceEpoch;

return PosixTimeConverter::PosixTimeToDateTime(
ProactiveRenewal(duration) + curentSecondsSinceEpoch);
}

AccessToken TokenCredentialImpl::ParseToken(
std::string const& jsonString,
std::string const& accessTokenPropertyName,
std::string const& expiresInPropertyName,
std::vector<std::string> const& expiresOnPropertyNames,
std::string const& refreshInPropertyName,
bool proactiveRenewal,
int utcDiffSeconds)
{
json parsedJson;
Expand Down Expand Up @@ -262,6 +292,35 @@ AccessToken TokenCredentialImpl::ParseToken(
accessToken.Token = parsedJson[accessTokenPropertyName].get<std::string>();
accessToken.ExpiresOn = std::chrono::system_clock::now();

// expiresIn = number of seconds until refresh.
// expiresOn = timestamp of refresh expressed as seconds since epoch.

if (!refreshInPropertyName.empty() && parsedJson.contains(refreshInPropertyName))
{
auto const& refreshIn = parsedJson[refreshInPropertyName];
if (refreshIn.is_number_unsigned())
{
try
{
// 'refresh_in' as number (seconds until refresh)
auto const value = refreshIn.get<std::int64_t>();
if (value <= MaxExpirationInSeconds)
{
static_assert(
MaxExpirationInSeconds <= std::numeric_limits<std::int32_t>::max(),
"Can safely cast to int32");

accessToken.ExpiresOn += std::chrono::seconds(static_cast<std::int32_t>(value));
return accessToken;
}
}
catch (std::exception const&)
{
// refreshIn.get<std::int64_t>() has thrown, we may throw later.
}
}
}

if (parsedJson.contains(expiresInPropertyName))
{
auto const& expiresIn = parsedJson[expiresInPropertyName];
Expand All @@ -278,7 +337,9 @@ AccessToken TokenCredentialImpl::ParseToken(
MaxExpirationInSeconds <= std::numeric_limits<std::int32_t>::max(),
"Can safely cast to int32");

accessToken.ExpiresOn += std::chrono::seconds(static_cast<std::int32_t>(value));
accessToken.ExpiresOn += std::chrono::seconds(
proactiveRenewal ? ProactiveRenewal(static_cast<std::int32_t>(value))
: static_cast<std::int32_t>(value));
return accessToken;
}
}
Expand All @@ -297,8 +358,10 @@ AccessToken TokenCredentialImpl::ParseToken(
MaxExpirationInSeconds <= std::numeric_limits<std::int32_t>::max(),
"Can safely cast to int32");

accessToken.ExpiresOn += std::chrono::seconds(static_cast<std::int32_t>(
ParseNumericExpiration(expiresIn.get<std::string>(), MaxExpirationInSeconds)));
auto value = static_cast<std::int32_t>(
ParseNumericExpiration(expiresIn.get<std::string>(), MaxExpirationInSeconds));
accessToken.ExpiresOn
+= std::chrono::seconds(proactiveRenewal ? ProactiveRenewal(value) : value);

return accessToken;
}
Expand Down Expand Up @@ -342,7 +405,9 @@ AccessToken TokenCredentialImpl::ParseToken(
auto const value = expiresOn.get<std::int64_t>();
if (value <= MaxPosixTimestamp)
{
accessToken.ExpiresOn = PosixTimeConverter::PosixTimeToDateTime(value);
accessToken.ExpiresOn = proactiveRenewal
? ProactiveRenewalDateTime(value)
: PosixTimeConverter::PosixTimeToDateTime(value);
return accessToken;
}
}
Expand All @@ -359,16 +424,23 @@ AccessToken TokenCredentialImpl::ParseToken(
for (auto const& parse : {
std::function<DateTime(std::string const&)>([&](auto const& s) {
// 'expires_on' as RFC3339 date string (absolute timestamp)
return DateTime::Parse(s + tzOffsetStr, DateTime::DateFormat::Rfc3339);
auto dateTime = DateTime::Parse(s + tzOffsetStr, DateTime::DateFormat::Rfc3339);
return proactiveRenewal
? ProactiveRenewalDateTime(PosixTimeConverter::DateTimeToPosixTime(dateTime))
: dateTime;
}),
std::function<DateTime(std::string const&)>([](auto const& s) {
std::function<DateTime(std::string const&)>([&](auto const& s) {
// 'expires_on' as numeric string (posix time representing an absolute timestamp)
return PosixTimeConverter::PosixTimeToDateTime(
ParseNumericExpiration(s, MaxPosixTimestamp));
auto value = ParseNumericExpiration(s, MaxPosixTimestamp);
return proactiveRenewal ? ProactiveRenewalDateTime(value)
: PosixTimeConverter::PosixTimeToDateTime(value);
}),
std::function<DateTime(std::string const&)>([](auto const& s) {
std::function<DateTime(std::string const&)>([&](auto const& s) {
// 'expires_on' as RFC1123 date string (absolute timestamp)
return DateTime::Parse(s, DateTime::DateFormat::Rfc1123);
auto dateTime = DateTime::Parse(s, DateTime::DateFormat::Rfc1123);
return proactiveRenewal
? ProactiveRenewalDateTime(PosixTimeConverter::DateTimeToPosixTime(dateTime))
: dateTime;
}),
})
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ AccessToken WorkloadIdentityCredential::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCredentialImpl->GetToken(context, [&]() {
return m_tokenCredentialImpl->GetToken(context, false, [&]() {
auto body = m_requestBody;
if (!scopesStr.empty())
{
Expand Down
Loading