Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ jobs:
toolchain: 'stable'

- name: Check
run: cargo test --features http3
run: cargo test --features http3,stream
env:
RUSTFLAGS: --cfg reqwest_unstable
RUSTDOCFLAGS: --cfg reqwest_unstable
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ socks = ["dep:tokio-socks"]
macos-system-configuration = ["dep:system-configuration"]

# Experimental HTTP/3 client.
http3 = ["rustls-tls-manual-roots", "dep:h3", "dep:h3-quinn", "dep:quinn", "dep:slab", "dep:futures-channel"]
http3 = ["rustls-tls-manual-roots", "dep:h3", "dep:h3-quinn", "dep:quinn", "dep:slab", "dep:futures-channel", "tokio/macros"]


# Internal (PRIVATE!) features used to aid testing.
Expand Down
70 changes: 55 additions & 15 deletions src/async_impl/h3_client/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::sync::mpsc::{Receiver, TryRecvError};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::watch;
use tokio::sync::{oneshot, watch};
use tokio::time::Instant;

use crate::async_impl::body::ResponseBody;
Expand All @@ -17,7 +17,7 @@ use h3::client::SendRequest;
use h3_quinn::{Connection, OpenStreams};
use http::uri::{Authority, Scheme};
use http::{Request, Response, Uri};
use log::trace;
use log::{error, trace};

pub(super) type Key = (Scheme, Authority);

Expand Down Expand Up @@ -209,7 +209,7 @@ impl PoolClient {
) -> Result<Response<ResponseBody>, BoxError> {
use hyper::body::Body as _;

let (head, req_body) = req.into_parts();
let (head, mut req_body) = req.into_parts();
let mut req = Request::from_parts(head, ());

if let Some(n) = req_body.size_hint().exact() {
Expand All @@ -219,22 +219,52 @@ impl PoolClient {
}
}

let mut stream = self.inner.send_request(req).await?;
let (mut send, mut recv) = self.inner.send_request(req).await?.split();

match req_body.as_bytes() {
Some(b) if !b.is_empty() => {
stream.send_data(Bytes::copy_from_slice(b)).await?;
let (tx, mut rx) = oneshot::channel::<Result<(), BoxError>>();
tokio::spawn(async move {
let mut req_body = Pin::new(&mut req_body);
loop {
match std::future::poll_fn(|cx| req_body.as_mut().poll_frame(cx)).await {
Some(Ok(frame)) => {
if let Ok(b) = frame.into_data() {
if let Err(e) = send.send_data(Bytes::copy_from_slice(&b)).await {
if let Err(e) = tx.send(Err(e.into())) {
error!("Failed to communicate send.send_data() error: {e:?}");
}
return;
}
}
}
Some(Err(e)) => {
if let Err(e) = tx.send(Err(e.into())) {
error!("Failed to communicate req_body read error: {e:?}");
}
return;
}

None => break,
}
}
_ => {}
}

stream.finish().await?;

let resp = stream.recv_response().await?;
if let Err(e) = send.finish().await {
if let Err(e) = tx.send(Err(e.into())) {
error!("Failed to communicate send.finish read error: {e:?}");
}
return;
}

let resp_body = crate::async_impl::body::boxed(Incoming::new(stream, resp.headers()));
let _ = tx.send(Ok(()));
});

Ok(resp.map(|_| resp_body))
tokio::select! {
Ok(Err(e)) = &mut rx => Err(e),
resp = recv.recv_response() => {
let resp = resp?;
let resp_body = crate::async_impl::body::boxed(Incoming::new(recv, resp.headers(), rx));
Ok(resp.map(|_| resp_body))
}
}
}
}

Expand Down Expand Up @@ -271,16 +301,22 @@ impl PoolConnection {
struct Incoming<S, B> {
inner: h3::client::RequestStream<S, B>,
content_length: Option<u64>,
send_rx: oneshot::Receiver<Result<(), BoxError>>,
}

impl<S, B> Incoming<S, B> {
fn new(stream: h3::client::RequestStream<S, B>, headers: &http::header::HeaderMap) -> Self {
fn new(
stream: h3::client::RequestStream<S, B>,
headers: &http::header::HeaderMap,
send_rx: oneshot::Receiver<Result<(), BoxError>>,
) -> Self {
Self {
inner: stream,
content_length: headers
.get(http::header::CONTENT_LENGTH)
.and_then(|h| h.to_str().ok())
.and_then(|v| v.parse().ok()),
send_rx,
}
}
}
Expand All @@ -296,6 +332,10 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
if let Ok(Err(e)) = self.send_rx.try_recv() {
return Poll::Ready(Some(Err(crate::error::body(e))));
}

match futures_core::ready!(self.inner.poll_recv_data(cx)) {
Ok(Some(mut b)) => Poll::Ready(Some(Ok(hyper::body::Frame::data(
b.copy_to_bytes(b.remaining()),
Expand Down
76 changes: 76 additions & 0 deletions tests/http3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,79 @@ async fn http3_test_reconnection() {
assert_eq!(res.status(), reqwest::StatusCode::OK);
drop(server);
}

#[cfg(all(feature = "http3", feature = "stream"))]
#[tokio::test]
async fn http3_request_stream() {
use http_body_util::BodyExt;

let server = server::Http3::new().build(move |req| async move {
let reqb = req.collect().await.unwrap().to_bytes();
assert_eq!(reqb, "hello world");
http::Response::default()
});

let url = format!("https://{}", server.addr());
let body = reqwest::Body::wrap_stream(futures_util::stream::iter(vec![
Ok::<_, std::convert::Infallible>("hello"),
Ok::<_, std::convert::Infallible>(" "),
Ok::<_, std::convert::Infallible>("world"),
]));

let res = reqwest::Client::builder()
.http3_prior_knowledge()
.danger_accept_invalid_certs(true)
.build()
.expect("client builder")
.post(url)
.version(http::Version::HTTP_3)
.body(body)
.send()
.await
.expect("request");

assert_eq!(res.version(), http::Version::HTTP_3);
assert_eq!(res.status(), reqwest::StatusCode::OK);
}

#[cfg(all(feature = "http3", feature = "stream"))]
#[tokio::test]
async fn http3_request_stream_error() {
use http_body_util::BodyExt;

let server = server::Http3::new().build(move |req| async move {
// HTTP/3 response can start and finish before the entire request body has been received.
// To avoid prematurely terminating the session, collect full request body before responding.
let _ = req.collect().await;

http::Response::default()
});

let url = format!("https://{}", server.addr());
let body = reqwest::Body::wrap_stream(futures_util::stream::iter(vec![
Ok::<_, std::io::Error>("first chunk"),
Err::<_, std::io::Error>(std::io::Error::other("oh no!")),
]));

let res = reqwest::Client::builder()
.http3_prior_knowledge()
.danger_accept_invalid_certs(true)
.build()
.expect("client builder")
.post(url)
.version(http::Version::HTTP_3)
.body(body)
.send()
.await;

let err = res.unwrap_err();
assert!(err.is_request());
let err = err
.source()
.unwrap()
.source()
.unwrap()
.downcast_ref::<reqwest::Error>()
.unwrap();
assert!(err.is_body());
}
Loading