diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs index 69289865a9a..832fd5bae13 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs @@ -1,23 +1,20 @@ -use std::error::Error; -use std::fmt; - use pyo3::{exceptions::PyRuntimeError, PyErr}; +use thiserror::Error; -#[derive(Debug)] +/// Possible middleware errors that might arise. +#[derive(Error, Debug)] pub enum PyMiddlewareError { - ResponseAlreadyGone, + /// Returned when `next` is called multiple times. + #[error("next already called")] + NextAlreadyCalled, + /// Returned when request is accessed after `next` is called. + #[error("request is gone")] + RequestGone, + /// Returned when response is called after it is returned. + #[error("response is gone")] + ResponseGone, } -impl fmt::Display for PyMiddlewareError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Self::ResponseAlreadyGone => write!(f, "response is already consumed"), - } - } -} - -impl Error for PyMiddlewareError {} - impl From for PyErr { fn from(err: PyMiddlewareError) -> PyErr { PyRuntimeError::new_err(err.to_string()) diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index 22bcfedbcd6..6a14205c0cf 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -36,10 +36,10 @@ impl PyNext { let req = py_req .borrow_mut(py) .take_inner() - .ok_or_else(|| PyRuntimeError::new_err("already called"))?; + .ok_or(PyMiddlewareError::RequestGone)?; let mut inner = self .take_inner() - .ok_or_else(|| PyRuntimeError::new_err("already called"))?; + .ok_or(PyMiddlewareError::NextAlreadyCalled)?; pyo3_asyncio::tokio::future_into_py(py, async move { let res = inner .call(req) @@ -105,6 +105,6 @@ impl PyMiddlewareHandler { Ok::<_, PyErr>(py_res.take_inner()) })?; - response.ok_or_else(|| PyMiddlewareError::ResponseAlreadyGone.into()) + response.ok_or_else(|| PyMiddlewareError::ResponseGone.into()) } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs index 9a23844282e..982cc3e067e 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs @@ -238,3 +238,57 @@ def second_middleware(request, next): assert_eq!(body, "hello client from Python second middleware"); Ok(()) } + +#[pyo3_asyncio::tokio::test] +async fn fails_if_req_is_used_after_calling_next() -> PyResult<()> { + let locals = Python::with_gil(|py| { + Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) + })?; + let handler = Python::with_gil(|py| { + let locals = PyDict::new(py); + py.run( + r#" +async def middleware(request, next): + uri = request.uri + response = await next(request) + uri = request.uri # <- fails + return response +"#, + None, + Some(locals), + )?; + let handler = locals.get_item("middleware").unwrap().into(); + Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) + })?; + + let layer = PyMiddlewareLayer::::new(handler, locals); + let (mut service, mut handle) = mock::spawn_with(|svc| { + let svc = svc.map_err(|err| panic!("service failed: {err}")); + let svc = layer.layer(svc); + svc + }); + assert_ready_ok!(service.poll_ready()); + + let th = tokio::spawn(async move { + let (req, send_response) = handle.next_request().await.unwrap(); + let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap(); + assert_eq!(req_body, "hello server"); + send_response.send_response( + Response::builder() + .body(to_boxed("hello client")) + .expect("could not create response"), + ); + }); + + let request = Request::builder() + .body(Body::from("hello server")) + .expect("could not create request"); + let response = service.call(request); + + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + assert_eq!(body, r#"{"message":"RuntimeError: request is gone"}"#); + th.await.unwrap(); + Ok(()) +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/util.rs b/rust-runtime/aws-smithy-http-server-python/src/util.rs index 118c9cac58b..c2dd5bf675a 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/util.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/util.rs @@ -40,6 +40,8 @@ mod tests { #[test] fn function_metadata() -> PyResult<()> { + pyo3::prepare_freethreaded_python(); + Python::with_gil(|py| { let module = PyModule::from_code( py,