Skip to content

Commit

Permalink
Add AsyncRequireAuthorization (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
ranile authored Sep 18, 2021
1 parent 1ec7dcc commit 773a781
Show file tree
Hide file tree
Showing 2 changed files with 357 additions and 0 deletions.
353 changes: 353 additions & 0 deletions tower-http/src/auth/async_require_authorization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
//! Authorize requests using the [`Authorization`] header asynchronously.
//!
//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
//!
//! # Example
//!
//! ```
//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
//! use hyper::{Request, Response, Body, Error};
//! use http::{StatusCode, header::AUTHORIZATION};
//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
//! use futures_util::future::BoxFuture;
//!
//! #[derive(Clone, Copy)]
//! struct MyAuth;
//!
//! impl AsyncAuthorizeRequest for MyAuth {
//! type Output = UserId;
//! type Future = BoxFuture<'static, Option<UserId>>;
//! type ResponseBody = Body;
//!
//! fn authorize<B>(&mut self, request: &Request<B>) -> Self::Future {
//! Box::pin(async {
//! // ...
//! # None
//! })
//! }
//!
//! fn on_authorized<B>(&mut self, request: &mut Request<B>, user_id: UserId) {
//! // Set `user_id` as a request extension so it can be accessed by other
//! // services down the stack.
//! request.extensions_mut().insert(user_id);
//! }
//!
//! fn unauthorized_response<B>(&mut self, request: &Request<B>) -> Response<Body> {
//! Response::builder()
//! .status(StatusCode::UNAUTHORIZED)
//! .body(Body::empty())
//! .unwrap()
//! }
//! }
//!
//! #[derive(Debug)]
//! struct UserId(String);
//!
//! async fn handle(request: Request<Body>) -> Result<Response<Body>, Error> {
//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
//! // request was authorized and `UserId` will be present.
//! let user_id = request
//! .extensions()
//! .get::<UserId>()
//! .expect("UserId will be there if request was authorized");
//!
//! println!("request from {:?}", user_id);
//!
//! Ok(Response::new(Body::empty()))
//! }
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let service = ServiceBuilder::new()
//! // Authorize requests using `MyAuth`
//! .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
//! .service_fn(handle);
//! # Ok(())
//! # }
//! ```

use futures_core::ready;
use http::{Request, Response};
use http_body::Body;
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;

/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
/// [`Authorization`] header.
///
/// See the [module docs](crate::auth::async_require_authorization) for an example.
///
/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
#[derive(Debug, Clone)]
pub struct AsyncRequireAuthorizationLayer<T> {
auth: T,
}

impl<T> AsyncRequireAuthorizationLayer<T>
where
T: AsyncAuthorizeRequest,
{
/// Authorize requests using a custom scheme.
pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
Self { auth }
}
}

impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
where
T: Clone + AsyncAuthorizeRequest,
{
type Service = AsyncRequireAuthorization<S, T>;

fn layer(&self, inner: S) -> Self::Service {
AsyncRequireAuthorization::new(inner, self.auth.clone())
}
}

/// Middleware that authorizes all requests using the [`Authorization`] header.
///
/// See the [module docs](crate::auth::async_require_authorization) for an example.
///
/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
#[derive(Clone, Debug)]
pub struct AsyncRequireAuthorization<S, T> {
inner: S,
auth: T,
}

impl<S, T> AsyncRequireAuthorization<S, T> {
define_inner_service_accessors!();
}

impl<S, T> AsyncRequireAuthorization<S, T>
where
T: AsyncAuthorizeRequest,
{
/// Authorize requests using a custom scheme.
///
/// The `Authorization` header is required to have the value provided.
pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
Self { inner, auth }
}
}

impl<ReqBody, ResBody, S, T> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, T>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
ResBody: Default,
T: AsyncAuthorizeRequest<ResponseBody = ResBody> + Clone,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = ResponseFuture<T, S, ReqBody>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let auth = self.auth.clone();
let inner = self.inner.clone();
let authorize = self.auth.authorize(&req);

ResponseFuture {
auth,
state: State::Authorize {
authorize,
req: Some(req),
},
service: inner,
}
}
}

#[pin_project(project = StateProj)]
enum State<A, ReqBody, SFut> {
Authorize {
#[pin]
authorize: A,
req: Option<Request<ReqBody>>,
},
Authorized {
#[pin]
fut: SFut,
},
}

/// Response future for [`AsyncRequireAuthorization`].
#[pin_project]
pub struct ResponseFuture<Auth, S, ReqBody>
where
Auth: AsyncAuthorizeRequest,
S: Service<Request<ReqBody>>,
{
auth: Auth,
#[pin]
state: State<Auth::Future, ReqBody, S::Future>,
service: S,
}

impl<Auth, S, ReqBody, B> Future for ResponseFuture<Auth, S, ReqBody>
where
Auth: AsyncAuthorizeRequest<ResponseBody = B>,
S: Service<Request<ReqBody>, Response = Response<B>>,
{
type Output = Result<Response<B>, S::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();

loop {
match this.state.as_mut().project() {
StateProj::Authorize { authorize, req } => {
let auth = ready!(authorize.poll(cx));
let mut req = req.take().expect("future polled after completion");
match auth {
Some(output) => {
this.auth.on_authorized(&mut req, output);
let fut = this.service.call(req);
this.state.set(State::Authorized { fut })
}
None => {
let res = this.auth.unauthorized_response(&req);
return Poll::Ready(Ok(res));
}
};
}
StateProj::Authorized { fut } => {
return fut.poll(cx);
}
}
}
}
}

/// Trait for authorizing requests.
pub trait AsyncAuthorizeRequest {
/// The output type of doing the authorization.
///
/// Use `()` if authorization doesn't produce any meaningful output.
type Output;

/// The Future type returned by `authorize`
type Future: Future<Output = Option<Self::Output>>;

/// The body type used for responses to unauthorized requests.
type ResponseBody: Body;

/// Authorize the request.
///
/// If the future resolves to `Some(_)` then the request is allowed through, otherwise not.
fn authorize<B>(&mut self, request: &Request<B>) -> Self::Future;

/// Callback for when a request has been successfully authorized.
///
/// For example this allows you to save `Self::Output` in a [request extension][] to make it
/// available to services further down the stack. This could for example be the "claims" for a
/// valid [JWT].
///
/// Defaults to doing nothing.
///
/// See the [module docs](crate::auth::async_require_authorization) for an example.
///
/// [request extension]: https://docs.rs/http/latest/http/struct.Extensions.html
/// [JWT]: https://jwt.io
#[inline]
fn on_authorized<B>(&mut self, _request: &mut Request<B>, _output: Self::Output) {}

/// Create the response for an unauthorized request.
fn unauthorized_response<B>(&mut self, request: &Request<B>) -> Response<Self::ResponseBody>;
}

#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use futures_util::future::BoxFuture;
use http::{header, StatusCode};
use hyper::Body;
use tower::{BoxError, ServiceBuilder, ServiceExt};

#[derive(Clone, Copy)]
struct MyAuth;

impl AsyncAuthorizeRequest for MyAuth {
type Output = UserId;
type Future = BoxFuture<'static, Option<UserId>>;
type ResponseBody = Body;

fn authorize<B>(&mut self, request: &Request<B>) -> Self::Future {
let authorized = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|it| it.to_str().ok())
.and_then(|it| it.strip_prefix("Bearer "))
.map(|it| it == "69420")
.unwrap_or(false);

Box::pin(async move {
if authorized {
Some(UserId(String::from("6969")))
} else {
None
}
})
}

fn on_authorized<B>(&mut self, request: &mut Request<B>, user_id: UserId) {
request.extensions_mut().insert(user_id);
}

fn unauthorized_response<B>(&mut self, _request: &Request<B>) -> Response<Body> {
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap()
}
}

#[derive(Debug)]
struct UserId(String);

#[tokio::test]
async fn require_async_auth_works() {
let mut service = ServiceBuilder::new()
.layer(AsyncRequireAuthorizationLayer::new(MyAuth))
.service_fn(echo);

let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer 69420")
.body(Body::empty())
.unwrap();

let res = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn require_async_auth_401() {
let mut service = ServiceBuilder::new()
.layer(AsyncRequireAuthorizationLayer::new(MyAuth))
.service_fn(echo);

let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer deez")
.body(Body::empty())
.unwrap();

let res = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}

async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}
4 changes: 4 additions & 0 deletions tower-http/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
//! Authorization related middleware.

pub mod add_authorization;
pub mod async_require_authorization;
pub mod require_authorization;

#[doc(inline)]
pub use self::{
add_authorization::{AddAuthorization, AddAuthorizationLayer},
async_require_authorization::{
AsyncAuthorizeRequest, AsyncRequireAuthorization, AsyncRequireAuthorizationLayer,
},
require_authorization::{AuthorizeRequest, RequireAuthorization, RequireAuthorizationLayer},
};

0 comments on commit 773a781

Please sign in to comment.