From 6e56341d8ab8a870a278b49a805bc501b25225a7 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Fri, 9 Aug 2024 14:52:28 +0000 Subject: [PATCH] Split the client into service layers Individual pieces of client functionality are now separate Service implementors. --- src/client/builder.rs | 18 +- src/client/mod.rs | 8 +- src/client/pool/checkout.rs | 130 +++++--- src/client/pool/mod.rs | 5 + src/client/pool/service.rs | 455 +++++++++++++++++++++++++++ src/client/service.rs | 595 +++--------------------------------- src/service/client.rs | 197 ++++++++++++ src/service/error.rs | 163 ++++++++++ src/service/host.rs | 202 ++++++++++++ src/service/http.rs | 409 +++++++++++++++++++++++++ src/service/mod.rs | 14 + tests/client.rs | 4 + 12 files changed, 1604 insertions(+), 596 deletions(-) create mode 100644 src/client/pool/service.rs create mode 100644 src/service/client.rs create mode 100644 src/service/error.rs create mode 100644 src/service/host.rs diff --git a/src/client/builder.rs b/src/client/builder.rs index 5ccdc38d..a4260a18 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -19,8 +19,10 @@ use super::conn::Connection; use super::conn::Protocol; use super::conn::Transport; use super::pool::PoolableConnection; +use super::ConnectionPoolLayer; +use crate::service::RequestExecutor; +use crate::service::{Http1ChecksLayer, Http2ChecksLayer, SetHostHeaderLayer}; -use super::ClientService; use crate::client::conn::connection::ConnectionError; #[cfg(feature = "tls")] use crate::client::default_tls_config; @@ -484,12 +486,14 @@ where http::header::USER_AGENT, user_agent, )) - .service(ClientService { - transport, - protocol: self.protocol.build(), - pool: self.pool.map(super::pool::Pool::new), - _body: std::marker::PhantomData, - }); + .layer( + ConnectionPoolLayer::new(transport, self.protocol.build()) + .with_optional_pool(self.pool.clone()), + ) + .layer(SetHostHeaderLayer::new()) + .layer(Http2ChecksLayer::new()) + .layer(Http1ChecksLayer::new()) + .service(RequestExecutor::new()); SharedService::new(service) } diff --git a/src/client/mod.rs b/src/client/mod.rs index adac8201..b63dac7e 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,7 +4,7 @@ //! //! 1. The high-level [`Client`] API, which is the most user-friendly and abstracts away most of the details. //! It is "batteries-included", and supports features like redirects, retries and timeouts. -//! 2. The [`Service`][ClientService] API, which is a lower-level API that allows for more control over the request and response. +//! 2. The [`Service`][ConnectionPoolService] API, which is a lower-level API that allows for more control over the request and response. //! It presents a `tower::Service` that can be used to send requests and receive responses, and can be wrapped //! by middleware compatible with the tower ecosystem. //! 3. The [connection][self::conn] API, which is the lowest-level API that allows for direct control over the @@ -21,14 +21,14 @@ use tower::ServiceExt; use self::conn::protocol::auto; use self::conn::transport::tcp::TcpTransportConfig; -pub use self::service::ClientService; +pub use self::pool::service::ConnectionPoolLayer; +pub use self::pool::service::ConnectionPoolService; use crate::service::SharedService; mod builder; pub mod conn; mod error; pub mod pool; -mod service; pub use self::error::Error; pub use self::pool::Config as PoolConfig; @@ -82,7 +82,7 @@ impl ClientRef { /// A high-level async HTTP client. /// -/// This client is built on top of the [`Service`][ClientService] API and provides a more user-friendly interface, +/// This client is built on top of the [`Service`][ConnectionPoolService] API and provides a more user-friendly interface, /// including support for retries, redirects and timeouts. /// /// # Example diff --git a/src/client/pool/checkout.rs b/src/client/pool/checkout.rs index 2e4a50ef..bb635577 100644 --- a/src/client/pool/checkout.rs +++ b/src/client/pool/checkout.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::future::poll_fn; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; @@ -9,7 +10,6 @@ use std::task::Context; use std::task::Poll; use futures_util::future::BoxFuture; -use futures_util::FutureExt as _; use pin_project::pin_project; use pin_project::pinned_drop; use thiserror::Error; @@ -145,7 +145,7 @@ impl fmt::Debug } #[pin_project(PinnedDrop)] -pub(crate) struct Checkout { +pub(crate) struct Checkout { key: Key, pool: WeakOpt>>, #[pin] @@ -157,7 +157,7 @@ pub(crate) struct Checkout { id: CheckoutId, } -impl fmt::Debug for Checkout { +impl fmt::Debug for Checkout { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Checkout") .field("key", &self.key) @@ -168,7 +168,7 @@ impl fmt::Debug for Checkout Checkout { +impl Checkout { pub(crate) fn detached(key: Key, connector: Connector) -> Self { Self { key, @@ -262,6 +262,7 @@ where let transport: T = match this.inner { InnerCheckoutConnecting::Waiting => { // We're waiting on a connection to be ready. + // If that were still happening, we would bail out above. return Poll::Ready(Err(Error::Unavailable)); } InnerCheckoutConnecting::Connected => { @@ -274,11 +275,11 @@ where return Poll::Ready(Ok(self.as_mut().connected(connection))); } InnerCheckoutConnecting::Connecting(Connector { transport, .. }) => { - ready!(transport.poll_unpin(cx)).map_err(Error::Connecting)? + ready!(transport.as_mut().poll(cx)).map_err(Error::Connecting)? } InnerCheckoutConnecting::Handshaking(handshake) => { let connection = - ready!(handshake.poll_unpin(cx)).map_err(Error::Handshaking)?; + ready!(handshake.as_mut().poll(cx)).map_err(Error::Handshaking)?; return Poll::Ready(Ok(self.as_mut().connected(connection))); } }; @@ -307,7 +308,7 @@ where } } -impl Checkout { +impl Checkout { /// Checks the waiter to see if a new connection is ready and can be passed along. /// /// If there is no waiter, this function returns `Poll::Ready(Ok(None))`. If there is @@ -320,45 +321,104 @@ impl Checkout { } /// Called to register a new connection with the pool. - pub(crate) fn connected(self: Pin<&mut Self>, mut connection: C) -> Pooled { - if let Some(pool) = self.pool.upgrade() { - if let Ok(mut inner) = pool.lock() { - if let Some(reused) = connection.reuse() { - inner.push(self.key.clone(), reused, self.pool.clone()); - return Pooled { - connection: Some(connection), - is_reused: true, - key: self.key.clone(), - pool: WeakOpt::none(), - }; - } else { - return Pooled { - connection: Some(connection), - is_reused: false, - key: self.key.clone(), - pool: WeakOpt::downgrade(&pool), - }; - } + pub(crate) fn connected(self: Pin<&mut Self>, connection: C) -> Pooled { + register_connected(&self.pool, &self.key, connection) + } +} + +fn register_connected( + poolref: &WeakOpt>>, + key: &Key, + mut connection: C, +) -> Pooled +where + C: PoolableConnection, +{ + if let Some(pool) = poolref.upgrade() { + if let Ok(mut inner) = pool.lock() { + if let Some(reused) = connection.reuse() { + inner.push(key.clone(), reused, poolref.clone()); + return Pooled { + connection: Some(connection), + is_reused: true, + key: key.clone(), + pool: WeakOpt::none(), + }; + } else { + return Pooled { + connection: Some(connection), + is_reused: false, + key: key.clone(), + pool: WeakOpt::downgrade(&pool), + }; } } + } - // No pool or lock was available, so we can't add the connection to the pool. - Pooled { - connection: Some(connection), - is_reused: false, - key: self.key.clone(), - pool: WeakOpt::none(), - } + // No pool or lock was available, so we can't add the connection to the pool. + Pooled { + connection: Some(connection), + is_reused: false, + key: key.clone(), + pool: WeakOpt::none(), } } #[pinned_drop] -impl PinnedDrop for Checkout { - fn drop(self: Pin<&mut Self>) { +impl PinnedDrop for Checkout +where + E: 'static, +{ + fn drop(mut self: Pin<&mut Self>) { if let Some(pool) = self.pool.upgrade() { if let Ok(mut inner) = pool.lock() { inner.cancel_connection(&self.key); } + + let state = std::mem::replace(&mut self.inner, InnerCheckoutConnecting::Connected); + + match state { + InnerCheckoutConnecting::Connecting(mut connector) => { + let pool = self.pool.clone(); + let key = self.key.clone(); + tokio::spawn(async move { + let io: T = match poll_fn(|cx| connector.transport.as_mut().poll(cx)).await + { + Ok(io) => io, + Err(_) => { + tracing::error!(%key, "error connecting background transport"); + return; + } + }; + + let connection = match (connector.handshake)(io).await { + Ok(conn) => conn, + Err(_) => { + tracing::error!(%key, "error handshaking background connection"); + return; + } + }; + + register_connected(&pool, &key, connection); + }); + } + InnerCheckoutConnecting::Handshaking(handshake) => { + let pool = self.pool.clone(); + let key = self.key.clone(); + tokio::spawn(async move { + let connection = match handshake.await { + Ok(conn) => conn, + Err(_) => { + tracing::error!(key=%key, "error handshaking connection"); + return; + } + }; + + register_connected(&pool, &key, connection); + }); + } + _ => {} + } } } } diff --git a/src/client/pool/mod.rs b/src/client/pool/mod.rs index b8eaa48c..fd8a77bc 100644 --- a/src/client/pool/mod.rs +++ b/src/client/pool/mod.rs @@ -30,6 +30,7 @@ use tracing::trace; mod checkout; mod idle; pub(super) mod key; +pub(super) mod service; mod weakopt; pub(crate) use self::checkout::Checkout; @@ -273,6 +274,10 @@ pub trait PoolableConnection: Unpin + Send + Sized + 'static { fn reuse(&mut self) -> Option; } +/// Wrapper type for a connection which is managed by a pool. +/// +/// This type is used outside of the Pool to ensure that dropped +/// connections are returned to the pool. pub(crate) struct Pooled { connection: Option, is_reused: bool, diff --git a/src/client/pool/service.rs b/src/client/pool/service.rs new file mode 100644 index 00000000..f233702b --- /dev/null +++ b/src/client/pool/service.rs @@ -0,0 +1,455 @@ +use std::fmt; +use std::future::poll_fn; +use std::future::Future; +use std::task::Poll; + +use futures_util::FutureExt; +use http_body::Body; +use pin_project::pin_project; +use tower::util::Oneshot; +use tower::ServiceExt; + +use crate::client::conn::connection::ConnectionError; +use crate::client::conn::connection::HttpConnection; +use crate::client::conn::protocol::auto::HttpConnectionBuilder; +use crate::client::conn::protocol::HttpProtocol; +use crate::client::conn::transport::tcp::TcpTransport; +use crate::client::conn::transport::TransportStream; +use crate::client::conn::Connection; +use crate::client::conn::Protocol; +use crate::client::conn::TlsTransport; +use crate::client::conn::Transport; +use crate::client::pool; +use crate::client::pool::Checkout; +use crate::client::pool::Connector; +use crate::client::pool::PoolableConnection; +use crate::client::Error; +use crate::info::HasConnectionInfo; +use crate::service::client::ExecuteRequest; + +/// Layer which adds connection pooling and converts +/// to an inner service which accepts `ExecuteRequest` +/// from an outer service which accepts `http::Request`. +pub struct ConnectionPoolLayer { + transport: T, + protocol: P, + pool: Option, + _body: std::marker::PhantomData ()>, +} + +impl fmt::Debug for ConnectionPoolLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConnectionPoolLayer") + .field("transport", &self.transport) + .field("protocol", &self.protocol) + .field("pool", &self.pool) + .finish() + } +} + +impl ConnectionPoolLayer { + /// Layer for connection pooling. + pub fn new(transport: T, protocol: P) -> Self { + Self { + transport, + protocol, + pool: None, + _body: std::marker::PhantomData, + } + } + + /// Set the connection pool configuration. + pub fn with_pool(mut self, pool: pool::Config) -> Self { + self.pool = Some(pool); + self + } + + /// Set the connection pool configuration to an optional value. + pub fn with_optional_pool(mut self, pool: Option) -> Self { + self.pool = pool; + self + } + + /// Disable connection pooling. + pub fn without_pool(mut self) -> Self { + self.pool = None; + self + } +} + +impl Clone for ConnectionPoolLayer +where + T: Clone, + P: Clone, +{ + fn clone(&self) -> Self { + Self { + transport: self.transport.clone(), + protocol: self.protocol.clone(), + pool: self.pool.clone(), + _body: std::marker::PhantomData, + } + } +} + +impl tower::layer::Layer for ConnectionPoolLayer +where + T: Transport + Clone + Send + Sync + 'static, + P: Protocol + Clone + Send + Sync + 'static, + P::Connection: PoolableConnection, +{ + type Service = ConnectionPoolService; + + fn layer(&self, service: S) -> Self::Service { + let pool = self.pool.clone().map(pool::Pool::new); + + ConnectionPoolService { + transport: self.transport.clone(), + protocol: self.protocol.clone(), + service, + pool, + _body: std::marker::PhantomData, + } + } +} + +/// A service which gets a connection from a possible connection pool and passes it to +/// an inner service to execute that request. +/// +/// This service will accept [`http::Request`] objects, but expects the inner service +/// to accept [`ExecuteRequest`] objects, which bundle the connection with the request. +/// +/// The simplest interior service is [`crate::service::RequestExecutor`], which will execute the request +/// on the connection and return the response. +#[derive(Debug)] +pub struct ConnectionPoolService +where + T: Transport, + P: Protocol, + P::Connection: PoolableConnection, +{ + pub(super) transport: T, + pub(super) protocol: P, + pub(super) service: S, + pub(super) pool: Option>, + pub(super) _body: std::marker::PhantomData ()>, +} + +impl ConnectionPoolService +where + T: Transport, + P: Protocol, + P::Connection: PoolableConnection, +{ + /// Create a new client with the given transport, protocol, and pool configuration. + pub fn new(transport: T, protocol: P, service: S, pool: pool::Config) -> Self { + Self { + transport, + protocol, + service, + pool: Some(pool::Pool::new(pool)), + _body: std::marker::PhantomData, + } + } + + /// Disable connection pooling for this client. + pub fn without_pool(self) -> Self { + Self { pool: None, ..self } + } +} + +impl + ConnectionPoolService< + TlsTransport, + HttpConnectionBuilder, + crate::service::client::RequestExecutor, crate::Body>, + crate::Body, + > +{ + /// Create a new client with the default configuration for making requests over TCP + /// connections using the HTTP protocol. + /// + /// When the `tls` feature is enabled, this will also add support for `tls` when + /// using the `https` scheme, with a default TLS configuration that will rely + /// on the system's certificate store. + pub fn new_tcp_http() -> Self { + Self { + pool: Some(pool::Pool::new(pool::Config { + idle_timeout: Some(std::time::Duration::from_secs(90)), + max_idle_per_host: 32, + })), + + transport: Default::default(), + + protocol: HttpConnectionBuilder::default(), + + service: crate::service::client::RequestExecutor::new(), + + _body: std::marker::PhantomData, + } + } +} + +impl Clone for ConnectionPoolService +where + P: Protocol + Clone, + P::Connection: PoolableConnection, + T: Transport + Clone, + S: Clone, +{ + fn clone(&self) -> Self { + Self { + protocol: self.protocol.clone(), + transport: self.transport.clone(), + pool: self.pool.clone(), + service: self.service.clone(), + _body: std::marker::PhantomData, + } + } +} + +impl ConnectionPoolService +where + C: Connection + PoolableConnection, + P: Protocol + + Clone + + Send + + Sync + + 'static, + T: Transport + 'static, + T::IO: Unpin, + <::IO as HasConnectionInfo>::Addr: Send, +{ + #[allow(clippy::type_complexity)] + fn connect_to( + &self, + uri: http::Uri, + http_protocol: HttpProtocol, + ) -> Result, ConnectionError>, ConnectionError> + { + let key: pool::Key = uri.clone().try_into()?; + let mut protocol = self.protocol.clone(); + let mut transport = self.transport.clone(); + + let connector = Connector::new( + move || async move { + poll_fn(|cx| Transport::poll_ready(&mut transport, cx)) + .await + .map_err(|error| ConnectionError::Connecting(error.into()))?; + transport + .connect(uri) + .await + .map_err(|error| ConnectionError::Connecting(error.into())) + }, + Box::new(move |transport| { + Box::pin(async move { + poll_fn(|cx| Protocol::poll_ready(&mut protocol, cx)) + .await + .map_err(|error| ConnectionError::Handshake(error.into()))?; + protocol + .connect(transport, http_protocol) + .await + .map_err(|error| ConnectionError::Handshake(error.into())) + }) as _ + }), + ); + + if let Some(pool) = self.pool.as_ref() { + Ok(pool.checkout(key, http_protocol.multiplex(), connector)) + } else { + Ok(Checkout::detached(key, connector)) + } + } +} + +impl tower::Service> + for ConnectionPoolService +where + C: Connection + PoolableConnection, + P: Protocol + + Clone + + Send + + Sync + + 'static, + T: Transport + 'static, + T::IO: Unpin, + <::IO as HasConnectionInfo>::Addr: Send, + S: tower::Service, Response = http::Response> + Clone, + S::Error: Into, + BOut: Body + Unpin + 'static, + BIn: Body + Unpin + Send + 'static, + ::Data: Send, + ::Error: Into>, +{ + type Response = http::Response; + type Error = Error; + type Future = ResponseFuture, S, BIn, BOut>; + + fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, request: http::Request) -> Self::Future { + let uri = request.uri().clone(); + + let protocol: HttpProtocol = request.version().into(); + + match self.connect_to(uri, protocol) { + Ok(checkout) => ResponseFuture::new(checkout, request, self.service.clone()), + Err(error) => ResponseFuture::error(error), + } + } +} + +impl ConnectionPoolService +where + C: Connection + PoolableConnection, + P: Protocol + + Clone + + Send + + Sync + + 'static, + T: Transport + 'static, + T::IO: Unpin, + S: tower::Service, Response = http::Response> + Clone, + S::Error: Into, + BIn: Body + Unpin + Send + 'static, + ::Data: Send, + ::Error: Into>, + BOut: Body + Unpin + 'static, + <::IO as HasConnectionInfo>::Addr: Send, +{ + /// Send an http Request, and return a Future of the Response. + pub fn request(&self, request: http::Request) -> Oneshot> { + self.clone().oneshot(request) + } +} + +/// A future that resolves to an HTTP response. +#[pin_project] +pub struct ResponseFuture +where + C: Connection + pool::PoolableConnection, + T: pool::PoolableTransport, + S: tower::Service, Response = http::Response>, +{ + #[pin] + inner: ResponseFutureState, + _body: std::marker::PhantomData BOut>, +} + +impl fmt::Debug for ResponseFuture +where + C: Connection + pool::PoolableConnection, + T: pool::PoolableTransport, + S: tower::Service, Response = http::Response>, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ResponseFuture").finish() + } +} + +impl ResponseFuture +where + C: Connection + pool::PoolableConnection, + T: pool::PoolableTransport, + S: tower::Service, Response = http::Response>, +{ + fn new( + checkout: Checkout, + request: http::Request, + service: S, + ) -> Self { + Self { + inner: ResponseFutureState::Checkout { + checkout, + request: Some(request), + service, + }, + _body: std::marker::PhantomData, + } + } + + fn error(error: ConnectionError) -> Self { + Self { + inner: ResponseFutureState::ConnectionError(Some(error)), + _body: std::marker::PhantomData, + } + } +} + +impl Future for ResponseFuture +where + C: Connection + pool::PoolableConnection, + T: pool::PoolableTransport, + S: tower::Service, Response = http::Response>, + S::Error: Into, + BOut: Body + Unpin + 'static, + BIn: Body + Unpin + Send + 'static, + ::Data: Send, + ::Error: Into>, +{ + type Output = Result, Error>; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll { + loop { + let mut this = self.as_mut().project(); + let next = match this.inner.as_mut().project() { + ResponseFutureStateProj::Checkout { + checkout, + request, + service, + } => match checkout.poll_unpin(cx) { + Poll::Ready(Ok(conn)) => { + ResponseFutureState::Request(service.call(ExecuteRequest { + conn, + request: request.take().expect("request polled again"), + })) + } + Poll::Ready(Err(error)) => { + return Poll::Ready(Err(error.into())); + } + Poll::Pending => { + return Poll::Pending; + } + }, + ResponseFutureStateProj::Request(fut) => match fut.poll(cx) { + Poll::Ready(Ok(response)) => { + return Poll::Ready(Ok(response)); + } + Poll::Ready(Err(error)) => { + return Poll::Ready(Err(error.into())); + } + Poll::Pending => { + return Poll::Pending; + } + }, + ResponseFutureStateProj::ConnectionError(error) => { + return Poll::Ready(Err(Error::Connection( + error.take().expect("error polled again").into(), + ))); + } + }; + this.inner.set(next); + } + } +} + +#[pin_project(project=ResponseFutureStateProj)] +enum ResponseFutureState +where + C: Connection + pool::PoolableConnection, + T: pool::PoolableTransport, + S: tower::Service, Response = http::Response>, +{ + Checkout { + checkout: Checkout, + request: Option>, + service: S, + }, + ConnectionError(Option), + Request(#[pin] S::Future), +} diff --git a/src/client/service.rs b/src/client/service.rs index 13ccd194..aa4818ea 100644 --- a/src/client/service.rs +++ b/src/client/service.rs @@ -1,398 +1,79 @@ use std::fmt; -use std::future::poll_fn; -use std::future::Future; use std::task::Poll; use futures_util::future::BoxFuture; -use futures_util::FutureExt; -use http::uri::Port; -use http::uri::Scheme; -use http::HeaderValue; -use http::Uri; -use http::Version; use http_body::Body; -use tower::util::Oneshot; -use tower::ServiceExt; -use tracing::warn; - -use super::conn::connection::ConnectionError; -use super::conn::protocol::auto::HttpConnectionBuilder; -use super::conn::protocol::HttpProtocol; -use super::conn::transport::tcp::TcpTransport; -use super::conn::transport::TransportStream; + use super::conn::Connection; -use super::conn::Protocol; -use super::conn::TlsTransport; -use super::conn::Transport; -use super::pool; -use super::pool::Checkout; -use super::pool::Connector; use super::pool::PoolableConnection; use super::pool::Pooled; use super::Error; -use crate::info::HasConnectionInfo; - -/// A client which provides a simple HTTP `tower::Service`. -/// -/// Client Services combine a [transport][Transport] (e.g. TCP) and a [protocol][Protocol] (e.g. HTTP) -/// to provide a `tower::Service` that can be used to make requests to a server. Optionally, a connection -/// pool can be configured so that individual connections can be reused. -/// -/// To use a client service, you must first poll the service to readiness with `Service::poll_ready`, -/// and then make the request with `Service::call`. This can be simplified with the `tower::ServiceExt` -/// which provides a `Service::oneshot` method that combines these two steps into a single future. -#[derive(Debug)] -pub struct ClientService -where - T: Transport, - P: Protocol, - P::Connection: PoolableConnection, -{ - pub(super) transport: T, - pub(super) protocol: P, - pub(super) pool: Option>, - pub(super) _body: std::marker::PhantomData BOut>, -} - -impl ClientService -where - T: Transport, - P: Protocol, - P::Connection: PoolableConnection, -{ - /// Create a new client with the given transport, protocol, and pool configuration. - pub fn new(transport: T, protocol: P, pool: pool::Config) -> Self { - Self { - transport, - protocol, - pool: Some(pool::Pool::new(pool)), - _body: std::marker::PhantomData, - } - } - - /// Disable connection pooling for this client. - pub fn without_pool(self) -> Self { - Self { pool: None, ..self } - } -} - -impl - ClientService< - TlsTransport, - HttpConnectionBuilder, - crate::Body, - crate::Body, - > -{ - /// Create a new client with the default configuration for making requests over TCP - /// connections using the HTTP protocol. - /// - /// When the `tls` feature is enabled, this will also add support for `tls` when - /// using the `https` scheme, with a default TLS configuration that will rely - /// on the system's certificate store. - pub fn new_tcp_http() -> Self { - Self { - pool: Some(pool::Pool::new(pool::Config { - idle_timeout: Some(std::time::Duration::from_secs(90)), - max_idle_per_host: 32, - })), - - transport: Default::default(), - - protocol: HttpConnectionBuilder::default(), - - _body: std::marker::PhantomData, - } - } -} -impl Clone for ClientService -where - P: Protocol + Clone, - P::Connection: PoolableConnection, - T: Transport + Clone, -{ - fn clone(&self) -> Self { - Self { - protocol: self.protocol.clone(), - transport: self.transport.clone(), - pool: self.pool.clone(), - _body: std::marker::PhantomData, - } - } -} - -impl ClientService -where - C: Connection + PoolableConnection, - P: Protocol - + Clone - + Send - + Sync - + 'static, - T: Transport + 'static, - T::IO: Unpin, - <::IO as HasConnectionInfo>::Addr: Send, -{ - #[allow(clippy::type_complexity)] - fn connect_to( - &self, - uri: http::Uri, - http_protocol: HttpProtocol, - ) -> Result, ConnectionError>, ConnectionError> - { - let key: pool::Key = uri.clone().try_into()?; - let mut protocol = self.protocol.clone(); - let mut transport = self.transport.clone(); - - let connector = Connector::new( - move || async move { - poll_fn(|cx| Transport::poll_ready(&mut transport, cx)) - .await - .map_err(|error| ConnectionError::Connecting(error.into()))?; - transport - .connect(uri) - .await - .map_err(|error| ConnectionError::Connecting(error.into())) - }, - Box::new(move |transport| { - Box::pin(async move { - poll_fn(|cx| Protocol::poll_ready(&mut protocol, cx)) - .await - .map_err(|error| ConnectionError::Handshake(error.into()))?; - protocol - .connect(transport, http_protocol) - .await - .map_err(|error| ConnectionError::Handshake(error.into())) - }) as _ - }), - ); - - if let Some(pool) = self.pool.as_ref() { - Ok(pool.checkout(key, http_protocol.multiplex(), connector)) - } else { - Ok(Checkout::detached(key, connector)) - } - } -} - -impl tower::Service> for ClientService -where - C: Connection + PoolableConnection, - P: Protocol - + Clone - + Send - + Sync - + 'static, - T: Transport + 'static, - T::IO: Unpin, - <::IO as HasConnectionInfo>::Addr: Send, - BOut: Body + Unpin + 'static, - BIn: Body + Unpin + Send + 'static, - ::Data: Send, - ::Error: Into>, -{ - type Response = http::Response; - type Error = Error; - type Future = ResponseFuture, BIn, BOut>; - - fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, request: http::Request) -> Self::Future { - let uri = request.uri().clone(); - - let protocol: HttpProtocol = request.version().into(); - - match self.connect_to(uri, protocol) { - Ok(checkout) => ResponseFuture::new(checkout, request), - Err(error) => ResponseFuture::error(error), - } - } +#[derive(Debug)] +pub struct ExecuteRequest + PoolableConnection, B> { + /// The connection to use for the request. + pub(crate) conn: Pooled, + /// The request to execute. + pub(crate) request: http::Request, } -impl ClientService -where - C: Connection + PoolableConnection, - P: Protocol - + Clone - + Send - + Sync - + 'static, - T: Transport + 'static, - T::IO: Unpin, - BIn: Body + Unpin + Send + 'static, - ::Data: Send, - ::Error: Into>, - BOut: Body + Unpin + 'static, - <::IO as HasConnectionInfo>::Addr: Send, -{ - /// Send an http Request, and return a Future of the Response. - pub fn request(&self, request: http::Request) -> Oneshot> { - self.clone().oneshot(request) +impl + PoolableConnection, B> ExecuteRequest { + pub fn connection(&self) -> &C { + &self.conn } } -/// A future that resolves to an HTTP response. -pub struct ResponseFuture -where - C: pool::PoolableConnection, - T: pool::PoolableTransport, -{ - inner: ResponseFutureState, - _body: std::marker::PhantomData BOut>, +#[derive(Default)] +pub struct RequestExecutor + PoolableConnection, B> { + _private: std::marker::PhantomData ()>, } -impl fmt::Debug - for ResponseFuture -{ +impl + PoolableConnection, B> fmt::Debug for RequestExecutor { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ResponseFuture").finish() + f.debug_struct("RequestExecutor").finish() } } -impl ResponseFuture -where - C: pool::PoolableConnection, - T: pool::PoolableTransport, -{ - fn new(checkout: Checkout, request: http::Request) -> Self { - Self { - inner: ResponseFutureState::Checkout { checkout, request }, - _body: std::marker::PhantomData, - } +impl + PoolableConnection, B> Clone for RequestExecutor { + fn clone(&self) -> Self { + Self::new() } +} - fn error(error: ConnectionError) -> Self { +impl + PoolableConnection, B> RequestExecutor { + pub fn new() -> Self { Self { - inner: ResponseFutureState::ConnectionError(error), - _body: std::marker::PhantomData, + _private: std::marker::PhantomData, } } } -impl Future for ResponseFuture +impl tower::Service> for RequestExecutor where - C: Connection + pool::PoolableConnection, - T: pool::PoolableTransport, - BOut: Body + Unpin + 'static, - BIn: Body + Unpin + Send + 'static, - ::Data: Send, - ::Error: Into>, + C: Connection + PoolableConnection, + B: Body + Unpin + Send + 'static, { - type Output = Result, Error>; - - fn poll( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll { - loop { - match std::mem::replace(&mut self.inner, ResponseFutureState::Empty) { - ResponseFutureState::Checkout { - mut checkout, - request, - } => match checkout.poll_unpin(cx) { - Poll::Ready(Ok(conn)) => { - self.inner = - ResponseFutureState::Request(execute_request(request, conn).boxed()); - } - Poll::Ready(Err(error)) => { - return Poll::Ready(Err(error.into())); - } - Poll::Pending => { - self.inner = ResponseFutureState::Checkout { checkout, request }; - return Poll::Pending; - } - }, - ResponseFutureState::Request(mut fut) => match fut.poll_unpin(cx) { - Poll::Ready(Ok(response)) => return Poll::Ready(Ok(response.map(Into::into))), - Poll::Ready(Err(error)) => return Poll::Ready(Err(error)), - Poll::Pending => { - self.inner = ResponseFutureState::Request(fut); - return Poll::Pending; - } - }, - ResponseFutureState::ConnectionError(error) => { - return Poll::Ready(Err(Error::Connection(error.into()))); - } - ResponseFutureState::Empty => { - panic!("future polled after completion"); - } - } - } - } -} + type Response = http::Response; -enum ResponseFutureState { - Empty, - Checkout { - checkout: Checkout, - request: http::Request, - }, - ConnectionError(ConnectionError), - Request(BoxFuture<'static, Result, Error>>), -} + type Error = Error; -/// Prepare a request for sending over the connection. -fn prepare_request + PoolableConnection>( - request: &mut http::Request, - conn: &Pooled, -) -> Result<(), Error> { - request - .headers_mut() - .entry(http::header::USER_AGENT) - .or_insert_with(|| { - HeaderValue::from_static(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )) - }); + type Future = BoxFuture<'static, Result>; - if conn.version() == Version::HTTP_11 { - if request.version() == Version::HTTP_2 || request.version() == Version::HTTP_3 { - warn!( - "refusing to send {:?} request to HTTP/1.1 connection", - request.version() - ); - return Err(Error::UnsupportedProtocol); - } + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } - //TODO: Configure set host header? - set_host_header(request); - - if request.method() == http::Method::CONNECT { - authority_form(request.uri_mut()); - - // If the URI is to HTTPS, and the connector claimed to be a proxy, - // then it *should* have tunneled, and so we don't want to send - // absolute-form in that case. - if request.uri().scheme() == Some(&Scheme::HTTPS) { - origin_form(request.uri_mut()); - } - } else if request.uri().scheme().is_none() || request.uri().authority().is_none() { - absolute_form(request.uri_mut()); - } else { - origin_form(request.uri_mut()); - } - } else if request.method() == http::Method::CONNECT { - return Err(Error::InvalidMethod(http::Method::CONNECT)); - } else if conn.version() == Version::HTTP_2 { - *request.version_mut() = Version::HTTP_2; + fn call(&mut self, req: ExecuteRequest) -> Self::Future { + Box::pin(execute_request(req)) } - Ok(()) } async fn execute_request( - mut request: http::Request, - mut conn: Pooled, + ExecuteRequest { request, mut conn }: ExecuteRequest, ) -> Result, Error> where C: Connection + PoolableConnection, { - prepare_request(&mut request, &conn)?; - tracing::trace!(request.uri=%request.uri(), conn.version=?conn.version(), req.version=?request.version(), "sending request"); let response = conn @@ -405,92 +86,15 @@ where // Only re-insert the connection when it is ready again. Spawn // a task to wait for the connection to become ready before dropping. tokio::spawn(async move { - let _ = conn.when_ready().await.map_err(|_| ()); + if let Err(error) = conn.when_ready().await { + tracing::trace!(conn.version=?conn.version(), error=%error, "Connection errored while polling for readiness"); + }; }); } Ok(response.map(Into::into)) } -/// Convert the URI to authority-form, if it is not already. -/// -/// This is the form of the URI with just the authority and a default -/// path and scheme. This is used in HTTP/1 CONNECT requests. -fn authority_form(uri: &mut Uri) { - *uri = match uri.authority() { - Some(auth) => { - let mut parts = ::http::uri::Parts::default(); - parts.authority = Some(auth.clone()); - Uri::from_parts(parts).expect("authority is valid") - } - None => { - unreachable!("authority_form with relative uri"); - } - }; -} - -fn absolute_form(uri: &mut Uri) { - debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme"); - debug_assert!( - uri.authority().is_some(), - "absolute_form needs an authority" - ); -} - -/// Convert the URI to origin-form, if it is not already. -/// -/// This form of the URI has no scheme or authority, and contains just -/// the path, usually used in HTTP/1 requests. -fn origin_form(uri: &mut Uri) { - let path = match uri.path_and_query() { - Some(path) if path.as_str() != "/" => { - let mut parts = ::http::uri::Parts::default(); - parts.path_and_query = Some(path.clone()); - Uri::from_parts(parts).expect("path is valid uri") - } - _none_or_just_slash => { - debug_assert!(Uri::default() == "/"); - Uri::default() - } - }; - *uri = path -} - -/// Returns the port if it is not the default port for the scheme. -fn get_non_default_port(uri: &Uri) -> Option> { - match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) { - (Some(443), true) => None, - (Some(80), false) => None, - _ => uri.port(), - } -} - -/// Returns true if the URI scheme is presumed secure. -fn is_schema_secure(uri: &Uri) -> bool { - uri.scheme_str() - .map(|scheme_str| matches!(scheme_str, "wss" | "https")) - .unwrap_or_default() -} - -/// Set the Host header on the request if it is not already set, -/// using the authority from the URI. -fn set_host_header(request: &mut http::Request) { - let uri = request.uri().clone(); - request - .headers_mut() - .entry(http::header::HOST) - .or_insert_with(|| { - let hostname = uri.host().expect("authority implies host"); - if let Some(port) = get_non_default_port(&uri) { - let s = format!("{}:{}", hostname, port); - HeaderValue::from_str(&s) - } else { - HeaderValue::from_str(hostname) - } - .expect("uri host is valid header value") - }); -} - #[cfg(test)] mod tests { @@ -506,128 +110,17 @@ mod tests { use super::*; - #[test] - fn test_set_host_header() { - let mut request = http::Request::new(()); - *request.uri_mut() = "http://example.com".parse().unwrap(); - set_host_header(&mut request); - assert_eq!( - request.headers().get(http::header::HOST).unwrap(), - "example.com" - ); - - let mut request = http::Request::new(()); - *request.uri_mut() = "http://example.com:8080".parse().unwrap(); - set_host_header(&mut request); - assert_eq!( - request.headers().get(http::header::HOST).unwrap(), - "example.com:8080" - ); - - let mut request = http::Request::new(()); - *request.uri_mut() = "https://example.com".parse().unwrap(); - set_host_header(&mut request); - assert_eq!( - request.headers().get(http::header::HOST).unwrap(), - "example.com" - ); - - let mut request = http::Request::new(()); - *request.uri_mut() = "https://example.com:8443".parse().unwrap(); - set_host_header(&mut request); - assert_eq!( - request.headers().get(http::header::HOST).unwrap(), - "example.com:8443" - ); - } - - #[test] - fn test_is_schema_secure() { - let uri = "http://example.com".parse().unwrap(); - assert!(!is_schema_secure(&uri)); - - let uri = "https://example.com".parse().unwrap(); - assert!(is_schema_secure(&uri)); - - let uri = "ws://example.com".parse().unwrap(); - assert!(!is_schema_secure(&uri)); - - let uri = "wss://example.com".parse().unwrap(); - assert!(is_schema_secure(&uri)); - } - - #[test] - fn test_get_non_default_port() { - let uri = "http://example.com".parse().unwrap(); - assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), None); - - let uri = "http://example.com:8080".parse().unwrap(); - assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), Some(8080)); - - let uri = "https://example.com".parse().unwrap(); - assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), None); - - let uri = "https://example.com:8443".parse().unwrap(); - assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), Some(8443)); - } - - #[test] - fn test_origin_form() { - let mut uri = "http://example.com".parse().unwrap(); - origin_form(&mut uri); - assert_eq!(uri, "/"); - - let mut uri = "/some/path/here".parse().unwrap(); - origin_form(&mut uri); - assert_eq!(uri, "/some/path/here"); - - let mut uri = "http://example.com:8080/some/path?query#fragment" - .parse() - .unwrap(); - origin_form(&mut uri); - assert_eq!(uri, "/some/path?query"); - - let mut uri = "/".parse().unwrap(); - origin_form(&mut uri); - assert_eq!(uri, "/"); - } - - #[test] - fn test_absolute_form() { - let mut uri = "http://example.com".parse().unwrap(); - absolute_form(&mut uri); - assert_eq!(uri, "http://example.com"); - - let mut uri = "http://example.com:8080".parse().unwrap(); - absolute_form(&mut uri); - assert_eq!(uri, "http://example.com:8080"); - - let mut uri = "https://example.com/some/path?query".parse().unwrap(); - absolute_form(&mut uri); - assert_eq!(uri, "https://example.com/some/path?query"); - - let mut uri = "https://example.com:8443".parse().unwrap(); - absolute_form(&mut uri); - assert_eq!(uri, "https://example.com:8443"); - - let mut uri = "http://example.com:443".parse().unwrap(); - absolute_form(&mut uri); - assert_eq!(uri, "http://example.com:443"); - - let mut uri = "https://example.com:80".parse().unwrap(); - absolute_form(&mut uri); - assert_eq!(uri, "https://example.com:80"); - } - #[cfg(feature = "mocks")] #[tokio::test] async fn test_client_mock_transport() { + use crate::client::ConnectionPoolService; + let transport = MockTransport::new(false); let protocol = MockProtocol::default(); let pool = PoolConfig::default(); - let client: ClientService = - ClientService::new(transport, protocol, pool); + let client: ConnectionPoolService = + ConnectionPoolService::new(transport, protocol, RequestExecutor::new(), pool); client .request( @@ -643,12 +136,14 @@ mod tests { #[cfg(feature = "mocks")] #[tokio::test] async fn test_client_mock_connection_error() { + use crate::client::{conn::connection::ConnectionError, ConnectionPoolService}; + let transport = MockTransport::connection_error(); let protocol = MockProtocol::default(); let pool = PoolConfig::default(); - let client: ClientService = - ClientService::new(transport, protocol, pool); + let client: ConnectionPoolService = + ConnectionPoolService::new(transport, protocol, RequestExecutor::new(), pool); let result = client .request( diff --git a/src/service/client.rs b/src/service/client.rs new file mode 100644 index 00000000..a1ecbea9 --- /dev/null +++ b/src/service/client.rs @@ -0,0 +1,197 @@ +//! Core service for implementing a Client. +//! +//! This service accepts an `ExecuteRequest` and returns a `http::Response`. +//! +//! The `ExecuteRequest` contains the connection to use for the request and the +//! request to execute. + +use std::fmt; +use std::task::Poll; + +use futures_util::future::BoxFuture; +use http_body::Body; + +use crate::client::conn::Connection; +use crate::client::pool::PoolableConnection; +use crate::client::pool::Pooled; +use crate::client::Error; + +/// A wrapper that binds a connection and request together. +#[derive(Debug)] +pub struct ExecuteRequest + PoolableConnection, B> { + /// The connection to use for the request. + pub(crate) conn: Pooled, + /// The request to execute. + pub(crate) request: http::Request, +} + +impl + PoolableConnection, B> ExecuteRequest { + /// A reference to the connection. + pub fn connection(&self) -> &C { + &self.conn + } + + /// A reference to the request. + pub fn request(&self) -> &http::Request { + &self.request + } + + /// A mutable reference to the request. + pub fn request_mut(&mut self) -> &mut http::Request { + &mut self.request + } +} + +/// A service that executes requests on associated connections. +pub struct RequestExecutor + PoolableConnection, B> { + _private: std::marker::PhantomData ()>, +} + +impl + PoolableConnection, B> Default for RequestExecutor { + fn default() -> Self { + Self::new() + } +} + +impl + PoolableConnection, B> fmt::Debug for RequestExecutor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RequestExecutor").finish() + } +} + +impl + PoolableConnection, B> Clone for RequestExecutor { + fn clone(&self) -> Self { + Self::new() + } +} + +impl + PoolableConnection, B> RequestExecutor { + /// Create a new `RequestExecutor`. + pub fn new() -> Self { + Self { + _private: std::marker::PhantomData, + } + } +} + +impl tower::Service> for RequestExecutor +where + C: Connection + PoolableConnection, + B: Body + Unpin + Send + 'static, +{ + type Response = http::Response; + + type Error = Error; + + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: ExecuteRequest) -> Self::Future { + Box::pin(execute_request(req)) + } +} + +async fn execute_request( + ExecuteRequest { request, mut conn }: ExecuteRequest, +) -> Result, Error> +where + C: Connection + PoolableConnection, +{ + tracing::trace!(request.uri=%request.uri(), conn.version=?conn.version(), req.version=?request.version(), "sending request"); + + let response = conn + .send_request(request) + .await + .map_err(|error| Error::Connection(error.into()))?; + + // Shared connections are already in the pool, no need to do this. + if !conn.can_share() { + // Only re-insert the connection when it is ready again. Spawn + // a task to wait for the connection to become ready before dropping. + tokio::spawn(async move { + if let Err(error) = conn.when_ready().await { + tracing::trace!(conn.version=?conn.version(), error=%error, "Connection errored while polling for readiness"); + }; + }); + } + + Ok(response.map(Into::into)) +} + +#[cfg(test)] +mod tests { + + #[cfg(feature = "mocks")] + use crate::Body; + + #[cfg(feature = "mocks")] + use crate::client::conn::protocol::mock::MockProtocol; + #[cfg(feature = "mocks")] + use crate::client::conn::transport::mock::{MockConnectionError, MockTransport}; + + use crate::client::pool::Config as PoolConfig; + + use super::*; + + #[cfg(feature = "mocks")] + #[tokio::test] + async fn test_client_mock_transport() { + use crate::client::ConnectionPoolService; + + let transport = MockTransport::new(false); + let protocol = MockProtocol::default(); + let pool = PoolConfig::default(); + + let client: ConnectionPoolService = + ConnectionPoolService::new(transport, protocol, RequestExecutor::new(), pool); + + client + .request( + http::Request::builder() + .uri("mock://somewhere") + .body(crate::Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + } + + #[cfg(feature = "mocks")] + #[tokio::test] + async fn test_client_mock_connection_error() { + use crate::client::{conn::connection::ConnectionError, ConnectionPoolService}; + + let transport = MockTransport::connection_error(); + let protocol = MockProtocol::default(); + let pool = PoolConfig::default(); + + let client: ConnectionPoolService = + ConnectionPoolService::new(transport, protocol, RequestExecutor::new(), pool); + + let result = client + .request( + http::Request::builder() + .uri("mock://somewhere") + .body(crate::Body::empty()) + .unwrap(), + ) + .await; + + let err = result.unwrap_err(); + + let Error::Connection(err) = err else { + panic!("unexpected error: {:?}", err); + }; + + let err = err.downcast::().unwrap(); + + let ConnectionError::Connecting(err) = *err else { + panic!("unexpected error: {:?}", err); + }; + + err.downcast::().unwrap(); + } +} diff --git a/src/service/error.rs b/src/service/error.rs new file mode 100644 index 00000000..c9178953 --- /dev/null +++ b/src/service/error.rs @@ -0,0 +1,163 @@ +//! Helpers for middleware that might error. + +use std::fmt; + +pub use self::future::MaybeErrorFuture; + +/// A middleware that calls some (non-async) function before +/// calling the inner service, and skips the inner service if the function +/// returns an error. +#[derive(Clone)] +pub struct PreprocessService { + inner: S, + preprocessor: F, +} + +impl fmt::Debug for PreprocessService { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PreprocessService") + .field("inner", &self.inner) + .finish() + } +} + +impl PreprocessService { + /// Helper Service for middleware that might error. + pub fn new(inner: S, preprocessor: F) -> Self { + Self { + inner, + preprocessor, + } + } + + /// Get a reference to the inner service. + pub fn service(&self) -> &S { + &self.inner + } +} + +impl tower::Service for PreprocessService +where + S: tower::Service, + F: Fn(R) -> Result, +{ + type Response = S::Response; + + type Error = S::Error; + + type Future = MaybeErrorFuture; + + #[inline] + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + #[inline] + fn call(&mut self, req: R) -> Self::Future { + match (self.preprocessor)(req) { + Ok(req) => MaybeErrorFuture::future(self.inner.call(req)), + Err(error) => MaybeErrorFuture::error(error), + } + } +} + +/// A layer that wraps a service with a preprocessor function. +#[derive(Clone)] +pub struct PreprocessLayer { + preprocessor: F, +} + +impl PreprocessLayer { + /// Create a new `PreprocessLayer` wrapping the given preprocessor function. + pub fn new(preprocessor: F) -> Self { + Self { preprocessor } + } +} + +impl fmt::Debug for PreprocessLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PreprocessLayer").finish() + } +} + +impl tower::layer::Layer for PreprocessLayer { + type Service = PreprocessService; + + fn layer(&self, inner: S) -> Self::Service { + PreprocessService::new(inner, self.preprocessor.clone()) + } +} + +mod future { + + use std::{fmt, future::Future, marker::PhantomData, task::Poll}; + + use pin_project::pin_project; + + #[pin_project(project = MaybeErrorFutureStateProj)] + enum MaybeErrorFutureState { + Inner(#[pin] F), + Error(Option), + } + + impl fmt::Debug for MaybeErrorFutureState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Inner(_) => f.debug_tuple("Inner").finish(), + Self::Error(error) => f.debug_tuple("Error").field(error).finish(), + } + } + } + + /// Future for when a service either errors before yielding, + /// or continues. This is us + #[derive(Debug)] + #[pin_project] + pub struct MaybeErrorFuture { + #[pin] + state: MaybeErrorFutureState, + response: PhantomData R>, + } + + impl MaybeErrorFuture { + /// Create a future that resolves to the contained service + pub fn future(inner: F) -> Self { + Self { + state: MaybeErrorFutureState::Inner(inner), + response: PhantomData, + } + } + + /// Create a future that immediately resolves to an error. + pub fn error(error: E) -> Self { + Self { + state: MaybeErrorFutureState::Error(Some(error)), + response: PhantomData, + } + } + } + + impl Future for MaybeErrorFuture + where + F: Future>, + { + type Output = Result; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll { + let mut this = self.project(); + + match this.state.as_mut().project() { + MaybeErrorFutureStateProj::Inner(inner) => inner.poll(cx), + MaybeErrorFutureStateProj::Error(error) => { + Poll::Ready(Err(error.take().expect("polled after error"))) + } + } + } + } +} diff --git a/src/service/host.rs b/src/service/host.rs new file mode 100644 index 00000000..edc49494 --- /dev/null +++ b/src/service/host.rs @@ -0,0 +1,202 @@ +//! Middleware for setting the Host header of a request. +use http; +use http::uri::Port; +use http::HeaderValue; +use http::Uri; + +use crate::client::conn::Connection; +use crate::client::pool::PoolableConnection; + +use super::ExecuteRequest; + +/// Returns true if the URI scheme is presumed secure. +fn is_schema_secure(uri: &Uri) -> bool { + uri.scheme_str() + .map(|scheme_str| matches!(scheme_str, "wss" | "https")) + .unwrap_or_default() +} + +/// Returns the port if it is not the default port for the scheme. +fn get_non_default_port(uri: &Uri) -> Option> { + match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) { + (Some(443), true) => None, + (Some(80), false) => None, + _ => uri.port(), + } +} + +/// Set the Host header on the request if it is not already set, +/// using the authority from the URI. +fn set_host_header(request: &mut http::Request) { + if request.uri().host().is_none() { + tracing::debug!(uri=%request.uri(), "request uri has no host"); + return; + } + + let uri = request.uri().clone(); + + request + .headers_mut() + .entry(http::header::HOST) + .or_insert_with(|| { + let hostname = uri.host().expect("authority implies host"); + if let Some(port) = get_non_default_port(&uri) { + let s = format!("{}:{}", hostname, port); + HeaderValue::from_str(&s) + } else { + HeaderValue::from_str(hostname) + } + .expect("uri host is valid header value") + }); +} + +/// Middleware which sets the Host header on requests. +#[derive(Debug, Default, Clone)] +pub struct SetHostHeader { + inner: S, +} + +impl tower::Service> for SetHostHeader +where + S: tower::Service>, +{ + type Response = S::Response; + + type Error = S::Error; + + type Future = S::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: http::Request) -> Self::Future { + if req.version() < http::Version::HTTP_2 { + set_host_header(&mut req); + } + + self.inner.call(req) + } +} + +impl tower::Service> for SetHostHeader +where + S: tower::Service>, + C: Connection + PoolableConnection, +{ + type Response = S::Response; + + type Error = S::Error; + + type Future = S::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: ExecuteRequest) -> Self::Future { + if req.connection().version() < http::Version::HTTP_2 { + set_host_header(req.request_mut()); + } + + self.inner.call(req) + } +} + +/// Layer for setting the Host header on requests. +#[derive(Debug, Default, Clone)] +pub struct SetHostHeaderLayer { + _priv: (), +} + +impl SetHostHeaderLayer { + /// Create a new SetHostHeaderLayer. + pub fn new() -> Self { + Self { _priv: () } + } +} + +impl tower::layer::Layer for SetHostHeaderLayer { + type Service = SetHostHeader; + + fn layer(&self, inner: S) -> Self::Service { + SetHostHeader { inner } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_set_host_header() { + let mut request = http::Request::new(()); + *request.uri_mut() = "http://example.com".parse().unwrap(); + set_host_header(&mut request); + assert_eq!( + request.headers().get(http::header::HOST).unwrap(), + "example.com" + ); + + let mut request = http::Request::new(()); + *request.uri_mut() = "http://example.com:8080".parse().unwrap(); + set_host_header(&mut request); + assert_eq!( + request.headers().get(http::header::HOST).unwrap(), + "example.com:8080" + ); + + let mut request = http::Request::new(()); + *request.uri_mut() = "https://example.com".parse().unwrap(); + set_host_header(&mut request); + assert_eq!( + request.headers().get(http::header::HOST).unwrap(), + "example.com" + ); + + let mut request = http::Request::new(()); + *request.uri_mut() = "https://example.com:8443".parse().unwrap(); + set_host_header(&mut request); + assert_eq!( + request.headers().get(http::header::HOST).unwrap(), + "example.com:8443" + ); + } + + #[test] + fn test_is_schema_secure() { + let uri = "http://example.com".parse().unwrap(); + assert!(!is_schema_secure(&uri)); + + let uri = "https://example.com".parse().unwrap(); + assert!(is_schema_secure(&uri)); + + let uri = "ws://example.com".parse().unwrap(); + assert!(!is_schema_secure(&uri)); + + let uri = "wss://example.com".parse().unwrap(); + assert!(is_schema_secure(&uri)); + } + + #[test] + fn test_get_non_default_port() { + let uri = "http://example.com".parse().unwrap(); + assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), None); + + let uri = "http://example.com:8080".parse().unwrap(); + assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), Some(8080)); + + let uri = "https://example.com".parse().unwrap(); + assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), None); + + let uri = "https://example.com:8443".parse().unwrap(); + assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), Some(8443)); + } +} diff --git a/src/service/http.rs b/src/service/http.rs index 82a95c88..f11b3d8c 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -47,6 +47,415 @@ where } } +#[cfg(feature = "client")] +pub(super) mod http1 { + + use std::fmt; + use std::task::{Context, Poll}; + + use ::http; + use http::uri::Scheme; + use http::Uri; + + use crate::client::conn::Connection; + use crate::client::pool::PoolableConnection; + use crate::client::Error; + use crate::service::client::ExecuteRequest; + use crate::service::error::MaybeErrorFuture; + use crate::service::error::PreprocessService; + + type PreprocessFn = fn(ExecuteRequest) -> Result, E>; + + /// A service that checks if the request is HTTP/1.1 compatible. + #[derive(Debug)] + pub struct Http1ChecksService + where + S: tower::Service, Error = Error>, + C: Connection + PoolableConnection, + { + inner: PreprocessService>, + } + + impl tower::Service> for Http1ChecksService + where + S: tower::Service, Error = Error>, + C: Connection + PoolableConnection, + { + type Response = S::Response; + + type Error = S::Error; + + type Future = MaybeErrorFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: ExecuteRequest) -> Self::Future { + self.inner.call(req) + } + } + + impl Clone for Http1ChecksService + where + S: tower::Service, Error = Error> + Clone, + C: Connection + PoolableConnection, + { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } + } + + impl Http1ChecksService + where + S: tower::Service, Error = Error>, + C: Connection + PoolableConnection, + { + /// Create a new `Http1ChecksService`. + pub fn new(service: S) -> Self { + Self { + inner: PreprocessService::new(service, check_http1_request), + } + } + } + + /// A layer that checks if the request is HTTP/1.1 compatible. + pub struct Http1ChecksLayer { + processor: std::marker::PhantomData, + } + + impl Http1ChecksLayer { + /// Create a new `Http1ChecksLayer`. + pub fn new() -> Self { + Self { + processor: std::marker::PhantomData, + } + } + } + + impl Default for Http1ChecksLayer { + fn default() -> Self { + Self::new() + } + } + + impl Clone for Http1ChecksLayer { + fn clone(&self) -> Self { + Self::new() + } + } + + impl fmt::Debug for Http1ChecksLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Http1ChecksLayer").finish() + } + } + + impl tower::layer::Layer for Http1ChecksLayer + where + S: tower::Service, Error = Error>, + C: Connection + PoolableConnection, + { + type Service = Http1ChecksService; + + fn layer(&self, service: S) -> Self::Service { + Http1ChecksService::new(service) + } + } + + fn check_http1_request( + mut req: ExecuteRequest, + ) -> Result, Error> + where + C: Connection + PoolableConnection, + { + if req.connection().version() >= http::Version::HTTP_2 { + return Ok(req); + } + + if req.request().method() == http::Method::CONNECT { + authority_form(req.request_mut().uri_mut()); + + // If the URI is to HTTPS, and the connector claimed to be a proxy, + // then it *should* have tunneled, and so we don't want to send + // absolute-form in that case. + if req.request().uri().scheme() == Some(&Scheme::HTTPS) { + origin_form(req.request_mut().uri_mut()); + } + } else if req.request().uri().scheme().is_none() + || req.request().uri().authority().is_none() + { + absolute_form(req.request_mut().uri_mut()); + } else { + origin_form(req.request_mut().uri_mut()); + } + + Ok(req) + } + + /// Convert the URI to authority-form, if it is not already. + /// + /// This is the form of the URI with just the authority and a default + /// path and scheme. This is used in HTTP/1 CONNECT requests. + fn authority_form(uri: &mut Uri) { + *uri = match uri.authority() { + Some(auth) => { + let mut parts = ::http::uri::Parts::default(); + parts.authority = Some(auth.clone()); + Uri::from_parts(parts).expect("authority is valid") + } + None => { + unreachable!("authority_form with relative uri"); + } + }; + } + + fn absolute_form(uri: &mut Uri) { + debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme"); + debug_assert!( + uri.authority().is_some(), + "absolute_form needs an authority" + ); + } + + /// Convert the URI to origin-form, if it is not already. + /// + /// This form of the URI has no scheme or authority, and contains just + /// the path, usually used in HTTP/1 requests. + fn origin_form(uri: &mut Uri) { + let path = match uri.path_and_query() { + Some(path) if path.as_str() != "/" => { + let mut parts = ::http::uri::Parts::default(); + parts.path_and_query = Some(path.clone()); + Uri::from_parts(parts).expect("path is valid uri") + } + _none_or_just_slash => { + debug_assert!(Uri::default() == "/"); + Uri::default() + } + }; + *uri = path + } + + #[cfg(test)] + mod tests { + + use super::*; + + #[test] + fn test_origin_form() { + let mut uri = "http://example.com".parse().unwrap(); + origin_form(&mut uri); + assert_eq!(uri, "/"); + + let mut uri = "/some/path/here".parse().unwrap(); + origin_form(&mut uri); + assert_eq!(uri, "/some/path/here"); + + let mut uri = "http://example.com:8080/some/path?query#fragment" + .parse() + .unwrap(); + origin_form(&mut uri); + assert_eq!(uri, "/some/path?query"); + + let mut uri = "/".parse().unwrap(); + origin_form(&mut uri); + assert_eq!(uri, "/"); + } + + #[test] + fn test_absolute_form() { + let mut uri = "http://example.com".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri, "http://example.com"); + + let mut uri = "http://example.com:8080".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri, "http://example.com:8080"); + + let mut uri = "https://example.com/some/path?query".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri, "https://example.com/some/path?query"); + + let mut uri = "https://example.com:8443".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri, "https://example.com:8443"); + + let mut uri = "http://example.com:443".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri, "http://example.com:443"); + + let mut uri = "https://example.com:80".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri, "https://example.com:80"); + } + } +} + +#[cfg(feature = "client")] +pub(super) mod http2 { + use std::fmt; + use std::task::{Context, Poll}; + + use ::http; + + use crate::client::conn::Connection; + use crate::client::pool::PoolableConnection; + use crate::client::Error; + use crate::service::client::ExecuteRequest; + use crate::service::error::{MaybeErrorFuture, PreprocessService}; + + const CONNECTION_HEADERS: [http::HeaderName; 5] = [ + http::header::CONNECTION, + http::HeaderName::from_static("proxy-connection"), + http::HeaderName::from_static("keep-alive"), + http::header::TRANSFER_ENCODING, + http::header::UPGRADE, + ]; + + type PreprocessFn = fn(ExecuteRequest) -> Result, E>; + + /// A service that checks if the request is HTTP/2 compatible. + #[derive(Debug)] + pub struct Http2ChecksService + where + S: tower::Service, Error = Error>, + C: Connection + PoolableConnection, + { + inner: PreprocessService>, + } + + impl Clone for Http2ChecksService + where + S: tower::Service, Error = Error> + Clone, + C: Connection + PoolableConnection, + { + fn clone(&self) -> Self { + Self::new(self.inner.service().clone()) + } + } + + impl Http2ChecksService + where + S: tower::Service, Error = Error>, + C: Connection + PoolableConnection, + { + /// Create a new `Http2ChecksService`. + pub fn new(inner: S) -> Self { + Self { + inner: PreprocessService::new(inner, check_http2_request), + } + } + } + + fn check_http2_request( + mut req: ExecuteRequest, + ) -> Result, Error> + where + C: Connection + PoolableConnection, + { + if req.connection().version() == http::Version::HTTP_2 { + if req.request().method() == http::Method::CONNECT { + return Err(Error::InvalidMethod(http::Method::CONNECT)); + } + + *req.request_mut().version_mut() = http::Version::HTTP_2; + + for connection_header in &CONNECTION_HEADERS { + if req + .request_mut() + .headers_mut() + .remove(connection_header) + .is_some() + { + tracing::warn!( + "removed illegal connection header {:?} from HTTP/2 request", + connection_header + ); + }; + } + + if req + .request_mut() + .headers_mut() + .remove(http::header::HOST) + .is_some() + { + tracing::warn!("removed illegal header `host` from HTTP/2 request"); + } + } + Ok(req) + } + + impl tower::Service> for Http2ChecksService + where + S: tower::Service, Error = Error>, + C: Connection + PoolableConnection, + { + type Response = S::Response; + + type Error = S::Error; + + type Future = MaybeErrorFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + #[inline] + fn call(&mut self, req: ExecuteRequest) -> Self::Future { + self.inner.call(req) + } + } + + /// A `Layer` that applies HTTP/2 checks to requests. + pub struct Http2ChecksLayer { + _marker: std::marker::PhantomData, + } + + impl Http2ChecksLayer { + /// Create a new `Http2ChecksLayer`. + pub fn new() -> Self { + Self { + _marker: std::marker::PhantomData, + } + } + } + + impl Default for Http2ChecksLayer { + fn default() -> Self { + Self::new() + } + } + + impl fmt::Debug for Http2ChecksLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Http2ChecksLayer").finish() + } + } + + impl Clone for Http2ChecksLayer { + fn clone(&self) -> Self { + Self::new() + } + } + + impl tower::layer::Layer for Http2ChecksLayer + where + S: tower::Service, Error = Error>, + C: Connection + PoolableConnection, + { + type Service = Http2ChecksService; + + fn layer(&self, inner: S) -> Self::Service { + Http2ChecksService::new(inner) + } + } +} + #[cfg(test)] #[allow(dead_code)] mod tests { diff --git a/src/service/mod.rs b/src/service/mod.rs index ca3e3535..250c42ee 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,5 +1,10 @@ //! A collection of utilities for working with `Service` types and Servers. +#[cfg(feature = "client")] +pub(crate) mod client; +mod error; +#[cfg(feature = "client")] +mod host; mod http; #[cfg(feature = "incoming")] mod incoming; @@ -11,6 +16,15 @@ mod serviceref; mod shared; mod timeout; +#[cfg(feature = "client")] +pub use self::client::{ExecuteRequest, RequestExecutor}; +pub use self::error::{MaybeErrorFuture, PreprocessLayer, PreprocessService}; +#[cfg(feature = "client")] +pub use self::host::{SetHostHeader, SetHostHeaderLayer}; +#[cfg(feature = "client")] +pub use self::http::http1::{Http1ChecksLayer, Http1ChecksService}; +#[cfg(feature = "client")] +pub use self::http::http2::{Http2ChecksLayer, Http2ChecksService}; pub use self::http::HttpService; #[cfg(feature = "incoming")] pub use self::incoming::{AdaptIncomingLayer, AdaptIncomingService}; diff --git a/tests/client.rs b/tests/client.rs index c563bf5b..edfd0719 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -11,6 +11,8 @@ type BoxError = Box; #[tokio::test] async fn client() -> Result<(), BoxError> { + let _ = tracing_subscriber::fmt::try_init(); + let (tx, incoming) = hyperdriver::stream::duplex::pair(); let acceptor: hyperdriver::server::conn::Acceptor = @@ -38,6 +40,8 @@ async fn client() -> Result<(), BoxError> { #[tokio::test] async fn client_h2() -> Result<(), BoxError> { + let _ = tracing_subscriber::fmt::try_init(); + let (tx, incoming) = hyperdriver::stream::duplex::pair(); let acceptor: hyperdriver::server::conn::Acceptor =