diff --git a/tower-http/src/auth/async_require_authorization.rs b/tower-http/src/auth/async_require_authorization.rs new file mode 100644 index 00000000..cccaa226 --- /dev/null +++ b/tower-http/src/auth/async_require_authorization.rs @@ -0,0 +1,353 @@ +//! Authorize requests using the [`Authorization`] header asynchronously. +//! +//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +//! +//! # Example +//! +//! ``` +//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; +//! use hyper::{Request, Response, Body, Error}; +//! use http::{StatusCode, header::AUTHORIZATION}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use futures_util::future::BoxFuture; +//! +//! #[derive(Clone, Copy)] +//! struct MyAuth; +//! +//! impl AsyncAuthorizeRequest for MyAuth { +//! type Output = UserId; +//! type Future = BoxFuture<'static, Option>; +//! type ResponseBody = Body; +//! +//! fn authorize(&mut self, request: &Request) -> Self::Future { +//! Box::pin(async { +//! // ... +//! # None +//! }) +//! } +//! +//! fn on_authorized(&mut self, request: &mut Request, user_id: UserId) { +//! // Set `user_id` as a request extension so it can be accessed by other +//! // services down the stack. +//! request.extensions_mut().insert(user_id); +//! } +//! +//! fn unauthorized_response(&mut self, request: &Request) -> Response { +//! Response::builder() +//! .status(StatusCode::UNAUTHORIZED) +//! .body(Body::empty()) +//! .unwrap() +//! } +//! } +//! +//! #[derive(Debug)] +//! struct UserId(String); +//! +//! async fn handle(request: Request) -> Result, Error> { +//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the +//! // request was authorized and `UserId` will be present. +//! let user_id = request +//! .extensions() +//! .get::() +//! .expect("UserId will be there if request was authorized"); +//! +//! println!("request from {:?}", user_id); +//! +//! Ok(Response::new(Body::empty())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let service = ServiceBuilder::new() +//! // Authorize requests using `MyAuth` +//! .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` + +use futures_core::ready; +use http::{Request, Response}; +use http_body::Body; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the +/// [`Authorization`] header. +/// +/// See the [module docs](crate::auth::async_require_authorization) for an example. +/// +/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +#[derive(Debug, Clone)] +pub struct AsyncRequireAuthorizationLayer { + auth: T, +} + +impl AsyncRequireAuthorizationLayer +where + T: AsyncAuthorizeRequest, +{ + /// Authorize requests using a custom scheme. + pub fn new(auth: T) -> AsyncRequireAuthorizationLayer { + Self { auth } + } +} + +impl Layer for AsyncRequireAuthorizationLayer +where + T: Clone + AsyncAuthorizeRequest, +{ + type Service = AsyncRequireAuthorization; + + fn layer(&self, inner: S) -> Self::Service { + AsyncRequireAuthorization::new(inner, self.auth.clone()) + } +} + +/// Middleware that authorizes all requests using the [`Authorization`] header. +/// +/// See the [module docs](crate::auth::async_require_authorization) for an example. +/// +/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +#[derive(Clone, Debug)] +pub struct AsyncRequireAuthorization { + inner: S, + auth: T, +} + +impl AsyncRequireAuthorization { + define_inner_service_accessors!(); +} + +impl AsyncRequireAuthorization +where + T: AsyncAuthorizeRequest, +{ + /// Authorize requests using a custom scheme. + /// + /// The `Authorization` header is required to have the value provided. + pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization { + Self { inner, auth } + } +} + +impl Service> for AsyncRequireAuthorization +where + S: Service, Response = Response> + Clone, + ResBody: Default, + T: AsyncAuthorizeRequest + Clone, +{ + type Response = Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let auth = self.auth.clone(); + let inner = self.inner.clone(); + let authorize = self.auth.authorize(&req); + + ResponseFuture { + auth, + state: State::Authorize { + authorize, + req: Some(req), + }, + service: inner, + } + } +} + +#[pin_project(project = StateProj)] +enum State { + Authorize { + #[pin] + authorize: A, + req: Option>, + }, + Authorized { + #[pin] + fut: SFut, + }, +} + +/// Response future for [`AsyncRequireAuthorization`]. +#[pin_project] +pub struct ResponseFuture +where + Auth: AsyncAuthorizeRequest, + S: Service>, +{ + auth: Auth, + #[pin] + state: State, + service: S, +} + +impl Future for ResponseFuture +where + Auth: AsyncAuthorizeRequest, + S: Service, Response = Response>, +{ + type Output = Result, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + match this.state.as_mut().project() { + StateProj::Authorize { authorize, req } => { + let auth = ready!(authorize.poll(cx)); + let mut req = req.take().expect("future polled after completion"); + match auth { + Some(output) => { + this.auth.on_authorized(&mut req, output); + let fut = this.service.call(req); + this.state.set(State::Authorized { fut }) + } + None => { + let res = this.auth.unauthorized_response(&req); + return Poll::Ready(Ok(res)); + } + }; + } + StateProj::Authorized { fut } => { + return fut.poll(cx); + } + } + } + } +} + +/// Trait for authorizing requests. +pub trait AsyncAuthorizeRequest { + /// The output type of doing the authorization. + /// + /// Use `()` if authorization doesn't produce any meaningful output. + type Output; + + /// The Future type returned by `authorize` + type Future: Future>; + + /// The body type used for responses to unauthorized requests. + type ResponseBody: Body; + + /// Authorize the request. + /// + /// If the future resolves to `Some(_)` then the request is allowed through, otherwise not. + fn authorize(&mut self, request: &Request) -> Self::Future; + + /// Callback for when a request has been successfully authorized. + /// + /// For example this allows you to save `Self::Output` in a [request extension][] to make it + /// available to services further down the stack. This could for example be the "claims" for a + /// valid [JWT]. + /// + /// Defaults to doing nothing. + /// + /// See the [module docs](crate::auth::async_require_authorization) for an example. + /// + /// [request extension]: https://docs.rs/http/latest/http/struct.Extensions.html + /// [JWT]: https://jwt.io + #[inline] + fn on_authorized(&mut self, _request: &mut Request, _output: Self::Output) {} + + /// Create the response for an unauthorized request. + fn unauthorized_response(&mut self, request: &Request) -> Response; +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use futures_util::future::BoxFuture; + use http::{header, StatusCode}; + use hyper::Body; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + + #[derive(Clone, Copy)] + struct MyAuth; + + impl AsyncAuthorizeRequest for MyAuth { + type Output = UserId; + type Future = BoxFuture<'static, Option>; + type ResponseBody = Body; + + fn authorize(&mut self, request: &Request) -> Self::Future { + let authorized = request + .headers() + .get(header::AUTHORIZATION) + .and_then(|it| it.to_str().ok()) + .and_then(|it| it.strip_prefix("Bearer ")) + .map(|it| it == "69420") + .unwrap_or(false); + + Box::pin(async move { + if authorized { + Some(UserId(String::from("6969"))) + } else { + None + } + }) + } + + fn on_authorized(&mut self, request: &mut Request, user_id: UserId) { + request.extensions_mut().insert(user_id); + } + + fn unauthorized_response(&mut self, _request: &Request) -> Response { + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(Body::empty()) + .unwrap() + } + } + + #[derive(Debug)] + struct UserId(String); + + #[tokio::test] + async fn require_async_auth_works() { + let mut service = ServiceBuilder::new() + .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer 69420") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn require_async_auth_401() { + let mut service = ServiceBuilder::new() + .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer deez") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + + async fn echo(req: Request) -> Result, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/tower-http/src/auth/mod.rs b/tower-http/src/auth/mod.rs index da40b149..01b95c9d 100644 --- a/tower-http/src/auth/mod.rs +++ b/tower-http/src/auth/mod.rs @@ -1,10 +1,14 @@ //! Authorization related middleware. pub mod add_authorization; +pub mod async_require_authorization; pub mod require_authorization; #[doc(inline)] pub use self::{ add_authorization::{AddAuthorization, AddAuthorizationLayer}, + async_require_authorization::{ + AsyncAuthorizeRequest, AsyncRequireAuthorization, AsyncRequireAuthorizationLayer, + }, require_authorization::{AuthorizeRequest, RequireAuthorization, RequireAuthorizationLayer}, };