Skip to content

Commit

Permalink
Use message and status code from PyMiddlewareException
Browse files Browse the repository at this point in the history
  • Loading branch information
unexge committed Oct 21, 2022
1 parent 6f5aca2 commit 67630d5
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 11 deletions.
14 changes: 6 additions & 8 deletions rust-runtime/aws-smithy-http-server-python/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use aws_smithy_http_server::{
use aws_smithy_types::date_time::{ConversionError, DateTimeParseError};
use pyo3::{create_exception, exceptions::PyException as BasePyException, prelude::*, PyErr};
use thiserror::Error;
use tower::BoxError;

/// Python error that implements foreign errors.
#[derive(Error, Debug)]
Expand Down Expand Up @@ -64,13 +63,12 @@ impl PyMiddlewareException {

impl From<PyErr> for PyMiddlewareException {
fn from(other: PyErr) -> Self {
Self::newpy(other.to_string(), None)
}
}

impl From<BoxError> for PyMiddlewareException {
fn from(other: BoxError) -> Self {
Self::newpy(other.to_string(), None)
// Try to extract `PyMiddlewareException` from `PyErr` and use that if succeed
let middleware_err = Python::with_gil(|py| other.to_object(py).extract::<Self>(py));
match middleware_err {
Ok(err) => err,
Err(_) => Self::newpy(other.to_string(), None),
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::error::Error;
use std::fmt;

use pyo3::{exceptions::PyRuntimeError, PyErr};

#[derive(Debug)]
pub enum PyMiddlewareError {
ResponseAlreadyGone,
Expand All @@ -15,3 +17,9 @@ impl fmt::Display for PyMiddlewareError {
}

impl Error for PyMiddlewareError {}

impl From<PyMiddlewareError> for PyErr {
fn from(err: PyMiddlewareError) -> PyErr {
PyRuntimeError::new_err(err.to_string())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl PyMiddlewareHandler {
req: Request<Body>,
next: PyNextInner,
locals: TaskLocals,
) -> Result<Response<BoxBody>, BoxError> {
) -> PyResult<Response<BoxBody>> {
let py_req = PyRequest::new(req);
let py_next = PyNext::new(next);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use aws_smithy_http_server::body::to_boxed;
use aws_smithy_http_server::proto::rest_json_1::RestJson1;
use aws_smithy_http_server_python::{
middleware::{PyMiddlewareHandler, PyMiddlewareLayer},
PyResponse,
PyMiddlewareException, PyResponse,
};
use http::{Request, Response};
use http::{Request, Response, StatusCode};
use hyper::Body;
use pretty_assertions::assert_eq;
use pyo3::{
Expand Down Expand Up @@ -115,6 +115,97 @@ def middleware(request, next):
Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn convert_exception_from_middleware_to_protocol_specific_response() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let py_handler = Python::with_gil(|py| {
let locals = PyDict::new(py);
py.run(
r#"
def middleware(request, next):
raise RuntimeError("fail")
"#,
None,
Some(locals),
)?;
Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into())
})?;
let handler = PyMiddlewareHandler {
func: py_handler,
name: "middleware".to_string(),
is_coroutine: false,
};

let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, _handle) = mock::spawn_with(|svc| {
let svc = svc.map_err(|err| panic!("service failed: {err}"));
let svc = layer.layer(svc);
svc
});
assert_ready_ok!(service.poll_ready());

let request = Request::builder()
.body(Body::from("hello server"))
.expect("could not create request");
let response = service.call(request);

let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(body, r#"{"message":"RuntimeError: fail"}"#);
Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn uses_status_code_and_message_from_middleware_exception() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let py_handler = Python::with_gil(|py| {
let globals = [(
"MiddlewareException",
py.get_type::<PyMiddlewareException>(),
)]
.into_py_dict(py);
let locals = PyDict::new(py);
py.run(
r#"
def middleware(request, next):
raise MiddlewareException("access denied", 401)
"#,
Some(globals),
Some(locals),
)?;
Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into())
})?;
let handler = PyMiddlewareHandler {
func: py_handler,
name: "middleware".to_string(),
is_coroutine: false,
};

let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, _handle) = mock::spawn_with(|svc| {
let svc = svc.map_err(|err| panic!("service failed: {err}"));
let svc = layer.layer(svc);
svc
});
assert_ready_ok!(service.poll_ready());

let request = Request::builder()
.body(Body::from("hello server"))
.expect("could not create request");
let response = service.call(request);

let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(body, r#"{"message":"access denied"}"#);
Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn nested_middlewares() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Expand Down

0 comments on commit 67630d5

Please sign in to comment.