diff --git a/examples/src/tower/server.rs b/examples/src/tower/server.rs index 59d884ad8..d1e27cf30 100644 --- a/examples/src/tower/server.rs +++ b/examples/src/tower/server.rs @@ -48,6 +48,7 @@ async fn main() -> Result<(), Box> { .layer(MyMiddlewareLayer::default()) // Interceptors can be also be applied as middleware .layer(tonic::service::interceptor(intercept)) + .layer(tonic::service::async_interceptor(async_intercept)) .into_inner(); Server::builder() @@ -65,6 +66,11 @@ fn intercept(req: Request<()>) -> Result, Status> { Ok(req) } +// An async interceptor function. +async fn async_intercept(req: Request<()>) -> Result, Status> { + Ok(req) +} + #[derive(Debug, Clone, Default)] struct MyMiddlewareLayer; diff --git a/tests/integration_tests/tests/extensions.rs b/tests/integration_tests/tests/extensions.rs index c60c98414..e41f87d02 100644 --- a/tests/integration_tests/tests/extensions.rs +++ b/tests/integration_tests/tests/extensions.rs @@ -8,6 +8,7 @@ use std::{ use tokio::sync::oneshot; use tonic::{ body::BoxBody, + service::{async_interceptor, interceptor}, transport::{Endpoint, NamedService, Server}, Request, Response, Status, }; @@ -60,6 +61,100 @@ async fn setting_extension_from_interceptor() { jh.await.unwrap(); } +#[tokio::test] +async fn setting_extension_from_interceptor_layer() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + let value = req.extensions().get::().unwrap(); + assert_eq!(value.0, 42); + + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + let interceptor_layer = interceptor(|mut req: Request<()>| { + req.extensions_mut().insert(ExtensionValue(42)); + Ok(req) + }); + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .layer(interceptor_layer) + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1325".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1325") + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + client.unary_call(Input {}).await.unwrap(); + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} + +#[tokio::test] +async fn setting_extension_from_async_interceptor_layer() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + let value = req.extensions().get::().unwrap(); + assert_eq!(value.0, 42); + + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + let interceptor_layer = async_interceptor(|mut req: Request<()>| { + req.extensions_mut().insert(ExtensionValue(42)); + futures::future::ok(req) + }); + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .layer(interceptor_layer) + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1326".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1326") + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + client.unary_call(Input {}).await.unwrap(); + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} + #[tokio::test] async fn setting_extension_from_tower() { struct Svc; diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index cb8d46bdb..96348a384 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -5,13 +5,15 @@ use crate::{ body::{boxed, BoxBody}, request::SanitizeHeaders, - Status, + Request, Status, }; use bytes::Bytes; +use http::{Method, Uri, Version}; use pin_project::pin_project; use std::{ fmt, future::Future, + mem, pin::Pin, task::{Context, Poll}, }; @@ -57,6 +59,26 @@ where } } +/// Async version of `Interceptor`. +pub trait AsyncInterceptor { + /// The Future returned by the interceptor. + type Future: Future, Status>>; + /// Intercept a request before it is sent, optionally cancelling it. + fn call(&mut self, request: crate::Request<()>) -> Self::Future; +} + +impl AsyncInterceptor for F +where + F: FnMut(crate::Request<()>) -> U, + U: Future, Status>>, +{ + type Future = U; + + fn call(&mut self, request: crate::Request<()>) -> Self::Future { + self(request) + } +} + /// Create a new interceptor layer. /// /// See [`Interceptor`] for more details. @@ -81,6 +103,16 @@ where interceptor(f) } +/// Create a new async interceptor layer. +/// +/// See [`AsyncInterceptor`] and [`Interceptor`] for more details. +pub fn async_interceptor(f: F) -> AsyncInterceptorLayer +where + F: AsyncInterceptor, +{ + AsyncInterceptorLayer { f } +} + /// A gRPC interceptor that can be used as a [`Layer`], /// created by calling [`interceptor`]. /// @@ -101,6 +133,27 @@ where } } +/// A gRPC async interceptor that can be used as a [`Layer`], +/// created by calling [`async_interceptor`]. +/// +/// See [`AsyncInterceptor`] for more details. +#[derive(Debug, Clone, Copy)] +pub struct AsyncInterceptorLayer { + f: F, +} + +impl Layer for AsyncInterceptorLayer +where + S: Clone, + F: AsyncInterceptor + Clone, +{ + type Service = AsyncInterceptedService; + + fn layer(&self, service: S) -> Self::Service { + AsyncInterceptedService::new(service, self.f.clone()) + } +} + #[deprecated( since = "0.5.1", note = "Please use the `InterceptorLayer` type instead" @@ -143,13 +196,63 @@ where } } +// Components and attributes of a request, without metadata or extensions. +#[derive(Debug)] +struct DecomposedRequest { + uri: Uri, + method: Method, + http_version: Version, + msg: ReqBody, +} + +/// Decompose the request into its contents and properties, and create a new request without a body. +/// +/// It is bad practice to modify the body (i.e. Message) of the request via an interceptor. +/// To avoid exposing the body of the request to the interceptor function, we first remove it +/// here, allow the interceptor to modify the metadata and extensions, and then recreate the +/// HTTP request with the original message body with the `recompose` function. Also note that Tonic +/// requests do not preserve the URI, HTTP version, and HTTP method of the HTTP request, so we +/// extract them here and then add them back in `recompose`. +fn decompose(req: http::Request) -> (DecomposedRequest, Request<()>) { + let uri = req.uri().clone(); + let method = req.method().clone(); + let http_version = req.version(); + let req = crate::Request::from_http(req); + let (metadata, extensions, msg) = req.into_parts(); + + let dreq = DecomposedRequest { + uri, + method, + http_version, + msg, + }; + let req_without_body = crate::Request::from_parts(metadata, extensions, ()); + + (dreq, req_without_body) +} + +/// Combine the modified metadata and extensions with the original message body and attributes. +fn recompose( + dreq: DecomposedRequest, + modified_req: Request<()>, +) -> http::Request { + let (metadata, extensions, _) = modified_req.into_parts(); + let req = crate::Request::from_parts(metadata, extensions, dreq.msg); + + req.into_http( + dreq.uri, + dreq.method, + dreq.http_version, + SanitizeHeaders::No, + ) +} + impl Service> for InterceptedService where - ResBody: Default + http_body::Body + Send + 'static, F: Interceptor, S: Service, Response = http::Response>, S::Error: Into, - ResBody: http_body::Body + Send + 'static, + ResBody: Default + http_body::Body + Send + 'static, ResBody::Error: Into, { type Response = http::Response; @@ -162,26 +265,13 @@ where } fn call(&mut self, req: http::Request) -> Self::Future { - // It is bad practice to modify the body (i.e. Message) of the request via an interceptor. - // To avoid exposing the body of the request to the interceptor function, we first remove it - // here, allow the interceptor to modify the metadata and extensions, and then recreate the - // HTTP request with the body. Tonic requests do not preserve the URI, HTTP version, and - // HTTP method of the HTTP request, so we extract them here and then add them back in below. - let uri = req.uri().clone(); - let method = req.method().clone(); - let version = req.version(); - let req = crate::Request::from_http(req); - let (metadata, extensions, msg) = req.into_parts(); - - match self - .f - .call(crate::Request::from_parts(metadata, extensions, ())) - { - Ok(req) => { - let (metadata, extensions, _) = req.into_parts(); - let req = crate::Request::from_parts(metadata, extensions, msg); - let req = req.into_http(uri, method, version, SanitizeHeaders::No); - ResponseFuture::future(self.inner.call(req)) + let (dreq, req_without_body) = decompose(req); + + match self.f.call(req_without_body) { + Ok(modified_req) => { + let modified_req_with_body = recompose(dreq, modified_req); + + ResponseFuture::future(self.inner.call(modified_req_with_body)) } Err(status) => ResponseFuture::status(status), } @@ -197,6 +287,67 @@ where const NAME: &'static str = S::NAME; } +/// A service wrapped in an async interceptor middleware. +/// +/// See [`AsyncInterceptor`] for more details. +#[derive(Clone, Copy)] +pub struct AsyncInterceptedService { + inner: S, + f: F, +} + +impl AsyncInterceptedService { + /// Create a new `AsyncInterceptedService` that wraps `S` and intercepts each request with the + /// function `F`. + pub fn new(service: S, f: F) -> Self { + Self { inner: service, f } + } +} + +impl fmt::Debug for AsyncInterceptedService +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AsyncInterceptedService") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Service> for AsyncInterceptedService +where + F: AsyncInterceptor + Clone, + S: Service, Response = http::Response> + Clone, + S::Error: Into, + ReqBody: Default, + ResBody: Default + http_body::Body + Send + 'static, + ResBody::Error: Into, +{ + type Response = http::Response; + type Error = S::Error; + type Future = AsyncResponseFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + AsyncResponseFuture::new(req, &mut self.f, self.inner.clone()) + } +} + +// required to use `AsyncInterceptedService` with `Router` +#[cfg(feature = "transport")] +impl crate::transport::NamedService for AsyncInterceptedService +where + S: crate::transport::NamedService, +{ + const NAME: &'static str = S::NAME; +} + /// Response future for [`InterceptedService`]. #[pin_project] #[derive(Debug)] @@ -253,6 +404,122 @@ where } } +#[pin_project(project = PinnedOptionProj)] +#[derive(Debug)] +enum PinnedOption { + Some(#[pin] F), + None, +} + +/// Response future for [`AsyncInterceptedService`]. +/// +/// Handles the call to the async interceptor, then calls the inner service and wraps the result in +/// [`ResponseFuture`]. +#[pin_project(project = AsyncResponseFutureProj)] +#[derive(Debug)] +pub struct AsyncResponseFuture +where + S: Service>, + S::Error: Into, + I: Future, Status>>, +{ + #[pin] + interceptor_fut: PinnedOption, + #[pin] + inner_fut: PinnedOption>, + inner: S, + dreq: DecomposedRequest, +} + +impl AsyncResponseFuture +where + S: Service>, + S::Error: Into, + I: Future, Status>>, + ReqBody: Default, +{ + fn new>( + req: http::Request, + interceptor: &mut A, + inner: S, + ) -> Self { + let (dreq, req_without_body) = decompose(req); + let interceptor_fut = interceptor.call(req_without_body); + + AsyncResponseFuture { + interceptor_fut: PinnedOption::Some(interceptor_fut), + inner_fut: PinnedOption::None, + inner, + dreq, + } + } + + /// Calls the inner service with the intercepted request (which has been modified by the + /// async interceptor func). + fn create_inner_fut( + this: &mut AsyncResponseFutureProj<'_, S, I, ReqBody>, + intercepted_req: Result, Status>, + ) -> ResponseFuture { + match intercepted_req { + Ok(req) => { + // We can't move the message body out of the pin projection. So, to + // avoid copying it, we swap its memory with an empty body and then can + // move it into the recomposed request. + let msg = mem::take(&mut this.dreq.msg); + let movable_dreq = DecomposedRequest { + uri: this.dreq.uri.clone(), + method: this.dreq.method.clone(), + http_version: this.dreq.http_version, + msg, + }; + let modified_req_with_body = recompose(movable_dreq, req); + + ResponseFuture::future(this.inner.call(modified_req_with_body)) + } + Err(status) => ResponseFuture::status(status), + } + } +} + +impl Future for AsyncResponseFuture +where + S: Service, Response = http::Response>, + I: Future, Status>>, + S::Error: Into, + ReqBody: Default, + ResBody: Default + http_body::Body + Send + 'static, + ResBody::Error: Into, +{ + type Output = Result, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + // The struct was initialized (via `new`) with interceptor func future, which we poll here. + if let PinnedOptionProj::Some(f) = this.interceptor_fut.as_mut().project() { + match f.poll(cx) { + Poll::Ready(intercepted_req) => { + let inner_fut = AsyncResponseFuture::::create_inner_fut( + &mut this, + intercepted_req, + ); + // Set the inner service future and clear the interceptor future. + this.inner_fut.set(PinnedOption::Some(inner_fut)); + this.interceptor_fut.set(PinnedOption::None); + } + Poll::Pending => return Poll::Pending, + } + } + // At this point, inner_fut should always be Some. + let inner_fut = match this.inner_fut.project() { + PinnedOptionProj::None => panic!(), + PinnedOptionProj::Some(f) => f, + }; + + inner_fut.poll(cx) + } +} + #[cfg(test)] mod tests { #[allow(unused_imports)] @@ -320,6 +587,39 @@ mod tests { svc.oneshot(request).await.unwrap(); } + #[tokio::test] + async fn async_interceptor_doesnt_remove_headers() { + let svc = tower::service_fn(|request: http::Request| async move { + assert_eq!( + request + .headers() + .get("user-agent") + .expect("missing in leaf service"), + "test-tonic" + ); + + Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty())) + }); + + let svc = AsyncInterceptedService::new(svc, |request: crate::Request<()>| { + assert_eq!( + request + .metadata() + .get("user-agent") + .expect("missing in interceptor"), + "test-tonic" + ); + std::future::ready(Ok(request)) + }); + + let request = http::Request::builder() + .header("user-agent", "test-tonic") + .body(hyper::Body::empty()) + .unwrap(); + + svc.oneshot(request).await.unwrap(); + } + #[tokio::test] async fn handles_intercepted_status_as_response() { let message = "Blocked by the interceptor"; @@ -341,6 +641,27 @@ mod tests { assert_eq!(expected.headers(), response.headers()); } + #[tokio::test] + async fn async_interceptor_handles_intercepted_status_as_response() { + let message = "Blocked by the interceptor"; + let expected = Status::permission_denied(message).to_http(); + + let svc = tower::service_fn(|_: http::Request| async { + Ok::<_, Status>(http::Response::new(TestBody)) + }); + + let svc = AsyncInterceptedService::new(svc, |_: crate::Request<()>| { + std::future::ready(Err(Status::permission_denied(message))) + }); + + let request = http::Request::builder().body(TestBody).unwrap(); + let response = svc.oneshot(request).await.unwrap(); + + assert_eq!(expected.status(), response.status()); + assert_eq!(expected.version(), response.version()); + assert_eq!(expected.headers(), response.headers()); + } + #[tokio::test] async fn doesnt_change_http_method() { let svc = tower::service_fn(|request: http::Request| async move { @@ -358,4 +679,24 @@ mod tests { svc.oneshot(request).await.unwrap(); } + + #[tokio::test] + async fn async_interceptor_doesnt_change_http_method() { + let svc = tower::service_fn(|request: http::Request| async move { + assert_eq!(request.method(), http::Method::OPTIONS); + + Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty())) + }); + + let svc = AsyncInterceptedService::new(svc, |request: crate::Request<()>| { + std::future::ready(Ok(request)) + }); + + let request = http::Request::builder() + .method(http::Method::OPTIONS) + .body(hyper::Body::empty()) + .unwrap(); + + svc.oneshot(request).await.unwrap(); + } } diff --git a/tonic/src/service/mod.rs b/tonic/src/service/mod.rs index 125ac7980..f464d0670 100644 --- a/tonic/src/service/mod.rs +++ b/tonic/src/service/mod.rs @@ -4,4 +4,6 @@ pub mod interceptor; #[doc(inline)] #[allow(deprecated)] -pub use self::interceptor::{interceptor, interceptor_fn, Interceptor}; +pub use self::interceptor::{ + async_interceptor, interceptor, interceptor_fn, AsyncInterceptor, Interceptor, +};