Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ h2 = { version = "0.4", optional = true }
log = "0.4.17"
percent-encoding = "2.3"
tokio = { version = "1.0", default-features = false, features = ["net", "time"] }
tower = { version = "0.5.2", default-features = false, features = ["timeout", "util"] }
tower = { version = "0.5.2", default-features = false, features = ["retry", "timeout", "util"] }
tower-http = { version = "0.6.5", default-features = false, features = ["follow-redirect"] }
pin-project-lite = "0.2.11"

Expand Down
9 changes: 0 additions & 9 deletions src/async_impl/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,6 @@ impl Body {
}
}

pub(crate) fn try_reuse(self) -> (Option<Bytes>, Self) {
let reuse = match self.inner {
Inner::Reusable(ref chunk) => Some(chunk.clone()),
Inner::Streaming { .. } => None,
};

(reuse, self)
}

pub(crate) fn try_clone(&self) -> Option<Body> {
match self.inner {
Inner::Reusable(ref chunk) => Some(Body::reusable(chunk.clone())),
Expand Down
245 changes: 69 additions & 176 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#[cfg(any(feature = "native-tls", feature = "__rustls",))]
use std::any::Any;
#[cfg(feature = "http2")]
use std::error::Error;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
Expand Down Expand Up @@ -45,7 +43,6 @@ use crate::Certificate;
use crate::Identity;
use crate::{IntoUrl, Method, Proxy, Url};

use bytes::Bytes;
use http::header::{
Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, PROXY_AUTHORIZATION, RANGE, USER_AGENT,
};
Expand Down Expand Up @@ -176,6 +173,7 @@ struct Config {
proxies: Vec<ProxyMatcher>,
auto_sys_proxy: bool,
redirect_policy: redirect::Policy,
retry_policy: crate::retry::Builder,
referer: bool,
read_timeout: Option<Duration>,
timeout: Option<Duration>,
Expand Down Expand Up @@ -300,6 +298,7 @@ impl ClientBuilder {
proxies: Vec::new(),
auto_sys_proxy: true,
redirect_policy: redirect::Policy::default(),
retry_policy: crate::retry::Builder::default(),
referer: true,
read_timeout: None,
timeout: None,
Expand Down Expand Up @@ -999,14 +998,18 @@ impl ClientBuilder {
hyper: hyper_client,
};

let policy = {
let redirect_policy = {
let mut p = TowerRedirectPolicy::new(config.redirect_policy);
p.with_referer(config.referer)
.with_https_only(config.https_only);
p
};

let hyper = FollowRedirect::with_policy(hyper_service, policy.clone());
let retry_policy = config.retry_policy.into_policy();

let retries = tower::retry::Retry::new(retry_policy.clone(), hyper_service);

let hyper = FollowRedirect::with_policy(retries, redirect_policy.clone());

Ok(Client {
inner: Arc::new(ClientRef {
Expand All @@ -1026,7 +1029,8 @@ impl ClientBuilder {
config.pool_idle_timeout,
config.cookie_store,
);
Some(FollowRedirect::with_policy(h3_service, policy))
let retries = tower::retry::Retry::new(retry_policy, h3_service);
Some(FollowRedirect::with_policy(retries, redirect_policy))
}
None => None,
},
Expand Down Expand Up @@ -1337,6 +1341,17 @@ impl ClientBuilder {
self
}

// Retry options

/// Set a request retry policy.
///
/// Default behavior is to retry protocol NACKs.
// XXX: accept an `impl retry::IntoPolicy` instead?
pub fn retry(mut self, policy: crate::retry::Builder) -> ClientBuilder {
self.config.retry_policy = policy;
self
}

// Proxy options

/// Add a `Proxy` to the list of proxies the `Client` will use.
Expand Down Expand Up @@ -2505,13 +2520,7 @@ impl Client {
_ => return Pending::new_err(error::url_invalid_uri(url)),
};

let (reusable, body) = match body {
Some(body) => {
let (reusable, body) = body.try_reuse();
(Some(reusable), body)
}
None => (None, Body::empty()),
};
let body = body.unwrap_or_else(Body::empty);

self.proxy_auth(&uri, &mut headers);
self.proxy_custom_headers(&uri, &mut headers);
Expand Down Expand Up @@ -2556,9 +2565,6 @@ impl Client {
method,
url,
headers,
body: reusable,

retry_count: 0,

client: self.inner.clone(),

Expand Down Expand Up @@ -2792,14 +2798,18 @@ impl Config {
}
}

type LayeredService<T> =
FollowRedirect<tower::retry::Retry<crate::retry::Policy, T>, TowerRedirectPolicy>;
type LayeredFuture<T> = <LayeredService<T> as Service<http::Request<Body>>>::Future;

struct ClientRef {
accepts: Accepts,
#[cfg(feature = "cookies")]
cookie_store: Option<Arc<dyn cookie::CookieStore>>,
headers: HeaderMap,
hyper: FollowRedirect<HyperService, TowerRedirectPolicy>,
hyper: LayeredService<HyperService>,
#[cfg(feature = "http3")]
h3_client: Option<FollowRedirect<H3Client, TowerRedirectPolicy>>,
h3_client: Option<LayeredService<H3Client>>,
referer: bool,
request_timeout: RequestConfig<RequestTimeout>,
read_timeout: Option<Duration>,
Expand Down Expand Up @@ -2863,9 +2873,6 @@ pin_project! {
method: Method,
url: Url,
headers: HeaderMap,
body: Option<Option<Bytes>>,

retry_count: usize,

client: Arc<ClientRef>,

Expand All @@ -2880,9 +2887,9 @@ pin_project! {
}

enum ResponseFuture {
Default(tower_http::follow_redirect::ResponseFuture<HyperService, Body, TowerRedirectPolicy>),
Default(LayeredFuture<HyperService>),
#[cfg(feature = "http3")]
H3(tower_http::follow_redirect::ResponseFuture<H3Client, Body, TowerRedirectPolicy>),
H3(LayeredFuture<H3Client>),
}

impl PendingRequest {
Expand All @@ -2897,103 +2904,6 @@ impl PendingRequest {
fn read_timeout(self: Pin<&mut Self>) -> Pin<&mut Option<Pin<Box<Sleep>>>> {
self.project().read_timeout_fut
}

#[cfg(any(feature = "http2", feature = "http3"))]
fn retry_error(mut self: Pin<&mut Self>, err: &(dyn std::error::Error + 'static)) -> bool {
use log::trace;

if !is_retryable_error(err) {
return false;
}

trace!("can retry {err:?}");

let body = match self.body {
Some(Some(ref body)) => Body::reusable(body.clone()),
Some(None) => {
log::debug!("error was retryable, but body not reusable");
return false;
}
None => Body::empty(),
};

if self.retry_count >= 2 {
trace!("retry count too high");
return false;
}
self.retry_count += 1;

// If it parsed once, it should parse again
let uri = try_uri(&self.url).expect("URL was already validated as URI");

*self.as_mut().in_flight().get_mut() = match *self.as_mut().in_flight().as_ref() {
#[cfg(feature = "http3")]
ResponseFuture::H3(_) => {
let mut req = hyper::Request::builder()
.method(self.method.clone())
.uri(uri)
.body(body)
.expect("valid request parts");
*req.headers_mut() = self.headers.clone();
let mut h3 = self
.client
.h3_client
.as_ref()
.expect("H3 client must exists, otherwise we can't have a h3 request here")
.clone();
ResponseFuture::H3(h3.call(req))
}
_ => {
let mut req = hyper::Request::builder()
.method(self.method.clone())
.uri(uri)
.body(body)
.expect("valid request parts");
*req.headers_mut() = self.headers.clone();
let mut hyper = self.client.hyper.clone();
ResponseFuture::Default(hyper.call(req))
}
};

true
}
}

#[cfg(any(feature = "http2", feature = "http3"))]
fn is_retryable_error(err: &(dyn std::error::Error + 'static)) -> bool {
// pop the legacy::Error
let err = if let Some(err) = err.source() {
err
} else {
return false;
};

#[cfg(feature = "http3")]
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<h3::error::ConnectionError>() {
log::debug!("determining if HTTP/3 error {err} can be retried");
// TODO: Does h3 provide an API for checking the error?
return err.to_string().as_str() == "timeout";
}
}

#[cfg(feature = "http2")]
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<h2::Error>() {
// They sent us a graceful shutdown, try with a new connection!
if err.is_go_away() && err.is_remote() && err.reason() == Some(h2::Reason::NO_ERROR) {
return true;
}

// REFUSED_STREAM was sent from the server, which is safe to retry.
// https://www.rfc-editor.org/rfc/rfc9113.html#section-8.7-3.2
if err.is_reset() && err.is_remote() && err.reason() == Some(h2::Reason::REFUSED_STREAM)
{
return true;
}
}
}
false
}

impl Pending {
Expand Down Expand Up @@ -3042,66 +2952,49 @@ impl Future for PendingRequest {
}
}

loop {
let res = match self.as_mut().in_flight().get_mut() {
ResponseFuture::Default(r) => match ready!(Pin::new(r).poll(cx)) {
Err(e) => {
#[cfg(feature = "http2")]
if e.is_request() {
if let Some(e) = e.source() {
if self.as_mut().retry_error(e) {
continue;
}
}
}

return Poll::Ready(Err(e.if_no_url(|| self.url.clone())));
}
Ok(res) => res.map(super::body::boxed),
},
#[cfg(feature = "http3")]
ResponseFuture::H3(r) => match ready!(Pin::new(r).poll(cx)) {
Err(e) => {
if self.as_mut().retry_error(&e) {
continue;
}
return Poll::Ready(Err(
crate::error::request(e).with_url(self.url.clone())
));
}
Ok(res) => res,
},
};

#[cfg(feature = "cookies")]
{
if let Some(ref cookie_store) = self.client.cookie_store {
let mut cookies =
cookie::extract_response_cookie_headers(res.headers()).peekable();
if cookies.peek().is_some() {
cookie_store.set_cookies(&mut cookies, &self.url);
}
let res = match self.as_mut().in_flight().get_mut() {
ResponseFuture::Default(r) => match ready!(Pin::new(r).poll(cx)) {
Err(e) => {
return Poll::Ready(Err(e.if_no_url(|| self.url.clone())));
}
}
if let Some(url) = &res
.extensions()
.get::<tower_http::follow_redirect::RequestUri>()
{
self.url = match Url::parse(&url.0.to_string()) {
Ok(url) => url,
Err(e) => return Poll::Ready(Err(crate::error::decode(e))),
Ok(res) => res.map(super::body::boxed),
},
#[cfg(feature = "http3")]
ResponseFuture::H3(r) => match ready!(Pin::new(r).poll(cx)) {
Err(e) => {
return Poll::Ready(Err(crate::error::request(e).with_url(self.url.clone())));
}
};
Ok(res) => res,
},
};

let res = Response::new(
res,
self.url.clone(),
self.client.accepts,
self.total_timeout.take(),
self.read_timeout,
);
return Poll::Ready(Ok(res));
#[cfg(feature = "cookies")]
{
if let Some(ref cookie_store) = self.client.cookie_store {
let mut cookies = cookie::extract_response_cookie_headers(res.headers()).peekable();
if cookies.peek().is_some() {
cookie_store.set_cookies(&mut cookies, &self.url);
}
}
}
if let Some(url) = &res
.extensions()
.get::<tower_http::follow_redirect::RequestUri>()
{
self.url = match Url::parse(&url.0.to_string()) {
Ok(url) => url,
Err(e) => return Poll::Ready(Err(crate::error::decode(e))),
}
};

let res = Response::new(
res,
self.url.clone(),
self.client.accepts,
self.total_timeout.take(),
self.read_timeout,
);
Poll::Ready(Ok(res))
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/blocking/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ impl ClientBuilder {
self.with_inner(move |inner| inner.redirect(policy))
}

/// Set a request retry policy.
///
/// Default behavior is to retry protocol NACKs.
pub fn retry(self, policy: crate::retry::Builder) -> ClientBuilder {
self.with_inner(move |inner| inner.retry(policy))
}

/// Enable or disable automatic setting of the `Referer` header.
///
/// Default is `true`.
Expand Down
Loading
Loading