From 44c34ce9adc888916bd67656cc54c35f7908f536 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 23 Jan 2018 13:02:44 -0800 Subject: [PATCH] fix(server): error if Response code is 1xx Returning a Response from a Service with a 1xx StatusCode is not currently supported in hyper. It has always resulted in broken semantics. This patch simply errors better. - A Response with 1xx status is converted into a 500 response with no body. - An error is returned from the `server::Connection` to alert about the bad response. --- src/proto/conn.rs | 28 +++++++++++++++++++++++----- src/proto/dispatch.rs | 1 + src/proto/h1/parse.rs | 26 ++++++++++++++++++++------ src/proto/mod.rs | 2 +- tests/server.rs | 36 +++++++++++++++++++++++++++++++++++- 5 files changed, 80 insertions(+), 13 deletions(-) diff --git a/src/proto/conn.rs b/src/proto/conn.rs index 5f572db327..c3f8bff0bf 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -40,6 +40,7 @@ where I: AsyncRead + AsyncWrite, Conn { io: Buffered::new(io), state: State { + error: None, keep_alive: keep_alive, method: None, read_task: None, @@ -437,11 +438,18 @@ where I: AsyncRead + AsyncWrite, buf.extend_from_slice(pending.buf()); } } - let encoder = T::encode(head, body, &mut self.state.method, buf); - self.state.writing = if !encoder.is_eof() { - Writing::Body(encoder, None) - } else { - Writing::KeepAlive + self.state.writing = match T::encode(head, body, &mut self.state.method, buf) { + Ok(encoder) => { + if !encoder.is_eof() { + Writing::Body(encoder, None) + } else { + Writing::KeepAlive + } + }, + Err(err) => { + self.state.error = Some(err); + Writing::Closed + } }; } @@ -626,6 +634,14 @@ where I: AsyncRead + AsyncWrite, self.state.disable_keep_alive(); } } + + pub fn take_error(&mut self) -> ::Result<()> { + if let Some(err) = self.state.error.take() { + Err(err) + } else { + Ok(()) + } + } } // ==== tokio_proto impl ==== @@ -736,6 +752,7 @@ impl, T, K: KeepAlive> fmt::Debug for Conn { } struct State { + error: Option<::Error>, keep_alive: K, method: Option, read_task: Option, @@ -767,6 +784,7 @@ impl, K: KeepAlive> fmt::Debug for State { .field("reading", &self.reading) .field("writing", &self.writing) .field("keep_alive", &self.keep_alive.status()) + .field("error", &self.error) //.field("method", &self.method) .field("read_task", &self.read_task) .finish() diff --git a/src/proto/dispatch.rs b/src/proto/dispatch.rs index 91c95c176d..e0f5475906 100644 --- a/src/proto/dispatch.rs +++ b/src/proto/dispatch.rs @@ -73,6 +73,7 @@ where if self.is_done() { try_ready!(self.conn.shutdown()); + self.conn.take_error()?; trace!("Dispatch::poll done"); Ok(Async::Ready(())) } else { diff --git a/src/proto/h1/parse.rs b/src/proto/h1/parse.rs index 3c6363b623..0ca640cc68 100644 --- a/src/proto/h1/parse.rs +++ b/src/proto/h1/parse.rs @@ -111,10 +111,23 @@ impl Http1Transaction for ServerTransaction { } - fn encode(mut head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> Encoder { + fn encode(mut head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> ::Result { trace!("ServerTransaction::encode has_body={}, method={:?}", has_body, method); - let body = ServerTransaction::set_length(&mut head, has_body, method.as_ref()); + // hyper currently doesn't support returning 1xx status codes as a Response + // This is because Service only allows returning a single Response, and + // so if you try to reply with a e.g. 100 Continue, you have no way of + // replying with the latter status code response. + let ret = if head.subject.is_informational() { + error!("response with 1xx status code not supported"); + head = MessageHead::default(); + head.subject = ::StatusCode::InternalServerError; + head.headers.set(ContentLength(0)); + Err(::Error::Status) + } else { + Ok(ServerTransaction::set_length(&mut head, has_body, method.as_ref())) + }; + let init_cap = 30 + head.headers.len() * AVERAGE_HEADER_SIZE; dst.reserve(init_cap); @@ -133,7 +146,8 @@ impl Http1Transaction for ServerTransaction { extend(dst, b"\r\n"); } extend(dst, b"\r\n"); - body + + ret } fn should_error_on_parse_eof() -> bool { @@ -289,7 +303,7 @@ impl Http1Transaction for ClientTransaction { } } - fn encode(mut head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> Encoder { + fn encode(mut head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> ::Result { trace!("ClientTransaction::encode has_body={}, method={:?}", has_body, method); *method = Some(head.subject.0.clone()); @@ -300,7 +314,7 @@ impl Http1Transaction for ClientTransaction { dst.reserve(init_cap); let _ = write!(FastWrite(dst), "{} {}\r\n{}\r\n", head.subject, head.version, head.headers); - body + Ok(body) } fn should_error_on_parse_eof() -> bool { @@ -645,7 +659,7 @@ mod tests { b.iter(|| { let mut vec = Vec::new(); - ServerTransaction::encode(head.clone(), true, &mut None, &mut vec); + ServerTransaction::encode(head.clone(), true, &mut None, &mut vec).unwrap(); assert_eq!(vec.len(), len); ::test::black_box(vec); }) diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 5bd8ead5b3..a562b3804c 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -148,7 +148,7 @@ pub trait Http1Transaction { type Outgoing: Default; fn parse(bytes: &mut BytesMut) -> ParseResult; fn decoder(head: &MessageHead, method: &mut Option<::Method>) -> ::Result>; - fn encode(head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> h1::Encoder; + fn encode(head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> ::Result; fn should_error_on_parse_eof() -> bool; fn should_read_first() -> bool; diff --git a/tests/server.rs b/tests/server.rs index ec9fd76d3e..159271802d 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -22,7 +22,8 @@ use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; -use hyper::server::{Http, Request, Response, Service, NewService}; +use hyper::StatusCode; +use hyper::server::{Http, Request, Response, Service, NewService, service_fn}; #[test] @@ -867,6 +868,38 @@ fn nonempty_parse_eof_returns_error() { core.run(fut).unwrap_err(); } +#[test] +fn returning_1xx_response_is_error() { + let mut core = Core::new().unwrap(); + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + tcp.write_all(b"GET / HTTP/1.1\r\n\r\n").unwrap(); + let mut buf = [0; 256]; + tcp.read(&mut buf).unwrap(); + + let expected = "HTTP/1.1 500 "; + assert_eq!(s(&buf[..expected.len()]), expected); + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| unreachable!()) + .and_then(|(item, _incoming)| { + let (socket, _) = item.unwrap(); + Http::::new() + .serve_connection(socket, service_fn(|_| { + Ok(Response::::new() + .with_status(StatusCode::Continue)) + })) + .map(|_| ()) + }); + + core.run(fut).unwrap_err(); +} + #[test] fn remote_addr() { let server = serve(); @@ -1191,3 +1224,4 @@ impl Drop for Dropped { self.0.store(true, Ordering::SeqCst); } } +