From 75c71170206db3119d9b298ea5cf3ee860803124 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 1 Sep 2015 16:44:58 -0700 Subject: [PATCH] fix(client): be resilient to invalid response bodies When an Http11Message knows that the previous response should not have included a body per RFC7230, and fails to parse the following response, the bytes are shuffled along, checking for the start of the next response. Closes #640 --- src/client/mod.rs | 29 +++++++++++++++ src/client/pool.rs | 58 ++++++++++++++++++++--------- src/client/response.rs | 7 ++-- src/http/h1.rs | 84 +++++++++++++++++++++++++++++------------- src/mock.rs | 51 ++++++++++++++++++++----- src/net.rs | 11 ++++++ 6 files changed, 183 insertions(+), 57 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 955320c370..ad9bb8596f 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -446,8 +446,10 @@ fn get_host_and_port(url: &Url) -> ::Result<(String, u16)> { #[cfg(test)] mod tests { + use std::io::Read; use header::Server; use super::{Client, RedirectPolicy}; + use super::pool::Pool; use url::Url; mock_connector!(MockRedirectPolicy { @@ -494,4 +496,31 @@ mod tests { let res = client.get("http://127.0.0.1").send().unwrap(); assert_eq!(res.headers.get(), Some(&Server("mock2".to_owned()))); } + + mock_connector!(Issue640Connector { + b"HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\n", + b"GET", + b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\n", + b"HEAD", + b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\n", + b"POST" + }); + + // see issue #640 + #[test] + fn test_head_response_body_keep_alive() { + let client = Client::with_connector(Pool::with_connector(Default::default(), Issue640Connector)); + + let mut s = String::new(); + client.get("http://127.0.0.1").send().unwrap().read_to_string(&mut s).unwrap(); + assert_eq!(s, "GET"); + + let mut s = String::new(); + client.head("http://127.0.0.1").send().unwrap().read_to_string(&mut s).unwrap(); + assert_eq!(s, ""); + + let mut s = String::new(); + client.post("http://127.0.0.1").send().unwrap().read_to_string(&mut s).unwrap(); + assert_eq!(s, "POST"); + } } diff --git a/src/client/pool.rs b/src/client/pool.rs index d2fe6fa592..1012f23d05 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -34,7 +34,7 @@ impl Default for Config { #[derive(Debug)] struct PoolImpl { - conns: HashMap>, + conns: HashMap>>, config: Config, } @@ -90,7 +90,7 @@ impl Pool { } impl PoolImpl { - fn reuse(&mut self, key: Key, conn: S) { + fn reuse(&mut self, key: Key, conn: PooledStreamInner) { trace!("reuse {:?}", key); let conns = self.conns.entry(key).or_insert(vec![]); if conns.len() < self.config.max_idle { @@ -105,73 +105,97 @@ impl, S: NetworkStream + Send> NetworkConnector fo let key = key(host, port, scheme); let mut locked = self.inner.lock().unwrap(); let mut should_remove = false; - let conn = match locked.conns.get_mut(&key) { + let inner = match locked.conns.get_mut(&key) { Some(ref mut vec) => { trace!("Pool had connection, using"); should_remove = vec.len() == 1; vec.pop().unwrap() } - _ => try!(self.connector.connect(host, port, scheme)) + _ => PooledStreamInner { + key: key.clone(), + stream: try!(self.connector.connect(host, port, scheme)), + previous_response_expected_no_content: false, + } }; if should_remove { locked.conns.remove(&key); } Ok(PooledStream { - inner: Some((key, conn)), + inner: Some(inner), is_closed: false, - pool: self.inner.clone() + pool: self.inner.clone(), }) } } /// A Stream that will try to be returned to the Pool when dropped. pub struct PooledStream { - inner: Option<(Key, S)>, + inner: Option>, is_closed: bool, - pool: Arc>> + pool: Arc>>, +} + +#[derive(Debug)] +struct PooledStreamInner { + key: Key, + stream: S, + previous_response_expected_no_content: bool, } impl Read for PooledStream { #[inline] fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.as_mut().unwrap().1.read(buf) + self.inner.as_mut().unwrap().stream.read(buf) } } impl Write for PooledStream { #[inline] fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.as_mut().unwrap().1.write(buf) + self.inner.as_mut().unwrap().stream.write(buf) } #[inline] fn flush(&mut self) -> io::Result<()> { - self.inner.as_mut().unwrap().1.flush() + self.inner.as_mut().unwrap().stream.flush() } } impl NetworkStream for PooledStream { #[inline] fn peer_addr(&mut self) -> io::Result { - self.inner.as_mut().unwrap().1.peer_addr() + self.inner.as_mut().unwrap().stream.peer_addr() } #[cfg(feature = "timeouts")] #[inline] fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.inner.as_ref().unwrap().1.set_read_timeout(dur) + self.inner.as_ref().unwrap().stream.set_read_timeout(dur) } #[cfg(feature = "timeouts")] #[inline] fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.inner.as_ref().unwrap().1.set_write_timeout(dur) + self.inner.as_ref().unwrap().stream.set_write_timeout(dur) } #[inline] fn close(&mut self, how: Shutdown) -> io::Result<()> { self.is_closed = true; - self.inner.as_mut().unwrap().1.close(how) + self.inner.as_mut().unwrap().stream.close(how) + } + + #[inline] + fn set_previous_response_expected_no_content(&mut self, expected: bool) { + trace!("set_previous_response_expected_no_content {}", expected); + self.inner.as_mut().unwrap().previous_response_expected_no_content = expected; + } + + #[inline] + fn previous_response_expected_no_content(&self) -> bool { + let answer = self.inner.as_ref().unwrap().previous_response_expected_no_content; + trace!("previous_response_expected_no_content {}", answer); + answer } } @@ -179,9 +203,9 @@ impl Drop for PooledStream { fn drop(&mut self) { trace!("PooledStream.drop, is_closed={}", self.is_closed); if !self.is_closed { - self.inner.take().map(|(key, conn)| { + self.inner.take().map(|inner| { if let Ok(mut pool) = self.pool.lock() { - pool.reuse(key, conn); + pool.reuse(inner.key.clone(), inner); } // else poisoned, give up }); diff --git a/src/client/response.rs b/src/client/response.rs index 8b14ef9263..89376e4af0 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -64,7 +64,6 @@ impl Response { pub fn status_raw(&self) -> &RawStatus { &self.status_raw } - } impl Read for Response { @@ -91,11 +90,11 @@ impl Drop for Response { // // otherwise, the response has been drained. we should check that the // server has agreed to keep the connection open - trace!("Response.is_drained = {:?}", self.is_drained); + trace!("Response.drop is_drained={}", self.is_drained); if !(self.is_drained && http::should_keep_alive(self.version, &self.headers)) { - trace!("closing connection"); + trace!("Response.drop closing connection"); if let Err(e) = self.message.close_connection() { - error!("error closing connection: {}", e); + error!("Response.drop error closing connection: {}", e); } } } diff --git a/src/http/h1.rs b/src/http/h1.rs index 1208cbeb7d..921037486f 100644 --- a/src/http/h1.rs +++ b/src/http/h1.rs @@ -33,6 +33,8 @@ use http::{ use header; use version; +const MAX_INVALID_RESPONSE_BYTES: usize = 1024 * 128; + /// An implementation of the `HttpMessage` trait for HTTP/1.1. #[derive(Debug)] pub struct Http11Message { @@ -169,19 +171,38 @@ impl HttpMessage for Http11Message { } }; + let expected_no_content = stream.previous_response_expected_no_content(); + trace!("previous_response_expected_no_content = {}", expected_no_content); + let mut stream = BufReader::new(stream); - let head = match parse_response(&mut stream) { - Ok(head) => head, - Err(e) => { - self.stream = Some(stream.into_inner()); - return Err(e); - } - }; + let mut invalid_bytes_read = 0; + let head; + loop { + head = match parse_response(&mut stream) { + Ok(head) => head, + Err(::Error::Version) + if expected_no_content && invalid_bytes_read < MAX_INVALID_RESPONSE_BYTES => { + trace!("expected_no_content, found content"); + invalid_bytes_read += 1; + stream.consume(1); + continue; + } + Err(e) => { + self.stream = Some(stream.into_inner()); + return Err(e); + } + }; + break; + } + let raw_status = head.subject; let headers = head.headers; let method = self.method.take().unwrap_or(Method::Get); + + let is_empty = !should_have_response_body(&method, raw_status.0); + stream.get_mut().set_previous_response_expected_no_content(is_empty); // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 // 1. HEAD reponses, and Status 1xx, 204, and 304 cannot have a body. // 2. Status 2xx to a CONNECT cannot have a body. @@ -190,27 +211,24 @@ impl HttpMessage for Http11Message { // 5. Content-Length header has a sized body. // 6. Not Client. // 7. Read till EOF. - self.reader = Some(match (method, raw_status.0) { - (Method::Head, _) => EmptyReader(stream), - (_, 100...199) | (_, 204) | (_, 304) => EmptyReader(stream), - (Method::Connect, 200...299) => EmptyReader(stream), - _ => { - if let Some(&TransferEncoding(ref codings)) = headers.get() { - if codings.last() == Some(&Chunked) { - ChunkedReader(stream, None) - } else { - trace!("not chuncked. read till eof"); - EofReader(stream) - } - } else if let Some(&ContentLength(len)) = headers.get() { - SizedReader(stream, len) - } else if headers.has::() { - trace!("illegal Content-Length: {:?}", headers.get_raw("Content-Length")); - return Err(Error::Header); + self.reader = Some(if is_empty { + EmptyReader(stream) + } else { + if let Some(&TransferEncoding(ref codings)) = headers.get() { + if codings.last() == Some(&Chunked) { + ChunkedReader(stream, None) } else { - trace!("neither Transfer-Encoding nor Content-Length"); + trace!("not chuncked. read till eof"); EofReader(stream) } + } else if let Some(&ContentLength(len)) = headers.get() { + SizedReader(stream, len) + } else if headers.has::() { + trace!("illegal Content-Length: {:?}", headers.get_raw("Content-Length")); + return Err(Error::Header); + } else { + trace!("neither Transfer-Encoding nor Content-Length"); + EofReader(stream) } }); @@ -226,7 +244,9 @@ impl HttpMessage for Http11Message { fn has_body(&self) -> bool { match self.reader { - Some(EmptyReader(..)) => false, + Some(EmptyReader(..)) | + Some(SizedReader(_, 0)) | + Some(ChunkedReader(_, Some(0))) => false, _ => true } } @@ -597,6 +617,18 @@ fn read_chunk_size(rdr: &mut R) -> io::Result { Ok(size) } +fn should_have_response_body(method: &Method, status: u16) -> bool { + trace!("should_have_response_body({:?}, {})", method, status); + match (method, status) { + (&Method::Head, _) | + (_, 100...199) | + (_, 204) | + (_, 304) | + (&Method::Connect, 200...299) => false, + _ => true + } +} + /// Writers to handle different Transfer-Encodings. pub enum HttpWriter { /// A no-op Writer, used initially before Transfer-Encoding is determined. diff --git a/src/mock.rs b/src/mock.rs index 5111b13edf..9f4ec88448 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -19,6 +19,7 @@ use net::{NetworkStream, NetworkConnector}; #[derive(Clone, Debug)] pub struct MockStream { pub read: Cursor>, + next_reads: Vec>, pub write: Vec, pub is_closed: bool, pub error_on_write: bool, @@ -40,27 +41,33 @@ impl MockStream { MockStream::with_input(b"") } - #[cfg(not(feature = "timeouts"))] pub fn with_input(input: &[u8]) -> MockStream { + MockStream::with_responses(vec![input]) + } + + #[cfg(feature = "timeouts")] + pub fn with_responses(mut responses: Vec<&[u8]>) -> MockStream { MockStream { - read: Cursor::new(input.to_vec()), + read: Cursor::new(responses.remove(0).to_vec()), + next_reads: responses.into_iter().map(|arr| arr.to_vec()).collect(), write: vec![], is_closed: false, error_on_write: false, error_on_read: false, + read_timeout: Cell::new(None), + write_timeout: Cell::new(None), } } - #[cfg(feature = "timeouts")] - pub fn with_input(input: &[u8]) -> MockStream { + #[cfg(not(feature = "timeouts"))] + pub fn with_responses(mut responses: Vec<&[u8]>) -> MockStream { MockStream { - read: Cursor::new(input.to_vec()), + read: Cursor::new(responses.remove(0).to_vec()), + next_reads: responses.into_iter().map(|arr| arr.to_vec()).collect(), write: vec![], is_closed: false, error_on_write: false, error_on_read: false, - read_timeout: Cell::new(None), - write_timeout: Cell::new(None), } } } @@ -70,7 +77,17 @@ impl Read for MockStream { if self.error_on_read { Err(io::Error::new(io::ErrorKind::Other, "mock error")) } else { - self.read.read(buf) + match self.read.read(buf) { + Ok(n) => { + if self.read.position() as usize == self.read.get_ref().len() { + if self.next_reads.len() > 0 { + self.read = Cursor::new(self.next_reads.remove(0)); + } + } + Ok(n) + }, + r => r + } } } } @@ -191,7 +208,7 @@ macro_rules! mock_connector ( struct $name; - impl ::net::NetworkConnector for $name { + impl $crate::net::NetworkConnector for $name { type Stream = ::mock::MockStream; fn connect(&self, host: &str, port: u16, scheme: &str) -> $crate::Result<::mock::MockStream> { @@ -210,7 +227,21 @@ macro_rules! mock_connector ( } } - ) + ); + + ($name:ident { $($response:expr),+ }) => ( + struct $name; + + impl $crate::net::NetworkConnector for $name { + type Stream = $crate::mock::MockStream; + fn connect(&self, _: &str, _: u16, _: &str) + -> $crate::Result<$crate::mock::MockStream> { + Ok($crate::mock::MockStream::with_responses(vec![ + $($response),+ + ])) + } + } + ); ); impl TransportStream for MockStream { diff --git a/src/net.rs b/src/net.rs index 602bfee946..98b6b41e34 100644 --- a/src/net.rs +++ b/src/net.rs @@ -62,6 +62,17 @@ pub trait NetworkStream: Read + Write + Any + Send + Typeable { fn close(&mut self, _how: Shutdown) -> io::Result<()> { Ok(()) } + + // Unsure about name and implementation... + + #[doc(hidden)] + fn set_previous_response_expected_no_content(&mut self, _expected: bool) { + + } + #[doc(hidden)] + fn previous_response_expected_no_content(&self) -> bool { + false + } } /// A connector creates a NetworkStream.