Skip to content

Commit

Permalink
Replace credentials cache with identity cache (#3077)
Browse files Browse the repository at this point in the history
This PR replaces the credentials cache with the new identity cache, and
adds config validation via the `SharedConfigValidator` runtime component
and `ValidateConfig` trait.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
jdisanti authored Oct 20, 2023
1 parent 66a3acf commit 40f4662
Show file tree
Hide file tree
Showing 52 changed files with 1,157 additions and 1,962 deletions.
18 changes: 18 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,21 @@ message = """
references = ["smithy-rs#3076"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" }
author = "ysaito1001"

[[aws-sdk-rust]]
message = "**This change has [detailed upgrade guidance](https://github.com/awslabs/aws-sdk-rust/discussions/923).** <br><br>The AWS credentials cache has been replaced with a more generic identity cache."
references = ["smithy-rs#3077"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "**Behavior Break!** Identities for auth are now cached by default. See the `Config` builder's `identity_cache()` method docs for an example of how to disable this caching."
references = ["smithy-rs#3077"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"

[[smithy-rs]]
message = "Clients now have a default async sleep implementation so that one does not need to be specified if you're using Tokio."
references = ["smithy-rs#3071"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client" }
author = "jdisanti"
5 changes: 4 additions & 1 deletion aws/rust-runtime/aws-config/external-types.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ allowed_external_types = [
"aws_smithy_async::rt::sleep::SharedAsyncSleep",
"aws_smithy_async::time::SharedTimeSource",
"aws_smithy_async::time::TimeSource",
"aws_smithy_types::body::SdkBody",
"aws_smithy_http::endpoint",
"aws_smithy_http::endpoint::error::InvalidEndpointError",
"aws_smithy_http::result::SdkError",
"aws_smithy_runtime::client::identity::cache::IdentityCache",
"aws_smithy_runtime::client::identity::cache::lazy::LazyCacheBuilder",
"aws_smithy_runtime_api::client::dns::ResolveDns",
"aws_smithy_runtime_api::client::dns::SharedDnsResolver",
"aws_smithy_runtime_api::client::http::HttpClient",
"aws_smithy_runtime_api::client::http::SharedHttpClient",
"aws_smithy_runtime_api::client::identity::ResolveCachedIdentity",
"aws_smithy_runtime_api::client::identity::ResolveIdentity",
"aws_smithy_types::body::SdkBody",
"aws_smithy_types::retry",
"aws_smithy_types::retry::*",
"aws_smithy_types::timeout",
Expand Down
5 changes: 4 additions & 1 deletion aws/rust-runtime/aws-config/src/imds/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ impl Builder {
.runtime_plugin(common_plugin.clone())
.runtime_plugin(TokenRuntimePlugin::new(
common_plugin,
config.time_source(),
self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
))
.with_connection_poisoning()
Expand Down Expand Up @@ -748,6 +747,7 @@ pub(crate) mod test {
/// Tokens are refreshed up to 120 seconds early to avoid using an expired token.
#[tokio::test]
async fn token_refresh_buffer() {
let _logs = capture_test_logs();
let (_, http_client) = mock_imds_client(vec![
ReplayEvent::new(
token_request("http://[fd00:ec2::254]", 600),
Expand Down Expand Up @@ -785,11 +785,14 @@ pub(crate) mod test {
.token_ttl(Duration::from_secs(600))
.build();

tracing::info!("resp1 -----------------------------------------------------------");
let resp1 = client.get("/latest/metadata").await.expect("success");
// now the cached credential has expired
time_source.advance(Duration::from_secs(400));
tracing::info!("resp2 -----------------------------------------------------------");
let resp2 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(150));
tracing::info!("resp3 -----------------------------------------------------------");
let resp3 = client.get("/latest/metadata").await.expect("success");
http_client.assert_requests_match(&[]);
assert_eq!("test-imds-output1", resp1.as_ref());
Expand Down
84 changes: 46 additions & 38 deletions aws/rust-runtime/aws-config/src/imds/client/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
//! - Retry token loading when it fails
//! - Attach the token to the request in the `x-aws-ec2-metadata-token` header

use crate::identity::IdentityCache;
use crate::imds::client::error::{ImdsError, TokenError, TokenErrorKind};
use aws_credential_types::cache::ExpiringCache;
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime::expiring_cache::ExpiringCache;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
use aws_smithy_runtime_api::client::auth::{
Expand Down Expand Up @@ -50,6 +51,12 @@ const X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS: &str = "x-aws-ec2-metadata-token-ttl
const X_AWS_EC2_METADATA_TOKEN: &str = "x-aws-ec2-metadata-token";
const IMDS_TOKEN_AUTH_SCHEME: AuthSchemeId = AuthSchemeId::new(X_AWS_EC2_METADATA_TOKEN);

#[derive(Debug)]
struct TtlToken {
value: HeaderValue,
ttl: Duration,
}

/// IMDS Token
#[derive(Clone)]
struct Token {
Expand All @@ -76,20 +83,18 @@ pub(super) struct TokenRuntimePlugin {
}

impl TokenRuntimePlugin {
pub(super) fn new(
common_plugin: SharedRuntimePlugin,
time_source: SharedTimeSource,
token_ttl: Duration,
) -> Self {
pub(super) fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
Self {
components: RuntimeComponentsBuilder::new("TokenRuntimePlugin")
.with_auth_scheme(TokenAuthScheme::new())
.with_auth_scheme_option_resolver(Some(StaticAuthSchemeOptionResolver::new(vec![
IMDS_TOKEN_AUTH_SCHEME,
])))
// The TokenResolver has a cache of its own, so don't use identity caching
.with_identity_cache(Some(IdentityCache::no_cache()))
.with_identity_resolver(
IMDS_TOKEN_AUTH_SCHEME,
TokenResolver::new(common_plugin, time_source, token_ttl),
TokenResolver::new(common_plugin, token_ttl),
),
}
}
Expand All @@ -107,8 +112,7 @@ impl RuntimePlugin for TokenRuntimePlugin {
#[derive(Debug)]
struct TokenResolverInner {
cache: ExpiringCache<Token, ImdsError>,
refresh: Operation<(), Token, TokenError>,
time_source: SharedTimeSource,
refresh: Operation<(), TtlToken, TokenError>,
}

#[derive(Clone, Debug)]
Expand All @@ -117,11 +121,7 @@ struct TokenResolver {
}

impl TokenResolver {
fn new(
common_plugin: SharedRuntimePlugin,
time_source: SharedTimeSource,
token_ttl: Duration,
) -> Self {
fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
Self {
inner: Arc::new(TokenResolverInner {
cache: ExpiringCache::new(TOKEN_REFRESH_BUFFER),
Expand All @@ -141,34 +141,34 @@ impl TokenResolver {
.try_into()
.unwrap())
})
.deserializer({
let time_source = time_source.clone();
move |response| {
let now = time_source.now();
parse_token_response(response, now)
.map_err(OrchestratorError::operation)
}
.deserializer(move |response| {
parse_token_response(response).map_err(OrchestratorError::operation)
})
.build(),
time_source,
}),
}
}

async fn get_token(&self) -> Result<(Token, SystemTime), ImdsError> {
self.inner
.refresh
.invoke(())
.await
async fn get_token(
&self,
time_source: SharedTimeSource,
) -> Result<(Token, SystemTime), ImdsError> {
let result = self.inner.refresh.invoke(()).await;
let now = time_source.now();
result
.map(|token| {
let token = Token {
value: token.value,
expiry: now + token.ttl,
};
let expiry = token.expiry;
(token, expiry)
})
.map_err(ImdsError::failed_to_load_token)
}
}

fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Token, TokenError> {
fn parse_token_response(response: &HttpResponse) -> Result<TtlToken, TokenError> {
match response.status().as_u16() {
400 => return Err(TokenErrorKind::InvalidParameters.into()),
403 => return Err(TokenErrorKind::Forbidden.into()),
Expand All @@ -187,30 +187,38 @@ fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Toke
.map_err(|_| TokenErrorKind::InvalidTtl)?
.parse()
.map_err(|_parse_error| TokenErrorKind::InvalidTtl)?;
Ok(Token {
Ok(TtlToken {
value,
expiry: now + Duration::from_secs(ttl),
ttl: Duration::from_secs(ttl),
})
}

impl ResolveIdentity for TokenResolver {
fn resolve_identity<'a>(
&'a self,
_components: &'a RuntimeComponents,
components: &'a RuntimeComponents,
_config_bag: &'a ConfigBag,
) -> IdentityFuture<'a> {
let time_source = components
.time_source()
.expect("time source required for IMDS token caching");
IdentityFuture::new(async {
let preloaded_token = self
.inner
.cache
.yield_or_clear_if_expired(self.inner.time_source.now())
.await;
let now = time_source.now();
let preloaded_token = self.inner.cache.yield_or_clear_if_expired(now).await;
let token = match preloaded_token {
Some(token) => Ok(token),
Some(token) => {
tracing::trace!(
buffer_time=?TOKEN_REFRESH_BUFFER,
expiration=?token.expiry,
now=?now,
"loaded IMDS token from cache");
Ok(token)
}
None => {
tracing::debug!("IMDS token cache miss");
self.inner
.cache
.get_or_load(|| async { self.get_token().await })
.get_or_load(|| async { self.get_token(time_source).await })
.await
}
}?;
Expand Down
Loading

0 comments on commit 40f4662

Please sign in to comment.