diff --git a/src/body/body.rs b/src/body/body.rs index 71c180f07a..7be740bc9f 100644 --- a/src/body/body.rs +++ b/src/body/body.rs @@ -35,6 +35,7 @@ pub struct Body { enum Kind { Once(Option), Chan { + content_length: Option, abort_rx: oneshot::Receiver<()>, rx: mpsc::Receiver>, }, @@ -85,6 +86,11 @@ impl Body { /// Useful when wanting to stream chunks from another thread. #[inline] pub fn channel() -> (Sender, Body) { + Self::new_channel(None) + } + + #[inline] + pub(crate) fn new_channel(content_length: Option) -> (Sender, Body) { let (tx, rx) = mpsc::channel(0); let (abort_tx, abort_rx) = oneshot::channel(); @@ -93,8 +99,9 @@ impl Body { tx: tx, }; let rx = Body::new(Kind::Chan { - abort_rx: abort_rx, - rx: rx, + content_length, + abort_rx, + rx, }); (tx, rx) @@ -188,13 +195,19 @@ impl Body { fn poll_inner(&mut self) -> Poll, ::Error> { match self.kind { Kind::Once(ref mut val) => Ok(Async::Ready(val.take())), - Kind::Chan { ref mut rx, ref mut abort_rx } => { + Kind::Chan { content_length: ref mut len, ref mut rx, ref mut abort_rx } => { if let Ok(Async::Ready(())) = abort_rx.poll() { return Err(::Error::new_body_write("body write aborted")); } match rx.poll().expect("mpsc cannot error") { - Async::Ready(Some(Ok(chunk))) => Ok(Async::Ready(Some(chunk))), + Async::Ready(Some(Ok(chunk))) => { + if let Some(ref mut len) = *len { + debug_assert!(*len >= chunk.len() as u64); + *len = *len - chunk.len() as u64; + } + Ok(Async::Ready(Some(chunk))) + } Async::Ready(Some(Err(err))) => Err(err), Async::Ready(None) => Ok(Async::Ready(None)), Async::NotReady => Ok(Async::NotReady), @@ -243,7 +256,7 @@ impl Payload for Body { fn is_end_stream(&self) -> bool { match self.kind { Kind::Once(ref val) => val.is_none(), - Kind::Chan { .. } => false, + Kind::Chan { content_length: len, .. } => len == Some(0), Kind::H2(ref h2) => h2.is_end_stream(), Kind::Wrapped(..) => false, } @@ -253,7 +266,7 @@ impl Payload for Body { match self.kind { Kind::Once(Some(ref val)) => Some(val.len() as u64), Kind::Once(None) => Some(0), - Kind::Chan { .. } => None, + Kind::Chan { content_length: len, .. } => len, Kind::H2(..) => None, Kind::Wrapped(..) => None, } diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 45c4050210..4c57e6e87e 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -114,7 +114,7 @@ where I: AsyncRead + AsyncWrite, read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE } - pub fn read_head(&mut self) -> Poll, bool)>, ::Error> { + pub fn read_head(&mut self) -> Poll, Option)>, ::Error> { debug_assert!(self.can_read_head()); trace!("Conn::read_head"); @@ -162,7 +162,6 @@ where I: AsyncRead + AsyncWrite, continue; } }; - debug!("incoming body is {}", decoder); self.state.busy(); @@ -172,20 +171,23 @@ where I: AsyncRead + AsyncWrite, } let wants_keep_alive = msg.keep_alive; self.state.keep_alive &= wants_keep_alive; - let (body, reading) = if decoder.is_eof() { - (false, Reading::KeepAlive) - } else { - (true, Reading::Body(decoder)) - }; + + let content_length = decoder.content_length(); + if let Reading::Closed = self.state.reading { // actually want an `if not let ...` } else { - self.state.reading = reading; + self.state.reading = if content_length.is_none() { + Reading::KeepAlive + } else { + Reading::Body(decoder) + }; } - if !body { + if content_length.is_none() { self.try_keep_alive(); } - return Ok(Async::Ready(Some((head, body)))); + + return Ok(Async::Ready(Some((head, content_length)))); } } diff --git a/src/proto/h1/decode.rs b/src/proto/h1/decode.rs index 4521547f77..03296878af 100644 --- a/src/proto/h1/decode.rs +++ b/src/proto/h1/decode.rs @@ -7,6 +7,7 @@ use futures::{Async, Poll}; use bytes::Bytes; use super::io::MemRead; +use super::BodyLength; use self::Kind::{Length, Chunked, Eof}; @@ -84,6 +85,16 @@ impl Decoder { } } + pub fn content_length(&self) -> Option { + match self.kind { + Length(0) | + Chunked(ChunkedState::End, _) | + Eof(true) => None, + Length(len) => Some(BodyLength::Known(len)), + _ => Some(BodyLength::Unknown), + } + } + pub fn decode(&mut self, body: &mut R) -> Poll { trace!("decode; state={:?}", self.kind); match self.kind { diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 5b34c7e4c8..ff411e81ee 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -190,9 +190,14 @@ where } // dispatch is ready for a message, try to read one match self.conn.read_head() { - Ok(Async::Ready(Some((head, has_body)))) => { - let body = if has_body { - let (mut tx, rx) = Body::channel(); + Ok(Async::Ready(Some((head, body_len)))) => { + let body = if let Some(body_len) = body_len { + let (mut tx, rx) = + Body::new_channel(if let BodyLength::Known(len) = body_len { + Some(len) + } else { + None + }); let _ = tx.poll_ready(); // register this task if rx is dropped self.body_tx = Some(tx); rx @@ -201,7 +206,7 @@ where }; self.dispatch.recv_msg(Ok((head, body)))?; Ok(Async::Ready(())) - }, + } Ok(Async::Ready(None)) => { // read eof, conn will start to shutdown automatically Ok(Async::Ready(())) diff --git a/tests/client.rs b/tests/client.rs index bd5958a7f8..5ea5f5a0a7 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1424,6 +1424,63 @@ mod conn { res.join(rx).map(|r| r.0).wait().unwrap(); } + #[test] + fn incoming_content_length() { + use hyper::body::Payload; + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut runtime = Runtime::new().unwrap(); + + let (tx1, rx1) = oneshot::channel(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + let n = sock.read(&mut buf).expect("read 1"); + + let expected = "GET / HTTP/1.1\r\n\r\n"; + assert_eq!(s(&buf[..n]), expected); + + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello").unwrap(); + let _ = tx1.send(()); + }); + + let tcp = tcp_connect(&addr).wait().unwrap(); + + let (mut client, conn) = conn::handshake(tcp).wait().unwrap(); + + runtime.spawn(conn.map(|_| ()).map_err(|e| panic!("conn error: {}", e))); + + let req = Request::builder() + .uri("/") + .body(Default::default()) + .unwrap(); + let res = client.send_request(req).and_then(move |mut res| { + assert_eq!(res.status(), hyper::StatusCode::OK); + assert_eq!(res.body().content_length(), Some(5)); + assert!(!res.body().is_end_stream()); + loop { + let chunk = res.body_mut().poll_data().unwrap(); + match chunk { + Async::Ready(Some(chunk)) => { + assert_eq!(chunk.len(), 5); + break; + } + _ => continue + } + } + res.into_body().concat2() + }); + let rx = rx1.expect("thread panicked"); + + let timeout = Delay::new(Duration::from_millis(200)); + let rx = rx.and_then(move |_| timeout.expect("timeout")); + res.join(rx).map(|r| r.0).wait().unwrap(); + } + #[test] fn aborted_body_isnt_completed() { let _ = ::pretty_env_logger::try_init();