diff --git a/src/server/mod.rs b/src/server/mod.rs index 14fa36dac2..9bf8aac116 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -111,8 +111,6 @@ use std::fmt; use std::io::{self, ErrorKind, BufWriter, Write}; use std::net::{SocketAddr, ToSocketAddrs}; use std::thread::{self, JoinHandle}; - -#[cfg(feature = "timeouts")] use std::time::Duration; use num_cpus; @@ -146,20 +144,16 @@ mod listener; #[derive(Debug)] pub struct Server { listener: L, - _timeouts: Timeouts, + timeouts: Timeouts, } -#[cfg(feature = "timeouts")] #[derive(Clone, Copy, Default, Debug)] struct Timeouts { read: Option, write: Option, + keep_alive: Option, } -#[cfg(not(feature = "timeouts"))] -#[derive(Clone, Copy, Default, Debug)] -struct Timeouts; - macro_rules! try_option( ($e:expr) => {{ match $e { @@ -175,18 +169,30 @@ impl Server { pub fn new(listener: L) -> Server { Server { listener: listener, - _timeouts: Timeouts::default(), + timeouts: Timeouts::default(), } } + /// Enables keep-alive for this server. + /// + /// The timeout duration passed will be used to determine how long + /// to keep the connection alive before dropping it. + /// + /// **NOTE**: The timeout will only be used when the `timeouts` feature + /// is enabled for hyper, and rustc is 1.4 or greater. + #[inline] + pub fn keep_alive(&mut self, timeout: Duration) { + self.timeouts.keep_alive = Some(timeout); + } + #[cfg(feature = "timeouts")] pub fn set_read_timeout(&mut self, dur: Option) { - self._timeouts.read = dur; + self.timeouts.read = dur; } #[cfg(feature = "timeouts")] pub fn set_write_timeout(&mut self, dur: Option) { - self._timeouts.write = dur; + self.timeouts.write = dur; } @@ -228,7 +234,7 @@ L: NetworkListener + Send + 'static { debug!("threads = {:?}", threads); let pool = ListenerPool::new(server.listener); - let worker = Worker::new(handler, server._timeouts); + let worker = Worker::new(handler, server.timeouts); let work = move |mut stream| worker.handle_connection(&mut stream); let guard = thread::spawn(move || pool.accept(work, threads)); @@ -241,7 +247,7 @@ L: NetworkListener + Send + 'static { struct Worker { handler: H, - _timeouts: Timeouts, + timeouts: Timeouts, } impl Worker { @@ -249,7 +255,7 @@ impl Worker { fn new(handler: H, timeouts: Timeouts) -> Worker { Worker { handler: handler, - _timeouts: timeouts, + timeouts: timeouts, } } @@ -258,7 +264,7 @@ impl Worker { self.handler.on_connection_start(); - if let Err(e) = self.set_timeouts(stream) { + if let Err(e) = self.set_timeouts(&*stream) { error!("set_timeouts error: {:?}", e); return; } @@ -273,73 +279,97 @@ impl Worker { // FIXME: Use Type ascription let stream_clone: &mut NetworkStream = &mut stream.clone(); - let rdr = BufReader::new(stream_clone); - let wrt = BufWriter::new(stream); + let mut rdr = BufReader::new(stream_clone); + let mut wrt = BufWriter::new(stream); - self.keep_alive_loop(rdr, wrt, addr); + while self.keep_alive_loop(&mut rdr, &mut wrt, addr) { + if let Err(e) = self.set_read_timeout(*rdr.get_ref(), self.timeouts.keep_alive) { + error!("set_read_timeout keep_alive {:?}", e); + break; + } + } self.handler.on_connection_end(); debug!("keep_alive loop ending for {}", addr); } + fn set_timeouts(&self, s: &NetworkStream) -> io::Result<()> { + try!(self.set_read_timeout(s, self.timeouts.read)); + self.set_write_timeout(s, self.timeouts.write) + } + + #[cfg(not(feature = "timeouts"))] - fn set_timeouts(&self, _: &mut S) -> io::Result<()> where S: NetworkStream { + fn set_write_timeout(&self, _s: &NetworkStream, _timeout: Option) -> io::Result<()> { Ok(()) } #[cfg(feature = "timeouts")] - fn set_timeouts(&self, s: &mut S) -> io::Result<()> where S: NetworkStream { - try!(s.set_read_timeout(self._timeouts.read)); - s.set_write_timeout(self._timeouts.write) + fn set_write_timeout(&self, s: &NetworkStream, timeout: Option) -> io::Result<()> { + s.set_write_timeout(timeout) } - fn keep_alive_loop(&self, mut rdr: BufReader<&mut NetworkStream>, - mut wrt: W, addr: SocketAddr) { - let mut keep_alive = true; - while keep_alive { - let req = match Request::new(&mut rdr, addr) { - Ok(req) => req, - Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => { - trace!("tcp closed, cancelling keep-alive loop"); - break; - } - Err(Error::Io(e)) => { - debug!("ioerror in keepalive loop = {:?}", e); - break; - } - Err(e) => { - //TODO: send a 400 response - error!("request error = {:?}", e); - break; - } - }; + #[cfg(not(feature = "timeouts"))] + fn set_read_timeout(&self, _s: &NetworkStream, _timeout: Option) -> io::Result<()> { + Ok(()) + } + #[cfg(feature = "timeouts")] + fn set_read_timeout(&self, s: &NetworkStream, timeout: Option) -> io::Result<()> { + s.set_read_timeout(timeout) + } - if !self.handle_expect(&req, &mut wrt) { - break; + fn keep_alive_loop(&self, mut rdr: &mut BufReader<&mut NetworkStream>, + wrt: &mut W, addr: SocketAddr) -> bool { + let req = match Request::new(rdr, addr) { + Ok(req) => req, + Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => { + trace!("tcp closed, cancelling keep-alive loop"); + return false; } - - keep_alive = http::should_keep_alive(req.version, &req.headers); - let version = req.version; - let mut res_headers = Headers::new(); - if !keep_alive { - res_headers.set(Connection::close()); + Err(Error::Io(e)) => { + debug!("ioerror in keepalive loop = {:?}", e); + return false; } - { - let mut res = Response::new(&mut wrt, &mut res_headers); - res.version = version; - self.handler.handle(req, res); + Err(e) => { + //TODO: send a 400 response + error!("request error = {:?}", e); + return false; } + }; - // if the request was keep-alive, we need to check that the server agrees - // if it wasn't, then the server cannot force it to be true anyways - if keep_alive { - keep_alive = http::should_keep_alive(version, &res_headers); - } - debug!("keep_alive = {:?} for {}", keep_alive, addr); + if !self.handle_expect(&req, wrt) { + return false; + } + + if let Err(e) = req.set_read_timeout(self.timeouts.read) { + error!("set_read_timeout {:?}", e); + return false; + } + + let mut keep_alive = self.timeouts.keep_alive.is_some() && + http::should_keep_alive(req.version, &req.headers); + let version = req.version; + let mut res_headers = Headers::new(); + if !keep_alive { + res_headers.set(Connection::close()); } + { + let mut res = Response::new(wrt, &mut res_headers); + res.version = version; + self.handler.handle(req, res); + } + + // if the request was keep-alive, we need to check that the server agrees + // if it wasn't, then the server cannot force it to be true anyways + if keep_alive { + keep_alive = http::should_keep_alive(version, &res_headers); + } + + debug!("keep_alive = {:?} for {}", keep_alive, addr); + keep_alive } fn handle_expect(&self, req: &Request, wrt: &mut W) -> bool { diff --git a/src/server/request.rs b/src/server/request.rs index 732aaa33c7..851a426f17 100644 --- a/src/server/request.rs +++ b/src/server/request.rs @@ -4,6 +4,7 @@ //! target URI, headers, and message body. use std::io::{self, Read}; use std::net::SocketAddr; +use std::time::Duration; use buffer::BufReader; use net::NetworkStream; @@ -64,6 +65,19 @@ impl<'a, 'b: 'a> Request<'a, 'b> { }) } + /// Set the read timeout of the underlying NetworkStream. + #[cfg(feature = "timeouts")] + #[inline] + pub fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { + self.body.get_ref().get_ref().set_read_timeout(timeout) + } + + /// Set the read timeout of the underlying NetworkStream. + #[cfg(not(feature = "timeouts"))] + #[inline] + pub fn set_read_timeout(&self, _timeout: Option) -> io::Result<()> { + Ok(()) + } /// Get a reference to the underlying `NetworkStream`. #[inline] pub fn downcast_ref(&self) -> Option<&T> {