diff --git a/src/client/connect.rs b/src/client/connect.rs index fb15e2e870..ec08d7df55 100644 --- a/src/client/connect.rs +++ b/src/client/connect.rs @@ -1,3 +1,4 @@ +use std::error::Error as StdError; use std::fmt; use std::io; //use std::net::SocketAddr; @@ -42,6 +43,7 @@ where T: Service + 'static, #[derive(Clone)] pub struct HttpConnector { dns: dns::Dns, + enforce_http: bool, handle: Handle, } @@ -50,15 +52,26 @@ impl HttpConnector { /// Construct a new HttpConnector. /// /// Takes number of DNS worker threads. + #[inline] pub fn new(threads: usize, handle: &Handle) -> HttpConnector { HttpConnector { dns: dns::Dns::new(threads), + enforce_http: true, handle: handle.clone(), } } + + /// Option to enforce all `Uri`s have the `http` scheme. + /// + /// Enabled by default. + #[inline] + pub fn enforce_http(&mut self, is_enforced: bool) { + self.enforce_http = is_enforced; + } } impl fmt::Debug for HttpConnector { + #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("HttpConnector") .finish() @@ -73,12 +86,18 @@ impl Service for HttpConnector { fn call(&self, uri: Uri) -> Self::Future { debug!("Http::connect({:?})", uri); + + if self.enforce_http { + if uri.scheme() != Some("http") { + return invalid_url(InvalidUrl::NotHttp, &self.handle); + } + } else if uri.scheme().is_none() { + return invalid_url(InvalidUrl::MissingScheme, &self.handle); + } + let host = match uri.host() { Some(s) => s, - None => return HttpConnecting { - state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, "invalid url"))), - handle: self.handle.clone(), - }, + None => return invalid_url(InvalidUrl::MissingAuthority, &self.handle), }; let port = match uri.port() { Some(port) => port, @@ -94,7 +113,37 @@ impl Service for HttpConnector { handle: self.handle.clone(), } } +} + +#[inline] +fn invalid_url(err: InvalidUrl, handle: &Handle) -> HttpConnecting { + HttpConnecting { + state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, err))), + handle: handle.clone(), + } +} + +#[derive(Debug, Clone, Copy)] +enum InvalidUrl { + MissingScheme, + NotHttp, + MissingAuthority, +} + +impl fmt::Display for InvalidUrl { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.description()) + } +} +impl StdError for InvalidUrl { + fn description(&self) -> &str { + match *self { + InvalidUrl::MissingScheme => "invalid URL, missing scheme", + InvalidUrl::NotHttp => "invalid URL, scheme must be http", + InvalidUrl::MissingAuthority => "invalid URL, missing domain", + } + } } /// A Future representing work to connect to a URL. @@ -195,7 +244,7 @@ mod tests { use super::{Connect, HttpConnector}; #[test] - fn test_non_http_url() { + fn test_errors_missing_authority() { let mut core = Core::new().unwrap(); let url = "/foo/bar?baz".parse().unwrap(); let connector = HttpConnector::new(1, &core.handle()); @@ -203,4 +252,22 @@ mod tests { assert_eq!(core.run(connector.connect(url)).unwrap_err().kind(), io::ErrorKind::InvalidInput); } + #[test] + fn test_errors_enforce_http() { + let mut core = Core::new().unwrap(); + let url = "https://example.domain/foo/bar?baz".parse().unwrap(); + let connector = HttpConnector::new(1, &core.handle()); + + assert_eq!(core.run(connector.connect(url)).unwrap_err().kind(), io::ErrorKind::InvalidInput); + } + + + #[test] + fn test_errors_missing_scheme() { + let mut core = Core::new().unwrap(); + let url = "example.domain".parse().unwrap(); + let connector = HttpConnector::new(1, &core.handle()); + + assert_eq!(core.run(connector.connect(url)).unwrap_err().kind(), io::ErrorKind::InvalidInput); + } } diff --git a/src/uri.rs b/src/uri.rs index 1e04b2a212..c4479b0173 100644 --- a/src/uri.rs +++ b/src/uri.rs @@ -95,6 +95,7 @@ impl Uri { } /// Get the path of this `Uri`. + #[inline] pub fn path(&self) -> &str { let index = self.path_start(); let end = self.path_end(); @@ -135,6 +136,7 @@ impl Uri { } /// Get the scheme of this `Uri`. + #[inline] pub fn scheme(&self) -> Option<&str> { if let Some(end) = self.scheme_end { Some(&self.source[..end]) @@ -144,6 +146,7 @@ impl Uri { } /// Get the authority of this `Uri`. + #[inline] pub fn authority(&self) -> Option<&str> { if let Some(end) = self.authority_end { let index = self.scheme_end.map(|i| i + 3).unwrap_or(0); @@ -155,6 +158,7 @@ impl Uri { } /// Get the host of this `Uri`. + #[inline] pub fn host(&self) -> Option<&str> { if let Some(auth) = self.authority() { auth.split(":").next() @@ -164,6 +168,7 @@ impl Uri { } /// Get the port of this `Uri`. + #[inline] pub fn port(&self) -> Option { match self.authority() { Some(auth) => auth.find(":").and_then(|i| u16::from_str(&auth[i+1..]).ok()), @@ -172,6 +177,7 @@ impl Uri { } /// Get the query string of this `Uri`, starting after the `?`. + #[inline] pub fn query(&self) -> Option<&str> { self.query_start.map(|start| { // +1 to remove '?'