Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: map Python middlewares to Tower layers #1871

Merged
merged 23 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e0be364
Python: map Python middlewares to Tower layers
unexge Oct 18, 2022
58f18ae
Make middleware layer infallible
unexge Oct 21, 2022
67dafeb
Use message and status code from `PyMiddlewareException`
unexge Oct 21, 2022
5437e73
Introduce `FuncMetadata` to represent some information about a Python…
unexge Oct 21, 2022
2b0d4f6
Improve middleware errors
unexge Oct 21, 2022
ec638e6
Add missing copyright headers
unexge Oct 21, 2022
b084ba2
Allow accessing and changing request body
unexge Oct 24, 2022
b19e6f3
Allow changing response
unexge Oct 24, 2022
ed09a19
Add some documentation about moving data back-and-forth between Rust …
unexge Oct 25, 2022
8e4cc14
Add `mypy` to Pokemon service and update typings and comments for mid…
unexge Oct 25, 2022
1f92ddf
Add or update comments on the important types
unexge Oct 25, 2022
5700cef
Add Rust equivalent of `collections.abc.MutableMapping`
unexge Oct 27, 2022
57b7c69
Add `PyHeaderMap` to make `HeaderMap` accessible from Python
unexge Oct 27, 2022
330e222
Apply suggestions from code review
unexge Oct 28, 2022
c0208b8
Improve logging
unexge Oct 28, 2022
63d522b
Add `RichPyErr` to have a better output for `PyErr`s
unexge Oct 28, 2022
c1cdad5
Better error messages for `PyMiddlewareError` variants
unexge Oct 28, 2022
02b5164
Factor out repeating patterns in tests
unexge Oct 28, 2022
4e1cd23
Preserve `__builtins__` in `globals` to fix tests in Python 3.7.10 (o…
unexge Oct 31, 2022
f5c9a56
Export `RichPyErr` to fix `cargo doc` error
unexge Oct 31, 2022
cef64c4
Apply suggestions from code review
unexge Oct 31, 2022
385f133
Add missing SPDX headers
unexge Oct 31, 2022
c8bd368
Document that `keys`, `values` and `items` on `PyMutableMapping` caus…
unexge Nov 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class PythonApplicationGenerator(
##[derive(Debug)]
pub struct App {
handlers: #{HashMap}<String, #{SmithyPython}::PyHandler>,
middlewares: #{SmithyPython}::PyMiddlewares,
middlewares: Vec<#{SmithyPython}::PyMiddlewareHandler>,
context: Option<#{pyo3}::PyObject>,
workers: #{parking_lot}::Mutex<Vec<#{pyo3}::PyObject>>,
}
Expand Down Expand Up @@ -141,7 +141,7 @@ class PythonApplicationGenerator(
fn default() -> Self {
Self {
handlers: Default::default(),
middlewares: #{SmithyPython}::PyMiddlewares::new::<#{Protocol}>(vec![]),
middlewares: vec![],
context: None,
workers: #{parking_lot}::Mutex::new(vec![]),
}
Expand Down Expand Up @@ -171,9 +171,6 @@ class PythonApplicationGenerator(
fn handlers(&mut self) -> &mut #{HashMap}<String, #{SmithyPython}::PyHandler> {
&mut self.handlers
}
fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares {
&mut self.middlewares
}
""",
*codegenScope,
)
Expand Down Expand Up @@ -212,13 +209,22 @@ class PythonApplicationGenerator(
}
rustTemplate(
"""
let middleware_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
let service = #{tower}::ServiceBuilder::new()
.boxed_clone()
.layer(
#{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(self.middlewares.clone(), middleware_locals),
)
.service(builder.build());
let mut service = #{tower}::util::BoxCloneService::new(builder.build());

{
use #{tower}::Layer;
#{tracing}::trace!("adding middlewares to rust python router");
let mut middlewares = self.middlewares.clone();
// Reverse the middlewares, so they run with same order as they defined
middlewares.reverse();
for handler in middlewares {
#{tracing}::trace!(name = &handler.name, "adding python middleware");
let locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
let layer = #{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(handler, locals);
service = #{tower}::util::BoxCloneService::new(layer.layer(service));
}
}

Ok(service)
""",
"Protocol" to protocol.markerStruct(),
Expand Down Expand Up @@ -248,11 +254,17 @@ class PythonApplicationGenerator(
pub fn context(&mut self, context: #{pyo3}::PyObject) {
self.context = Some(context);
}
/// Register a request middleware function that will be run inside a Tower layer, without cloning the body.
/// Register a Python function to be executed inside a Tower middleware layer.
##[pyo3(text_signature = "(${'$'}self, func)")]
pub fn request_middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
use #{SmithyPython}::PyApp;
self.register_middleware(py, func, #{SmithyPython}::PyMiddlewareType::Request)
pub fn middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
let handler = #{SmithyPython}::PyMiddlewareHandler::new(py, func)?;
#{tracing}::trace!(
name = &handler.name,
is_coroutine = handler.is_coroutine,
"registering middleware function",
);
self.middlewares.push(handler);
Ok(())
}
/// Main entrypoint: start the server on multiple workers.
##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ class PythonServerModuleGenerator(
middleware.add_class::<#{SmithyPython}::PyRequest>()?;
middleware.add_class::<#{SmithyPython}::PyResponse>()?;
middleware.add_class::<#{SmithyPython}::PyMiddlewareException>()?;
middleware.add_class::<#{SmithyPython}::PyHttpVersion>()?;
pyo3::py_run!(
py,
middleware,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class PythonServerOperationHandlerGenerator(
writable {
rustTemplate(
"""
#{tracing}::debug!("Executing Python handler function `$name()`");
#{tracing}::trace!(name = "$name", "executing python handler function");
#{pyo3}::Python::with_gil(|py| {
let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?;
let output = if handler.args == 1 {
Expand All @@ -110,7 +110,7 @@ class PythonServerOperationHandlerGenerator(
writable {
rustTemplate(
"""
#{tracing}::debug!("Executing Python handler coroutine `$name()`");
#{tracing}::trace!(name = "$name", "executing python handler coroutine");
let result = #{pyo3}::Python::with_gil(|py| {
let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?;
let coroutine = if handler.args == 1 {
Expand All @@ -132,15 +132,9 @@ class PythonServerOperationHandlerGenerator(
"""
// Catch and record a Python traceback.
result.map_err(|e| {
let traceback = #{pyo3}::Python::with_gil(|py| {
match e.traceback(py) {
Some(t) => t.format().unwrap_or_else(|e| e.to_string()),
None => "Unknown traceback\n".to_string()
}
});
let error = e.into();
#{tracing}::error!("{}{}", traceback, error);
error
let rich_py_err = #{SmithyPython}::rich_py_err(#{pyo3}::Python::with_gil(|py| { e.clone_ref(py) }));
#{tracing}::error!(error = ?rich_py_err, "handler error");
e.into()
})
""",
*codegenScope,
Expand Down
8 changes: 8 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ tracing-appender = { version = "0.2.2"}
[dev-dependencies]
pretty_assertions = "1"
futures-util = "0.3"
tower-test = "0.4"
tokio-test = "0.4"
pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime"] }

[[test]]
name = "middleware_tests"
path = "src/middleware/pytests/harness.rs"
harness = false

[package.metadata.docs.rs]
all-features = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ build: codegen
cargo build
ln -sf $(DEBUG_SHARED_LIBRARY_SRC) $(SHARED_LIBRARY_DST)

py_check: build
mypy pokemon_service.py

release: codegen
cargo build --release
ln -sf $(RELEASE_SHARED_LIBRARY_SRC) $(SHARED_LIBRARY_DST)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

# NOTE: This is manually created to surpass some mypy errors and it is incomplete,
unexge marked this conversation as resolved.
Show resolved Hide resolved
# in future we will autogenerate correct stubs.

from typing import Any, TypeVar, Callable

F = TypeVar("F", bound=Callable[..., Any])

class App:
context: Any
run: Any

def middleware(self, func: F) -> F: ...
def do_nothing(self, func: F) -> F: ...
def get_pokemon_species(self, func: F) -> F: ...
def get_server_statistics(self, func: F) -> F: ...
def check_health(self, func: F) -> F: ...
def stream_pokemon_radio(self, func: F) -> F: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

[mypy]
unexge marked this conversation as resolved.
Show resolved Hide resolved
strict = True
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,52 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import itertools
import logging
import random
from threading import Lock
from dataclasses import dataclass
from typing import List, Optional

import aiohttp
from typing import List, Optional, Callable, Awaitable

from libpokemon_service_server_sdk import App
from libpokemon_service_server_sdk.error import ResourceNotFoundException
from libpokemon_service_server_sdk.input import (
DoNothingInput, GetPokemonSpeciesInput, GetServerStatisticsInput,
CheckHealthInput, StreamPokemonRadioInput)
from libpokemon_service_server_sdk.logging import TracingHandler
from libpokemon_service_server_sdk.middleware import (MiddlewareException,
Request)
from libpokemon_service_server_sdk.model import FlavorText, Language
from libpokemon_service_server_sdk.output import (
DoNothingOutput, GetPokemonSpeciesOutput, GetServerStatisticsOutput,
CheckHealthOutput, StreamPokemonRadioOutput)
from libpokemon_service_server_sdk.types import ByteStream
from libpokemon_service_server_sdk.error import ResourceNotFoundException # type: ignore
from libpokemon_service_server_sdk.input import ( # type: ignore
DoNothingInput,
GetPokemonSpeciesInput,
GetServerStatisticsInput,
CheckHealthInput,
StreamPokemonRadioInput,
)
from libpokemon_service_server_sdk.logging import TracingHandler # type: ignore
from libpokemon_service_server_sdk.middleware import ( # type: ignore
MiddlewareException,
Response,
Request,
)
from libpokemon_service_server_sdk.model import FlavorText, Language # type: ignore
from libpokemon_service_server_sdk.output import ( # type: ignore
DoNothingOutput,
GetPokemonSpeciesOutput,
GetServerStatisticsOutput,
CheckHealthOutput,
StreamPokemonRadioOutput,
)
from libpokemon_service_server_sdk.types import ByteStream # type: ignore

# Logging can bee setup using standard Python tooling. We provide
# fast logging handler, Tracingandler based on Rust tracing crate.
logging.basicConfig(handlers=[TracingHandler(level=logging.DEBUG).handler()])


class SafeCounter:
def __init__(self):
def __init__(self) -> None:
self._val = 0
self._lock = Lock()

def increment(self):
def increment(self) -> None:
with self._lock:
self._val += 1

def value(self):
def value(self) -> int:
with self._lock:
return self._val

Expand All @@ -63,9 +72,9 @@ def value(self):
#
# Synchronization:
# Instance of `Context` class will be cloned for every worker and all state kept in `Context`
# will be specific to that process. There is no protection provided by default,
# it is up to you to have synchronization between processes.
# If you really want to share state between different processes you need to use `multiprocessing` primitives:
# will be specific to that process. There is no protection provided by default,
# it is up to you to have synchronization between processes.
# If you really want to share state between different processes you need to use `multiprocessing` primitives:
# https://docs.python.org/3/library/multiprocessing.html#sharing-state-between-processes
@dataclass
class Context:
Expand Down Expand Up @@ -124,49 +133,51 @@ def get_random_radio_stream(self) -> str:
# Middleware
############################################################
# Middlewares are sync or async function decorated by `@app.middleware`.
# They are executed in order and take as input the HTTP request object.
# A middleware can return multiple values, following these rules:
# * Middleware not returning will let the execution continue without
# changing the original request.
# * Middleware returning a modified Request will update the original
# request before continuing the execution.
# * Middleware returning a Response will immediately terminate the request
# handling and return the response constructed from Python.
# * Middleware raising MiddlewareException will immediately terminate the
# request handling and return a protocol specific error, with the option of
# setting the HTTP return code.
# * Middleware raising any other exception will immediately terminate the
# request handling and return a protocol specific error, with HTTP status
# code 500.
@app.request_middleware
def check_content_type_header(request: Request):
content_type = request.get_header("content-type")
# They are executed in order and take as input the HTTP request object and
# the handler or the next middleware in the stack.
# A middleware should return a `Response`, either by calling `next` with `Request`
# to get `Response` from the handler or by constructing `Response` by itself.
# It can also modify the `Request` before calling `next` or it can also modify
# the `Response` returned by the handler.
# It can also raise an `MiddlewareException` with custom error message and HTTP status code,
# any other raised exceptions will cause an internal server error response to be returned.

# Next is either the next middleware in the stack or the handler.
Next = Callable[[Request], Awaitable[Response]]

# This middleware checks the `Content-Type` from the request header,
# logs some information depending on that and then calls `next`.
@app.middleware
async def check_content_type_header(request: Request, next: Next) -> Response:
content_type = request.headers.get("content-type")
if content_type == "application/json":
logging.debug("Found valid `application/json` content type")
else:
logging.warning(
f"Invalid content type {content_type}, dumping headers: {request.headers()}"
f"Invalid content type {content_type}, dumping headers: {request.headers}"
)
return await next(request)


# This middleware adds a new header called `x-amzn-answer` to the
# request. We expect to see this header to be populated in the next
# middleware.
@app.request_middleware
def add_x_amzn_answer_header(request: Request):
request.set_header("x-amzn-answer", "42")
@app.middleware
async def add_x_amzn_answer_header(request: Request, next: Next) -> Response:
request.headers["x-amzn-answer"] = "42"
logging.debug("Setting `x-amzn-answer` header to 42")
return request
return await next(request)


# This middleware checks if the header `x-amzn-answer` is correctly set
# to 42, otherwise it returns an exception with a set status code.
@app.request_middleware
async def check_x_amzn_answer_header(request: Request):
@app.middleware
async def check_x_amzn_answer_header(request: Request, next: Next) -> Response:
# Check that `x-amzn-answer` is 42.
if request.get_header("x-amzn-answer") != "42":
if request.headers.get("x-amzn-answer") != "42":
# Return an HTTP 401 Unauthorized if the content type is not JSON.
raise MiddlewareException("Invalid answer", 401)
return await next(request)


###########################################################
Expand Down Expand Up @@ -216,7 +227,11 @@ def check_health(_: CheckHealthInput) -> CheckHealthOutput:

# Stream a random Pokémon song.
@app.stream_pokemon_radio
async def stream_pokemon_radio(_: StreamPokemonRadioInput, context: Context):
async def stream_pokemon_radio(
_: StreamPokemonRadioInput, context: Context
) -> StreamPokemonRadioOutput:
import aiohttp

radio_url = context.get_random_radio_stream()
logging.info("Random radio URL for this stream is %s", radio_url)
async with aiohttp.ClientSession() as session:
Expand All @@ -229,7 +244,9 @@ async def stream_pokemon_radio(_: StreamPokemonRadioInput, context: Context):
###########################################################
# Run the server.
###########################################################
def main():
def main() -> None:
app.run(workers=1)
Comment on lines -232 to 248
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a separate main function? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not required but having a explicitly defined entry point is kinda useful when you want to setup your service via a supervisor or something


main()

if __name__ == "__main__":
main()
7 changes: 6 additions & 1 deletion rust-runtime/aws-smithy-http-server-python/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ impl PyMiddlewareException {

impl From<PyErr> for PyMiddlewareException {
fn from(other: PyErr) -> Self {
Self::newpy(other.to_string(), None)
// Try to extract `PyMiddlewareException` from `PyErr` and use that if succeed
let middleware_err = Python::with_gil(|py| other.to_object(py).extract::<Self>(py));
match middleware_err {
Ok(err) => err,
Err(_) => Self::newpy(other.to_string(), None),
Comment on lines +67 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we probably want a log statement capturing the details of the exception that we failed to extract since it didn't match the one we expected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error on the Err(_) variant will always be 'TypeError' object cannot be converted to 'MiddlewareException' here because it is a conversion error and the real error will be logged at https://github.com/awslabs/smithy-rs/blob/e12136563d734e3c1ea40d96a1705d3b1f5d9cbe/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs#L124

}
}
}

Expand Down
Loading