diff --git a/sources/Cargo.lock b/sources/Cargo.lock index 576bc00c17e..5ea951758d2 100644 --- a/sources/Cargo.lock +++ b/sources/Cargo.lock @@ -1588,6 +1588,7 @@ dependencies = [ "serde_json", "snafu", "tokio", + "tokio-retry", "tokio-test", "url", ] @@ -3442,6 +3443,17 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-retry" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" +dependencies = [ + "pin-project 1.0.8", + "rand", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.22.0" diff --git a/sources/api/early-boot-config/src/provider/aws.rs b/sources/api/early-boot-config/src/provider/aws.rs index 48d5e0a51dd..eb97bfe6af1 100644 --- a/sources/api/early-boot-config/src/provider/aws.rs +++ b/sources/api/early-boot-config/src/provider/aws.rs @@ -85,7 +85,7 @@ impl PlatformDataProvider for AwsDataProvider { ) -> std::result::Result, Box> { let mut output = Vec::new(); - let mut client = ImdsClient::new().await.context(error::ImdsClient)?; + let mut client = ImdsClient::new(); // Attempt to read from local file first on the `aws-dev` variant #[cfg(bottlerocket_platform = "aws-dev")] diff --git a/sources/api/pluto/src/main.rs b/sources/api/pluto/src/main.rs index fed04bc8d69..6500e2dac23 100644 --- a/sources/api/pluto/src/main.rs +++ b/sources/api/pluto/src/main.rs @@ -303,7 +303,7 @@ fn parse_args(mut args: env::Args) -> String { async fn run() -> Result<()> { let setting_name = parse_args(env::args()); - let mut client = ImdsClient::new().await.context(error::ImdsClient)?; + let mut client = ImdsClient::new(); let setting = match setting_name.as_ref() { "cluster-dns-ip" => get_cluster_dns_ip(&mut client).await, diff --git a/sources/api/shibaken/src/main.rs b/sources/api/shibaken/src/main.rs index f4f664de648..05d0abbee34 100644 --- a/sources/api/shibaken/src/main.rs +++ b/sources/api/shibaken/src/main.rs @@ -42,7 +42,7 @@ impl UserData { /// Returns a list of public keys. async fn fetch_public_keys_from_imds() -> Result> { info!("Connecting to IMDS"); - let mut client = ImdsClient::new().await.context(error::ImdsClient)?; + let mut client = ImdsClient::new(); let public_keys = client .fetch_public_ssh_keys() .await diff --git a/sources/imdsclient/Cargo.toml b/sources/imdsclient/Cargo.toml index 75dab30148f..085b0f7897b 100644 --- a/sources/imdsclient/Cargo.toml +++ b/sources/imdsclient/Cargo.toml @@ -16,6 +16,7 @@ reqwest = { version = "0.11.1", default-features = false } serde_json = "1" snafu = "0.6" tokio = { version = "~1.8", default-features = false, features = ["macros", "rt-multi-thread", "time"] } # LTS +tokio-retry = "0.3" url = "2.1.1" [build-dependencies] diff --git a/sources/imdsclient/README.md b/sources/imdsclient/README.md index edaa38e1a02..b48114953b0 100644 --- a/sources/imdsclient/README.md +++ b/sources/imdsclient/README.md @@ -6,6 +6,9 @@ Current version: 0.1.0 The library uses IMDSv2 (session-oriented) requests over a pinned schema to guarantee compatibility. Session tokens are fetched automatically and refreshed if the request receives a `401` response. +If an IMDS token fetch or query fails, the library will continue to retry with a fibonacci backoff +strategy until it is successful or times out. The default timeout is 300s to match the ifup timeout +set in wicked.service, but can configured using `.with_timeout` during client creation. Each public method is explicitly targeted and return either bytes or a `String`. diff --git a/sources/imdsclient/src/lib.rs b/sources/imdsclient/src/lib.rs index 9cdec9a7d2c..ed47df51fc8 100644 --- a/sources/imdsclient/src/lib.rs +++ b/sources/imdsclient/src/lib.rs @@ -3,6 +3,9 @@ The library uses IMDSv2 (session-oriented) requests over a pinned schema to guarantee compatibility. Session tokens are fetched automatically and refreshed if the request receives a `401` response. +If an IMDS token fetch or query fails, the library will continue to retry with a fibonacci backoff +strategy until it is successful or times out. The default timeout is 300s to match the ifup timeout +set in wicked.service, but can configured using `.with_timeout` during client creation. Each public method is explicitly targeted and return either bytes or a `String`. @@ -16,41 +19,63 @@ The result is returned as a `String` _(ex. m5.large)_. #![deny(rust_2018_idioms)] +use std::sync::RwLock; + use http::StatusCode; use log::{debug, info, trace, warn}; use reqwest::Client; use serde_json::Value; use snafu::{ensure, OptionExt, ResultExt}; -use std::time::Duration; -use tokio::time; +use tokio::time::{timeout, Duration}; +use tokio_retry::{strategy::FibonacciBackoff, Retry}; const BASE_URI: &str = "http://169.254.169.254"; const PINNED_SCHEMA: &str = "2021-01-03"; -// Currently only able to get fetch session tokens from `latest` +// Currently only able to get fetch session tokens from `latest`. const SESSION_TARGET: &str = "latest/api/token"; +// Retry timeout tied to wicked.service ifup timeout. +const RETRY_TIMEOUT_SECS: u64 = 300; + +fn retry_strategy() -> impl Iterator { + // Retry attempts at 0.25s, 0.5s, 1s, 1.75s, 3s, 5s, 8.25s, 13.5s, 22s and then every 10s after. + FibonacciBackoff::from_millis(250).max_delay(Duration::from_secs(10)) +} + /// A client for making IMDSv2 queries. -/// It obtains a session token when it is first instantiated and is reused between helper functions. pub struct ImdsClient { client: Client, imds_base_uri: String, - session_token: String, + retry_timeout: Duration, + // The token is reader-writer locked to prevent reads while it's being refreshed in retry logic. + session_token: RwLock>, +} + +impl Default for ImdsClient { + fn default() -> Self { + Self::new() + } } impl ImdsClient { - pub async fn new() -> Result { - Self::new_impl(BASE_URI.to_string()).await + pub fn new() -> Self { + Self::new_impl(BASE_URI.to_string()) } - async fn new_impl(imds_base_uri: String) -> Result { - let client = Client::new(); - let session_token = fetch_token(&client, &imds_base_uri).await?; - Ok(Self { - client, + fn new_impl(imds_base_uri: String) -> Self { + Self { + client: Client::new(), + retry_timeout: Duration::from_secs(RETRY_TIMEOUT_SECS), + session_token: RwLock::new(None), imds_base_uri, - session_token, - }) + } + } + + /// Overrides the default timeout when building your own ImdsClient. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.retry_timeout = timeout; + self } /// Gets `user-data` from IMDS. The user-data may be either a UTF-8 string or compressed bytes. @@ -105,7 +130,7 @@ impl ImdsClient { /// Gets the IPV6 address associated with the primary network interface from instance metadata. pub async fn fetch_primary_ipv6_address(&mut self) -> Result> { - // Get the mac address for the primary network interface + // Get the mac address for the primary network interface. let mac = self .fetch_mac_addresses() .await? @@ -114,7 +139,7 @@ impl ImdsClient { .context(error::MacAddresses)? .clone(); - // Get the IPv6 addresses associated with the primary network interface + // Get the IPv6 addresses associated with the primary network interface. let ipv6_address_target = format!("meta-data/network/interfaces/macs/{}/ipv6s", mac); let ipv6_address = self @@ -134,7 +159,7 @@ impl ImdsClient { /// Returns a list of public ssh keys skipping any keys that do not start with 'ssh'. pub async fn fetch_public_ssh_keys(&mut self) -> Result>> { info!("Fetching list of available public keys from IMDS"); - // Returns a list of available public keys as '0=my-public-key' + // Returns a list of available public keys as '0=my-public-key'. let public_key_list = match self.fetch_string("meta-data/public-keys").await? { Some(public_key_list) => { debug!("available public keys '{}'", &public_key_list); @@ -221,93 +246,120 @@ impl ImdsClient { target.as_ref() ); debug!("Requesting {}", &uri); - let mut attempt: u8 = 0; - let max_attempts: u8 = 3; - loop { - attempt += 1; - ensure!(attempt <= max_attempts, error::FailedFetchIMDS { attempt }); - if attempt > 1 { - time::sleep(Duration::from_secs(1)).await; - } - let response = self - .client - .get(&uri) - .header("X-aws-ec2-metadata-token", &self.session_token) - .send() - .await - .context(error::Request { - method: "GET", - uri: &uri, - })?; - trace!("IMDS response: {:?}", &response); - - match response.status() { - code @ StatusCode::OK => { - info!("Received {}", target.as_ref()); - let response_body = response - .bytes() - .await - .context(error::ResponseBody { - method: "GET", - uri: &uri, - code, - })? - .to_vec(); + timeout( + self.retry_timeout, + Retry::spawn(retry_strategy(), || async { + let session_token = match self.read_token().await? { + Some(session_token) => session_token, + None => self.write_token().await?, + }; + let response = self + .client + .get(&uri) + .header("X-aws-ec2-metadata-token", session_token) + .send() + .await + .context(error::Request { + method: "GET", + uri: &uri, + })?; + trace!("IMDS response: {:?}", &response); + + match response.status() { + code @ StatusCode::OK => { + info!("Received {}", target.as_ref()); + let response_body = response + .bytes() + .await + .context(error::ResponseBody { + method: "GET", + uri: &uri, + code, + })? + .to_vec(); + + let response_str = printable_string(&response_body); + trace!("Response: {:?}", response_str); + + Ok(Some(response_body)) + } - let response_str = printable_string(&response_body); - trace!("Response: {:?}", response_str); + // IMDS returns 404 if no user data is given, or if IMDS is disabled. + StatusCode::NOT_FOUND => Ok(None), - return Ok(Some(response_body)); - } + // IMDS returns 401 if the session token is expired or invalid. + StatusCode::UNAUTHORIZED => { + warn!("IMDS request unauthorized"); + self.clear_token()?; + error::TokenInvalid.fail() + } - // IMDS returns 404 if no user data is given, or if IMDS is disabled - StatusCode::NOT_FOUND => return Ok(None), + code => { + let response_body = response + .bytes() + .await + .context(error::ResponseBody { + method: "GET", + uri: &uri, + code, + })? + .to_vec(); - // IMDS returns 401 if the session token is expired or invalid - StatusCode::UNAUTHORIZED => { - info!("Session token is invalid or expired"); - self.refresh_token().await?; - info!("Refreshed session token"); - continue; - } + let response_str = printable_string(&response_body); - StatusCode::REQUEST_TIMEOUT => { - info!("Retrying request"); - continue; - } + trace!("Response: {:?}", response_str); - code => { - let response_body = response - .bytes() - .await - .context(error::ResponseBody { + error::Response { method: "GET", uri: &uri, code, - })? - .to_vec(); - - let response_str = printable_string(&response_body); - - trace!("Response: {:?}", response_str); - - return error::Response { - method: "GET", - uri: &uri, - code, - response_body: response_str, + response_body: response_str, + } + .fail() } - .fail(); } + }), + ) + .await + .context(error::TimeoutFetchIMDS)? + } + + /// Fetches a new session token and writes it to the current ImdsClient. + async fn write_token(&self) -> Result { + match fetch_token(&self.client, &self.imds_base_uri, &self.retry_timeout).await? { + Some(written_token) => { + *self + .session_token + .write() + .map_err(|_| error::Error::FailedWriteToken {})? = Some(written_token.clone()); + Ok(written_token) } + None => error::FailedWriteToken.fail(), } } - /// Fetches a new session token and adds it to the current ImdsClient. - async fn refresh_token(&mut self) -> Result<()> { - self.session_token = fetch_token(&self.client, &self.imds_base_uri).await?; + /// Clears the session token in the current ImdsClient. + fn clear_token(&self) -> Result<()> { + *self + .session_token + .write() + .map_err(|_| error::Error::FailedClearToken {})? = None; Ok(()) } + + /// Helper to read session token within the ImdsClient. + async fn read_token(&self) -> Result> { + match self + .session_token + .read() + .map_err(|_| error::Error::FailedReadToken {})? + // Cloned to release RwLock as soon as possible. + .clone() + { + Some(read_token) => Ok(Some(read_token)), + None => Ok(None), + } + } } /// Converts `bytes` to a `String` if it is a UTF-8 encoded string. @@ -349,45 +401,45 @@ fn build_public_key_targets(public_key_list: &str) -> Vec { } /// Helper to fetch an IMDSv2 session token that is valid for 60 seconds. -async fn fetch_token(client: &Client, imds_base_uri: &str) -> Result { +async fn fetch_token( + client: &Client, + imds_base_uri: &str, + retry_timeout: &Duration, +) -> Result> { let uri = format!("{}/{}", imds_base_uri, SESSION_TARGET); - let mut attempt: u8 = 0; - let max_attempts: u8 = 3; - loop { - attempt += 1; - ensure!(attempt <= max_attempts, error::FailedFetchToken { attempt }); - if attempt > 1 { - time::sleep(Duration::from_secs(5)).await; - } - let response = client - .put(&uri) - .header("X-aws-ec2-metadata-token-ttl-seconds", "60") - .send() - .await - .context(error::Request { - method: "PUT", - uri: &uri, - })?; + timeout( + *retry_timeout, + Retry::spawn(retry_strategy(), || async { + let response = client + .put(&uri) + .header("X-aws-ec2-metadata-token-ttl-seconds", "60") + .send() + .await + .context(error::Request { + method: "PUT", + uri: &uri, + })?; - let code = response.status(); - if code == StatusCode::OK { - return response.text().await.context(error::ResponseBody { + let code = response.status(); + ensure!(code == StatusCode::OK, error::FailedFetchToken); + + let response_body = response.text().await.context(error::ResponseBody { method: "PUT", uri: &uri, code, - }); - } else { - info!("Retrying token request"); - continue; - } - } + })?; + Ok(Some(response_body)) + }), + ) + .await + .context(error::TimeoutFetchToken)? } mod error { use http::StatusCode; use snafu::Snafu; - // Extracts the status code from a reqwest::Error and converts it to a string to be displayed + // Extracts the status code from a reqwest::Error and converts it to a string to be displayed. fn get_status_code(source: &reqwest::Error) -> String { source .status() @@ -400,19 +452,29 @@ mod error { #[derive(Debug, Snafu)] #[snafu(visibility = "pub(super)")] + // snafu doesn't yet support the lifetimes used by std::sync::PoisonError. pub enum Error { #[snafu(display("Response '{}' from '{}': {}", get_status_code(source), uri, source))] BadResponse { uri: String, source: reqwest::Error }, + #[snafu(display("Failed to clear token within ImdsClient"))] + FailedClearToken, + #[snafu(display("IMDS fetch failed after {} attempts", attempt))] FailedFetchIMDS { attempt: u8 }, - #[snafu(display("Failed to fetch IMDSv2 session token after {} attempts", attempt))] - FailedFetchToken { attempt: u8 }, + #[snafu(display("Failed to fetch IMDSv2 session token"))] + FailedFetchToken, + + #[snafu(display("Failed to read token within ImdsClient"))] + FailedReadToken, #[snafu(display("IMDS session failed: {}", source))] FailedSession { source: reqwest::Error }, + #[snafu(display("Failed to write token to ImdsClient"))] + FailedWriteToken, + #[snafu(display("Error retrieving key from {}", target))] KeyNotFound { target: String }, @@ -453,6 +515,15 @@ mod error { #[snafu(display("Deserialization error: {}", source))] Serde { source: serde_json::Error }, + + #[snafu(display("Timed out fetching data from IMDS: {}", source))] + TimeoutFetchIMDS { source: tokio::time::error::Elapsed }, + + #[snafu(display("Timed out fetching IMDSv2 session token: {}", source))] + TimeoutFetchToken { source: tokio::time::error::Elapsed }, + + #[snafu(display("IMDSv2 session token is invalid or expired."))] + TokenInvalid, } } @@ -464,24 +535,6 @@ mod test { use super::*; use httptest::{matchers::*, responders::*, Expectation, Server}; - #[tokio::test] - async fn new_imds_client() { - let server = Server::run(); - let base_uri = format!("http://{}", server.addr()); - let token = "some+token"; - server.expect( - Expectation::matching(request::method_path("PUT", "/latest/api/token")) - .times(1) - .respond_with( - status_code(200) - .append_header("X-aws-ec2-metadata-token-ttl-seconds", "60") - .body(token), - ), - ); - let imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); - assert_eq!(imds_client.session_token, token); - } - #[tokio::test] async fn fetch_imds() { let server = Server::run(); @@ -512,11 +565,13 @@ mod test { .body(response_body), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client .fetch_imds(schema_version, target) .await .unwrap(); + let imds_token = imds_client.read_token().await.unwrap().unwrap(); + assert_eq!(imds_token, token); assert_eq!(imds_data, Some(response_body.as_bytes().to_vec())); } @@ -547,7 +602,7 @@ mod test { status_code(response_code).append_header("X-aws-ec2-metadata-token", token), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client .fetch_imds(schema_version, target) .await @@ -563,9 +618,10 @@ mod test { let schema_version = "latest"; let target = "meta-data/instance-type"; let response_code = 401; + let retry_timeout = Duration::from_secs(2); server.expect( Expectation::matching(request::method_path("PUT", "/latest/api/token")) - .times(4) + .times(2..) .respond_with( status_code(200) .append_header("X-aws-ec2-metadata-token-ttl-seconds", "60") @@ -577,12 +633,12 @@ mod test { "GET", format!("/{}/{}", schema_version, target), )) - .times(3) + .times(2..) .respond_with( status_code(response_code).append_header("X-aws-ec2-metadata-token", token), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri).with_timeout(retry_timeout); assert!(imds_client .fetch_imds(schema_version, target) .await @@ -597,6 +653,7 @@ mod test { let schema_version = "latest"; let target = "meta-data/instance-type"; let response_code = 408; + let retry_timeout = Duration::from_secs(2); server.expect( Expectation::matching(request::method_path("PUT", "/latest/api/token")) .times(1) @@ -611,12 +668,12 @@ mod test { "GET", format!("/{}/{}", schema_version, target), )) - .times(3) + .times(2..) .respond_with( status_code(response_code).append_header("X-aws-ec2-metadata-token", token), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri).with_timeout(retry_timeout); assert!(imds_client .fetch_imds(schema_version, target) .await @@ -627,14 +684,17 @@ mod test { async fn fetch_token_timeout() { let server = Server::run(); let base_uri = format!("http://{}", server.addr()); + let retry_timeout = Duration::from_secs(2); let response_code = 408; server.expect( Expectation::matching(request::method_path("PUT", "/latest/api/token")) - .times(3) + .times(2..) .respond_with(status_code(response_code)), ); let client = Client::new(); - assert!(fetch_token(&client, &base_uri).await.is_err()); + assert!(fetch_token(&client, &base_uri, &retry_timeout) + .await + .is_err()); } #[tokio::test] @@ -666,7 +726,7 @@ mod test { .body(response_body), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client.fetch_string(end_target).await.unwrap(); assert_eq!(imds_data, Some(response_body.to_string())); } @@ -700,7 +760,7 @@ mod test { .body(response_body), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client.fetch_bytes(end_target).await.unwrap(); assert_eq!(imds_data, Some(response_body.as_bytes().to_vec())); } @@ -733,7 +793,7 @@ mod test { .body(response_body), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client.fetch_userdata().await.unwrap(); assert_eq!(imds_data, Some(response_body.as_bytes().to_vec())); }