Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Env vars and profile to configure IMDS retry and timeouts (#625) #626

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 71 additions & 24 deletions sdk/aws-config/src/imds/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
//!
//! Client for direct access to IMDSv2.

use std::borrow::Cow;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt::{Display, Formatter};
Expand Down Expand Up @@ -48,8 +47,8 @@ mod token;
// 6 hours
const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
const DEFAULT_ATTEMPTS: u32 = 4;
const DEFAULT_CONNECT_TIMEOUT: Option<Duration> = Some(Duration::from_secs(1));
const DEFAULT_READ_TIMEOUT: Option<Duration> = Some(Duration::from_secs(1));
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);

fn user_agent() -> AwsUserAgent {
AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
Expand Down Expand Up @@ -555,18 +554,28 @@ impl Builder {
/// Build an IMDSv2 Client
pub async fn build(self) -> Result<Client, BuildError> {
let config = self.config.unwrap_or_default();
let env = Environment::new(config.env(), config.fs()).await?;
let http_timeout_env = env.parse::<f64>(env::TIMEOUT, profile_keys::TIMEOUT)
.transpose()
.ok()
.flatten()
.map(Duration::from_secs_f64);
let num_attempts_env = env.parse::<u32>(env::NUM_ATTEMPTS, profile_keys::NUM_ATTEMPTS)
.transpose()
.ok()
.flatten();
let http_timeout_config = timeout::Http::new()
.with_connect_timeout(self.connect_timeout.or(DEFAULT_CONNECT_TIMEOUT).into())
.with_read_timeout(self.read_timeout.or(DEFAULT_READ_TIMEOUT).into());
.with_connect_timeout(self.connect_timeout.or(http_timeout_env.or(Some(DEFAULT_CONNECT_TIMEOUT))).into())
.with_read_timeout(self.read_timeout.or(http_timeout_env.or(Some(DEFAULT_READ_TIMEOUT))).into());
let http_settings = HttpSettings::default().with_http_timeout_config(http_timeout_config);
let connector = expect_connector(config.connector(&http_settings));
let endpoint_source = self
.endpoint
.unwrap_or_else(|| EndpointSource::Env(config.env(), config.fs()));
.unwrap_or_else(|| EndpointSource::Env(env));
let endpoint = endpoint_source.endpoint(self.mode_override).await?;
let endpoint = Endpoint::immutable(endpoint);
let retry_config = retry::Config::default()
.with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
.with_max_attempts(self.max_attempts.unwrap_or(num_attempts_env.unwrap_or(DEFAULT_ATTEMPTS)));
let timeout_config = timeout::Config::default();
let token_loader = token::TokenMiddleware::new(
connector.clone(),
Expand Down Expand Up @@ -599,18 +608,64 @@ impl Builder {
mod env {
pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
pub(super) const NUM_ATTEMPTS: &str = "AWS_METADATA_SERVICE_NUM_ATTEMPTS";
pub(super) const TIMEOUT: &str = "AWS_METADATA_SERVICE_TIMEOUT";
}

mod profile_keys {
pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
pub(super) const NUM_ATTEMPTS: &str = "metadata_service_num_attempts";
pub(super) const TIMEOUT: &str = "metadata_service_timeout";
}

/// Profile and Environment Variable Abstraction
#[derive(Debug, Clone)]
struct Environment {
env: Env,
profile: profile::ProfileSet,
}

impl Environment {
async fn new(env: Env, fs: Fs) -> Result<Self, BuildError> {
let profile = profile::load(&fs, &env)
.await
.map_err(BuildError::InvalidProfile)?;
Ok(Self{env, profile})
}

fn parse<T: FromStr>(&self, env_key: &str, profile_key: &str) -> Option<Result<T, T::Err>> {
if let Ok(value) = self.env.get(env_key) {
let parsed = value.parse::<T>();
if let Err(_) = parsed {
tracing::warn!(
key = env_key,
value = value,
"Failed to parse environment variable into type `{}`", stringify!(T)
);
};
Some(parsed)
} else if let Some(value) = self.profile.get(profile_key) {
let parsed = value.parse::<T>();
if let Err(_) = parsed {
tracing::warn!(
key = profile_key,
value = value,
"Failed to parse profile value into type `{}`", stringify!(T)
);
}
Some(value.parse::<T>())
} else {
None
}
}
}

/// Endpoint Configuration Abstraction
#[derive(Debug, Clone)]
enum EndpointSource {
Explicit(Uri),
Env(Env, Fs),
Env(Environment),
}

impl EndpointSource {
Expand All @@ -624,29 +679,21 @@ impl EndpointSource {
}
Ok(uri.clone())
}
EndpointSource::Env(env, fs) => {
EndpointSource::Env(env) => {
// load an endpoint override from the environment
let profile = profile::load(fs, env)
.await
.map_err(BuildError::InvalidProfile)?;
let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
Some(Cow::Owned(uri))
} else {
profile.get(profile_keys::ENDPOINT).map(Cow::Borrowed)
};
let uri_override = env
.parse::<String>(env::ENDPOINT, profile_keys::ENDPOINT)
.transpose()
.unwrap(); // safe because parsing `str` to `String` will always succeed
if let Some(uri) = uri_override {
return Uri::try_from(uri.as_ref()).map_err(BuildError::InvalidEndpointUri);
return Uri::try_from(&uri).map_err(BuildError::InvalidEndpointUri);
}

// if not, load a endpoint mode from the environment
let mode = if let Some(mode) = mode_override {
mode
} else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
mode.parse::<EndpointMode>()
.map_err(BuildError::InvalidEndpointMode)?
} else if let Some(mode) = profile.get(profile_keys::ENDPOINT_MODE) {
mode.parse::<EndpointMode>()
.map_err(BuildError::InvalidEndpointMode)?
} else if let Some(mode) = env.parse::<EndpointMode>(env::ENDPOINT_MODE, profile_keys::ENDPOINT_MODE) {
mode.map_err(BuildError::InvalidEndpointMode)?
} else {
EndpointMode::IpV4
};
Expand Down