Skip to content

Commit

Permalink
Allow accessing and changing request body
Browse files Browse the repository at this point in the history
  • Loading branch information
unexge committed Oct 24, 2022
1 parent 2329748 commit a1cab3e
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

use aws_smithy_http_server::body::to_boxed;
use aws_smithy_http_server::proto::rest_json_1::RestJson1;
use aws_smithy_http_server::{body::to_boxed, proto::rest_json_1::RestJson1};
use aws_smithy_http_server_python::{
middleware::{PyMiddlewareHandler, PyMiddlewareLayer},
PyMiddlewareException, PyResponse,
Expand Down Expand Up @@ -244,6 +243,59 @@ def second_middleware(request, next):
Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn changes_req_body() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let handler = Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
async def middleware(request, next):
body = bytes(await request.body).decode()
body_reversed = body[::-1]
request.body = body_reversed.encode()
return await next(request)
"#,
"",
"",
)?;
let handler = module.getattr("middleware")?.into();
Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?)
})?;
let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, mut handle) = mock::spawn_with(|svc| {
let svc = svc.map_err(|err| panic!("service failed: {err}"));
let svc = layer.layer(svc);
svc
});
assert_ready_ok!(service.poll_ready());

let th = tokio::spawn(async move {
let (req, send_response) = handle.next_request().await.unwrap();
let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap();
assert_eq!(req_body, "hello server".chars().rev().collect::<String>());
send_response.send_response(
Response::builder()
.body(to_boxed("hello client"))
.expect("could not create response"),
);
});

let request = Request::builder()
.body(Body::from("hello server"))
.expect("could not create request");
let response = service.call(request);

let body = hyper::body::to_bytes(response.await.unwrap().into_body())
.await
.unwrap();
assert_eq!(body, "hello client");
th.await.unwrap();
Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn fails_if_req_is_used_after_calling_next() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,33 @@ async fn accessing_request_properties() -> PyResult<()> {
})
}

// #[pyo3_asyncio::tokio::test]
// async fn accessing_request_body() -> PyResult<()> {
// todo!()
// }
#[pyo3_asyncio::tokio::test]
async fn accessing_and_changing_request_body() -> PyResult<()> {
let request = Request::builder()
.body(Body::from("hello world"))
.expect("could not build request");
let py_request = PyRequest::new(request);

Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
async def handler(req):
# TODO: why we need to wrap with `bytes`?
assert bytes(await req.body) == b"hello world"
req.body = b"hello world from middleware"
assert bytes(await req.body) == b"hello world from middleware"
"#,
"",
"",
)?;
let handler = module.getattr("handler")?;

let output = handler.call1((py_request,))?;
Ok::<_, PyErr>(pyo3_asyncio::tokio::into_future(output))
})??
.await?;

Ok(())
}
104 changes: 76 additions & 28 deletions rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,43 @@
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;

use aws_smithy_http_server::body::Body;
use http::header::HeaderName;
use http::request::Parts;
use http::{HeaderValue, Request};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use tokio::sync::Mutex;

use super::PyMiddlewareError;

/// Python-compatible [Request] object.
///
/// For performance reasons, there is not support yet to pass the body to the Python middleware,
/// as it requires to consume and clone the body, which is a very expensive operation.
///
/// TODO(if customers request for it, we can implemented an opt-in functionality to also pass
/// the body around).
#[pyclass(name = "Request")]
#[pyo3(text_signature = "(request)")]
#[derive(Debug)]
pub struct PyRequest(Option<Request<Body>>);
pub struct PyRequest {
parts: Option<Parts>,
body: Option<Arc<Mutex<Option<Body>>>>,
}

impl PyRequest {
/// Create a new Python-compatible [Request] structure from the Rust side.
pub fn new(request: Request<Body>) -> Self {
Self(Some(request))
let (parts, body) = request.into_parts();
Self {
parts: Some(parts),
body: Some(Arc::new(Mutex::new(Some(body)))),
}
}

pub fn take_inner(&mut self) -> Option<Request<Body>> {
self.0.take()
let parts = self.parts.take()?;
let body = self.body.take()?;
let body = Arc::try_unwrap(body).ok()?;
let body = body.into_inner().take()?;
Some(Request::from_parts(parts, body))
}
}

Expand All @@ -42,47 +52,48 @@ impl PyRequest {
/// Return the HTTP method of this request.
#[getter]
fn method(&self) -> PyResult<String> {
self.0
self.parts
.as_ref()
.map(|req| req.method().to_string())
.ok_or_else(|| PyRuntimeError::new_err("request is gone"))
.map(|parts| parts.method.to_string())
.ok_or_else(|| PyMiddlewareError::RequestGone.into())
}

/// Return the URI of this request.
#[getter]
fn uri(&self) -> PyResult<String> {
self.0
self.parts
.as_ref()
.map(|req| req.uri().to_string())
.ok_or_else(|| PyRuntimeError::new_err("request is gone"))
.map(|parts| parts.uri.to_string())
.ok_or_else(|| PyMiddlewareError::RequestGone.into())
}

/// Return the HTTP version of this request.
#[getter]
fn version(&self) -> PyResult<String> {
self.0
self.parts
.as_ref()
.map(|req| format!("{:?}", req.version()))
.ok_or_else(|| PyRuntimeError::new_err("request is gone"))
.map(|parts| format!("{:?}", parts.version))
.ok_or_else(|| PyMiddlewareError::RequestGone.into())
}

/// Return the HTTP headers of this request.
/// TODO(can we use `Py::clone_ref()` to prevent cloning the hashmap?)
#[getter]
fn headers(&self) -> PyResult<HashMap<String, String>> {
self.0
self.parts
.as_ref()
.map(|req| {
req.headers()
.into_iter()
.map(|parts| {
parts
.headers
.iter()
.map(|(k, v)| -> (String, String) {
let name: String = k.as_str().to_string();
let name: String = k.to_string();
let value: String = String::from_utf8_lossy(v.as_bytes()).to_string();
(name, value)
})
.collect()
})
.ok_or_else(|| PyRuntimeError::new_err("request is gone"))
.ok_or_else(|| PyMiddlewareError::RequestGone.into())
}

/// Insert a new key/value into this request's headers.
Expand All @@ -92,16 +103,53 @@ impl PyRequest {
/// under pymethods. The same applies to response.
#[pyo3(text_signature = "($self, key, value)")]
fn set_header(&mut self, key: &str, value: &str) -> PyResult<()> {
match self.0.as_mut() {
Some(ref mut req) => {
match self.parts.as_mut() {
Some(parts) => {
let key = HeaderName::from_str(key)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
let value = HeaderValue::from_str(value)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
req.headers_mut().insert(key, value);
parts.headers.insert(key, value);
Ok(())
}
None => Err(PyMiddlewareError::RequestGone.into()),
}
}

/// Return the HTTP body of this request.
/// Note that this is a costly operation because the whole request body is cloned.
#[getter]
fn body<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> {
let body = self
.body
.as_mut()
.map(|b| b.clone())
.ok_or(PyMiddlewareError::RequestGone)?;
pyo3_asyncio::tokio::future_into_py(py, async move {
let body = {
let mut body_guard = body.lock().await;
let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?;
let body = hyper::body::to_bytes(body)
.await
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
let buf = body.clone();
body_guard.replace(Body::from(body));
buf
};
// TODO: can we use `PyBytes` here?
Ok(body.to_vec())
})
}

/// Set the HTTP body of this request.
#[setter]
fn set_body(&mut self, buf: &[u8]) -> PyResult<()> {
match self.body.as_mut() {
Some(body) => {
*body = Arc::new(Mutex::new(Some(Body::from(buf.to_owned()))));
Ok(())
}
None => Err(PyRuntimeError::new_err("request is gone")),
None => Err(PyMiddlewareError::RequestGone.into()),
}
}
}

0 comments on commit a1cab3e

Please sign in to comment.