Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions crates/uv-auth/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ enum TokenState {
Initialized(Option<AccessToken>),
}

#[derive(Clone)]
enum S3CredentialState {
/// The S3 credential state has not yet been initialized.
Uninitialized,
/// The S3 credential state has been initialized, with either a signer or `None` if
/// no S3 endpoint is configured.
Initialized(Option<Arc<Authentication>>),
}

/// A middleware that adds basic authentication to requests.
///
/// Uses a cache to propagate credentials from previously seen requests and
Expand All @@ -150,6 +159,8 @@ pub struct AuthMiddleware {
pyx_token_store: Option<PyxTokenStore>,
/// Tokens to use for persistent credentials.
pyx_token_state: Mutex<TokenState>,
/// Cached S3 credentials to avoid running the credential helper multiple times.
s3_credential_state: Mutex<S3CredentialState>,
preview: Preview,
}

Expand All @@ -172,6 +183,7 @@ impl AuthMiddleware {
base_client: None,
pyx_token_store: None,
pyx_token_state: Mutex::new(TokenState::Uninitialized),
s3_credential_state: Mutex::new(S3CredentialState::Uninitialized),
preview: Preview::default(),
}
}
Expand Down Expand Up @@ -678,13 +690,26 @@ impl AuthMiddleware {
return Some(credentials);
}

if let Some(credentials) = S3EndpointProvider::credentials_for(url, self.preview)
.map(Authentication::from)
.map(Arc::new)
{
debug!("Found S3 credentials for {url}");
self.cache().fetches.done(key, Some(credentials.clone()));
return Some(credentials);
if S3EndpointProvider::is_s3_endpoint(url, self.preview) {
let mut s3_state = self.s3_credential_state.lock().await;

// If the S3 credential state is uninitialized, initialize it.
let credentials = match &*s3_state {
S3CredentialState::Uninitialized => {
trace!("Initializing S3 credentials for {url}");
let signer = S3EndpointProvider::create_signer();
let credentials = Arc::new(Authentication::from(signer));
*s3_state = S3CredentialState::Initialized(Some(credentials.clone()));
Some(credentials)
}
S3CredentialState::Initialized(credentials) => credentials.clone(),
};

if let Some(credentials) = credentials {
debug!("Found S3 credentials for {url}");
self.cache().fetches.done(key, Some(credentials.clone()));
return Some(credentials);
}
}

// If this is a known URL, authenticate it via the token store.
Expand Down
35 changes: 21 additions & 14 deletions crates/uv-auth/src/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ static S3_ENDPOINT_REALM: LazyLock<Option<Realm>> = LazyLock::new(|| {
pub(crate) struct S3EndpointProvider;

impl S3EndpointProvider {
/// Returns the credentials for the S3 endpoint, if available.
pub(crate) fn credentials_for(url: &Url, preview: Preview) -> Option<DefaultSigner> {
/// Returns `true` if the URL matches the configured S3 endpoint.
pub(crate) fn is_s3_endpoint(url: &Url, preview: Preview) -> bool {
if let Some(s3_endpoint_realm) = S3_ENDPOINT_REALM.as_ref().map(RealmRef::from) {
if !preview.is_enabled(PreviewFeatures::S3_ENDPOINT) {
warn_user_once!(
Expand All @@ -79,19 +79,26 @@ impl S3EndpointProvider {
// Treat any URL on the same domain or subdomain as available for S3 signing.
let realm = RealmRef::from(url);
if realm == s3_endpoint_realm || realm.is_subdomain_of(s3_endpoint_realm) {
// TODO(charlie): Can `reqsign` infer the region for us? Profiles, for example,
// often have a region set already.
let region = std::env::var(EnvVars::AWS_REGION)
.map(Cow::Owned)
.unwrap_or_else(|_| {
std::env::var(EnvVars::AWS_DEFAULT_REGION)
.map(Cow::Owned)
.unwrap_or_else(|_| Cow::Borrowed("us-east-1"))
});
let signer = reqsign::aws::default_signer("s3", &region);
return Some(signer);
return true;
}
}
None
false
}

/// Creates a new S3 signer with the configured region.
///
/// This is potentially expensive as it may invoke credential helpers, so the result
/// should be cached.
pub(crate) fn create_signer() -> DefaultSigner {
// TODO(charlie): Can `reqsign` infer the region for us? Profiles, for example,
// often have a region set already.
let region = std::env::var(EnvVars::AWS_REGION)
.map(Cow::Owned)
.unwrap_or_else(|_| {
std::env::var(EnvVars::AWS_DEFAULT_REGION)
.map(Cow::Owned)
.unwrap_or_else(|_| Cow::Borrowed("us-east-1"))
});
reqsign::aws::default_signer("s3", &region)
}
}
Loading