diff --git a/sources/api/early-boot-config/src/provider/aws.rs b/sources/api/early-boot-config/src/provider/aws.rs index 48d5e0a51dd..eb97bfe6af1 100644 --- a/sources/api/early-boot-config/src/provider/aws.rs +++ b/sources/api/early-boot-config/src/provider/aws.rs @@ -85,7 +85,7 @@ impl PlatformDataProvider for AwsDataProvider { ) -> std::result::Result, Box> { let mut output = Vec::new(); - let mut client = ImdsClient::new().await.context(error::ImdsClient)?; + let mut client = ImdsClient::new(); // Attempt to read from local file first on the `aws-dev` variant #[cfg(bottlerocket_platform = "aws-dev")] diff --git a/sources/api/pluto/src/main.rs b/sources/api/pluto/src/main.rs index fed04bc8d69..6500e2dac23 100644 --- a/sources/api/pluto/src/main.rs +++ b/sources/api/pluto/src/main.rs @@ -303,7 +303,7 @@ fn parse_args(mut args: env::Args) -> String { async fn run() -> Result<()> { let setting_name = parse_args(env::args()); - let mut client = ImdsClient::new().await.context(error::ImdsClient)?; + let mut client = ImdsClient::new(); let setting = match setting_name.as_ref() { "cluster-dns-ip" => get_cluster_dns_ip(&mut client).await, diff --git a/sources/api/shibaken/src/main.rs b/sources/api/shibaken/src/main.rs index f4f664de648..05d0abbee34 100644 --- a/sources/api/shibaken/src/main.rs +++ b/sources/api/shibaken/src/main.rs @@ -42,7 +42,7 @@ impl UserData { /// Returns a list of public keys. async fn fetch_public_keys_from_imds() -> Result> { info!("Connecting to IMDS"); - let mut client = ImdsClient::new().await.context(error::ImdsClient)?; + let mut client = ImdsClient::new(); let public_keys = client .fetch_public_ssh_keys() .await diff --git a/sources/imdsclient/src/lib.rs b/sources/imdsclient/src/lib.rs index a7eb242bf1e..ed47df51fc8 100644 --- a/sources/imdsclient/src/lib.rs +++ b/sources/imdsclient/src/lib.rs @@ -44,31 +44,32 @@ fn retry_strategy() -> impl Iterator { } /// A client for making IMDSv2 queries. -/// It obtains a session token when it is first instantiated and is reused between helper functions. pub struct ImdsClient { client: Client, imds_base_uri: String, retry_timeout: Duration, // The token is reader-writer locked to prevent reads while it's being refreshed in retry logic. - session_token: RwLock, + session_token: RwLock>, +} + +impl Default for ImdsClient { + fn default() -> Self { + Self::new() + } } impl ImdsClient { - pub async fn new() -> Result { - Self::new_impl(BASE_URI.to_string()).await + pub fn new() -> Self { + Self::new_impl(BASE_URI.to_string()) } - async fn new_impl(imds_base_uri: String) -> Result { - let client = Client::new(); - let retry_timeout = Duration::from_secs(RETRY_TIMEOUT_SECS); - let session_token = - RwLock::new(fetch_token(&client, &imds_base_uri, &retry_timeout).await?); - Ok(Self { - client, + fn new_impl(imds_base_uri: String) -> Self { + Self { + client: Client::new(), + retry_timeout: Duration::from_secs(RETRY_TIMEOUT_SECS), + session_token: RwLock::new(None), imds_base_uri, - retry_timeout, - session_token, - }) + } } /// Overrides the default timeout when building your own ImdsClient. @@ -129,7 +130,7 @@ impl ImdsClient { /// Gets the IPV6 address associated with the primary network interface from instance metadata. pub async fn fetch_primary_ipv6_address(&mut self) -> Result> { - // Get the mac address for the primary network interface + // Get the mac address for the primary network interface. let mac = self .fetch_mac_addresses() .await? @@ -158,7 +159,7 @@ impl ImdsClient { /// Returns a list of public ssh keys skipping any keys that do not start with 'ssh'. pub async fn fetch_public_ssh_keys(&mut self) -> Result>> { info!("Fetching list of available public keys from IMDS"); - // Returns a list of available public keys as '0=my-public-key' + // Returns a list of available public keys as '0=my-public-key'. let public_key_list = match self.fetch_string("meta-data/public-keys").await? { Some(public_key_list) => { debug!("available public keys '{}'", &public_key_list); @@ -248,7 +249,10 @@ impl ImdsClient { timeout( self.retry_timeout, Retry::spawn(retry_strategy(), || async { - let session_token = self.read_token().await?; + let session_token = match self.read_token().await? { + Some(session_token) => session_token, + None => self.write_token().await?, + }; let response = self .client .get(&uri) @@ -285,9 +289,9 @@ impl ImdsClient { // IMDS returns 401 if the session token is expired or invalid. StatusCode::UNAUTHORIZED => { - warn!("Session token is invalid or expired"); - self.refresh_token().await?; - error::TokenRefreshed.fail() + warn!("IMDS request unauthorized"); + self.clear_token()?; + error::TokenInvalid.fail() } code => { @@ -320,25 +324,41 @@ impl ImdsClient { .context(error::TimeoutFetchIMDS)? } - /// Fetches a new session token and adds it to the current ImdsClient. - async fn refresh_token(&self) -> Result<()> { + /// Fetches a new session token and writes it to the current ImdsClient. + async fn write_token(&self) -> Result { + match fetch_token(&self.client, &self.imds_base_uri, &self.retry_timeout).await? { + Some(written_token) => { + *self + .session_token + .write() + .map_err(|_| error::Error::FailedWriteToken {})? = Some(written_token.clone()); + Ok(written_token) + } + None => error::FailedWriteToken.fail(), + } + } + + /// Clears the session token in the current ImdsClient. + fn clear_token(&self) -> Result<()> { *self .session_token .write() - .map_err(|_| error::Error::FailedWriteToken {})? = - fetch_token(&self.client, &self.imds_base_uri, &self.retry_timeout).await?; + .map_err(|_| error::Error::FailedClearToken {})? = None; Ok(()) } - /// Helper to read session token within the ImdsClient - async fn read_token(&self) -> Result { - let session_token = self + /// Helper to read session token within the ImdsClient. + async fn read_token(&self) -> Result> { + match self .session_token .read() .map_err(|_| error::Error::FailedReadToken {})? - // Cloned to release RwLock as soon as possible - .clone(); - Ok(session_token) + // Cloned to release RwLock as soon as possible. + .clone() + { + Some(read_token) => Ok(Some(read_token)), + None => Ok(None), + } } } @@ -385,7 +405,7 @@ async fn fetch_token( client: &Client, imds_base_uri: &str, retry_timeout: &Duration, -) -> Result { +) -> Result> { let uri = format!("{}/{}", imds_base_uri, SESSION_TARGET); timeout( *retry_timeout, @@ -402,11 +422,13 @@ async fn fetch_token( let code = response.status(); ensure!(code == StatusCode::OK, error::FailedFetchToken); - return response.text().await.context(error::ResponseBody { + + let response_body = response.text().await.context(error::ResponseBody { method: "PUT", uri: &uri, code, - }); + })?; + Ok(Some(response_body)) }), ) .await @@ -417,7 +439,7 @@ mod error { use http::StatusCode; use snafu::Snafu; - // Extracts the status code from a reqwest::Error and converts it to a string to be displayed + // Extracts the status code from a reqwest::Error and converts it to a string to be displayed. fn get_status_code(source: &reqwest::Error) -> String { source .status() @@ -435,6 +457,9 @@ mod error { #[snafu(display("Response '{}' from '{}': {}", get_status_code(source), uri, source))] BadResponse { uri: String, source: reqwest::Error }, + #[snafu(display("Failed to clear token within ImdsClient"))] + FailedClearToken, + #[snafu(display("IMDS fetch failed after {} attempts", attempt))] FailedFetchIMDS { attempt: u8 }, @@ -497,8 +522,8 @@ mod error { #[snafu(display("Timed out fetching IMDSv2 session token: {}", source))] TimeoutFetchToken { source: tokio::time::error::Elapsed }, - #[snafu(display("IMDSv2 session token was refreshed."))] - TokenRefreshed, + #[snafu(display("IMDSv2 session token is invalid or expired."))] + TokenInvalid, } } @@ -510,25 +535,6 @@ mod test { use super::*; use httptest::{matchers::*, responders::*, Expectation, Server}; - #[tokio::test] - async fn new_imds_client() { - let server = Server::run(); - let base_uri = format!("http://{}", server.addr()); - let token = "some+token"; - server.expect( - Expectation::matching(request::method_path("PUT", "/latest/api/token")) - .times(1) - .respond_with( - status_code(200) - .append_header("X-aws-ec2-metadata-token-ttl-seconds", "60") - .body(token), - ), - ); - let imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); - let imds_token = imds_client.read_token().await.unwrap(); - assert_eq!(imds_token, token); - } - #[tokio::test] async fn fetch_imds() { let server = Server::run(); @@ -559,11 +565,13 @@ mod test { .body(response_body), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client .fetch_imds(schema_version, target) .await .unwrap(); + let imds_token = imds_client.read_token().await.unwrap().unwrap(); + assert_eq!(imds_token, token); assert_eq!(imds_data, Some(response_body.as_bytes().to_vec())); } @@ -594,7 +602,7 @@ mod test { status_code(response_code).append_header("X-aws-ec2-metadata-token", token), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client .fetch_imds(schema_version, target) .await @@ -630,10 +638,7 @@ mod test { status_code(response_code).append_header("X-aws-ec2-metadata-token", token), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri) - .await - .unwrap() - .with_timeout(retry_timeout); + let mut imds_client = ImdsClient::new_impl(base_uri).with_timeout(retry_timeout); assert!(imds_client .fetch_imds(schema_version, target) .await @@ -668,10 +673,7 @@ mod test { status_code(response_code).append_header("X-aws-ec2-metadata-token", token), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri) - .await - .unwrap() - .with_timeout(retry_timeout); + let mut imds_client = ImdsClient::new_impl(base_uri).with_timeout(retry_timeout); assert!(imds_client .fetch_imds(schema_version, target) .await @@ -724,7 +726,7 @@ mod test { .body(response_body), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client.fetch_string(end_target).await.unwrap(); assert_eq!(imds_data, Some(response_body.to_string())); } @@ -758,7 +760,7 @@ mod test { .body(response_body), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client.fetch_bytes(end_target).await.unwrap(); assert_eq!(imds_data, Some(response_body.as_bytes().to_vec())); } @@ -791,7 +793,7 @@ mod test { .body(response_body), ), ); - let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap(); + let mut imds_client = ImdsClient::new_impl(base_uri); let imds_data = imds_client.fetch_userdata().await.unwrap(); assert_eq!(imds_data, Some(response_body.as_bytes().to_vec())); }