diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 8f3bebb1e..9a7b74c4e 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -20,6 +20,8 @@ use crate::async_impl::h3_client::H3Client; use crate::config::{RequestConfig, TotalTimeout}; #[cfg(unix)] use crate::connect::uds::UnixSocketProvider; +#[cfg(target_os = "windows")] +use crate::connect::windows_named_pipe::WindowsNamedPipeProvider; use crate::connect::{ sealed::{Conn, Unnameable}, BoxedConnectorLayer, BoxedConnectorService, Connector, ConnectorBuilder, @@ -227,6 +229,8 @@ struct Config { #[cfg(unix)] unix_socket: Option>, + #[cfg(target_os = "windows")] + windows_named_pipe: Option>, } impl Default for ClientBuilder { @@ -352,6 +356,8 @@ impl ClientBuilder { dns_resolver: None, #[cfg(unix)] unix_socket: None, + #[cfg(target_os = "windows")] + windows_named_pipe: None, }, } } @@ -886,6 +892,8 @@ impl ClientBuilder { // ways TLS can be configured... #[cfg(unix)] connector_builder.set_unix_socket(config.unix_socket); + #[cfg(target_os = "windows")] + connector_builder.set_windows_named_pipe(config.windows_named_pipe.clone()); let mut builder = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); @@ -1720,6 +1728,26 @@ impl ClientBuilder { self } + /// Set that all connections will use this Windows named pipe. + /// + /// If a request URI uses the `https` scheme, TLS will still be used over + /// the Windows named pipe. + /// + /// # Note + /// + /// This option is not compatible with any of the TCP or Proxy options. + /// Setting this will ignore all those options previously set. + /// + /// Likewise, DNS resolution will not be done on the domain name. + #[cfg(target_os = "windows")] + pub fn windows_named_pipe(mut self, pipe: impl WindowsNamedPipeProvider) -> ClientBuilder { + self.config.windows_named_pipe = Some( + pipe.reqwest_windows_named_pipe_path(crate::connect::windows_named_pipe::Internal) + .into(), + ); + self + } + // TLS options /// Add a custom root certificate. diff --git a/src/connect.rs b/src/connect.rs index ca52dd5e0..23f2fad58 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -82,6 +82,8 @@ pub(crate) struct ConnectorBuilder { resolver: Option, #[cfg(unix)] unix_socket: Option>, + #[cfg(target_os = "windows")] + windows_named_pipe: Option>, } impl ConnectorBuilder { @@ -103,6 +105,8 @@ where { resolver: self.resolver.unwrap_or_else(DynResolver::gai), #[cfg(unix)] unix_socket: self.unix_socket, + #[cfg(target_os = "windows")] + windows_named_pipe: self.windows_named_pipe, }; #[cfg(unix)] @@ -110,6 +114,11 @@ where { base_service.proxies = Default::default(); log::trace!("unix_socket() set, proxies are ignored"); } + #[cfg(target_os = "windows")] + if base_service.windows_named_pipe.is_some() && !base_service.proxies.is_empty() { + base_service.proxies = Default::default(); + log::trace!("windows_named_pipe() set, proxies are ignored"); + } if layers.is_empty() { // we have no user-provided layers, only use concrete types @@ -208,6 +217,8 @@ where { resolver: None, #[cfg(unix)] unix_socket: None, + #[cfg(target_os = "windows")] + windows_named_pipe: None, } } @@ -319,6 +330,8 @@ where { resolver: None, #[cfg(unix)] unix_socket: None, + #[cfg(target_os = "windows")] + windows_named_pipe: None, } } @@ -392,6 +405,8 @@ where { resolver: None, #[cfg(unix)] unix_socket: None, + #[cfg(target_os = "windows")] + windows_named_pipe: None, } } @@ -457,6 +472,11 @@ where { pub(crate) fn set_unix_socket(&mut self, path: Option>) { self.unix_socket = path; } + + #[cfg(target_os = "windows")] + pub(crate) fn set_windows_named_pipe(&mut self, pipe: Option>) { + self.windows_named_pipe = pipe; + } } #[allow(missing_debug_implementations)] @@ -481,6 +501,8 @@ pub(crate) struct ConnectorService { /// If set, this always takes priority over TCP. #[cfg(unix)] unix_socket: Option>, + #[cfg(target_os = "windows")] + windows_named_pipe: Option>, } #[derive(Clone)] @@ -673,21 +695,37 @@ impl ConnectorService { } } - /// Connect over Unix Domain Socket (or Windows?). - #[cfg(unix)] + /// Connect over a local transport: Unix Domain Socket (on Unix) or Windows Named Pipe (on Windows). + #[cfg(any(unix, target_os = "windows"))] async fn connect_local_transport(self, dst: Uri) -> Result { - let path = self - .unix_socket - .as_ref() - .expect("connect local must have socket path") - .clone(); - let svc = tower::service_fn(move |_| { - let fut = tokio::net::UnixStream::connect(path.clone()); - async move { - let io = fut.await?; - Ok::<_, std::io::Error>(TokioIo::new(io)) - } - }); + #[cfg(unix)] + let svc = { + let path = self + .unix_socket + .as_ref() + .expect("connect local must have socket path") + .clone(); + tower::service_fn(move |_| { + let fut = tokio::net::UnixStream::connect(path.clone()); + async move { + let io = fut.await?; + Ok::<_, std::io::Error>(TokioIo::new(io)) + } + }) + }; + #[cfg(target_os = "windows")] + let svc = { + use tokio::net::windows::named_pipe::ClientOptions; + let pipe = self + .windows_named_pipe + .as_ref() + .expect("connect local must have pipe path") + .clone(); + tower::service_fn(move |_| { + let pipe = pipe.clone(); + async move { ClientOptions::new().open(pipe).map(TokioIo::new) } + }) + }; let is_proxy = false; match self.inner { #[cfg(not(feature = "__tls"))] @@ -852,6 +890,15 @@ impl ConnectorService { self.connect_with_maybe_proxy(proxy_dst, true).await } + + #[cfg(any(unix, target_os = "windows"))] + fn should_use_local_transport(&self) -> bool { + #[cfg(unix)] + return self.unix_socket.is_some(); + + #[cfg(target_os = "windows")] + return self.windows_named_pipe.is_some(); + } } async fn with_timeout(f: F, timeout: Option) -> Result @@ -882,9 +929,9 @@ impl Service for ConnectorService { log::debug!("starting new connection: {dst:?}"); let timeout = self.simple_timeout; - // Local transports (UDS) skip proxies - #[cfg(unix)] - if self.unix_socket.is_some() { + // Local transports (UDS, Windows Named Pipes) skip proxies + #[cfg(any(unix, target_os = "windows"))] + if self.should_use_local_transport() { return Box::pin(with_timeout( self.clone().connect_local_transport(dst), timeout, @@ -1104,6 +1151,120 @@ impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream Option { + None + } +} + +#[cfg(feature = "default-tls")] +#[cfg(target_os = "windows")] +impl TlsInfoFactory + for tokio_native_tls::TlsStream< + TokioIo>, + > +{ + fn tls_info(&self) -> Option { + let peer_certificate = self + .get_ref() + .peer_certificate() + .ok() + .flatten() + .and_then(|c| c.to_der().ok()); + Some(crate::tls::TlsInfo { peer_certificate }) + } +} + +#[cfg(feature = "default-tls")] +#[cfg(target_os = "windows")] +impl TlsInfoFactory + for tokio_native_tls::TlsStream< + TokioIo< + hyper_tls::MaybeHttpsStream>, + >, + > +{ + fn tls_info(&self) -> Option { + let peer_certificate = self + .get_ref() + .peer_certificate() + .ok() + .flatten() + .and_then(|c| c.to_der().ok()); + Some(crate::tls::TlsInfo { peer_certificate }) + } +} + +#[cfg(feature = "default-tls")] +#[cfg(target_os = "windows")] +impl TlsInfoFactory + for hyper_tls::MaybeHttpsStream> +{ + fn tls_info(&self) -> Option { + match self { + hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(), + hyper_tls::MaybeHttpsStream::Http(_) => None, + } + } +} + +#[cfg(feature = "__rustls")] +#[cfg(target_os = "windows")] +impl TlsInfoFactory + for tokio_rustls::client::TlsStream< + TokioIo>, + > +{ + fn tls_info(&self) -> Option { + let peer_certificate = self + .get_ref() + .1 + .peer_certificates() + .and_then(|certs| certs.first()) + .map(|c| c.to_vec()); + Some(crate::tls::TlsInfo { peer_certificate }) + } +} + +#[cfg(feature = "__rustls")] +#[cfg(target_os = "windows")] +impl TlsInfoFactory + for tokio_rustls::client::TlsStream< + TokioIo< + hyper_rustls::MaybeHttpsStream< + TokioIo, + >, + >, + > +{ + fn tls_info(&self) -> Option { + let peer_certificate = self + .get_ref() + .1 + .peer_certificates() + .and_then(|certs| certs.first()) + .map(|c| c.to_vec()); + Some(crate::tls::TlsInfo { peer_certificate }) + } +} + +#[cfg(feature = "__rustls")] +#[cfg(target_os = "windows")] +impl TlsInfoFactory + for hyper_rustls::MaybeHttpsStream> +{ + fn tls_info(&self) -> Option { + match self { + hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), + hyper_rustls::MaybeHttpsStream::Http(_) => None, + } + } +} + pub(crate) trait AsyncConn: Read + Write + Connection + Send + Sync + Unpin + 'static { @@ -1249,6 +1410,39 @@ pub(crate) mod uds { ]; } +// Sealed trait for Windows Named Pipe support +#[cfg(target_os = "windows")] +pub(crate) mod windows_named_pipe { + use std::ffi::OsStr; + /// A provider for Windows Named Pipe paths. + /// + /// This trait is sealed. This allows us to expand support in the future + /// by controlling who can implement the trait. + #[cfg(target_os = "windows")] + pub trait WindowsNamedPipeProvider { + #[doc(hidden)] + fn reqwest_windows_named_pipe_path(&self, _: Internal) -> &OsStr; + } + + #[allow(missing_debug_implementations)] + pub struct Internal; + + macro_rules! as_os_str { + ($($t:ty,)+) => { + $( + impl WindowsNamedPipeProvider for $t { + #[doc(hidden)] + fn reqwest_windows_named_pipe_path(&self, _: Internal) -> &OsStr { + self.as_ref() + } + } + )+ + } + } + + as_os_str![String, &'_ str,]; +} + pub(crate) type Connecting = Pin> + Send>>; #[cfg(feature = "default-tls")] @@ -1342,6 +1536,40 @@ mod native_tls_conn { } } + #[cfg(target_os = "windows")] + impl Connection + for NativeTlsConn>> + { + fn connected(&self) -> Connected { + let connected = Connected::new(); + #[cfg(feature = "native-tls-alpn")] + match self.inner.inner().get_ref().negotiated_alpn().ok() { + Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(), + _ => connected, + } + #[cfg(not(feature = "native-tls-alpn"))] + connected + } + } + + #[cfg(target_os = "windows")] + impl Connection + for NativeTlsConn< + TokioIo>>, + > + { + fn connected(&self) -> Connected { + let connected = Connected::new(); + #[cfg(feature = "native-tls-alpn")] + match self.inner.inner().get_ref().negotiated_alpn().ok() { + Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(), + _ => connected, + } + #[cfg(not(feature = "native-tls-alpn"))] + connected + } + } + impl Read for NativeTlsConn { fn poll_read( self: Pin<&mut Self>, @@ -1491,6 +1719,46 @@ mod rustls_tls_conn { } } + #[cfg(target_os = "windows")] + impl Connection + for RustlsTlsConn>> + { + fn connected(&self) -> Connected { + if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") { + self.inner + .inner() + .get_ref() + .0 + .inner() + .connected() + .negotiated_h2() + } else { + self.inner.inner().get_ref().0.inner().connected() + } + } + } + + #[cfg(target_os = "windows")] + impl Connection + for RustlsTlsConn< + TokioIo>>, + > + { + fn connected(&self) -> Connected { + if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") { + self.inner + .inner() + .get_ref() + .0 + .inner() + .connected() + .negotiated_h2() + } else { + self.inner.inner().get_ref().0.inner().connected() + } + } + } + impl Read for RustlsTlsConn { fn poll_read( self: Pin<&mut Self>,