From f7532b71d141ebe41172dbb863d58d519e387a4e Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 2 Oct 2017 15:05:40 -0700 Subject: [PATCH] feat(lib): add support to disable tokio-proto internals For now, this adds `client::Config::no_proto`, `server::Http::no_proto`, and `server::Server::no_proto` to skip tokio-proto implementations, and use an internal dispatch system instead. `Http::no_proto` is similar to `Http::bind_connection`, but returns a `Connection` that is a `Future` to drive HTTP with the provided service. Any errors prior to parsing a request, and after delivering a response (but before flush the response body) will be returned from this future. See #1342 for more. --- .travis.yml | 2 + examples/client.rs | 6 +- examples/hello.rs | 3 +- examples/server.rs | 3 +- src/client/mod.rs | 179 +++++++++++++++------- src/proto/body.rs | 1 + src/proto/conn.rs | 124 +++++++++++----- src/proto/dispatch.rs | 325 ++++++++++++++++++++++++++++++++++++++++ src/proto/h1/parse.rs | 8 + src/proto/io.rs | 6 +- src/proto/mod.rs | 3 + src/server/mod.rs | 108 +++++++++++++- tests/client.rs | 338 ++++++++++++++++++++++++++++++++++++------ tests/server.rs | 93 +++++++++++- 14 files changed, 1042 insertions(+), 157 deletions(-) create mode 100644 src/proto/dispatch.rs diff --git a/.travis.yml b/.travis.yml index df3a678a8d..9834e7e2dc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,8 @@ matrix: env: FEATURES="--features nightly" - rust: beta - rust: stable + - rust: stable + env: HYPER_NO_PROTO=1 - rust: stable env: FEATURES="--features compat" - rust: 1.17.0 diff --git a/examples/client.rs b/examples/client.rs index ccaaa2e989..a5014e5a84 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,4 +1,4 @@ -#![deny(warnings)] +//#![deny(warnings)] extern crate futures; extern crate hyper; extern crate tokio_core; @@ -32,7 +32,9 @@ fn main() { let mut core = tokio_core::reactor::Core::new().unwrap(); let handle = core.handle(); - let client = Client::new(&handle); + let client = Client::configure() + .no_proto() + .build(&handle); let work = client.get(url).and_then(|res| { println!("Response: {}", res.status()); diff --git a/examples/hello.rs b/examples/hello.rs index f6924402f0..d468a680ad 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -31,7 +31,8 @@ impl Service for Hello { fn main() { pretty_env_logger::init().unwrap(); let addr = "127.0.0.1:3000".parse().unwrap(); - let server = Http::new().bind(&addr, || Ok(Hello)).unwrap(); + let mut server = Http::new().bind(&addr, || Ok(Hello)).unwrap(); + server.no_proto(); println!("Listening on http://{} with 1 thread.", server.local_addr().unwrap()); server.run().unwrap(); } diff --git a/examples/server.rs b/examples/server.rs index 4e164bf906..4c790a7775 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -47,7 +47,8 @@ fn main() { pretty_env_logger::init().unwrap(); let addr = "127.0.0.1:1337".parse().unwrap(); - let server = Http::new().bind(&addr, || Ok(Echo)).unwrap(); + let mut server = Http::new().bind(&addr, || Ok(Echo)).unwrap(); + server.no_proto(); println!("Listening on http://{} with 1 thread.", server.local_addr().unwrap()); server.run().unwrap(); } diff --git a/src/client/mod.rs b/src/client/mod.rs index ba29f628f4..bd65729ebc 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -20,7 +20,7 @@ use tokio_proto::util::client_proxy::ClientProxy; pub use tokio_service::Service; use header::{Headers, Host}; -use proto::{self, TokioBody}; +use proto::{self, RequestHead, TokioBody}; use proto::response; use proto::request; use method::Method; @@ -45,7 +45,7 @@ pub mod compat; pub struct Client { connector: C, handle: Handle, - pool: Pool>, + pool: Dispatch, } impl Client { @@ -93,7 +93,11 @@ impl Client { Client { connector: config.connector, handle: handle.clone(), - pool: Pool::new(config.keep_alive, config.keep_alive_timeout), + pool: if config.no_proto { + Dispatch::Hyper(Pool::new(config.keep_alive, config.keep_alive_timeout)) + } else { + Dispatch::Proto(Pool::new(config.keep_alive, config.keep_alive_timeout)) + } } } } @@ -187,48 +191,100 @@ where C: Connect, headers.extend(head.headers.iter()); head.headers = headers; - let checkout = self.pool.checkout(domain.as_ref()); - let connect = { - let handle = self.handle.clone(); - let pool = self.pool.clone(); - let pool_key = Rc::new(domain.to_string()); - self.connector.connect(url) - .map(move |io| { - let (tx, rx) = oneshot::channel(); - let client = HttpClient { - client_rx: RefCell::new(Some(rx)), - }.bind_client(&handle, io); - let pooled = pool.pooled(pool_key, client); - drop(tx.send(pooled.clone())); - pooled - }) - }; + match self.pool { + Dispatch::Proto(ref pool) => { + trace!("proto_dispatch"); + let checkout = pool.checkout(domain.as_ref()); + let connect = { + let handle = self.handle.clone(); + let pool = pool.clone(); + let pool_key = Rc::new(domain.to_string()); + self.connector.connect(url) + .map(move |io| { + let (tx, rx) = oneshot::channel(); + let client = HttpClient { + client_rx: RefCell::new(Some(rx)), + }.bind_client(&handle, io); + let pooled = pool.pooled(pool_key, client); + drop(tx.send(pooled.clone())); + pooled + }) + }; + + let race = checkout.select(connect) + .map(|(client, _work)| client) + .map_err(|(e, _work)| { + // the Pool Checkout cannot error, so the only error + // is from the Connector + // XXX: should wait on the Checkout? Problem is + // that if the connector is failing, it may be that we + // never had a pooled stream at all + e.into() + }); + let resp = race.and_then(move |client| { + let msg = match body { + Some(body) => { + Message::WithBody(head, body.into()) + }, + None => Message::WithoutBody(head), + }; + client.call(msg) + }); + FutureResponse(Box::new(resp.map(|msg| { + match msg { + Message::WithoutBody(head) => response::from_wire(head, None), + Message::WithBody(head, body) => response::from_wire(head, Some(body.into())), + } + }))) + }, + Dispatch::Hyper(ref pool) => { + trace!("no_proto dispatch"); + use futures::Sink; + use futures::sync::{mpsc, oneshot}; + + let checkout = pool.checkout(domain.as_ref()); + let connect = { + let handle = self.handle.clone(); + let pool = pool.clone(); + let pool_key = Rc::new(domain.to_string()); + self.connector.connect(url) + .map(move |io| { + let (tx, rx) = mpsc::channel(1); + let pooled = pool.pooled(pool_key, RefCell::new(tx)); + let conn = proto::Conn::<_, _, proto::ClientTransaction, _>::new(io, pooled.clone()); + let dispatch = proto::dispatch::Dispatcher::new(proto::dispatch::Client::new(rx), conn); + handle.spawn(dispatch.map_err(|err| error!("no_proto error: {}", err))); + pooled + }) + }; + + let race = checkout.select(connect) + .map(|(client, _work)| client) + .map_err(|(e, _work)| { + // the Pool Checkout cannot error, so the only error + // is from the Connector + // XXX: should wait on the Checkout? Problem is + // that if the connector is failing, it may be that we + // never had a pooled stream at all + e.into() + }); + + let resp = race.and_then(move |client| { + let (callback, rx) = oneshot::channel(); + client.borrow_mut().start_send((head, body, callback)).unwrap(); + rx.then(|res| { + match res { + Ok(Ok(res)) => Ok(res), + Ok(Err(err)) => Err(err), + Err(_) => panic!("dispatch dropped without returning error"), + } + }) + }); + + FutureResponse(Box::new(resp)) - let race = checkout.select(connect) - .map(|(client, _work)| client) - .map_err(|(e, _work)| { - // the Pool Checkout cannot error, so the only error - // is from the Connector - // XXX: should wait on the Checkout? Problem is - // that if the connector is failing, it may be that we - // never had a pooled stream at all - e.into() - }); - let resp = race.and_then(move |client| { - let msg = match body { - Some(body) => { - Message::WithBody(head, body.into()) - }, - None => Message::WithoutBody(head), - }; - client.call(msg) - }); - FutureResponse(Box::new(resp.map(|msg| { - match msg { - Message::WithoutBody(head) => response::from_wire(head, None), - Message::WithBody(head, body) => response::from_wire(head, Some(body.into())), } - }))) + } } } @@ -238,7 +294,10 @@ impl Clone for Client { Client { connector: self.connector.clone(), handle: self.handle.clone(), - pool: self.pool.clone(), + pool: match self.pool { + Dispatch::Proto(ref pool) => Dispatch::Proto(pool.clone()), + Dispatch::Hyper(ref pool) => Dispatch::Hyper(pool.clone()), + } } } } @@ -249,10 +308,16 @@ impl fmt::Debug for Client { } } -type TokioClient = ClientProxy, Message, ::Error>; +type ProtoClient = ClientProxy, Message, ::Error>; +type HyperClient = RefCell<::futures::sync::mpsc::Sender<(RequestHead, Option, ::futures::sync::oneshot::Sender<::Result<::Response>>)>>; + +enum Dispatch { + Proto(Pool>), + Hyper(Pool>), +} struct HttpClient { - client_rx: RefCell>>>>, + client_rx: RefCell>>>>, } impl ClientProto for HttpClient @@ -265,7 +330,7 @@ where T: AsyncRead + AsyncWrite + 'static, type Response = proto::ResponseHead; type ResponseBody = proto::Chunk; type Error = ::Error; - type Transport = proto::Conn>>; + type Transport = proto::Conn>>; type BindTransport = BindingClient; fn bind_transport(&self, io: T) -> Self::BindTransport { @@ -277,7 +342,7 @@ where T: AsyncRead + AsyncWrite + 'static, } struct BindingClient { - rx: oneshot::Receiver>>, + rx: oneshot::Receiver>>, io: Option, } @@ -286,7 +351,7 @@ where T: AsyncRead + AsyncWrite + 'static, B: Stream, B::Item: AsRef<[u8]>, { - type Item = proto::Conn>>; + type Item = proto::Conn>>; type Error = io::Error; fn poll(&mut self) -> Poll { @@ -309,6 +374,7 @@ pub struct Config { keep_alive_timeout: Option, //TODO: make use of max_idle config max_idle: usize, + no_proto: bool, } /// Phantom type used to signal that `Config` should create a `HttpConnector`. @@ -324,6 +390,7 @@ impl Default for Config { keep_alive: true, keep_alive_timeout: Some(Duration::from_secs(90)), max_idle: 5, + no_proto: false, } } } @@ -347,6 +414,7 @@ impl Config { keep_alive: self.keep_alive, keep_alive_timeout: self.keep_alive_timeout, max_idle: self.max_idle, + no_proto: self.no_proto, } } @@ -360,6 +428,7 @@ impl Config { keep_alive: self.keep_alive, keep_alive_timeout: self.keep_alive_timeout, max_idle: self.max_idle, + no_proto: self.no_proto, } } @@ -393,6 +462,13 @@ impl Config { self } */ + + /// Disable tokio-proto internal usage. + #[inline] + pub fn no_proto(mut self) -> Config { + self.no_proto = true; + self + } } impl Config @@ -431,11 +507,8 @@ impl fmt::Debug for Config { impl Clone for Config { fn clone(&self) -> Config { Config { - _body_type: PhantomData::, connector: self.connector.clone(), - keep_alive: self.keep_alive, - keep_alive_timeout: self.keep_alive_timeout, - max_idle: self.max_idle, + .. *self } } } diff --git a/src/proto/body.rs b/src/proto/body.rs index 27fb34ec78..cd7935f5ab 100644 --- a/src/proto/body.rs +++ b/src/proto/body.rs @@ -7,6 +7,7 @@ use std::borrow::Cow; use super::Chunk; pub type TokioBody = tokio_proto::streaming::Body; +pub type BodySender = mpsc::Sender>; /// A `Stream` for `Chunk`s used in requests and responses. #[must_use = "streams do nothing unless polled"] diff --git a/src/proto/conn.rs b/src/proto/conn.rs index 63b527042b..823ae8b4bf 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -7,7 +7,7 @@ use futures::task::Task; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_proto::streaming::pipeline::{Frame, Transport}; -use proto::{Http1Transaction}; +use proto::Http1Transaction; use super::io::{Cursor, Buffered}; use super::h1::{Encoder, Decoder}; use method::Method; @@ -51,15 +51,28 @@ where I: AsyncRead + AsyncWrite, self.io.set_flush_pipeline(enabled); } - fn poll2(&mut self) -> Poll, super::Chunk, ::Error>>, io::Error> { - trace!("Conn::poll()"); + fn poll_incoming(&mut self) -> Poll, super::Chunk, ::Error>>, io::Error> { + trace!("Conn::poll_incoming()"); loop { if self.is_read_closed() { trace!("Conn::poll when closed"); return Ok(Async::Ready(None)); } else if self.can_read_head() { - return self.read_head(); + return match self.read_head() { + Ok(Async::Ready(Some((head, body)))) => { + Ok(Async::Ready(Some(Frame::Message { + message: head, + body: body, + }))) + }, + Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(::Error::Io(err)) => Err(err), + Err(err) => Ok(Async::Ready(Some(Frame::Error { + error: err, + }))), + }; } else if self.can_write_continue() { try_nb!(self.flush()); } else if self.can_read_body() { @@ -79,16 +92,15 @@ where I: AsyncRead + AsyncWrite, } } - fn is_read_closed(&self) -> bool { + pub fn is_read_closed(&self) -> bool { self.state.is_read_closed() } - #[allow(unused)] - fn is_write_closed(&self) -> bool { + pub fn is_write_closed(&self) -> bool { self.state.is_write_closed() } - fn can_read_head(&self) -> bool { + pub fn can_read_head(&self) -> bool { match self.state.reading { Reading::Init => true, _ => false, @@ -102,14 +114,14 @@ where I: AsyncRead + AsyncWrite, } } - fn can_read_body(&self) -> bool { + pub fn can_read_body(&self) -> bool { match self.state.reading { Reading::Body(..) => true, _ => false, } } - fn read_head(&mut self) -> Poll, super::Chunk, ::Error>>, io::Error> { + pub fn read_head(&mut self) -> Poll, bool)>, ::Error> { debug_assert!(self.can_read_head()); trace!("Conn::read_head"); @@ -117,13 +129,16 @@ where I: AsyncRead + AsyncWrite, Ok(Async::Ready(head)) => (head.version, head), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(e) => { - let must_respond_with_error = !self.state.is_idle(); + // If we are currently waiting on a message, then an empty + // message should be reported as an error. If not, it is just + // the connection closing gracefully. + let must_error = !self.state.is_idle() && T::should_error_on_parse_eof(); self.state.close_read(); self.io.consume_leading_lines(); let was_mid_parse = !self.io.read_buf().is_empty(); - return if was_mid_parse || must_respond_with_error { + return if was_mid_parse || must_error { debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len()); - Ok(Async::Ready(Some(Frame::Error { error: e }))) + Err(e) } else { debug!("read eof"); Ok(Async::Ready(None)) @@ -138,7 +153,7 @@ where I: AsyncRead + AsyncWrite, Err(e) => { debug!("decoder error = {:?}", e); self.state.close_read(); - return Ok(Async::Ready(Some(Frame::Error { error: e }))); + return Err(e); } }; self.state.busy(); @@ -154,17 +169,17 @@ where I: AsyncRead + AsyncWrite, (true, Reading::Body(decoder)) }; self.state.reading = reading; - Ok(Async::Ready(Some(Frame::Message { message: head, body: body }))) + Ok(Async::Ready(Some((head, body)))) }, _ => { error!("unimplemented HTTP Version = {:?}", version); self.state.close_read(); - Ok(Async::Ready(Some(Frame::Error { error: ::Error::Version }))) + Err(::Error::Version) } } } - fn read_body(&mut self) -> Poll, io::Error> { + pub fn read_body(&mut self) -> Poll, io::Error> { debug_assert!(self.can_read_body()); trace!("Conn::read_body"); @@ -187,7 +202,7 @@ where I: AsyncRead + AsyncWrite, ret } - fn maybe_park_read(&mut self) { + pub fn maybe_park_read(&mut self) { if !self.io.is_read_blocked() { // the Io object is ready to read, which means it will never alert // us that it is ready until we drain it. However, we're currently @@ -236,13 +251,16 @@ where I: AsyncRead + AsyncWrite, return }, Err(e) => { - trace!("maybe_notify read_from_io error: {}", e); + trace!("maybe_notify; read_from_io error: {}", e); self.state.close(); } } } if let Some(ref task) = self.state.read_task { + trace!("maybe_notify; notifying task"); task.notify(); + } else { + trace!("maybe_notify; no task to notify"); } } } @@ -252,14 +270,14 @@ where I: AsyncRead + AsyncWrite, self.maybe_notify(); } - fn can_write_head(&self) -> bool { + pub fn can_write_head(&self) -> bool { match self.state.writing { Writing::Continue(..) | Writing::Init => true, _ => false } } - fn can_write_body(&self) -> bool { + pub fn can_write_body(&self) -> bool { match self.state.writing { Writing::Body(..) => true, Writing::Continue(..) | @@ -277,7 +295,7 @@ where I: AsyncRead + AsyncWrite, } } - fn write_head(&mut self, head: super::MessageHead, body: bool) { + pub fn write_head(&mut self, head: super::MessageHead, body: bool) { debug_assert!(self.can_write_head()); let wants_keep_alive = head.should_keep_alive(); @@ -298,7 +316,7 @@ where I: AsyncRead + AsyncWrite, }; } - fn write_body(&mut self, chunk: Option) -> StartSend, io::Error> { + pub fn write_body(&mut self, chunk: Option) -> StartSend, io::Error> { debug_assert!(self.can_write_body()); if self.has_queued_body() { @@ -397,7 +415,7 @@ where I: AsyncRead + AsyncWrite, Ok(Async::Ready(())) } - fn flush(&mut self) -> Poll<(), io::Error> { + pub fn flush(&mut self) -> Poll<(), io::Error> { loop { let queue_finished = try!(self.write_queued()).is_ready(); try_nb!(self.io.flush()); @@ -410,8 +428,18 @@ where I: AsyncRead + AsyncWrite, Ok(Async::Ready(())) } + + pub fn close_read(&mut self) { + self.state.close_read(); + } + + pub fn close_write(&mut self) { + self.state.close_write(); + } } +// ==== tokio_proto impl ==== + impl Stream for Conn where I: AsyncRead + AsyncWrite, B: AsRef<[u8]>, @@ -423,7 +451,7 @@ where I: AsyncRead + AsyncWrite, #[inline] fn poll(&mut self) -> Poll, Self::Error> { - self.poll2().map_err(|err| { + self.poll_incoming().map_err(|err| { debug!("poll error: {}", err); err }) @@ -635,6 +663,12 @@ impl State { self.keep_alive.disable(); } + fn close_write(&mut self) { + trace!("State::close_write()"); + self.writing = Writing::Closed; + self.keep_alive.disable(); + } + fn try_keep_alive(&mut self) { match (&self.reading, &self.writing) { (&Reading::KeepAlive, &Writing::KeepAlive) => { @@ -652,14 +686,6 @@ impl State { } } - fn is_idle(&self) -> bool { - if let KA::Idle = self.keep_alive.status() { - true - } else { - false - } - } - fn busy(&mut self) { if let KA::Disabled = self.keep_alive.status() { return; @@ -674,6 +700,14 @@ impl State { self.keep_alive.idle(); } + fn is_idle(&self) -> bool { + if let KA::Idle = self.keep_alive.status() { + true + } else { + false + } + } + fn is_read_closed(&self) -> bool { match self.reading { Reading::Closed => true, @@ -681,7 +715,6 @@ impl State { } } - #[allow(unused)] fn is_write_closed(&self) -> bool { match self.writing { Writing::Closed => true, @@ -727,7 +760,7 @@ mod tests { use futures::future; use tokio_proto::streaming::pipeline::Frame; - use proto::{self, MessageHead, ServerTransaction}; + use proto::{self, ClientTransaction, MessageHead, ServerTransaction}; use super::super::h1::Encoder; use mock::AsyncIo; @@ -799,21 +832,32 @@ mod tests { let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); conn.state.idle(); - match conn.poll().unwrap() { - Async::Ready(Some(Frame::Error { .. })) => {}, - other => panic!("frame is not Error: {:?}", other) + match conn.poll() { + Err(ref err) if err.kind() == ::std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) } } #[test] fn test_conn_init_read_eof_busy() { + // server ignores let io = AsyncIo::new_buf(vec![], 1); let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); conn.state.busy(); match conn.poll().unwrap() { - Async::Ready(Some(Frame::Error { .. })) => {}, - other => panic!("frame is not Error: {:?}", other) + Async::Ready(None) => {}, + other => panic!("unexpected frame: {:?}", other) + } + + // client, when busy, returns the error + let io = AsyncIo::new_buf(vec![], 1); + let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default()); + conn.state.busy(); + + match conn.poll() { + Err(ref err) if err.kind() == ::std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) } } diff --git a/src/proto/dispatch.rs b/src/proto/dispatch.rs new file mode 100644 index 0000000000..e24ae79b8d --- /dev/null +++ b/src/proto/dispatch.rs @@ -0,0 +1,325 @@ +use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; +use futures::sync::{mpsc, oneshot}; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_service::Service; + +use super::{Body, Conn, KeepAlive, Http1Transaction, MessageHead, RequestHead, ResponseHead}; +use ::StatusCode; + +pub struct Dispatcher { + conn: Conn, + dispatch: D, + body_tx: Option, + body_rx: Option, +} + +pub trait Dispatch { + type PollItem; + type PollBody; + type RecvItem; + fn poll_msg(&mut self) -> Poll)>, ::Error>; + fn recv_msg(&mut self, msg: ::Result<(Self::RecvItem, Body)>) -> ::Result<()>; + fn should_poll(&self) -> bool; +} + +pub struct Server { + in_flight: Option, + service: S, +} + +pub struct Client { + callback: Option>>, + rx: ClientRx, +} + +type ClientRx = mpsc::Receiver<(RequestHead, Option, oneshot::Sender<::Result<::Response>>)>; + +impl Dispatcher +where + D: Dispatch, PollBody=Bs, RecvItem=MessageHead>, + I: AsyncRead + AsyncWrite, + B: AsRef<[u8]>, + T: Http1Transaction, + K: KeepAlive, + Bs: Stream, +{ + pub fn new(dispatch: D, conn: Conn) -> Self { + Dispatcher { + conn: conn, + dispatch: dispatch, + body_tx: None, + body_rx: None, + } + } + + fn poll_read(&mut self) -> Poll<(), ::Error> { + loop { + if self.conn.can_read_head() { + match self.conn.read_head() { + Ok(Async::Ready(Some((head, has_body)))) => { + let body = if has_body { + let (tx, rx) = super::Body::pair(); + self.body_tx = Some(tx); + rx + } else { + Body::empty() + }; + self.dispatch.recv_msg(Ok((head, body))).expect("recv_msg with Ok shouldn't error"); + }, + Ok(Async::Ready(None)) => { + // read eof, conn will start to shutdown automatically + return Ok(Async::Ready(())); + } + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(err) => { + debug!("read_head error: {}", err); + self.dispatch.recv_msg(Err(err))?; + // if here, the dispatcher gave the user the error + // somewhere else. we still need to shutdown, but + // not as a second error. + return Ok(Async::Ready(())); + } + } + } else if let Some(mut body) = self.body_tx.take() { + let can_read_body = self.conn.can_read_body(); + match body.poll_ready() { + Ok(Async::Ready(())) => (), + Ok(Async::NotReady) => { + self.body_tx = Some(body); + return Ok(Async::NotReady); + }, + Err(_canceled) => { + // user doesn't care about the body + // so we should stop reading + if can_read_body { + trace!("body receiver dropped before eof, closing"); + self.conn.close_read(); + return Ok(Async::Ready(())); + } + } + } + if can_read_body { + match self.conn.read_body() { + Ok(Async::Ready(Some(chunk))) => { + match body.start_send(Ok(chunk)) { + Ok(AsyncSink::Ready) => { + self.body_tx = Some(body); + }, + Ok(AsyncSink::NotReady(_chunk)) => { + unreachable!("mpsc poll_ready was ready, start_send was not"); + } + Err(_canceled) => { + if self.conn.can_read_body() { + trace!("body receiver dropped before eof, closing"); + self.conn.close_read(); + } + } + + } + }, + Ok(Async::Ready(None)) => { + let _ = body.close(); + }, + Ok(Async::NotReady) => { + self.body_tx = Some(body); + return Ok(Async::NotReady); + } + Err(e) => { + let _ = body.start_send(Err(::Error::Io(e))); + } + } + } else { + let _ = body.close(); + } + } else { + self.conn.maybe_park_read(); + return Ok(Async::Ready(())); + } + } + } + + fn poll_write(&mut self) -> Poll<(), ::Error> { + loop { + if self.body_rx.is_none() && self.dispatch.should_poll() { + if let Some((head, body)) = try_ready!(self.dispatch.poll_msg()) { + self.conn.write_head(head, body.is_some()); + self.body_rx = body; + } else { + self.conn.close_write(); + return Ok(Async::Ready(())); + } + } else if let Some(mut body) = self.body_rx.take() { + let chunk = match body.poll()? { + Async::Ready(Some(chunk)) => { + self.body_rx = Some(body); + chunk + }, + Async::Ready(None) => { + if self.conn.can_write_body() { + self.conn.write_body(None)?; + } + continue; + }, + Async::NotReady => { + self.body_rx = Some(body); + return Ok(Async::NotReady); + } + }; + self.conn.write_body(Some(chunk))?; + } else { + return Ok(Async::NotReady); + } + } + } + + fn poll_flush(&mut self) -> Poll<(), ::Error> { + self.conn.flush().map_err(|err| { + debug!("error writing: {}", err); + err.into() + }) + } + + fn is_done(&self) -> bool { + let read_done = self.conn.is_read_closed(); + let write_done = self.conn.is_write_closed() || + (!self.dispatch.should_poll() && self.body_rx.is_none()); + + read_done && write_done + } +} + + +impl Future for Dispatcher +where + D: Dispatch, PollBody=Bs, RecvItem=MessageHead>, + I: AsyncRead + AsyncWrite, + B: AsRef<[u8]>, + T: Http1Transaction, + K: KeepAlive, + Bs: Stream, +{ + type Item = (); + type Error = ::Error; + + #[inline] + fn poll(&mut self) -> Poll { + self.poll_read()?; + self.poll_write()?; + self.poll_flush()?; + + if self.is_done() { + trace!("Dispatch::poll done"); + Ok(Async::Ready(())) + } else { + Ok(Async::NotReady) + } + } +} + +// ===== impl Server ===== + +impl Server where S: Service { + pub fn new(service: S) -> Server { + Server { + in_flight: None, + service: service, + } + } +} + +impl Dispatch for Server +where + S: Service, Error=::Error>, + Bs: Stream, + Bs::Item: AsRef<[u8]>, +{ + type PollItem = MessageHead; + type PollBody = Bs; + type RecvItem = RequestHead; + + fn poll_msg(&mut self) -> Poll)>, ::Error> { + if let Some(mut fut) = self.in_flight.take() { + let resp = match fut.poll()? { + Async::Ready(res) => res, + Async::NotReady => { + self.in_flight = Some(fut); + return Ok(Async::NotReady); + } + }; + let (head, body) = super::response::split(resp); + Ok(Async::Ready(Some((head.into(), body)))) + } else { + unreachable!("poll_msg shouldn't be called if no inflight"); + } + } + + fn recv_msg(&mut self, msg: ::Result<(Self::RecvItem, Body)>) -> ::Result<()> { + let (msg, body) = msg?; + let req = super::request::from_wire(None, msg, body); + self.in_flight = Some(self.service.call(req)); + Ok(()) + } + + fn should_poll(&self) -> bool { + self.in_flight.is_some() + } +} + +// ===== impl Client ===== + +impl Client { + pub fn new(rx: ClientRx) -> Client { + Client { + callback: None, + rx: rx, + } + } +} + +impl Dispatch for Client +where + B: Stream, + B::Item: AsRef<[u8]>, +{ + type PollItem = RequestHead; + type PollBody = B; + type RecvItem = ResponseHead; + + fn poll_msg(&mut self) -> Poll)>, ::Error> { + match self.rx.poll() { + Ok(Async::Ready(Some((head, body, cb)))) => { + self.callback = Some(cb); + Ok(Async::Ready(Some((head, body)))) + }, + Ok(Async::Ready(None)) => { + // user has dropped sender handle + Ok(Async::Ready(None)) + }, + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(()) => unreachable!("mpsc receiver cannot error"), + } + } + + fn recv_msg(&mut self, msg: ::Result<(Self::RecvItem, Body)>) -> ::Result<()> { + match msg { + Ok((msg, body)) => { + let res = super::response::from_wire(msg, Some(body)); + let cb = self.callback.take().expect("recv_msg without callback"); + let _ = cb.send(Ok(res)); + Ok(()) + }, + Err(err) => { + if let Some(cb) = self.callback.take() { + let _ = cb.send(Err(err)); + Ok(()) + } else { + Err(err) + } + } + } + } + + fn should_poll(&self) -> bool { + self.callback.is_none() + } +} diff --git a/src/proto/h1/parse.rs b/src/proto/h1/parse.rs index 8fff641a2b..93a96a6e7c 100644 --- a/src/proto/h1/parse.rs +++ b/src/proto/h1/parse.rs @@ -132,6 +132,10 @@ impl Http1Transaction for ServerTransaction { extend(dst, b"\r\n"); body } + + fn should_error_on_parse_eof() -> bool { + false + } } impl ServerTransaction { @@ -281,6 +285,10 @@ impl Http1Transaction for ClientTransaction { body } + + fn should_error_on_parse_eof() -> bool { + true + } } impl ClientTransaction { diff --git a/src/proto/io.rs b/src/proto/io.rs index 0117758aa6..1cc0a1e489 100644 --- a/src/proto/io.rs +++ b/src/proto/io.rs @@ -84,7 +84,7 @@ impl Buffered { match try_ready!(self.read_from_io()) { 0 => { trace!("parse eof"); - //TODO: With Rust 1.14, this can be Error::from(ErrorKind) + //TODO: utilize Error::Incomplete when Error type is redesigned return Err(io::Error::new(io::ErrorKind::UnexpectedEof, ParseEof).into()); } _ => {}, @@ -335,13 +335,13 @@ struct ParseEof; impl fmt::Display for ParseEof { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str("parse eof") + f.write_str(::std::error::Error::description(self)) } } impl ::std::error::Error for ParseEof { fn description(&self) -> &str { - "parse eof" + "end of file reached before parsing could complete" } } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 53fe49f4ed..11552557ab 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -19,6 +19,7 @@ pub use self::chunk::Chunk; mod body; mod chunk; mod conn; +pub mod dispatch; mod io; mod h1; //mod h2; @@ -146,6 +147,8 @@ pub trait Http1Transaction { fn parse(bytes: &mut BytesMut) -> ParseResult; fn decoder(head: &MessageHead, method: &mut Option<::Method>) -> ::Result; fn encode(head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> h1::Encoder; + + fn should_error_on_parse_eof() -> bool; } pub type ParseResult = ::Result, usize)>>; diff --git a/src/server/mod.rs b/src/server/mod.rs index 5da465126f..52594a41f6 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -66,6 +66,7 @@ where B: Stream, core: Core, listener: TcpListener, shutdown_timeout: Duration, + no_proto: bool, } impl + 'static> Http { @@ -121,6 +122,7 @@ impl + 'static> Http { listener: listener, protocol: self.clone(), shutdown_timeout: Duration::new(1, 0), + no_proto: false, }) } @@ -165,6 +167,30 @@ impl + 'static> Http { }) } + /// Bind a connection together with a Service. + /// + /// This returns a Future that must be polled in order for HTTP to be + /// driven on the connection. + /// + /// This additionally skips the tokio-proto infrastructure internally. + pub fn no_proto(&self, io: I, service: S) -> Connection + where S: Service, Error = ::Error> + 'static, + Bd: Stream + 'static, + I: AsyncRead + AsyncWrite + 'static, + + { + let ka = if self.keep_alive { + proto::KA::Busy + } else { + proto::KA::Disabled + }; + let mut conn = proto::Conn::new(io, ka); + conn.set_flush_pipeline(self.pipeline); + Connection { + conn: proto::dispatch::Dispatcher::new(proto::dispatch::Server::new(service), conn), + } + } + /// Bind a `Service` using types from the `http` crate. /// /// See `Http::bind_connection`. @@ -185,6 +211,67 @@ impl + 'static> Http { } } +/// A future binding a connection with a Service. +/// +/// Polling this future will drive HTTP forward. +#[must_use = "futures do nothing unless polled"] +pub struct Connection +where S: Service, + B: Stream, + B::Item: AsRef<[u8]>, +{ + conn: proto::dispatch::Dispatcher, B, I, B::Item, proto::ServerTransaction, proto::KA>, +} + +impl Future for Connection +where S: Service, Error = ::Error> + 'static, + I: AsyncRead + AsyncWrite + 'static, + B: Stream + 'static, + B::Item: AsRef<[u8]>, +{ + type Item = self::unnameable::Opaque; + type Error = ::Error; + + fn poll(&mut self) -> Poll { + try_ready!(self.conn.poll()); + Ok(self::unnameable::opaque().into()) + } +} + +impl fmt::Debug for Connection +where S: Service, + B: Stream, + B::Item: AsRef<[u8]>, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Connection") + .finish() + } +} + +mod unnameable { + // This type is specifically not exported outside the crate, + // so no one can actually name the type. With no methods, we make no + // promises about this type. + // + // All of that to say we can eventually replace the type returned + // to something else, and it would not be a breaking change. + // + // We may want to eventually yield the `T: AsyncRead + AsyncWrite`, which + // doesn't have a `Debug` bound. So, this type can't implement `Debug` + // either, so the type change doesn't break people. + #[allow(missing_debug_implementations)] + pub struct Opaque { + _inner: (), + } + + pub fn opaque() -> Opaque { + Opaque { + _inner: (), + } + } +} + impl Clone for Http { fn clone(&self) -> Http { Http { @@ -207,7 +294,7 @@ impl fmt::Debug for Http { pub struct __ProtoRequest(proto::RequestHead); #[doc(hidden)] #[allow(missing_debug_implementations)] -pub struct __ProtoResponse(ResponseHead); +pub struct __ProtoResponse(proto::MessageHead<::StatusCode>); #[doc(hidden)] #[allow(missing_debug_implementations)] pub struct __ProtoTransport(proto::Conn); @@ -368,8 +455,6 @@ struct HttpService { remote_addr: SocketAddr, } -type ResponseHead = proto::MessageHead<::StatusCode>; - impl Service for HttpService where T: Service, Error=::Error>, B: Stream, @@ -420,6 +505,12 @@ impl Server self } + /// Configure this server to not use tokio-proto infrastructure internally. + pub fn no_proto(&mut self) -> &mut Self { + self.no_proto = true; + self + } + /// Execute this server infinitely. /// /// This method does not currently return, but it will return an error if @@ -444,7 +535,7 @@ impl Server pub fn run_until(self, shutdown_signal: F) -> ::Result<()> where F: Future, { - let Server { protocol, new_service, mut core, listener, shutdown_timeout } = self; + let Server { protocol, new_service, mut core, listener, shutdown_timeout, no_proto } = self; let handle = core.handle(); // Mini future to track the number of active services @@ -460,7 +551,14 @@ impl Server info: Rc::downgrade(&info), }; info.borrow_mut().active += 1; - protocol.bind_connection(&handle, socket, addr, s); + if no_proto { + let fut = protocol.no_proto(socket, s) + .map(|_| ()) + .map_err(|err| error!("no_proto error: {}", err)); + handle.spawn(fut); + } else { + protocol.bind_connection(&handle, socket, addr, s); + } Ok(()) }); diff --git a/tests/client.rs b/tests/client.rs index 823ba62b07..5a5cb5bfa6 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -2,6 +2,7 @@ extern crate hyper; extern crate futures; extern crate tokio_core; +extern crate tokio_io; extern crate pretty_env_logger; use std::io::{self, Read, Write}; @@ -18,13 +19,24 @@ use futures::sync::oneshot; use tokio_core::reactor::{Core, Handle}; fn client(handle: &Handle) -> Client { - Client::new(handle) + let mut config = Client::configure(); + if env("HYPER_NO_PROTO", "1") { + config = config.no_proto(); + } + config.build(handle) } fn s(buf: &[u8]) -> &str { ::std::str::from_utf8(buf).unwrap() } +fn env(name: &str, val: &str) -> bool { + match ::std::env::var(name) { + Ok(var) => var == val, + Err(_) => false, + } +} + macro_rules! test { ( name: $name:ident, @@ -49,51 +61,24 @@ macro_rules! test { #![allow(unused)] use hyper::header::*; let _ = pretty_env_logger::init(); - let server = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = server.local_addr().unwrap(); let mut core = Core::new().unwrap(); - let client = client(&core.handle()); - let mut req = Request::new(Method::$client_method, format!($client_url, addr=addr).parse().unwrap()); - $( - req.headers_mut().set($request_headers); - )* - if let Some(body) = $request_body { - let body: &'static str = body; - req.set_body(body); - } - req.set_proxy($request_proxy); - - let res = client.request(req); - - let (tx, rx) = oneshot::channel(); - - let thread = thread::Builder::new() - .name(format!("tcp-server<{}>", stringify!($name))); - thread.spawn(move || { - let mut inc = server.accept().unwrap().0; - inc.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); - inc.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); - let expected = format!($server_expected, addr=addr); - let mut buf = [0; 4096]; - let mut n = 0; - while n < buf.len() && n < expected.len() { - n += match inc.read(&mut buf[n..]) { - Ok(n) => n, - Err(e) => panic!("failed to read request, partially read = {:?}, error: {}", s(&buf[..n]), e), - }; - } - assert_eq!(s(&buf[..n]), expected); - - inc.write_all($server_reply.as_ref()).unwrap(); - let _ = tx.send(()); - }).unwrap(); - - let rx = rx.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); - - let work = res.join(rx).map(|r| r.0); - - let res = core.run(work).unwrap(); + let res = test! { + INNER; + core: &mut core, + server: + expected: $server_expected, + reply: $server_reply, + client: + request: + method: $client_method, + url: $client_url, + headers: [ $($request_headers,)* ], + body: $request_body, + proxy: $request_proxy, + }.unwrap(); + + assert_eq!(res.status(), StatusCode::$client_status); $( assert_eq!(res.headers().get(), Some(&$response_headers)); @@ -106,6 +91,108 @@ macro_rules! test { assert_eq!(body.as_ref(), expected_res_body); } ); + ( + name: $name:ident, + server: + expected: $server_expected:expr, + reply: $server_reply:expr, + client: + request: + method: $client_method:ident, + url: $client_url:expr, + headers: [ $($request_headers:expr,)* ], + body: $request_body:expr, + proxy: $request_proxy:expr, + + error: $err:expr, + ) => ( + #[test] + fn $name() { + #![allow(unused)] + use hyper::header::*; + let _ = pretty_env_logger::init(); + let mut core = Core::new().unwrap(); + + let err = test! { + INNER; + core: &mut core, + server: + expected: $server_expected, + reply: $server_reply, + client: + request: + method: $client_method, + url: $client_url, + headers: [ $($request_headers,)* ], + body: $request_body, + proxy: $request_proxy, + }.unwrap_err(); + if !$err(&err) { + panic!("unexpected error: {:?}", err) + } + } + ); + + ( + INNER; + core: $core:expr, + server: + expected: $server_expected:expr, + reply: $server_reply:expr, + client: + request: + method: $client_method:ident, + url: $client_url:expr, + headers: [ $($request_headers:expr,)* ], + body: $request_body:expr, + proxy: $request_proxy:expr, + ) => ({ + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = $core; + let client = client(&core.handle()); + let mut req = Request::new(Method::$client_method, format!($client_url, addr=addr).parse().unwrap()); + $( + req.headers_mut().set($request_headers); + )* + + if let Some(body) = $request_body { + let body: &'static str = body; + req.set_body(body); + } + req.set_proxy($request_proxy); + + let res = client.request(req); + + let (tx, rx) = oneshot::channel(); + + let thread = thread::Builder::new() + .name(format!("tcp-server<{}>", stringify!($name))); + thread.spawn(move || { + let mut inc = server.accept().unwrap().0; + inc.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + inc.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let expected = format!($server_expected, addr=addr); + let mut buf = [0; 4096]; + let mut n = 0; + while n < buf.len() && n < expected.len() { + n += match inc.read(&mut buf[n..]) { + Ok(n) => n, + Err(e) => panic!("failed to read request, partially read = {:?}, error: {}", s(&buf[..n]), e), + }; + } + assert_eq!(s(&buf[..n]), expected); + + inc.write_all($server_reply.as_ref()).unwrap(); + let _ = tx.send(()); + }).unwrap(); + + let rx = rx.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + + let work = res.join(rx).map(|r| r.0); + + core.run(work) + }); } static REPLY_OK: &'static str = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"; @@ -266,6 +353,61 @@ test! { body: None, } +test! { + name: client_error_unexpected_eof, + + server: + expected: "\ + GET /err HTTP/1.1\r\n\ + Host: {addr}\r\n\ + \r\n\ + ", + reply: "\ + HTTP/1.1 200 OK\r\n\ + ", // unexpected eof before double CRLF + + client: + request: + method: Get, + url: "http://{addr}/err", + headers: [], + body: None, + proxy: false, + error: |err| match err { + &hyper::Error::Io(_) => true, + _ => false, + }, +} + +test! { + name: client_error_parse_version, + + server: + expected: "\ + GET /err HTTP/1.1\r\n\ + Host: {addr}\r\n\ + \r\n\ + ", + reply: "\ + HEAT/1.1 200 OK\r\n\ + \r\n\ + ", + + client: + request: + method: Get, + url: "http://{addr}/err", + headers: [], + body: None, + proxy: false, + error: |err| match err { + &hyper::Error::Version if env("HYPER_NO_PROTO", "1") => true, + &hyper::Error::Io(_) if !env("HYPER_NO_PROTO", "1") => true, + _ => false, + }, + +} + #[test] fn client_keep_alive() { let server = TcpListener::bind("127.0.0.1:0").unwrap(); @@ -285,9 +427,10 @@ fn client_keep_alive() { sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").expect("write 1"); let _ = tx1.send(()); - sock.read(&mut buf).expect("read 2"); - let second_get = b"GET /b HTTP/1.1\r\n"; - assert_eq!(&buf[..second_get.len()], second_get); + let n2 = sock.read(&mut buf).expect("read 2"); + assert_ne!(n2, 0); + let second_get = "GET /b HTTP/1.1\r\n"; + assert_eq!(s(&buf[..second_get.len()]), second_get); sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").expect("write 2"); let _ = tx2.send(()); }); @@ -367,3 +510,104 @@ fn client_pooled_socket_disconnected() { assert_ne!(addr1, addr2); } */ + +#[test] +fn drop_body_before_eof_closes_connection() { + // https://github.com/hyperium/hyper/issues/1353 + use std::io::{self, Read, Write}; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; + use tokio_core::reactor::{Timeout}; + use tokio_core::net::TcpStream; + use tokio_io::{AsyncRead, AsyncWrite}; + use hyper::client::HttpConnector; + use hyper::server::Service; + use hyper::Uri; + + let _ = pretty_env_logger::init(); + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let closes = Arc::new(AtomicUsize::new(0)); + let client = Client::configure() + .connector(DebugConnector(HttpConnector::new(1, &core.handle()), closes.clone())) + .no_proto() + .build(&handle); + + let (tx1, rx1) = oneshot::channel(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + let body = vec![b'x'; 1024 * 128]; + write!(sock, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).expect("write head"); + let _ = sock.write_all(&body); + let _ = tx1.send(()); + }); + + let uri = format!("http://{}/a", addr).parse().unwrap(); + + let res = client.get(uri).and_then(move |res| { + assert_eq!(res.status(), hyper::StatusCode::Ok); + Timeout::new(Duration::from_secs(1), &handle).unwrap() + .from_err() + }); + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + core.run(res.join(rx).map(|r| r.0)).unwrap(); + + assert_eq!(closes.load(Ordering::Relaxed), 1); + + + + struct DebugConnector(HttpConnector, Arc); + + impl Service for DebugConnector { + type Request = Uri; + type Response = DebugStream; + type Error = io::Error; + type Future = Box>; + + fn call(&self, uri: Uri) -> Self::Future { + let counter = self.1.clone(); + Box::new(self.0.call(uri).map(move |s| DebugStream(s, counter))) + } + } + + struct DebugStream(TcpStream, Arc); + + impl Drop for DebugStream { + fn drop(&mut self) { + self.1.fetch_add(1, Ordering::SeqCst); + } + } + + impl Write for DebugStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + } + + impl AsyncWrite for DebugStream { + fn shutdown(&mut self) -> futures::Poll<(), io::Error> { + AsyncWrite::shutdown(&mut self.0) + } + } + + impl Read for DebugStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + } + + impl AsyncRead for DebugStream {} +} diff --git a/tests/server.rs b/tests/server.rs index b804cecee1..c5ec8a8606 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -3,10 +3,15 @@ extern crate hyper; extern crate futures; extern crate spmc; extern crate pretty_env_logger; +extern crate tokio_core; use futures::{Future, Stream}; +use futures::future::{self, FutureResult}; use futures::sync::oneshot; +use tokio_core::net::TcpListener; +use tokio_core::reactor::Core; + use std::net::{TcpStream, SocketAddr}; use std::io::{Read, Write}; use std::sync::mpsc; @@ -387,6 +392,7 @@ fn disable_keep_alive() { .header(hyper::header::ContentLength(quux.len() as u64)) .body(quux); + // the write can possibly succeed, since it fills the kernel buffer on the first write let _ = req.write_all(b"\ GET /quux HTTP/1.1\r\n\ Host: example.domain\r\n\ @@ -394,7 +400,6 @@ fn disable_keep_alive() { \r\n\ "); - // the write can possibly succeed, since it fills the kernel buffer on the first write let mut buf = [0; 1024 * 8]; match req.read(&mut buf[..]) { // Ok(0) means EOF, so a proper shutdown @@ -504,6 +509,50 @@ fn pipeline_enabled() { assert_eq!(n, 0); } +#[test] +fn no_proto_empty_parse_eof_does_not_return_error() { + let mut core = Core::new().unwrap(); + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + let _tcp = connect(&addr); + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| unreachable!()) + .and_then(|(item, _incoming)| { + let (socket, _) = item.unwrap(); + Http::new().no_proto(socket, HelloWorld) + }); + + core.run(fut).unwrap(); +} + +#[test] +fn no_proto_nonempty_parse_eof_returns_error() { + let mut core = Core::new().unwrap(); + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + tcp.write_all(b"GET / HTTP/1.1").unwrap(); + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| unreachable!()) + .and_then(|(item, _incoming)| { + let (socket, _) = item.unwrap(); + Http::new().no_proto(socket, HelloWorld) + .map(|_| ()) + }); + + core.run(fut).unwrap_err(); +} + // ------------------------------------------------- // the Server that is used to run all the tests with // ------------------------------------------------- @@ -628,6 +677,20 @@ impl Service for TestService { } +struct HelloWorld; + +impl Service for HelloWorld { + type Request = Request; + type Response = Response; + type Error = hyper::Error; + type Future = FutureResult; + + fn call(&self, _req: Request) -> Self::Future { + future::ok(Response::new()) + } +} + + fn connect(addr: &SocketAddr) -> TcpStream { let req = TcpStream::connect(addr).unwrap(); req.set_read_timeout(Some(Duration::from_secs(1))).unwrap(); @@ -639,13 +702,31 @@ fn serve() -> Serve { serve_with_options(Default::default()) } -#[derive(Default)] struct ServeOptions { keep_alive_disabled: bool, + no_proto: bool, pipeline: bool, timeout: Option, } +impl Default for ServeOptions { + fn default() -> Self { + ServeOptions { + keep_alive_disabled: false, + no_proto: env("HYPER_NO_PROTO", "1"), + pipeline: false, + timeout: None, + } + } +} + +fn env(name: &str, val: &str) -> bool { + match ::std::env::var(name) { + Ok(var) => var == val, + Err(_) => false, + } +} + fn serve_with_options(options: ServeOptions) -> Serve { let _ = pretty_env_logger::init(); @@ -657,12 +738,13 @@ fn serve_with_options(options: ServeOptions) -> Serve { let addr = "127.0.0.1:0".parse().unwrap(); let keep_alive = !options.keep_alive_disabled; + let no_proto = !options.no_proto; let pipeline = options.pipeline; let dur = options.timeout; let thread_name = format!("test-server-{:?}", dur); let thread = thread::Builder::new().name(thread_name).spawn(move || { - let srv = Http::new() + let mut srv = Http::new() .keep_alive(keep_alive) .pipeline(pipeline) .bind(&addr, TestService { @@ -670,6 +752,9 @@ fn serve_with_options(options: ServeOptions) -> Serve { _timeout: dur, reply: reply_rx, }).unwrap(); + if no_proto { + srv.no_proto(); + } addr_tx.send(srv.local_addr().unwrap()).unwrap(); srv.run_until(shutdown_rx.then(|_| Ok(()))).unwrap(); }).unwrap(); @@ -684,5 +769,3 @@ fn serve_with_options(options: ServeOptions) -> Serve { thread: Some(thread), } } - -