Skip to content

Commit

Permalink
Allow changing response
Browse files Browse the repository at this point in the history
  • Loading branch information
unexge committed Oct 24, 2022
1 parent 70c1826 commit d8eb21d
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def second_middleware(request, next):
}

#[pyo3_asyncio::tokio::test]
async fn changes_req_body() -> PyResult<()> {
async fn changes_request() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
Expand All @@ -256,6 +256,7 @@ async def middleware(request, next):
body = bytes(await request.body).decode()
body_reversed = body[::-1]
request.body = body_reversed.encode()
request.set_header("X-From-Middleware", "yes")
return await next(request)
"#,
"",
Expand All @@ -274,6 +275,7 @@ async def middleware(request, next):

let th = tokio::spawn(async move {
let (req, send_response) = handle.next_request().await.unwrap();
assert_eq!(&"yes", req.headers().get("X-From-Middleware").unwrap());
let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap();
assert_eq!(req_body, "hello server".chars().rev().collect::<String>());
send_response.send_response(
Expand All @@ -296,6 +298,62 @@ async def middleware(request, next):
Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn changes_response() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let handler = Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
async def middleware(request, next):
response = await next(request)
body = bytes(await response.body).decode()
body_reversed = body[::-1]
response.body = body_reversed.encode()
response.set_header("X-From-Middleware", "yes")
return response
"#,
"",
"",
)?;
let handler = module.getattr("middleware")?.into();
Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?)
})?;
let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, mut handle) = mock::spawn_with(|svc| {
let svc = svc.map_err(|err| panic!("service failed: {err}"));
let svc = layer.layer(svc);
svc
});
assert_ready_ok!(service.poll_ready());

let th = tokio::spawn(async move {
let (req, send_response) = handle.next_request().await.unwrap();
let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap();
assert_eq!(req_body, "hello server");
send_response.send_response(
Response::builder()
.body(to_boxed("hello client"))
.expect("could not create response"),
);
});

let request = Request::builder()
.body(Body::from("hello server"))
.expect("could not create request");
let response = service.call(request);

let response = response.await.unwrap();
assert_eq!(&"yes", response.headers().get("X-From-Middleware").unwrap());

let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(body, "hello client".chars().rev().collect::<String>());
th.await.unwrap();
Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn fails_if_req_is_used_after_calling_next() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ async fn accessing_request_properties() -> PyResult<()> {
py,
req,
r#"
assert req.method == "POST"
assert req.uri == "https://www.rust-lang.org/"
assert req.headers["accept-encoding"] == "*"
assert req.headers["x-custom"] == "42"
assert req.version == "HTTP/2.0"
"#
assert req.method == "POST"
assert req.uri == "https://www.rust-lang.org/"
assert req.headers["accept-encoding"] == "*"
assert req.headers["x-custom"] == "42"
assert req.version == "HTTP/2.0"
assert req.headers.get("x-foo") == None
req.set_header("x-foo", "bar")
assert req.headers["x-foo"] == "bar"
"#
);
Ok(())
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
* SPDX-License-Identifier: Apache-2.0
*/

use aws_smithy_http_server::body::to_boxed;
use aws_smithy_http_server_python::PyResponse;
use http::StatusCode;
use http::{Response, StatusCode, Version};
use pyo3::{
prelude::*,
py_run,
types::{IntoPyDict, PyDict},
};

Expand Down Expand Up @@ -43,3 +45,62 @@ response = Response(200, {"Content-Type": "application/json"}, b"hello world")

Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn accessing_response_properties() -> PyResult<()> {
let response = Response::builder()
.status(StatusCode::IM_A_TEAPOT)
.version(Version::HTTP_3)
.header("X-Secret", "42")
.body(to_boxed("hello world"))
.expect("could not build response");
let py_response = PyResponse::new(response);

Python::with_gil(|py| {
let res = PyCell::new(py, py_response)?;
py_run!(
py,
res,
r#"
assert res.status == 418
assert res.version == "HTTP/3.0"
assert res.headers["x-secret"] == "42"
assert res.headers.get("x-foo") == None
res.set_header("x-foo", "bar")
assert res.headers["x-foo"] == "bar"
"#
);
Ok(())
})
}

#[pyo3_asyncio::tokio::test]
async fn accessing_and_changing_response_body() -> PyResult<()> {
let response = Response::builder()
.body(to_boxed("hello world"))
.expect("could not build response");
let py_response = PyResponse::new(response);

Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
async def handler(res):
assert bytes(await res.body) == b"hello world"
res.body = b"hello world from middleware"
assert bytes(await res.body) == b"hello world from middleware"
"#,
"",
"",
)?;
let handler = module.getattr("handler")?;

let output = handler.call1((py_response,))?;
Ok::<_, PyErr>(pyo3_asyncio::tokio::into_future(output))
})??
.await?;

Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@ use std::str::FromStr;
use std::sync::Arc;

use aws_smithy_http_server::body::Body;
use http::header::HeaderName;
use http::request::Parts;
use http::{HeaderValue, Request};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use http::{header::HeaderName, request::Parts, HeaderValue, Request};
use pyo3::{exceptions::PyRuntimeError, prelude::*};
use tokio::sync::Mutex;

use super::PyMiddlewareError;
Expand All @@ -25,7 +22,7 @@ use super::PyMiddlewareError;
#[derive(Debug)]
pub struct PyRequest {
parts: Option<Parts>,
body: Option<Arc<Mutex<Option<Body>>>>,
body: Arc<Mutex<Option<Body>>>,
}

impl PyRequest {
Expand All @@ -34,13 +31,13 @@ impl PyRequest {
let (parts, body) = request.into_parts();
Self {
parts: Some(parts),
body: Some(Arc::new(Mutex::new(Some(body)))),
body: Arc::new(Mutex::new(Some(body))),
}
}

pub fn take_inner(&mut self) -> Option<Request<Body>> {
let parts = self.parts.take()?;
let body = self.body.take()?;
let body = std::mem::replace(&mut self.body, Arc::new(Mutex::new(None)));
let body = Arc::try_unwrap(body).ok()?;
let body = body.into_inner().take()?;
Some(Request::from_parts(parts, body))
Expand Down Expand Up @@ -119,12 +116,8 @@ impl PyRequest {
/// Return the HTTP body of this request.
/// Note that this is a costly operation because the whole request body is cloned.
#[getter]
fn body<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> {
let body = self
.body
.as_mut()
.map(|b| b.clone())
.ok_or(PyMiddlewareError::RequestGone)?;
fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
let body = self.body.clone();
pyo3_asyncio::tokio::future_into_py(py, async move {
let body = {
let mut body_guard = body.lock().await;
Expand All @@ -143,13 +136,7 @@ impl PyRequest {

/// Set the HTTP body of this request.
#[setter]
fn set_body(&mut self, buf: &[u8]) -> PyResult<()> {
match self.body.as_mut() {
Some(body) => {
*body = Arc::new(Mutex::new(Some(Body::from(buf.to_owned()))));
Ok(())
}
None => Err(PyMiddlewareError::RequestGone.into()),
}
fn set_body(&mut self, buf: &[u8]) {
self.body = Arc::new(Mutex::new(Some(Body::from(buf.to_owned()))));
}
}
111 changes: 106 additions & 5 deletions rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,40 @@
//! Python-compatible middleware [http::Response] implementation.
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;

use aws_smithy_http_server::body::{to_boxed, BoxBody};
use http::Response;
use http::{header::HeaderName, response::Parts, HeaderValue, Response};
use pyo3::{exceptions::PyRuntimeError, prelude::*};
use tokio::sync::Mutex;

use super::PyMiddlewareError;

/// Python-compatible [Response] object.
#[pyclass(name = "Response")]
#[pyo3(text_signature = "(status, headers, body)")]
pub struct PyResponse(Option<Response<BoxBody>>);
pub struct PyResponse {
parts: Option<Parts>,
body: Arc<Mutex<Option<BoxBody>>>,
}

impl PyResponse {
/// Create a new Python-compatible [Response] structure from the Rust side.
pub fn new(response: Response<BoxBody>) -> Self {
Self(Some(response))
let (parts, body) = response.into_parts();
Self {
parts: Some(parts),
body: Arc::new(Mutex::new(Some(body))),
}
}

pub fn take_inner(&mut self) -> Option<Response<BoxBody>> {
self.0.take()
let parts = self.parts.take()?;
let body = std::mem::replace(&mut self.body, Arc::new(Mutex::new(None)));
let body = Arc::try_unwrap(body).ok()?;
let body = body.into_inner().take()?;
Some(Response::from_parts(parts, body))
}
}

Expand All @@ -48,6 +64,91 @@ impl PyResponse {
.body(body.map(to_boxed).unwrap_or_default())
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;

Ok(Self(Some(response)))
Ok(Self::new(response))
}

/// Return the HTTP status of this response.
#[getter]
fn status(&self) -> PyResult<u16> {
self.parts
.as_ref()
.map(|parts| parts.status.as_u16())
.ok_or_else(|| PyMiddlewareError::ResponseGone.into())
}

/// Return the HTTP version of this response.
#[getter]
fn version(&self) -> PyResult<String> {
self.parts
.as_ref()
.map(|parts| format!("{:?}", parts.version))
.ok_or_else(|| PyMiddlewareError::ResponseGone.into())
}

/// Return the HTTP headers of this response.
/// TODO(can we use `Py::clone_ref()` to prevent cloning the hashmap?)
#[getter]
fn headers(&self) -> PyResult<HashMap<String, String>> {
self.parts
.as_ref()
.map(|parts| {
parts
.headers
.iter()
.map(|(k, v)| -> (String, String) {
let name: String = k.to_string();
let value: String = String::from_utf8_lossy(v.as_bytes()).to_string();
(name, value)
})
.collect()
})
.ok_or_else(|| PyMiddlewareError::ResponseGone.into())
}

/// Insert a new key/value into this response's headers.
/// TODO(investigate if using a PyDict can make the experience more idiomatic)
/// I'd like to be able to do response.headers.get("my-header") and
/// response.headers["my-header"] = 42 instead of implementing set_header() and get_header()
/// under pymethods. The same applies to request.
#[pyo3(text_signature = "($self, key, value)")]
fn set_header(&mut self, key: &str, value: &str) -> PyResult<()> {
match self.parts.as_mut() {
Some(parts) => {
let key = HeaderName::from_str(key)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
let value = HeaderValue::from_str(value)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
parts.headers.insert(key, value);
Ok(())
}
None => Err(PyMiddlewareError::ResponseGone.into()),
}
}

/// Return the HTTP body of this response.
/// Note that this is a costly operation because the whole response body is cloned.
#[getter]
fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
let body = self.body.clone();
pyo3_asyncio::tokio::future_into_py(py, async move {
let body = {
let mut body_guard = body.lock().await;
let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?;
let body = hyper::body::to_bytes(body)
.await
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
let buf = body.clone();
body_guard.replace(to_boxed(body));
buf
};
// TODO: can we use `PyBytes` here?
Ok(body.to_vec())
})
}

/// Set the HTTP body of this response.
#[setter]
fn set_body(&mut self, buf: &[u8]) {
self.body = Arc::new(Mutex::new(Some(to_boxed(buf.to_owned()))));
}
}

0 comments on commit d8eb21d

Please sign in to comment.