From 27db8b0061f85d89ec94e07295463e8d1030d94f Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 18 Jun 2018 15:21:41 -0700 Subject: [PATCH] feat(client): add `set_scheme`, `set_host`, and `set_port` for `Destination` Closes #1564 --- src/client/connect.rs | 257 +++++++++++++++++++++++++++++++++++++++++- src/error.rs | 6 + 2 files changed, 262 insertions(+), 1 deletion(-) diff --git a/src/client/connect.rs b/src/client/connect.rs index fdcc84a530..d347426e3c 100644 --- a/src/client/connect.rs +++ b/src/client/connect.rs @@ -6,9 +6,11 @@ //! establishes connections over TCP. //! - The [`Connect`](Connect) trait and related types to build custom connectors. use std::error::Error as StdError; +use std::mem; +use bytes::{BufMut, BytesMut}; use futures::Future; -use http::Uri; +use http::{uri, Uri}; use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "runtime")] pub use self::http::HttpConnector; @@ -79,6 +81,144 @@ impl Destination { self.uri.port() } + /// Update the scheme of this destination. + /// + /// # Example + /// + /// ```rust + /// # use hyper::client::connect::Destination; + /// # fn with_dst(mut dst: Destination) { + /// // let mut dst = some_destination... + /// // Change from "http://"... + /// assert_eq!(dst.scheme(), "http"); + /// + /// // to "ws://"... + /// dst.set_scheme("ws"); + /// assert_eq!(dst.scheme(), "ws"); + /// # } + /// ``` + /// + /// # Error + /// + /// Returns an error if the string is not a valid scheme. + pub fn set_scheme(&mut self, scheme: &str) -> ::Result<()> { + let scheme = scheme.parse().map_err(::error::Parse::from)?; + self.update_uri(move |parts| { + parts.scheme = Some(scheme); + }) + } + + /// Update the host of this destination. + /// + /// # Example + /// + /// ```rust + /// # use hyper::client::connect::Destination; + /// # fn with_dst(mut dst: Destination) { + /// // let mut dst = some_destination... + /// // Change from "hyper.rs"... + /// assert_eq!(dst.host(), "hyper.rs"); + /// + /// // to "some.proxy"... + /// dst.set_host("some.proxy"); + /// assert_eq!(dst.host(), "some.proxy"); + /// # } + /// ``` + /// + /// # Error + /// + /// Returns an error if the string is not a valid hostname. + pub fn set_host(&mut self, host: &str) -> ::Result<()> { + if host.contains(&['@',':'][..]) { + return Err(::error::Parse::Uri.into()); + } + let auth = if let Some(port) = self.port() { + format!("{}:{}", host, port).parse().map_err(::error::Parse::from)? + } else { + host.parse().map_err(::error::Parse::from)? + }; + self.update_uri(move |parts| { + parts.authority = Some(auth); + }) + } + + /// Update the port of this destination. + /// + /// # Example + /// + /// ```rust + /// # use hyper::client::connect::Destination; + /// # fn with_dst(mut dst: Destination) { + /// // let mut dst = some_destination... + /// // Change from "None"... + /// assert_eq!(dst.port(), None); + /// + /// // to "4321"... + /// dst.set_port(4321); + /// assert_eq!(dst.port(), Some(4321)); + /// + /// // Or remove the port... + /// dst.set_port(None); + /// assert_eq!(dst.port(), None); + /// # } + /// ``` + pub fn set_port

(&mut self, port: P) + where + P: Into>, + { + self.set_port_opt(port.into()); + } + + fn set_port_opt(&mut self, port: Option) { + use std::fmt::Write; + + let auth = if let Some(port) = port { + let host = self.host(); + // Need space to copy the hostname, plus ':', + // plus max 5 port digits... + let cap = host.len() + 1 + 5; + let mut buf = BytesMut::with_capacity(cap); + buf.put_slice(host.as_bytes()); + buf.put_u8(b':'); + write!(buf, "{}", port) + .expect("should have space for 5 digits"); + + uri::Authority::from_shared(buf.freeze()) + .expect("valid host + :port should be valid authority") + } else { + self.host().parse() + .expect("valid host without port should be valid authority") + }; + + self.update_uri(move |parts| { + parts.authority = Some(auth); + }) + .expect("valid uri should be valid with port"); + } + + fn update_uri(&mut self, f: F) -> ::Result<()> + where + F: FnOnce(&mut uri::Parts) + { + // Need to store a default Uri while we modify the current one... + let old_uri = mem::replace(&mut self.uri, Uri::default()); + // However, mutate a clone, so we can revert if there's an error... + let mut parts: uri::Parts = old_uri.clone().into(); + + f(&mut parts); + + match Uri::from_parts(parts) { + Ok(uri) => { + self.uri = uri; + Ok(()) + }, + Err(err) => { + self.uri = old_uri; + Err(::error::Parse::from(err).into()) + }, + } + } + /* /// Returns whether this connection must negotiate HTTP/2 via ALPN. pub fn must_h2(&self) -> bool { @@ -121,6 +261,121 @@ impl Connected { */ } +#[cfg(test)] +mod tests { + use super::Destination; + + #[test] + fn test_destination_set_scheme() { + let mut dst = Destination { + uri: "http://hyper.rs".parse().expect("initial parse"), + }; + + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "hyper.rs"); + + dst.set_scheme("https").expect("set https"); + assert_eq!(dst.scheme(), "https"); + assert_eq!(dst.host(), "hyper.rs"); + + dst.set_scheme("").unwrap_err(); + assert_eq!(dst.scheme(), "https", "error doesn't modify dst"); + assert_eq!(dst.host(), "hyper.rs", "error doesn't modify dst"); + } + + #[test] + fn test_destination_set_host() { + let mut dst = Destination { + uri: "http://hyper.rs".parse().expect("initial parse"), + }; + + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "hyper.rs"); + assert_eq!(dst.port(), None); + + dst.set_host("seanmonstar.com").expect("set https"); + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "seanmonstar.com"); + assert_eq!(dst.port(), None); + + dst.set_host("/im-not a host! >:)").unwrap_err(); + assert_eq!(dst.scheme(), "http", "error doesn't modify dst"); + assert_eq!(dst.host(), "seanmonstar.com", "error doesn't modify dst"); + assert_eq!(dst.port(), None, "error doesn't modify dst"); + + // Also test that an exist port is set correctly. + let mut dst = Destination { + uri: "http://hyper.rs:8080".parse().expect("initial parse 2"), + }; + + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "hyper.rs"); + assert_eq!(dst.port(), Some(8080)); + + dst.set_host("seanmonstar.com").expect("set host"); + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "seanmonstar.com"); + assert_eq!(dst.port(), Some(8080)); + + dst.set_host("/im-not a host! >:)").unwrap_err(); + assert_eq!(dst.scheme(), "http", "error doesn't modify dst"); + assert_eq!(dst.host(), "seanmonstar.com", "error doesn't modify dst"); + assert_eq!(dst.port(), Some(8080), "error doesn't modify dst"); + + // Check port isn't snuck into `set_host`. + dst.set_host("seanmonstar.com:3030").expect_err("set_host sneaky port"); + assert_eq!(dst.scheme(), "http", "error doesn't modify dst"); + assert_eq!(dst.host(), "seanmonstar.com", "error doesn't modify dst"); + assert_eq!(dst.port(), Some(8080), "error doesn't modify dst"); + + // Check userinfo isn't snuck into `set_host`. + dst.set_host("sean@nope").expect_err("set_host sneaky userinfo"); + assert_eq!(dst.scheme(), "http", "error doesn't modify dst"); + assert_eq!(dst.host(), "seanmonstar.com", "error doesn't modify dst"); + assert_eq!(dst.port(), Some(8080), "error doesn't modify dst"); + } + + #[test] + fn test_destination_set_port() { + let mut dst = Destination { + uri: "http://hyper.rs".parse().expect("initial parse"), + }; + + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "hyper.rs"); + assert_eq!(dst.port(), None); + + dst.set_port(None); + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "hyper.rs"); + assert_eq!(dst.port(), None); + + dst.set_port(8080); + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "hyper.rs"); + assert_eq!(dst.port(), Some(8080)); + + // Also test that an exist port is set correctly. + let mut dst = Destination { + uri: "http://hyper.rs:8080".parse().expect("initial parse 2"), + }; + + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "hyper.rs"); + assert_eq!(dst.port(), Some(8080)); + + dst.set_port(3030); + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "hyper.rs"); + assert_eq!(dst.port(), Some(3030)); + + dst.set_port(None); + assert_eq!(dst.scheme(), "http"); + assert_eq!(dst.host(), "hyper.rs"); + assert_eq!(dst.port(), None); + } +} + #[cfg(feature = "runtime")] mod http { use super::*; diff --git a/src/error.rs b/src/error.rs index 337bea0c5b..45feec53d7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -350,6 +350,12 @@ impl From for Parse { } } +impl From for Parse { + fn from(_: http::uri::InvalidUriParts) -> Parse { + Parse::Uri + } +} + #[doc(hidden)] trait AssertSendSync: Send + Sync + 'static {} #[doc(hidden)]