From 20369f561688576a54a9eeb44976bcac418d50e1 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:13:48 +0000 Subject: [PATCH 01/31] Upgrade to hyper 1 and http 1 Upgrades only in Cargo.toml Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- examples/Cargo.toml | 16 +++++++++------- interop/Cargo.toml | 6 +++--- tests/compression/Cargo.toml | 8 +++++--- tests/integration_tests/Cargo.toml | 7 ++++--- tonic-web/Cargo.toml | 6 +++--- tonic-web/tests/integration/Cargo.toml | 5 ++++- tonic/Cargo.toml | 17 ++++++++++------- 7 files changed, 38 insertions(+), 27 deletions(-) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 57d05d3e3..a672287d8 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -311,16 +311,18 @@ serde_json = { version = "1.0", optional = true } tracing = { version = "0.1.16", optional = true } tracing-subscriber = { version = "0.3", features = ["tracing-log", "fmt"], optional = true } prost-types = { version = "0.12", optional = true } -http = { version = "0.2", optional = true } -http-body = { version = "0.4.2", optional = true } -hyper = { version = "0.14", optional = true } +http = { version = "1", optional = true } +http-body = { version = "1", optional = true } +http-body-util = { version = "0.1", optional = true } +hyper = { version = "1", optional = true } +hyper-util = { version = "0.1", optional = true } listenfd = { version = "1.0", optional = true } bytes = { version = "1", optional = true } h2 = { version = "0.3", optional = true } -tokio-rustls = { version = "0.24.0", optional = true } -hyper-rustls = { version = "0.24.0", features = ["http2"], optional = true } -rustls-pemfile = { version = "1", optional = true } -tower-http = { version = "0.4", optional = true } +tokio-rustls = { version = "0.26", optional = true, features = ["ring", "tls12"], default-features = false } +hyper-rustls = { version = "0.27.0", features = ["http2", "ring", "tls12"], optional = true, default-features = false } +rustls-pemfile = { version = "2.0.0", optional = true } +tower-http = { version = "0.5", optional = true } [build-dependencies] tonic-build = { path = "../tonic-build", features = ["prost"] } diff --git a/interop/Cargo.toml b/interop/Cargo.toml index a58ef64cf..9a32b2a1d 100644 --- a/interop/Cargo.toml +++ b/interop/Cargo.toml @@ -19,9 +19,9 @@ async-stream = "0.3" strum = {version = "0.26", features = ["derive"]} pico-args = {version = "0.5", features = ["eq-separator"]} console = "0.15" -http = "0.2" -http-body = "0.4.2" -hyper = "0.14" +http = "1" +http-body = "1" +hyper = "1" prost = "0.12" tokio = {version = "1.0", features = ["rt-multi-thread", "time", "macros"]} tokio-stream = "0.1" diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index 5bc87c829..4ba549cdc 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -8,9 +8,11 @@ version = "0.1.0" [dependencies] bytes = "1" -http = "0.2" -http-body = "0.4" -hyper = "0.14.3" +http = "1" +http-body = "1" +http-body-util = "0.1" +hyper = "1" +hyper-util = "0.1" paste = "1.0.12" pin-project = "1.0" prost = "0.12" diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 222d1919c..6a7ec8052 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -17,9 +17,10 @@ tracing-subscriber = {version = "0.3"} [dev-dependencies] async-stream = "0.3" -http = "0.2" -http-body = "0.4" -hyper = "0.14" +http = "1" +http-body = "1" +hyper = "1" +hyper-util = "0.1" tokio-stream = {version = "0.1.5", features = ["net"]} tower = {version = "0.4", features = []} tower-http = { version = "0.4", features = ["set-header", "trace"] } diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml index d6649f65c..5813fd6a3 100644 --- a/tonic-web/Cargo.toml +++ b/tonic-web/Cargo.toml @@ -18,9 +18,9 @@ version = "0.11.0" base64 = "0.22" bytes = "1" tokio-stream = "0.1" -http = "0.2" -http-body = "0.4" -hyper = {version = "0.14", default-features = false, features = ["stream"]} +http = "1" +http-body = "1" +http-body-util = "0.1" pin-project = "1" tonic = {version = "0.11", path = "../tonic", default-features = false} tower-service = "0.3" diff --git a/tonic-web/tests/integration/Cargo.toml b/tonic-web/tests/integration/Cargo.toml index 5c6d5727e..38fd9ff32 100644 --- a/tonic-web/tests/integration/Cargo.toml +++ b/tonic-web/tests/integration/Cargo.toml @@ -9,7 +9,10 @@ license = "MIT" [dependencies] base64 = "0.22" bytes = "1.0" -hyper = "0.14" +http-body = "1" +http-body-util = "0.1" +hyper = "1" +hyper-util = "0.1" prost = "0.12" tokio = { version = "1", features = ["macros", "rt", "net"] } tokio-stream = { version = "0.1", features = ["net"] } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 0d2be669f..1934c416e 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -51,10 +51,11 @@ channel = [] [dependencies] base64 = "0.22" bytes = "1.0" -http = "0.2" +http = "1" tracing = "0.1" -http-body = "0.4.4" +http-body = "1" +http-body-util = "0.1" percent-encoding = "2.1" pin-project = "1.0.11" tower-layer = "0.3" @@ -68,11 +69,13 @@ async-trait = {version = "0.1.13", optional = true} # transport async-stream = {version = "0.3", optional = true} -h2 = {version = "0.3.24", optional = true} -hyper = {version = "0.14.26", features = ["full"], optional = true} -hyper-timeout = {version = "0.4", optional = true} -tokio = {version = "1.0.1", optional = true} -tokio-stream = "0.1" +h2 = {version = "0.4", optional = true} +hyper = {version = "1", features = ["full"], optional = true} +hyper-util = { version = ">=0.1.4, <0.2", features = ["full"], optional = true } +hyper-timeout = {version = "0.5", optional = true} +socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] } +tokio = {version = "1", default-features = false, optional = true} +tokio-stream = { version = "0.1", features = ["net"] } tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} axum = {version = "0.6.9", default-features = false, optional = true} From aaede1e47bc2396001106a82fc97fff1979407f1 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:17:49 +0000 Subject: [PATCH 02/31] Convert from hyper::Body to http_body::BoxedBody When appropriate, we replace `hyper::Body` with `http_body::BoxedBody`, a good general purpose replacement for `hyper::Body`. Hyper does provide `hyper::body::Incoming`, but we cannot construct that, so anywhere we might need a body that we can construct (even most Service trait impls) we must use something like `http_body::BoxedBody`. When a service accepts `BoxedBody` and not `Incoming`, this indicates that the service is designed to run in places where it is not adjacent to hyper, for example, after routing (which is managed by Axum) Additionally, http >= 1 requires that extension types are `Clone`, so this bound has been added where appropriate. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- examples/src/interceptor/server.rs | 1 + examples/src/tower/client.rs | 3 +- examples/src/tower/server.rs | 10 +++-- interop/src/server.rs | 6 +-- tests/integration_tests/tests/extensions.rs | 9 +++-- tests/integration_tests/tests/origin.rs | 1 + tonic-web/src/layer.rs | 2 +- tonic-web/src/lib.rs | 10 ++--- tonic-web/src/service.rs | 39 +++++++++---------- tonic-web/tests/integration/tests/grpc_web.rs | 10 +++-- tonic/src/body.rs | 6 +-- tonic/src/extensions.rs | 2 +- tonic/src/request.rs | 2 +- tonic/src/transport/server/mod.rs | 11 ++++-- tonic/src/transport/service/connection.rs | 5 +-- 15 files changed, 62 insertions(+), 55 deletions(-) diff --git a/examples/src/interceptor/server.rs b/examples/src/interceptor/server.rs index 263348a6d..fd0cf462f 100644 --- a/examples/src/interceptor/server.rs +++ b/examples/src/interceptor/server.rs @@ -57,6 +57,7 @@ fn intercept(mut req: Request<()>) -> Result, Status> { Ok(req) } +#[derive(Clone)] struct MyExtension { some_piece_of_data: String, } diff --git a/examples/src/tower/client.rs b/examples/src/tower/client.rs index 0a33fffae..39fec5d47 100644 --- a/examples/src/tower/client.rs +++ b/examples/src/tower/client.rs @@ -44,7 +44,6 @@ mod service { use std::pin::Pin; use std::task::{Context, Poll}; use tonic::body::BoxBody; - use tonic::transport::Body; use tonic::transport::Channel; use tower::Service; @@ -59,7 +58,7 @@ mod service { } impl Service> for AuthSvc { - type Response = Response; + type Response = Response; type Error = Box; #[allow(clippy::type_complexity)] type Future = Pin> + Send>>; diff --git a/examples/src/tower/server.rs b/examples/src/tower/server.rs index cc85d62e5..b7066a1b6 100644 --- a/examples/src/tower/server.rs +++ b/examples/src/tower/server.rs @@ -1,4 +1,3 @@ -use hyper::Body; use std::{ pin::Pin, task::{Context, Poll}, @@ -84,9 +83,12 @@ struct MyMiddleware { type BoxFuture<'a, T> = Pin + Send + 'a>>; -impl Service> for MyMiddleware +impl Service> for MyMiddleware where - S: Service, Response = hyper::Response> + Clone + Send + 'static, + S: Service, Response = hyper::Response> + + Clone + + Send + + 'static, S::Future: Send + 'static, { type Response = S::Response; @@ -97,7 +99,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: hyper::Request) -> Self::Future { + fn call(&mut self, req: hyper::Request) -> Self::Future { // This is necessary because tonic internally uses `tower::buffer::Buffer`. // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149 // for details on why this is necessary diff --git a/interop/src/server.rs b/interop/src/server.rs index b32468866..aef7b0d45 100644 --- a/interop/src/server.rs +++ b/interop/src/server.rs @@ -180,9 +180,9 @@ impl EchoHeadersSvc { } } -impl Service> for EchoHeadersSvc +impl Service> for EchoHeadersSvc where - S: Service, Response = http::Response> + Send, + S: Service, Response = http::Response> + Send, S::Future: Send + 'static, { type Response = S::Response; @@ -193,7 +193,7 @@ where Ok(()).into() } - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, req: http::Request) -> Self::Future { let echo_header = req.headers().get("x-grpc-test-echo-initial").cloned(); let echo_trailer = req diff --git a/tests/integration_tests/tests/extensions.rs b/tests/integration_tests/tests/extensions.rs index b112f8e66..b2380181d 100644 --- a/tests/integration_tests/tests/extensions.rs +++ b/tests/integration_tests/tests/extensions.rs @@ -1,4 +1,4 @@ -use hyper::{Body, Request as HyperRequest, Response as HyperResponse}; +use hyper::{Request as HyperRequest, Response as HyperResponse}; use integration_tests::{ pb::{test_client, test_server, Input, Output}, BoxFuture, @@ -16,6 +16,7 @@ use tonic::{ }; use tower_service::Service; +#[derive(Clone)] struct ExtensionValue(i32); #[tokio::test] @@ -112,9 +113,9 @@ struct InterceptedService { inner: S, } -impl Service> for InterceptedService +impl Service> for InterceptedService where - S: Service, Response = HyperResponse> + S: Service, Response = HyperResponse> + NamedService + Clone + Send @@ -129,7 +130,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, mut req: HyperRequest) -> Self::Future { + fn call(&mut self, mut req: HyperRequest) -> Self::Future { let clone = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, clone); diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs index f149dc68d..c8140c79f 100644 --- a/tests/integration_tests/tests/origin.rs +++ b/tests/integration_tests/tests/origin.rs @@ -7,6 +7,7 @@ use std::time::Duration; use tokio::sync::oneshot; use tonic::codegen::http::Request; use tonic::{ + body::BoxBody, transport::{Endpoint, Server}, Response, Status, }; diff --git a/tonic-web/src/layer.rs b/tonic-web/src/layer.rs index 77b03c77e..7834f1990 100644 --- a/tonic-web/src/layer.rs +++ b/tonic-web/src/layer.rs @@ -24,7 +24,7 @@ impl Default for GrpcWebLayer { impl Layer for GrpcWebLayer where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index 16e57e19d..50ed8c0a8 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -127,7 +127,7 @@ type BoxError = Box; /// You can customize the CORS configuration composing the [`GrpcWebLayer`] with the cors layer of your choice. pub fn enable(service: S) -> CorsGrpcWeb where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -159,9 +159,9 @@ where #[derive(Debug, Clone)] pub struct CorsGrpcWeb(tower_http::cors::Cors>); -impl Service> for CorsGrpcWeb +impl Service> for CorsGrpcWeb where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -169,7 +169,7 @@ where type Response = S::Response; type Error = S::Error; type Future = - > as Service>>::Future; + > as Service>>::Future; fn poll_ready( &mut self, @@ -178,7 +178,7 @@ where self.0.poll_ready(cx) } - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, req: http::Request) -> Self::Future { self.0.call(req) } } diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index af4c5276f..da65ba832 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use std::task::{ready, Context, Poll}; use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version}; -use hyper::Body; +use http_body_util::BodyExt; use pin_project::pin_project; use tonic::{ body::{empty_body, BoxBody}, @@ -50,7 +50,7 @@ impl GrpcWebService { impl GrpcWebService where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Send + 'static, { fn response(&self, status: StatusCode) -> ResponseFuture { ResponseFuture { @@ -66,9 +66,9 @@ where } } -impl Service> for GrpcWebService +impl Service> for GrpcWebService where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, { @@ -80,7 +80,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { match RequestKind::new(req.headers(), req.method(), req.version()) { // A valid grpc-web request, regardless of HTTP version. // @@ -202,7 +202,7 @@ impl<'a> RequestKind<'a> { // Mutating request headers to conform to a gRPC request is not really // necessary for us at this point. We could remove most of these except // maybe for inserting `header::TE`, which tonic should check? -fn coerce_request(mut req: Request, encoding: Encoding) -> Request { +fn coerce_request(mut req: Request, encoding: Encoding) -> Request { req.headers_mut().remove(header::CONTENT_LENGTH); req.headers_mut() @@ -216,8 +216,7 @@ fn coerce_request(mut req: Request, encoding: Encoding) -> Request { HeaderValue::from_static("identity,deflate,gzip"), ); - req.map(|b| GrpcWebCall::request(b, encoding)) - .map(Body::wrap_stream) + req.map(|b| GrpcWebCall::request(b, encoding).boxed_unsync()) } fn coerce_response(res: Response, encoding: Encoding) -> Response { @@ -246,7 +245,7 @@ mod tests { #[derive(Debug, Clone)] struct Svc; - impl tower_service::Service> for Svc { + impl tower_service::Service> for Svc { type Response = Response; type Error = String; type Future = BoxFuture; @@ -255,7 +254,7 @@ mod tests { Poll::Ready(Ok(())) } - fn call(&mut self, _: Request) -> Self::Future { + fn call(&mut self, _: Request) -> Self::Future { Box::pin(async { Ok(Response::new(empty_body())) }) } } @@ -266,15 +265,14 @@ mod tests { mod grpc_web { use super::*; - use http::HeaderValue; use tower_layer::Layer; - fn request() -> Request { + fn request() -> Request { Request::builder() .method(Method::POST) .header(CONTENT_TYPE, GRPC_WEB) .header(ORIGIN, "http://example.com") - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -350,13 +348,13 @@ mod tests { mod options { use super::*; - fn request() -> Request { + fn request() -> Request { Request::builder() .method(Method::OPTIONS) .header(ORIGIN, "http://example.com") .header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web") .header(ACCESS_CONTROL_REQUEST_METHOD, "POST") - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -371,13 +369,12 @@ mod tests { mod grpc { use super::*; - use http::HeaderValue; - fn request() -> Request { + fn request() -> Request { Request::builder() .version(Version::HTTP_2) .header(CONTENT_TYPE, GRPC) - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -397,7 +394,7 @@ mod tests { let req = Request::builder() .header(CONTENT_TYPE, GRPC) - .body(Body::empty()) + .body(empty_body()) .unwrap(); let res = svc.call(req).await.unwrap(); @@ -425,10 +422,10 @@ mod tests { mod other { use super::*; - fn request() -> Request { + fn request() -> Request { Request::builder() .header(CONTENT_TYPE, "application/text") - .body(Body::empty()) + .body(empty_body()) .unwrap() } diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index 3343d754c..037ff8dad 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -2,11 +2,13 @@ use std::net::SocketAddr; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use http_body_util::{BodyExt as _, Full}; use hyper::http::{header, StatusCode}; -use hyper::{Body, Client, Method, Request, Uri}; +use hyper::{Method, Request, Uri}; use prost::Message; use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; +use tonic::body::BoxBody; use tonic::transport::Server; use integration::pb::{test_server::TestServer, Input, Output}; @@ -102,7 +104,7 @@ fn encode_body() -> Bytes { buf.split_to(len + 5).freeze() } -fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request { +fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request { use header::{ACCEPT, CONTENT_TYPE, ORIGIN}; let request_uri = format!("{}/{}/{}", base_uri, "test.Test", "UnaryCall") @@ -123,7 +125,9 @@ fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request< .header(ORIGIN, "http://example.com") .header(ACCEPT, format!("application/{}", accept)) .uri(request_uri) - .body(Body::from(bytes)) + .body(BoxBody::new( + Full::new(bytes).map_err(|err| Status::internal(err.to_string())), + )) .unwrap() } diff --git a/tonic/src/body.rs b/tonic/src/body.rs index ef95eec47..428c0dade 100644 --- a/tonic/src/body.rs +++ b/tonic/src/body.rs @@ -1,9 +1,9 @@ //! HTTP specific body utilities. -use http_body::Body; +use http_body_util::BodyExt; /// A type erased HTTP body used for tonic services. -pub type BoxBody = http_body::combinators::UnsyncBoxBody; +pub type BoxBody = http_body_util::combinators::UnsyncBoxBody; /// Convert a [`http_body::Body`] into a [`BoxBody`]. pub(crate) fn boxed(body: B) -> BoxBody @@ -16,7 +16,7 @@ where /// Create an empty `BoxBody` pub fn empty_body() -> BoxBody { - http_body::Empty::new() + http_body_util::Empty::new() .map_err(|err| match err {}) .boxed_unsync() } diff --git a/tonic/src/extensions.rs b/tonic/src/extensions.rs index 37d84b87b..32b9ad021 100644 --- a/tonic/src/extensions.rs +++ b/tonic/src/extensions.rs @@ -24,7 +24,7 @@ impl Extensions { /// If a extension of this type already existed, it will /// be returned. #[inline] - pub fn insert(&mut self, val: T) -> Option { + pub fn insert(&mut self, val: T) -> Option { self.inner.insert(val) } diff --git a/tonic/src/request.rs b/tonic/src/request.rs index a27a7070c..f2cca7c74 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -313,6 +313,7 @@ impl Request { /// ```no_run /// use tonic::{Request, service::interceptor}; /// + /// #[derive(Clone)] // Extensions must be Clone /// struct MyExtension { /// some_piece_of_data: String, /// } @@ -440,7 +441,6 @@ pub(crate) enum SanitizeHeaders { #[cfg(test)] mod tests { use super::*; - use crate::metadata::MetadataValue; use http::Uri; #[test] diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 7f2ffde2b..ad930c617 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -35,12 +35,13 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, ServerIo}; +use crate::body::boxed; use crate::body::BoxBody; use crate::server::NamedService; use bytes::Bytes; use http::{Request, Response}; -use http_body::Body as _; -use hyper::{server::accept, Body}; +use http_body_util::BodyExt; +use hyper::server::accept; use pin_project::pin_project; use std::{ convert::Infallible, @@ -63,9 +64,11 @@ use tower::{ Service, ServiceBuilder, }; -type BoxHttpBody = http_body::combinators::UnsyncBoxBody; -type BoxService = tower::util::BoxService, Response, crate::Error>; type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; +type BoxHttpBody = crate::body::BoxBody; +type Body = hyper::body::Incoming; // Temporary type alias to ease transition +type BoxError = crate::Error; +type BoxService = tower::util::BoxCloneService, Response, crate::Error>; const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 46a88dda5..b3428aa2c 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,6 +1,6 @@ use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; use crate::{ - body::BoxBody, + body::{boxed, BoxBody}, transport::{BoxFuture, Endpoint}, }; use http::Uri; @@ -21,8 +21,7 @@ use tower::{ }; use tower_service::Service; -pub(crate) type Request = http::Request; -pub(crate) type Response = http::Response; +pub(crate) use crate::transport::{Request, Response}; pub(crate) struct Connection { inner: BoxService, From ac2698db50c1b41481e7d00272f8f76cb01a7943 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:29:37 +0000 Subject: [PATCH 03/31] Convert tonic::codec::decode to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. Co-authored-by: Ivan Krivosheev --- tonic/benches/decode.rs | 21 ++++++------ tonic/src/codec/decode.rs | 71 ++++++++++++++++++--------------------- 2 files changed, 43 insertions(+), 49 deletions(-) diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 5c7cd0159..22ab6d9d4 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -1,6 +1,6 @@ use bencher::{benchmark_group, benchmark_main, Bencher}; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use http_body::Body; +use http_body::{Body, Frame, SizeHint}; use std::{ fmt::{Error, Formatter}, pin::Pin, @@ -58,23 +58,24 @@ impl Body for MockBody { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { if self.data.has_remaining() { let split = std::cmp::min(self.chunk_size, self.data.remaining()); - Poll::Ready(Some(Ok(self.data.split_to(split)))) + Poll::Ready(Some(Ok(Frame::data(self.data.split_to(split))))) } else { Poll::Ready(None) } } - fn poll_trailers( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + fn is_end_stream(&self) -> bool { + !self.data.is_empty() + } + + fn size_hint(&self) -> SizeHint { + SizeHint::with_exact(self.data.len() as u64) } } diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 020551704..e5aee85f2 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -2,8 +2,9 @@ use super::compression::{decompress, CompressionEncoding, CompressionSettings}; use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; -use http::StatusCode; +use http::{HeaderMap, StatusCode}; use http_body::Body; +use http_body_util::BodyExt; use std::{ fmt, future, pin::Pin, @@ -27,7 +28,7 @@ struct StreamingInner { state: State, direction: Direction, buf: BytesMut, - trailers: Option, + trailers: Option, decompress_buf: BytesMut, encoding: Option, max_message_size: Option, @@ -121,7 +122,7 @@ impl Streaming { decoder: Box::new(decoder), inner: StreamingInner { body: body - .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) + .map_frame(|frame| frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))) .map_err(|err| Status::map_error(err.into())) .boxed_unsync(), state: State::ReadHeader, @@ -239,8 +240,8 @@ impl StreamingInner { } // Returns Some(()) if data was found or None if the loop in `poll_next` should break - fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { - let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { + fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { + let chunk = match ready!(Pin::new(&mut self.body).poll_frame(cx)) { Some(Ok(d)) => Some(d), Some(Err(status)) => { if self.direction == Direction::Request && status.code() == Code::Cancelled { @@ -254,9 +255,18 @@ impl StreamingInner { None => None, }; - Poll::Ready(if let Some(data) = chunk { - self.buf.put(data); - Ok(Some(())) + Poll::Ready(if let Some(frame) = chunk { + match frame { + frame if frame.is_data() => { + self.buf.put(frame.into_data().unwrap()); + Ok(Some(())) + } + frame if frame.is_trailers() => { + self.trailers = Some(frame.into_trailers().unwrap()); + Ok(None) + } + frame => panic!("unexpected frame: {:?}", frame), + } } else { // FIXME: improve buf usage. if self.buf.has_remaining() { @@ -271,27 +281,18 @@ impl StreamingInner { }) } - fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll> { + fn response(&mut self) -> Result<(), Status> { if let Direction::Response(status) = self.direction { - match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { - Ok(trailer) => { - if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { - if let Some(e) = e { - return Poll::Ready(Err(e)); - } else { - return Poll::Ready(Ok(())); - } - } else { - self.trailers = trailer.map(MetadataMap::from_headers); - } - } - Err(status) => { - debug!("decoder inner trailers error: {:?}", status); - return Poll::Ready(Err(status)); + if let Err(e) = crate::status::infer_grpc_status(self.trailers.as_ref(), status) { + if let Some(e) = e { + // If the trailers contain a grpc-status, then we should return that as the error + // and otherwise stop the stream (by taking the error state) + self.trailers.take(); + return Err(e); } } } - Poll::Ready(Ok(())) + Ok(()) } } @@ -351,7 +352,7 @@ impl Streaming { // Shortcut to see if we already pulled the trailers in the stream step // we need to do that so that the stream can error on trailing grpc-status if let Some(trailers) = self.inner.trailers.take() { - return Ok(Some(trailers)); + return Ok(Some(MetadataMap::from_headers(trailers))); } // To fetch the trailers we must clear the body and drop it. @@ -360,16 +361,11 @@ impl Streaming { // Since we call poll_trailers internally on poll_next we need to // check if it got cached again. if let Some(trailers) = self.inner.trailers.take() { - return Ok(Some(trailers)); + return Ok(Some(MetadataMap::from_headers(trailers))); } - // Trailers were not caught during poll_next and thus lets poll for - // them manually. - let map = future::poll_fn(|cx| Pin::new(&mut self.inner.body).poll_trailers(cx)) - .await - .map_err(|e| Status::from_error(Box::new(e))); - - map.map(|x| x.map(MetadataMap::from_headers)) + // We've polled through all the frames, and still no trailers, return None + Ok(None) } fn decode_chunk(&mut self) -> Result, Status> { @@ -395,20 +391,17 @@ impl Stream for Streaming { return Poll::Ready(None); } - // FIXME: implement the ability to poll trailers when we _know_ that - // the consumer of this stream will only poll for the first message. - // This means we skip the poll_trailers step. if let Some(item) = self.decode_chunk()? { return Poll::Ready(Some(Ok(item))); } - match ready!(self.inner.poll_data(cx))? { + match ready!(self.inner.poll_frame(cx))? { Some(()) => (), None => break, } } - Poll::Ready(match ready!(self.inner.poll_response(cx)) { + Poll::Ready(match self.inner.response() { Ok(()) => None, Err(err) => Some(Err(err)), }) From 3ef308c1cfbd01fb7f5857bcfb628320d9bca197 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:35:14 +0000 Subject: [PATCH 04/31] Convert tonic::transport::channel to use http >= 1 body types tonic::transport::channel previously used `hyper::Body` as the response body type. This type no longer exists in hyper >= 1, and so has been converted to a `BoxBody` provided by `http_body_util` designed for interoperability between http crates. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/transport/channel/mod.rs | 6 +++--- tonic/src/transport/service/connection.rs | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index b510a6980..6a857dff1 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -38,7 +38,7 @@ use tower::{ Service, }; -type Svc = Either, Response, crate::Error>>; +type Svc = Either, Response, crate::Error>>; const DEFAULT_BUFFER_SIZE: usize = 1024; @@ -201,7 +201,7 @@ impl Channel { } impl Service> for Channel { - type Response = http::Response; + type Response = http::Response; type Error = super::Error; type Future = ResponseFuture; @@ -217,7 +217,7 @@ impl Service> for Channel { } impl Future for ResponseFuture { - type Output = Result, super::Error>; + type Output = Result, super::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let val = ready!(Pin::new(&mut self.inner).poll(cx)).map_err(super::Error::from_source)?; diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index b3428aa2c..1fa059c96 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -21,7 +21,8 @@ use tower::{ }; use tower_service::Service; -pub(crate) use crate::transport::{Request, Response}; +pub(crate) type Response = http::Response; +pub(crate) type Request = http::Request; pub(crate) struct Connection { inner: BoxService, From 8a24c79dd4bd78616a57d0b470dc5c0d2af0c856 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:31:02 +0000 Subject: [PATCH 05/31] [tests] Convert tonic::codec::prost::tests to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. This also handles the return types which should now be wrapped in `Frame` when appropriate. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/codec/prost.rs | 44 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 4e925ba6e..8237daf32 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -156,6 +156,7 @@ mod tests { use crate::{Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use http_body::Body; + use http_body_util::BodyExt as _; use std::pin::pin; const LEN: usize = 10000; @@ -238,7 +239,7 @@ mod tests { None, )); - while let Some(r) = body.data().await { + while let Some(r) = body.frame().await { r.unwrap(); } } @@ -260,12 +261,15 @@ mod tests { Some(MAX_MESSAGE_SIZE), )); - assert!(body.data().await.is_none()); + let frame = body + .frame() + .await + .expect("at least one frame") + .expect("no error polling frame"); assert_eq!( - body.trailers() - .await - .expect("no error polling trailers") - .expect("some trailers") + frame + .into_trailers() + .expect("got trailers") .get("grpc-status") .expect("grpc-status header"), "11" @@ -292,12 +296,15 @@ mod tests { Some(usize::MAX), )); - assert!(body.data().await.is_none()); + let frame = body + .frame() + .await + .expect("at least one frame") + .expect("no error polling frame"); assert_eq!( - body.trailers() - .await - .expect("no error polling trailers") - .expect("some trailers") + frame + .into_trailers() + .expect("got trailers") .get("grpc-status") .expect("grpc-status header"), "8" @@ -343,7 +350,7 @@ mod tests { mod body { use crate::Status; use bytes::Bytes; - use http_body::Body; + use http_body::{Body, Frame}; use std::{ pin::Pin, task::{Context, Poll}, @@ -374,10 +381,10 @@ mod tests { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { // every other call to poll_data returns data let should_send = self.count % 2 == 0; let data_len = self.data.len(); @@ -395,18 +402,11 @@ mod tests { }; // make some fake progress self.count += 1; - result + result.map(|opt| opt.map(|res| res.map(|data| Frame::data(data)))) } else { Poll::Ready(None) } } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } } } } From e0f0f6cc06d1b151b762d17b259e168e2eed7257 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:30:04 +0000 Subject: [PATCH 06/31] Convert tonic::codec::encode to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/codec/encode.rs | 46 ++++++++++++++------------------------- 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 88ec9568e..82b4eb61d 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -5,7 +5,7 @@ use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, H use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderMap; -use http_body::Body; +use http_body::{Body, Frame}; use pin_project::pin_project; use std::{ pin::Pin, @@ -298,22 +298,21 @@ where } impl EncodeState { - fn trailers(&mut self) -> Result, Status> { + fn trailers(&mut self) -> Option> { match self.role { - Role::Client => Ok(None), + Role::Client => None, Role::Server => { if self.is_end_stream { - return Ok(None); + return None; } + self.is_end_stream = true; let status = if let Some(status) = self.error.take() { - self.is_end_stream = true; status } else { Status::new(Code::Ok, "") }; - - Ok(Some(status.to_header_map()?)) + Some(status.to_header_map()) } } } @@ -330,38 +329,25 @@ where self.state.is_end_stream } - fn size_hint(&self) -> http_body::SizeHint { - let sh = self.inner.size_hint(); - let mut size_hint = http_body::SizeHint::new(); - size_hint.set_lower(sh.0 as u64); - if let Some(upper) = sh.1 { - size_hint.set_upper(upper as u64); - } - size_hint - } - - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let self_proj = self.project(); match ready!(self_proj.inner.poll_next(cx)) { - Some(Ok(d)) => Some(Ok(d)).into(), + Some(Ok(d)) => Some(Ok(Frame::data(d))).into(), Some(Err(status)) => match self_proj.state.role { Role::Client => Some(Err(status)).into(), Role::Server => { - self_proj.state.error = Some(status); - None.into() + self_proj.state.is_end_stream = true; + Some(Ok(Frame::trailers(status.to_header_map()?))).into() } }, - None => None.into(), + None => self_proj + .state + .trailers() + .map(|t| t.map(Frame::trailers)) + .into(), } } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Status>> { - Poll::Ready(self.project().state.trailers()) - } } From 51cf7dc399ba9a6a53ce04eeda905f83efa8ab1c Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:31:37 +0000 Subject: [PATCH 07/31] [tests] Convert tonic::service::interceptor::tests to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. This also handles the return types which should now be wrapped in `Frame` when appropriate. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/service/interceptor.rs | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index cadff466f..ebe78093d 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -232,11 +232,8 @@ where mod tests { #[allow(unused_imports)] use super::*; - use http::header::HeaderMap; - use std::{ - pin::Pin, - task::{Context, Poll}, - }; + use http_body::Frame; + use http_body_util::Empty; use tower::ServiceExt; #[derive(Debug, Default)] @@ -246,19 +243,12 @@ mod tests { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, _cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { Poll::Ready(None) } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } } #[tokio::test] @@ -318,17 +308,17 @@ mod tests { #[tokio::test] async fn doesnt_change_http_method() { - let svc = tower::service_fn(|request: http::Request| async move { + let svc = tower::service_fn(|request: http::Request>| async move { assert_eq!(request.method(), http::Method::OPTIONS); - Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty())) + Ok::<_, hyper::Error>(hyper::Response::new(Empty::new())) }); let svc = InterceptedService::new(svc, Ok); let request = http::Request::builder() .method(http::Method::OPTIONS) - .body(hyper::Body::empty()) + .body(Empty::new()) .unwrap(); svc.oneshot(request).await.unwrap(); From 4bf003b1824fffcae6d00b75dd278ea14b548eeb Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:39:22 +0000 Subject: [PATCH 08/31] Convert tonic::transport to use http >= 1 body types Here, we must update some body types which are no longer valid. (A) BoxBody no longer has an `empty` method, instead we provide a helper in `tonic::body` for creating an empty boxed body via `http_body_util`. As well, `hyper::Body` is no longer a type, and instead, `hyper::Incoming` is used when directly recieving a Request from hyper, and `BoxBody` is used when the request may have passed through an axum router. In tonic, we prefer `BoxBody` as it allows for services to be used downstream from other components which enforce a specific body type (e.g. Axum), at the cost of making Body streaming opaque. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/transport/mod.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 758bdb7d8..faa3b49be 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -1,7 +1,7 @@ //! Batteries included server and client. //! //! This module provides a set of batteries included, fully featured and -//! fast set of HTTP/2 server and client's. These components each provide a or +//! fast set of HTTP/2 server and client's. These components each provide a //! `rustls` tls backend when the respective feature flag is enabled, and //! provides builders to configure transport behavior. //! @@ -22,6 +22,7 @@ //! # use tonic::transport::{Channel, Certificate, ClientTlsConfig}; //! # use std::time::Duration; //! # use tonic::body::BoxBody; +//! # use tonic::body::empty_body; //! # use tonic::client::GrpcService;; //! # use http::Request; //! # #[cfg(feature = "rustls")] @@ -38,7 +39,7 @@ //! .connect() //! .await?; //! -//! channel.call(Request::new(BoxBody::empty())).await?; +//! channel.call(Request::new(empty_body())).await?; //! # Ok(()) //! # } //! ``` @@ -46,21 +47,23 @@ //! ## Server //! //! ```no_run +//! # use std::convert::Infallible; //! # #[cfg(feature = "rustls")] //! # use tonic::transport::{Server, Identity, ServerTlsConfig}; +//! # use tonic::body::BoxBody; //! # use tower::Service; //! # #[cfg(feature = "rustls")] //! # async fn do_thing() -> Result<(), Box> { //! # #[derive(Clone)] //! # pub struct Svc; -//! # impl Service> for Svc { //! # type Response = hyper::Response; -//! # type Error = tonic::Status; +//! # impl Service> for Svc { +//! # type Error = Infallible; //! # type Future = std::future::Ready>; //! # fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { //! # Ok(()).into() //! # } -//! # fn call(&mut self, _req: hyper::Request) -> Self::Future { +//! # fn call(&mut self, _req: hyper::Request) -> Self::Future { //! # unimplemented!() //! # } //! # } @@ -123,5 +126,8 @@ pub use self::server::ServerTlsConfig; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Identity; +use crate::body::BoxBody; type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; +pub(crate) type Response = http::Response; +pub(crate) type Request = http::Request; From 70abf84ede85c912750010c54d314ccd5e355deb Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:40:14 +0000 Subject: [PATCH 09/31] Convert tonic::transport::server::recover_error to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. Co-authored-by: Ivan Krivosheev Co-authored-by: Ludea --- tonic/src/transport/server/recover_error.rs | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs index fdb14a66a..60b0d9a7b 100644 --- a/tonic/src/transport/server/recover_error.rs +++ b/tonic/src/transport/server/recover_error.rs @@ -1,5 +1,6 @@ use crate::Status; use http::Response; +use http_body::Frame; use pin_project::pin_project; use std::{ future::Future, @@ -98,26 +99,16 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.as_pin_mut() { - Some(b) => b.poll_data(cx), + Some(b) => b.poll_frame(cx), None => Poll::Ready(None), } } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.as_pin_mut() { - Some(b) => b.poll_trailers(cx), - None => Poll::Ready(Ok(None)), - } - } - fn is_end_stream(&self) -> bool { match &self.inner { Some(b) => b.is_end_stream(), From c5d0d1345b1218813ff2947876b6ec3a79797778 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:25:12 +0000 Subject: [PATCH 10/31] Convert h2c examples to use http >= 1 body types In h2c, when a service is receiving from hyper, it has to accept a `hyper::body::Incoming` in hyper >= 1. Additionally, response bodies must be built from `http_body_util` combinators and become BoxBody objects. --- examples/src/h2c/client.rs | 11 +++++++---- examples/src/h2c/server.rs | 22 ++++++++++------------ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/examples/src/h2c/client.rs b/examples/src/h2c/client.rs index 31076b1ac..624ea175f 100644 --- a/examples/src/h2c/client.rs +++ b/examples/src/h2c/client.rs @@ -34,15 +34,18 @@ mod h2c { }; use hyper::{client::HttpConnector, Client}; - use tonic::body::BoxBody; + use hyper::body::Incoming; + use hyper_util::{ + rt::TokioExecutor, + use tonic::body::{empty_body, BoxBody}; use tower::Service; pub struct H2cChannel { - pub client: Client, + pub client: Client, } impl Service> for H2cChannel { - type Response = http::Response; + type Response = http::Response; type Error = hyper::Error; type Future = Pin> + Send>>; @@ -60,7 +63,7 @@ mod h2c { let h2c_req = hyper::Request::builder() .uri(origin) .header(http::header::UPGRADE, "h2c") - .body(hyper::Body::empty()) + .body(empty_body()) .unwrap(); let res = client.request(h2c_req).await.unwrap(); diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index 92d08a417..21dcc1f35 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -49,7 +49,9 @@ mod h2c { use std::pin::Pin; use http::{Request, Response}; - use hyper::Body; + use hyper::body::Incoming; + use hyper_util::{rt::TokioExecutor, service::TowerToHyperService}; + use tonic::{body::empty_body, transport::AxumBoxBody}; use tower::Service; #[derive(Clone)] @@ -59,17 +61,14 @@ mod h2c { type BoxError = Box; - impl Service> for H2c + impl Service> for H2c where - S: Service, Response = Response> - + Clone - + Send - + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Sync + Send + 'static, S::Response: Send + 'static, { - type Response = hyper::Response; + type Response = hyper::Response; type Error = hyper::Error; type Future = Pin> + Send>>; @@ -81,20 +80,19 @@ mod h2c { std::task::Poll::Ready(Ok(())) } - fn call(&mut self, mut req: hyper::Request) -> Self::Future { + fn call(&mut self, mut req: hyper::Request) -> Self::Future { let svc = self.s.clone(); Box::pin(async move { tokio::spawn(async move { let upgraded_io = hyper::upgrade::on(&mut req).await.unwrap(); - hyper::server::conn::Http::new() - .http2_only(true) - .serve_connection(upgraded_io, svc) + hyper::server::conn::http2::Builder::new(TokioExecutor::new()) + .serve_connection(upgraded_io, TowerToHyperService::new(svc)) .await .unwrap(); }); - let mut res = hyper::Response::new(hyper::Body::empty()); + let mut res = hyper::Response::new(empty_body()); *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS; res.headers_mut().insert( hyper::header::UPGRADE, From c3ce3b3725011f44e98cdd8362be0605b5a68816 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:54:14 +0000 Subject: [PATCH 11/31] [tests] Convert MergeTrailers body wrapper in interop server The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- interop/src/server.rs | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/interop/src/server.rs b/interop/src/server.rs index aef7b0d45..38b1be65e 100644 --- a/interop/src/server.rs +++ b/interop/src/server.rs @@ -1,10 +1,10 @@ use crate::pb::{self, *}; use async_stream::try_stream; -use http::header::{HeaderMap, HeaderName, HeaderValue}; +use http::header::{HeaderName, HeaderValue}; use http_body::Body; use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::Duration; use tokio_stream::StreamExt; use tonic::{body::BoxBody, server::NamedService, Code, Request, Response, Status}; @@ -235,25 +235,19 @@ impl Body for MergeTrailers { type Data = B::Data; type Error = B::Error; - fn poll_data( - mut self: Pin<&mut Self>, + fn poll_frame( + self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - Pin::new(&mut self.inner).poll_data(cx) - } - - fn poll_trailers( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| { - h.map(|mut headers| { - if let Some((key, value)) = &self.trailer { - headers.insert(key.clone(), value.clone()); + ) -> Poll, Self::Error>>> { + let this = self.get_mut(); + let mut frame = ready!(Pin::new(&mut this.inner).poll_frame(cx)?); + if let Some(frame) = frame.as_mut() { + if let Some(trailers) = frame.trailers_mut() { + if let Some((key, value)) = &this.trailer { + trailers.insert(key.clone(), value.clone()); } - - headers - }) - }) + } + } + Poll::Ready(frame.map(Ok)) } } From f2f7870bfcee4bc30e9bb822fff1f11ea14fddea Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:42:34 +0000 Subject: [PATCH 12/31] [tests] Convert compression tests to use hyper 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- tests/compression/Cargo.toml | 2 +- tests/compression/src/util.rs | 81 ++++++++++++++++++++++++----------- 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index 4ba549cdc..cf4da321b 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -20,7 +20,7 @@ tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} tokio-stream = "0.1" tonic = {path = "../../tonic", features = ["gzip", "zstd"]} tower = {version = "0.4", features = []} -tower-http = {version = "0.4", features = ["map-response-body", "map-request-body"]} +tower-http = {version = "0.5", features = ["map-response-body", "map-request-body"]} [build-dependencies] tonic-build = {path = "../../tonic-build" } diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 28fa5d96a..99afded3f 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -1,6 +1,8 @@ use super::*; -use bytes::Bytes; -use http_body::Body; +use bytes::{Buf, Bytes}; +use http_body::{Body, Frame}; +use http_body_util::BodyExt as _; +use hyper_util::rt::TokioIo; use pin_project::pin_project; use std::{ pin::Pin, @@ -11,6 +13,7 @@ use std::{ task::{ready, Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tonic::body::BoxBody; use tonic::codec::CompressionEncoding; use tonic::transport::{server::Connected, Channel}; use tower_http::map_request_body::MapRequestBodyLayer; @@ -46,29 +49,22 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let this = self.project(); let counter: Arc = this.counter.clone(); - match ready!(this.inner.poll_data(cx)) { + match ready!(this.inner.poll_frame(cx)) { Some(Ok(chunk)) => { - println!("response body chunk size = {}", chunk.len()); - counter.fetch_add(chunk.len(), SeqCst); + println!("response body chunk size = {}", frame_data_length(&chunk)); + counter.fetch_add(frame_data_length(&chunk), SeqCst); Poll::Ready(Some(Ok(chunk))) } x => Poll::Ready(x), } } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) - } - fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } @@ -78,28 +74,61 @@ where } } +fn frame_data_length(frame: &http_body::Frame) -> usize { + if let Some(data) = frame.data_ref() { + data.len() + } else { + 0 + } +} + +#[pin_project] +struct ChannelBody { + #[pin] + rx: tokio::sync::mpsc::Receiver>, +} + +impl ChannelBody { + pub fn new() -> (tokio::sync::mpsc::Sender>, Self) { + let (tx, rx) = tokio::sync::mpsc::channel(32); + (tx, Self { rx }) + } +} + +impl Body for ChannelBody +where + T: Buf, +{ + type Data = T; + type Error = tonic::Status; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let frame = ready!(self.project().rx.poll_recv(cx)); + Poll::Ready(frame.map(Ok)) + } +} + #[allow(dead_code)] pub fn measure_request_body_size_layer( bytes_sent_counter: Arc, -) -> MapRequestBodyLayer hyper::Body + Clone> { - MapRequestBodyLayer::new(move |mut body: hyper::Body| { - let (mut tx, new_body) = hyper::Body::channel(); +) -> MapRequestBodyLayer BoxBody + Clone> { + MapRequestBodyLayer::new(move |mut body: BoxBody| { + let (tx, new_body) = ChannelBody::new(); let bytes_sent_counter = bytes_sent_counter.clone(); tokio::spawn(async move { - while let Some(chunk) = body.data().await { + while let Some(chunk) = body.frame().await { let chunk = chunk.unwrap(); - println!("request body chunk size = {}", chunk.len()); - bytes_sent_counter.fetch_add(chunk.len(), SeqCst); - tx.send_data(chunk).await.unwrap(); - } - - if let Some(trailers) = body.trailers().await.unwrap() { - tx.send_trailers(trailers).await.unwrap(); + println!("request body chunk size = {}", frame_data_length(&chunk)); + bytes_sent_counter.fetch_add(frame_data_length(&chunk), SeqCst); + tx.send(chunk).await.unwrap(); } }); - new_body + new_body.boxed_unsync() }) } From 829eb3b2691a0af64aceec8a2aa9fb9104a0fe48 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:55:15 +0000 Subject: [PATCH 13/31] [tests] Convert complex_tower_middleware Body for hyper 1 The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- tests/integration_tests/Cargo.toml | 2 +- .../tests/complex_tower_middleware.rs | 11 ++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 6a7ec8052..cfeebf725 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -23,7 +23,7 @@ hyper = "1" hyper-util = "0.1" tokio-stream = {version = "0.1.5", features = ["net"]} tower = {version = "0.4", features = []} -tower-http = { version = "0.4", features = ["set-header", "trace"] } +tower-http = { version = "0.5", features = ["set-header", "trace"] } tower-service = "0.3" tracing = "0.1" diff --git a/tests/integration_tests/tests/complex_tower_middleware.rs b/tests/integration_tests/tests/complex_tower_middleware.rs index 5d7690be3..b1b669426 100644 --- a/tests/integration_tests/tests/complex_tower_middleware.rs +++ b/tests/integration_tests/tests/complex_tower_middleware.rs @@ -97,17 +97,10 @@ where type Data = B::Data; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - unimplemented!() - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { + ) -> Poll, Self::Error>>> { unimplemented!() } } From 55f9edbc0cab3ea1c8950580aa11c7864cc27bfd Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:06:13 +0000 Subject: [PATCH 14/31] [tests] Convert integration_tests::origin to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- tests/integration_tests/tests/origin.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs index c8140c79f..e41287245 100644 --- a/tests/integration_tests/tests/origin.rs +++ b/tests/integration_tests/tests/origin.rs @@ -7,7 +7,6 @@ use std::time::Duration; use tokio::sync::oneshot; use tonic::codegen::http::Request; use tonic::{ - body::BoxBody, transport::{Endpoint, Server}, Response, Status, }; @@ -77,9 +76,9 @@ struct OriginService { inner: S, } -impl Service> for OriginService +impl Service> for OriginService where - T: Service>, + T: Service>, T::Future: Send + 'static, T::Error: Into>, { @@ -91,7 +90,7 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { assert_eq!(req.uri().host(), Some("docs.rs")); let fut = self.inner.call(req); From 26e640e542cb6c4a776fe153d2346943a939c30f Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:56:09 +0000 Subject: [PATCH 15/31] Convert tonic-web to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- tonic-web/Cargo.toml | 2 +- tonic-web/src/call.rs | 134 ++++++++++-------- tonic-web/tests/integration/tests/grpc_web.rs | 5 +- 3 files changed, 78 insertions(+), 63 deletions(-) diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml index 5813fd6a3..157605a95 100644 --- a/tonic-web/Cargo.toml +++ b/tonic-web/Cargo.toml @@ -25,7 +25,7 @@ pin-project = "1" tonic = {version = "0.11", path = "../tonic", default-features = false} tower-service = "0.3" tower-layer = "0.3" -tower-http = { version = "0.4", features = ["cors"] } +tower-http = { version = "0.5", features = ["cors"] } tracing = "0.1" [dev-dependencies] diff --git a/tonic-web/src/call.rs b/tonic-web/src/call.rs index f52087e9e..178e620ae 100644 --- a/tonic-web/src/call.rs +++ b/tonic-web/src/call.rs @@ -5,7 +5,7 @@ use std::task::{ready, Context, Poll}; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use http::{header, HeaderMap, HeaderName, HeaderValue}; -use http_body::{Body, SizeHint}; +use http_body::{Body, Frame, SizeHint}; use pin_project::pin_project; use tokio_stream::Stream; use tonic::Status; @@ -63,9 +63,9 @@ pub struct GrpcWebCall { #[pin] inner: B, buf: BytesMut, + decoded: BytesMut, direction: Direction, encoding: Encoding, - poll_trailers: bool, client: bool, trailers: Option, } @@ -75,9 +75,9 @@ impl Default for GrpcWebCall { Self { inner: Default::default(), buf: Default::default(), + decoded: Default::default(), direction: Direction::Empty, encoding: Encoding::None, - poll_trailers: Default::default(), client: Default::default(), trailers: Default::default(), } @@ -108,9 +108,12 @@ impl GrpcWebCall { (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, _ => 0, }), + decoded: BytesMut::with_capacity(match direction { + Direction::Decode => BUFFER_SIZE, + _ => 0, + }), direction, encoding, - poll_trailers: true, client: true, trailers: None, } @@ -123,9 +126,9 @@ impl GrpcWebCall { (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, _ => 0, }), + decoded: BytesMut::with_capacity(0), direction, encoding, - poll_trailers: true, client: false, trailers: None, } @@ -160,24 +163,37 @@ where B: Body, B::Error: Error, { + // Poll body for data, decoding (e.g. via Base64 if necessary) and returning frames + // to the caller. If the caller is a client, it should look for trailers before + // returning these frames. fn poll_decode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Status>>> { match self.encoding { Encoding::Base64 => loop { if let Some(bytes) = self.as_mut().decode_chunk()? { - return Poll::Ready(Some(Ok(bytes))); + return Poll::Ready(Some(Ok(Frame::data(bytes)))); } let mut this = self.as_mut().project(); - match ready!(this.inner.as_mut().poll_data(cx)) { - Some(Ok(data)) => this.buf.put(data), + match ready!(this.inner.as_mut().poll_frame(cx)) { + Some(Ok(frame)) if frame.is_data() => this.buf.put(frame.into_data().unwrap()), + Some(Ok(frame)) if frame.is_trailers() => { + return Poll::Ready(Some(Err(internal_error( + "malformed base64 request has unencoded trailers", + )))) + } + Some(Ok(_)) => { + return Poll::Ready(Some(Err(internal_error("unexpected frame type")))) + } Some(Err(e)) => return Poll::Ready(Some(Err(internal_error(e)))), None => { return if this.buf.has_remaining() { Poll::Ready(Some(Err(internal_error("malformed base64 request")))) + } else if let Some(trailers) = this.trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) } else { Poll::Ready(None) } @@ -185,7 +201,7 @@ where } }, - Encoding::None => match ready!(self.project().inner.poll_data(cx)) { + Encoding::None => match ready!(self.project().inner.poll_frame(cx)) { Some(res) => Poll::Ready(Some(res.map_err(internal_error))), None => Poll::Ready(None), }, @@ -195,37 +211,33 @@ where fn poll_encode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Status>>> { let mut this = self.as_mut().project(); - if let Some(mut res) = ready!(this.inner.as_mut().poll_data(cx)) { - if *this.encoding == Encoding::Base64 { - res = res.map(|b| crate::util::base64::STANDARD.encode(b).into()) - } - - return Poll::Ready(Some(res.map_err(internal_error))); - } - - // this flag is needed because the inner stream never - // returns Poll::Ready(None) when polled for trailers - if *this.poll_trailers { - return match ready!(this.inner.poll_trailers(cx)) { - Ok(Some(map)) => { - let mut frame = make_trailers_frame(map); + match ready!(this.inner.as_mut().poll_frame(cx)) { + Some(Ok(frame)) if frame.is_data() => { + let mut res = frame.into_data().unwrap(); - if *this.encoding == Encoding::Base64 { - frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); - } + if *this.encoding == Encoding::Base64 { + let mut buf = Vec::with_capacity(res.len()); + buf.extend_from_slice(&res); + res = crate::util::base64::STANDARD.encode(buf).into(); + } - *this.poll_trailers = false; - Poll::Ready(Some(Ok(frame.into()))) + Poll::Ready(Some(Ok(Frame::data(res)))) + } + Some(Ok(frame)) if frame.is_trailers() => { + let trailers = frame.into_trailers().expect("must be trailers"); + let mut frame = make_trailers_frame(trailers); + if *this.encoding == Encoding::Base64 { + frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); } - Ok(None) => Poll::Ready(None), - Err(e) => Poll::Ready(Some(Err(internal_error(e)))), - }; + Poll::Ready(Some(Ok(Frame::data(frame.into())))) + } + Some(Ok(_)) => Poll::Ready(Some(Err(internal_error("unexepected frame type")))), + Some(Err(e)) => Poll::Ready(Some(Err(internal_error(e)))), + None => Poll::Ready(None), } - - Poll::Ready(None) } } @@ -237,28 +249,34 @@ where type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { if self.client && self.direction == Direction::Decode { let mut me = self.as_mut(); loop { - let incoming_buf = match ready!(me.as_mut().poll_decode(cx)) { - Some(Ok(incoming_buf)) => incoming_buf, - None => { - // TODO: Consider eofing here? - // Even if the buffer has more data, this will hit the eof branch - // of decode in tonic - return Poll::Ready(None); + match ready!(me.as_mut().poll_decode(cx)) { + Some(Ok(incoming_buf)) if incoming_buf.is_data() => { + me.as_mut() + .project() + .decoded + .put(incoming_buf.into_data().unwrap()); + } + Some(Ok(incoming_buf)) if incoming_buf.is_trailers() => { + let trailers = incoming_buf.into_trailers().unwrap(); + me.as_mut().project().trailers.replace(trailers); + continue; } + Some(Ok(_)) => unreachable!("unexpected frame type"), + None => {} // No more data to decode, time to look for trailers Some(Err(e)) => return Poll::Ready(Some(Err(e))), }; - let buf = &mut me.as_mut().project().buf; - - buf.put(incoming_buf); + // Hold the incoming, decoded data until we have a full message + // or trailers to return. + let buf = me.as_mut().project().decoded; return match find_trailers(&buf[..])? { FindTrailers::Trailer(len) => { @@ -266,20 +284,24 @@ where let msg_buf = buf.copy_to_bytes(len); match decode_trailers_frame(buf.split().freeze()) { Ok(Some(trailers)) => { - self.project().trailers.replace(trailers); + me.as_mut().project().trailers.replace(trailers); } Err(e) => return Poll::Ready(Some(Err(e))), _ => {} } if msg_buf.has_remaining() { - Poll::Ready(Some(Ok(msg_buf))) + Poll::Ready(Some(Ok(Frame::data(msg_buf)))) + } else if let Some(trailers) = me.as_mut().project().trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) } else { Poll::Ready(None) } } FindTrailers::IncompleteBuf => continue, - FindTrailers::Done(len) => Poll::Ready(Some(Ok(buf.split_to(len).freeze()))), + FindTrailers::Done(len) => { + Poll::Ready(Some(Ok(Frame::data(buf.split_to(len).freeze())))) + } }; } } @@ -291,14 +313,6 @@ where } } - fn poll_trailers( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>, Self::Error>> { - let trailers = self.project().trailers.take(); - Poll::Ready(Ok(trailers)) - } - fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } @@ -313,10 +327,10 @@ where B: Body, B::Error: Error, { - type Item = Result; + type Item = Result, Status>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Body::poll_data(self, cx) + self.poll_frame(cx) } } diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index 037ff8dad..96720b19e 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -3,6 +3,7 @@ use std::net::SocketAddr; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use http_body_util::{BodyExt as _, Full}; +use hyper::body::Incoming; use hyper::http::{header, StatusCode}; use hyper::{Method, Request, Uri}; use prost::Message; @@ -131,8 +132,8 @@ fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request< .unwrap() } -async fn decode_body(body: Body, content_type: &str) -> (Output, Bytes) { - let mut body = hyper::body::to_bytes(body).await.unwrap(); +async fn decode_body(body: Incoming, content_type: &str) -> (Output, Bytes) { + let mut body = body.collect().await.unwrap().to_bytes(); if content_type == "application/grpc-web-text+proto" { body = integration::util::base64::STANDARD From 0e903cbfdd4eb6acbc8dd29d2bc4cbd3a1214b10 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:26:02 +0000 Subject: [PATCH 16/31] Adapt for hyper-specific IO traits hyper >= 1 provides its own I/O traits (Read & Write) instead of relying on the equivalent traits from `tokio`. Then, `hyper-util` provides adaptor structs to wrap `tokio` I/O objects and implement the hyper equivalents. Therefore, we update the appropriate bounds to use the hyper traits, and update the I/O objects so that they are wrapped in the tokio to hyper adaptor. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- examples/Cargo.toml | 12 +++++----- examples/src/grpc-web/client.rs | 1 + examples/src/h2c/client.rs | 2 ++ examples/src/h2c/server.rs | 1 + examples/src/mock/mock.rs | 3 ++- examples/src/uds/client.rs | 6 +++-- tests/compression/src/util.rs | 2 +- tests/integration_tests/tests/connect_info.rs | 8 ++++++- .../tests/max_message_size.rs | 5 ++-- tests/integration_tests/tests/status.rs | 3 ++- tonic-web/tests/integration/tests/grpc_web.rs | 3 +++ tonic/Cargo.toml | 3 +-- tonic/src/transport/channel/endpoint.rs | 5 ++-- tonic/src/transport/channel/mod.rs | 10 ++++---- tonic/src/transport/server/mod.rs | 1 + tonic/src/transport/service/connection.rs | 8 +++---- tonic/src/transport/service/connector.rs | 24 +++++++++++-------- tonic/src/transport/service/io.rs | 13 +++++----- 18 files changed, 66 insertions(+), 44 deletions(-) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index a672287d8..e04868826 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -271,21 +271,21 @@ routeguide = ["dep:async-stream", "tokio-stream", "dep:rand", "dep:serde", "dep: reflection = ["dep:tonic-reflection"] autoreload = ["tokio-stream/net", "dep:listenfd"] health = ["dep:tonic-health"] -grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:tracing-subscriber", "dep:tower"] +grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:hyper-util", "dep:tracing-subscriber", "dep:tower"] tracing = ["dep:tracing", "dep:tracing-subscriber"] -uds = ["tokio-stream/net", "dep:tower", "dep:hyper"] +uds = ["tokio-stream/net", "dep:tower", "dep:hyper", "dep:hyper-util"] streaming = ["tokio-stream", "dep:h2"] -mock = ["tokio-stream", "dep:tower"] -tower = ["dep:hyper", "dep:tower", "dep:http"] +mock = ["tokio-stream", "dep:tower", "dep:hyper-util"] +tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "dep:http"] json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"] compression = ["tonic/gzip"] tls = ["tonic/tls"] -tls-rustls = ["dep:hyper", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"] +tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"] dynamic-load-balance = ["dep:tower"] timeout = ["tokio/time", "dep:tower"] tls-client-auth = ["tonic/tls"] types = ["dep:tonic-types"] -h2c = ["dep:hyper", "dep:tower", "dep:http"] +h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"] cancellation = ["dep:tokio-util"] full = ["gcp", "routeguide", "reflection", "autoreload", "health", "grpc-web", "tracing", "uds", "streaming", "mock", "tower", "json-codec", "compression", "tls", "tls-rustls", "dynamic-load-balance", "timeout", "tls-client-auth", "types", "cancellation", "h2c"] diff --git a/examples/src/grpc-web/client.rs b/examples/src/grpc-web/client.rs index a16ac674a..fd20a788b 100644 --- a/examples/src/grpc-web/client.rs +++ b/examples/src/grpc-web/client.rs @@ -1,4 +1,5 @@ use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use hyper_util::rt::TokioExecutor; use tonic_web::GrpcWebClientLayer; pub mod hello_world { diff --git a/examples/src/h2c/client.rs b/examples/src/h2c/client.rs index 624ea175f..2f9f90a79 100644 --- a/examples/src/h2c/client.rs +++ b/examples/src/h2c/client.rs @@ -2,6 +2,7 @@ use hello_world::greeter_client::GreeterClient; use hello_world::HelloRequest; use http::Uri; use hyper::Client; +use hyper_util::rt::TokioExecutor; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -12,6 +13,7 @@ async fn main() -> Result<(), Box> { let origin = Uri::from_static("http://[::1]:50051"); let h2c_client = h2c::H2cChannel { client: Client::new(), + client: Client::builder(TokioExecutor::new()).build_http(), }; let mut client = GreeterClient::with_origin(h2c_client, origin); diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index 21dcc1f35..b1d4c0a8d 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -1,3 +1,4 @@ +use hyper_util::rt::{TokioExecutor, TokioIo}; use tonic::{transport::Server, Request, Response, Status}; use hello_world::greeter_server::{Greeter, GreeterServer}; diff --git a/examples/src/mock/mock.rs b/examples/src/mock/mock.rs index 0d3754921..6c26a6735 100644 --- a/examples/src/mock/mock.rs +++ b/examples/src/mock/mock.rs @@ -1,3 +1,4 @@ +use hyper_util::rt::TokioIo; use tonic::{ transport::{Endpoint, Server, Uri}, Request, Response, Status, @@ -36,7 +37,7 @@ async fn main() -> Result<(), Box> { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, diff --git a/examples/src/uds/client.rs b/examples/src/uds/client.rs index e78531ac4..9a09e6981 100644 --- a/examples/src/uds/client.rs +++ b/examples/src/uds/client.rs @@ -5,6 +5,7 @@ pub mod hello_world { } use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::UnixStream; use tonic::transport::{Endpoint, Uri}; @@ -16,12 +17,13 @@ async fn main() -> Result<(), Box> { // We will ignore this uri because uds do not use it // if your connector does use the uri it will be provided // as the request to the `MakeConnection`. + let channel = Endpoint::try_from("http://[::]:50051")? - .connect_with_connector(service_fn(|_: Uri| { + .connect_with_connector(service_fn(|_: Uri| async { let path = "/tmp/tonic/helloworld"; // Connect to a Uds socket - UnixStream::connect(path) + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path).await?)) })) .await?; diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 99afded3f..d7e250ce4 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -139,7 +139,7 @@ pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { Endpoint::try_from("http://[::]:50051") .unwrap() .connect_with_connector(service_fn(move |_: Uri| { - let client = client.take().unwrap(); + let client = TokioIo::new(client.take().unwrap()); async move { Ok::<_, std::io::Error>(client) } })) .await diff --git a/tests/integration_tests/tests/connect_info.rs b/tests/integration_tests/tests/connect_info.rs index 94fac8221..e87bb858f 100644 --- a/tests/integration_tests/tests/connect_info.rs +++ b/tests/integration_tests/tests/connect_info.rs @@ -51,6 +51,9 @@ async fn getting_connect_info() { #[cfg(unix)] pub mod unix { + use std::io; + + use hyper_util::rt::TokioIo; use tokio::{ net::{UnixListener, UnixStream}, sync::oneshot, @@ -106,7 +109,10 @@ pub mod unix { let path = unix_socket_path.clone(); let channel = Endpoint::try_from("http://[::]:50051") .unwrap() - .connect_with_connector(service_fn(move |_: Uri| UnixStream::connect(path.clone()))) + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + async move { Ok::<_, io::Error>(TokioIo::new(UnixStream::connect(path).await?)) } + })) .await .unwrap(); diff --git a/tests/integration_tests/tests/max_message_size.rs b/tests/integration_tests/tests/max_message_size.rs index 9ae524dbc..f03699cdf 100644 --- a/tests/integration_tests/tests/max_message_size.rs +++ b/tests/integration_tests/tests/max_message_size.rs @@ -1,5 +1,6 @@ use std::pin::Pin; +use hyper_util::rt::TokioIo; use integration_tests::{ pb::{test1_client, test1_server, Input1, Output1}, trace_init, @@ -163,7 +164,7 @@ async fn response_stream_limit() { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, @@ -332,7 +333,7 @@ async fn max_message_run(case: &TestCase) -> Result<(), Status> { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index 3fdabcd36..df6bc4b3b 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -1,5 +1,6 @@ use bytes::Bytes; use http::Uri; +use hyper_util::rt::TokioIo; use integration_tests::mock::MockStream; use integration_tests::pb::{ test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output, @@ -183,7 +184,7 @@ async fn status_from_server_stream_with_source() { let channel = Endpoint::try_from("http://[::]:50051") .unwrap() .connect_with_connector_lazy(tower::service_fn(move |_: Uri| async move { - Err::(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) + Err::, _>(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) })); let mut client = test_stream_client::TestStreamClient::new(channel); diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index 96720b19e..b46d98d45 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -6,6 +6,7 @@ use http_body_util::{BodyExt as _, Full}; use hyper::body::Incoming; use hyper::http::{header, StatusCode}; use hyper::{Method, Request, Uri}; +use hyper_util::rt::TokioExecutor; use prost::Message; use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; @@ -20,6 +21,7 @@ use tonic_web::GrpcWebLayer; async fn binary_request() { let server_url = spawn().await; let client = Client::new(); + let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web", "grpc-web"); let res = client.request(req).await.unwrap(); @@ -43,6 +45,7 @@ async fn binary_request() { async fn text_request() { let server_url = spawn().await; let client = Client::new(); + let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); let res = client.request(req).await.unwrap(); diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 1934c416e..51b38a6e6 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -37,10 +37,9 @@ transport = [ "dep:axum", "channel", "dep:h2", - "dep:hyper", "dep:tokio", "tokio?/net", "tokio?/time", + "dep:hyper", "dep:hyper-util", "dep:hyper-timeout", "dep:tower", - "dep:hyper-timeout", ] channel = [] diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 995e2a15b..584c56f8c 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -7,6 +7,7 @@ use crate::transport::service::TlsConnector; use crate::transport::{service::SharedExec, Error, Executor}; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; +use hyper::rt; use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration}; use tower::make::MakeConnection; @@ -369,7 +370,7 @@ impl Endpoint { pub async fn connect_with_connector(&self, connector: C) -> Result where C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { @@ -394,7 +395,7 @@ impl Endpoint { pub fn connect_with_connector_lazy(&self, connector: C) -> Channel where C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 6a857dff1..3e5869bcb 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -25,11 +25,9 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc::{channel, Sender}, -}; +use tokio::sync::mpsc::{channel, Sender}; +use hyper::rt; use tower::balance::p2c::Balance; use tower::{ buffer::{self, Buffer}, @@ -152,7 +150,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); @@ -169,7 +167,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index ad930c617..0c64ba6d0 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -13,6 +13,7 @@ pub use super::service::Routes; pub use super::service::RoutesBuilder; pub use conn::{Connected, TcpConnectInfo}; +use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(feature = "tls")] pub use tls::ServerTlsConfig; diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 1fa059c96..8e1f52c5f 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -7,11 +7,11 @@ use http::Uri; use hyper::client::conn::Builder; use hyper::client::connect::Connection as HyperConnection; use hyper::client::service::Connect as HyperConnect; +use hyper::rt; use std::{ fmt, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncWrite}; use tower::load::Load; use tower::{ layer::Layer, @@ -34,7 +34,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { let mut settings = Builder::new() .http2_initial_stream_window_size(endpoint.init_stream_window_size) @@ -83,7 +83,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, false).ready_oneshot().await } @@ -93,7 +93,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, true) } diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index 12336813a..8219fe8d9 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -6,7 +6,11 @@ use http::Uri; #[cfg(feature = "tls")] use std::fmt; use std::task::{Context, Poll}; -use tower::make::MakeConnection; + +use hyper::rt; + +#[cfg(feature = "tls")] +use hyper_util::rt::TokioIo; use tower_service::Service; pub(crate) struct Connector { @@ -51,8 +55,8 @@ impl Connector { impl Service for Connector where - C: MakeConnection, - C::Connection: Unpin + Send + 'static, + C: Service, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { @@ -61,7 +65,7 @@ where type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - MakeConnection::poll_ready(&mut self.inner, cx).map_err(Into::into) + self.inner.poll_ready(cx).map_err(Into::into) } fn call(&mut self, uri: Uri) -> Self::Future { @@ -73,7 +77,7 @@ where #[cfg(feature = "tls")] let is_https = uri.scheme_str() == Some("https"); - let connect = self.inner.make_connection(uri); + let connect = self.inner.call(uri); Box::pin(async move { let io = connect.await?; @@ -81,12 +85,12 @@ where #[cfg(feature = "tls")] { if let Some(tls) = tls { - if is_https { - let conn = tls.connect(io).await?; - return Ok(BoxedIo::new(conn)); + return if is_https { + let io = tls.connect(TokioIo::new(io)).await?; + Ok(io) } else { - return Ok(BoxedIo::new(io)); - } + Ok(BoxedIo::new(io)) + }; } else if is_https { return Err(HttpsUriWithoutTlsSupport(()).into()); } diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 2230b9b2e..cb2296cac 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,5 +1,6 @@ use crate::transport::server::Connected; -use hyper::client::connect::{Connected as HyperConnected, Connection}; +use hyper::rt; +use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection}; use std::io; use std::io::IoSlice; use std::pin::Pin; @@ -9,11 +10,11 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; pub(in crate::transport) trait Io: - AsyncRead + AsyncWrite + Send + 'static + rt::Read + rt::Write + Send + 'static { } -impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} +impl Io for T where T: rt::Read + rt::Write + Send + 'static {} pub(crate) struct BoxedIo(Pin>); @@ -40,17 +41,17 @@ impl Connected for BoxedIo { #[derive(Copy, Clone)] pub(crate) struct NoneConnectInfo; -impl AsyncRead for BoxedIo { +impl rt::Read for BoxedIo { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } -impl AsyncWrite for BoxedIo { +impl rt::Write for BoxedIo { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, From e1e1ffee64c8306dfc1237de1dfffb654d20c369 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:00:15 +0000 Subject: [PATCH 17/31] Upgrade axum to 0.7 Axum must be >= 0.7 to support hyper >= 1 Doing this also involves changing the Body type used. Since hyper >= 1 does not provide a generic body type, Axum and tonic both use `BoxBody` to provide a pointer to a Body. This changes the trait bounds required for methods which accept additional Serivces to be run alongside the primary GRPC service, since those will be routed with Axum, and therefore must accept a BoxBody. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/Cargo.toml | 2 +- tonic/src/transport/mod.rs | 5 +- tonic/src/transport/server/incoming.rs | 4 +- tonic/src/transport/server/mod.rs | 340 ++++++++++++++++++++----- tonic/src/transport/service/router.rs | 53 +++- 5 files changed, 324 insertions(+), 80 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 51b38a6e6..da4482291 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -76,7 +76,7 @@ socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] } tokio = {version = "1", default-features = false, optional = true} tokio-stream = { version = "0.1", features = ["net"] } tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} -axum = {version = "0.6.9", default-features = false, optional = true} +axum = {version = "0.7", default-features = false, optional = true} # rustls rustls-pemfile = { version = "2.0", optional = true } diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index faa3b49be..978bdfee0 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -56,8 +56,8 @@ //! # async fn do_thing() -> Result<(), Box> { //! # #[derive(Clone)] //! # pub struct Svc; -//! # type Response = hyper::Response; //! # impl Service> for Svc { +//! # type Response = hyper::Response; //! # type Error = Infallible; //! # type Future = std::future::Ready>; //! # fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { @@ -126,8 +126,5 @@ pub use self::server::ServerTlsConfig; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Identity; -use crate::body::BoxBody; type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; -pub(crate) type Response = http::Response; -pub(crate) type Request = http::Request; diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index bc1bb7650..ede62a32d 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -139,13 +139,13 @@ impl TcpIncoming { /// ```no_run /// # use tower_service::Service; /// # use http::{request::Request, response::Response}; - /// # use tonic::{body::BoxBody, server::NamedService, transport::{Body, Server, server::TcpIncoming}}; + /// # use tonic::{body::BoxBody, server::NamedService, transport::{Server, server::TcpIncoming}}; /// # use core::convert::Infallible; /// # use std::error::Error; /// # fn main() { } // Cannot have type parameters, hence instead define: /// # fn run(some_service: S) -> Result<(), Box> /// # where - /// # S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, + /// # S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, /// # S::Future: Send + 'static, /// # { /// // Find a free port diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 0c64ba6d0..fb63058ad 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -9,6 +9,13 @@ mod tls; #[cfg(unix)] mod unix; +use tokio_stream::StreamExt as _; +use tower::util::BoxCloneService; +use tower::util::Oneshot; +use tower::ServiceExt; +use tracing::debug; +use tracing::trace; + pub use super::service::Routes; pub use super::service::RoutesBuilder; @@ -42,15 +49,16 @@ use crate::server::NamedService; use bytes::Bytes; use http::{Request, Response}; use http_body_util::BodyExt; -use hyper::server::accept; +use hyper::body::Incoming; use pin_project::pin_project; +use std::future::poll_fn; use std::{ convert::Infallible, fmt, future::{self, Future}, marker::PhantomData, net::SocketAddr, - pin::Pin, + pin::{pin, Pin}, sync::Arc, task::{ready, Context, Poll}, time::Duration, @@ -65,18 +73,17 @@ use tower::{ Service, ServiceBuilder, }; -type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; type BoxHttpBody = crate::body::BoxBody; -type Body = hyper::body::Incoming; // Temporary type alias to ease transition type BoxError = crate::Error; -type BoxService = tower::util::BoxCloneService, Response, crate::Error>; +type BoxService = tower::util::BoxCloneService, Response, crate::Error>; +type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; /// A default batteries included `transport` server. /// -/// This is a wrapper around [`hyper::Server`] and provides an easy builder -/// pattern style builder [`Server`]. This builder exposes easy configuration parameters +/// This provides an easy builder pattern style builder [`Server`] on top of +/// `hyper` connections. This builder exposes easy configuration parameters /// for providing a fully featured http2 based gRPC server. This should provide /// a very good out of the box http2 server for use with tonic but is also a /// reference implementation that should be a good starting point for anyone @@ -126,7 +133,7 @@ impl Default for Server { } } -/// A stack based `Service` router. +/// A stack based [`Service`] router. #[derive(Debug)] pub struct Router { server: Server, @@ -363,7 +370,7 @@ impl Server { /// route around different services. pub fn add_service(&mut self, svc: S) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -384,7 +391,7 @@ impl Server { /// As a result, one cannot use this to toggle between two identically named implementations. pub fn add_optional_service(&mut self, svc: Option) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -498,9 +505,11 @@ impl Server { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send + 'static, I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IO::ConnectInfo: Clone + Send + Sync + 'static, @@ -527,10 +536,8 @@ impl Server { let svc = self.service_builder.service(svc); - let tcp = incoming::tcp_incoming(incoming, self); - let incoming = accept::from_stream::<_, _, crate::Error>(tcp); - - let svc = MakeSvc { + let incoming = incoming::tcp_incoming(incoming, self); + let mut svc = MakeSvc { inner: svc, concurrency_limit, timeout, @@ -538,31 +545,204 @@ impl Server { _io: PhantomData, }; - let server = hyper::Server::builder(incoming) - .http2_only(http2_only) - .http2_initial_connection_window_size(init_connection_window_size) - .http2_initial_stream_window_size(init_stream_window_size) - .http2_max_concurrent_streams(max_concurrent_streams) - .http2_keep_alive_interval(http2_keepalive_interval) - .http2_keep_alive_timeout(http2_keepalive_timeout) - .http2_adaptive_window(http2_adaptive_window.unwrap_or_default()) - .http2_max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) - .http2_max_frame_size(max_frame_size); - - if let Some(signal) = signal { - server - .serve(svc) - .with_graceful_shutdown(signal) - .await - .map_err(super::Error::from_source)? - } else { - server.serve(svc).await.map_err(super::Error::from_source)?; + let server = { + let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + + if http2_only { + builder = builder.http2_only(); + } + + builder + .http2() + .initial_connection_window_size(init_connection_window_size) + .initial_stream_window_size(init_stream_window_size) + .max_concurrent_streams(max_concurrent_streams) + .keep_alive_interval(http2_keepalive_interval) + .keep_alive_timeout(http2_keepalive_timeout) + .adaptive_window(http2_adaptive_window.unwrap_or_default()) + .max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) + .max_frame_size(max_frame_size); + + builder + }; + + let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); + let signal_tx = Arc::new(signal_tx); + + let graceful = signal.is_some(); + let mut sig = pin!(Fuse { inner: signal }); + let mut incoming = pin!(incoming); + + loop { + tokio::select! { + _ = &mut sig => { + trace!("signal received, shutting down"); + break; + }, + io = incoming.next() => { + let io = match io { + Some(Ok(io)) => io, + Some(Err(e)) => { + trace!("error accepting connection: {:#}", e); + continue; + }, + None => { + break + }, + }; + + trace!("connection accepted"); + + poll_fn(|cx| svc.poll_ready(cx)) + .await + .map_err(super::Error::from_source)?; + + let req_svc = svc + .call(&io) + .await + .map_err(super::Error::from_source)?; + let hyper_svc = TowerToHyperService::new(req_svc); + + serve_connection(io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone())); + } + } + } + + if graceful { + let _ = signal_tx.send(()); + drop(signal_rx); + trace!( + "waiting for {} connections to close", + signal_tx.receiver_count() + ); + + // Wait for all connections to close + signal_tx.closed().await; } Ok(()) } } +// This is moved to its own function as a way to get around +// https://github.com/rust-lang/rust/issues/102211 +fn serve_connection( + io: ServerIo, + hyper_svc: TowerToHyperService, + builder: ConnectionBuilder, + mut watcher: Option>, +) where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, +{ + tokio::spawn(async move { + { + let mut sig = pin!(Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); + + let mut conn = pin!(builder.serve_connection(TokioIo::new(io), hyper_svc)); + + loop { + tokio::select! { + rv = &mut conn => { + if let Err(err) = rv { + debug!("failed serving connection: {:#}", err); + } + break; + }, + _ = &mut sig => { + conn.as_mut().graceful_shutdown(); + } + } + } + } + + drop(watcher); + trace!("connection closed"); + }); +} + +type ConnectionBuilder = hyper_util::server::conn::auto::Builder; + +/// An adaptor which converts a [`tower::Service`] to a [`hyper::service::Service`]. +/// +/// The [`hyper::service::Service`] trait is used by hyper to handle incoming requests, +/// and does not support the `poll_ready` method that is used by tower services. +#[derive(Debug, Copy, Clone)] +pub struct TowerToHyperService { + service: S, +} + +impl TowerToHyperService { + /// Create a new `TowerToHyperService` from a tower service. + pub fn new(service: S) -> Self { + Self { service } + } + + /// Extract the inner tower service. + pub fn into_inner(self) -> S { + self.service + } + + /// Get a reference to the inner tower service. + pub fn as_inner(&self) -> &S { + &self.service + } + + /// Get a mutable reference to the inner tower service. + pub fn as_inner_mut(&mut self) -> &mut S { + &mut self.service + } +} + +impl hyper::service::Service> for TowerToHyperService +where + S: tower_service::Service> + Clone, + S::Error: Into + 'static, +{ + type Response = S::Response; + type Error = super::Error; + type Future = TowerToHyperServiceFuture>; + + fn call(&self, req: Request) -> Self::Future { + let req = req.map(crate::body::boxed); + TowerToHyperServiceFuture { + future: self.service.clone().oneshot(req), + } + } +} + +/// Future returned by [`TowerToHyperService`]. +#[derive(Debug)] +#[pin_project] +pub struct TowerToHyperServiceFuture +where + S: tower_service::Service, +{ + #[pin] + future: Oneshot, +} + +impl Future for TowerToHyperServiceFuture +where + S: tower_service::Service, + S::Error: Into + 'static, +{ + type Output = Result; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project() + .future + .poll(cx) + .map_err(super::Error::from_source) + } +} + impl Router { pub(crate) fn new(server: Server, routes: Routes) -> Self { Self { server, routes } @@ -573,7 +753,7 @@ impl Router { /// Add a new service to this router. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -592,7 +772,7 @@ impl Router { #[allow(clippy::type_complexity)] pub fn add_optional_service(mut self, svc: Option) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -617,10 +797,12 @@ impl Router { /// [tokio]: https://docs.rs/tokio pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> where - L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L: Layer + Clone, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -648,9 +830,11 @@ impl Router { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -677,9 +861,11 @@ impl Router { IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -712,9 +898,11 @@ impl Router { IE: Into, F: Future, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -727,9 +915,11 @@ impl Router { pub fn into_service(self) -> L::Service where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -743,14 +933,15 @@ impl fmt::Debug for Server { } } +#[derive(Clone)] struct Svc { inner: S, trace_interceptor: Option, } -impl Service> for Svc +impl Service> for Svc where - S: Service, Response = Response>, + S: Service, Response = Response>, S::Error: Into, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, @@ -763,7 +954,7 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { let span = if let Some(trace_interceptor) = &self.trace_interceptor { let (parts, body) = req.into_parts(); let bodyless_request = Request::from_parts(parts, ()); @@ -806,7 +997,7 @@ where let _guard = this.span.enter(); let response: Response = ready!(this.inner.poll(cx)).map_err(Into::into)?; - let response = response.map(|body| body.map_err(Into::into).boxed_unsync()); + let response = response.map(|body| boxed(body.map_err(Into::into))); Poll::Ready(Ok(response)) } } @@ -817,6 +1008,7 @@ impl fmt::Debug for Svc { } } +#[derive(Clone)] struct MakeSvc { concurrency_limit: Option, timeout: Option, @@ -828,7 +1020,7 @@ struct MakeSvc { impl Service<&ServerIo> for MakeSvc where IO: Connected, - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, ResBody: http_body::Body + Send + 'static, @@ -857,8 +1049,8 @@ where .service(svc); let svc = ServiceBuilder::new() - .layer(BoxService::layer()) - .map_request(move |mut request: Request| { + .layer(BoxCloneService::layer()) + .map_request(move |mut request: Request| { match &conn_info { tower::util::Either::A(inner) => { request.extensions_mut().insert(inner.clone()); @@ -889,3 +1081,29 @@ where future::ready(Ok(svc)) } } + +// From `futures-util` crate, borrowed since this is the only dependency tonic requires. +// LICENSE: MIT or Apache-2.0 +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +#[pin_project] +struct Fuse { + #[pin] + inner: Option, +} + +impl Future for Fuse +where + F: Future, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().inner.as_pin_mut() { + Some(fut) => fut.poll(cx).map(|output| { + self.project().inner.set(None); + output + }), + None => Poll::Pending, + } + } +} diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs index 85636c4d4..c43782ba9 100644 --- a/tonic/src/transport/service/router.rs +++ b/tonic/src/transport/service/router.rs @@ -1,9 +1,9 @@ use crate::{ body::{boxed, BoxBody}, server::NamedService, + transport::BoxFuture, }; use http::{Request, Response}; -use hyper::Body; use pin_project::pin_project; use std::{ convert::Infallible, @@ -12,7 +12,6 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tower::ServiceExt; use tower_service::Service; /// A [`Service`] router. @@ -31,7 +30,7 @@ impl RoutesBuilder { /// Add a new service. pub fn add_service(&mut self, svc: S) -> &mut Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -53,7 +52,7 @@ impl Routes { /// Create a new routes with `svc` already added to it. pub fn new(svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -68,7 +67,7 @@ impl Routes { /// Add a new service. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -76,10 +75,10 @@ impl Routes { S::Future: Send + 'static, S::Error: Into + Send, { - let svc = svc.map_response(|res| res.map(axum::body::boxed)); - self.router = self - .router - .route_service(&format!("/{}/*rest", S::NAME), svc); + self.router = self.router.route_service( + &format!("/{}/*rest", S::NAME), + AxumBodyService { service: svc }, + ); self } @@ -103,7 +102,7 @@ async fn unimplemented() -> impl axum::response::IntoResponse { (status, headers) } -impl Service> for Routes { +impl Service> for Routes { type Response = Response; type Error = crate::Error; type Future = RoutesFuture; @@ -113,13 +112,13 @@ impl Service> for Routes { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { RoutesFuture(self.router.call(req)) } } #[pin_project] -pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture); +pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture); impl fmt::Debug for RoutesFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -137,3 +136,33 @@ impl Future for RoutesFuture { } } } + +#[derive(Clone)] +struct AxumBodyService { + service: S, +} + +impl Service> for AxumBodyService +where + S: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + S::Future: Send + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.service.call(req.map(|body| boxed(body))); + Box::pin(async move { + fut.await + .map(|res| res.map(|body| axum::body::Body::new(body))) + }) + } +} From 7a5d95c12c9016b8031262195f5eac83ff14c73c Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:05:19 +0000 Subject: [PATCH 18/31] Convert service connector for hyper-1.0 Hyper >= 1 no longer includes automatic http2/http1 combined connections, and so we must swtich to the `http2::Builder` type (this is okay, we set http2_only(true) anyhow). As well, hyper >= 1 is generic over executors and does not directly depend on tokio. Since http2 connections can be multiplexed, they require some additional background task to handle sending and receiving requests. Additionally, these background tasks do not natively implement `tower::Service` since hyper >= 1 does not depend on `tower`. Therefore, we re-implement the `SendRequest` task as a tower::Service, so that it can be used within `Connection`, which expects to operate on a tower::Service to serve connections. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/transport/service/connection.rs | 113 +++++++++++++++++++--- tonic/src/transport/service/executor.rs | 9 +- 2 files changed, 103 insertions(+), 19 deletions(-) diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 8e1f52c5f..a31c9868b 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,13 +1,12 @@ +use super::SharedExec; use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; use crate::{ body::{boxed, BoxBody}, transport::{BoxFuture, Endpoint}, }; use http::Uri; -use hyper::client::conn::Builder; -use hyper::client::connect::Connection as HyperConnection; -use hyper::client::service::Connect as HyperConnect; use hyper::rt; +use hyper::{client::conn::http2::Builder, rt::Executor}; use std::{ fmt, task::{Context, Poll}, @@ -36,24 +35,22 @@ impl Connection { C::Future: Unpin + Send, C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { - let mut settings = Builder::new() - .http2_initial_stream_window_size(endpoint.init_stream_window_size) - .http2_initial_connection_window_size(endpoint.init_connection_window_size) - .http2_only(true) - .http2_keep_alive_interval(endpoint.http2_keep_alive_interval) - .executor(endpoint.executor.clone()) + let mut settings: Builder = Builder::new(endpoint.executor.clone()) + .initial_stream_window_size(endpoint.init_stream_window_size) + .initial_connection_window_size(endpoint.init_connection_window_size) + .keep_alive_interval(endpoint.http2_keep_alive_interval) .clone(); if let Some(val) = endpoint.http2_keep_alive_timeout { - settings.http2_keep_alive_timeout(val); + settings.keep_alive_timeout(val); } if let Some(val) = endpoint.http2_keep_alive_while_idle { - settings.http2_keep_alive_while_idle(val); + settings.keep_alive_while_idle(val); } if let Some(val) = endpoint.http2_adaptive_window { - settings.http2_adaptive_window(val); + settings.adaptive_window(val); } let stack = ServiceBuilder::new() @@ -68,13 +65,13 @@ impl Connection { .option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); - let connector = HyperConnect::new(connector, settings); - let conn = Reconnect::new(connector, endpoint.uri.clone(), is_lazy); + let make_service = + MakeSendRequestService::new(connector, endpoint.executor.clone(), settings); - let inner = stack.layer(conn); + let conn = Reconnect::new(make_service, endpoint.uri.clone(), is_lazy); Self { - inner: BoxService::new(inner), + inner: BoxService::new(stack.layer(conn)), } } @@ -126,3 +123,87 @@ impl fmt::Debug for Connection { f.debug_struct("Connection").finish() } } + +struct SendRequest { + inner: hyper::client::conn::http2::SendRequest, +} + +impl From> for SendRequest { + fn from(inner: hyper::client::conn::http2::SendRequest) -> Self { + Self { inner } + } +} + +impl tower::Service> for SendRequest { + type Response = http::Response; + type Error = crate::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.inner.send_request(req); + + Box::pin(async move { + fut.await + .map_err(Into::into) + .map(|res| res.map(|body| boxed(body))) + }) + } +} + +struct MakeSendRequestService { + connector: C, + executor: super::SharedExec, + settings: Builder, +} + +impl MakeSendRequestService { + fn new(connector: C, executor: SharedExec, settings: Builder) -> Self { + Self { + connector, + executor, + settings, + } + } +} + +impl tower::Service for MakeSendRequestService +where + C: Service + Send + 'static, + C::Error: Into + Send, + C::Future: Unpin + Send, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, +{ + type Response = SendRequest; + type Error = crate::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.connector.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Uri) -> Self::Future { + let fut = self.connector.call(req); + let builder = self.settings.clone(); + let executor = self.executor.clone(); + + Box::pin(async move { + let io = fut.await.map_err(Into::into)?; + let (send_request, conn) = builder.handshake(io).await?; + + Executor::>::execute( + &executor, + Box::pin(async move { + if let Err(e) = conn.await { + tracing::debug!("connection task error: {:?}", e); + } + }) as _, + ); + + Ok(SendRequest::from(send_request)) + }) + } +} diff --git a/tonic/src/transport/service/executor.rs b/tonic/src/transport/service/executor.rs index de3cfbe6e..7b699c307 100644 --- a/tonic/src/transport/service/executor.rs +++ b/tonic/src/transport/service/executor.rs @@ -36,8 +36,11 @@ impl SharedExec { } } -impl Executor> for SharedExec { - fn execute(&self, fut: BoxFuture<'static, ()>) { - self.inner.execute(fut) +impl Executor for SharedExec +where + F: Future + Send + 'static, +{ + fn execute(&self, fut: F) { + self.inner.execute(Box::pin(fut)) } } From ef542bc5b82bcbf198144a70ddcd206915529261 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:07:21 +0000 Subject: [PATCH 19/31] Convert hyper::Client to hyper_util::legacy::Client `hyper::Client` has been moved to `hyper_util::legacy::Client` in version 1. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- examples/src/grpc-web/client.rs | 2 +- examples/src/h2c/client.rs | 16 ++++++++-------- examples/src/tls_rustls/client.rs | 5 +++-- tonic-web/tests/integration/tests/grpc_web.rs | 4 ++-- tonic/src/transport/channel/endpoint.rs | 11 ++++++----- tonic/src/transport/channel/mod.rs | 4 ++-- tonic/src/transport/service/discover.rs | 3 ++- 7 files changed, 24 insertions(+), 21 deletions(-) diff --git a/examples/src/grpc-web/client.rs b/examples/src/grpc-web/client.rs index fd20a788b..fa64dd506 100644 --- a/examples/src/grpc-web/client.rs +++ b/examples/src/grpc-web/client.rs @@ -9,7 +9,7 @@ pub mod hello_world { #[tokio::main] async fn main() -> Result<(), Box> { // Must use hyper directly... - let client = hyper::Client::builder().build_http(); + let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build_http(); let svc = tower::ServiceBuilder::new() .layer(GrpcWebClientLayer::new()) diff --git a/examples/src/h2c/client.rs b/examples/src/h2c/client.rs index 2f9f90a79..b162fcc08 100644 --- a/examples/src/h2c/client.rs +++ b/examples/src/h2c/client.rs @@ -1,7 +1,7 @@ use hello_world::greeter_client::GreeterClient; use hello_world::HelloRequest; use http::Uri; -use hyper::Client; +use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; pub mod hello_world { @@ -12,7 +12,6 @@ pub mod hello_world { async fn main() -> Result<(), Box> { let origin = Uri::from_static("http://[::1]:50051"); let h2c_client = h2c::H2cChannel { - client: Client::new(), client: Client::builder(TokioExecutor::new()).build_http(), }; @@ -35,10 +34,11 @@ mod h2c { task::{Context, Poll}, }; - use hyper::{client::HttpConnector, Client}; use hyper::body::Incoming; use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, rt::TokioExecutor, + }; use tonic::body::{empty_body, BoxBody}; use tower::Service; @@ -77,11 +77,11 @@ mod h2c { let upgraded_io = hyper::upgrade::on(res).await.unwrap(); // In an ideal world you would somehow cache this connection - let (mut h2_client, conn) = hyper::client::conn::Builder::new() - .http2_only(true) - .handshake(upgraded_io) - .await - .unwrap(); + let (mut h2_client, conn) = + hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(upgraded_io) + .await + .unwrap(); tokio::spawn(conn); h2_client.send_request(request).await diff --git a/examples/src/tls_rustls/client.rs b/examples/src/tls_rustls/client.rs index 23d6e8130..f03f70051 100644 --- a/examples/src/tls_rustls/client.rs +++ b/examples/src/tls_rustls/client.rs @@ -5,7 +5,8 @@ pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } -use hyper::{client::HttpConnector, Uri}; +use hyper::Uri; +use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use pb::{echo_client::EchoClient, EchoRequest}; use tokio_rustls::rustls::{ClientConfig, RootCertStore}; @@ -47,7 +48,7 @@ async fn main() -> Result<(), Box> { .map_request(|_| Uri::from_static("https://[::1]:50051")) .service(http); - let client = hyper::Client::builder().build(connector); + let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector); // Using `with_origin` will let the codegenerated client set the `scheme` and // `authority` from the porvided `Uri`. diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index b46d98d45..2c57f2680 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -6,6 +6,7 @@ use http_body_util::{BodyExt as _, Full}; use hyper::body::Incoming; use hyper::http::{header, StatusCode}; use hyper::{Method, Request, Uri}; +use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; use prost::Message; use tokio::net::TcpListener; @@ -15,12 +16,12 @@ use tonic::transport::Server; use integration::pb::{test_server::TestServer, Input, Output}; use integration::Svc; +use tonic::Status; use tonic_web::GrpcWebLayer; #[tokio::test] async fn binary_request() { let server_url = spawn().await; - let client = Client::new(); let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web", "grpc-web"); @@ -44,7 +45,6 @@ async fn binary_request() { #[tokio::test] async fn text_request() { let server_url = spawn().await; - let client = Client::new(); let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 584c56f8c..6014960a8 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -8,8 +8,9 @@ use crate::transport::{service::SharedExec, Error, Executor}; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; use hyper::rt; +use hyper_util::client::legacy::connect::HttpConnector; use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration}; -use tower::make::MakeConnection; +use tower_service::Service; /// Channel builder. /// @@ -333,7 +334,7 @@ impl Endpoint { /// Create a channel from this config. pub async fn connect(&self) -> Result { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(self.tcp_nodelay); http.set_keepalive(self.tcp_keepalive); @@ -349,7 +350,7 @@ impl Endpoint { /// The channel returned by this method does not attempt to connect to the endpoint until first /// use. pub fn connect_lazy(&self) -> Channel { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(self.tcp_nodelay); http.set_keepalive(self.tcp_keepalive); @@ -369,7 +370,7 @@ impl Endpoint { /// The [`connect_timeout`](Endpoint::connect_timeout) will still be applied. pub async fn connect_with_connector(&self, connector: C) -> Result where - C: MakeConnection + Send + 'static, + C: Service + Send + 'static, C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, @@ -394,7 +395,7 @@ impl Endpoint { /// uses a Unix socket transport. pub fn connect_with_connector_lazy(&self, connector: C) -> Channel where - C: MakeConnection + Send + 'static, + C: Service + Send + 'static, C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 3e5869bcb..0983725f8 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -17,7 +17,7 @@ use http::{ uri::{InvalidUri, Uri}, Request, Response, }; -use hyper::client::connect::Connection as HyperConnection; +use hyper_util::client::legacy::connect::Connection as HyperConnection; use std::{ fmt, future::Future, @@ -42,7 +42,7 @@ const DEFAULT_BUFFER_SIZE: usize = 1024; /// A default batteries included `transport` channel. /// -/// This provides a fully featured http2 gRPC client based on [`hyper::Client`] +/// This provides a fully featured http2 gRPC client based on `hyper` /// and `tower` services. /// /// # Multiplexing requests diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/service/discover.rs index 2d23ca74c..b9356110e 100644 --- a/tonic/src/transport/service/discover.rs +++ b/tonic/src/transport/service/discover.rs @@ -1,6 +1,7 @@ use super::connection::Connection; use crate::transport::Endpoint; +use hyper_util::client::legacy::connect::HttpConnector; use std::{ hash::Hash, pin::Pin, @@ -32,7 +33,7 @@ impl Stream for DynamicServiceStream { Poll::Pending | Poll::Ready(None) => Poll::Pending, Poll::Ready(Some(change)) => match change { Change::Insert(k, endpoint) => { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.set_nodelay(endpoint.tcp_nodelay); http.set_keepalive(endpoint.tcp_keepalive); http.set_connect_timeout(endpoint.connect_timeout); From 9c6b63a8b7be83eb4dea31a731630abdc2f618d0 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 18:54:31 +0000 Subject: [PATCH 20/31] Identify and propogate connect errors hyper::Error no longer provides information about Connect errors, especially since hyper_util now contains the connection implementation, it does not provide a separate error type. Instead, we create an internal Error type which is used in our own connectors, and then checked when figuring out what the gRPC status should be. --- tonic/src/status.rs | 13 +++--- tonic/src/transport/mod.rs | 2 + tonic/src/transport/service/connector.rs | 58 ++++++++++++++++-------- tonic/src/transport/service/mod.rs | 1 + 4 files changed, 49 insertions(+), 25 deletions(-) diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 108ee3cf2..0783bd8e2 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -412,13 +412,7 @@ impl Status { // > status. Note that the frequency of PINGs is highly dependent on the network // > environment, implementations are free to adjust PING frequency based on network and // > application requirements, which is why it's mapped to unavailable here. - // - // Likewise, if we are unable to connect to the server, map this to UNAVAILABLE. This is - // consistent with the behavior of a C++ gRPC client when the server is not running, and - // matches the spec of: - // > The service is currently unavailable. This is most likely a transient condition that - // > can be corrected if retried with a backoff. - if err.is_timeout() || err.is_connect() { + if err.is_timeout() { return Some(Status::unavailable(err.to_string())); } @@ -620,6 +614,11 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option { return Some(Status::cancelled(timeout.to_string())); } + #[cfg(feature = "transport")] + if let Some(connect) = err.downcast_ref::() { + return Some(Status::unavailable(connect.to_string())); + } + #[cfg(feature = "transport")] if let Some(hyper) = err .downcast_ref::() diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 978bdfee0..719f260e9 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -107,6 +107,8 @@ pub use self::error::Error; pub use self::server::Server; #[doc(inline)] pub use self::service::grpc_timeout::TimeoutExpired; +pub(crate) use self::service::ConnectError; + #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index 8219fe8d9..978441d75 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -3,7 +3,6 @@ use super::io::BoxedIo; #[cfg(feature = "tls")] use super::tls::TlsConnector; use http::Uri; -#[cfg(feature = "tls")] use std::fmt; use std::task::{Context, Poll}; @@ -13,6 +12,23 @@ use hyper::rt; use hyper_util::rt::TokioIo; use tower_service::Service; +/// Wrapper type to indicate that an error occurs during the connection +/// process, so that the appropriate gRPC Status can be inferred. +#[derive(Debug)] +pub(crate) struct ConnectError(pub(crate) crate::Error); + +impl fmt::Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl std::error::Error for ConnectError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(self.0.as_ref()) + } +} + pub(crate) struct Connector { inner: C, #[cfg(feature = "tls")] @@ -61,11 +77,13 @@ where crate::Error: From + Send + 'static, { type Response = BoxedIo; - type Error = crate::Error; + type Error = ConnectError; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx).map_err(Into::into) + self.inner + .poll_ready(cx) + .map_err(|err| ConnectError(From::from(err))) } fn call(&mut self, uri: Uri) -> Self::Future { @@ -80,23 +98,27 @@ where let connect = self.inner.call(uri); Box::pin(async move { - let io = connect.await?; - - #[cfg(feature = "tls")] - { - if let Some(tls) = tls { - return if is_https { - let io = tls.connect(TokioIo::new(io)).await?; - Ok(io) - } else { - Ok(BoxedIo::new(io)) - }; - } else if is_https { - return Err(HttpsUriWithoutTlsSupport(()).into()); + async { + let io = connect.await?; + + #[cfg(feature = "tls")] + { + if let Some(tls) = tls { + return if is_https { + let io = tls.connect(TokioIo::new(io)).await?; + Ok(io) + } else { + Ok(BoxedIo::new(io)) + }; + } else if is_https { + return Err(HttpsUriWithoutTlsSupport(()).into()); + } } - } - Ok(BoxedIo::new(io)) + Ok::<_, crate::Error>(BoxedIo::new(io)) + } + .await + .map_err(|err| ConnectError(From::from(err))) }) } } diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 69d850f10..2b2a84070 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -13,6 +13,7 @@ mod user_agent; pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; +pub(crate) use self::connector::ConnectError; pub(crate) use self::connector::Connector; pub(crate) use self::discover::DynamicServiceStream; pub(crate) use self::executor::SharedExec; From 39e9d571e46fc861a0430faa096c7a3db2625e37 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:10:24 +0000 Subject: [PATCH 21/31] Remove hyper::server::conn::AddrStream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit hyper >= 1 has deprecated all of `hyper::server`, including `AddrStream` Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang Replace hyper::server::Accept hyper::server is deprectaed. Instead, we implement our own TCP-incoming based on the now removed hyper::server::Accept. In order to set `TCP_KEEPALIVE` we require the socket2 crate, since this option is not exposed in the standard library’s API. The implementaiton is inspired by that of hyper v0.14 --- tonic/Cargo.toml | 3 +- tonic/src/transport/server/conn.rs | 12 ----- tonic/src/transport/server/incoming.rs | 66 +++++++++++++++++++------- 3 files changed, 50 insertions(+), 31 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index da4482291..b795cc5f9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -37,8 +37,9 @@ transport = [ "dep:axum", "channel", "dep:h2", - "dep:tokio", "tokio?/net", "tokio?/time", "dep:hyper", "dep:hyper-util", "dep:hyper-timeout", + "dep:socket2", + "dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time", "dep:tower", ] channel = [] diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index 122f13baf..49c086a59 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -1,4 +1,3 @@ -use hyper::server::conn::AddrStream; use std::net::SocketAddr; use tokio::net::TcpStream; @@ -86,17 +85,6 @@ impl TcpConnectInfo { } } -impl Connected for AddrStream { - type ConnectInfo = TcpConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - TcpConnectInfo { - local_addr: Some(self.local_addr()), - remote_addr: Some(self.remote_addr()), - } - } -} - impl Connected for TcpStream { type ConnectInfo = TcpConnectInfo; diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index ede62a32d..7f5f76c25 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -1,20 +1,18 @@ use super::{Connected, Server}; use crate::transport::service::ServerIo; -use hyper::server::{ - accept::Accept, - conn::{AddrIncoming, AddrStream}, -}; use std::{ - net::SocketAddr, + net::{SocketAddr, TcpListener as StdTcpListener}, pin::{pin, Pin}, - task::{Context, Poll}, + task::{ready, Context, Poll}, time::Duration, }; use tokio::{ io::{AsyncRead, AsyncWrite}, - net::TcpListener, + net::{TcpListener, TcpStream}, }; +use tokio_stream::wrappers::TcpListenerStream; use tokio_stream::{Stream, StreamExt}; +use tracing::warn; #[cfg(not(feature = "tls"))] pub(crate) fn tcp_incoming( @@ -127,7 +125,9 @@ enum SelectOutput { /// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address. #[derive(Debug)] pub struct TcpIncoming { - inner: AddrIncoming, + inner: TcpListenerStream, + nodelay: bool, + keepalive: Option, } impl TcpIncoming { @@ -167,10 +167,15 @@ impl TcpIncoming { nodelay: bool, keepalive: Option, ) -> Result { - let mut inner = AddrIncoming::bind(&addr)?; - inner.set_nodelay(nodelay); - inner.set_keepalive(keepalive); - Ok(TcpIncoming { inner }) + let std_listener = StdTcpListener::bind(addr)?; + std_listener.set_nonblocking(true)?; + + let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?); + Ok(Self { + inner, + nodelay, + keepalive, + }) } /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`. @@ -179,18 +184,43 @@ impl TcpIncoming { nodelay: bool, keepalive: Option, ) -> Result { - let mut inner = AddrIncoming::from_listener(listener)?; - inner.set_nodelay(nodelay); - inner.set_keepalive(keepalive); - Ok(TcpIncoming { inner }) + Ok(Self { + inner: TcpListenerStream::new(listener), + nodelay, + keepalive, + }) } } impl Stream for TcpIncoming { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_accept(cx) + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(stream)) => { + set_accepted_socket_options(&stream, self.nodelay, self.keepalive); + Some(Ok(stream)).into() + } + other => Poll::Ready(other), + } + } +} + +// Consistent with hyper-0.14, this function does not return an error. +fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option) { + if nodelay { + if let Err(e) = stream.set_nodelay(true) { + warn!("error trying to set TCP nodelay: {}", e); + } + } + + if let Some(timeout) = keepalive { + let sock_ref = socket2::SockRef::from(&stream); + let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout); + + if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) { + warn!("error trying to set TCP keepalive: {}", e); + } } } From ca9a859e73c6906b289da6eff4984ae57bf06b48 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:13:39 +0000 Subject: [PATCH 22/31] [examples] In h2c, replace hyper::Server with an accept loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit hyper::Server is deprecated, with no current common replacement. Instead of implementing (or using tonic’s new) full server in here, we write a simple accept loop, which is sufficient to demonstrate the functionality of h2c. --- examples/src/h2c/server.rs | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index b1d4c0a8d..da5c3425c 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -1,9 +1,14 @@ +use std::net::SocketAddr; + use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder; +use hyper_util::service::TowerToHyperService; +use tokio::net::TcpListener; +// use tonic::transport::server::TowerToHyperService; use tonic::{transport::Server, Request, Response, Status}; use hello_world::greeter_server::{Greeter, GreeterServer}; use hello_world::{HelloReply, HelloRequest}; -use tower::make::Shared; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -29,21 +34,36 @@ impl Greeter for MyGreeter { #[tokio::main] async fn main() -> Result<(), Box> { - let addr = "[::1]:50051".parse().unwrap(); + let addr: SocketAddr = "[::1]:50051".parse().unwrap(); let greeter = MyGreeter::default(); println!("GreeterServer listening on {}", addr); + let incoming = TcpListener::bind(addr).await?; let svc = Server::builder() .add_service(GreeterServer::new(greeter)) .into_router(); let h2c = h2c::H2c { s: svc }; - let server = hyper::Server::bind(&addr).serve(Shared::new(h2c)); - server.await.unwrap(); - - Ok(()) + loop { + match incoming.accept().await { + Ok((io, _)) => { + let router = h2c.clone(); + tokio::spawn(async move { + let builder = Builder::new(TokioExecutor::new()); + let conn = builder.serve_connection_with_upgrades( + TokioIo::new(io), + TowerToHyperService::new(router), + ); + let _ = conn.await; + }); + } + Err(e) => { + eprintln!("Error accepting connection: {}", e); + } + } + } } mod h2c { From 32c91838122eaaab076b2c4606751a680afe5b46 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:22:29 +0000 Subject: [PATCH 23/31] Upgrade tls dependencies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit hyper-rustls requires version 0.27.0 to support hyper >= 1, bringing a few other tls bumps along. Importantly, we add the “ring” and “tls12” features to use ring as the crypto backend, consistent with previous versions of tonic. A future version of tonic might support selecting backends via features. Co-authored-by: Ivan Krivosheev --- examples/src/tls_rustls/client.rs | 5 ++-- examples/src/tls_rustls/server.rs | 39 ++++++++++++++++++------------ tonic/src/transport/service/tls.rs | 3 ++- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/examples/src/tls_rustls/client.rs b/examples/src/tls_rustls/client.rs index f03f70051..4c39a2c46 100644 --- a/examples/src/tls_rustls/client.rs +++ b/examples/src/tls_rustls/client.rs @@ -18,11 +18,10 @@ async fn main() -> Result<(), Box> { let mut roots = RootCertStore::empty(); let mut buf = std::io::BufReader::new(&fd); - let certs = rustls_pemfile::certs(&mut buf)?; - roots.add_parsable_certificates(&certs); + let certs = rustls_pemfile::certs(&mut buf).collect::, _>>()?; + roots.add_parsable_certificates(certs.into_iter()); let tls = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(roots) .with_no_client_auth(); diff --git a/examples/src/tls_rustls/server.rs b/examples/src/tls_rustls/server.rs index 82f009344..5630edfa1 100644 --- a/examples/src/tls_rustls/server.rs +++ b/examples/src/tls_rustls/server.rs @@ -2,45 +2,51 @@ pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } -use hyper::server::conn::Http; +use hyper::server::conn::http2::Builder; +use hyper_util::rt::{TokioExecutor, TokioIo}; use pb::{EchoRequest, EchoResponse}; use std::sync::Arc; use tokio::net::TcpListener; use tokio_rustls::{ - rustls::{Certificate, PrivateKey, ServerConfig}, + rustls::{ + pki_types::{CertificateDer, PrivatePkcs8KeyDer}, + ServerConfig, + }, TlsAcceptor, }; +use tonic::transport::server::TowerToHyperService; use tonic::{transport::Server, Request, Response, Status}; use tower_http::ServiceBuilderExt; #[tokio::main] async fn main() -> Result<(), Box> { let data_dir = std::path::PathBuf::from_iter([std::env!("CARGO_MANIFEST_DIR"), "data"]); - let certs = { + let certs: Vec> = { let fd = std::fs::File::open(data_dir.join("tls/server.pem"))?; let mut buf = std::io::BufReader::new(&fd); - rustls_pemfile::certs(&mut buf)? + rustls_pemfile::certs(&mut buf) .into_iter() - .map(Certificate) - .collect() + .map(|res| res.map(|cert| cert.to_owned())) + .collect::, _>>()? }; - let key = { + let key: PrivatePkcs8KeyDer<'static> = { let fd = std::fs::File::open(data_dir.join("tls/server.key"))?; let mut buf = std::io::BufReader::new(&fd); - rustls_pemfile::pkcs8_private_keys(&mut buf)? + let key = rustls_pemfile::pkcs8_private_keys(&mut buf) .into_iter() - .map(PrivateKey) .next() - .unwrap() + .unwrap()? + .clone_key(); + + key // let key = std::fs::read(data_dir.join("tls/server.key"))?; // PrivateKey(key) }; let mut tls = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_single_cert(certs, key.into())?; tls.alpn_protocols = vec![b"h2".to_vec()]; let server = EchoServer::default(); @@ -49,8 +55,7 @@ async fn main() -> Result<(), Box> { .add_service(pb::echo_server::EchoServer::new(server)) .into_service(); - let mut http = Http::new(); - http.http2_only(true); + let http = Builder::new(TokioExecutor::new()); let listener = TcpListener::bind("[::1]:50051").await?; let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); @@ -86,7 +91,9 @@ async fn main() -> Result<(), Box> { .add_extension(Arc::new(ConnInfo { addr, certificates })) .service(svc); - http.serve_connection(conn, svc).await.unwrap(); + http.serve_connection(TokioIo::new(conn), TowerToHyperService::new(svc)) + .await + .unwrap(); }); } } @@ -94,7 +101,7 @@ async fn main() -> Result<(), Box> { #[derive(Debug)] struct ConnInfo { addr: std::net::SocketAddr, - certificates: Vec, + certificates: Vec>, } type EchoResult = Result, Status>; diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 2a6394a4f..2ce9dc5da 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -18,6 +18,7 @@ use crate::transport::{ server::{Connected, TlsStream}, Certificate, Identity, }; +use hyper_util::rt::TokioIo; /// h2 alpn in plain format for rustls. const ALPN_H2: &[u8] = b"h2"; @@ -88,7 +89,7 @@ impl TlsConnector { if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) { return Err(TlsError::H2NotNegotiated.into()); } - Ok(BoxedIo::new(io)) + Ok(BoxedIo::new(TokioIo::new(io))) } } From ed74e45771266d70d27c328f9df7341ca99e5a19 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Fri, 31 May 2024 04:52:20 +0000 Subject: [PATCH 24/31] Combine trailers when streaming decode body We aren't sure if multiple trailers should even be legal, but if we get multiple trailers in an HTTP body stream, we'll combine them all, to preserve their data. Alternatively we'd have to pick the first or last trailers, and that might lose information. --- tonic/src/codec/decode.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index e5aee85f2..2117df8b5 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -262,7 +262,15 @@ impl StreamingInner { Ok(Some(())) } frame if frame.is_trailers() => { - self.trailers = Some(frame.into_trailers().unwrap()); + match &mut self.trailers { + Some(trailers) => { + trailers.extend(frame.into_trailers().unwrap()); + } + None => { + self.trailers = Some(frame.into_trailers().unwrap()); + } + } + Ok(None) } frame => panic!("unexpected frame: {:?}", frame), From 855ec61af61cfd7457acc9be9a5c44f0b474fe66 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Fri, 31 May 2024 04:53:19 +0000 Subject: [PATCH 25/31] Tweak imports in transport example Example used `empty_body()`, which is now fully qualified as `tonic::body::empty_body()` to make clear that this is a tonic helper method for creating an empty BoxBody. --- tonic/src/transport/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 719f260e9..8cd1fdf12 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -22,7 +22,6 @@ //! # use tonic::transport::{Channel, Certificate, ClientTlsConfig}; //! # use std::time::Duration; //! # use tonic::body::BoxBody; -//! # use tonic::body::empty_body; //! # use tonic::client::GrpcService;; //! # use http::Request; //! # #[cfg(feature = "rustls")] @@ -39,7 +38,7 @@ //! .connect() //! .await?; //! -//! channel.call(Request::new(empty_body())).await?; +//! channel.call(Request::new(tonic::body::empty_body())).await?; //! # Ok(()) //! # } //! ``` From bd9de54b1fcd703250a5aca6176382912ccee24d Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Wed, 12 Jun 2024 01:15:40 +0000 Subject: [PATCH 26/31] Remove commented out code from examples/h2c --- examples/src/h2c/server.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index da5c3425c..cf981f957 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -4,7 +4,6 @@ use hyper_util::rt::{TokioExecutor, TokioIo}; use hyper_util::server::conn::auto::Builder; use hyper_util::service::TowerToHyperService; use tokio::net::TcpListener; -// use tonic::transport::server::TowerToHyperService; use tonic::{transport::Server, Request, Response, Status}; use hello_world::greeter_server::{Greeter, GreeterServer}; From deab62b9c40debc0eb57591e5752581772e46d7d Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Wed, 12 Jun 2024 01:16:16 +0000 Subject: [PATCH 27/31] tonic-web: avoid copy to vector to base64 encode --- tonic-web/src/call.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tonic-web/src/call.rs b/tonic-web/src/call.rs index 178e620ae..1389441f6 100644 --- a/tonic-web/src/call.rs +++ b/tonic-web/src/call.rs @@ -219,9 +219,7 @@ where let mut res = frame.into_data().unwrap(); if *this.encoding == Encoding::Base64 { - let mut buf = Vec::with_capacity(res.len()); - buf.extend_from_slice(&res); - res = crate::util::base64::STANDARD.encode(buf).into(); + res = crate::util::base64::STANDARD.encode(res).into(); } Poll::Ready(Some(Ok(Frame::data(res)))) From 39f9f111b9b65d6f482161bbee7daf3922969de4 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Wed, 12 Jun 2024 01:17:06 +0000 Subject: [PATCH 28/31] tonic-web: Merge subsequent trailer frames Ideally, a body should only return a single trailer frame. If multiple trailers are returned, merge them together. --- tonic-web/src/call.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tonic-web/src/call.rs b/tonic-web/src/call.rs index 1389441f6..5f67e4d5c 100644 --- a/tonic-web/src/call.rs +++ b/tonic-web/src/call.rs @@ -264,7 +264,14 @@ where } Some(Ok(incoming_buf)) if incoming_buf.is_trailers() => { let trailers = incoming_buf.into_trailers().unwrap(); - me.as_mut().project().trailers.replace(trailers); + match me.as_mut().project().trailers { + Some(current_trailers) => { + current_trailers.extend(trailers); + } + None => { + me.as_mut().project().trailers.replace(trailers); + } + } continue; } Some(Ok(_)) => unreachable!("unexpected frame type"), From 791b138504ab4be2a37d6ae38ca49e4c5e1633cc Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Wed, 12 Jun 2024 01:18:00 +0000 Subject: [PATCH 29/31] Comment in tonic::status::find_status_in_source_chain MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Comment mentions why we choose “Unavailable” for connection errors --- tonic/src/status.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 0783bd8e2..968693f87 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -614,6 +614,11 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option { return Some(Status::cancelled(timeout.to_string())); } + // If we are unable to connect to the server, map this to UNAVAILABLE. This is + // consistent with the behavior of a C++ gRPC client when the server is not running, and + // matches the spec of: + // > The service is currently unavailable. This is most likely a transient condition that + // > can be corrected if retried with a backoff. #[cfg(feature = "transport")] if let Some(connect) = err.downcast_ref::() { return Some(Status::unavailable(connect.to_string())); From 8345dfe5dc8b40f1af7e52003de4acf857eeceef Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Wed, 12 Jun 2024 01:20:41 +0000 Subject: [PATCH 30/31] Make TowerToHyperService crate-private MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This also requires vendoring it in the rustls example, which doesn’t use a server type. Making the type crate-private means we can delete some unused methods. --- examples/Cargo.toml | 5 ++- examples/src/tls_rustls/server.rs | 72 ++++++++++++++++++++++++++++++- tonic/src/transport/server/mod.rs | 40 ++++------------- 3 files changed, 82 insertions(+), 35 deletions(-) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index e04868826..deab25fd4 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -280,7 +280,7 @@ tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "dep:http"] json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"] compression = ["tonic/gzip"] tls = ["tonic/tls"] -tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"] +tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls", "dep:pin-project", "dep:http-body-util"] dynamic-load-balance = ["dep:tower"] timeout = ["tokio/time", "dep:tower"] tls-client-auth = ["tonic/tls"] @@ -315,7 +315,7 @@ http = { version = "1", optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } hyper = { version = "1", optional = true } -hyper-util = { version = "0.1", optional = true } +hyper-util = { version = ">=0.1.4, <0.2", optional = true } listenfd = { version = "1.0", optional = true } bytes = { version = "1", optional = true } h2 = { version = "0.3", optional = true } @@ -323,6 +323,7 @@ tokio-rustls = { version = "0.26", optional = true, features = ["ring", "tls12"] hyper-rustls = { version = "0.27.0", features = ["http2", "ring", "tls12"], optional = true, default-features = false } rustls-pemfile = { version = "2.0.0", optional = true } tower-http = { version = "0.5", optional = true } +pin-project = { version = "1.0.11", optional = true } [build-dependencies] tonic-build = { path = "../tonic-build", features = ["prost"] } diff --git a/examples/src/tls_rustls/server.rs b/examples/src/tls_rustls/server.rs index 5630edfa1..0fb31f8fa 100644 --- a/examples/src/tls_rustls/server.rs +++ b/examples/src/tls_rustls/server.rs @@ -2,6 +2,7 @@ pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } +use http_body_util::BodyExt; use hyper::server::conn::http2::Builder; use hyper_util::rt::{TokioExecutor, TokioIo}; use pb::{EchoRequest, EchoResponse}; @@ -14,8 +15,8 @@ use tokio_rustls::{ }, TlsAcceptor, }; -use tonic::transport::server::TowerToHyperService; -use tonic::{transport::Server, Request, Response, Status}; +use tonic::{body::BoxBody, transport::Server, Request, Response, Status}; +use tower::{BoxError, ServiceExt}; use tower_http::ServiceBuilderExt; #[tokio::main] @@ -122,3 +123,70 @@ impl pb::echo_server::Echo for EchoServer { Ok(Response::new(EchoResponse { message })) } } + +/// An adaptor which converts a [`tower::Service`] to a [`hyper::service::Service`]. +/// +/// The [`hyper::service::Service`] trait is used by hyper to handle incoming requests, +/// and does not support the `poll_ready` method that is used by tower services. +/// +/// This is provided here because the equivalent adaptor in hyper-util does not support +/// tonic::body::BoxBody bodies. +#[derive(Debug, Clone)] +struct TowerToHyperService { + service: S, +} + +impl TowerToHyperService { + /// Create a new `TowerToHyperService` from a tower service. + fn new(service: S) -> Self { + Self { service } + } +} + +impl hyper::service::Service> for TowerToHyperService +where + S: tower::Service> + Clone, + S::Error: Into + 'static, +{ + type Response = S::Response; + type Error = BoxError; + type Future = TowerToHyperServiceFuture>; + + fn call(&self, req: hyper::Request) -> Self::Future { + let req = req.map(|incoming| { + incoming + .map_err(|err| Status::from_error(err.into())) + .boxed_unsync() + }); + TowerToHyperServiceFuture { + future: self.service.clone().oneshot(req), + } + } +} + +/// Future returned by [`TowerToHyperService`]. +#[derive(Debug)] +#[pin_project::pin_project] +struct TowerToHyperServiceFuture +where + S: tower::Service, +{ + #[pin] + future: tower::util::Oneshot, +} + +impl std::future::Future for TowerToHyperServiceFuture +where + S: tower::Service, + S::Error: Into + 'static, +{ + type Output = Result; + + #[inline] + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.project().future.poll(cx).map_err(Into::into) + } +} diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index fb63058ad..3fa406c77 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -10,14 +10,9 @@ mod tls; mod unix; use tokio_stream::StreamExt as _; -use tower::util::BoxCloneService; -use tower::util::Oneshot; -use tower::ServiceExt; -use tracing::debug; -use tracing::trace; +use tracing::{debug, trace}; -pub use super::service::Routes; -pub use super::service::RoutesBuilder; +pub use super::service::{Routes, RoutesBuilder}; pub use conn::{Connected, TcpConnectInfo}; use hyper_util::rt::{TokioExecutor, TokioIo}; @@ -43,19 +38,17 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, ServerIo}; -use crate::body::boxed; -use crate::body::BoxBody; +use crate::body::{boxed, BoxBody}; use crate::server::NamedService; use bytes::Bytes; use http::{Request, Response}; use http_body_util::BodyExt; use hyper::body::Incoming; use pin_project::pin_project; -use std::future::poll_fn; use std::{ convert::Infallible, fmt, - future::{self, Future}, + future::{self, poll_fn, Future}, marker::PhantomData, net::SocketAddr, pin::{pin, Pin}, @@ -69,8 +62,8 @@ use tower::{ layer::util::{Identity, Stack}, layer::Layer, limit::concurrency::ConcurrencyLimitLayer, - util::Either, - Service, ServiceBuilder, + util::{BoxCloneService, Either, Oneshot}, + Service, ServiceBuilder, ServiceExt, }; type BoxHttpBody = crate::body::BoxBody; @@ -673,30 +666,15 @@ type ConnectionBuilder = hyper_util::server::conn::auto::Builder; /// The [`hyper::service::Service`] trait is used by hyper to handle incoming requests, /// and does not support the `poll_ready` method that is used by tower services. #[derive(Debug, Copy, Clone)] -pub struct TowerToHyperService { +pub(crate) struct TowerToHyperService { service: S, } impl TowerToHyperService { /// Create a new `TowerToHyperService` from a tower service. - pub fn new(service: S) -> Self { + pub(crate) fn new(service: S) -> Self { Self { service } } - - /// Extract the inner tower service. - pub fn into_inner(self) -> S { - self.service - } - - /// Get a reference to the inner tower service. - pub fn as_inner(&self) -> &S { - &self.service - } - - /// Get a mutable reference to the inner tower service. - pub fn as_inner_mut(&mut self) -> &mut S { - &mut self.service - } } impl hyper::service::Service> for TowerToHyperService @@ -719,7 +697,7 @@ where /// Future returned by [`TowerToHyperService`]. #[derive(Debug)] #[pin_project] -pub struct TowerToHyperServiceFuture +pub(crate) struct TowerToHyperServiceFuture where S: tower_service::Service, { From 68a5bbdb5e6182660257abfcf5bcd9db1c933c3b Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Wed, 12 Jun 2024 01:39:15 +0000 Subject: [PATCH 31/31] Fixup imports in tonic::transport --- tonic/src/transport/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 8cd1fdf12..767534748 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -111,8 +111,8 @@ pub(crate) use self::service::ConnectError; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; -pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; -pub use hyper::{Body, Uri}; +pub use axum::{body::Body as AxumBoxBody, Router as AxumRouter}; +pub use hyper::{body::Body, Uri}; #[cfg(feature = "tls")] pub use tokio_rustls::rustls::pki_types::CertificateDer;