diff --git a/src/client/connect/dns.rs b/src/client/connect/dns.rs index 0160a268a3..17d3c7df67 100644 --- a/src/client/connect/dns.rs +++ b/src/client/connect/dns.rs @@ -1,11 +1,26 @@ -//! The `Resolve` trait, support types, and some basic implementations. +//! DNS Resolution used by the `HttpConnector`. //! //! This module contains: //! //! - A [`GaiResolver`](dns::GaiResolver) that is the default resolver for the //! `HttpConnector`. -//! - The [`Resolve`](dns::Resolve) trait and related types to build a custom -//! resolver for use with the `HttpConnector`. +//! - The `Name` type used as an argument to custom resolvers. +//! +//! # Resolvers are `Service`s +//! +//! A resolver is just a +//! `Service>`. +//! +//! A simple resolver that ignores the name and always returns a specific +//! address: +//! +//! ```rust,ignore +//! use std::{convert::Infallible, iter, net::IpAddr}; +//! +//! let resolver = tower::service_fn(|_name| async { +//! Ok::<_, Infallible>(iter::once(IpAddr::from([127, 0, 0, 1]))) +//! }); +//! ``` use std::{fmt, io, vec}; use std::error::Error; use std::net::{ @@ -15,19 +30,10 @@ use std::net::{ }; use std::str::FromStr; -use tokio_sync::{mpsc, oneshot}; +use tower_service::Service; +use crate::common::{Future, Pin, Poll, task}; -use crate::common::{Future, Never, Pin, Poll, task}; - -/// Resolve a hostname to a set of IP addresses. -pub trait Resolve { - /// The set of IP addresses to try to connect to. - type Addrs: Iterator; - /// A Future of the resolved set of addresses. - type Future: Future>; - /// Resolve a hostname. - fn resolve(&self, name: Name) -> Self::Future; -} +pub(super) use self::sealed::Resolve; /// A domain name to resolve into IP addresses. #[derive(Clone, Hash, Eq, PartialEq)] @@ -41,15 +47,12 @@ pub struct GaiResolver { _priv: (), } -#[derive(Clone)] -struct ThreadPoolKeepAlive(mpsc::Sender); - /// An iterator of IP addresses returned from `getaddrinfo`. pub struct GaiAddrs { inner: IpAddrs, } -/// A future to resole a name returned by `GaiResolver`. +/// A future to resolve a name returned by `GaiResolver`. pub struct GaiFuture { inner: tokio_executor::blocking::Blocking>, } @@ -110,11 +113,16 @@ impl GaiResolver { } } -impl Resolve for GaiResolver { - type Addrs = GaiAddrs; +impl Service for GaiResolver { + type Response = GaiAddrs; + type Error = io::Error; type Future = GaiFuture; - fn resolve(&self, name: Name) -> Self::Future { + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, name: Name) -> Self::Future { let blocking = tokio_executor::blocking::run(move || { debug!("resolving host={:?}", name.host); (&*name.host, 0).to_socket_addrs() @@ -164,39 +172,6 @@ impl fmt::Debug for GaiAddrs { } } - -pub(super) struct GaiBlocking { - host: String, - tx: Option>>, -} - -impl GaiBlocking { - fn block(&self) -> io::Result { - debug!("resolving host={:?}", self.host); - (&*self.host, 0).to_socket_addrs() - .map(|i| IpAddrs { iter: i }) - - } -} - -impl Future for GaiBlocking { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - if self.tx.as_mut().expect("polled after complete").poll_closed(cx).is_ready() { - trace!("resolve future canceled for {:?}", self.host); - return Poll::Ready(()); - } - - let res = self.block(); - - let tx = self.tx.take().expect("polled after complete"); - let _ = tx.send(res); - - Poll::Ready(()) - } -} - pub(super) struct IpAddrs { iter: vec::IntoIter, } @@ -276,11 +251,16 @@ impl TokioThreadpoolGaiResolver { } #[cfg(feature = "runtime")] -impl Resolve for TokioThreadpoolGaiResolver { - type Addrs = GaiAddrs; +impl Service for TokioThreadpoolGaiResolver { + type Response = GaiAddrs; + type Error = io::Error; type Future = TokioThreadpoolGaiFuture; - fn resolve(&self, name: Name) -> TokioThreadpoolGaiFuture { + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, name: Name) -> Self::Future { TokioThreadpoolGaiFuture { name } } } @@ -299,6 +279,41 @@ impl Future for TokioThreadpoolGaiFuture { } } +mod sealed { + use tower_service::Service; + use crate::common::{Future, Poll, task}; + use super::{IpAddr, Name}; + + // "Trait alias" for `Service` + pub trait Resolve { + type Addrs: Iterator; + type Error: Into>; + type Future: Future>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll>; + fn resolve(&mut self, name: Name) -> Self::Future; + } + + impl Resolve for S + where + S: Service, + S::Response: Iterator, + S::Error: Into>, + { + type Addrs = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + Service::poll_ready(self, cx) + } + + fn resolve(&mut self, name: Name) -> Self::Future { + Service::call(self, name) + } + } +} + #[cfg(test)] mod tests { use std::net::{Ipv4Addr, Ipv6Addr}; diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index f85802d99d..7a4c9cccc7 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -228,11 +228,18 @@ impl HttpConnector { } } +static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http"; +static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing"; +static INVALID_MISSING_HOST: &str = "invalid URL, host is missing"; + impl HttpConnector { - fn invalid_url(&self, err: InvalidUrl) -> HttpConnecting { + fn invalid_url(&self, msg: impl Into>) -> HttpConnecting { HttpConnecting { config: self.config.clone(), - state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, err))), + state: State::Error(Some(ConnectError { + msg: msg.into(), + cause: None, + })), port: 0, } } @@ -252,14 +259,11 @@ where R::Future: Send, { type Response = (TcpStream, Connected); - type Error = io::Error; + type Error = ConnectError; type Future = HttpConnecting; fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - // For now, always ready. - // TODO: When `Resolve` becomes an alias for `Service`, check - // the resolver's readiness. - drop(cx); + ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?; Poll::Ready(Ok(())) } @@ -273,15 +277,15 @@ where if self.config.enforce_http { if dst.uri.scheme_part() != Some(&Scheme::HTTP) { - return self.invalid_url(InvalidUrl::NotHttp); + return self.invalid_url(INVALID_NOT_HTTP); } } else if dst.uri.scheme_part().is_none() { - return self.invalid_url(InvalidUrl::MissingScheme); + return self.invalid_url(INVALID_MISSING_SCHEME); } let host = match dst.uri.host() { Some(s) => s, - None => return self.invalid_url(InvalidUrl::MissingAuthority), + None => return self.invalid_url(INVALID_MISSING_HOST), }; let port = match dst.uri.port_part() { Some(port) => port.as_u16(), @@ -302,7 +306,7 @@ where R::Future: Send, { type Response = TcpStream; - type Error = io::Error; + type Error = ConnectError; type Future = Pin> + Send + 'static>>; fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { @@ -324,28 +328,73 @@ impl HttpInfo { } } -#[derive(Debug, Clone, Copy)] -enum InvalidUrl { - MissingScheme, - NotHttp, - MissingAuthority, +// Not publicly exported (so missing_docs doesn't trigger). +pub struct ConnectError { + msg: Box, + cause: Option>, +} + +impl ConnectError { + fn new(msg: S, cause: E) -> ConnectError + where + S: Into>, + E: Into>, + { + ConnectError { + msg: msg.into(), + cause: Some(cause.into()), + } + } + + fn dns(cause: E) -> ConnectError + where + E: Into>, + { + ConnectError::new("dns error", cause) + } + + fn m(msg: S) -> impl FnOnce(E) -> ConnectError + where + S: Into>, + E: Into>, + { + move |cause| { + ConnectError::new(msg, cause) + } + } } -impl fmt::Display for InvalidUrl { +impl fmt::Debug for ConnectError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.description()) + if let Some(ref cause) = self.cause { + f.debug_tuple("ConnectError") + .field(&self.msg) + .field(cause) + .finish() + } else { + self.msg.fmt(f) + } } } -impl StdError for InvalidUrl { - fn description(&self) -> &str { - match *self { - InvalidUrl::MissingScheme => "invalid URL, missing scheme", - InvalidUrl::NotHttp => "invalid URL, scheme must be http", - InvalidUrl::MissingAuthority => "invalid URL, missing domain", +impl fmt::Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.msg)?; + + if let Some(ref cause) = self.cause { + write!(f, ": {}", cause)?; } + + Ok(()) } } + +impl StdError for ConnectError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.cause.as_ref().map(|e| &**e as _) + } +} + /// A Future representing work to connect to a URL. #[must_use = "futures do nothing unless polled"] #[pin_project] @@ -361,11 +410,11 @@ enum State { Lazy(R, String), Resolving(#[pin] R::Future), Connecting(ConnectingTcp), - Error(Option), + Error(Option), } impl Future for HttpConnecting { - type Output = Result<(TcpStream, Connected), io::Error>; + type Output = Result<(TcpStream, Connected), ConnectError>; #[project] fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { @@ -375,19 +424,20 @@ impl Future for HttpConnecting { let state; #[project] match me.state.as_mut().project() { - State::Lazy(ref resolver, ref mut host) => { + State::Lazy(ref mut resolver, ref mut host) => { // If the host is already an IP addr (v4 or v6), // skip resolving the dns and start connecting right away. if let Some(addrs) = dns::IpAddrs::try_parse(host, *me.port) { state = State::Connecting(ConnectingTcp::new( config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address)); } else { + ready!(resolver.poll_ready(cx)).map_err(ConnectError::dns)?; let name = dns::Name::new(mem::replace(host, String::new())); state = State::Resolving(resolver.resolve(name)); } }, State::Resolving(future) => { - let addrs = ready!(future.poll(cx))?; + let addrs = ready!(future.poll(cx)).map_err(ConnectError::dns)?; let port = *me.port; let addrs = addrs .map(|addr| SocketAddr::new(addr, port)) @@ -397,24 +447,25 @@ impl Future for HttpConnecting { config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address)); }, State::Connecting(ref mut c) => { - let sock = ready!(c.poll(cx, &config.handle))?; + let sock = ready!(c.poll(cx, &config.handle)) + .map_err(ConnectError::m("tcp connect error"))?; if let Some(dur) = config.keep_alive_timeout { - sock.set_keepalive(Some(dur))?; + sock.set_keepalive(Some(dur)).map_err(ConnectError::m("tcp set_keepalive error"))?; } if let Some(size) = config.send_buffer_size { - sock.set_send_buffer_size(size)?; + sock.set_send_buffer_size(size).map_err(ConnectError::m("tcp set_send_buffer_size error"))?; } if let Some(size) = config.recv_buffer_size { - sock.set_recv_buffer_size(size)?; + sock.set_recv_buffer_size(size).map_err(ConnectError::m("tcp set_recv_buffer_size error"))?; } - sock.set_nodelay(config.nodelay)?; + sock.set_nodelay(config.nodelay).map_err(ConnectError::m("tcp set_nodelay error"))?; let extra = HttpInfo { - remote_addr: sock.peer_addr()?, + remote_addr: sock.peer_addr().map_err(ConnectError::m("tcp peer_addr error"))?, }; let connected = Connected::new() .extra(extra); @@ -642,7 +693,6 @@ impl ConnectingTcp { mod tests { use std::io; - use tokio::runtime::current_thread::Runtime; use tokio_net::driver::Handle; use super::{Connected, Destination, HttpConnector}; @@ -655,55 +705,29 @@ mod tests { connector.connect(super::super::sealed::Internal, dst).await } - #[test] - fn test_errors_missing_authority() { - let mut rt = Runtime::new().unwrap(); - let uri = "/foo/bar?baz".parse().unwrap(); - let dst = Destination { - uri, - }; - let connector = HttpConnector::new(); - - rt.block_on(async { - assert_eq!( - connect(connector, dst).await.unwrap_err().kind(), - io::ErrorKind::InvalidInput, - ); - }) - } - - #[test] - fn test_errors_enforce_http() { - let mut rt = Runtime::new().unwrap(); + #[tokio::test] + async fn test_errors_enforce_http() { let uri = "https://example.domain/foo/bar?baz".parse().unwrap(); let dst = Destination { uri, }; let connector = HttpConnector::new(); - rt.block_on(async { - assert_eq!( - connect(connector, dst).await.unwrap_err().kind(), - io::ErrorKind::InvalidInput, - ); - }) + let err = connect(connector, dst).await.unwrap_err(); + assert_eq!(&*err.msg, super::INVALID_NOT_HTTP); } - #[test] - fn test_errors_missing_scheme() { - let mut rt = Runtime::new().unwrap(); + #[tokio::test] + async fn test_errors_missing_scheme() { let uri = "example.domain".parse().unwrap(); let dst = Destination { uri, }; - let connector = HttpConnector::new(); + let mut connector = HttpConnector::new(); + connector.enforce_http(false); - rt.block_on(async { - assert_eq!( - connect(connector, dst).await.unwrap_err().kind(), - io::ErrorKind::InvalidInput, - ); - }); + let err = connect(connector, dst).await.unwrap_err(); + assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME); } #[test] diff --git a/tests/client.rs b/tests/client.rs index 9401cf1af9..da1a852892 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1722,7 +1722,7 @@ mod dispatch_impl { impl hyper::service::Service for DebugConnector { type Response = (DebugStream, Connected); - type Error = io::Error; + type Error = >::Error; type Future = Pin > + Send>>;