From 67630d504702361913940e8fa494a2cb56b69e29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Fri, 21 Oct 2022 13:55:34 +0100 Subject: [PATCH] Use message and status code from `PyMiddlewareException` --- .../src/error.rs | 14 ++- .../src/middleware/error.rs | 8 ++ .../src/middleware/handler.rs | 2 +- .../src/middleware/pytests/layer.rs | 95 ++++++++++++++++++- 4 files changed, 108 insertions(+), 11 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index 73bfe3ce4dd..06e20e9b520 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -15,7 +15,6 @@ use aws_smithy_http_server::{ use aws_smithy_types::date_time::{ConversionError, DateTimeParseError}; use pyo3::{create_exception, exceptions::PyException as BasePyException, prelude::*, PyErr}; use thiserror::Error; -use tower::BoxError; /// Python error that implements foreign errors. #[derive(Error, Debug)] @@ -64,13 +63,12 @@ impl PyMiddlewareException { impl From for PyMiddlewareException { fn from(other: PyErr) -> Self { - Self::newpy(other.to_string(), None) - } -} - -impl From for PyMiddlewareException { - fn from(other: BoxError) -> Self { - Self::newpy(other.to_string(), None) + // Try to extract `PyMiddlewareException` from `PyErr` and use that if succeed + let middleware_err = Python::with_gil(|py| other.to_object(py).extract::(py)); + match middleware_err { + Ok(err) => err, + Err(_) => Self::newpy(other.to_string(), None), + } } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs index 65f22180c03..69289865a9a 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs @@ -1,6 +1,8 @@ use std::error::Error; use std::fmt; +use pyo3::{exceptions::PyRuntimeError, PyErr}; + #[derive(Debug)] pub enum PyMiddlewareError { ResponseAlreadyGone, @@ -15,3 +17,9 @@ impl fmt::Display for PyMiddlewareError { } impl Error for PyMiddlewareError {} + +impl From for PyErr { + fn from(err: PyMiddlewareError) -> PyErr { + PyRuntimeError::new_err(err.to_string()) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index fc35c95e0e4..3ce9121876b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -66,7 +66,7 @@ impl PyMiddlewareHandler { req: Request, next: PyNextInner, locals: TaskLocals, - ) -> Result, BoxError> { + ) -> PyResult> { let py_req = PyRequest::new(req); let py_next = PyNext::new(next); diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs index ea5ffbbb1fb..cc61aea82c7 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs @@ -2,9 +2,9 @@ use aws_smithy_http_server::body::to_boxed; use aws_smithy_http_server::proto::rest_json_1::RestJson1; use aws_smithy_http_server_python::{ middleware::{PyMiddlewareHandler, PyMiddlewareLayer}, - PyResponse, + PyMiddlewareException, PyResponse, }; -use http::{Request, Response}; +use http::{Request, Response, StatusCode}; use hyper::Body; use pretty_assertions::assert_eq; use pyo3::{ @@ -115,6 +115,97 @@ def middleware(request, next): Ok(()) } +#[pyo3_asyncio::tokio::test] +async fn convert_exception_from_middleware_to_protocol_specific_response() -> PyResult<()> { + let locals = Python::with_gil(|py| { + Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) + })?; + let py_handler = Python::with_gil(|py| { + let locals = PyDict::new(py); + py.run( + r#" +def middleware(request, next): + raise RuntimeError("fail") +"#, + None, + Some(locals), + )?; + Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into()) + })?; + let handler = PyMiddlewareHandler { + func: py_handler, + name: "middleware".to_string(), + is_coroutine: false, + }; + + let layer = PyMiddlewareLayer::::new(handler, locals); + let (mut service, _handle) = mock::spawn_with(|svc| { + let svc = svc.map_err(|err| panic!("service failed: {err}")); + let svc = layer.layer(svc); + svc + }); + assert_ready_ok!(service.poll_ready()); + + let request = Request::builder() + .body(Body::from("hello server")) + .expect("could not create request"); + let response = service.call(request); + + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + assert_eq!(body, r#"{"message":"RuntimeError: fail"}"#); + Ok(()) +} + +#[pyo3_asyncio::tokio::test] +async fn uses_status_code_and_message_from_middleware_exception() -> PyResult<()> { + let locals = Python::with_gil(|py| { + Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) + })?; + let py_handler = Python::with_gil(|py| { + let globals = [( + "MiddlewareException", + py.get_type::(), + )] + .into_py_dict(py); + let locals = PyDict::new(py); + py.run( + r#" +def middleware(request, next): + raise MiddlewareException("access denied", 401) +"#, + Some(globals), + Some(locals), + )?; + Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into()) + })?; + let handler = PyMiddlewareHandler { + func: py_handler, + name: "middleware".to_string(), + is_coroutine: false, + }; + + let layer = PyMiddlewareLayer::::new(handler, locals); + let (mut service, _handle) = mock::spawn_with(|svc| { + let svc = svc.map_err(|err| panic!("service failed: {err}")); + let svc = layer.layer(svc); + svc + }); + assert_ready_ok!(service.poll_ready()); + + let request = Request::builder() + .body(Body::from("hello server")) + .expect("could not create request"); + let response = service.call(request); + + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + assert_eq!(body, r#"{"message":"access denied"}"#); + Ok(()) +} + #[pyo3_asyncio::tokio::test] async fn nested_middlewares() -> PyResult<()> { let locals = Python::with_gil(|py| {