Skip to content

Commit

Permalink
Python: map Python middlewares to Tower layers (#1871)
Browse files Browse the repository at this point in the history
* Python: map Python middlewares to Tower layers

* Make middleware layer infallible

* Use message and status code from `PyMiddlewareException`

* Introduce `FuncMetadata` to represent some information about a Python function

* Improve middleware errors

* Add missing copyright headers

* Allow accessing and changing request body

* Allow changing response

* Add some documentation about moving data back-and-forth between Rust and Python

* Add `mypy` to Pokemon service and update typings and comments for middlewares

* Add or update comments on the important types

* Add Rust equivalent of `collections.abc.MutableMapping`

* Add `PyHeaderMap` to make `HeaderMap` accessible from Python

* Apply suggestions from code review

Co-authored-by: Luca Palmieri <[email protected]>

* Improve logging

* Add `RichPyErr` to have a better output for `PyErr`s

* Better error messages for `PyMiddlewareError` variants

* Factor out repeating patterns in tests

* Preserve `__builtins__` in `globals` to fix tests in Python 3.7.10 (our CI version)

* Export `RichPyErr` to fix `cargo doc` error

* Apply suggestions from code review

Co-authored-by: Matteo Bigoi <[email protected]>

* Add missing SPDX headers

* Document that `keys`, `values` and `items` on `PyMutableMapping` causes clones

Co-authored-by: Luca Palmieri <[email protected]>
Co-authored-by: Matteo Bigoi <[email protected]>
  • Loading branch information
3 people authored Nov 10, 2022
1 parent b82b6a6 commit 4f76e35
Show file tree
Hide file tree
Showing 27 changed files with 1,798 additions and 766 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class PythonApplicationGenerator(
##[derive(Debug)]
pub struct App {
handlers: #{HashMap}<String, #{SmithyPython}::PyHandler>,
middlewares: #{SmithyPython}::PyMiddlewares,
middlewares: Vec<#{SmithyPython}::PyMiddlewareHandler>,
context: Option<#{pyo3}::PyObject>,
workers: #{parking_lot}::Mutex<Vec<#{pyo3}::PyObject>>,
}
Expand Down Expand Up @@ -141,7 +141,7 @@ class PythonApplicationGenerator(
fn default() -> Self {
Self {
handlers: Default::default(),
middlewares: #{SmithyPython}::PyMiddlewares::new::<#{Protocol}>(vec![]),
middlewares: vec![],
context: None,
workers: #{parking_lot}::Mutex::new(vec![]),
}
Expand Down Expand Up @@ -171,9 +171,6 @@ class PythonApplicationGenerator(
fn handlers(&mut self) -> &mut #{HashMap}<String, #{SmithyPython}::PyHandler> {
&mut self.handlers
}
fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares {
&mut self.middlewares
}
""",
*codegenScope,
)
Expand Down Expand Up @@ -212,13 +209,22 @@ class PythonApplicationGenerator(
}
rustTemplate(
"""
let middleware_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
let service = #{tower}::ServiceBuilder::new()
.boxed_clone()
.layer(
#{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(self.middlewares.clone(), middleware_locals),
)
.service(builder.build());
let mut service = #{tower}::util::BoxCloneService::new(builder.build());
{
use #{tower}::Layer;
#{tracing}::trace!("adding middlewares to rust python router");
let mut middlewares = self.middlewares.clone();
// Reverse the middlewares, so they run with same order as they defined
middlewares.reverse();
for handler in middlewares {
#{tracing}::trace!(name = &handler.name, "adding python middleware");
let locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
let layer = #{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(handler, locals);
service = #{tower}::util::BoxCloneService::new(layer.layer(service));
}
}
Ok(service)
""",
"Protocol" to protocol.markerStruct(),
Expand Down Expand Up @@ -248,11 +254,17 @@ class PythonApplicationGenerator(
pub fn context(&mut self, context: #{pyo3}::PyObject) {
self.context = Some(context);
}
/// Register a request middleware function that will be run inside a Tower layer, without cloning the body.
/// Register a Python function to be executed inside a Tower middleware layer.
##[pyo3(text_signature = "(${'$'}self, func)")]
pub fn request_middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
use #{SmithyPython}::PyApp;
self.register_middleware(py, func, #{SmithyPython}::PyMiddlewareType::Request)
pub fn middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
let handler = #{SmithyPython}::PyMiddlewareHandler::new(py, func)?;
#{tracing}::trace!(
name = &handler.name,
is_coroutine = handler.is_coroutine,
"registering middleware function",
);
self.middlewares.push(handler);
Ok(())
}
/// Main entrypoint: start the server on multiple workers.
##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ class PythonServerModuleGenerator(
middleware.add_class::<#{SmithyPython}::PyRequest>()?;
middleware.add_class::<#{SmithyPython}::PyResponse>()?;
middleware.add_class::<#{SmithyPython}::PyMiddlewareException>()?;
middleware.add_class::<#{SmithyPython}::PyHttpVersion>()?;
pyo3::py_run!(
py,
middleware,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class PythonServerOperationHandlerGenerator(
writable {
rustTemplate(
"""
#{tracing}::debug!("Executing Python handler function `$name()`");
#{tracing}::trace!(name = "$name", "executing python handler function");
#{pyo3}::Python::with_gil(|py| {
let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?;
let output = if handler.args == 1 {
Expand All @@ -110,7 +110,7 @@ class PythonServerOperationHandlerGenerator(
writable {
rustTemplate(
"""
#{tracing}::debug!("Executing Python handler coroutine `$name()`");
#{tracing}::trace!(name = "$name", "executing python handler coroutine");
let result = #{pyo3}::Python::with_gil(|py| {
let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?;
let coroutine = if handler.args == 1 {
Expand All @@ -132,15 +132,9 @@ class PythonServerOperationHandlerGenerator(
"""
// Catch and record a Python traceback.
result.map_err(|e| {
let traceback = #{pyo3}::Python::with_gil(|py| {
match e.traceback(py) {
Some(t) => t.format().unwrap_or_else(|e| e.to_string()),
None => "Unknown traceback\n".to_string()
}
});
let error = e.into();
#{tracing}::error!("{}{}", traceback, error);
error
let rich_py_err = #{SmithyPython}::rich_py_err(#{pyo3}::Python::with_gil(|py| { e.clone_ref(py) }));
#{tracing}::error!(error = ?rich_py_err, "handler error");
e.into()
})
""",
*codegenScope,
Expand Down
8 changes: 8 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ tracing-appender = { version = "0.2.2"}
[dev-dependencies]
pretty_assertions = "1"
futures-util = "0.3"
tower-test = "0.4"
tokio-test = "0.4"
pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime"] }

[[test]]
name = "middleware_tests"
path = "src/middleware/pytests/harness.rs"
harness = false

[package.metadata.docs.rs]
all-features = true
Expand Down
3 changes: 3 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ build: codegen
cargo build
ln -sf $(DEBUG_SHARED_LIBRARY_SRC) $(SHARED_LIBRARY_DST)

py_check: build
mypy pokemon_service.py

release: codegen
cargo build --release
ln -sf $(RELEASE_SHARED_LIBRARY_SRC) $(SHARED_LIBRARY_DST)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

# NOTE: This is manually created to surpass some mypy errors and it is incomplete,
# in future we will autogenerate correct stubs.

from typing import Any, TypeVar, Callable

F = TypeVar("F", bound=Callable[..., Any])

class App:
context: Any
run: Any

def middleware(self, func: F) -> F: ...
def do_nothing(self, func: F) -> F: ...
def get_pokemon_species(self, func: F) -> F: ...
def get_server_statistics(self, func: F) -> F: ...
def check_health(self, func: F) -> F: ...
def stream_pokemon_radio(self, func: F) -> F: ...
5 changes: 5 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/examples/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

[mypy]
strict = True
117 changes: 67 additions & 50 deletions rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,52 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import itertools
import logging
import random
from threading import Lock
from dataclasses import dataclass
from typing import List, Optional

import aiohttp
from typing import List, Optional, Callable, Awaitable

from libpokemon_service_server_sdk import App
from libpokemon_service_server_sdk.error import ResourceNotFoundException
from libpokemon_service_server_sdk.input import (
DoNothingInput, GetPokemonSpeciesInput, GetServerStatisticsInput,
CheckHealthInput, StreamPokemonRadioInput)
from libpokemon_service_server_sdk.logging import TracingHandler
from libpokemon_service_server_sdk.middleware import (MiddlewareException,
Request)
from libpokemon_service_server_sdk.model import FlavorText, Language
from libpokemon_service_server_sdk.output import (
DoNothingOutput, GetPokemonSpeciesOutput, GetServerStatisticsOutput,
CheckHealthOutput, StreamPokemonRadioOutput)
from libpokemon_service_server_sdk.types import ByteStream
from libpokemon_service_server_sdk.error import ResourceNotFoundException # type: ignore
from libpokemon_service_server_sdk.input import ( # type: ignore
DoNothingInput,
GetPokemonSpeciesInput,
GetServerStatisticsInput,
CheckHealthInput,
StreamPokemonRadioInput,
)
from libpokemon_service_server_sdk.logging import TracingHandler # type: ignore
from libpokemon_service_server_sdk.middleware import ( # type: ignore
MiddlewareException,
Response,
Request,
)
from libpokemon_service_server_sdk.model import FlavorText, Language # type: ignore
from libpokemon_service_server_sdk.output import ( # type: ignore
DoNothingOutput,
GetPokemonSpeciesOutput,
GetServerStatisticsOutput,
CheckHealthOutput,
StreamPokemonRadioOutput,
)
from libpokemon_service_server_sdk.types import ByteStream # type: ignore

# Logging can bee setup using standard Python tooling. We provide
# fast logging handler, Tracingandler based on Rust tracing crate.
logging.basicConfig(handlers=[TracingHandler(level=logging.DEBUG).handler()])


class SafeCounter:
def __init__(self):
def __init__(self) -> None:
self._val = 0
self._lock = Lock()

def increment(self):
def increment(self) -> None:
with self._lock:
self._val += 1

def value(self):
def value(self) -> int:
with self._lock:
return self._val

Expand All @@ -63,9 +72,9 @@ def value(self):
#
# Synchronization:
# Instance of `Context` class will be cloned for every worker and all state kept in `Context`
# will be specific to that process. There is no protection provided by default,
# it is up to you to have synchronization between processes.
# If you really want to share state between different processes you need to use `multiprocessing` primitives:
# will be specific to that process. There is no protection provided by default,
# it is up to you to have synchronization between processes.
# If you really want to share state between different processes you need to use `multiprocessing` primitives:
# https://docs.python.org/3/library/multiprocessing.html#sharing-state-between-processes
@dataclass
class Context:
Expand Down Expand Up @@ -124,49 +133,51 @@ def get_random_radio_stream(self) -> str:
# Middleware
############################################################
# Middlewares are sync or async function decorated by `@app.middleware`.
# They are executed in order and take as input the HTTP request object.
# A middleware can return multiple values, following these rules:
# * Middleware not returning will let the execution continue without
# changing the original request.
# * Middleware returning a modified Request will update the original
# request before continuing the execution.
# * Middleware returning a Response will immediately terminate the request
# handling and return the response constructed from Python.
# * Middleware raising MiddlewareException will immediately terminate the
# request handling and return a protocol specific error, with the option of
# setting the HTTP return code.
# * Middleware raising any other exception will immediately terminate the
# request handling and return a protocol specific error, with HTTP status
# code 500.
@app.request_middleware
def check_content_type_header(request: Request):
content_type = request.get_header("content-type")
# They are executed in order and take as input the HTTP request object and
# the handler or the next middleware in the stack.
# A middleware should return a `Response`, either by calling `next` with `Request`
# to get `Response` from the handler or by constructing `Response` by itself.
# It can also modify the `Request` before calling `next` or it can also modify
# the `Response` returned by the handler.
# It can also raise an `MiddlewareException` with custom error message and HTTP status code,
# any other raised exceptions will cause an internal server error response to be returned.

# Next is either the next middleware in the stack or the handler.
Next = Callable[[Request], Awaitable[Response]]

# This middleware checks the `Content-Type` from the request header,
# logs some information depending on that and then calls `next`.
@app.middleware
async def check_content_type_header(request: Request, next: Next) -> Response:
content_type = request.headers.get("content-type")
if content_type == "application/json":
logging.debug("Found valid `application/json` content type")
else:
logging.warning(
f"Invalid content type {content_type}, dumping headers: {request.headers()}"
f"Invalid content type {content_type}, dumping headers: {request.headers}"
)
return await next(request)


# This middleware adds a new header called `x-amzn-answer` to the
# request. We expect to see this header to be populated in the next
# middleware.
@app.request_middleware
def add_x_amzn_answer_header(request: Request):
request.set_header("x-amzn-answer", "42")
@app.middleware
async def add_x_amzn_answer_header(request: Request, next: Next) -> Response:
request.headers["x-amzn-answer"] = "42"
logging.debug("Setting `x-amzn-answer` header to 42")
return request
return await next(request)


# This middleware checks if the header `x-amzn-answer` is correctly set
# to 42, otherwise it returns an exception with a set status code.
@app.request_middleware
async def check_x_amzn_answer_header(request: Request):
@app.middleware
async def check_x_amzn_answer_header(request: Request, next: Next) -> Response:
# Check that `x-amzn-answer` is 42.
if request.get_header("x-amzn-answer") != "42":
if request.headers.get("x-amzn-answer") != "42":
# Return an HTTP 401 Unauthorized if the content type is not JSON.
raise MiddlewareException("Invalid answer", 401)
return await next(request)


###########################################################
Expand Down Expand Up @@ -216,7 +227,11 @@ def check_health(_: CheckHealthInput) -> CheckHealthOutput:

# Stream a random Pokémon song.
@app.stream_pokemon_radio
async def stream_pokemon_radio(_: StreamPokemonRadioInput, context: Context):
async def stream_pokemon_radio(
_: StreamPokemonRadioInput, context: Context
) -> StreamPokemonRadioOutput:
import aiohttp

radio_url = context.get_random_radio_stream()
logging.info("Random radio URL for this stream is %s", radio_url)
async with aiohttp.ClientSession() as session:
Expand All @@ -229,7 +244,9 @@ async def stream_pokemon_radio(_: StreamPokemonRadioInput, context: Context):
###########################################################
# Run the server.
###########################################################
def main():
def main() -> None:
app.run(workers=1)

main()

if __name__ == "__main__":
main()
7 changes: 6 additions & 1 deletion rust-runtime/aws-smithy-http-server-python/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ impl PyMiddlewareException {

impl From<PyErr> for PyMiddlewareException {
fn from(other: PyErr) -> Self {
Self::newpy(other.to_string(), None)
// Try to extract `PyMiddlewareException` from `PyErr` and use that if succeed
let middleware_err = Python::with_gil(|py| other.to_object(py).extract::<Self>(py));
match middleware_err {
Ok(err) => err,
Err(_) => Self::newpy(other.to_string(), None),
}
}
}

Expand Down
Loading

0 comments on commit 4f76e35

Please sign in to comment.