diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index 2ac3e2fde..56790c82d 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -88,15 +88,13 @@ #![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")] pub use layer::GrpcWebLayer; -pub use service::GrpcWebService; +pub use service::{GrpcWebService, ResponseFuture}; mod call; mod layer; mod service; use http::header::HeaderName; -use std::future::Future; -use std::pin::Pin; use std::time::Duration; use tonic::body::BoxBody; use tower_http::cors::{AllowOrigin, Cors, CorsLayer}; @@ -110,7 +108,6 @@ const DEFAULT_ALLOW_HEADERS: [&str; 4] = ["x-grpc-web", "content-type", "x-user-agent", "grpc-timeout"]; type BoxError = Box; -type BoxFuture = Pin> + Send>>; /// Enable a tonic service to handle grpc-web requests with the default configuration. /// diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index ac4ef9a91..b071ec449 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -1,7 +1,11 @@ +use futures_core::ready; +use std::future::Future; +use std::pin::Pin; use std::task::{Context, Poll}; use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version}; use hyper::Body; +use pin_project::pin_project; use tonic::body::{empty_body, BoxBody}; use tonic::transport::NamedService; use tower_service::Service; @@ -9,7 +13,7 @@ use tracing::{debug, trace}; use crate::call::content_types::is_grpc_web; use crate::call::{Encoding, GrpcWebCall}; -use crate::{BoxError, BoxFuture}; +use crate::BoxError; const GRPC: &str = "application/grpc"; @@ -47,13 +51,17 @@ impl GrpcWebService where S: Service, Response = Response> + Send + 'static, { - fn response(&self, status: StatusCode) -> BoxFuture { - Box::pin(async move { - Ok(Response::builder() - .status(status) - .body(empty_body()) - .unwrap()) - }) + fn response(&self, status: StatusCode) -> ResponseFuture { + ResponseFuture { + case: Case::ImmediateResponse { + res: Some( + Response::builder() + .status(status) + .body(empty_body()) + .unwrap(), + ), + }, + } } } @@ -65,7 +73,7 @@ where { type Response = S::Response; type Error = S::Error; - type Future = BoxFuture; + type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) @@ -89,8 +97,12 @@ where } => { trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept); - let fut = self.inner.call(coerce_request(req, encoding)); - Box::pin(async move { Ok(coerce_response(fut.await?, accept)) }) + ResponseFuture { + case: Case::GrpcWeb { + future: self.inner.call(coerce_request(req, encoding)), + accept, + }, + } } // The request's content-type matches one of the 4 supported grpc-web @@ -105,7 +117,11 @@ where // whatever they are. RequestKind::Other(Version::HTTP_2) => { debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE)); - Box::pin(self.inner.call(req)) + ResponseFuture { + case: Case::Other { + future: self.inner.call(req), + }, + } } // Return HTTP 400 for all other requests. @@ -117,6 +133,53 @@ where } } +/// Response future for the [`GrpcWebService`]. +#[allow(missing_debug_implementations)] +#[pin_project] +#[must_use = "futures do nothing unless polled"] +pub struct ResponseFuture { + #[pin] + case: Case, +} + +#[pin_project(project = CaseProj)] +enum Case { + GrpcWeb { + #[pin] + future: F, + accept: Encoding, + }, + Other { + #[pin] + future: F, + }, + ImmediateResponse { + res: Option>, + }, +} + +impl Future for ResponseFuture +where + F: Future, E>> + Send + 'static, + E: Into + Send, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + match this.case.as_mut().project() { + CaseProj::GrpcWeb { future, accept } => { + let res = ready!(future.poll(cx))?; + + Poll::Ready(Ok(coerce_response(res, *accept))) + } + CaseProj::Other { future } => future.poll(cx), + CaseProj::ImmediateResponse { res } => Poll::Ready(Ok(res.take().unwrap())), + } + } +} + impl NamedService for GrpcWebService { const NAME: &'static str = S::NAME; } @@ -177,6 +240,8 @@ mod tests { ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN, }; + type BoxFuture = Pin> + Send>>; + #[derive(Debug, Clone)] struct Svc;