diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs index 39037bc9e8d..11546dca4ff 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs @@ -244,7 +244,7 @@ def second_middleware(request, next): } #[pyo3_asyncio::tokio::test] -async fn changes_req_body() -> PyResult<()> { +async fn changes_request() -> PyResult<()> { let locals = Python::with_gil(|py| { Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) })?; @@ -256,6 +256,7 @@ async def middleware(request, next): body = bytes(await request.body).decode() body_reversed = body[::-1] request.body = body_reversed.encode() + request.set_header("X-From-Middleware", "yes") return await next(request) "#, "", @@ -274,6 +275,7 @@ async def middleware(request, next): let th = tokio::spawn(async move { let (req, send_response) = handle.next_request().await.unwrap(); + assert_eq!(&"yes", req.headers().get("X-From-Middleware").unwrap()); let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap(); assert_eq!(req_body, "hello server".chars().rev().collect::()); send_response.send_response( @@ -296,6 +298,62 @@ async def middleware(request, next): Ok(()) } +#[pyo3_asyncio::tokio::test] +async fn changes_response() -> PyResult<()> { + let locals = Python::with_gil(|py| { + Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) + })?; + let handler = Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +async def middleware(request, next): + response = await next(request) + body = bytes(await response.body).decode() + body_reversed = body[::-1] + response.body = body_reversed.encode() + response.set_header("X-From-Middleware", "yes") + return response +"#, + "", + "", + )?; + let handler = module.getattr("middleware")?.into(); + Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) + })?; + let layer = PyMiddlewareLayer::::new(handler, locals); + let (mut service, mut handle) = mock::spawn_with(|svc| { + let svc = svc.map_err(|err| panic!("service failed: {err}")); + let svc = layer.layer(svc); + svc + }); + assert_ready_ok!(service.poll_ready()); + + let th = tokio::spawn(async move { + let (req, send_response) = handle.next_request().await.unwrap(); + let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap(); + assert_eq!(req_body, "hello server"); + send_response.send_response( + Response::builder() + .body(to_boxed("hello client")) + .expect("could not create response"), + ); + }); + + let request = Request::builder() + .body(Body::from("hello server")) + .expect("could not create request"); + let response = service.call(request); + + let response = response.await.unwrap(); + assert_eq!(&"yes", response.headers().get("X-From-Middleware").unwrap()); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + assert_eq!(body, "hello client".chars().rev().collect::()); + th.await.unwrap(); + Ok(()) +} + #[pyo3_asyncio::tokio::test] async fn fails_if_req_is_used_after_calling_next() -> PyResult<()> { let locals = Python::with_gil(|py| { diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/request.rs index 2fe064a0795..804647fe1cd 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/request.rs @@ -26,12 +26,16 @@ async fn accessing_request_properties() -> PyResult<()> { py, req, r#" - assert req.method == "POST" - assert req.uri == "https://www.rust-lang.org/" - assert req.headers["accept-encoding"] == "*" - assert req.headers["x-custom"] == "42" - assert req.version == "HTTP/2.0" - "# +assert req.method == "POST" +assert req.uri == "https://www.rust-lang.org/" +assert req.headers["accept-encoding"] == "*" +assert req.headers["x-custom"] == "42" +assert req.version == "HTTP/2.0" + +assert req.headers.get("x-foo") == None +req.set_header("x-foo", "bar") +assert req.headers["x-foo"] == "bar" +"# ); Ok(()) }) diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/response.rs index d42b8541831..a4bb6989177 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/response.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/response.rs @@ -3,10 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_smithy_http_server::body::to_boxed; use aws_smithy_http_server_python::PyResponse; -use http::StatusCode; +use http::{Response, StatusCode, Version}; use pyo3::{ prelude::*, + py_run, types::{IntoPyDict, PyDict}, }; @@ -43,3 +45,62 @@ response = Response(200, {"Content-Type": "application/json"}, b"hello world") Ok(()) } + +#[pyo3_asyncio::tokio::test] +async fn accessing_response_properties() -> PyResult<()> { + let response = Response::builder() + .status(StatusCode::IM_A_TEAPOT) + .version(Version::HTTP_3) + .header("X-Secret", "42") + .body(to_boxed("hello world")) + .expect("could not build response"); + let py_response = PyResponse::new(response); + + Python::with_gil(|py| { + let res = PyCell::new(py, py_response)?; + py_run!( + py, + res, + r#" +assert res.status == 418 +assert res.version == "HTTP/3.0" +assert res.headers["x-secret"] == "42" + +assert res.headers.get("x-foo") == None +res.set_header("x-foo", "bar") +assert res.headers["x-foo"] == "bar" +"# + ); + Ok(()) + }) +} + +#[pyo3_asyncio::tokio::test] +async fn accessing_and_changing_response_body() -> PyResult<()> { + let response = Response::builder() + .body(to_boxed("hello world")) + .expect("could not build response"); + let py_response = PyResponse::new(response); + + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +async def handler(res): + assert bytes(await res.body) == b"hello world" + + res.body = b"hello world from middleware" + assert bytes(await res.body) == b"hello world from middleware" +"#, + "", + "", + )?; + let handler = module.getattr("handler")?; + + let output = handler.call1((py_response,))?; + Ok::<_, PyErr>(pyo3_asyncio::tokio::into_future(output)) + })?? + .await?; + + Ok(()) +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs index a3c2d4b3504..49ca273402c 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -10,11 +10,8 @@ use std::str::FromStr; use std::sync::Arc; use aws_smithy_http_server::body::Body; -use http::header::HeaderName; -use http::request::Parts; -use http::{HeaderValue, Request}; -use pyo3::exceptions::PyRuntimeError; -use pyo3::prelude::*; +use http::{header::HeaderName, request::Parts, HeaderValue, Request}; +use pyo3::{exceptions::PyRuntimeError, prelude::*}; use tokio::sync::Mutex; use super::PyMiddlewareError; @@ -25,7 +22,7 @@ use super::PyMiddlewareError; #[derive(Debug)] pub struct PyRequest { parts: Option, - body: Option>>>, + body: Arc>>, } impl PyRequest { @@ -34,13 +31,13 @@ impl PyRequest { let (parts, body) = request.into_parts(); Self { parts: Some(parts), - body: Some(Arc::new(Mutex::new(Some(body)))), + body: Arc::new(Mutex::new(Some(body))), } } pub fn take_inner(&mut self) -> Option> { let parts = self.parts.take()?; - let body = self.body.take()?; + let body = std::mem::replace(&mut self.body, Arc::new(Mutex::new(None))); let body = Arc::try_unwrap(body).ok()?; let body = body.into_inner().take()?; Some(Request::from_parts(parts, body)) @@ -119,12 +116,8 @@ impl PyRequest { /// Return the HTTP body of this request. /// Note that this is a costly operation because the whole request body is cloned. #[getter] - fn body<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { - let body = self - .body - .as_mut() - .map(|b| b.clone()) - .ok_or(PyMiddlewareError::RequestGone)?; + fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + let body = self.body.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { let body = { let mut body_guard = body.lock().await; @@ -143,13 +136,7 @@ impl PyRequest { /// Set the HTTP body of this request. #[setter] - fn set_body(&mut self, buf: &[u8]) -> PyResult<()> { - match self.body.as_mut() { - Some(body) => { - *body = Arc::new(Mutex::new(Some(Body::from(buf.to_owned())))); - Ok(()) - } - None => Err(PyMiddlewareError::RequestGone.into()), - } + fn set_body(&mut self, buf: &[u8]) { + self.body = Arc::new(Mutex::new(Some(Body::from(buf.to_owned())))); } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs index 1e2d6fb5c39..6b5a4a02fca 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs @@ -6,24 +6,40 @@ //! Python-compatible middleware [http::Response] implementation. use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; use aws_smithy_http_server::body::{to_boxed, BoxBody}; -use http::Response; +use http::{header::HeaderName, response::Parts, HeaderValue, Response}; use pyo3::{exceptions::PyRuntimeError, prelude::*}; +use tokio::sync::Mutex; + +use super::PyMiddlewareError; /// Python-compatible [Response] object. #[pyclass(name = "Response")] #[pyo3(text_signature = "(status, headers, body)")] -pub struct PyResponse(Option>); +pub struct PyResponse { + parts: Option, + body: Arc>>, +} impl PyResponse { /// Create a new Python-compatible [Response] structure from the Rust side. pub fn new(response: Response) -> Self { - Self(Some(response)) + let (parts, body) = response.into_parts(); + Self { + parts: Some(parts), + body: Arc::new(Mutex::new(Some(body))), + } } pub fn take_inner(&mut self) -> Option> { - self.0.take() + let parts = self.parts.take()?; + let body = std::mem::replace(&mut self.body, Arc::new(Mutex::new(None))); + let body = Arc::try_unwrap(body).ok()?; + let body = body.into_inner().take()?; + Some(Response::from_parts(parts, body)) } } @@ -48,6 +64,91 @@ impl PyResponse { .body(body.map(to_boxed).unwrap_or_default()) .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; - Ok(Self(Some(response))) + Ok(Self::new(response)) + } + + /// Return the HTTP status of this response. + #[getter] + fn status(&self) -> PyResult { + self.parts + .as_ref() + .map(|parts| parts.status.as_u16()) + .ok_or_else(|| PyMiddlewareError::ResponseGone.into()) + } + + /// Return the HTTP version of this response. + #[getter] + fn version(&self) -> PyResult { + self.parts + .as_ref() + .map(|parts| format!("{:?}", parts.version)) + .ok_or_else(|| PyMiddlewareError::ResponseGone.into()) + } + + /// Return the HTTP headers of this response. + /// TODO(can we use `Py::clone_ref()` to prevent cloning the hashmap?) + #[getter] + fn headers(&self) -> PyResult> { + self.parts + .as_ref() + .map(|parts| { + parts + .headers + .iter() + .map(|(k, v)| -> (String, String) { + let name: String = k.to_string(); + let value: String = String::from_utf8_lossy(v.as_bytes()).to_string(); + (name, value) + }) + .collect() + }) + .ok_or_else(|| PyMiddlewareError::ResponseGone.into()) + } + + /// Insert a new key/value into this response's headers. + /// TODO(investigate if using a PyDict can make the experience more idiomatic) + /// I'd like to be able to do response.headers.get("my-header") and + /// response.headers["my-header"] = 42 instead of implementing set_header() and get_header() + /// under pymethods. The same applies to request. + #[pyo3(text_signature = "($self, key, value)")] + fn set_header(&mut self, key: &str, value: &str) -> PyResult<()> { + match self.parts.as_mut() { + Some(parts) => { + let key = HeaderName::from_str(key) + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + let value = HeaderValue::from_str(value) + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + parts.headers.insert(key, value); + Ok(()) + } + None => Err(PyMiddlewareError::ResponseGone.into()), + } + } + + /// Return the HTTP body of this response. + /// Note that this is a costly operation because the whole response body is cloned. + #[getter] + fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + let body = self.body.clone(); + pyo3_asyncio::tokio::future_into_py(py, async move { + let body = { + let mut body_guard = body.lock().await; + let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?; + let body = hyper::body::to_bytes(body) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + let buf = body.clone(); + body_guard.replace(to_boxed(body)); + buf + }; + // TODO: can we use `PyBytes` here? + Ok(body.to_vec()) + }) + } + + /// Set the HTTP body of this response. + #[setter] + fn set_body(&mut self, buf: &[u8]) { + self.body = Arc::new(Mutex::new(Some(to_boxed(buf.to_owned())))); } }