Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace credentials cache with identity cache #3077

Merged
merged 12 commits into from
Oct 21, 2023
18 changes: 18 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,21 @@ message = """
references = ["smithy-rs#3076"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" }
author = "ysaito1001"

[[aws-sdk-rust]]
message = "**This change has [detailed upgrade guidance](https://github.com/awslabs/aws-sdk-rust/discussions/923).** <br><br>The AWS credentials cache has been replaced with a more generic identity cache."
references = ["smithy-rs#3077"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "**Behavior Break!** Identities for auth are now cached by default. See the `Config` builder's `identity_cache()` method docs for an example of how to disable this caching."
references = ["smithy-rs#3077"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"

[[smithy-rs]]
message = "Clients now have a default async sleep implementation so that one does not need to be specified if you're using Tokio."
references = ["smithy-rs#3071"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client" }
author = "jdisanti"
5 changes: 4 additions & 1 deletion aws/rust-runtime/aws-config/external-types.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ allowed_external_types = [
"aws_smithy_async::rt::sleep::SharedAsyncSleep",
"aws_smithy_async::time::SharedTimeSource",
"aws_smithy_async::time::TimeSource",
"aws_smithy_types::body::SdkBody",
"aws_smithy_http::endpoint",
"aws_smithy_http::endpoint::error::InvalidEndpointError",
"aws_smithy_http::result::SdkError",
"aws_smithy_runtime::client::identity::cache::IdentityCache",
"aws_smithy_runtime::client::identity::cache::lazy::LazyCacheBuilder",
"aws_smithy_runtime_api::client::dns::ResolveDns",
"aws_smithy_runtime_api::client::dns::SharedDnsResolver",
"aws_smithy_runtime_api::client::http::HttpClient",
"aws_smithy_runtime_api::client::http::SharedHttpClient",
"aws_smithy_runtime_api::client::identity::ResolveCachedIdentity",
"aws_smithy_runtime_api::client::identity::ResolveIdentity",
"aws_smithy_types::body::SdkBody",
"aws_smithy_types::retry",
"aws_smithy_types::retry::*",
"aws_smithy_types::timeout",
Expand Down
5 changes: 4 additions & 1 deletion aws/rust-runtime/aws-config/src/imds/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ impl Builder {
.runtime_plugin(common_plugin.clone())
.runtime_plugin(TokenRuntimePlugin::new(
common_plugin,
config.time_source(),
self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
))
.with_connection_poisoning()
Expand Down Expand Up @@ -748,6 +747,7 @@ pub(crate) mod test {
/// Tokens are refreshed up to 120 seconds early to avoid using an expired token.
#[tokio::test]
async fn token_refresh_buffer() {
let _logs = capture_test_logs();
let (_, http_client) = mock_imds_client(vec![
ReplayEvent::new(
token_request("http://[fd00:ec2::254]", 600),
Expand Down Expand Up @@ -785,11 +785,14 @@ pub(crate) mod test {
.token_ttl(Duration::from_secs(600))
.build();

tracing::info!("resp1 -----------------------------------------------------------");
let resp1 = client.get("/latest/metadata").await.expect("success");
// now the cached credential has expired
time_source.advance(Duration::from_secs(400));
tracing::info!("resp2 -----------------------------------------------------------");
let resp2 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(150));
tracing::info!("resp3 -----------------------------------------------------------");
let resp3 = client.get("/latest/metadata").await.expect("success");
http_client.assert_requests_match(&[]);
assert_eq!("test-imds-output1", resp1.as_ref());
Expand Down
84 changes: 46 additions & 38 deletions aws/rust-runtime/aws-config/src/imds/client/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
//! - Retry token loading when it fails
//! - Attach the token to the request in the `x-aws-ec2-metadata-token` header

use crate::identity::IdentityCache;
use crate::imds::client::error::{ImdsError, TokenError, TokenErrorKind};
use aws_credential_types::cache::ExpiringCache;
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime::expiring_cache::ExpiringCache;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
use aws_smithy_runtime_api::client::auth::{
Expand Down Expand Up @@ -50,6 +51,12 @@ const X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS: &str = "x-aws-ec2-metadata-token-ttl
const X_AWS_EC2_METADATA_TOKEN: &str = "x-aws-ec2-metadata-token";
const IMDS_TOKEN_AUTH_SCHEME: AuthSchemeId = AuthSchemeId::new(X_AWS_EC2_METADATA_TOKEN);

#[derive(Debug)]
struct TtlToken {
value: HeaderValue,
ttl: Duration,
}

/// IMDS Token
#[derive(Clone)]
struct Token {
Expand All @@ -76,20 +83,18 @@ pub(super) struct TokenRuntimePlugin {
}

impl TokenRuntimePlugin {
pub(super) fn new(
common_plugin: SharedRuntimePlugin,
time_source: SharedTimeSource,
token_ttl: Duration,
) -> Self {
pub(super) fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
Self {
components: RuntimeComponentsBuilder::new("TokenRuntimePlugin")
.with_auth_scheme(TokenAuthScheme::new())
.with_auth_scheme_option_resolver(Some(StaticAuthSchemeOptionResolver::new(vec![
IMDS_TOKEN_AUTH_SCHEME,
])))
// The TokenResolver has a cache of its own, so don't use identity caching
.with_identity_cache(Some(IdentityCache::no_cache()))
.with_identity_resolver(
IMDS_TOKEN_AUTH_SCHEME,
TokenResolver::new(common_plugin, time_source, token_ttl),
TokenResolver::new(common_plugin, token_ttl),
),
}
}
Expand All @@ -107,8 +112,7 @@ impl RuntimePlugin for TokenRuntimePlugin {
#[derive(Debug)]
struct TokenResolverInner {
cache: ExpiringCache<Token, ImdsError>,
refresh: Operation<(), Token, TokenError>,
time_source: SharedTimeSource,
refresh: Operation<(), TtlToken, TokenError>,
}

#[derive(Clone, Debug)]
Expand All @@ -117,11 +121,7 @@ struct TokenResolver {
}

impl TokenResolver {
fn new(
common_plugin: SharedRuntimePlugin,
time_source: SharedTimeSource,
token_ttl: Duration,
) -> Self {
fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
Self {
inner: Arc::new(TokenResolverInner {
cache: ExpiringCache::new(TOKEN_REFRESH_BUFFER),
Expand All @@ -141,34 +141,34 @@ impl TokenResolver {
.try_into()
.unwrap())
})
.deserializer({
let time_source = time_source.clone();
move |response| {
let now = time_source.now();
parse_token_response(response, now)
.map_err(OrchestratorError::operation)
}
.deserializer(move |response| {
parse_token_response(response).map_err(OrchestratorError::operation)
})
.build(),
time_source,
}),
}
}

async fn get_token(&self) -> Result<(Token, SystemTime), ImdsError> {
self.inner
.refresh
.invoke(())
.await
async fn get_token(
&self,
time_source: SharedTimeSource,
) -> Result<(Token, SystemTime), ImdsError> {
let result = self.inner.refresh.invoke(()).await;
let now = time_source.now();
result
.map(|token| {
let token = Token {
value: token.value,
expiry: now + token.ttl,
};
let expiry = token.expiry;
(token, expiry)
})
.map_err(ImdsError::failed_to_load_token)
}
}

fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Token, TokenError> {
fn parse_token_response(response: &HttpResponse) -> Result<TtlToken, TokenError> {
match response.status().as_u16() {
400 => return Err(TokenErrorKind::InvalidParameters.into()),
403 => return Err(TokenErrorKind::Forbidden.into()),
Expand All @@ -187,30 +187,38 @@ fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Toke
.map_err(|_| TokenErrorKind::InvalidTtl)?
.parse()
.map_err(|_parse_error| TokenErrorKind::InvalidTtl)?;
Ok(Token {
Ok(TtlToken {
value,
expiry: now + Duration::from_secs(ttl),
ttl: Duration::from_secs(ttl),
})
}

impl ResolveIdentity for TokenResolver {
fn resolve_identity<'a>(
&'a self,
_components: &'a RuntimeComponents,
components: &'a RuntimeComponents,
_config_bag: &'a ConfigBag,
) -> IdentityFuture<'a> {
let time_source = components
.time_source()
.expect("time source required for IMDS token caching");
IdentityFuture::new(async {
let preloaded_token = self
.inner
.cache
.yield_or_clear_if_expired(self.inner.time_source.now())
.await;
let now = time_source.now();
let preloaded_token = self.inner.cache.yield_or_clear_if_expired(now).await;
let token = match preloaded_token {
Some(token) => Ok(token),
Some(token) => {
tracing::trace!(
buffer_time=?TOKEN_REFRESH_BUFFER,
expiration=?token.expiry,
now=?now,
"loaded IMDS token from cache");
Ok(token)
}
None => {
tracing::debug!("IMDS token cache miss");
self.inner
.cache
.get_or_load(|| async { self.get_token().await })
.get_or_load(|| async { self.get_token(time_source).await })
.await
}
}?;
Expand Down
Loading