Skip to content

Commit

Permalink
Trusted host
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Aug 24, 2024
1 parent 31019ff commit b2db1d7
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 94 deletions.
191 changes: 113 additions & 78 deletions crates/uv-client/src/base_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use reqwest_retry::{
DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy,
};
use tracing::debug;

use url::Url;
use pep508_rs::MarkerEnvironment;
use platform_tags::Platform;
use uv_auth::AuthMiddleware;
Expand All @@ -24,7 +24,7 @@ use uv_warnings::warn_user_once;
use crate::linehaul::LineHaul;
use crate::middleware::OfflineMiddleware;
use crate::tls::read_identity;
use crate::Connectivity;
use crate::{CachedClient, Connectivity};

/// A builder for an [`BaseClient`].
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -109,6 +109,7 @@ impl<'a> BaseClientBuilder<'a> {
// Create user agent.
let mut user_agent_string = format!("uv/{}", version());


// Add linehaul metadata.
if let Some(markers) = self.markers {
let linehaul = LineHaul::new(markers, self.platform);
Expand All @@ -117,6 +118,18 @@ impl<'a> BaseClientBuilder<'a> {
}
}

// Check for the presence of an `SSL_CERT_FILE`.
let ssl_cert_file_exists = env::var_os("SSL_CERT_FILE").is_some_and(|path| {
let path_exists = Path::new(&path).exists();
if !path_exists {
warn_user_once!(
"Ignoring invalid `SSL_CERT_FILE`. File does not exist: {}.",
path.simplified_display().cyan()
);
}
path_exists
});

// Timeout options, matching https://doc.rust-lang.org/nightly/cargo/reference/config.html#httptimeout
// `UV_REQUEST_TIMEOUT` is provided for backwards compatibility with v0.1.6
let default_timeout = 30;
Expand All @@ -134,100 +147,132 @@ impl<'a> BaseClientBuilder<'a> {
.unwrap_or(default_timeout);
debug!("Using request timeout of {timeout}s");

// Initialize the base client.
let client = self.client.clone().unwrap_or_else(|| {
// Check for the presence of an `SSL_CERT_FILE`.
let ssl_cert_file_exists = env::var_os("SSL_CERT_FILE").is_some_and(|path| {
let path_exists = Path::new(&path).exists();
if !path_exists {
warn_user_once!(
"Ignoring invalid `SSL_CERT_FILE`. File does not exist: {}.",
path.simplified_display().cyan()
);
}
path_exists
});

// Configure the builder.
let client_core = ClientBuilder::new()
.user_agent(user_agent_string)
.pool_max_idle_per_host(20)
.read_timeout(std::time::Duration::from_secs(timeout))
.tls_built_in_root_certs(false);

// Configure TLS.
let client_core = if self.native_tls || ssl_cert_file_exists {
client_core.tls_built_in_native_certs(true)
} else {
client_core.tls_built_in_webpki_certs(true)
};

// Configure mTLS.
let client_core = if let Some(ssl_client_cert) = env::var_os("SSL_CLIENT_CERT") {
match read_identity(&ssl_client_cert) {
Ok(identity) => client_core.identity(identity),
Err(err) => {
warn_user_once!("Ignoring invalid `SSL_CLIENT_CERT`: {err}");
client_core
}
}
} else {
client_core
};
// Create a secure client that validates certificates.
let client =
self.create_client(&user_agent_string, timeout, false, ssl_cert_file_exists);

client_core.build().expect("Failed to build HTTP client")
});
// Create an insecure client that accepts invalid certificates.
let dangerous_client =
self.create_client(&user_agent_string, timeout, true, ssl_cert_file_exists);

// Wrap in any relevant middleware.

// Wrap in any relevant middleware and handle connectivity.
let client = match self.connectivity {
Connectivity::Online => {
let client = reqwest_middleware::ClientBuilder::new(client.clone());

// Initialize the retry strategy.
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(self.retries);
let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy(
retry_policy,
UvRetryableStrategy,
);
let client = client.with(retry_strategy);

// Initialize the authentication middleware to set headers.
let client =
client.with(AuthMiddleware::new().with_keyring(self.keyring.to_provider()));

client.build()
}
Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client.clone())
.with(OfflineMiddleware)
.build(),
Connectivity::Online => self.apply_middleware(client),
Connectivity::Offline => self.apply_offline_middleware(client),
};
let dangerous_client = match self.connectivity {
Connectivity::Online => self.apply_middleware(dangerous_client),
Connectivity::Offline => self.apply_offline_middleware(dangerous_client),
};

BaseClient {
connectivity: self.connectivity,
client,
dangerous_client,
timeout,
trusted_host: vec![]
}
}

fn create_client(
&self,
user_agent: &str,
timeout: u64,
accept_invalid_certs: bool,
ssl_cert_file_exists: bool,
) -> Client {
// Configure the builder.
let client_builder = ClientBuilder::new()
.user_agent(user_agent)
.pool_max_idle_per_host(20)
.read_timeout(std::time::Duration::from_secs(timeout))
.tls_built_in_root_certs(false)
.danger_accept_invalid_certs(accept_invalid_certs);

let client_builder = if self.native_tls || ssl_cert_file_exists {
client_builder.tls_built_in_native_certs(true)
} else {
client_builder.tls_built_in_webpki_certs(true)
};

// Configure mTLS.
let client_builder = if let Some(ssl_client_cert) = env::var_os("SSL_CLIENT_CERT") {
match read_identity(&ssl_client_cert) {
Ok(identity) => client_builder.identity(identity),
Err(err) => {
warn_user_once!("Ignoring invalid `SSL_CLIENT_CERT`: {err}");
client_builder
}
}
} else {
client_builder
};

client_builder
.build()
.expect("Failed to build HTTP client.")
}

fn apply_middleware(&self, client: Client) -> ClientWithMiddleware {
let client = reqwest_middleware::ClientBuilder::new(client.clone());

// Initialize the retry strategy.
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(self.retries);
let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy(
retry_policy,
UvRetryableStrategy,
);
let client = client.with(retry_strategy);

// Initialize the authentication middleware to set headers.
let client =
client.with(AuthMiddleware::new().with_keyring(self.keyring.to_provider()));


client.build()
}

fn apply_offline_middleware(&self, client: Client) -> ClientWithMiddleware {
reqwest_middleware::ClientBuilder::new(client)
.with(OfflineMiddleware)
.build()
}
}

/// A base client for HTTP requests
#[derive(Debug, Clone)]
pub struct BaseClient {
/// The underlying HTTP client.
/// The underlying HTTP client that enforces valid certificates.
client: ClientWithMiddleware,
/// The underlying HTTP client that accepts invalid certificates.
dangerous_client: ClientWithMiddleware,
/// The connectivity mode to use.
connectivity: Connectivity,
/// Configured client timeout, in seconds.
timeout: u64,
/// The host that is trusted to use the insecure client.
trusted_host: Vec<Url>,
}

impl BaseClient {
/// The underlying [`ClientWithMiddleware`].
/// The underlying [`ClientWithMiddleware`] for secure requests.
pub fn client(&self) -> ClientWithMiddleware {
self.client.clone()
}

/// Selects the appropriate client based on the host's trustworthiness.
pub fn for_host(&self, url: &Url) -> &ClientWithMiddleware {
if self
.trusted_host.iter().any(|trusted| url.host() == trusted.host())
{
&self.dangerous_client
} else {
&self.client
}
}

/// The configured client timeout, in seconds.
pub fn timeout(&self) -> u64 {
self.timeout
Expand All @@ -239,16 +284,6 @@ impl BaseClient {
}
}

// To avoid excessively verbose call chains, as the [`BaseClient`] is often nested within other client types.
impl Deref for BaseClient {
type Target = ClientWithMiddleware;

/// Deference to the underlying [`ClientWithMiddleware`].
fn deref(&self) -> &Self::Target {
&self.client
}
}

/// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases.
struct UvRetryableStrategy;

Expand Down
8 changes: 5 additions & 3 deletions crates/uv-client/src/cached_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ impl CachedClient {
Self(client)
}

/// The base client
pub fn uncached(&self) -> BaseClient {
self.0.clone()
/// The underlying [`BaseClient`] without caching.
pub fn uncached(&self) -> &BaseClient {
&self.0
}

/// Make a cached request with a custom response transformation
Expand Down Expand Up @@ -460,6 +460,7 @@ impl CachedClient {
debug!("Sending revalidation request for: {url}");
let response = self
.0
.for_host(req.url())
.execute(req)
.instrument(info_span!("revalidation_request", url = url.as_str()))
.await
Expand Down Expand Up @@ -499,6 +500,7 @@ impl CachedClient {
let cache_policy_builder = CachePolicyBuilder::new(&req);
let response = self
.0
.for_host(req.url())
.execute(req)
.await
.map_err(ErrorKind::from)?
Expand Down
2 changes: 1 addition & 1 deletion crates/uv-client/src/flat_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl<'a> FlatIndexClient<'a> {

let flat_index_request = self
.client
.uncached_client()
.uncached_client(url)
.get(url.clone())
.header("Accept-Encoding", "gzip")
.header("Accept", "text/html")
Expand Down
15 changes: 8 additions & 7 deletions crates/uv-client/src/registry_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use async_http_range_reader::AsyncHttpRangeReader;
use futures::{FutureExt, TryStreamExt};
use http::HeaderMap;
use reqwest::{Client, Response, StatusCode};
use reqwest_middleware::ClientWithMiddleware;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncReadExt;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
Expand Down Expand Up @@ -171,8 +172,8 @@ impl RegistryClient {
}

/// Return the [`BaseClient`] used by this client.
pub fn uncached_client(&self) -> BaseClient {
self.client.uncached()
pub fn uncached_client(&self, url: &Url) -> &ClientWithMiddleware {
self.client.uncached().for_host(url)
}

/// Return the [`Connectivity`] mode used by this client.
Expand Down Expand Up @@ -298,7 +299,7 @@ impl RegistryClient {
cache_control: CacheControl,
) -> Result<OwnedArchive<SimpleMetadata>, Error> {
let simple_request = self
.uncached_client()
.uncached_client(&url)
.get(url.clone())
.header("Accept-Encoding", "gzip")
.header("Accept", MediaType::accepts())
Expand Down Expand Up @@ -512,7 +513,7 @@ impl RegistryClient {
})
};
let req = self
.uncached_client()
.uncached_client(&url)
.get(url.clone())
.build()
.map_err(ErrorKind::from)?;
Expand Down Expand Up @@ -551,7 +552,7 @@ impl RegistryClient {
};

let req = self
.uncached_client()
.uncached_client(&url)
.head(url.clone())
.header(
"accept-encoding",
Expand All @@ -571,7 +572,7 @@ impl RegistryClient {
let read_metadata_range_request = |response: Response| {
async {
let mut reader = AsyncHttpRangeReader::from_head_response(
self.uncached_client().client(),
self.uncached_client(&url).clone(),
response,
url.clone(),
headers,
Expand Down Expand Up @@ -619,7 +620,7 @@ impl RegistryClient {

// Create a request to stream the file.
let req = self
.uncached_client()
.uncached_client(&url)
.get(url.clone())
.header(
// `reqwest` defaults to accepting compressed responses.
Expand Down
2 changes: 2 additions & 0 deletions crates/uv-client/tests/user_agent_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async fn test_user_agent_has_version() -> Result<()> {
let res = client
.cached_client()
.uncached()
.client()
.get(format!("http://{addr}"))
.send()
.await?;
Expand Down Expand Up @@ -151,6 +152,7 @@ async fn test_user_agent_has_linehaul() -> Result<()> {
let res = client
.cached_client()
.uncached()
.client()
.get(format!("http://{addr}"))
.send()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion crates/uv-distribution/src/distribution_database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
fn request(&self, url: Url) -> Result<reqwest::Request, reqwest::Error> {
self.client
.unmanaged
.uncached_client()
.uncached_client(&url)
.get(url)
.header(
// `reqwest` defaults to accepting compressed responses.
Expand Down
Loading

0 comments on commit b2db1d7

Please sign in to comment.