Skip to content

Commit

Permalink
imdsclient: move token fetch out ImdsClient::new()
Browse files Browse the repository at this point in the history
This moves the token fetch to just before the IMDS data fetch so
that building a new ImdsClient no longer needs an await, nor does it
return a result.

`refresh_token` has been replaced by `clear_token` and `write_token`.
  • Loading branch information
jpculp committed Dec 14, 2021
1 parent 0493a81 commit 5e5200e
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 71 deletions.
2 changes: 1 addition & 1 deletion sources/api/early-boot-config/src/provider/aws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl PlatformDataProvider for AwsDataProvider {
) -> std::result::Result<Vec<SettingsJson>, Box<dyn std::error::Error>> {
let mut output = Vec::new();

let mut client = ImdsClient::new().await.context(error::ImdsClient)?;
let mut client = ImdsClient::new();

// Attempt to read from local file first on the `aws-dev` variant
#[cfg(bottlerocket_platform = "aws-dev")]
Expand Down
2 changes: 1 addition & 1 deletion sources/api/pluto/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ fn parse_args(mut args: env::Args) -> String {

async fn run() -> Result<()> {
let setting_name = parse_args(env::args());
let mut client = ImdsClient::new().await.context(error::ImdsClient)?;
let mut client = ImdsClient::new();

let setting = match setting_name.as_ref() {
"cluster-dns-ip" => get_cluster_dns_ip(&mut client).await,
Expand Down
2 changes: 1 addition & 1 deletion sources/api/shibaken/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl UserData {
/// Returns a list of public keys.
async fn fetch_public_keys_from_imds() -> Result<Vec<String>> {
info!("Connecting to IMDS");
let mut client = ImdsClient::new().await.context(error::ImdsClient)?;
let mut client = ImdsClient::new();
let public_keys = client
.fetch_public_ssh_keys()
.await
Expand Down
138 changes: 70 additions & 68 deletions sources/imdsclient/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,32 @@ fn retry_strategy() -> impl Iterator<Item = Duration> {
}

/// A client for making IMDSv2 queries.
/// It obtains a session token when it is first instantiated and is reused between helper functions.
pub struct ImdsClient {
client: Client,
imds_base_uri: String,
retry_timeout: Duration,
// The token is reader-writer locked to prevent reads while it's being refreshed in retry logic.
session_token: RwLock<String>,
session_token: RwLock<Option<String>>,
}

impl Default for ImdsClient {
fn default() -> Self {
Self::new()
}
}

impl ImdsClient {
pub async fn new() -> Result<Self> {
Self::new_impl(BASE_URI.to_string()).await
pub fn new() -> Self {
Self::new_impl(BASE_URI.to_string())
}

async fn new_impl(imds_base_uri: String) -> Result<Self> {
let client = Client::new();
let retry_timeout = Duration::from_secs(RETRY_TIMEOUT_SECS);
let session_token =
RwLock::new(fetch_token(&client, &imds_base_uri, &retry_timeout).await?);
Ok(Self {
client,
fn new_impl(imds_base_uri: String) -> Self {
Self {
client: Client::new(),
retry_timeout: Duration::from_secs(RETRY_TIMEOUT_SECS),
session_token: RwLock::new(None),
imds_base_uri,
retry_timeout,
session_token,
})
}
}

/// Overrides the default timeout when building your own ImdsClient.
Expand Down Expand Up @@ -129,7 +130,7 @@ impl ImdsClient {

/// Gets the IPV6 address associated with the primary network interface from instance metadata.
pub async fn fetch_primary_ipv6_address(&mut self) -> Result<Option<String>> {
// Get the mac address for the primary network interface
// Get the mac address for the primary network interface.
let mac = self
.fetch_mac_addresses()
.await?
Expand Down Expand Up @@ -158,7 +159,7 @@ impl ImdsClient {
/// Returns a list of public ssh keys skipping any keys that do not start with 'ssh'.
pub async fn fetch_public_ssh_keys(&mut self) -> Result<Option<Vec<String>>> {
info!("Fetching list of available public keys from IMDS");
// Returns a list of available public keys as '0=my-public-key'
// Returns a list of available public keys as '0=my-public-key'.
let public_key_list = match self.fetch_string("meta-data/public-keys").await? {
Some(public_key_list) => {
debug!("available public keys '{}'", &public_key_list);
Expand Down Expand Up @@ -248,7 +249,10 @@ impl ImdsClient {
timeout(
self.retry_timeout,
Retry::spawn(retry_strategy(), || async {
let session_token = self.read_token().await?;
let session_token = match self.read_token().await? {
Some(session_token) => session_token,
None => self.write_token().await?,
};
let response = self
.client
.get(&uri)
Expand Down Expand Up @@ -285,9 +289,9 @@ impl ImdsClient {

// IMDS returns 401 if the session token is expired or invalid.
StatusCode::UNAUTHORIZED => {
warn!("Session token is invalid or expired");
self.refresh_token().await?;
error::TokenRefreshed.fail()
warn!("IMDS request unauthorized");
self.clear_token()?;
error::TokenInvalid.fail()
}

code => {
Expand Down Expand Up @@ -320,25 +324,41 @@ impl ImdsClient {
.context(error::TimeoutFetchIMDS)?
}

/// Fetches a new session token and adds it to the current ImdsClient.
async fn refresh_token(&self) -> Result<()> {
/// Fetches a new session token and writes it to the current ImdsClient.
async fn write_token(&self) -> Result<String> {
match fetch_token(&self.client, &self.imds_base_uri, &self.retry_timeout).await? {
Some(written_token) => {
*self
.session_token
.write()
.map_err(|_| error::Error::FailedWriteToken {})? = Some(written_token.clone());
Ok(written_token)
}
None => error::FailedWriteToken.fail(),
}
}

/// Clears the session token in the current ImdsClient.
fn clear_token(&self) -> Result<()> {
*self
.session_token
.write()
.map_err(|_| error::Error::FailedWriteToken {})? =
fetch_token(&self.client, &self.imds_base_uri, &self.retry_timeout).await?;
.map_err(|_| error::Error::FailedClearToken {})? = None;
Ok(())
}

/// Helper to read session token within the ImdsClient
async fn read_token(&self) -> Result<String> {
let session_token = self
/// Helper to read session token within the ImdsClient.
async fn read_token(&self) -> Result<Option<String>> {
match self
.session_token
.read()
.map_err(|_| error::Error::FailedReadToken {})?
// Cloned to release RwLock as soon as possible
.clone();
Ok(session_token)
// Cloned to release RwLock as soon as possible.
.clone()
{
Some(read_token) => Ok(Some(read_token)),
None => Ok(None),
}
}
}

Expand Down Expand Up @@ -385,7 +405,7 @@ async fn fetch_token(
client: &Client,
imds_base_uri: &str,
retry_timeout: &Duration,
) -> Result<String> {
) -> Result<Option<String>> {
let uri = format!("{}/{}", imds_base_uri, SESSION_TARGET);
timeout(
*retry_timeout,
Expand All @@ -402,11 +422,13 @@ async fn fetch_token(

let code = response.status();
ensure!(code == StatusCode::OK, error::FailedFetchToken);
return response.text().await.context(error::ResponseBody {

let response_body = response.text().await.context(error::ResponseBody {
method: "PUT",
uri: &uri,
code,
});
})?;
Ok(Some(response_body))
}),
)
.await
Expand All @@ -417,7 +439,7 @@ mod error {
use http::StatusCode;
use snafu::Snafu;

// Extracts the status code from a reqwest::Error and converts it to a string to be displayed
// Extracts the status code from a reqwest::Error and converts it to a string to be displayed.
fn get_status_code(source: &reqwest::Error) -> String {
source
.status()
Expand All @@ -435,6 +457,9 @@ mod error {
#[snafu(display("Response '{}' from '{}': {}", get_status_code(source), uri, source))]
BadResponse { uri: String, source: reqwest::Error },

#[snafu(display("Failed to clear token within ImdsClient"))]
FailedClearToken,

#[snafu(display("IMDS fetch failed after {} attempts", attempt))]
FailedFetchIMDS { attempt: u8 },

Expand Down Expand Up @@ -497,8 +522,8 @@ mod error {
#[snafu(display("Timed out fetching IMDSv2 session token: {}", source))]
TimeoutFetchToken { source: tokio::time::error::Elapsed },

#[snafu(display("IMDSv2 session token was refreshed."))]
TokenRefreshed,
#[snafu(display("IMDSv2 session token is invalid or expired."))]
TokenInvalid,
}
}

Expand All @@ -510,25 +535,6 @@ mod test {
use super::*;
use httptest::{matchers::*, responders::*, Expectation, Server};

#[tokio::test]
async fn new_imds_client() {
let server = Server::run();
let base_uri = format!("http://{}", server.addr());
let token = "some+token";
server.expect(
Expectation::matching(request::method_path("PUT", "/latest/api/token"))
.times(1)
.respond_with(
status_code(200)
.append_header("X-aws-ec2-metadata-token-ttl-seconds", "60")
.body(token),
),
);
let imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let imds_token = imds_client.read_token().await.unwrap();
assert_eq!(imds_token, token);
}

#[tokio::test]
async fn fetch_imds() {
let server = Server::run();
Expand Down Expand Up @@ -559,11 +565,13 @@ mod test {
.body(response_body),
),
);
let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let mut imds_client = ImdsClient::new_impl(base_uri);
let imds_data = imds_client
.fetch_imds(schema_version, target)
.await
.unwrap();
let imds_token = imds_client.read_token().await.unwrap().unwrap();
assert_eq!(imds_token, token);
assert_eq!(imds_data, Some(response_body.as_bytes().to_vec()));
}

Expand Down Expand Up @@ -594,7 +602,7 @@ mod test {
status_code(response_code).append_header("X-aws-ec2-metadata-token", token),
),
);
let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let mut imds_client = ImdsClient::new_impl(base_uri);
let imds_data = imds_client
.fetch_imds(schema_version, target)
.await
Expand Down Expand Up @@ -630,10 +638,7 @@ mod test {
status_code(response_code).append_header("X-aws-ec2-metadata-token", token),
),
);
let mut imds_client = ImdsClient::new_impl(base_uri)
.await
.unwrap()
.with_timeout(retry_timeout);
let mut imds_client = ImdsClient::new_impl(base_uri).with_timeout(retry_timeout);
assert!(imds_client
.fetch_imds(schema_version, target)
.await
Expand Down Expand Up @@ -668,10 +673,7 @@ mod test {
status_code(response_code).append_header("X-aws-ec2-metadata-token", token),
),
);
let mut imds_client = ImdsClient::new_impl(base_uri)
.await
.unwrap()
.with_timeout(retry_timeout);
let mut imds_client = ImdsClient::new_impl(base_uri).with_timeout(retry_timeout);
assert!(imds_client
.fetch_imds(schema_version, target)
.await
Expand Down Expand Up @@ -724,7 +726,7 @@ mod test {
.body(response_body),
),
);
let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let mut imds_client = ImdsClient::new_impl(base_uri);
let imds_data = imds_client.fetch_string(end_target).await.unwrap();
assert_eq!(imds_data, Some(response_body.to_string()));
}
Expand Down Expand Up @@ -758,7 +760,7 @@ mod test {
.body(response_body),
),
);
let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let mut imds_client = ImdsClient::new_impl(base_uri);
let imds_data = imds_client.fetch_bytes(end_target).await.unwrap();
assert_eq!(imds_data, Some(response_body.as_bytes().to_vec()));
}
Expand Down Expand Up @@ -791,7 +793,7 @@ mod test {
.body(response_body),
),
);
let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let mut imds_client = ImdsClient::new_impl(base_uri);
let imds_data = imds_client.fetch_userdata().await.unwrap();
assert_eq!(imds_data, Some(response_body.as_bytes().to_vec()));
}
Expand Down

0 comments on commit 5e5200e

Please sign in to comment.