Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/uv-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,3 @@ hyper = { version = "1.4.1", features = ["server", "http1"] }
hyper-util = { version = "0.1.8", features = ["tokio"] }
insta = { version = "1.40.0", features = ["filters", "json", "redactions"] }
tokio = { workspace = true }
wiremock = { workspace = true }
232 changes: 9 additions & 223 deletions crates/uv-client/src/base_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ use std::sync::Arc;
use std::time::Duration;
use std::{env, iter};

use anyhow::anyhow;
use http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
use itertools::Itertools;
use reqwest::{multipart, Client, ClientBuilder, IntoUrl, Proxy, Request, Response};
use reqwest::{Client, ClientBuilder, Proxy, Response};
use reqwest_middleware::{ClientWithMiddleware, Middleware};
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::{
Expand Down Expand Up @@ -62,24 +60,6 @@ pub struct BaseClientBuilder<'a> {
default_timeout: Duration,
extra_middleware: Option<ExtraMiddleware>,
proxies: Vec<Proxy>,
redirect_policy: RedirectPolicy,
}

/// The policy for handling redirects.
#[derive(Debug, Default, Clone, Copy)]
pub enum RedirectPolicy {
#[default]
BypassMiddleware,
RetriggerMiddleware,
}

impl RedirectPolicy {
pub fn reqwest_policy(self) -> reqwest::redirect::Policy {
match self {
RedirectPolicy::BypassMiddleware => reqwest::redirect::Policy::default(),
RedirectPolicy::RetriggerMiddleware => reqwest::redirect::Policy::none(),
}
}
}

/// A list of user-defined middlewares to be applied to the client.
Expand Down Expand Up @@ -115,7 +95,6 @@ impl BaseClientBuilder<'_> {
default_timeout: Duration::from_secs(30),
extra_middleware: None,
proxies: vec![],
redirect_policy: RedirectPolicy::default(),
}
}
}
Expand Down Expand Up @@ -193,12 +172,6 @@ impl<'a> BaseClientBuilder<'a> {
self
}

#[must_use]
pub fn redirect(mut self, policy: RedirectPolicy) -> Self {
self.redirect_policy = policy;
self
}

pub fn is_offline(&self) -> bool {
matches!(self.connectivity, Connectivity::Offline)
}
Expand Down Expand Up @@ -255,7 +228,6 @@ impl<'a> BaseClientBuilder<'a> {
timeout,
ssl_cert_file_exists,
Security::Secure,
self.redirect_policy,
);

// Create an insecure client that accepts invalid certificates.
Expand All @@ -264,18 +236,11 @@ impl<'a> BaseClientBuilder<'a> {
timeout,
ssl_cert_file_exists,
Security::Insecure,
self.redirect_policy,
);

// Wrap in any relevant middleware and handle connectivity.
let client = RedirectClientWithMiddleware {
client: self.apply_middleware(raw_client.clone()),
redirect_policy: self.redirect_policy,
};
let dangerous_client = RedirectClientWithMiddleware {
client: self.apply_middleware(raw_dangerous_client.clone()),
redirect_policy: self.redirect_policy,
};
let client = self.apply_middleware(raw_client.clone());
let dangerous_client = self.apply_middleware(raw_dangerous_client.clone());

BaseClient {
connectivity: self.connectivity,
Expand All @@ -292,14 +257,8 @@ impl<'a> BaseClientBuilder<'a> {
/// Share the underlying client between two different middleware configurations.
pub fn wrap_existing(&self, existing: &BaseClient) -> BaseClient {
// Wrap in any relevant middleware and handle connectivity.
let client = RedirectClientWithMiddleware {
client: self.apply_middleware(existing.raw_client.clone()),
redirect_policy: self.redirect_policy,
};
let dangerous_client = RedirectClientWithMiddleware {
client: self.apply_middleware(existing.raw_dangerous_client.clone()),
redirect_policy: self.redirect_policy,
};
let client = self.apply_middleware(existing.raw_client.clone());
let dangerous_client = self.apply_middleware(existing.raw_dangerous_client.clone());

BaseClient {
connectivity: self.connectivity,
Expand All @@ -319,16 +278,14 @@ impl<'a> BaseClientBuilder<'a> {
timeout: Duration,
ssl_cert_file_exists: bool,
security: Security,
redirect_policy: RedirectPolicy,
) -> Client {
// Configure the builder.
let client_builder = ClientBuilder::new()
.http1_title_case_headers()
.user_agent(user_agent)
.pool_max_idle_per_host(20)
.read_timeout(timeout)
.tls_built_in_root_certs(false)
.redirect(redirect_policy.reqwest_policy());
.tls_built_in_root_certs(false);

// If necessary, accept invalid certificates.
let client_builder = match security {
Expand Down Expand Up @@ -425,9 +382,9 @@ impl<'a> BaseClientBuilder<'a> {
#[derive(Debug, Clone)]
pub struct BaseClient {
/// The underlying HTTP client that enforces valid certificates.
client: RedirectClientWithMiddleware,
client: ClientWithMiddleware,
/// The underlying HTTP client that accepts invalid certificates.
dangerous_client: RedirectClientWithMiddleware,
dangerous_client: ClientWithMiddleware,
/// The HTTP client without middleware.
raw_client: Client,
/// The HTTP client that accepts invalid certificates without middleware.
Expand All @@ -452,20 +409,14 @@ enum Security {

impl BaseClient {
/// Selects the appropriate client based on the host's trustworthiness.
pub fn for_host(&self, url: &Url) -> &RedirectClientWithMiddleware {
pub fn for_host(&self, url: &Url) -> &ClientWithMiddleware {
if self.disable_ssl(url) {
&self.dangerous_client
} else {
&self.client
}
}

/// Executes a request, applying redirect policy.
pub async fn execute(&self, req: Request) -> reqwest_middleware::Result<Response> {
let client = self.for_host(req.url());
client.execute(req).await
}

/// Returns `true` if the host is trusted to use the insecure client.
pub fn disable_ssl(&self, url: &Url) -> bool {
self.allow_insecure_host
Expand All @@ -489,171 +440,6 @@ impl BaseClient {
}
}

/// Wrapper around [`ClientWithMiddleware`] that manages redirects.
#[derive(Debug, Clone)]
pub struct RedirectClientWithMiddleware {
client: ClientWithMiddleware,
redirect_policy: RedirectPolicy,
}

impl RedirectClientWithMiddleware {
/// Convenience method to make a `GET` request to a URL.
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
RequestBuilder::new(self.client.get(url), self)
}

/// Convenience method to make a `POST` request to a URL.
pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder {
RequestBuilder::new(self.client.post(url), self)
}

/// Convenience method to make a `HEAD` request to a URL.
pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder {
RequestBuilder::new(self.client.head(url), self)
}

/// Executes a request, applying the redirect policy.
pub async fn execute(&self, req: Request) -> reqwest_middleware::Result<Response> {
match self.redirect_policy {
RedirectPolicy::BypassMiddleware => self.client.execute(req).await,
RedirectPolicy::RetriggerMiddleware => self.execute_with_redirect_handling(req).await,
}
}

/// Executes a request. If the response is a 302 redirect, executes the
/// request again with the redirect location URL (up to a maximum number
/// of redirects).
///
/// Unlike the built-in reqwest redirect policies, this sends the
/// redirect request through the entire middleware pipeline again.
async fn execute_with_redirect_handling(
&self,
req: Request,
) -> reqwest_middleware::Result<Response> {
let mut request = req;
let mut redirects = 0;
// This is the default used by reqwest.
let max_redirects = 10;

loop {
let result = self
.client
.execute(request.try_clone().expect("HTTP request must be cloneable"))
.await;
if redirects == max_redirects {
return result;
}
let Ok(response) = result else {
return result;
};

// Handle redirect if we receive a 301, 302, 307, or 308.
if matches!(
response.status(),
StatusCode::MOVED_PERMANENTLY
| StatusCode::FOUND
| StatusCode::TEMPORARY_REDIRECT
| StatusCode::PERMANENT_REDIRECT
) {
let location_str = response
.headers()
.get("location")
.ok_or(reqwest_middleware::Error::Middleware(anyhow!(
"Missing 302 location header"
)))?
.to_str()
.map_err(|_| {
reqwest_middleware::Error::Middleware(anyhow!(
"Invalid 302 location header"
))
})?;
let redirect_url = Url::parse(location_str).map_err(|_| {
reqwest_middleware::Error::Middleware(anyhow!("Invalid 302 location URL"))
})?;
debug!("Received 302 redirect to {redirect_url}");
*request.url_mut() = redirect_url;
redirects += 1;
continue;
}

return Ok(response);
}
}

pub fn raw_client(&self) -> &ClientWithMiddleware {
&self.client
}
}

impl From<RedirectClientWithMiddleware> for ClientWithMiddleware {
fn from(item: RedirectClientWithMiddleware) -> ClientWithMiddleware {
item.client
}
}

/// A builder to construct the properties of a `Request`.
///
/// This wraps [`reqwest_middleware::RequestBuilder`] to ensure that the [`BaseClient`]
/// redirect policy is respected if `send()` is called.
#[derive(Debug)]
#[must_use]
pub struct RequestBuilder<'a> {
builder: reqwest_middleware::RequestBuilder,
client: &'a RedirectClientWithMiddleware,
}

impl<'a> RequestBuilder<'a> {
pub fn new(
builder: reqwest_middleware::RequestBuilder,
client: &'a RedirectClientWithMiddleware,
) -> Self {
Self { builder, client }
}

/// Add a `Header` to this Request.
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.builder = self.builder.header(key, value);
self
}

/// Add a set of Headers to the existing ones on this Request.
///
/// The headers will be merged in to any already set.
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.builder = self.builder.headers(headers);
self
}

#[cfg(not(target_arch = "wasm32"))]
pub fn version(mut self, version: reqwest::Version) -> Self {
self.builder = self.builder.version(version);
self
}

#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
pub fn multipart(mut self, multipart: multipart::Form) -> Self {
self.builder = self.builder.multipart(multipart);
self
}

/// Build a `Request`.
pub fn build(self) -> reqwest::Result<Request> {
self.builder.build()
}

/// Constructs the Request and sends it to the target URL, returning a
/// future Response.
pub async fn send(self) -> reqwest_middleware::Result<Response> {
self.client.execute(self.build()?).await
}
}

/// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases.
pub struct UvRetryableStrategy;

Expand Down
2 changes: 2 additions & 0 deletions crates/uv-client/src/cached_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ impl CachedClient {
debug!("Sending revalidation request for: {url}");
let response = self
.0
.for_host(req.url())
.execute(req)
.instrument(info_span!("revalidation_request", url = url.as_str()))
.await
Expand Down Expand Up @@ -550,6 +551,7 @@ impl CachedClient {
let cache_policy_builder = CachePolicyBuilder::new(&req);
let response = self
.0
.for_host(&url)
.execute(req)
.await
.map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?
Expand Down
2 changes: 1 addition & 1 deletion crates/uv-client/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub use base_client::{
is_extended_transient_error, AuthIntegration, BaseClient, BaseClientBuilder, ExtraMiddleware,
RedirectClientWithMiddleware, RequestBuilder, UvRetryableStrategy, DEFAULT_RETRIES,
UvRetryableStrategy, DEFAULT_RETRIES,
};
pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy};
pub use error::{Error, ErrorKind, WrappedReqwestError};
Expand Down
Loading
Loading