Skip to content

Commit

Permalink
Bonus: removed unnecessary boxing
Browse files Browse the repository at this point in the history
Signed-off-by: slinkydeveloper <[email protected]>
  • Loading branch information
slinkydeveloper committed Oct 31, 2022
1 parent 53f2390 commit 216f410
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 16 deletions.
5 changes: 1 addition & 4 deletions tonic-web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,13 @@
#![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")]

pub use layer::GrpcWebLayer;
pub use service::GrpcWebService;
pub use service::{GrpcWebService, ResponseFuture};

mod call;
mod layer;
mod service;

use http::header::HeaderName;
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use tonic::body::BoxBody;
use tower_http::cors::{AllowOrigin, Cors, CorsLayer};
Expand All @@ -110,7 +108,6 @@ const DEFAULT_ALLOW_HEADERS: [&str; 4] =
["x-grpc-web", "content-type", "x-user-agent", "grpc-timeout"];

type BoxError = Box<dyn std::error::Error + Send + Sync>;
type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;

/// Enable a tonic service to handle grpc-web requests with the default configuration.
///
Expand Down
89 changes: 77 additions & 12 deletions tonic-web/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use futures_core::ready;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
use hyper::Body;
use pin_project::pin_project;
use tonic::body::{empty_body, BoxBody};
use tonic::transport::NamedService;
use tower_service::Service;
use tracing::{debug, trace};

use crate::call::content_types::is_grpc_web;
use crate::call::{Encoding, GrpcWebCall};
use crate::{BoxError, BoxFuture};
use crate::BoxError;

const GRPC: &str = "application/grpc";

Expand Down Expand Up @@ -47,13 +51,17 @@ impl<S> GrpcWebService<S>
where
S: Service<Request<Body>, Response = Response<BoxBody>> + Send + 'static,
{
fn response(&self, status: StatusCode) -> BoxFuture<S::Response, S::Error> {
Box::pin(async move {
Ok(Response::builder()
.status(status)
.body(empty_body())
.unwrap())
})
fn response(&self, status: StatusCode) -> ResponseFuture<S::Future> {
ResponseFuture {
case: Case::ImmediateResponse {
res: Some(
Response::builder()
.status(status)
.body(empty_body())
.unwrap(),
),
},
}
}
}

Expand All @@ -65,7 +73,7 @@ where
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<Self::Response, Self::Error>;
type Future = ResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
Expand All @@ -89,8 +97,12 @@ where
} => {
trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept);

let fut = self.inner.call(coerce_request(req, encoding));
Box::pin(async move { Ok(coerce_response(fut.await?, accept)) })
ResponseFuture {
case: Case::GrpcWeb {
future: self.inner.call(coerce_request(req, encoding)),
accept,
},
}
}

// The request's content-type matches one of the 4 supported grpc-web
Expand All @@ -105,7 +117,11 @@ where
// whatever they are.
RequestKind::Other(Version::HTTP_2) => {
debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
Box::pin(self.inner.call(req))
ResponseFuture {
case: Case::Other {
future: self.inner.call(req),
},
}
}

// Return HTTP 400 for all other requests.
Expand All @@ -117,6 +133,53 @@ where
}
}

/// Response future for the [`GrpcWebService`].
#[allow(missing_debug_implementations)]
#[pin_project]
#[must_use = "futures do nothing unless polled"]
pub struct ResponseFuture<F> {
#[pin]
case: Case<F>,
}

#[pin_project(project = CaseProj)]
enum Case<F> {
GrpcWeb {
#[pin]
future: F,
accept: Encoding,
},
Other {
#[pin]
future: F,
},
ImmediateResponse {
res: Option<Response<BoxBody>>,
},
}

impl<F, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<BoxBody>, E>> + Send + 'static,
E: Into<BoxError> + Send,
{
type Output = Result<Response<BoxBody>, E>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();

match this.case.as_mut().project() {
CaseProj::GrpcWeb { future, accept } => {
let res = ready!(future.poll(cx))?;

Poll::Ready(Ok(coerce_response(res, *accept)))
}
CaseProj::Other { future } => future.poll(cx),
CaseProj::ImmediateResponse { res } => Poll::Ready(Ok(res.take().unwrap())),
}
}
}

impl<S: NamedService> NamedService for GrpcWebService<S> {
const NAME: &'static str = S::NAME;
}
Expand Down Expand Up @@ -177,6 +240,8 @@ mod tests {
ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN,
};

type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;

#[derive(Debug, Clone)]
struct Svc;

Expand Down

0 comments on commit 216f410

Please sign in to comment.