Skip to content

Commit

Permalink
Improve middleware errors
Browse files Browse the repository at this point in the history
  • Loading branch information
unexge committed Oct 21, 2022
1 parent 8417633 commit 1b08bf5
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 18 deletions.
27 changes: 12 additions & 15 deletions rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
use std::error::Error;
use std::fmt;

use pyo3::{exceptions::PyRuntimeError, PyErr};
use thiserror::Error;

#[derive(Debug)]
/// Possible middleware errors that might arise.
#[derive(Error, Debug)]
pub enum PyMiddlewareError {
ResponseAlreadyGone,
/// Returned when `next` is called multiple times.
#[error("next already called")]
NextAlreadyCalled,
/// Returned when request is accessed after `next` is called.
#[error("request is gone")]
RequestGone,
/// Returned when response is called after it is returned.
#[error("response is gone")]
ResponseGone,
}

impl fmt::Display for PyMiddlewareError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::ResponseAlreadyGone => write!(f, "response is already consumed"),
}
}
}

impl Error for PyMiddlewareError {}

impl From<PyMiddlewareError> for PyErr {
fn from(err: PyMiddlewareError) -> PyErr {
PyRuntimeError::new_err(err.to_string())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ impl PyNext {
let req = py_req
.borrow_mut(py)
.take_inner()
.ok_or_else(|| PyRuntimeError::new_err("already called"))?;
.ok_or(PyMiddlewareError::RequestGone)?;
let mut inner = self
.take_inner()
.ok_or_else(|| PyRuntimeError::new_err("already called"))?;
.ok_or(PyMiddlewareError::NextAlreadyCalled)?;
pyo3_asyncio::tokio::future_into_py(py, async move {
let res = inner
.call(req)
Expand Down Expand Up @@ -105,6 +105,6 @@ impl PyMiddlewareHandler {
Ok::<_, PyErr>(py_res.take_inner())
})?;

response.ok_or_else(|| PyMiddlewareError::ResponseAlreadyGone.into())
response.ok_or_else(|| PyMiddlewareError::ResponseGone.into())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,57 @@ def second_middleware(request, next):
assert_eq!(body, "hello client from Python second middleware");
Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn fails_if_req_is_used_after_calling_next() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let handler = Python::with_gil(|py| {
let locals = PyDict::new(py);
py.run(
r#"
async def middleware(request, next):
uri = request.uri
response = await next(request)
uri = request.uri # <- fails
return response
"#,
None,
Some(locals),
)?;
let handler = locals.get_item("middleware").unwrap().into();
Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?)
})?;

let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, mut handle) = mock::spawn_with(|svc| {
let svc = svc.map_err(|err| panic!("service failed: {err}"));
let svc = layer.layer(svc);
svc
});
assert_ready_ok!(service.poll_ready());

let th = tokio::spawn(async move {
let (req, send_response) = handle.next_request().await.unwrap();
let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap();
assert_eq!(req_body, "hello server");
send_response.send_response(
Response::builder()
.body(to_boxed("hello client"))
.expect("could not create response"),
);
});

let request = Request::builder()
.body(Body::from("hello server"))
.expect("could not create request");
let response = service.call(request);

let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(body, r#"{"message":"RuntimeError: request is gone"}"#);
th.await.unwrap();
Ok(())
}
2 changes: 2 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ mod tests {

#[test]
fn function_metadata() -> PyResult<()> {
pyo3::prepare_freethreaded_python();

Python::with_gil(|py| {
let module = PyModule::from_code(
py,
Expand Down

0 comments on commit 1b08bf5

Please sign in to comment.