diff --git a/src/client/mod.rs b/src/client/mod.rs index fdc53413a5..5587edb2f0 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -46,15 +46,17 @@ use status::StatusClass::Redirection; use {Url, HttpResult}; use HttpError::HttpUriError; +pub use self::pool::Pool; pub use self::request::Request; pub use self::response::Response; +pub mod pool; pub mod request; pub mod response; /// A Client to use additional features with Requests. /// -/// Clients can handle things such as: redirect policy. +/// Clients can handle things such as: redirect policy, connection pooling. pub struct Client { connector: Connector, redirect_policy: RedirectPolicy, @@ -64,7 +66,12 @@ impl Client { /// Create a new Client. pub fn new() -> Client { - Client::with_connector(HttpConnector(None)) + Client::with_pool_config(Default::default()) + } + + /// Create a new Client with a configured Pool Config. + pub fn with_pool_config(config: pool::Config) -> Client { + Client::with_connector(Pool::new(config)) } /// Create a new client with a specific connector. @@ -78,7 +85,10 @@ impl Client { /// Set the SSL verifier callback for use with OpenSSL. pub fn set_ssl_verifier(&mut self, verifier: ContextVerifier) { - self.connector = with_connector(HttpConnector(Some(verifier))); + self.connector = with_connector(Pool::with_connector( + Default::default(), + HttpConnector(Some(verifier)) + )); } /// Set the RedirectPolicy. diff --git a/src/client/pool.rs b/src/client/pool.rs new file mode 100644 index 0000000000..1b3a8dacba --- /dev/null +++ b/src/client/pool.rs @@ -0,0 +1,227 @@ +//! Client Connection Pooling +use std::borrow::ToOwned; +use std::collections::HashMap; +use std::io::{self, Read, Write}; +use std::net::{SocketAddr, Shutdown}; +use std::sync::{Arc, Mutex}; + +use net::{NetworkConnector, NetworkStream, HttpConnector}; + +/// The `NetworkConnector` that behaves as a connection pool used by hyper's `Client`. +pub struct Pool { + connector: C, + inner: Arc::Stream>>> +} + +/// Config options for the `Pool`. +#[derive(Debug)] +pub struct Config { + /// The maximum idle connections *per host*. + pub max_idle: usize, +} + +impl Default for Config { + #[inline] + fn default() -> Config { + Config { + max_idle: 5, + } + } +} + +#[derive(Debug)] +struct PoolImpl { + conns: HashMap>, + config: Config, +} + +type Key = (String, u16, Scheme); + +fn key>(host: &str, port: u16, scheme: T) -> Key { + (host.to_owned(), port, scheme.into()) +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash)] +enum Scheme { + Http, + Https, + Other(String) +} + +impl<'a> From<&'a str> for Scheme { + fn from(s: &'a str) -> Scheme { + match s { + "http" => Scheme::Http, + "https" => Scheme::Https, + s => Scheme::Other(String::from(s)) + } + } +} + +impl Pool { + /// Creates a `Pool` with an `HttpConnector`. + #[inline] + pub fn new(config: Config) -> Pool { + Pool::with_connector(config, HttpConnector(None)) + } +} + +impl Pool { + /// Creates a `Pool` with a specified `NetworkConnector`. + #[inline] + pub fn with_connector(config: Config, connector: C) -> Pool { + Pool { + connector: connector, + inner: Arc::new(Mutex::new(PoolImpl { + conns: HashMap::new(), + config: config, + })) + } + } + + /// Clear all idle connections from the Pool, closing them. + #[inline] + pub fn clear_idle(&mut self) { + self.inner.lock().unwrap().conns.clear(); + } +} + +impl PoolImpl { + fn reuse(&mut self, key: Key, conn: S) { + trace!("reuse {:?}", key); + let conns = self.conns.entry(key).or_insert(vec![]); + if conns.len() < self.config.max_idle { + conns.push(conn); + } + } +} + +impl, S: NetworkStream + Send> NetworkConnector for Pool { + type Stream = PooledStream; + fn connect(&mut self, host: &str, port: u16, scheme: &str) -> io::Result> { + let key = key(host, port, scheme); + let mut locked = self.inner.lock().unwrap(); + let mut should_remove = false; + let conn = match locked.conns.get_mut(&key) { + Some(ref mut vec) => { + should_remove = vec.len() == 1; + vec.pop().unwrap() + } + _ => try!(self.connector.connect(host, port, scheme)) + }; + if should_remove { + locked.conns.remove(&key); + } + Ok(PooledStream { + inner: Some((key, conn)), + is_closed: false, + is_drained: false, + pool: self.inner.clone() + }) + } +} + +/// A Stream that will try to be returned to the Pool when dropped. +pub struct PooledStream { + inner: Option<(Key, S)>, + is_closed: bool, + is_drained: bool, + pool: Arc>> +} + +impl Read for PooledStream { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self.inner.as_mut().unwrap().1.read(buf) { + Ok(0) => { + self.is_drained = true; + Ok(0) + } + r => r + } + } +} + +impl Write for PooledStream { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.as_mut().unwrap().1.write(buf) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + self.inner.as_mut().unwrap().1.flush() + } +} + +impl NetworkStream for PooledStream { + #[inline] + fn peer_addr(&mut self) -> io::Result { + self.inner.as_mut().unwrap().1.peer_addr() + } + + #[inline] + fn close(&mut self, how: Shutdown) -> io::Result<()> { + self.is_closed = true; + self.inner.as_mut().unwrap().1.close(how) + } +} + +impl Drop for PooledStream { + fn drop(&mut self) { + trace!("PooledStream.drop, is_closed={}, is_drained={}", self.is_closed, self.is_drained); + if !self.is_closed && self.is_drained { + self.inner.take().map(|(key, conn)| { + if let Ok(mut pool) = self.pool.lock() { + pool.reuse(key, conn); + } + // else poisoned, give up + }); + } + } +} + +#[cfg(test)] +mod tests { + use std::net::Shutdown; + use mock::MockConnector; + use net::{NetworkConnector, NetworkStream}; + + use super::{Pool, key}; + + macro_rules! mocked { + () => ({ + Pool::with_connector(Default::default(), MockConnector) + }) + } + + #[test] + fn test_connect_and_drop() { + let mut pool = mocked!(); + let key = key("127.0.0.1", 3000, "http"); + pool.connect("127.0.0.1", 3000, "http").unwrap().is_drained = true; + { + let locked = pool.inner.lock().unwrap(); + assert_eq!(locked.conns.len(), 1); + assert_eq!(locked.conns.get(&key).unwrap().len(), 1); + } + pool.connect("127.0.0.1", 3000, "http").unwrap().is_drained = true; //reused + { + let locked = pool.inner.lock().unwrap(); + assert_eq!(locked.conns.len(), 1); + assert_eq!(locked.conns.get(&key).unwrap().len(), 1); + } + } + + #[test] + fn test_closed() { + let mut pool = mocked!(); + let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap(); + stream.close(Shutdown::Both).unwrap(); + drop(stream); + let locked = pool.inner.lock().unwrap(); + assert_eq!(locked.conns.len(), 0); + } + + +} diff --git a/src/client/request.rs b/src/client/request.rs index 1ad4ee7679..eb2e0ea336 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -1,6 +1,7 @@ //! Client Requests use std::marker::PhantomData; use std::io::{self, Write, BufWriter}; +use std::net::Shutdown; use url::Url; @@ -8,7 +9,7 @@ use method::{self, Method}; use header::Headers; use header::{self, Host}; use net::{NetworkStream, NetworkConnector, HttpConnector, Fresh, Streaming}; -use http::{HttpWriter, LINE_ENDING}; +use http::{self, HttpWriter, LINE_ENDING}; use http::HttpWriter::{ThroughWriter, ChunkedWriter, SizedWriter, EmptyWriter}; use version; use HttpResult; @@ -154,7 +155,10 @@ impl Request { /// /// Consumes the Request. pub fn send(self) -> HttpResult { - let raw = try!(self.body.end()).into_inner().unwrap(); // end() already flushes + let mut raw = try!(self.body.end()).into_inner().unwrap(); // end() already flushes + if !http::should_keep_alive(self.version, &self.headers) { + try!(raw.close(Shutdown::Write)); + } Response::new(raw) } } diff --git a/src/client/response.rs b/src/client/response.rs index 56e0528cf8..74175140d4 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -1,6 +1,7 @@ //! Client Responses use std::io::{self, Read}; use std::marker::PhantomData; +use std::net::Shutdown; use buffer::BufReader; use header; @@ -42,6 +43,10 @@ impl Response { debug!("version={:?}, status={:?}", head.version, status); debug!("headers={:?}", headers); + if !http::should_keep_alive(head.version, &headers) { + try!(stream.get_mut().close(Shutdown::Write)); + } + let body = if headers.has::() { match headers.get::() { Some(&TransferEncoding(ref codings)) => { diff --git a/src/http.rs b/src/http.rs index 593e1e0cde..a8204ed52b 100644 --- a/src/http.rs +++ b/src/http.rs @@ -7,7 +7,8 @@ use std::fmt; use httparse; use buffer::BufReader; -use header::Headers; +use header::{Headers, Connection}; +use header::ConnectionOption::{Close, KeepAlive}; use method::Method; use status::StatusCode; use uri::RequestUri; @@ -443,6 +444,15 @@ pub const LINE_ENDING: &'static str = "\r\n"; #[derive(Clone, PartialEq, Debug)] pub struct RawStatus(pub u16, pub Cow<'static, str>); +/// Checks if a connection should be kept alive. +pub fn should_keep_alive(version: HttpVersion, headers: &Headers) -> bool { + match (version, headers.get::()) { + (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false, + (Http11, Some(conn)) if conn.contains(&Close) => false, + _ => true + } +} + #[cfg(test)] mod tests { use std::io::{self, Write}; diff --git a/src/net.rs b/src/net.rs index 30ec4dc9f8..2146d413e7 100644 --- a/src/net.rs +++ b/src/net.rs @@ -2,7 +2,7 @@ use std::any::{Any, TypeId}; use std::fmt; use std::io::{self, Read, Write}; -use std::net::{SocketAddr, ToSocketAddrs, TcpStream, TcpListener}; +use std::net::{SocketAddr, ToSocketAddrs, TcpStream, TcpListener, Shutdown}; use std::mem; use std::path::Path; use std::sync::Arc; @@ -57,6 +57,10 @@ impl<'a, N: NetworkListener + 'a> Iterator for NetworkConnections<'a, N> { pub trait NetworkStream: Read + Write + Any + Send + Typeable { /// Get the remote address of the underlying connection. fn peer_addr(&mut self) -> io::Result; + /// This will be called when Stream should no longer be kept alive. + fn close(&mut self, _how: Shutdown) -> io::Result<()> { + Ok(()) + } } /// A connector creates a NetworkStream. @@ -123,6 +127,7 @@ impl NetworkStream + Send { } /// If the underlying type is T, extract it. + #[inline] pub fn downcast(self: Box) -> Result, Box> { if self.is::() { @@ -277,12 +282,21 @@ impl Write for HttpStream { } impl NetworkStream for HttpStream { + #[inline] fn peer_addr(&mut self) -> io::Result { match *self { HttpStream::Http(ref mut inner) => inner.0.peer_addr(), HttpStream::Https(ref mut inner) => inner.get_mut().0.peer_addr() } } + + #[inline] + fn close(&mut self, how: Shutdown) -> io::Result<()> { + match *self { + HttpStream::Http(ref mut inner) => inner.0.shutdown(how), + HttpStream::Https(ref mut inner) => inner.get_mut().0.shutdown(how) + } + } } /// A connector that will produce HttpStreams. diff --git a/src/server/mod.rs b/src/server/mod.rs index 84773dc84e..25986a861f 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -36,13 +36,13 @@ pub use net::{Fresh, Streaming}; use HttpError::HttpIoError; use {HttpResult}; use buffer::BufReader; -use header::{Headers, Connection, Expect}; -use header::ConnectionOption::{Close, KeepAlive}; +use header::{Headers, Expect}; +use http; use method::Method; use net::{NetworkListener, NetworkStream, HttpListener}; use status::StatusCode; use uri::RequestUri; -use version::HttpVersion::{Http10, Http11}; +use version::HttpVersion::Http11; use self::listener::ListenerPool; @@ -206,11 +206,7 @@ where S: NetworkStream + Clone, H: Handler { } } - keep_alive = match (req.version, req.headers.get::()) { - (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false, - (Http11, Some(conn)) if conn.contains(&Close) => false, - _ => true - }; + keep_alive = http::should_keep_alive(req.version, &req.headers); let mut res = Response::new(&mut wrt); res.version = req.version; handler.handle(req, res);