diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs index 35044f3fc15..4715d3ee98b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs @@ -3,13 +3,18 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_http_server::{body::to_boxed, proto::rest_json_1::RestJson1}; +use std::convert::Infallible; + +use aws_smithy_http_server::{ + body::{to_boxed, Body, BoxBody}, + proto::rest_json_1::RestJson1, +}; use aws_smithy_http_server_python::{ middleware::{PyMiddlewareHandler, PyMiddlewareLayer}, PyMiddlewareException, PyResponse, }; use http::{Request, Response, StatusCode}; -use hyper::Body; +use lambda_http::Service; use pretty_assertions::assert_eq; use pyo3::{ prelude::*, @@ -17,34 +22,18 @@ use pyo3::{ }; use pyo3_asyncio::TaskLocals; use tokio_test::assert_ready_ok; -use tower::{layer::util::Stack, Layer, ServiceExt}; +use tower::{layer::util::Stack, util::BoxCloneService, Layer, ServiceExt}; use tower_test::mock; #[pyo3_asyncio::tokio::test] async fn identity_middleware() -> PyResult<()> { - let locals = Python::with_gil(|py| { - Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) - })?; - let handler = Python::with_gil(|py| { - let module = PyModule::from_code( - py, - r#" -async def identity_middleware(request, next): + let layer = layer( + r#" +async def middleware(request, next): return await next(request) "#, - "", - "", - )?; - let handler = module.getattr("identity_middleware")?.into(); - Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) - })?; - let layer = PyMiddlewareLayer::::new(handler, locals); - let (mut service, mut handle) = mock::spawn_with(|svc| { - let svc = svc.map_err(|err| panic!("service failed: {err}")); - let svc = layer.layer(svc); - svc - }); - assert_ready_ok!(service.poll_ready()); + ); + let (mut service, mut handle) = spawn_service(layer); let th = tokio::spawn(async move { let (req, send_response) = handle.next_request().await.unwrap(); @@ -57,201 +46,101 @@ async def identity_middleware(request, next): ); }); - let request = Request::builder() - .body(Body::from("hello server")) - .expect("could not create request"); + let request = simple_request("hello server"); let response = service.call(request); + assert_body(response.await?, "hello client").await; - let body = hyper::body::to_bytes(response.await.unwrap().into_body()) - .await - .unwrap(); - assert_eq!(body, "hello client"); th.await.unwrap(); Ok(()) } #[pyo3_asyncio::tokio::test] async fn returning_response_from_python_middleware() -> PyResult<()> { - let locals = Python::with_gil(|py| { - Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) - })?; - let handler = Python::with_gil(|py| { - let globals = [("Response", py.get_type::())].into_py_dict(py); - let locals = PyDict::new(py); - py.run( - r#" + let layer = layer( + r#" def middleware(request, next): return Response(200, {}, b"hello client from Python") "#, - Some(globals), - Some(locals), - )?; - let handler = locals.get_item("middleware").unwrap().into(); - Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) - })?; - - let layer = PyMiddlewareLayer::::new(handler, locals); - let (mut service, _handle) = mock::spawn_with(|svc| { - let svc = svc.map_err(|err| panic!("service failed: {err}")); - let svc = layer.layer(svc); - svc - }); - assert_ready_ok!(service.poll_ready()); + ); + let (mut service, _handle) = spawn_service(layer); - let request = Request::builder() - .body(Body::from("hello server")) - .expect("could not create request"); + let request = simple_request("hello server"); let response = service.call(request); + assert_body(response.await?, "hello client from Python").await; - let body = hyper::body::to_bytes(response.await.unwrap().into_body()) - .await - .unwrap(); - assert_eq!(body, "hello client from Python"); Ok(()) } #[pyo3_asyncio::tokio::test] async fn convert_exception_from_middleware_to_protocol_specific_response() -> PyResult<()> { - let locals = Python::with_gil(|py| { - Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) - })?; - let handler = Python::with_gil(|py| { - let locals = PyDict::new(py); - py.run( - r#" + let layer = layer( + r#" def middleware(request, next): raise RuntimeError("fail") "#, - None, - Some(locals), - )?; - let handler = locals.get_item("middleware").unwrap().into(); - Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) - })?; - - let layer = PyMiddlewareLayer::::new(handler, locals); - let (mut service, _handle) = mock::spawn_with(|svc| { - let svc = svc.map_err(|err| panic!("service failed: {err}")); - let svc = layer.layer(svc); - svc - }); - assert_ready_ok!(service.poll_ready()); + ); + let (mut service, _handle) = spawn_service(layer); - let request = Request::builder() - .body(Body::from("hello server")) - .expect("could not create request"); + let request = simple_request("hello server"); let response = service.call(request); - let response = response.await.unwrap(); assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - assert_eq!(body, r#"{"message":"RuntimeError: fail"}"#); + assert_body(response, r#"{"message":"RuntimeError: fail"}"#).await; + Ok(()) } #[pyo3_asyncio::tokio::test] async fn uses_status_code_and_message_from_middleware_exception() -> PyResult<()> { - let locals = Python::with_gil(|py| { - Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) - })?; - let handler = Python::with_gil(|py| { - let globals = [( - "MiddlewareException", - py.get_type::(), - )] - .into_py_dict(py); - let locals = PyDict::new(py); - py.run( - r#" + let layer = layer( + r#" def middleware(request, next): raise MiddlewareException("access denied", 401) "#, - Some(globals), - Some(locals), - )?; - let handler = locals.get_item("middleware").unwrap().into(); - Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) - })?; - - let layer = PyMiddlewareLayer::::new(handler, locals); - let (mut service, _handle) = mock::spawn_with(|svc| { - let svc = svc.map_err(|err| panic!("service failed: {err}")); - let svc = layer.layer(svc); - svc - }); - assert_ready_ok!(service.poll_ready()); + ); + let (mut service, _handle) = spawn_service(layer); - let request = Request::builder() - .body(Body::from("hello server")) - .expect("could not create request"); + let request = simple_request("hello server"); let response = service.call(request); - let response = response.await.unwrap(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - assert_eq!(body, r#"{"message":"access denied"}"#); + assert_body(response, r#"{"message":"access denied"}"#).await; + Ok(()) } #[pyo3_asyncio::tokio::test] async fn nested_middlewares() -> PyResult<()> { - let locals = Python::with_gil(|py| { - Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) - })?; - let (first_handler, second_handler) = Python::with_gil(|py| { - let globals = [("Response", py.get_type::())].into_py_dict(py); - let locals = PyDict::new(py); - - py.run( - r#" -async def first_middleware(request, next): + let first_layer = layer( + r#" +async def middleware(request, next): return await next(request) - -def second_middleware(request, next): +"#, + ); + let second_layer = layer( + r#" +def middleware(request, next): return Response(200, {}, b"hello client from Python second middleware") "#, - Some(globals), - Some(locals), - )?; - let first_handler = - PyMiddlewareHandler::new(py, locals.get_item("first_middleware").unwrap().into())?; - let second_handler = - PyMiddlewareHandler::new(py, locals.get_item("second_middleware").unwrap().into())?; - Ok::<_, PyErr>((first_handler, second_handler)) - })?; - - let layer = Stack::new( - PyMiddlewareLayer::::new(second_handler, locals.clone()), - PyMiddlewareLayer::::new(first_handler, locals.clone()), ); - let (mut service, _handle) = mock::spawn_with(|svc| { - let svc = svc.map_err(|err| panic!("service failed: {err}")); - let svc = layer.layer(svc); - svc - }); - assert_ready_ok!(service.poll_ready()); + let layer = Stack::new(first_layer, second_layer); + let (mut service, _handle) = spawn_service(layer); - let request = Request::builder() - .body(Body::from("hello server")) - .expect("could not create request"); + let request = simple_request("hello server"); let response = service.call(request); + assert_body( + response.await?, + "hello client from Python second middleware", + ) + .await; - let body = hyper::body::to_bytes(response.await.unwrap().into_body()) - .await - .unwrap(); - assert_eq!(body, "hello client from Python second middleware"); Ok(()) } #[pyo3_asyncio::tokio::test] async fn changes_request() -> PyResult<()> { - let locals = Python::with_gil(|py| { - Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) - })?; - let handler = Python::with_gil(|py| { - let module = PyModule::from_code( - py, - r#" + let layer = layer( + r#" async def middleware(request, next): body = bytes(await request.body).decode() body_reversed = body[::-1] @@ -259,19 +148,8 @@ async def middleware(request, next): request.headers["X-From-Middleware"] = "yes" return await next(request) "#, - "", - "", - )?; - let handler = module.getattr("middleware")?.into(); - Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) - })?; - let layer = PyMiddlewareLayer::::new(handler, locals); - let (mut service, mut handle) = mock::spawn_with(|svc| { - let svc = svc.map_err(|err| panic!("service failed: {err}")); - let svc = layer.layer(svc); - svc - }); - assert_ready_ok!(service.poll_ready()); + ); + let (mut service, mut handle) = spawn_service(layer); let th = tokio::spawn(async move { let (req, send_response) = handle.next_request().await.unwrap(); @@ -285,28 +163,18 @@ async def middleware(request, next): ); }); - let request = Request::builder() - .body(Body::from("hello server")) - .expect("could not create request"); + let request = simple_request("hello server"); let response = service.call(request); + assert_body(response.await?, "hello client").await; - let body = hyper::body::to_bytes(response.await.unwrap().into_body()) - .await - .unwrap(); - assert_eq!(body, "hello client"); th.await.unwrap(); Ok(()) } #[pyo3_asyncio::tokio::test] async fn changes_response() -> PyResult<()> { - let locals = Python::with_gil(|py| { - Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) - })?; - let handler = Python::with_gil(|py| { - let module = PyModule::from_code( - py, - r#" + let layer = layer( + r#" async def middleware(request, next): response = await next(request) body = bytes(await response.body).decode() @@ -315,19 +183,8 @@ async def middleware(request, next): response.headers["X-From-Middleware"] = "yes" return response "#, - "", - "", - )?; - let handler = module.getattr("middleware")?.into(); - Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) - })?; - let layer = PyMiddlewareLayer::::new(handler, locals); - let (mut service, mut handle) = mock::spawn_with(|svc| { - let svc = svc.map_err(|err| panic!("service failed: {err}")); - let svc = layer.layer(svc); - svc - }); - assert_ready_ok!(service.poll_ready()); + ); + let (mut service, mut handle) = spawn_service(layer); let th = tokio::spawn(async move { let (req, send_response) = handle.next_request().await.unwrap(); @@ -340,49 +197,29 @@ async def middleware(request, next): ); }); - let request = Request::builder() - .body(Body::from("hello server")) - .expect("could not create request"); + let request = simple_request("hello server"); let response = service.call(request); - let response = response.await.unwrap(); - assert_eq!(&"yes", response.headers().get("X-From-Middleware").unwrap()); + assert_eq!(response.headers().get("X-From-Middleware").unwrap(), &"yes"); + assert_body(response, &"hello client".chars().rev().collect::()).await; - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - assert_eq!(body, "hello client".chars().rev().collect::()); th.await.unwrap(); Ok(()) } #[pyo3_asyncio::tokio::test] async fn fails_if_req_is_used_after_calling_next() -> PyResult<()> { - let locals = Python::with_gil(|py| { - Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) - })?; - let handler = Python::with_gil(|py| { - let locals = PyDict::new(py); - py.run( - r#" + let layer = layer( + r#" async def middleware(request, next): uri = request.uri response = await next(request) uri = request.uri # <- fails return response "#, - None, - Some(locals), - )?; - let handler = locals.get_item("middleware").unwrap().into(); - Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) - })?; + ); - let layer = PyMiddlewareLayer::::new(handler, locals); - let (mut service, mut handle) = mock::spawn_with(|svc| { - let svc = svc.map_err(|err| panic!("service failed: {err}")); - let svc = layer.layer(svc); - svc - }); - assert_ready_ok!(service.poll_ready()); + let (mut service, mut handle) = spawn_service(layer); let th = tokio::spawn(async move { let (req, send_response) = handle.next_request().await.unwrap(); @@ -395,18 +232,80 @@ async def middleware(request, next): ); }); - let request = Request::builder() - .body(Body::from("hello server")) - .expect("could not create request"); + let request = simple_request("hello server"); let response = service.call(request); - let response = response.await.unwrap(); assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - assert_eq!( - body, - r#"{"message":"RuntimeError: request is accessed after `next` is called"}"# - ); + assert_body( + response, + r#"{"message":"RuntimeError: request is accessed after `next` is called"}"#, + ) + .await; + th.await.unwrap(); Ok(()) } + +async fn assert_body(response: Response, eq: &str) { + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + assert_eq!(body, eq); +} + +fn simple_request(body: &'static str) -> Request { + Request::builder() + .body(Body::from(body)) + .expect("could not create request") +} + +fn spawn_service( + layer: L, +) -> ( + mock::Spawn, + mock::Handle, Response>, +) +where + L: Layer, Response, Infallible>>, + L::Service: Service, Error = E>, + E: std::fmt::Debug, +{ + let (mut service, handle) = mock::spawn_with(|svc| { + let svc = svc + .map_err(|err| panic!("service failed: {err}")) + .boxed_clone(); + layer.layer(svc) + }); + assert_ready_ok!(service.poll_ready()); + (service, handle) +} + +fn layer(code: &str) -> PyMiddlewareLayer { + PyMiddlewareLayer::::new(py_handler(code), task_locals()) +} + +fn task_locals() -> TaskLocals { + Python::with_gil(|py| { + Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) + }) + .unwrap() +} + +fn py_handler(code: &str) -> PyMiddlewareHandler { + Python::with_gil(|py| { + let globals = [ + ( + "MiddlewareException", + py.get_type::(), + ), + ("Response", py.get_type::()), + ] + .into_py_dict(py); + let locals = PyDict::new(py); + py.run(code, Some(globals), Some(locals))?; + let handler = locals + .get_item("middleware") + .expect("your handler must be named `middleware`") + .into(); + Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) + }) + .unwrap() +}