Skip to content

Commit

Permalink
Python: map Python middlewares to Tower layers
Browse files Browse the repository at this point in the history
  • Loading branch information
unexge committed Oct 21, 2022
1 parent e74ecf3 commit 1fd9c73
Show file tree
Hide file tree
Showing 17 changed files with 577 additions and 678 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class PythonApplicationGenerator(
renderAppStruct(writer)
renderAppDefault(writer)
renderAppClone(writer)
renderAppImpl(writer)
renderPyAppTrait(writer)
renderPyMethods(writer)
}
Expand All @@ -107,7 +108,7 @@ class PythonApplicationGenerator(
##[derive(Debug)]
pub struct App {
handlers: #{HashMap}<String, #{SmithyPython}::PyHandler>,
middlewares: #{SmithyPython}::PyMiddlewares,
middlewares: Vec<#{SmithyPython}::PyMiddlewareHandler>,
context: Option<#{pyo3}::PyObject>,
workers: #{parking_lot}::Mutex<Vec<#{pyo3}::PyObject>>,
}
Expand Down Expand Up @@ -141,7 +142,7 @@ class PythonApplicationGenerator(
fn default() -> Self {
Self {
handlers: Default::default(),
middlewares: #{SmithyPython}::PyMiddlewares::new::<#{Protocol}>(vec![]),
middlewares: vec![],
context: None,
workers: #{parking_lot}::Mutex::new(vec![]),
}
Expand All @@ -153,6 +154,30 @@ class PythonApplicationGenerator(
)
}

private fun renderAppImpl(writer: RustWriter) {
writer.rustBlockTemplate(
"""
impl App
""",
*codegenScope,
) {
rustTemplate(
"""
// Check if a Python function is a coroutine. Since the function has not run yet,
// we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`.
fn is_coroutine(&self, py: #{pyo3}::Python, func: &#{pyo3}::PyObject) -> #{pyo3}::PyResult<bool> {
let inspect = py.import("inspect")?;
// NOTE: that `asyncio.iscoroutine()` doesn't work here.
inspect
.call_method1("iscoroutinefunction", (func,))?
.extract::<bool>()
}
""",
*codegenScope,
)
}
}

private fun renderPyAppTrait(writer: RustWriter) {
writer.rustBlockTemplate(
"""
Expand All @@ -171,9 +196,6 @@ class PythonApplicationGenerator(
fn handlers(&mut self) -> &mut #{HashMap}<String, #{SmithyPython}::PyHandler> {
&mut self.handlers
}
fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares {
&mut self.middlewares
}
""",
*codegenScope,
)
Expand Down Expand Up @@ -212,13 +234,25 @@ class PythonApplicationGenerator(
}
rustTemplate(
"""
let middleware_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
let service = #{tower}::ServiceBuilder::new()
.boxed_clone()
.layer(
#{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(self.middlewares.clone(), middleware_locals),
)
.service(builder.build());
use #{tower}::{Layer, ServiceExt};
let mut service = #{tower}::util::BoxCloneService::new(builder.build());
let mut middlewares = self.middlewares.clone();
// Reverse the middlewares, so they run with same order as they defined
middlewares.reverse();
for handler in middlewares {
let locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
let name = handler.name.clone();
let layer = #{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(handler, locals);
service = #{tower}::util::BoxCloneService::new(layer.layer(service).map_err(move |err| {
// TODO: correctly propogate errors.
tracing::error!(
err = %err,
"'{}' middleware failed",
name,
);
loop {}
}));
}
Ok(service)
""",
"Protocol" to protocol.markerStruct(),
Expand Down Expand Up @@ -248,11 +282,23 @@ class PythonApplicationGenerator(
pub fn context(&mut self, context: #{pyo3}::PyObject) {
self.context = Some(context);
}
/// Register a request middleware function that will be run inside a Tower layer, without cloning the body.
/// Register a Python function to be executed inside a Tower middleware layer.
##[pyo3(text_signature = "(${'$'}self, func)")]
pub fn request_middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
use #{SmithyPython}::PyApp;
self.register_middleware(py, func, #{SmithyPython}::PyMiddlewareType::Request)
pub fn middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
let name = func.getattr(py, "__name__")?.extract::<String>(py)?;
let is_coroutine = self.is_coroutine(py, &func)?;
let handler = #{SmithyPython}::PyMiddlewareHandler {
name,
func,
is_coroutine,
};
tracing::info!(
"registering middleware function `{}`, coroutine: {}",
handler.name,
handler.is_coroutine,
);
self.middlewares.push(handler);
Ok(())
}
/// Main entrypoint: start the server on multiple workers.
##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ class PythonServerModuleGenerator(
middleware.add_class::<#{SmithyPython}::PyRequest>()?;
middleware.add_class::<#{SmithyPython}::PyResponse>()?;
middleware.add_class::<#{SmithyPython}::PyMiddlewareException>()?;
middleware.add_class::<#{SmithyPython}::PyHttpVersion>()?;
pyo3::py_run!(
py,
middleware,
Expand Down
8 changes: 8 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ tracing-appender = { version = "0.2.2"}
[dev-dependencies]
pretty_assertions = "1"
futures-util = "0.3"
tower-test = "0.4"
tokio-test = "0.4"
pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime"] }

[[test]]
name = "middleware_tests"
path = "src/middleware/pytests/harness.rs"
harness = false

[package.metadata.docs.rs]
all-features = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,35 +138,37 @@ def get_random_radio_stream(self) -> str:
# * Middleware raising any other exception will immediately terminate the
# request handling and return a protocol specific error, with HTTP status
# code 500.
@app.request_middleware
def check_content_type_header(request: Request):
content_type = request.get_header("content-type")
@app.middleware
async def check_content_type_header(request: Request, next):
content_type = request.headers.get("content-type")
if content_type == "application/json":
logging.debug("Found valid `application/json` content type")
else:
logging.warning(
f"Invalid content type {content_type}, dumping headers: {request.headers()}"
f"Invalid content type {content_type}, dumping headers: {request.headers}"
)
return await next(request)


# This middleware adds a new header called `x-amzn-answer` to the
# request. We expect to see this header to be populated in the next
# middleware.
@app.request_middleware
def add_x_amzn_answer_header(request: Request):
@app.middleware
async def add_x_amzn_answer_header(request: Request, next):
request.set_header("x-amzn-answer", "42")
logging.debug("Setting `x-amzn-answer` header to 42")
return request
return await next(request)


# This middleware checks if the header `x-amzn-answer` is correctly set
# to 42, otherwise it returns an exception with a set status code.
@app.request_middleware
async def check_x_amzn_answer_header(request: Request):
@app.middleware
async def check_x_amzn_answer_header(request: Request, next):
# Check that `x-amzn-answer` is 42.
if request.get_header("x-amzn-answer") != "42":
if request.headers.get("x-amzn-answer") != "42":
# Return an HTTP 401 Unauthorized if the content type is not JSON.
raise MiddlewareException("Invalid answer", 401)
return await next(request)


###########################################################
Expand Down Expand Up @@ -232,4 +234,5 @@ async def stream_pokemon_radio(_: StreamPokemonRadioInput, context: Context):
def main():
app.run(workers=1)

main()
if __name__ == '__main__':
main()
7 changes: 7 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use aws_smithy_http_server::{
use aws_smithy_types::date_time::{ConversionError, DateTimeParseError};
use pyo3::{create_exception, exceptions::PyException as BasePyException, prelude::*, PyErr};
use thiserror::Error;
use tower::BoxError;

/// Python error that implements foreign errors.
#[derive(Error, Debug)]
Expand Down Expand Up @@ -67,6 +68,12 @@ impl From<PyErr> for PyMiddlewareException {
}
}

impl From<BoxError> for PyMiddlewareException {
fn from(other: BoxError) -> Self {
Self::newpy(other.to_string(), None)
}
}

impl IntoResponse<RestJson1> for PyMiddlewareException {
fn into_response(self) -> http::Response<BoxBody> {
http::Response::builder()
Expand Down
4 changes: 1 addition & 3 deletions rust-runtime/aws-smithy-http-server-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ pub use error::{PyError, PyMiddlewareException};
#[doc(inline)]
pub use logging::{py_tracing_event, PyTracingHandler};
#[doc(inline)]
pub use middleware::{
PyHttpVersion, PyMiddlewareLayer, PyMiddlewareType, PyMiddlewares, PyRequest, PyResponse,
};
pub use middleware::{PyMiddlewareHandler, PyMiddlewareLayer, PyRequest, PyResponse};
#[doc(inline)]
pub use server::{PyApp, PyHandler};
#[doc(inline)]
Expand Down
17 changes: 17 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use std::error::Error;
use std::fmt;

#[derive(Debug)]
pub enum PyMiddlewareError {
ResponseAlreadyGone,
}

impl fmt::Display for PyMiddlewareError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::ResponseAlreadyGone => write!(f, "response is already consumed"),
}
}
}

impl Error for PyMiddlewareError {}
Loading

0 comments on commit 1fd9c73

Please sign in to comment.