diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2a35f8df6..cf56103b7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -215,7 +215,7 @@ jobs: toolchain: 'stable' - name: Check - run: cargo test --features http3 + run: cargo test --features http3,stream env: RUSTFLAGS: --cfg reqwest_unstable RUSTDOCFLAGS: --cfg reqwest_unstable diff --git a/Cargo.toml b/Cargo.toml index 3fde9e882..91df31686 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,7 @@ socks = ["dep:tokio-socks"] macos-system-configuration = ["dep:system-configuration"] # Experimental HTTP/3 client. -http3 = ["rustls-tls-manual-roots", "dep:h3", "dep:h3-quinn", "dep:quinn", "dep:slab", "dep:futures-channel"] +http3 = ["rustls-tls-manual-roots", "dep:h3", "dep:h3-quinn", "dep:quinn", "dep:slab", "dep:futures-channel", "tokio/macros"] # Internal (PRIVATE!) features used to aid testing. diff --git a/src/async_impl/h3_client/pool.rs b/src/async_impl/h3_client/pool.rs index 0926fb3cf..8d730f39b 100644 --- a/src/async_impl/h3_client/pool.rs +++ b/src/async_impl/h3_client/pool.rs @@ -6,7 +6,7 @@ use std::sync::mpsc::{Receiver, TryRecvError}; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::sync::watch; +use tokio::sync::{oneshot, watch}; use tokio::time::Instant; use crate::async_impl::body::ResponseBody; @@ -17,7 +17,7 @@ use h3::client::SendRequest; use h3_quinn::{Connection, OpenStreams}; use http::uri::{Authority, Scheme}; use http::{Request, Response, Uri}; -use log::trace; +use log::{error, trace}; pub(super) type Key = (Scheme, Authority); @@ -209,7 +209,7 @@ impl PoolClient { ) -> Result, BoxError> { use hyper::body::Body as _; - let (head, req_body) = req.into_parts(); + let (head, mut req_body) = req.into_parts(); let mut req = Request::from_parts(head, ()); if let Some(n) = req_body.size_hint().exact() { @@ -219,22 +219,52 @@ impl PoolClient { } } - let mut stream = self.inner.send_request(req).await?; + let (mut send, mut recv) = self.inner.send_request(req).await?.split(); - match req_body.as_bytes() { - Some(b) if !b.is_empty() => { - stream.send_data(Bytes::copy_from_slice(b)).await?; + let (tx, mut rx) = oneshot::channel::>(); + tokio::spawn(async move { + let mut req_body = Pin::new(&mut req_body); + loop { + match std::future::poll_fn(|cx| req_body.as_mut().poll_frame(cx)).await { + Some(Ok(frame)) => { + if let Ok(b) = frame.into_data() { + if let Err(e) = send.send_data(Bytes::copy_from_slice(&b)).await { + if let Err(e) = tx.send(Err(e.into())) { + error!("Failed to communicate send.send_data() error: {e:?}"); + } + return; + } + } + } + Some(Err(e)) => { + if let Err(e) = tx.send(Err(e.into())) { + error!("Failed to communicate req_body read error: {e:?}"); + } + return; + } + + None => break, + } } - _ => {} - } - stream.finish().await?; - - let resp = stream.recv_response().await?; + if let Err(e) = send.finish().await { + if let Err(e) = tx.send(Err(e.into())) { + error!("Failed to communicate send.finish read error: {e:?}"); + } + return; + } - let resp_body = crate::async_impl::body::boxed(Incoming::new(stream, resp.headers())); + let _ = tx.send(Ok(())); + }); - Ok(resp.map(|_| resp_body)) + tokio::select! { + Ok(Err(e)) = &mut rx => Err(e), + resp = recv.recv_response() => { + let resp = resp?; + let resp_body = crate::async_impl::body::boxed(Incoming::new(recv, resp.headers(), rx)); + Ok(resp.map(|_| resp_body)) + } + } } } @@ -271,16 +301,22 @@ impl PoolConnection { struct Incoming { inner: h3::client::RequestStream, content_length: Option, + send_rx: oneshot::Receiver>, } impl Incoming { - fn new(stream: h3::client::RequestStream, headers: &http::header::HeaderMap) -> Self { + fn new( + stream: h3::client::RequestStream, + headers: &http::header::HeaderMap, + send_rx: oneshot::Receiver>, + ) -> Self { Self { inner: stream, content_length: headers .get(http::header::CONTENT_LENGTH) .and_then(|h| h.to_str().ok()) .and_then(|v| v.parse().ok()), + send_rx, } } } @@ -296,6 +332,10 @@ where mut self: Pin<&mut Self>, cx: &mut Context, ) -> Poll, Self::Error>>> { + if let Ok(Err(e)) = self.send_rx.try_recv() { + return Poll::Ready(Some(Err(crate::error::body(e)))); + } + match futures_core::ready!(self.inner.poll_recv_data(cx)) { Ok(Some(mut b)) => Poll::Ready(Some(Ok(hyper::body::Frame::data( b.copy_to_bytes(b.remaining()), diff --git a/tests/http3.rs b/tests/http3.rs index 57a5331fd..9f097f6c3 100644 --- a/tests/http3.rs +++ b/tests/http3.rs @@ -213,3 +213,79 @@ async fn http3_test_reconnection() { assert_eq!(res.status(), reqwest::StatusCode::OK); drop(server); } + +#[cfg(all(feature = "http3", feature = "stream"))] +#[tokio::test] +async fn http3_request_stream() { + use http_body_util::BodyExt; + + let server = server::Http3::new().build(move |req| async move { + let reqb = req.collect().await.unwrap().to_bytes(); + assert_eq!(reqb, "hello world"); + http::Response::default() + }); + + let url = format!("https://{}", server.addr()); + let body = reqwest::Body::wrap_stream(futures_util::stream::iter(vec![ + Ok::<_, std::convert::Infallible>("hello"), + Ok::<_, std::convert::Infallible>(" "), + Ok::<_, std::convert::Infallible>("world"), + ])); + + let res = reqwest::Client::builder() + .http3_prior_knowledge() + .danger_accept_invalid_certs(true) + .build() + .expect("client builder") + .post(url) + .version(http::Version::HTTP_3) + .body(body) + .send() + .await + .expect("request"); + + assert_eq!(res.version(), http::Version::HTTP_3); + assert_eq!(res.status(), reqwest::StatusCode::OK); +} + +#[cfg(all(feature = "http3", feature = "stream"))] +#[tokio::test] +async fn http3_request_stream_error() { + use http_body_util::BodyExt; + + let server = server::Http3::new().build(move |req| async move { + // HTTP/3 response can start and finish before the entire request body has been received. + // To avoid prematurely terminating the session, collect full request body before responding. + let _ = req.collect().await; + + http::Response::default() + }); + + let url = format!("https://{}", server.addr()); + let body = reqwest::Body::wrap_stream(futures_util::stream::iter(vec![ + Ok::<_, std::io::Error>("first chunk"), + Err::<_, std::io::Error>(std::io::Error::other("oh no!")), + ])); + + let res = reqwest::Client::builder() + .http3_prior_knowledge() + .danger_accept_invalid_certs(true) + .build() + .expect("client builder") + .post(url) + .version(http::Version::HTTP_3) + .body(body) + .send() + .await; + + let err = res.unwrap_err(); + assert!(err.is_request()); + let err = err + .source() + .unwrap() + .source() + .unwrap() + .downcast_ref::() + .unwrap(); + assert!(err.is_body()); +}