Skip to content

Commit

Permalink
Merge pull request #687 from phu-cinemo/async
Browse files Browse the repository at this point in the history
tough: migrate to async
  • Loading branch information
webern authored Nov 3, 2023
2 parents cfaa71f + c8361e8 commit f5fbea3
Show file tree
Hide file tree
Showing 55 changed files with 1,944 additions and 1,261 deletions.
85 changes: 79 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 3 additions & 20 deletions tough-kms/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,21 @@
// Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT OR Apache-2.0

use crate::error::{self, Result};
use aws_config::default_provider::credentials::DefaultCredentialsChain;
use aws_config::default_provider::region::DefaultRegionChain;
use aws_sdk_kms::Client as KmsClient;
use snafu::ResultExt;
use std::thread;

/// Builds a KMS client for a given profile name.
pub(crate) fn build_client_kms(profile: Option<&str>) -> Result<KmsClient> {
// We are cloning this so that we can send it across a thread boundary
let profile = profile.map(std::borrow::ToOwned::to_owned);
// We need to spin up a new thread to deal with the async nature of the
// AWS SDK Rust
let client: Result<KmsClient> = thread::spawn(move || {
let runtime = tokio::runtime::Runtime::new().context(error::RuntimeCreationSnafu)?;
Ok(runtime.block_on(async_build_client_kms(profile)))
})
.join()
.map_err(|_| error::Error::ThreadJoin {})?;
client
}

async fn async_build_client_kms(profile: Option<String>) -> KmsClient {
pub(crate) async fn build_client_kms(profile: Option<&str>) -> KmsClient {
let config = aws_config::from_env();
let client_config = if let Some(profile) = profile {
let region = DefaultRegionChain::builder()
.profile_name(&profile)
.profile_name(profile)
.build()
.region()
.await;
let creds = DefaultCredentialsChain::builder()
.profile_name(&profile)
.profile_name(profile)
.region(region.clone())
.build()
.await;
Expand Down
10 changes: 0 additions & 10 deletions tough-kms/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ pub type Result<T> = std::result::Result<T, Error>;
#[non_exhaustive]
#[allow(missing_docs)]
pub enum Error {
/// The library failed to instantiate 'tokio Runtime'.
#[snafu(display("Unable to create tokio runtime: {}", source))]
RuntimeCreation {
source: std::io::Error,
backtrace: Backtrace,
},
/// The library failed to join 'tokio Runtime'.
#[snafu(display("Unable to join tokio thread used to offload async workloads"))]
ThreadJoin,

/// The library failed to get public key from AWS KMS
#[snafu(display(
"Failed to get public key for aws-kms://{}/{} : {}",
Expand Down
32 changes: 15 additions & 17 deletions tough-kms/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use ring::rand::SecureRandom;
use snafu::{ensure, OptionExt, ResultExt};
use std::collections::HashMap;
use std::fmt;
use tough::async_trait;
use tough::key_source::KeySource;
use tough::schema::decoded::{Decoded, RsaPem};
use tough::schema::key::{Key, RsaKey, RsaScheme};
Expand Down Expand Up @@ -76,23 +77,22 @@ impl fmt::Debug for KmsKeySource {
}

/// Implement the `KeySource` trait.
#[async_trait]
impl KeySource for KmsKeySource {
fn as_sign(
async fn as_sign(
&self,
) -> std::result::Result<Box<dyn Sign>, Box<dyn std::error::Error + Send + Sync + 'static>>
{
let kms_client = match self.client.clone() {
Some(value) => value,
None => client::build_client_kms(self.profile.as_deref())?,
None => client::build_client_kms(self.profile.as_deref()).await,
};
// Get the public key from AWS KMS
let fut = kms_client
let response = kms_client
.get_public_key()
.key_id(self.key_id.clone())
.send();
let response = tokio::runtime::Runtime::new()
.context(error::RuntimeCreationSnafu)?
.block_on(fut)
.send()
.await
.context(error::KmsGetPublicKeySnafu {
profile: self.profile.clone(),
key_id: self.key_id.clone(),
Expand Down Expand Up @@ -131,7 +131,7 @@ impl KeySource for KmsKeySource {
}))
}

fn write(
async fn write(
&self,
_value: &str,
_key_id_hex: &str,
Expand Down Expand Up @@ -166,6 +166,7 @@ impl fmt::Debug for KmsRsaKey {
}
}

#[async_trait]
impl Sign for KmsRsaKey {
fn tuf_key(&self) -> Key {
// Create a Key struct for the public key
Expand All @@ -179,27 +180,24 @@ impl Sign for KmsRsaKey {
}
}

fn sign(
async fn sign(
&self,
msg: &[u8],
_rng: &dyn SecureRandom,
_rng: &(dyn SecureRandom + Sync),
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync + 'static>> {
let kms_client = match self.client.clone() {
Some(value) => value,
None => client::build_client_kms(self.profile.as_deref())?,
None => client::build_client_kms(self.profile.as_deref()).await,
};
let blob = Blob::new(digest(&SHA256, msg).as_ref().to_vec());
let sign_fut = kms_client
let response = kms_client
.sign()
.key_id(self.key_id.clone())
.message(blob)
.message_type(aws_sdk_kms::types::MessageType::Digest)
.signing_algorithm(self.signing_algorithm.value())
.send();

let response = tokio::runtime::Runtime::new()
.context(error::RuntimeCreationSnafu)?
.block_on(sign_fut)
.send()
.await
.context(error::KmsSignMessageSnafu {
profile: self.profile.clone(),
key_id: self.key_id.clone(),
Expand Down
Loading

0 comments on commit f5fbea3

Please sign in to comment.