Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ jobs:
toolchain: 'stable'

- name: Check
run: cargo test --features http3
run: cargo test --features http3,stream
env:
RUSTFLAGS: --cfg reqwest_unstable
RUSTDOCFLAGS: --cfg reqwest_unstable
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ once_cell = "1.18"
log = "0.4.17"
mime = "0.3.16"
percent-encoding = "2.3"
tokio = { version = "1.0", default-features = false, features = ["net", "time"] }
tokio = { version = "1.0", default-features = false, features = ["net", "time", "macros"] }
tower = { version = "0.5.2", default-features = false, features = ["timeout", "util"] }
pin-project-lite = "0.2.11"
ipnet = "2.3"
Expand Down
70 changes: 55 additions & 15 deletions src/async_impl/h3_client/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::sync::mpsc::{Receiver, TryRecvError};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::watch;
use tokio::sync::{oneshot, watch};
use tokio::time::Instant;

use crate::async_impl::body::ResponseBody;
Expand All @@ -17,7 +17,7 @@ use h3::client::SendRequest;
use h3_quinn::{Connection, OpenStreams};
use http::uri::{Authority, Scheme};
use http::{Request, Response, Uri};
use log::trace;
use log::{error, trace};

pub(super) type Key = (Scheme, Authority);

Expand Down Expand Up @@ -209,7 +209,7 @@ impl PoolClient {
) -> Result<Response<ResponseBody>, BoxError> {
use hyper::body::Body as _;

let (head, req_body) = req.into_parts();
let (head, mut req_body) = req.into_parts();
let mut req = Request::from_parts(head, ());

if let Some(n) = req_body.size_hint().exact() {
Expand All @@ -219,22 +219,52 @@ impl PoolClient {
}
}

let mut stream = self.inner.send_request(req).await?;
let (mut send, mut recv) = self.inner.send_request(req).await?.split();

match req_body.as_bytes() {
Some(b) if !b.is_empty() => {
stream.send_data(Bytes::copy_from_slice(b)).await?;
let (tx, mut rx) = oneshot::channel::<Result<(), BoxError>>();
tokio::spawn(async move {
let mut req_body = Pin::new(&mut req_body);
loop {
match std::future::poll_fn(|cx| req_body.as_mut().poll_frame(cx)).await {
Some(Ok(frame)) => {
if let Ok(b) = frame.into_data() {
if let Err(e) = send.send_data(Bytes::copy_from_slice(&b)).await {
if let Err(e) = tx.send(Err(e.into())) {
error!("Failed to communicate send.send_data() error: {e:?}");
}
return;
}
}
}
Some(Err(e)) => {
if let Err(e) = tx.send(Err(e.into())) {
error!("Failed to communicate req_body read error: {e:?}");
}
return;
}

None => break,
}
}
_ => {}
}

stream.finish().await?;

let resp = stream.recv_response().await?;
if let Err(e) = send.finish().await {
if let Err(e) = tx.send(Err(e.into())) {
error!("Failed to communicate send.finish read error: {e:?}");
}
return;
}

let resp_body = crate::async_impl::body::boxed(Incoming::new(stream, resp.headers()));
let _ = tx.send(Ok(()));
});

Ok(resp.map(|_| resp_body))
tokio::select! {
Ok(Err(e)) = &mut rx => Err(e),
resp = recv.recv_response() => {
let resp = resp?;
let resp_body = crate::async_impl::body::boxed(Incoming::new(recv, resp.headers(), rx));
Ok(resp.map(|_| resp_body))
}
}
}
}

Expand Down Expand Up @@ -271,16 +301,22 @@ impl PoolConnection {
struct Incoming<S, B> {
inner: h3::client::RequestStream<S, B>,
content_length: Option<u64>,
send_rx: oneshot::Receiver<Result<(), BoxError>>,
}

impl<S, B> Incoming<S, B> {
fn new(stream: h3::client::RequestStream<S, B>, headers: &http::header::HeaderMap) -> Self {
fn new(
stream: h3::client::RequestStream<S, B>,
headers: &http::header::HeaderMap,
send_rx: oneshot::Receiver<Result<(), BoxError>>,
) -> Self {
Self {
inner: stream,
content_length: headers
.get(http::header::CONTENT_LENGTH)
.and_then(|h| h.to_str().ok())
.and_then(|v| v.parse().ok()),
send_rx,
}
}
}
Expand All @@ -296,6 +332,10 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
if let Ok(Err(e)) = self.send_rx.try_recv() {
return Poll::Ready(Some(Err(crate::error::body(e))));
}

match futures_core::ready!(self.inner.poll_recv_data(cx)) {
Ok(Some(mut b)) => Poll::Ready(Some(Ok(hyper::body::Frame::data(
b.copy_to_bytes(b.remaining()),
Expand Down
76 changes: 76 additions & 0 deletions tests/http3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,79 @@ async fn http3_test_reconnection() {
assert_eq!(res.status(), reqwest::StatusCode::OK);
drop(server);
}

#[cfg(all(feature = "http3", feature = "stream"))]
#[tokio::test]
async fn http3_request_stream() {
use http_body_util::BodyExt;

let server = server::Http3::new().build(move |req| async move {
let reqb = req.collect().await.unwrap().to_bytes();
assert_eq!(reqb, "hello world");
http::Response::default()
});

let url = format!("https://{}", server.addr());
let body = reqwest::Body::wrap_stream(futures_util::stream::iter(vec![
Ok::<_, std::convert::Infallible>("hello"),
Ok::<_, std::convert::Infallible>(" "),
Ok::<_, std::convert::Infallible>("world"),
]));

let res = reqwest::Client::builder()
.http3_prior_knowledge()
.danger_accept_invalid_certs(true)
.build()
.expect("client builder")
.post(url)
.version(http::Version::HTTP_3)
.body(body)
.send()
.await
.expect("request");

assert_eq!(res.version(), http::Version::HTTP_3);
assert_eq!(res.status(), reqwest::StatusCode::OK);
}

#[cfg(all(feature = "http3", feature = "stream"))]
#[tokio::test]
async fn http3_request_stream_error() {
use http_body_util::BodyExt;

let server = server::Http3::new().build(move |req| async move {
// HTTP/3 response can start and finish before the entire request body has been received.
// To avoid prematurely terminating the session, collect full request body before responding.
let _ = req.collect().await;

http::Response::default()
});

let url = format!("https://{}", server.addr());
let body = reqwest::Body::wrap_stream(futures_util::stream::iter(vec![
Ok::<_, std::io::Error>("first chunk"),
Err::<_, std::io::Error>(std::io::Error::other("oh no!")),
]));

let res = reqwest::Client::builder()
.http3_prior_knowledge()
.danger_accept_invalid_certs(true)
.build()
.expect("client builder")
.post(url)
.version(http::Version::HTTP_3)
.body(body)
.send()
.await;

let err = res.unwrap_err();
assert!(err.is_request());
let err = err
.source()
.unwrap()
.source()
.unwrap()
.downcast_ref::<reqwest::Error>()
.unwrap();
assert!(err.is_body());
}
Loading