From 8ba9a8d2c4bab0f44b3f94a326b3b91c82d7877e Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Thu, 5 Dec 2019 17:51:37 -0800 Subject: [PATCH] feat(body): add `body::aggregate` and `body::to_bytes` functions Adds utility functions to `hyper::body` to help asynchronously collecting all the buffers of some `HttpBody` into one. - `aggregate` will collect all into an `impl Buf` without copying the contents. This is ideal if you don't need a contiguous buffer. - `to_bytes` will copy all the data into a single contiguous `Bytes` buffer. --- Cargo.toml | 7 +++- benches/body.rs | 89 +++++++++++++++++++++++++++++++++++++++++ examples/client.rs | 2 + examples/client_json.rs | 15 ++++--- examples/echo.rs | 11 ++--- examples/params.rs | 8 +--- examples/web_api.rs | 26 ++++++------ src/body/aggregate.rs | 25 ++++++++++++ src/body/mod.rs | 5 +++ src/body/to_bytes.rs | 36 +++++++++++++++++ src/common/buf.rs | 75 ++++++++++++++++++++++++++++++++++ src/common/mod.rs | 1 + src/proto/h1/io.rs | 88 +++++++--------------------------------- tests/client.rs | 10 +---- tests/support/mod.rs | 12 +----- 15 files changed, 282 insertions(+), 128 deletions(-) create mode 100644 benches/body.rs create mode 100644 src/body/aggregate.rs create mode 100644 src/body/to_bytes.rs create mode 100644 src/common/buf.rs diff --git a/Cargo.toml b/Cargo.toml index 5732c5a1ce..2b314931d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -98,7 +98,7 @@ required-features = ["runtime"] [[example]] name = "client_json" path = "examples/client_json.rs" -required-features = ["runtime", "stream"] +required-features = ["runtime"] [[example]] name = "echo" @@ -162,6 +162,11 @@ path = "examples/web_api.rs" required-features = ["runtime", "stream"] +[[bench]] +name = "body" +path = "benches/body.rs" +required-features = ["runtime", "stream"] + [[bench]] name = "connect" path = "benches/connect.rs" diff --git a/benches/body.rs b/benches/body.rs new file mode 100644 index 0000000000..6c25dfbe2c --- /dev/null +++ b/benches/body.rs @@ -0,0 +1,89 @@ +#![feature(test)] +#![deny(warnings)] + +extern crate test; + +use bytes::Buf; +use futures_util::stream; +use futures_util::StreamExt; +use hyper::body::Body; + +macro_rules! bench_stream { + ($bencher:ident, bytes: $bytes:expr, count: $count:expr, $total_ident:ident, $body_pat:pat, $block:expr) => {{ + let mut rt = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .expect("rt build"); + + let $total_ident: usize = $bytes * $count; + $bencher.bytes = $total_ident as u64; + let __s: &'static [&'static [u8]] = &[&[b'x'; $bytes] as &[u8]; $count] as _; + + $bencher.iter(|| { + rt.block_on(async { + let $body_pat = Body::wrap_stream( + stream::iter(__s.iter()).map(|&s| Ok::<_, std::convert::Infallible>(s)), + ); + $block; + }); + }); + }}; +} + +macro_rules! benches { + ($($name:ident, $bytes:expr, $count:expr;)+) => ( + mod aggregate { + use super::*; + + $( + #[bench] + fn $name(b: &mut test::Bencher) { + bench_stream!(b, bytes: $bytes, count: $count, total, body, { + let buf = hyper::body::aggregate(body).await.unwrap(); + assert_eq!(buf.remaining(), total); + }); + } + )+ + } + + mod manual_into_vec { + use super::*; + + $( + #[bench] + fn $name(b: &mut test::Bencher) { + bench_stream!(b, bytes: $bytes, count: $count, total, mut body, { + let mut vec = Vec::new(); + while let Some(chunk) = body.next().await { + vec.extend_from_slice(&chunk.unwrap()); + } + assert_eq!(vec.len(), total); + }); + } + )+ + } + + mod to_bytes { + use super::*; + + $( + #[bench] + fn $name(b: &mut test::Bencher) { + bench_stream!(b, bytes: $bytes, count: $count, total, body, { + let bytes = hyper::body::to_bytes(body).await.unwrap(); + assert_eq!(bytes.len(), total); + }); + } + )+ + } + ) +} + +// ===== Actual Benchmarks ===== + +benches! { + bytes_1_000_count_2, 1_000, 2; + bytes_1_000_count_10, 1_000, 10; + bytes_10_000_count_1, 10_000, 1; + bytes_10_000_count_10, 10_000, 10; +} diff --git a/examples/client.rs b/examples/client.rs index de4e4ec3be..de52d8a706 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -40,6 +40,8 @@ async fn fetch_url(url: hyper::Uri) -> Result<()> { println!("Response: {}", res.status()); println!("Headers: {:#?}\n", res.headers()); + // Stream the body, writing each chunk to stdout as we get it + // (instead of buffering and printing at the end). while let Some(next) = res.body_mut().data().await { let chunk = next?; io::stdout().write_all(&chunk).await?; diff --git a/examples/client_json.rs b/examples/client_json.rs index 98a79682d8..9027f05e31 100644 --- a/examples/client_json.rs +++ b/examples/client_json.rs @@ -4,7 +4,7 @@ #[macro_use] extern crate serde_derive; -use futures_util::StreamExt; +use bytes::buf::BufExt as _; use hyper::Client; // A simple type alias so as to DRY. @@ -27,14 +27,13 @@ async fn fetch_json(url: hyper::Uri) -> Result> { let client = Client::new(); // Fetch the url... - let mut res = client.get(url).await?; - // asynchronously concatenate chunks of the body - let mut body = Vec::new(); - while let Some(chunk) = res.body_mut().next().await { - body.extend_from_slice(&chunk?); - } + let res = client.get(url).await?; + + // asynchronously aggregate the chunks of the body + let body = hyper::body::aggregate(res.into_body()).await?; + // try to parse as json with serde_json - let users = serde_json::from_slice(&body)?; + let users = serde_json::from_reader(body.reader())?; Ok(users) } diff --git a/examples/echo.rs b/examples/echo.rs index 3aab063abb..ff7573049e 100644 --- a/examples/echo.rs +++ b/examples/echo.rs @@ -1,12 +1,12 @@ -//#![deny(warnings)] +#![deny(warnings)] -use futures_util::{StreamExt, TryStreamExt}; +use futures_util::TryStreamExt; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Method, Request, Response, Server, StatusCode}; /// This is our service handler. It receives a Request, routes on its /// path, and returns a Future of a Response. -async fn echo(mut req: Request) -> Result, hyper::Error> { +async fn echo(req: Request) -> Result, hyper::Error> { match (req.method(), req.uri().path()) { // Serve some instructions at / (&Method::GET, "/") => Ok(Response::new(Body::from( @@ -34,10 +34,7 @@ async fn echo(mut req: Request) -> Result, hyper::Error> { // So here we do `.await` on the future, waiting on concatenating the full body, // then afterwards the content can be reversed. Only then can we return a `Response`. (&Method::POST, "/echo/reversed") => { - let mut whole_body = Vec::new(); - while let Some(chunk) = req.body_mut().next().await { - whole_body.extend_from_slice(&chunk?); - } + let whole_body = hyper::body::to_bytes(req.into_body()).await?; let reversed_body = whole_body.iter().rev().cloned().collect::>(); Ok(Response::new(Body::from(reversed_body))) diff --git a/examples/params.rs b/examples/params.rs index c2e08e1fd6..d3f2966e48 100644 --- a/examples/params.rs +++ b/examples/params.rs @@ -4,7 +4,6 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Method, Request, Response, Server, StatusCode}; -use futures_util::StreamExt; use std::collections::HashMap; use url::form_urlencoded; @@ -13,15 +12,12 @@ static MISSING: &[u8] = b"Missing field"; static NOTNUMERIC: &[u8] = b"Number field is not numeric"; // Using service_fn, we can turn this function into a `Service`. -async fn param_example(mut req: Request) -> Result, hyper::Error> { +async fn param_example(req: Request) -> Result, hyper::Error> { match (req.method(), req.uri().path()) { (&Method::GET, "/") | (&Method::GET, "/post") => Ok(Response::new(INDEX.into())), (&Method::POST, "/post") => { // Concatenate the body... - let mut b = Vec::new(); - while let Some(chunk) = req.body_mut().next().await { - b.extend_from_slice(&chunk?); - } + let b = hyper::body::to_bytes(req.into_body()).await?; // Parse the request body. form_urlencoded::parse // always succeeds, but in general parsing may // fail (for example, an invalid post of json), so diff --git a/examples/web_api.rs b/examples/web_api.rs index aad71af175..7f30c65dfe 100644 --- a/examples/web_api.rs +++ b/examples/web_api.rs @@ -1,6 +1,7 @@ #![deny(warnings)] -use futures_util::{StreamExt, TryStreamExt}; +use bytes::buf::BufExt; +use futures_util::{stream, StreamExt}; use hyper::client::HttpConnector; use hyper::service::{make_service_fn, service_fn}; use hyper::{header, Body, Client, Method, Request, Response, Server, StatusCode}; @@ -24,25 +25,24 @@ async fn client_request_response(client: &Client) -> ResultPOST request body: {}
Response: {}", + let before = stream::once(async { + Ok(format!( + "POST request body: {}
Response: ", POST_DATA, - std::str::from_utf8(&b).unwrap() ) - })); + .into()) + }); + let after = web_res.into_body(); + let body = Body::wrap_stream(before.chain(after)); Ok(Response::new(body)) } -async fn api_post_response(mut req: Request) -> Result> { - // Concatenate the body... - let mut whole_body = Vec::new(); - while let Some(chunk) = req.body_mut().next().await { - whole_body.extend_from_slice(&chunk?); - } +async fn api_post_response(req: Request) -> Result> { + // Aggregate the body... + let whole_body = hyper::body::aggregate(req.into_body()).await?; // Decode as JSON... - let mut data: serde_json::Value = serde_json::from_slice(&whole_body)?; + let mut data: serde_json::Value = serde_json::from_reader(whole_body.reader())?; // Change the JSON... data["test"] = serde_json::Value::from("test_value"); // And respond with the new JSON. diff --git a/src/body/aggregate.rs b/src/body/aggregate.rs new file mode 100644 index 0000000000..97b6c2d91f --- /dev/null +++ b/src/body/aggregate.rs @@ -0,0 +1,25 @@ +use bytes::Buf; + +use super::HttpBody; +use crate::common::buf::BufList; + +/// Aggregate the data buffers from a body asynchronously. +/// +/// The returned `impl Buf` groups the `Buf`s from the `HttpBody` without +/// copying them. This is ideal if you don't require a contiguous buffer. +pub async fn aggregate(body: T) -> Result +where + T: HttpBody, +{ + let mut bufs = BufList::new(); + + futures_util::pin_mut!(body); + while let Some(buf) = body.data().await { + let buf = buf?; + if buf.has_remaining() { + bufs.push(buf); + } + } + + Ok(bufs) +} diff --git a/src/body/mod.rs b/src/body/mod.rs index 1a28093a39..0d8d358617 100644 --- a/src/body/mod.rs +++ b/src/body/mod.rs @@ -18,11 +18,16 @@ pub use bytes::{Buf, Bytes}; pub use http_body::Body as HttpBody; +pub use self::aggregate::aggregate; pub use self::body::{Body, Sender}; +pub use self::to_bytes::to_bytes; + pub(crate) use self::payload::Payload; +mod aggregate; mod body; mod payload; +mod to_bytes; /// An optimization to try to take a full body if immediately available. /// diff --git a/src/body/to_bytes.rs b/src/body/to_bytes.rs new file mode 100644 index 0000000000..e631580fd2 --- /dev/null +++ b/src/body/to_bytes.rs @@ -0,0 +1,36 @@ +use bytes::{Buf, BufMut, Bytes}; + +use super::HttpBody; + +/// dox +pub async fn to_bytes(body: T) -> Result +where + T: HttpBody, +{ + futures_util::pin_mut!(body); + + // If there's only 1 chunk, we can just return Buf::to_bytes() + let mut first = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(Bytes::new()); + }; + + let second = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(first.to_bytes()); + }; + + // With more than 1 buf, we gotta flatten into a Vec first. + let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; + let mut vec = Vec::with_capacity(cap); + vec.put(first); + vec.put(second); + + while let Some(buf) = body.data().await { + vec.put(buf?); + } + + Ok(vec.into()) +} diff --git a/src/common/buf.rs b/src/common/buf.rs new file mode 100644 index 0000000000..4de4d947dd --- /dev/null +++ b/src/common/buf.rs @@ -0,0 +1,75 @@ +use std::collections::VecDeque; +use std::io::IoSlice; + +use bytes::Buf; + +pub(crate) struct BufList { + bufs: VecDeque, +} + +impl BufList { + pub(crate) fn new() -> BufList { + BufList { + bufs: VecDeque::new(), + } + } + + #[inline] + pub(crate) fn push(&mut self, buf: T) { + debug_assert!(buf.has_remaining()); + self.bufs.push_back(buf); + } + + #[inline] + pub(crate) fn bufs_cnt(&self) -> usize { + self.bufs.len() + } +} + +impl Buf for BufList { + #[inline] + fn remaining(&self) -> usize { + self.bufs.iter().map(|buf| buf.remaining()).sum() + } + + #[inline] + fn bytes(&self) -> &[u8] { + for buf in &self.bufs { + return buf.bytes(); + } + &[] + } + + #[inline] + fn advance(&mut self, mut cnt: usize) { + while cnt > 0 { + { + let front = &mut self.bufs[0]; + let rem = front.remaining(); + if rem > cnt { + front.advance(cnt); + return; + } else { + front.advance(rem); + cnt -= rem; + } + } + self.bufs.pop_front(); + } + } + + #[inline] + fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + if dst.is_empty() { + return 0; + } + let mut vecs = 0; + for buf in &self.bufs { + vecs += buf.bytes_vectored(&mut dst[vecs..]); + if vecs == dst.len() { + break; + } + } + vecs + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 28169a2f5c..394e549895 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -7,6 +7,7 @@ macro_rules! ready { }; } +pub(crate) mod buf; pub(crate) mod drain; pub(crate) mod exec; pub(crate) mod io; diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 75d3f355c3..7cfaaaa01b 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -1,6 +1,5 @@ use std::cell::Cell; use std::cmp; -use std::collections::VecDeque; use std::fmt; use std::io::{self, IoSlice}; @@ -8,6 +7,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use tokio::io::{AsyncRead, AsyncWrite}; use super::{Http1Transaction, ParseContext, ParsedMessage}; +use crate::common::buf::BufList; use crate::common::{task, Pin, Poll, Unpin}; /// The initial buffer size allocated before trying to read from IO. @@ -90,7 +90,7 @@ where pub fn set_write_strategy_flatten(&mut self) { // this should always be called only at construction time, // so this assert is here to catch myself - debug_assert!(self.write_buf.queue.bufs.is_empty()); + debug_assert!(self.write_buf.queue.bufs_cnt() == 0); self.write_buf.set_strategy(WriteStrategy::Flatten); } @@ -431,16 +431,16 @@ pub(super) struct WriteBuf { headers: Cursor>, max_buf_size: usize, /// Deque of user buffers if strategy is Queue - queue: BufDeque, + queue: BufList, strategy: WriteStrategy, } -impl WriteBuf { +impl WriteBuf { fn new() -> WriteBuf { WriteBuf { headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)), max_buf_size: DEFAULT_MAX_BUFFER_SIZE, - queue: BufDeque::new(), + queue: BufList::new(), strategy: WriteStrategy::Auto, } } @@ -479,7 +479,7 @@ where } } WriteStrategy::Auto | WriteStrategy::Queue => { - self.queue.bufs.push_back(buf.into()); + self.queue.push(buf.into()); } } } @@ -488,7 +488,7 @@ where match self.strategy { WriteStrategy::Flatten => self.remaining() < self.max_buf_size, WriteStrategy::Auto | WriteStrategy::Queue => { - self.queue.bufs.len() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size + self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size } } } @@ -608,66 +608,6 @@ enum WriteStrategy { Queue, } -struct BufDeque { - bufs: VecDeque, -} - -impl BufDeque { - fn new() -> BufDeque { - BufDeque { - bufs: VecDeque::new(), - } - } -} - -impl Buf for BufDeque { - #[inline] - fn remaining(&self) -> usize { - self.bufs.iter().map(|buf| buf.remaining()).sum() - } - - #[inline] - fn bytes(&self) -> &[u8] { - for buf in &self.bufs { - return buf.bytes(); - } - &[] - } - - #[inline] - fn advance(&mut self, mut cnt: usize) { - while cnt > 0 { - { - let front = &mut self.bufs[0]; - let rem = front.remaining(); - if rem > cnt { - front.advance(cnt); - return; - } else { - front.advance(rem); - cnt -= rem; - } - } - self.bufs.pop_front(); - } - } - - #[inline] - fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { - if dst.is_empty() { - return 0; - } - let mut vecs = 0; - for buf in &self.bufs { - vecs += buf.bytes_vectored(&mut dst[vecs..]); - if vecs == dst.len() { - break; - } - } - vecs - } -} - #[cfg(test)] mod tests { use super::*; @@ -871,12 +811,12 @@ mod tests { buffered.buffer(Cursor::new(b"world, ".to_vec())); buffered.buffer(Cursor::new(b"it's ".to_vec())); buffered.buffer(Cursor::new(b"hyper!".to_vec())); - assert_eq!(buffered.write_buf.queue.bufs.len(), 3); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); buffered.flush().unwrap(); assert_eq!(buffered.io, b"hello world, it's hyper!"); assert_eq!(buffered.io.num_writes(), 1); - assert_eq!(buffered.write_buf.queue.bufs.len(), 0); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); } */ @@ -896,7 +836,7 @@ mod tests { buffered.buffer(Cursor::new(b"world, ".to_vec())); buffered.buffer(Cursor::new(b"it's ".to_vec())); buffered.buffer(Cursor::new(b"hyper!".to_vec())); - assert_eq!(buffered.write_buf.queue.bufs.len(), 0); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); buffered.flush().await.expect("flush"); } @@ -921,11 +861,11 @@ mod tests { buffered.buffer(Cursor::new(b"world, ".to_vec())); buffered.buffer(Cursor::new(b"it's ".to_vec())); buffered.buffer(Cursor::new(b"hyper!".to_vec())); - assert_eq!(buffered.write_buf.queue.bufs.len(), 3); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); buffered.flush().await.expect("flush"); - assert_eq!(buffered.write_buf.queue.bufs.len(), 0); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); } #[tokio::test] @@ -949,11 +889,11 @@ mod tests { buffered.buffer(Cursor::new(b"world, ".to_vec())); buffered.buffer(Cursor::new(b"it's ".to_vec())); buffered.buffer(Cursor::new(b"hyper!".to_vec())); - assert_eq!(buffered.write_buf.queue.bufs.len(), 3); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); buffered.flush().await.expect("flush"); - assert_eq!(buffered.write_buf.queue.bufs.len(), 0); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); } #[cfg(feature = "nightly")] diff --git a/tests/client.rs b/tests/client.rs index 68076365cd..03c3913beb 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -11,12 +11,12 @@ use std::task::{Context, Poll}; use std::thread; use std::time::Duration; +use hyper::body::to_bytes as concat; use hyper::{Body, Client, Method, Request, StatusCode}; use futures_channel::oneshot; use futures_core::{Future, Stream, TryFuture}; use futures_util::future::{self, FutureExt, TryFutureExt}; -use futures_util::StreamExt; use tokio::net::TcpStream; use tokio::runtime::Runtime; @@ -28,14 +28,6 @@ fn tcp_connect(addr: &SocketAddr) -> impl Future Result { - let mut vec = Vec::new(); - while let Some(chunk) = body.next().await { - vec.extend_from_slice(&chunk?); - } - Ok(vec.into()) -} - macro_rules! test { ( name: $name:ident, diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 48863410c2..17095392f3 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -355,7 +355,7 @@ async fn async_test(cfg: __TestConfig) { func(&req.headers()); } let sbody = sreq.body; - concat(req.into_body()).map_ok(move |body| { + hyper::body::to_bytes(req.into_body()).map_ok(move |body| { assert_eq!(body.as_ref(), sbody.as_slice(), "client body"); let mut res = Response::builder() @@ -410,7 +410,7 @@ async fn async_test(cfg: __TestConfig) { for func in &cheaders { func(&res.headers()); } - concat(res.into_body()) + hyper::body::to_bytes(res.into_body()) }) .map_ok(move |body| { assert_eq!(body.as_ref(), cbody.as_slice(), "server body"); @@ -473,11 +473,3 @@ fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) { let proxy_addr = srv.local_addr(); (proxy_addr, srv.map(|res| res.expect("proxy error"))) } - -async fn concat(mut body: Body) -> Result { - let mut vec = Vec::new(); - while let Some(chunk) = body.next().await { - vec.extend_from_slice(&chunk?); - } - Ok(vec.into()) -}