diff --git a/examples/src/grpc-web/server.rs b/examples/src/grpc-web/server.rs index 08883fa79..7204b07a2 100644 --- a/examples/src/grpc-web/server.rs +++ b/examples/src/grpc-web/server.rs @@ -2,6 +2,7 @@ use tonic::{transport::Server, Request, Response, Status}; use hello_world::greeter_server::{Greeter, GreeterServer}; use hello_world::{HelloReply, HelloRequest}; +use tonic_web::GrpcWebLayer; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -33,14 +34,12 @@ async fn main() -> Result<(), Box> { let greeter = MyGreeter::default(); let greeter = GreeterServer::new(greeter); - let greeter = tonic_web::config() - .allow_origins(vec!["127.0.0.1"]) - .enable(greeter); println!("GreeterServer listening on {}", addr); Server::builder() .accept_http1(true) + .layer(GrpcWebLayer::new()) .add_service(greeter) .serve(addr) .await?; diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml index d8e4b7bb6..f4285bc01 100644 --- a/tonic-web/Cargo.toml +++ b/tonic-web/Cargo.toml @@ -25,6 +25,7 @@ pin-project = "1" tonic = {version = "0.8", path = "../tonic", default-features = false, features = ["transport"]} tower-service = "0.3" tower-layer = "0.3" +tower-http = { version = "0.3", features = ["cors"] } tracing = "0.1" [dev-dependencies] diff --git a/tonic-web/src/config.rs b/tonic-web/src/config.rs deleted file mode 100644 index 5c965a18a..000000000 --- a/tonic-web/src/config.rs +++ /dev/null @@ -1,166 +0,0 @@ -use std::collections::{BTreeSet, HashSet}; -use std::convert::TryFrom; -use std::time::Duration; - -use http::{header::HeaderName, HeaderValue}; -use tonic::body::BoxBody; -use tower_service::Service; - -use crate::service::GrpcWeb; -use crate::BoxError; - -const DEFAULT_MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60); - -const DEFAULT_EXPOSED_HEADERS: [&str; 2] = ["grpc-status", "grpc-message"]; - -/// A Configuration builder for grpc_web services. -/// -/// `Config` can be used to tweak the behavior of tonic_web services. Currently, -/// `Config` instances only expose cors settings. However, since tonic_web is designed to work -/// with grpc-web compliant clients only, some cors options have specific default values and not -/// all settings are configurable. -/// -/// ## Default values and configuration options -/// -/// * `allow-origin`: All origins allowed by default. Configurable, but null and wildcard origins -/// are not supported. -/// * `allow-methods`: `[POST,OPTIONS]`. Not configurable. -/// * `allow-headers`: Set to whatever the `OPTIONS` request carries. Not configurable. -/// * `allow-credentials`: `true`. Configurable. -/// * `max-age`: `86400`. Configurable. -/// * `expose-headers`: `grpc-status,grpc-message`. Configurable but values can only be added. -/// `grpc-status` and `grpc-message` will always be exposed. -#[derive(Debug, Clone)] -pub struct Config { - pub(crate) allowed_origins: AllowedOrigins, - pub(crate) exposed_headers: HashSet, - pub(crate) max_age: Option, - pub(crate) allow_credentials: bool, -} - -#[derive(Debug, Clone)] -pub(crate) enum AllowedOrigins { - Any, - #[allow(clippy::mutable_key_type)] - Only(BTreeSet), -} - -impl AllowedOrigins { - pub(crate) fn is_allowed(&self, origin: &HeaderValue) -> bool { - match self { - AllowedOrigins::Any => true, - AllowedOrigins::Only(origins) => origins.contains(origin), - } - } -} - -impl Config { - pub(crate) fn new() -> Config { - Config { - allowed_origins: AllowedOrigins::Any, - exposed_headers: DEFAULT_EXPOSED_HEADERS - .iter() - .cloned() - .map(HeaderName::from_static) - .collect(), - max_age: Some(DEFAULT_MAX_AGE), - allow_credentials: true, - } - } - - /// Allow any origin to access this resource. - /// - /// This is the default value. - pub fn allow_all_origins(self) -> Config { - Self { - allowed_origins: AllowedOrigins::Any, - ..self - } - } - - /// Only allow a specific set of origins to access this resource. - /// - /// ## Example - /// - /// ``` - /// tonic_web::config().allow_origins(vec!["http://a.com", "http://b.com"]); - /// ``` - pub fn allow_origins(self, origins: I) -> Config - where - I: IntoIterator, - HeaderValue: TryFrom, - { - // false positive when using HeaderValue, which uses Bytes internally - // https://rust-lang.github.io/rust-clippy/master/index.html#mutable_key_type - #[allow(clippy::mutable_key_type)] - let origins = origins - .into_iter() - .map(|v| match TryFrom::try_from(v) { - Ok(uri) => uri, - Err(_) => panic!("invalid origin"), - }) - .collect(); - - Self { - allowed_origins: AllowedOrigins::Only(origins), - ..self - } - } - - /// Adds multiple headers to the list of exposed headers. - /// - /// Default: `grpc-status,grpc-message`. These will always be included. - pub fn expose_headers(mut self, headers: I) -> Config - where - I: IntoIterator, - HeaderName: TryFrom, - { - let iter = headers - .into_iter() - .map(|header| match TryFrom::try_from(header) { - Ok(header) => header, - Err(_) => panic!("invalid header"), - }); - - self.exposed_headers.extend(iter); - self - } - - /// Defines the maximum cache lifetime for operations allowed on this - /// resource. - /// - /// Default: "86400" (24 hours) - pub fn max_age>>(self, max_age: T) -> Config { - Self { - max_age: max_age.into(), - ..self - } - } - - /// If true, the `access-control-allow-credentials` will be sent. - /// - /// Default: true - pub fn allow_credentials(self, allow_credentials: bool) -> Config { - Self { - allow_credentials, - ..self - } - } - - /// enable a tonic service to handle grpc-web requests with this configuration values. - pub fn enable(&self, service: S) -> GrpcWeb - where - S: Service, Response = http::Response>, - S: Clone + Send + 'static, - S::Future: Send + 'static, - S::Error: Into + Send, - { - GrpcWeb::new(service, self.clone()) - } -} - -impl Default for Config { - fn default() -> Self { - Config::new() - } -} diff --git a/tonic-web/src/cors.rs b/tonic-web/src/cors.rs deleted file mode 100644 index 7b49d35b0..000000000 --- a/tonic-web/src/cors.rs +++ /dev/null @@ -1,402 +0,0 @@ -use std::sync::Arc; - -pub(crate) use http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS as ALLOW_CREDENTIALS; -pub(crate) use http::header::ACCESS_CONTROL_ALLOW_HEADERS as ALLOW_HEADERS; -pub(crate) use http::header::ACCESS_CONTROL_ALLOW_METHODS as ALLOW_METHODS; -pub(crate) use http::header::ACCESS_CONTROL_ALLOW_ORIGIN as ALLOW_ORIGIN; -pub(crate) use http::header::ACCESS_CONTROL_EXPOSE_HEADERS as EXPOSE_HEADERS; -pub(crate) use http::header::ACCESS_CONTROL_MAX_AGE as MAX_AGE; -pub(crate) use http::header::ACCESS_CONTROL_REQUEST_HEADERS as REQUEST_HEADERS; -pub(crate) use http::header::ACCESS_CONTROL_REQUEST_METHOD as REQUEST_METHOD; -pub(crate) use http::header::ORIGIN; -use http::{header, HeaderMap, HeaderValue, Method}; -use tracing::debug; - -use crate::config::Config; - -const DEFAULT_ALLOWED_METHODS: &[Method; 2] = &[Method::POST, Method::OPTIONS]; - -#[derive(Debug, Clone)] -pub(crate) struct Cors { - cache: Arc, -} - -#[derive(Debug, PartialEq)] -pub(crate) enum Error { - OriginNotAllowed, - MethodNotAllowed, -} - -#[derive(Clone, Debug)] -struct Cache { - config: Config, - expose_headers: HeaderValue, - allow_methods: HeaderValue, - allow_credentials: HeaderValue, -} - -impl Cors { - pub(crate) fn new(config: Config) -> Cors { - let expose_headers = join_header_value(&config.exposed_headers).unwrap(); - let allow_methods = HeaderValue::from_static("POST,OPTIONS"); - let allow_credentials = HeaderValue::from_static("true"); - - let cache = Arc::new(Cache { - config, - expose_headers, - allow_methods, - allow_credentials, - }); - - Cors { cache } - } - - fn is_method_allowed(&self, header: Option<&HeaderValue>) -> bool { - match header { - Some(value) => match Method::from_bytes(value.as_bytes()) { - Ok(method) => DEFAULT_ALLOWED_METHODS.contains(&method), - Err(_) => { - debug!("access-control-request-method {:?} is not valid", value); - false - } - }, - None => { - debug!("access-control-request-method is missing"); - false - } - } - } - - pub(crate) fn preflight( - &self, - req_headers: &HeaderMap, - origin: &HeaderValue, - request_headers_header: &HeaderValue, - ) -> Result { - if !self.is_origin_allowed(origin) { - return Err(Error::OriginNotAllowed); - } - - if !self.is_method_allowed(req_headers.get(REQUEST_METHOD)) { - return Err(Error::MethodNotAllowed); - } - - let mut headers = self.common_headers(origin.clone()); - headers.insert(ALLOW_METHODS, self.cache.allow_methods.clone()); - headers.insert(ALLOW_HEADERS, request_headers_header.clone()); - - if let Some(max_age) = self.cache.config.max_age { - headers.insert(MAX_AGE, HeaderValue::from(max_age.as_secs())); - } - - Ok(headers) - } - - pub(crate) fn simple(&self, headers: &HeaderMap) -> Result { - match headers.get(header::ORIGIN) { - Some(origin) if self.is_origin_allowed(origin) => { - Ok(self.common_headers(origin.clone())) - } - Some(_) => Err(Error::OriginNotAllowed), - None => Ok(HeaderMap::new()), - } - } - - fn common_headers(&self, origin: HeaderValue) -> HeaderMap { - let mut headers = HeaderMap::new(); - headers.insert(ALLOW_ORIGIN, origin); - headers.insert(EXPOSE_HEADERS, self.cache.expose_headers.clone()); - - if self.cache.config.allow_credentials { - headers.insert(ALLOW_CREDENTIALS, self.cache.allow_credentials.clone()); - } - - headers - } - - fn is_origin_allowed(&self, origin: &HeaderValue) -> bool { - self.cache.config.allowed_origins.is_allowed(origin) - } - - #[cfg(test)] - pub(crate) fn __check_preflight(&self, headers: &HeaderMap) -> Result { - self.preflight( - headers, - headers.get(ORIGIN).unwrap(), - headers.get(REQUEST_HEADERS).unwrap(), - ) - } -} - -#[cfg(test)] -impl Default for Cors { - fn default() -> Self { - Cors::new(Config::default()) - } -} - -fn join_header_value(values: I) -> Result -where - I: IntoIterator, - I::Item: AsRef, -{ - let mut values = values.into_iter(); - let mut value = Vec::new(); - - if let Some(v) = values.next() { - value.extend(v.as_ref().as_bytes()); - } - for v in values { - value.push(b','); - value.extend(v.as_ref().as_bytes()); - } - HeaderValue::from_bytes(&value) -} - -#[cfg(test)] -mod tests { - use super::*; - - macro_rules! assert_value_eq { - ($header:expr, $expected:expr) => { - fn sorted(value: &str) -> Vec<&str> { - let mut vec = value.split(",").collect::>(); - vec.sort_unstable(); - vec - } - - assert_eq!(sorted($header.to_str().unwrap()), sorted($expected)) - }; - } - - fn value(s: &str) -> HeaderValue { - s.parse().unwrap() - } - - impl From for Cors { - fn from(c: Config) -> Self { - Cors::new(c) - } - } - - #[test] - #[should_panic] - #[ignore] - fn origin_is_valid_url() { - Config::new().allow_origins(vec!["foo"]); - } - - mod preflight { - use super::*; - - fn preflight_headers() -> HeaderMap { - let mut headers = HeaderMap::new(); - headers.insert(ORIGIN, value("http://example.com")); - headers.insert(REQUEST_METHOD, value("POST")); - headers.insert(REQUEST_HEADERS, value("x-grpc-web")); - headers - } - - #[test] - fn default_config() { - let cors = Cors::default(); - let headers = cors.__check_preflight(&preflight_headers()).unwrap(); - - assert_eq!(headers[ALLOW_ORIGIN], "http://example.com"); - assert_eq!(headers[ALLOW_METHODS], "POST,OPTIONS"); - assert_eq!(headers[ALLOW_HEADERS], "x-grpc-web"); - assert_eq!(headers[ALLOW_CREDENTIALS], "true"); - assert_eq!(headers[MAX_AGE], "86400"); - assert_value_eq!(&headers[EXPOSE_HEADERS], "grpc-status,grpc-message"); - } - - #[test] - fn any_origin() { - let cors: Cors = Config::new().allow_all_origins().into(); - - assert!(cors.__check_preflight(&preflight_headers()).is_ok()); - } - - #[test] - fn origin_list() { - let cors: Cors = Config::new() - .allow_origins(vec![ - HeaderValue::from_static("http://a.com"), - HeaderValue::from_static("http://b.com"), - ]) - .into(); - - let mut req_headers = preflight_headers(); - req_headers.insert(ORIGIN, value("http://b.com")); - - assert!(cors.__check_preflight(&req_headers).is_ok()); - } - - #[test] - fn origin_not_allowed() { - let cors: Cors = Config::new().allow_origins(vec!["http://a.com"]).into(); - - let err = cors.__check_preflight(&preflight_headers()).unwrap_err(); - - assert_eq!(err, Error::OriginNotAllowed) - } - - #[test] - fn disallow_credentials() { - let cors = Cors::new(Config::new().allow_credentials(false)); - let headers = cors.__check_preflight(&preflight_headers()).unwrap(); - - assert!(!headers.contains_key(ALLOW_CREDENTIALS)); - } - - #[test] - fn expose_headers_are_merged() { - let cors = Cors::new(Config::new().expose_headers(vec!["x-request-id"])); - let headers = cors.__check_preflight(&preflight_headers()).unwrap(); - - assert_value_eq!( - &headers[EXPOSE_HEADERS], - "x-request-id,grpc-message,grpc-status" - ); - } - - #[test] - fn allow_headers_echo_request_headers() { - let cors = Cors::default(); - let mut request_headers = preflight_headers(); - request_headers.insert(REQUEST_HEADERS, value("x-grpc-web,foo,x-request-id")); - - let headers = cors.__check_preflight(&request_headers).unwrap(); - - assert_value_eq!(&headers[ALLOW_HEADERS], "x-grpc-web,foo,x-request-id"); - } - - #[test] - fn missing_request_method() { - let cors = Cors::default(); - let mut request_headers = preflight_headers(); - request_headers.remove(REQUEST_METHOD); - - let err = cors.__check_preflight(&request_headers).unwrap_err(); - - assert_eq!(err, Error::MethodNotAllowed); - } - - #[test] - fn only_options_and_post_allowed() { - let cors = Cors::default(); - - for method in &[ - Method::GET, - Method::DELETE, - Method::TRACE, - Method::PATCH, - Method::PUT, - Method::HEAD, - ] { - let mut request_headers = preflight_headers(); - request_headers.insert(REQUEST_METHOD, value(method.as_str())); - - assert_eq!( - cors.__check_preflight(&request_headers).unwrap_err(), - Error::MethodNotAllowed, - ) - } - } - - #[test] - fn custom_max_age() { - use std::time::Duration; - - let cors = Cors::new(Config::new().max_age(Duration::from_secs(99))); - let headers = cors.__check_preflight(&preflight_headers()).unwrap(); - - assert_eq!(headers[MAX_AGE], "99"); - } - - #[test] - fn no_max_age() { - let cors = Cors::new(Config::new().max_age(None)); - let headers = cors.__check_preflight(&preflight_headers()).unwrap(); - - assert!(!headers.contains_key(MAX_AGE)); - } - } - - mod simple { - use super::*; - - fn request_headers() -> HeaderMap { - let mut headers = HeaderMap::new(); - headers.insert(ORIGIN, value("http://example.com")); - headers - } - - #[test] - fn default_config() { - let cors = Cors::default(); - let headers = cors.simple(&request_headers()).unwrap(); - - assert_eq!(headers[ALLOW_ORIGIN], "http://example.com"); - assert_eq!(headers[ALLOW_CREDENTIALS], "true"); - assert_value_eq!(&headers[EXPOSE_HEADERS], "grpc-message,grpc-status"); - - assert!(!headers.contains_key(ALLOW_HEADERS)); - assert!(!headers.contains_key(ALLOW_METHODS)); - assert!(!headers.contains_key(MAX_AGE)); - } - - #[test] - fn any_origin() { - let cors: Cors = Config::new().allow_all_origins().into(); - - assert!(cors.simple(&request_headers()).is_ok()); - } - - #[test] - fn origin_list() { - let cors: Cors = Config::new() - .allow_origins(vec![ - HeaderValue::from_static("http://a.com"), - HeaderValue::from_static("http://b.com"), - ]) - .into(); - - let mut req_headers = request_headers(); - req_headers.insert(ORIGIN, value("http://b.com")); - - assert!(cors.simple(&req_headers).is_ok()); - } - - #[test] - fn origin_not_allowed() { - let cors: Cors = Config::new().allow_origins(vec!["http://a.com"]).into(); - - let err = cors.simple(&request_headers()).unwrap_err(); - - assert_eq!(err, Error::OriginNotAllowed) - } - - #[test] - fn disallow_credentials() { - let cors = Cors::new(Config::new().allow_credentials(false)); - let headers = cors.simple(&request_headers()).unwrap(); - - assert!(!headers.contains_key(ALLOW_CREDENTIALS)); - } - - #[test] - fn expose_headers_are_merged() { - let cors: Cors = Config::new() - .expose_headers(vec!["x-hello", "custom-1"]) - .into(); - - let headers = cors.simple(&request_headers()).unwrap(); - - assert_value_eq!( - &headers[EXPOSE_HEADERS], - "grpc-message,grpc-status,x-hello,custom-1" - ); - } - } -} diff --git a/tonic-web/src/layer.rs b/tonic-web/src/layer.rs index 95ca2ef66..5efa7466b 100644 --- a/tonic-web/src/layer.rs +++ b/tonic-web/src/layer.rs @@ -1,4 +1,4 @@ -use super::{BoxBody, BoxError, Config, GrpcWeb}; +use super::{BoxBody, BoxError, GrpcWebService}; use tower_layer::Layer; use tower_service::Service; @@ -23,9 +23,9 @@ where S::Future: Send + 'static, S::Error: Into + Send, { - type Service = GrpcWeb; + type Service = GrpcWebService; fn layer(&self, inner: S) -> Self::Service { - Config::default().enable(inner) + GrpcWebService::new(inner) } } diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index f7d1281b7..e1c640c6c 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -34,7 +34,8 @@ //! //! ``` //! This will apply a default configuration that works well with grpc-web clients out of the box. -//! See the [`Config`] documentation for details. +//! +//! You can customize the CORS configuration composing the [`GrpcWebLayer`] with the cors layer of your choice. //! //! Alternatively, if you have a tls enabled server, you could skip setting `accept_http1` to `true`. //! This works because the browser will handle `ALPN`. @@ -77,7 +78,6 @@ //! [grpc-web]: https://github.com/grpc/grpc-web //! [tower]: https://github.com/tower-rs/tower //! [`enable`]: crate::enable() -//! [`Config`]: crate::Config #![warn( missing_debug_implementations, missing_docs, @@ -87,52 +87,56 @@ #![doc(html_root_url = "https://docs.rs/tonic-web/0.4.0")] #![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")] -pub use config::Config; pub use layer::GrpcWebLayer; -pub use service::GrpcWeb; +pub use service::GrpcWebService; mod call; -mod config; -mod cors; mod layer; mod service; +use http::header::HeaderName; use std::future::Future; use std::pin::Pin; +use std::time::Duration; use tonic::body::BoxBody; +use tower_http::cors::{AllowOrigin, Cors, CorsLayer}; +use tower_layer::Layer; use tower_service::Service; -pub use layer::GrpcWebLayer; +const DEFAULT_MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60); +const DEFAULT_EXPOSED_HEADERS: [&str; 3] = ["grpc-status", "grpc-message", "grpc-status-details-bin"]; +const DEFAULT_ALLOW_HEADERS: [&str; 4] = ["x-grpc-web", "content-type", "x-user-agent", "grpc-timeout"]; -/// enable a tonic service to handle grpc-web requests with the default configuration. +type BoxError = Box; +type BoxFuture = Pin> + Send>>; + +/// Enable a tonic service to handle grpc-web requests with the default configuration. /// -/// Shortcut for `tonic_web::config().enable(service)` -pub fn enable(service: S) -> GrpcWeb +/// You can customize the CORS configuration composing the [`GrpcWebLayer`] with the cors layer of your choice. +pub fn enable(service: S) -> Cors> where S: Service, Response = http::Response>, S: Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, { - config().enable(service) + CorsLayer::new() + .allow_origin(AllowOrigin::mirror_request()) + .allow_credentials(true) + .max_age(DEFAULT_MAX_AGE) + .expose_headers( + DEFAULT_EXPOSED_HEADERS + .iter() + .cloned() + .map(HeaderName::from_static) + .collect::>(), + ) + .allow_headers( + DEFAULT_ALLOW_HEADERS + .iter() + .cloned() + .map(HeaderName::from_static) + .collect::>(), + ) + .layer(GrpcWebService::new(service)) } - -/// returns a default [`Config`] instance for configuring services. -/// -/// ## Example -/// -/// ``` -/// let config = tonic_web::config() -/// .allow_origins(vec!["http://foo.com"]) -/// .allow_credentials(false) -/// .expose_headers(vec!["x-request-id"]); -/// -/// // let greeter = config.enable(Greeter); -/// // let route_guide = config.enable(RouteGuide); -/// ``` -pub fn config() -> Config { - Config::default() -} - -type BoxError = Box; -type BoxFuture = Pin> + Send>>; diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index c4beac1a0..ac4ef9a91 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -9,17 +9,14 @@ use tracing::{debug, trace}; use crate::call::content_types::is_grpc_web; use crate::call::{Encoding, GrpcWebCall}; -use crate::cors::Cors; -use crate::cors::{ORIGIN, REQUEST_HEADERS}; -use crate::{BoxError, BoxFuture, Config}; +use crate::{BoxError, BoxFuture}; const GRPC: &str = "application/grpc"; /// Service implementing the grpc-web protocol. #[derive(Debug, Clone)] -pub struct GrpcWeb { +pub struct GrpcWebService { inner: S, - cors: Cors, } #[derive(Debug, PartialEq)] @@ -36,44 +33,20 @@ enum RequestKind<'a> { encoding: Encoding, accept: Encoding, }, - // The request is considered a grpc-web preflight request if all these - // conditions are met: - // - // - the request method is `OPTIONS` - // - request headers include `origin` - // - `access-control-request-headers` header is present and includes `x-grpc-web` - GrpcWebPreflight { - origin: &'a HeaderValue, - request_headers: &'a HeaderValue, - }, // All other requests, including `application/grpc` Other(http::Version), } -impl GrpcWeb { - pub(crate) fn new(inner: S, config: Config) -> Self { - GrpcWeb { - inner, - cors: Cors::new(config), - } +impl GrpcWebService { + pub(crate) fn new(inner: S) -> Self { + GrpcWebService { inner } } } -impl GrpcWeb +impl GrpcWebService where S: Service, Response = Response> + Send + 'static, { - fn no_content(&self, headers: HeaderMap) -> BoxFuture { - let mut res = Response::builder() - .status(StatusCode::NO_CONTENT) - .body(empty_body()) - .unwrap(); - - res.headers_mut().extend(headers); - - Box::pin(async { Ok(res) }) - } - fn response(&self, status: StatusCode) -> BoxFuture { Box::pin(async move { Ok(Response::builder() @@ -84,7 +57,7 @@ where } } -impl Service> for GrpcWeb +impl Service> for GrpcWebService where S: Service, Response = Response> + Send + 'static, S::Future: Send + 'static, @@ -113,23 +86,12 @@ where method: &Method::POST, encoding, accept, - } => match self.cors.simple(req.headers()) { - Ok(headers) => { - trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept); - - let fut = self.inner.call(coerce_request(req, encoding)); - - Box::pin(async move { - let mut res = coerce_response(fut.await?, accept); - res.headers_mut().extend(headers); - Ok(res) - }) - } - Err(e) => { - debug!(kind = "simple", error=?e, ?req); - self.response(StatusCode::FORBIDDEN) - } - }, + } => { + trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept); + + let fut = self.inner.call(coerce_request(req, encoding)); + Box::pin(async move { Ok(coerce_response(fut.await?, accept)) }) + } // The request's content-type matches one of the 4 supported grpc-web // content-types, but the request method is not `POST`. @@ -139,24 +101,8 @@ where self.response(StatusCode::METHOD_NOT_ALLOWED) } - // A valid grpc-web preflight request, regardless of HTTP version. - // This is handled by the cors module. - RequestKind::GrpcWebPreflight { - origin, - request_headers, - } => match self.cors.preflight(req.headers(), origin, request_headers) { - Ok(headers) => { - trace!(kind = "preflight", path = ?req.uri().path(), ?origin); - self.no_content(headers) - } - Err(e) => { - debug!(kind = "preflight", error = ?e, ?req); - self.response(StatusCode::FORBIDDEN) - } - }, - - // All http/2 requests that are not grpc-web or grpc-web preflight - // are passed through to the inner service, whatever they are. + // All http/2 requests that are not grpc-web are passed through to the inner service, + // whatever they are. RequestKind::Other(Version::HTTP_2) => { debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE)); Box::pin(self.inner.call(req)) @@ -171,7 +117,7 @@ where } } -impl NamedService for GrpcWeb { +impl NamedService for GrpcWebService { const NAME: &'static str = S::NAME; } @@ -185,20 +131,6 @@ impl<'a> RequestKind<'a> { }; } - if let (&Method::OPTIONS, Some(origin), Some(value)) = - (method, headers.get(ORIGIN), headers.get(REQUEST_HEADERS)) - { - match value.to_str() { - Ok(h) if h.contains("x-grpc-web") => { - return RequestKind::GrpcWebPreflight { - origin, - request_headers: value, - }; - } - _ => {} - } - } - RequestKind::Other(version) } } @@ -241,9 +173,11 @@ fn coerce_response(res: Response, encoding: Encoding) -> Response> for Svc { @@ -307,18 +241,7 @@ mod tests { } #[tokio::test] - async fn origin_not_allowed() { - let mut svc = crate::config() - .allow_origins(vec!["http://localhost"]) - .enable(Svc); - - let res = svc.call(request()).await.unwrap(); - - assert_eq!(res.status(), StatusCode::FORBIDDEN) - } - - #[tokio::test] - async fn only_post_allowed() { + async fn only_post_and_options_allowed() { let mut svc = crate::enable(Svc); for method in &[ @@ -326,7 +249,6 @@ mod tests { Method::PUT, Method::DELETE, Method::HEAD, - Method::OPTIONS, Method::PATCH, ] { let mut req = request(); @@ -361,127 +283,23 @@ mod tests { mod options { use super::*; - use crate::cors::{REQUEST_HEADERS, REQUEST_METHOD}; - use http::HeaderValue; - - const SUCCESS: StatusCode = StatusCode::NO_CONTENT; fn request() -> Request { Request::builder() .method(Method::OPTIONS) .header(ORIGIN, "http://example.com") - .header(REQUEST_HEADERS, "x-grpc-web") - .header(REQUEST_METHOD, "POST") + .header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web") + .header(ACCESS_CONTROL_REQUEST_METHOD, "POST") .body(Body::empty()) .unwrap() } - #[tokio::test] - async fn origin_not_allowed() { - let mut svc = crate::config() - .allow_origins(vec!["http://foo.com"]) - .enable(Svc); - - let res = svc.call(request()).await.unwrap(); - - assert_eq!(res.status(), StatusCode::FORBIDDEN); - } - - #[tokio::test] - async fn missing_request_method() { - let mut svc = crate::enable(Svc); - - let mut req = request(); - req.headers_mut().remove(REQUEST_METHOD); - - let res = svc.call(req).await.unwrap(); - - assert_eq!(res.status(), StatusCode::FORBIDDEN); - } - - #[tokio::test] - async fn only_post_and_options_allowed() { - let mut svc = crate::enable(Svc); - - for method in &[ - Method::GET, - Method::PUT, - Method::DELETE, - Method::HEAD, - Method::PATCH, - ] { - let mut req = request(); - req.headers_mut().insert( - REQUEST_METHOD, - HeaderValue::from_maybe_shared(method.to_string()).unwrap(), - ); - - let res = svc.call(req).await.unwrap(); - - assert_eq!( - res.status(), - StatusCode::FORBIDDEN, - "{} should not be allowed", - method - ); - } - } - - #[tokio::test] - async fn h1_missing_origin_is_err() { - let mut svc = crate::enable(Svc); - let mut req = request(); - req.headers_mut().remove(ORIGIN); - - let res = svc.call(req).await.unwrap(); - - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - } - - #[tokio::test] - async fn h2_missing_origin_is_ok() { - let mut svc = crate::enable(Svc); - - let mut req = request(); - *req.version_mut() = Version::HTTP_2; - req.headers_mut().remove(ORIGIN); - - let res = svc.call(req).await.unwrap(); - - assert_eq!(res.status(), StatusCode::OK); - } - - #[tokio::test] - async fn h1_missing_x_grpc_web_header_is_err() { - let mut svc = crate::enable(Svc); - - let mut req = request(); - req.headers_mut().remove(REQUEST_HEADERS); - - let res = svc.call(req).await.unwrap(); - - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - } - - #[tokio::test] - async fn h2_missing_x_grpc_web_header_is_ok() { - let mut svc = crate::enable(Svc); - - let mut req = request(); - *req.version_mut() = Version::HTTP_2; - req.headers_mut().remove(REQUEST_HEADERS); - - let res = svc.call(req).await.unwrap(); - - assert_eq!(res.status(), StatusCode::OK); - } - #[tokio::test] async fn valid_grpc_web_preflight() { let mut svc = crate::enable(Svc); let res = svc.call(request()).await.unwrap(); - assert_eq!(res.status(), SUCCESS); + assert_eq!(res.status(), StatusCode::OK); } } diff --git a/tonic-web/tests/integration/tests/grpc.rs b/tonic-web/tests/integration/tests/grpc.rs index 5515adaef..8e81ef68c 100644 --- a/tonic-web/tests/integration/tests/grpc.rs +++ b/tonic-web/tests/integration/tests/grpc.rs @@ -11,6 +11,7 @@ use tonic::{Response, Streaming}; use integration::pb::{test_client::TestClient, test_server::TestServer, Input}; use integration::Svc; +use tonic_web::GrpcWebLayer; #[tokio::test] async fn smoke_unary() { @@ -113,13 +114,10 @@ async fn grpc(accept_h1: bool) -> (impl Future>, Stri async fn grpc_web(accept_h1: bool) -> (impl Future>, String) { let (listener, url) = bind().await; - let svc = tonic_web::config() - .allow_origins(vec!["http://foo.com"]) - .enable(TestServer::new(Svc)); - let fut = Server::builder() .accept_http1(accept_h1) - .add_service(svc) + .layer(GrpcWebLayer::new()) + .add_service(TestServer::new(Svc)) .serve_with_incoming(TcpListenerStream::new(listener)); (fut, url) diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index 61b686652..dc68ca25a 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -10,10 +10,11 @@ use tonic::transport::Server; use integration::pb::{test_server::TestServer, Input, Output}; use integration::Svc; +use tonic_web::GrpcWebLayer; #[tokio::test] async fn binary_request() { - let server_url = spawn("http://example.com").await; + let server_url = spawn().await; let client = Client::new(); let req = build_request(server_url, "grpc-web", "grpc-web"); @@ -36,7 +37,7 @@ async fn binary_request() { #[tokio::test] async fn text_request() { - let server_url = spawn("http://example.com").await; + let server_url = spawn().await; let client = Client::new(); let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); @@ -57,31 +58,17 @@ async fn text_request() { assert_eq!(&trailers[..], b"grpc-status:0\r\n"); } -#[tokio::test] -async fn origin_not_allowed() { - let server_url = spawn("http://foo.com").await; - let client = Client::new(); - - let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); - let res = client.request(req).await.unwrap(); - - assert_eq!(res.status(), StatusCode::FORBIDDEN); -} - -async fn spawn(allowed_origin: &str) -> String { +async fn spawn() -> String { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let listener = TcpListener::bind(addr).await.expect("listener"); let url = format!("http://{}", listener.local_addr().unwrap()); let listener_stream = TcpListenerStream::new(listener); - let svc = tonic_web::config() - .allow_origins(vec![allowed_origin]) - .enable(TestServer::new(Svc)); - let _ = tokio::spawn(async move { Server::builder() .accept_http1(true) - .add_service(svc) + .layer(GrpcWebLayer::new()) + .add_service(TestServer::new(Svc)) .serve_with_incoming(listener_stream) .await .unwrap()