Skip to content

Commit

Permalink
Make middleware layer infallible
Browse files Browse the repository at this point in the history
  • Loading branch information
unexge committed Oct 21, 2022
1 parent 1fd9c73 commit 6f5aca2
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,25 +234,22 @@ class PythonApplicationGenerator(
}
rustTemplate(
"""
use #{tower}::{Layer, ServiceExt};
let mut service = #{tower}::util::BoxCloneService::new(builder.build());
let mut middlewares = self.middlewares.clone();
// Reverse the middlewares, so they run with same order as they defined
middlewares.reverse();
for handler in middlewares {
let locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
let name = handler.name.clone();
let layer = #{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(handler, locals);
service = #{tower}::util::BoxCloneService::new(layer.layer(service).map_err(move |err| {
// TODO: correctly propogate errors.
tracing::error!(
err = %err,
"'{}' middleware failed",
name,
);
loop {}
}));
{
use #{tower}::Layer;
tracing::debug!("adding middlewares to rust python router");
let mut middlewares = self.middlewares.clone();
// Reverse the middlewares, so they run with same order as they defined
middlewares.reverse();
for handler in middlewares {
tracing::debug!("adding python middleware '{}'", &handler.name);
let locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
let layer = #{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(handler, locals);
service = #{tower}::util::BoxCloneService::new(layer.layer(service));
}
}
Ok(service)
""",
"Protocol" to protocol.markerStruct(),
Expand Down
17 changes: 12 additions & 5 deletions rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
*/

//! Tower layer implementation of Python middleware handling.
use std::{
convert::Infallible,
marker::PhantomData,
task::{Context, Poll},
};
Expand All @@ -16,7 +18,7 @@ use aws_smithy_http_server::{
use futures::{future::BoxFuture, TryFutureExt};
use http::{Request, Response};
use pyo3_asyncio::TaskLocals;
use tower::{util::BoxService, BoxError, Layer, Service, ServiceExt};
use tower::{util::BoxService, Layer, Service, ServiceExt};

use super::PyMiddlewareHandler;
use crate::PyMiddlewareException;
Expand Down Expand Up @@ -85,16 +87,21 @@ impl<S> PyMiddlewareService<S> {

impl<S> Service<Request<Body>> for PyMiddlewareService<S>
where
S: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
S::Error: Into<BoxError>,
S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = S::Response;
type Error = BoxError;
// We are making `Service` `Infallible` because we convert errors to responses via
// `PyMiddlewareException::into_response` which has `IntoResponse<Protocol>` bound,
// so we always return a protocol specific error response instead of erroring out.
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|err| err.into())
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<Body>) -> Self::Future {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use pyo3::{
};
use pyo3_asyncio::TaskLocals;
use tokio_test::assert_ready_ok;
use tower::layer::util::Stack;
use tower::{layer::util::Stack, Layer, ServiceExt};
use tower_test::mock;

#[pyo3_asyncio::tokio::test]
Expand All @@ -39,8 +39,12 @@ async def identity_middleware(request, next):
is_coroutine: true,
};

let (mut service, mut handle) =
mock::spawn_layer(PyMiddlewareLayer::<RestJson1>::new(handler, locals));
let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, mut handle) = mock::spawn_with(|svc| {
let svc = svc.map_err(|err| panic!("service failed: {err}"));
let svc = layer.layer(svc);
svc
});
assert_ready_ok!(service.poll_ready());

let th = tokio::spawn(async move {
Expand Down Expand Up @@ -91,8 +95,12 @@ def middleware(request, next):
is_coroutine: false,
};

let (mut service, _handle) =
mock::spawn_layer(PyMiddlewareLayer::<RestJson1>::new(handler, locals));
let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, _handle) = mock::spawn_with(|svc| {
let svc = svc.map_err(|err| panic!("service failed: {err}"));
let svc = layer.layer(svc);
svc
});
assert_ready_ok!(service.poll_ready());

let request = Request::builder()
Expand Down Expand Up @@ -143,10 +151,15 @@ def second_middleware(request, next):
is_coroutine: false,
};

let (mut service, _handle) = mock::spawn_layer(Stack::new(
let layer = Stack::new(
PyMiddlewareLayer::<RestJson1>::new(second_handler, locals.clone()),
PyMiddlewareLayer::<RestJson1>::new(first_handler, locals.clone()),
));
);
let (mut service, _handle) = mock::spawn_with(|svc| {
let svc = svc.map_err(|err| panic!("service failed: {err}"));
let svc = layer.layer(svc);
svc
});
assert_ready_ok!(service.poll_ready());

let request = Request::builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl PyRequest {
req.headers_mut().insert(key, value);
Ok(())
}
None => return Err(PyRuntimeError::new_err("request is gone")),
None => Err(PyRuntimeError::new_err("request is gone")),
}
}
}

0 comments on commit 6f5aca2

Please sign in to comment.