diff --git a/src/body/body.rs b/src/body/body.rs index e50e9f123e..9c199fd2c8 100644 --- a/src/body/body.rs +++ b/src/body/body.rs @@ -5,8 +5,6 @@ use std::fmt; use bytes::Bytes; use futures_channel::mpsc; -#[cfg(any(feature = "http1", feature = "http2"))] -#[cfg(feature = "client")] use futures_channel::oneshot; use futures_core::Stream; // for mpsc::Receiver #[cfg(feature = "stream")] @@ -17,14 +15,16 @@ use http_body::{Body as HttpBody, SizeHint}; use super::DecodedLength; #[cfg(feature = "stream")] use crate::common::sync_wrapper::SyncWrapper; -use crate::common::{task, watch, Pin, Poll}; +use crate::common::Future; #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "client")] -use crate::common::{Future, Never}; +use crate::common::Never; +use crate::common::{task, watch, Pin, Poll}; #[cfg(all(feature = "http2", any(feature = "client", feature = "server")))] use crate::proto::h2::ping; type BodySender = mpsc::Sender>; +type TrailersSender = oneshot::Sender; /// A stream of `Bytes`, used when receiving bodies. /// @@ -43,7 +43,8 @@ enum Kind { Chan { content_length: DecodedLength, want_tx: watch::Sender, - rx: mpsc::Receiver>, + data_rx: mpsc::Receiver>, + trailers_rx: oneshot::Receiver, }, #[cfg(all(feature = "http2", any(feature = "client", feature = "server")))] H2 { @@ -106,7 +107,8 @@ enum DelayEof { #[must_use = "Sender does nothing unless sent on"] pub struct Sender { want_rx: watch::Receiver, - tx: BodySender, + data_tx: BodySender, + trailers_tx: Option, } const WANT_PENDING: usize = 1; @@ -137,7 +139,8 @@ impl Body { } pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Body) { - let (tx, rx) = mpsc::channel(0); + let (data_tx, data_rx) = mpsc::channel(0); + let (trailers_tx, trailers_rx) = oneshot::channel(); // If wanter is true, `Sender::poll_ready()` won't becoming ready // until the `Body` has been polled for data once. @@ -145,11 +148,16 @@ impl Body { let (want_tx, want_rx) = watch::channel(want); - let tx = Sender { want_rx, tx }; + let tx = Sender { + want_rx, + data_tx, + trailers_tx: Some(trailers_tx), + }; let rx = Body::new(Kind::Chan { content_length, want_tx, - rx, + data_rx, + trailers_rx, }); (tx, rx) @@ -282,12 +290,13 @@ impl Body { Kind::Once(ref mut val) => Poll::Ready(val.take().map(Ok)), Kind::Chan { content_length: ref mut len, - ref mut rx, + ref mut data_rx, ref mut want_tx, + .. } => { want_tx.send(WANT_READY); - match ready!(Pin::new(rx).poll_next(cx)?) { + match ready!(Pin::new(data_rx).poll_next(cx)?) { Some(chunk) => { len.sub_if(chunk.len() as u64); Poll::Ready(Some(Ok(chunk))) @@ -368,10 +377,15 @@ impl HttpBody for Body { } Err(e) => Poll::Ready(Err(crate::Error::new_h2(e))), }, - + Kind::Chan { + ref mut trailers_rx, + .. + } => match ready!(Pin::new(trailers_rx).poll(cx)) { + Ok(t) => Poll::Ready(Ok(Some(t))), + Err(_) => Poll::Ready(Ok(None)), + }, #[cfg(feature = "ffi")] Kind::Ffi(ref mut body) => body.poll_trailers(cx), - _ => Poll::Ready(Ok(None)), } } @@ -527,7 +541,7 @@ impl Sender { pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { // Check if the receiver end has tried polling for the body yet ready!(self.poll_want(cx)?); - self.tx + self.data_tx .poll_ready(cx) .map_err(|_| crate::Error::new_closed()) } @@ -545,14 +559,23 @@ impl Sender { futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await } - /// Send data on this channel when it is ready. + /// Send data on data channel when it is ready. pub async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> { self.ready().await?; - self.tx + self.data_tx .try_send(Ok(chunk)) .map_err(|_| crate::Error::new_closed()) } + /// Send trailers on trailers channel. + pub async fn send_trailers(&mut self, trailers: HeaderMap) -> crate::Result<()> { + let tx = match self.trailers_tx.take() { + Some(tx) => tx, + None => return Err(crate::Error::new_closed()), + }; + tx.send(trailers).map_err(|_| crate::Error::new_closed()) + } + /// Try to send data on this channel. /// /// # Errors @@ -566,7 +589,7 @@ impl Sender { /// that doesn't have an async context. If in an async context, prefer /// `send_data()` instead. pub fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> { - self.tx + self.data_tx .try_send(Ok(chunk)) .map_err(|err| err.into_inner().expect("just sent Ok")) } @@ -574,7 +597,7 @@ impl Sender { /// Aborts the body in an abnormal fashion. pub fn abort(self) { let _ = self - .tx + .data_tx // clone so the send works even if buffer is full .clone() .try_send(Err(crate::Error::new_body_write_aborted())); @@ -582,7 +605,7 @@ impl Sender { #[cfg(feature = "http1")] pub(crate) fn send_error(&mut self, err: crate::Error) { - let _ = self.tx.try_send(Err(err)); + let _ = self.data_tx.try_send(Err(err)); } } @@ -628,7 +651,7 @@ mod tests { assert_eq!( mem::size_of::(), - mem::size_of::() * 4, + mem::size_of::() * 5, "Sender" );