From ba2a144f8b81042247088215425f91760d8694a1 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 13 Jan 2020 11:45:28 -0800 Subject: [PATCH] fix(client): strip path from Uri before calling Connector (#2109) --- Cargo.toml | 1 + src/client/mod.rs | 36 +++++++++++++++++++++--------------- src/client/pool.rs | 23 +++++++++++++---------- src/client/tests.rs | 33 ++++++++++++++++++++++++--------- 4 files changed, 59 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e633f13a16..796c114007 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ serde_derive = "1.0" serde_json = "1.0" tokio = { version = "0.2.2", features = ["fs", "macros", "io-std", "rt-util", "sync", "time", "test-util"] } tokio-test = "0.2" +tower-util = "0.3" url = "1.0" [features] diff --git a/src/client/mod.rs b/src/client/mod.rs index 1071de2f82..ce598d40b3 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -50,7 +50,6 @@ use std::fmt; use std::mem; -use std::sync::Arc; use std::time::Duration; use futures_channel::oneshot; @@ -230,14 +229,13 @@ where other => return ResponseFuture::error_version(other), }; - let domain = match extract_domain(req.uri_mut(), is_http_connect) { + let pool_key = match extract_domain(req.uri_mut(), is_http_connect) { Ok(s) => s, Err(err) => { return ResponseFuture::new(Box::new(future::err(err))); } }; - let pool_key = Arc::new(domain); ResponseFuture::new(Box::new(self.retryably_send_request(req, pool_key))) } @@ -281,7 +279,7 @@ where mut req: Request, pool_key: PoolKey, ) -> impl Future, ClientError>> + Unpin { - let conn = self.connection_for(req.uri().clone(), pool_key); + let conn = self.connection_for(pool_key); let set_host = self.config.set_host; let executor = self.conn_builder.exec.clone(); @@ -377,7 +375,6 @@ where fn connection_for( &self, - uri: Uri, pool_key: PoolKey, ) -> impl Future>, ClientError>> { // This actually races 2 different futures to try to get a ready @@ -394,7 +391,7 @@ where // connection future is spawned into the runtime to complete, // and then be inserted into the pool as an idle connection. let checkout = self.pool.checkout(pool_key.clone()); - let connect = self.connect_to(uri, pool_key); + let connect = self.connect_to(pool_key); let executor = self.conn_builder.exec.clone(); // The order of the `select` is depended on below... @@ -455,7 +452,6 @@ where fn connect_to( &self, - uri: Uri, pool_key: PoolKey, ) -> impl Lazy>>> + Unpin { let executor = self.conn_builder.exec.clone(); @@ -464,7 +460,7 @@ where let ver = self.config.ver; let is_ver_h2 = ver == Ver::Http2; let connector = self.connector.clone(); - let dst = uri; + let dst = domain_as_uri(pool_key.clone()); hyper_lazy(move || { // Try to take a "connecting lock". // @@ -794,22 +790,22 @@ fn authority_form(uri: &mut Uri) { }; } -fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result { +fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result { let uri_clone = uri.clone(); match (uri_clone.scheme(), uri_clone.authority()) { - (Some(scheme), Some(auth)) => Ok(format!("{}://{}", scheme, auth)), + (Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())), (None, Some(auth)) if is_http_connect => { let scheme = match auth.port_u16() { Some(443) => { set_scheme(uri, Scheme::HTTPS); - "https" + Scheme::HTTPS } _ => { set_scheme(uri, Scheme::HTTP); - "http" + Scheme::HTTP } }; - Ok(format!("{}://{}", scheme, auth)) + Ok((scheme, auth.clone())) } _ => { debug!("Client requires absolute-form URIs, received: {:?}", uri); @@ -818,6 +814,15 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result } } +fn domain_as_uri((scheme, auth): PoolKey) -> Uri { + http::uri::Builder::new() + .scheme(scheme) + .authority(auth) + .path_and_query("/") + .build() + .expect("domain is valid Uri") +} + fn set_scheme(uri: &mut Uri, scheme: Scheme) { debug_assert!( uri.scheme().is_none(), @@ -1126,7 +1131,8 @@ mod unit_tests { #[test] fn test_extract_domain_connect_no_port() { let mut uri = "hyper.rs".parse().unwrap(); - let domain = extract_domain(&mut uri, true).expect("extract domain"); - assert_eq!(domain, "http://hyper.rs"); + let (scheme, host) = extract_domain(&mut uri, true).expect("extract domain"); + assert_eq!(scheme, *"http"); + assert_eq!(host, "hyper.rs"); } } diff --git a/src/client/pool.rs b/src/client/pool.rs index a9f3991cd2..e98dcddecf 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -52,7 +52,7 @@ pub(super) enum Reservation { } /// Simple type alias in case the key type needs to be adjusted. -pub(super) type Key = Arc; +pub(super) type Key = (http::uri::Scheme, http::uri::Authority); //Arc; struct PoolInner { // A flag that a connection is being established, and the connection @@ -755,7 +755,6 @@ impl WeakOpt { #[cfg(test)] mod tests { - use std::sync::Arc; use std::task::Poll; use std::time::Duration; @@ -787,6 +786,10 @@ mod tests { } } + fn host_key(s: &str) -> Key { + (http::uri::Scheme::HTTP, s.parse().expect("host key")) + } + fn pool_no_timer() -> Pool { pool_max_idle_no_timer(::std::usize::MAX) } @@ -807,7 +810,7 @@ mod tests { #[tokio::test] async fn test_pool_checkout_smoke() { let pool = pool_no_timer(); - let key = Arc::new("foo".to_string()); + let key = host_key("foo"); let pooled = pool.pooled(c(key.clone()), Uniq(41)); drop(pooled); @@ -839,7 +842,7 @@ mod tests { #[tokio::test] async fn test_pool_checkout_returns_none_if_expired() { let pool = pool_no_timer(); - let key = Arc::new("foo".to_string()); + let key = host_key("foo"); let pooled = pool.pooled(c(key.clone()), Uniq(41)); drop(pooled); @@ -854,7 +857,7 @@ mod tests { #[tokio::test] async fn test_pool_checkout_removes_expired() { let pool = pool_no_timer(); - let key = Arc::new("foo".to_string()); + let key = host_key("foo"); pool.pooled(c(key.clone()), Uniq(41)); pool.pooled(c(key.clone()), Uniq(5)); @@ -876,7 +879,7 @@ mod tests { #[test] fn test_pool_max_idle_per_host() { let pool = pool_max_idle_no_timer(2); - let key = Arc::new("foo".to_string()); + let key = host_key("foo"); pool.pooled(c(key.clone()), Uniq(41)); pool.pooled(c(key.clone()), Uniq(5)); @@ -904,7 +907,7 @@ mod tests { &Exec::Default, ); - let key = Arc::new("foo".to_string()); + let key = host_key("foo"); pool.pooled(c(key.clone()), Uniq(41)); pool.pooled(c(key.clone()), Uniq(5)); @@ -929,7 +932,7 @@ mod tests { use futures_util::FutureExt; let pool = pool_no_timer(); - let key = Arc::new("foo".to_string()); + let key = host_key("foo"); let pooled = pool.pooled(c(key.clone()), Uniq(41)); let checkout = join(pool.checkout(key), async { @@ -948,7 +951,7 @@ mod tests { #[tokio::test] async fn test_pool_checkout_drop_cleans_up_waiters() { let pool = pool_no_timer::>(); - let key = Arc::new("localhost:12345".to_string()); + let key = host_key("foo"); let mut checkout1 = pool.checkout(key.clone()); let mut checkout2 = pool.checkout(key.clone()); @@ -993,7 +996,7 @@ mod tests { #[test] fn pooled_drop_if_closed_doesnt_reinsert() { let pool = pool_no_timer(); - let key = Arc::new("localhost:12345".to_string()); + let key = host_key("foo"); pool.pooled( c(key.clone()), CanClose { diff --git a/src/client/tests.rs b/src/client/tests.rs index 8088f0eaa4..cf3248c7eb 100644 --- a/src/client/tests.rs +++ b/src/client/tests.rs @@ -1,15 +1,30 @@ -// FIXME: re-implement tests with `async/await` -/* -#![cfg(feature = "runtime")] +use std::io; + +use futures_util::future; +use tokio::net::TcpStream; -use futures::{Async, Future, Stream}; -use futures::future::poll_fn; -use futures::sync::oneshot; -use tokio::runtime::current_thread::Runtime; +use super::Client; -use crate::mock::MockConnector; -use super::*; +#[tokio::test] +async fn client_connect_uri_argument() { + let connector = tower_util::service_fn(|dst: http::Uri| { + assert_eq!(dst.scheme(), Some(&http::uri::Scheme::HTTP)); + assert_eq!(dst.host(), Some("example.local")); + assert_eq!(dst.port(), None); + assert_eq!(dst.path(), "/", "path should be removed"); + + future::err::(io::Error::new(io::ErrorKind::Other, "expect me")) + }); + let client = Client::builder().build::<_, crate::Body>(connector); + let _ = client + .get("http://example.local/and/a/path".parse().unwrap()) + .await + .expect_err("response should fail"); +} + +/* +// FIXME: re-implement tests with `async/await` #[test] fn retryable_request() { let _ = pretty_env_logger::try_init();