Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(server): use a timeout for Server keep-alive #661

Merged
merged 1 commit into from
Oct 9, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 90 additions & 60 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ use std::fmt;
use std::io::{self, ErrorKind, BufWriter, Write};
use std::net::{SocketAddr, ToSocketAddrs};
use std::thread::{self, JoinHandle};

#[cfg(feature = "timeouts")]
use std::time::Duration;

use num_cpus;
Expand Down Expand Up @@ -146,20 +144,16 @@ mod listener;
#[derive(Debug)]
pub struct Server<L = HttpListener> {
listener: L,
_timeouts: Timeouts,
timeouts: Timeouts,
}

#[cfg(feature = "timeouts")]
#[derive(Clone, Copy, Default, Debug)]
struct Timeouts {
read: Option<Duration>,
write: Option<Duration>,
keep_alive: Option<Duration>,
}

#[cfg(not(feature = "timeouts"))]
#[derive(Clone, Copy, Default, Debug)]
struct Timeouts;

macro_rules! try_option(
($e:expr) => {{
match $e {
Expand All @@ -175,18 +169,30 @@ impl<L: NetworkListener> Server<L> {
pub fn new(listener: L) -> Server<L> {
Server {
listener: listener,
_timeouts: Timeouts::default(),
timeouts: Timeouts::default(),
}
}

/// Enables keep-alive for this server.
///
/// The timeout duration passed will be used to determine how long
/// to keep the connection alive before dropping it.
///
/// **NOTE**: The timeout will only be used when the `timeouts` feature
/// is enabled for hyper, and rustc is 1.4 or greater.
#[inline]
pub fn keep_alive(&mut self, timeout: Duration) {
self.timeouts.keep_alive = Some(timeout);
}

#[cfg(feature = "timeouts")]
pub fn set_read_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.read = dur;
self.timeouts.read = dur;
}

#[cfg(feature = "timeouts")]
pub fn set_write_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.write = dur;
self.timeouts.write = dur;
}


Expand Down Expand Up @@ -228,7 +234,7 @@ L: NetworkListener + Send + 'static {

debug!("threads = {:?}", threads);
let pool = ListenerPool::new(server.listener);
let worker = Worker::new(handler, server._timeouts);
let worker = Worker::new(handler, server.timeouts);
let work = move |mut stream| worker.handle_connection(&mut stream);

let guard = thread::spawn(move || pool.accept(work, threads));
Expand All @@ -241,15 +247,15 @@ L: NetworkListener + Send + 'static {

struct Worker<H: Handler + 'static> {
handler: H,
_timeouts: Timeouts,
timeouts: Timeouts,
}

impl<H: Handler + 'static> Worker<H> {

fn new(handler: H, timeouts: Timeouts) -> Worker<H> {
Worker {
handler: handler,
_timeouts: timeouts,
timeouts: timeouts,
}
}

Expand All @@ -258,7 +264,7 @@ impl<H: Handler + 'static> Worker<H> {

self.handler.on_connection_start();

if let Err(e) = self.set_timeouts(stream) {
if let Err(e) = self.set_timeouts(&*stream) {
error!("set_timeouts error: {:?}", e);
return;
}
Expand All @@ -273,73 +279,97 @@ impl<H: Handler + 'static> Worker<H> {

// FIXME: Use Type ascription
let stream_clone: &mut NetworkStream = &mut stream.clone();
let rdr = BufReader::new(stream_clone);
let wrt = BufWriter::new(stream);
let mut rdr = BufReader::new(stream_clone);
let mut wrt = BufWriter::new(stream);

self.keep_alive_loop(rdr, wrt, addr);
while self.keep_alive_loop(&mut rdr, &mut wrt, addr) {
if let Err(e) = self.set_read_timeout(*rdr.get_ref(), self.timeouts.keep_alive) {
error!("set_read_timeout keep_alive {:?}", e);
break;
}
}

self.handler.on_connection_end();

debug!("keep_alive loop ending for {}", addr);
}

fn set_timeouts(&self, s: &NetworkStream) -> io::Result<()> {
try!(self.set_read_timeout(s, self.timeouts.read));
self.set_write_timeout(s, self.timeouts.write)
}


#[cfg(not(feature = "timeouts"))]
fn set_timeouts<S>(&self, _: &mut S) -> io::Result<()> where S: NetworkStream {
fn set_write_timeout(&self, _s: &NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}

#[cfg(feature = "timeouts")]
fn set_timeouts<S>(&self, s: &mut S) -> io::Result<()> where S: NetworkStream {
try!(s.set_read_timeout(self._timeouts.read));
s.set_write_timeout(self._timeouts.write)
fn set_write_timeout(&self, s: &NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
s.set_write_timeout(timeout)
}

fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>,
mut wrt: W, addr: SocketAddr) {
let mut keep_alive = true;
while keep_alive {
let req = match Request::new(&mut rdr, addr) {
Ok(req) => req,
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
trace!("tcp closed, cancelling keep-alive loop");
break;
}
Err(Error::Io(e)) => {
debug!("ioerror in keepalive loop = {:?}", e);
break;
}
Err(e) => {
//TODO: send a 400 response
error!("request error = {:?}", e);
break;
}
};
#[cfg(not(feature = "timeouts"))]
fn set_read_timeout(&self, _s: &NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}

#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, s: &NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
s.set_read_timeout(timeout)
}

if !self.handle_expect(&req, &mut wrt) {
break;
fn keep_alive_loop<W: Write>(&self, mut rdr: &mut BufReader<&mut NetworkStream>,
wrt: &mut W, addr: SocketAddr) -> bool {
let req = match Request::new(rdr, addr) {
Ok(req) => req,
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
trace!("tcp closed, cancelling keep-alive loop");
return false;
}

keep_alive = http::should_keep_alive(req.version, &req.headers);
let version = req.version;
let mut res_headers = Headers::new();
if !keep_alive {
res_headers.set(Connection::close());
Err(Error::Io(e)) => {
debug!("ioerror in keepalive loop = {:?}", e);
return false;
}
{
let mut res = Response::new(&mut wrt, &mut res_headers);
res.version = version;
self.handler.handle(req, res);
Err(e) => {
//TODO: send a 400 response
error!("request error = {:?}", e);
return false;
}
};

// if the request was keep-alive, we need to check that the server agrees
// if it wasn't, then the server cannot force it to be true anyways
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}

debug!("keep_alive = {:?} for {}", keep_alive, addr);
if !self.handle_expect(&req, wrt) {
return false;
}

if let Err(e) = req.set_read_timeout(self.timeouts.read) {
error!("set_read_timeout {:?}", e);
return false;
}

let mut keep_alive = self.timeouts.keep_alive.is_some() &&
http::should_keep_alive(req.version, &req.headers);
let version = req.version;
let mut res_headers = Headers::new();
if !keep_alive {
res_headers.set(Connection::close());
}
{
let mut res = Response::new(wrt, &mut res_headers);
res.version = version;
self.handler.handle(req, res);
}

// if the request was keep-alive, we need to check that the server agrees
// if it wasn't, then the server cannot force it to be true anyways
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}

debug!("keep_alive = {:?} for {}", keep_alive, addr);
keep_alive
}

fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {
Expand Down
14 changes: 14 additions & 0 deletions src/server/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! target URI, headers, and message body.
use std::io::{self, Read};
use std::net::SocketAddr;
use std::time::Duration;

use buffer::BufReader;
use net::NetworkStream;
Expand Down Expand Up @@ -64,6 +65,19 @@ impl<'a, 'b: 'a> Request<'a, 'b> {
})
}

/// Set the read timeout of the underlying NetworkStream.
#[cfg(feature = "timeouts")]
#[inline]
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
self.body.get_ref().get_ref().set_read_timeout(timeout)
}

/// Set the read timeout of the underlying NetworkStream.
#[cfg(not(feature = "timeouts"))]
#[inline]
pub fn set_read_timeout(&self, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}
/// Get a reference to the underlying `NetworkStream`.
#[inline]
pub fn downcast_ref<T: NetworkStream>(&self) -> Option<&T> {
Expand Down