diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 2dd7c76aee..777a6b6c1b 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -107,7 +107,7 @@ class PythonApplicationGenerator( ##[derive(Debug)] pub struct App { handlers: #{HashMap}, - middlewares: #{SmithyPython}::PyMiddlewares, + middlewares: Vec<#{SmithyPython}::PyMiddlewareHandler>, context: Option<#{pyo3}::PyObject>, workers: #{parking_lot}::Mutex>, } @@ -141,7 +141,7 @@ class PythonApplicationGenerator( fn default() -> Self { Self { handlers: Default::default(), - middlewares: #{SmithyPython}::PyMiddlewares::new::<#{Protocol}>(vec![]), + middlewares: vec![], context: None, workers: #{parking_lot}::Mutex::new(vec![]), } @@ -171,9 +171,6 @@ class PythonApplicationGenerator( fn handlers(&mut self) -> &mut #{HashMap} { &mut self.handlers } - fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares { - &mut self.middlewares - } """, *codegenScope, ) @@ -212,13 +209,22 @@ class PythonApplicationGenerator( } rustTemplate( """ - let middleware_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop); - let service = #{tower}::ServiceBuilder::new() - .boxed_clone() - .layer( - #{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(self.middlewares.clone(), middleware_locals), - ) - .service(builder.build()); + let mut service = #{tower}::util::BoxCloneService::new(builder.build()); + + { + use #{tower}::Layer; + #{tracing}::trace!("adding middlewares to rust python router"); + let mut middlewares = self.middlewares.clone(); + // Reverse the middlewares, so they run with same order as they defined + middlewares.reverse(); + for handler in middlewares { + #{tracing}::trace!(name = &handler.name, "adding python middleware"); + let locals = #{pyo3_asyncio}::TaskLocals::new(event_loop); + let layer = #{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(handler, locals); + service = #{tower}::util::BoxCloneService::new(layer.layer(service)); + } + } + Ok(service) """, "Protocol" to protocol.markerStruct(), @@ -248,11 +254,17 @@ class PythonApplicationGenerator( pub fn context(&mut self, context: #{pyo3}::PyObject) { self.context = Some(context); } - /// Register a request middleware function that will be run inside a Tower layer, without cloning the body. + /// Register a Python function to be executed inside a Tower middleware layer. ##[pyo3(text_signature = "(${'$'}self, func)")] - pub fn request_middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { - use #{SmithyPython}::PyApp; - self.register_middleware(py, func, #{SmithyPython}::PyMiddlewareType::Request) + pub fn middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { + let handler = #{SmithyPython}::PyMiddlewareHandler::new(py, func)?; + #{tracing}::trace!( + name = &handler.name, + is_coroutine = handler.is_coroutine, + "registering middleware function", + ); + self.middlewares.push(handler); + Ok(()) } /// Main entrypoint: start the server on multiple workers. ##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")] diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index 4736719cec..5f2502b408 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -152,7 +152,6 @@ class PythonServerModuleGenerator( middleware.add_class::<#{SmithyPython}::PyRequest>()?; middleware.add_class::<#{SmithyPython}::PyResponse>()?; middleware.add_class::<#{SmithyPython}::PyMiddlewareException>()?; - middleware.add_class::<#{SmithyPython}::PyHttpVersion>()?; pyo3::py_run!( py, middleware, diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index 076a757014..c3d1f19146 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -91,7 +91,7 @@ class PythonServerOperationHandlerGenerator( writable { rustTemplate( """ - #{tracing}::debug!("Executing Python handler function `$name()`"); + #{tracing}::trace!(name = "$name", "executing python handler function"); #{pyo3}::Python::with_gil(|py| { let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?; let output = if handler.args == 1 { @@ -110,7 +110,7 @@ class PythonServerOperationHandlerGenerator( writable { rustTemplate( """ - #{tracing}::debug!("Executing Python handler coroutine `$name()`"); + #{tracing}::trace!(name = "$name", "executing python handler coroutine"); let result = #{pyo3}::Python::with_gil(|py| { let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?; let coroutine = if handler.args == 1 { @@ -132,15 +132,9 @@ class PythonServerOperationHandlerGenerator( """ // Catch and record a Python traceback. result.map_err(|e| { - let traceback = #{pyo3}::Python::with_gil(|py| { - match e.traceback(py) { - Some(t) => t.format().unwrap_or_else(|e| e.to_string()), - None => "Unknown traceback\n".to_string() - } - }); - let error = e.into(); - #{tracing}::error!("{}{}", traceback, error); - error + let rich_py_err = #{SmithyPython}::rich_py_err(#{pyo3}::Python::with_gil(|py| { e.clone_ref(py) })); + #{tracing}::error!(error = ?rich_py_err, "handler error"); + e.into() }) """, *codegenScope, diff --git a/rust-runtime/aws-smithy-http-server-python/Cargo.toml b/rust-runtime/aws-smithy-http-server-python/Cargo.toml index 1a4e4098b0..29306449a4 100644 --- a/rust-runtime/aws-smithy-http-server-python/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server-python/Cargo.toml @@ -41,6 +41,14 @@ tracing-appender = { version = "0.2.2"} [dev-dependencies] pretty_assertions = "1" futures-util = "0.3" +tower-test = "0.4" +tokio-test = "0.4" +pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime"] } + +[[test]] +name = "middleware_tests" +path = "src/middleware/pytests/harness.rs" +harness = false [package.metadata.docs.rs] all-features = true diff --git a/rust-runtime/aws-smithy-http-server-python/examples/Makefile b/rust-runtime/aws-smithy-http-server-python/examples/Makefile index c3af9f378b..35299e78c8 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/Makefile +++ b/rust-runtime/aws-smithy-http-server-python/examples/Makefile @@ -31,6 +31,9 @@ build: codegen cargo build ln -sf $(DEBUG_SHARED_LIBRARY_SRC) $(SHARED_LIBRARY_DST) +py_check: build + mypy pokemon_service.py + release: codegen cargo build --release ln -sf $(RELEASE_SHARED_LIBRARY_SRC) $(SHARED_LIBRARY_DST) diff --git a/rust-runtime/aws-smithy-http-server-python/examples/libpokemon_service_server_sdk.pyi b/rust-runtime/aws-smithy-http-server-python/examples/libpokemon_service_server_sdk.pyi new file mode 100644 index 0000000000..ccbd5edb2e --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/examples/libpokemon_service_server_sdk.pyi @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# NOTE: This is manually created to surpass some mypy errors and it is incomplete, +# in future we will autogenerate correct stubs. + +from typing import Any, TypeVar, Callable + +F = TypeVar("F", bound=Callable[..., Any]) + +class App: + context: Any + run: Any + + def middleware(self, func: F) -> F: ... + def do_nothing(self, func: F) -> F: ... + def get_pokemon_species(self, func: F) -> F: ... + def get_server_statistics(self, func: F) -> F: ... + def check_health(self, func: F) -> F: ... + def stream_pokemon_radio(self, func: F) -> F: ... diff --git a/rust-runtime/aws-smithy-http-server-python/examples/mypy.ini b/rust-runtime/aws-smithy-http-server-python/examples/mypy.ini new file mode 100644 index 0000000000..c6a867948d --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/examples/mypy.ini @@ -0,0 +1,5 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +[mypy] +strict = True diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index 9264b31dcb..98a844d687 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -3,43 +3,52 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import itertools import logging import random from threading import Lock from dataclasses import dataclass -from typing import List, Optional - -import aiohttp +from typing import List, Optional, Callable, Awaitable from libpokemon_service_server_sdk import App -from libpokemon_service_server_sdk.error import ResourceNotFoundException -from libpokemon_service_server_sdk.input import ( - DoNothingInput, GetPokemonSpeciesInput, GetServerStatisticsInput, - CheckHealthInput, StreamPokemonRadioInput) -from libpokemon_service_server_sdk.logging import TracingHandler -from libpokemon_service_server_sdk.middleware import (MiddlewareException, - Request) -from libpokemon_service_server_sdk.model import FlavorText, Language -from libpokemon_service_server_sdk.output import ( - DoNothingOutput, GetPokemonSpeciesOutput, GetServerStatisticsOutput, - CheckHealthOutput, StreamPokemonRadioOutput) -from libpokemon_service_server_sdk.types import ByteStream +from libpokemon_service_server_sdk.error import ResourceNotFoundException # type: ignore +from libpokemon_service_server_sdk.input import ( # type: ignore + DoNothingInput, + GetPokemonSpeciesInput, + GetServerStatisticsInput, + CheckHealthInput, + StreamPokemonRadioInput, +) +from libpokemon_service_server_sdk.logging import TracingHandler # type: ignore +from libpokemon_service_server_sdk.middleware import ( # type: ignore + MiddlewareException, + Response, + Request, +) +from libpokemon_service_server_sdk.model import FlavorText, Language # type: ignore +from libpokemon_service_server_sdk.output import ( # type: ignore + DoNothingOutput, + GetPokemonSpeciesOutput, + GetServerStatisticsOutput, + CheckHealthOutput, + StreamPokemonRadioOutput, +) +from libpokemon_service_server_sdk.types import ByteStream # type: ignore # Logging can bee setup using standard Python tooling. We provide # fast logging handler, Tracingandler based on Rust tracing crate. logging.basicConfig(handlers=[TracingHandler(level=logging.DEBUG).handler()]) + class SafeCounter: - def __init__(self): + def __init__(self) -> None: self._val = 0 self._lock = Lock() - def increment(self): + def increment(self) -> None: with self._lock: self._val += 1 - def value(self): + def value(self) -> int: with self._lock: return self._val @@ -63,9 +72,9 @@ def value(self): # # Synchronization: # Instance of `Context` class will be cloned for every worker and all state kept in `Context` -# will be specific to that process. There is no protection provided by default, -# it is up to you to have synchronization between processes. -# If you really want to share state between different processes you need to use `multiprocessing` primitives: +# will be specific to that process. There is no protection provided by default, +# it is up to you to have synchronization between processes. +# If you really want to share state between different processes you need to use `multiprocessing` primitives: # https://docs.python.org/3/library/multiprocessing.html#sharing-state-between-processes @dataclass class Context: @@ -124,49 +133,51 @@ def get_random_radio_stream(self) -> str: # Middleware ############################################################ # Middlewares are sync or async function decorated by `@app.middleware`. -# They are executed in order and take as input the HTTP request object. -# A middleware can return multiple values, following these rules: -# * Middleware not returning will let the execution continue without -# changing the original request. -# * Middleware returning a modified Request will update the original -# request before continuing the execution. -# * Middleware returning a Response will immediately terminate the request -# handling and return the response constructed from Python. -# * Middleware raising MiddlewareException will immediately terminate the -# request handling and return a protocol specific error, with the option of -# setting the HTTP return code. -# * Middleware raising any other exception will immediately terminate the -# request handling and return a protocol specific error, with HTTP status -# code 500. -@app.request_middleware -def check_content_type_header(request: Request): - content_type = request.get_header("content-type") +# They are executed in order and take as input the HTTP request object and +# the handler or the next middleware in the stack. +# A middleware should return a `Response`, either by calling `next` with `Request` +# to get `Response` from the handler or by constructing `Response` by itself. +# It can also modify the `Request` before calling `next` or it can also modify +# the `Response` returned by the handler. +# It can also raise an `MiddlewareException` with custom error message and HTTP status code, +# any other raised exceptions will cause an internal server error response to be returned. + +# Next is either the next middleware in the stack or the handler. +Next = Callable[[Request], Awaitable[Response]] + +# This middleware checks the `Content-Type` from the request header, +# logs some information depending on that and then calls `next`. +@app.middleware +async def check_content_type_header(request: Request, next: Next) -> Response: + content_type = request.headers.get("content-type") if content_type == "application/json": logging.debug("Found valid `application/json` content type") else: logging.warning( - f"Invalid content type {content_type}, dumping headers: {request.headers()}" + f"Invalid content type {content_type}, dumping headers: {request.headers}" ) + return await next(request) # This middleware adds a new header called `x-amzn-answer` to the # request. We expect to see this header to be populated in the next # middleware. -@app.request_middleware -def add_x_amzn_answer_header(request: Request): - request.set_header("x-amzn-answer", "42") +@app.middleware +async def add_x_amzn_answer_header(request: Request, next: Next) -> Response: + request.headers["x-amzn-answer"] = "42" logging.debug("Setting `x-amzn-answer` header to 42") - return request + return await next(request) # This middleware checks if the header `x-amzn-answer` is correctly set # to 42, otherwise it returns an exception with a set status code. -@app.request_middleware -async def check_x_amzn_answer_header(request: Request): +@app.middleware +async def check_x_amzn_answer_header(request: Request, next: Next) -> Response: # Check that `x-amzn-answer` is 42. - if request.get_header("x-amzn-answer") != "42": + if request.headers.get("x-amzn-answer") != "42": # Return an HTTP 401 Unauthorized if the content type is not JSON. raise MiddlewareException("Invalid answer", 401) + return await next(request) ########################################################### @@ -216,7 +227,11 @@ def check_health(_: CheckHealthInput) -> CheckHealthOutput: # Stream a random Pokémon song. @app.stream_pokemon_radio -async def stream_pokemon_radio(_: StreamPokemonRadioInput, context: Context): +async def stream_pokemon_radio( + _: StreamPokemonRadioInput, context: Context +) -> StreamPokemonRadioOutput: + import aiohttp + radio_url = context.get_random_radio_stream() logging.info("Random radio URL for this stream is %s", radio_url) async with aiohttp.ClientSession() as session: @@ -229,7 +244,9 @@ async def stream_pokemon_radio(_: StreamPokemonRadioInput, context: Context): ########################################################### # Run the server. ########################################################### -def main(): +def main() -> None: app.run(workers=1) -main() + +if __name__ == "__main__": + main() diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index b2cf345bcd..06e20e9b52 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -63,7 +63,12 @@ impl PyMiddlewareException { impl From for PyMiddlewareException { fn from(other: PyErr) -> Self { - Self::newpy(other.to_string(), None) + // Try to extract `PyMiddlewareException` from `PyErr` and use that if succeed + let middleware_err = Python::with_gil(|py| other.to_object(py).extract::(py)); + match middleware_err { + Ok(err) => err, + Err(_) => Self::newpy(other.to_string(), None), + } } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index 793104af59..ee8e8b4b9b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -17,19 +17,20 @@ pub mod middleware; mod server; mod socket; pub mod types; +mod util; #[doc(inline)] pub use error::{PyError, PyMiddlewareException}; #[doc(inline)] pub use logging::{py_tracing_event, PyTracingHandler}; #[doc(inline)] -pub use middleware::{ - PyHttpVersion, PyMiddlewareLayer, PyMiddlewareType, PyMiddlewares, PyRequest, PyResponse, -}; +pub use middleware::{PyMiddlewareHandler, PyMiddlewareLayer, PyRequest, PyResponse}; #[doc(inline)] pub use server::{PyApp, PyHandler}; #[doc(inline)] pub use socket::PySocket; +#[doc(inline)] +pub use util::error::{rich_py_err, RichPyErr}; #[cfg(test)] mod tests { diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index c7894d0c5e..2492096269 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -148,7 +148,6 @@ pub fn py_tracing_event( filename = filename, lineno = lineno ); - println!("message2: {message}"); let _guard = span.enter(); match level { 40 => tracing::error!("{message}"), diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs new file mode 100644 index 0000000000..fa81f8faee --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/error.rs @@ -0,0 +1,24 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use pyo3::{exceptions::PyRuntimeError, PyErr}; +use thiserror::Error; + +/// Possible middleware errors that might arise. +#[derive(Error, Debug)] +pub enum PyMiddlewareError { + #[error("`next` is called multiple times")] + NextAlreadyCalled, + #[error("request is accessed after `next` is called")] + RequestGone, + #[error("response is called after it is returned")] + ResponseGone, +} + +impl From for PyErr { + fn from(err: PyMiddlewareError) -> PyErr { + PyRuntimeError::new_err(err.to_string()) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index 68fdc12acc..668ab363b3 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -3,21 +3,64 @@ * SPDX-License-Identifier: Apache-2.0 */ -//! Execute Python middleware handlers. -use aws_smithy_http_server::{body::Body, body::BoxBody, response::IntoResponse}; -use http::Request; -use pyo3::prelude::*; +//! Execute pure-Python middleware handler. +use aws_smithy_http_server::body::{Body, BoxBody}; +use http::{Request, Response}; +use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyFunction}; use pyo3_asyncio::TaskLocals; +use tower::{util::BoxService, BoxError, Service}; -use crate::{PyMiddlewareException, PyRequest, PyResponse}; +use crate::util::func_metadata; -use super::PyFuture; +use super::{PyMiddlewareError, PyRequest, PyResponse}; -#[derive(Debug, Clone, Copy)] -pub enum PyMiddlewareType { - Request, - Response, +// PyNextInner represents the inner service Tower layer applied to. +type PyNextInner = BoxService, Response, BoxError>; + +// PyNext wraps inner Tower service and makes it callable from Python. +#[pyo3::pyclass] +struct PyNext(Option); + +impl PyNext { + fn new(inner: PyNextInner) -> Self { + Self(Some(inner)) + } + + // Consumes self by taking the inner Tower service. + // This method would have been `into_inner(self) -> PyNextInner` + // but we can't do that because we are crossing Python boundary. + fn take_inner(&mut self) -> Option { + self.0.take() + } +} + +#[pyo3::pymethods] +impl PyNext { + // Calls the inner Tower service with the `Request` that is passed from Python. + // It returns a coroutine to be awaited on the Python side to complete the call. + // Note that it takes wrapped objects from both `PyRequest` and `PyNext`, + // so after calling `next`, consumer can't access to the `Request` or + // can't call the `next` again, this basically emulates consuming `self` and `Request`, + // but since we are crossing the Python boundary we can't express it in natural Rust terms. + // + // Naming the method `__call__` allows `next` to be called like `next(...)`. + fn __call__<'p>(&'p mut self, py: Python<'p>, py_req: Py) -> PyResult<&'p PyAny> { + let req = py_req + .borrow_mut(py) + .take_inner() + .ok_or(PyMiddlewareError::RequestGone)?; + let mut inner = self + .take_inner() + .ok_or(PyMiddlewareError::NextAlreadyCalled)?; + pyo3_asyncio::tokio::future_into_py(py, async move { + let res = inner + .call(req) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + Ok(Python::with_gil(|py| PyResponse::new(res).into_py(py))) + }) + } } /// A Python middleware handler function representation. @@ -29,313 +72,54 @@ pub struct PyMiddlewareHandler { pub name: String, pub func: PyObject, pub is_coroutine: bool, - pub _type: PyMiddlewareType, -} - -/// Structure holding the list of Python middlewares that will be executed by this server. -/// -/// Middlewares are executed one after each other inside the [crate::PyMiddlewareLayer] Tower layer. -#[derive(Debug, Clone)] -pub struct PyMiddlewares { - handlers: Vec, - into_response: fn(PyMiddlewareException) -> http::Response, } -impl PyMiddlewares { - /// Create a new instance of `PyMiddlewareHandlers` from a list of heandlers. - pub fn new

(handlers: Vec) -> Self - where - PyMiddlewareException: IntoResponse

, - { - Self { - handlers, - into_response: PyMiddlewareException::into_response, - } - } - - /// Add a new handler to the list. - pub fn push(&mut self, handler: PyMiddlewareHandler) { - self.handlers.push(handler); +impl PyMiddlewareHandler { + pub fn new(py: Python, func: PyObject) -> PyResult { + let func_metadata = func_metadata(py, &func)?; + Ok(Self { + name: func_metadata.name, + func, + is_coroutine: func_metadata.is_coroutine, + }) } - /// Execute a single middleware handler. - /// - /// The handler is scheduled on the Python interpreter syncronously or asynchronously, - /// dependening on the handler signature. - async fn execute_middleware( - request: PyRequest, - handler: PyMiddlewareHandler, - ) -> Result<(Option, Option), PyMiddlewareException> { - let handle: PyResult> = if handler.is_coroutine { - tracing::debug!("Executing Python middleware coroutine `{}`", handler.name); - let result = pyo3::Python::with_gil(|py| { - let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; - let coroutine = pyhandler.call1((request,))?; - pyo3_asyncio::tokio::into_future(coroutine) - })?; - let output = result.await?; - Ok(output) - } else { - tracing::debug!("Executing Python middleware function `{}`", handler.name); - pyo3::Python::with_gil(|py| { - let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; - let output = pyhandler.call1((request,))?; - Ok(output.into_py(py)) + // Calls pure-Python middleware handler with given `Request` and the next Tower service + // and returns the `Response` that returned from the pure-Python handler. + pub async fn call( + self, + req: Request, + next: PyNextInner, + locals: TaskLocals, + ) -> PyResult> { + let py_req = PyRequest::new(req); + let py_next = PyNext::new(next); + + let handler = self.func; + let result = if self.is_coroutine { + pyo3_asyncio::tokio::scope(locals, async move { + Python::with_gil(|py| { + let py_handler: &PyFunction = handler.extract(py)?; + let output = py_handler.call1((py_req, py_next))?; + pyo3_asyncio::tokio::into_future(output) + })? + .await }) + .await? + } else { + Python::with_gil(|py| { + let py_handler: &PyFunction = handler.extract(py)?; + let output = py_handler.call1((py_req, py_next))?; + Ok::<_, PyErr>(output.into()) + })? }; - Python::with_gil(|py| match handle { - Ok(result) => { - if let Ok(request) = result.extract::(py) { - return Ok((Some(request), None)); - } - if let Ok(response) = result.extract::(py) { - return Ok((None, Some(response))); - } - Ok((None, None)) - } - Err(e) => pyo3::Python::with_gil(|py| { - let traceback = match e.traceback(py) { - Some(t) => t.format().unwrap_or_else(|e| e.to_string()), - None => "Unknown traceback\n".to_string(), - }; - tracing::error!("{}{}", traceback, e); - let variant = e.value(py); - if let Ok(v) = variant.extract::() { - Err(v) - } else { - Err(e.into()) - } - }), - }) - } - - /// Execute all the available Python middlewares in order of registration. - /// - /// Once the response is returned by the Python interpreter, different scenarios can happen: - /// * Middleware not returning will let the execution continue to the next middleware without - /// changing the original request. - /// * Middleware returning a modified [PyRequest] will update the original request before - /// continuing the execution of the next middleware. - /// * Middleware returning a [PyResponse] will immediately terminate the request handling and - /// return the response constructed from Python. - /// * Middleware raising [PyMiddlewareException] will immediately terminate the request handling - /// and return a protocol specific error, with the option of setting the HTTP return code. - /// * Middleware raising any other exception will immediately terminate the request handling and - /// return a protocol specific error, with HTTP status code 500. - pub fn run(&mut self, mut request: Request, locals: TaskLocals) -> PyFuture { - let handlers = self.handlers.clone(); - let into_response = self.into_response; - // Run all Python handlers in a loop. - Box::pin(async move { - tracing::debug!("Executing Python middleware stack"); - for handler in handlers { - let name = handler.name.clone(); - let pyrequest = PyRequest::new(&request); - let loop_locals = locals.clone(); - let result = pyo3_asyncio::tokio::scope( - loop_locals, - Self::execute_middleware(pyrequest, handler), - ) - .await; - match result { - Ok((pyrequest, pyresponse)) => { - if let Some(pyrequest) = pyrequest { - if let Ok(headers) = (&pyrequest.headers).try_into() { - tracing::debug!("Python middleware `{name}` returned an HTTP request, override headers with middleware's one"); - *request.headers_mut() = headers; - } - } - if let Some(pyresponse) = pyresponse { - tracing::debug!( - "Python middleware `{name}` returned a HTTP response, exit middleware loop" - ); - return Err(pyresponse.into()); - } - } - Err(e) => { - tracing::debug!( - "Middleware `{name}` returned an error, exit middleware loop" - ); - return Err((into_response)(e)); - } - } - } - tracing::debug!( - "Python middleware execution finised, returning the request to operation handler" - ); - Ok(request) - }) - } -} - -#[cfg(test)] -mod tests { - use aws_smithy_http_server::proto::rest_json_1::RestJson1; - use http::HeaderValue; - use hyper::body::to_bytes; - use pretty_assertions::assert_eq; - - use super::*; - - #[tokio::test] - async fn request_middleware_chain_keeps_headers_changes() -> PyResult<()> { - let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares::new::(vec![]); - - Python::with_gil(|py| { - let middleware = PyModule::new(py, "middleware").unwrap(); - middleware.add_class::().unwrap(); - middleware.add_class::().unwrap(); - let pycode = r#" -def first_middleware(request: Request): - request.set_header("x-amzn-answer", "42") - return request - -def second_middleware(request: Request): - if request.get_header("x-amzn-answer") != "42": - raise MiddlewareException("wrong answer", 401) -"#; - py.run(pycode, Some(middleware.dict()), None)?; - let all = middleware.index()?; - let first_middleware = PyMiddlewareHandler { - func: middleware.getattr("first_middleware")?.into_py(py), - is_coroutine: false, - name: "first".to_string(), - _type: PyMiddlewareType::Request, - }; - all.append("first_middleware")?; - middlewares.push(first_middleware); - let second_middleware = PyMiddlewareHandler { - func: middleware.getattr("second_middleware")?.into_py(py), - is_coroutine: false, - name: "second".to_string(), - _type: PyMiddlewareType::Request, - }; - all.append("second_middleware")?; - middlewares.push(second_middleware); - Ok::<(), PyErr>(()) - })?; - - let result = middlewares - .run(Request::builder().body(Body::from("")).unwrap(), locals) - .await - .unwrap(); - assert_eq!( - result.headers().get("x-amzn-answer"), - Some(&HeaderValue::from_static("42")) - ); - Ok(()) - } - - #[tokio::test] - async fn request_middleware_return_response() -> PyResult<()> { - let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares::new::(vec![]); - - Python::with_gil(|py| { - let middleware = PyModule::new(py, "middleware").unwrap(); - middleware.add_class::().unwrap(); - middleware.add_class::().unwrap(); - let pycode = r#" -def middleware(request: Request): - return Response(200, {}, b"something")"#; - py.run(pycode, Some(middleware.dict()), None)?; - let all = middleware.index()?; - let middleware = PyMiddlewareHandler { - func: middleware.getattr("middleware")?.into_py(py), - is_coroutine: false, - name: "middleware".to_string(), - _type: PyMiddlewareType::Request, - }; - all.append("middleware")?; - middlewares.push(middleware); - Ok::<(), PyErr>(()) - })?; - - let result = middlewares - .run(Request::builder().body(Body::from("")).unwrap(), locals) - .await - .unwrap_err(); - assert_eq!(result.status(), 200); - let body = to_bytes(result.into_body()).await.unwrap(); - assert_eq!(body, "something".as_bytes()); - Ok(()) - } - - #[tokio::test] - async fn request_middleware_raise_middleware_exception() -> PyResult<()> { - let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares::new::(vec![]); - - Python::with_gil(|py| { - let middleware = PyModule::new(py, "middleware").unwrap(); - middleware.add_class::().unwrap(); - middleware.add_class::().unwrap(); - let pycode = r#" -def middleware(request: Request): - raise MiddlewareException("error", 503)"#; - py.run(pycode, Some(middleware.dict()), None)?; - let all = middleware.index()?; - let middleware = PyMiddlewareHandler { - func: middleware.getattr("middleware")?.into_py(py), - is_coroutine: false, - name: "middleware".to_string(), - _type: PyMiddlewareType::Request, - }; - all.append("middleware")?; - middlewares.push(middleware); - Ok::<(), PyErr>(()) - })?; - - let result = middlewares - .run(Request::builder().body(Body::from("")).unwrap(), locals) - .await - .unwrap_err(); - assert_eq!(result.status(), 503); - assert_eq!( - result.headers().get("X-Amzn-Errortype"), - Some(&HeaderValue::from_static("MiddlewareException")) - ); - let body = to_bytes(result.into_body()).await.unwrap(); - assert_eq!(body, r#"{"message":"error"}"#.as_bytes()); - Ok(()) - } - - #[tokio::test] - async fn request_middleware_raise_python_exception() -> PyResult<()> { - let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares::new::(vec![]); - Python::with_gil(|py| { - let middleware = PyModule::from_code( - py, - r#" -def middleware(request): - raise ValueError("error")"#, - "", - "", - )?; - let middleware = PyMiddlewareHandler { - func: middleware.getattr("middleware")?.into_py(py), - is_coroutine: false, - name: "middleware".to_string(), - _type: PyMiddlewareType::Request, - }; - middlewares.push(middleware); - Ok::<(), PyErr>(()) + let response = Python::with_gil(|py| { + let py_res: Py = result.extract(py)?; + let mut py_res = py_res.borrow_mut(py); + Ok::<_, PyErr>(py_res.take_inner()) })?; - let result = middlewares - .run(Request::builder().body(Body::from("")).unwrap(), locals) - .await - .unwrap_err(); - assert_eq!(result.status(), 500); - assert_eq!( - result.headers().get("X-Amzn-Errortype"), - Some(&HeaderValue::from_static("MiddlewareException")) - ); - let body = to_bytes(result.into_body()).await.unwrap(); - assert_eq!(body, r#"{"message":"ValueError: error"}"#.as_bytes()); - Ok(()) + response.ok_or_else(|| PyMiddlewareError::ResponseGone.into()) } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/header_map.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/header_map.rs new file mode 100644 index 0000000000..1c953a175b --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/header_map.rs @@ -0,0 +1,159 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::{mem, str::FromStr, sync::Arc}; + +use http::{header::HeaderName, HeaderMap, HeaderValue}; +use parking_lot::Mutex; +use pyo3::{ + exceptions::{PyKeyError, PyValueError}, + pyclass, PyErr, PyResult, +}; + +use crate::{mutable_mapping_pymethods, util::collection::PyMutableMapping}; + +/// Python-compatible [HeaderMap] object. +#[pyclass(mapping)] +#[derive(Clone, Debug)] +pub struct PyHeaderMap { + inner: Arc>, +} + +impl PyHeaderMap { + pub fn new(inner: HeaderMap) -> Self { + Self { + inner: Arc::new(Mutex::new(inner)), + } + } + + // Consumes self by taking the inner `HeaderMap`. + // This method would have been `into_inner(self) -> HeaderMap` + // but we can't do that because we are crossing Python boundary. + pub fn take_inner(&mut self) -> Option { + let header_map = mem::take(&mut self.inner); + let header_map = Arc::try_unwrap(header_map).ok()?; + let header_map = header_map.into_inner(); + Some(header_map) + } +} + +/// By implementing [PyMutableMapping] for [PyHeaderMap] we are making it to +/// behave like a dictionary on the Python. +impl PyMutableMapping for PyHeaderMap { + type Key = String; + type Value = String; + + fn len(&self) -> PyResult { + Ok(self.inner.lock().len()) + } + + fn contains(&self, key: Self::Key) -> PyResult { + Ok(self.inner.lock().contains_key(key)) + } + + fn keys(&self) -> PyResult> { + Ok(self.inner.lock().keys().map(|h| h.to_string()).collect()) + } + + fn values(&self) -> PyResult> { + self.inner + .lock() + .values() + .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error)) + .collect() + } + + fn get(&self, key: Self::Key) -> PyResult> { + self.inner + .lock() + .get(key) + .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error)) + .transpose() + } + + fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()> { + self.inner.lock().insert( + HeaderName::from_str(&key).map_err(to_value_error)?, + HeaderValue::from_str(&value).map_err(to_value_error)?, + ); + Ok(()) + } + + fn del(&mut self, key: Self::Key) -> PyResult<()> { + if self.inner.lock().remove(key).is_none() { + Err(PyKeyError::new_err("unknown key")) + } else { + Ok(()) + } + } +} + +mutable_mapping_pymethods!(PyHeaderMap, keys_iter: PyHeaderMapKeys); + +fn to_value_error(err: impl std::error::Error) -> PyErr { + PyValueError::new_err(err.to_string()) +} + +#[cfg(test)] +mod tests { + use http::header; + use pyo3::{prelude::*, py_run}; + + use super::*; + + #[test] + fn py_header_map() -> PyResult<()> { + pyo3::prepare_freethreaded_python(); + + let mut header_map = HeaderMap::new(); + header_map.insert(header::CONTENT_LENGTH, "42".parse().unwrap()); + header_map.insert(header::HOST, "localhost".parse().unwrap()); + + let header_map = Python::with_gil(|py| { + let py_header_map = PyHeaderMap::new(header_map); + let headers = PyCell::new(py, py_header_map)?; + py_run!( + py, + headers, + r#" +assert len(headers) == 2 +assert headers["content-length"] == "42" +assert headers["host"] == "localhost" + +headers["content-length"] = "45" +assert headers["content-length"] == "45" +headers["content-encoding"] = "application/json" +assert headers["content-encoding"] == "application/json" + +del headers["host"] +assert headers.get("host") == None +assert len(headers) == 2 + +assert set(headers.items()) == set([ + ("content-length", "45"), + ("content-encoding", "application/json") +]) +"# + ); + + Ok::<_, PyErr>(headers.borrow_mut().take_inner().unwrap()) + })?; + + assert_eq!( + header_map, + vec![ + (header::CONTENT_LENGTH, "45".parse().unwrap()), + ( + header::CONTENT_ENCODING, + "application/json".parse().unwrap() + ), + ] + .into_iter() + .collect() + ); + + Ok(()) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index 04e62fd732..a658e1bb06 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -4,9 +4,11 @@ */ //! Tower layer implementation of Python middleware handling. + use std::{ + convert::Infallible, marker::PhantomData, - pin::Pin, + mem, task::{Context, Poll}, }; @@ -14,29 +16,29 @@ use aws_smithy_http_server::{ body::{Body, BoxBody}, response::IntoResponse, }; -use futures::{ready, Future}; +use futures::{future::BoxFuture, TryFutureExt}; use http::{Request, Response}; -use pin_project_lite::pin_project; +use pyo3::Python; use pyo3_asyncio::TaskLocals; -use tower::{Layer, Service}; +use tower::{util::BoxService, Layer, Service, ServiceExt}; -use crate::{middleware::PyFuture, PyMiddlewareException, PyMiddlewares}; +use super::PyMiddlewareHandler; +use crate::{util::error::rich_py_err, PyMiddlewareException}; /// Tower [Layer] implementation of Python middleware handling. /// -/// Middleware stored in the `handlers` attribute will be executed, in order, -/// inside an async Tower middleware. +/// Middleware stored in the `handler` attribute will be executed inside an async Tower middleware. #[derive(Debug, Clone)] pub struct PyMiddlewareLayer

{ - handlers: PyMiddlewares, + handler: PyMiddlewareHandler, locals: TaskLocals, _protocol: PhantomData

, } impl

PyMiddlewareLayer

{ - pub fn new(handlers: PyMiddlewares, locals: TaskLocals) -> Self { + pub fn new(handler: PyMiddlewareHandler, locals: TaskLocals) -> Self { Self { - handlers, + handler, locals, _protocol: PhantomData, } @@ -50,175 +52,79 @@ where type Service = PyMiddlewareService; fn layer(&self, inner: S) -> Self::Service { - PyMiddlewareService::new(inner, self.handlers.clone(), self.locals.clone()) + PyMiddlewareService::new( + inner, + self.handler.clone(), + self.locals.clone(), + PyMiddlewareException::into_response, + ) } } -// Tower [Service] wrapping the Python middleware [Layer]. +/// Tower [Service] wrapping the Python middleware [Layer]. #[derive(Clone, Debug)] pub struct PyMiddlewareService { inner: S, - handlers: PyMiddlewares, + handler: PyMiddlewareHandler, locals: TaskLocals, + into_response: fn(PyMiddlewareException) -> http::Response, } impl PyMiddlewareService { - pub fn new(inner: S, handlers: PyMiddlewares, locals: TaskLocals) -> PyMiddlewareService { + pub fn new( + inner: S, + handler: PyMiddlewareHandler, + locals: TaskLocals, + into_response: fn(PyMiddlewareException) -> http::Response, + ) -> PyMiddlewareService { Self { inner, - handlers, + handler, locals, + into_response, } } } impl Service> for PyMiddlewareService where - S: Service, Response = Response> + Clone, + S: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + S::Future: Send, { - type Response = Response; - type Error = S::Error; - type Future = ResponseFuture; + type Response = S::Response; + // We are making `Service` `Infallible` because we convert errors to responses via + // `PyMiddlewareException::into_response` which has `IntoResponse` bound, + // so we always return a protocol specific error response instead of erroring out. + type Error = Infallible; + type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { - // TODO(Should we make this clone less expensive by wrapping inner in a Arc?) - let clone = self.inner.clone(); - // See https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services - let inner = std::mem::replace(&mut self.inner, clone); - let run = self.handlers.run(req, self.locals.clone()); - - ResponseFuture { - middleware: State::Running { run }, - service: inner, - } - } -} - -pin_project! { - /// Response future handling the state transition between a running and a done future. - pub struct ResponseFuture - where - S: Service>, - { - #[pin] - middleware: State, - service: S, - } -} - -pin_project! { - /// Representation of the result of the middleware execution. - #[project = StateProj] - enum State { - Running { - #[pin] - run: A, - }, - Done { - #[pin] - fut: Fut - } - } -} - -impl Future for ResponseFuture -where - S: Service, Response = Response>, -{ - type Output = Result, S::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - loop { - match this.middleware.as_mut().project() { - // Run the handler and store the future inside the inner state. - StateProj::Running { run } => { - let run = ready!(run.poll(cx)); - match run { - Ok(req) => { - let fut = this.service.call(req); - this.middleware.set(State::Done { fut }); - } - Err(res) => return Poll::Ready(Ok(res)), - } - } - // Execute the future returned by the layer. - StateProj::Done { fut } => return fut.poll(cx), - } - } - } -} - -#[cfg(test)] -mod tests { - use std::error::Error; - - use super::*; - - use aws_smithy_http_server::body::to_boxed; - use aws_smithy_http_server::proto::rest_json_1::RestJson1; - use pyo3::prelude::*; - use tower::{Service, ServiceBuilder, ServiceExt}; - - use crate::middleware::PyMiddlewareHandler; - use crate::{PyMiddlewareException, PyMiddlewareType, PyRequest}; - - async fn echo(req: Request) -> Result, Box> { - Ok(Response::new(to_boxed(req.into_body()))) - } - - #[tokio::test] - async fn request_middlewares_are_chained_inside_layer() -> PyResult<()> { - let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares::new::(vec![]); - - Python::with_gil(|py| { - let middleware = PyModule::new(py, "middleware").unwrap(); - middleware.add_class::().unwrap(); - middleware.add_class::().unwrap(); - let pycode = r#" -def first_middleware(request: Request): - request.set_header("x-amzn-answer", "42") - return request - -def second_middleware(request: Request): - if request.get_header("x-amzn-answer") != "42": - raise MiddlewareException("wrong answer", 401) -"#; - py.run(pycode, Some(middleware.dict()), None)?; - let all = middleware.index()?; - let first_middleware = PyMiddlewareHandler { - func: middleware.getattr("first_middleware")?.into_py(py), - is_coroutine: false, - name: "first".to_string(), - _type: PyMiddlewareType::Request, - }; - all.append("first_middleware")?; - middlewares.push(first_middleware); - let second_middleware = PyMiddlewareHandler { - func: middleware.getattr("second_middleware")?.into_py(py), - is_coroutine: false, - name: "second".to_string(), - _type: PyMiddlewareType::Request, - }; - all.append("second_middleware")?; - middlewares.push(second_middleware); - Ok::<(), PyErr>(()) - })?; - - let mut service = ServiceBuilder::new() - .layer(PyMiddlewareLayer::::new(middlewares, locals)) - .service_fn(echo); - - let request = Request::get("/").body(Body::empty()).unwrap(); - - let res = service.ready().await.unwrap().call(request).await.unwrap(); - - assert_eq!(res.status(), 200); - Ok(()) + let inner = { + // https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services + let clone = self.inner.clone(); + mem::replace(&mut self.inner, clone) + }; + let handler = self.handler.clone(); + let handler_name = handler.name.clone(); + let next = BoxService::new(inner.map_err(|err| err.into())); + let locals = self.locals.clone(); + let into_response = self.into_response; + + Box::pin( + handler + .call(req, next, locals) + .or_else(move |err| async move { + tracing::error!(error = ?rich_py_err(Python::with_gil(|py| err.clone_ref(py))), handler_name, "middleware failed"); + let response = (into_response)(err.into()); + Ok(response) + }), + ) } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs index a1a2d14ced..267f825873 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -3,21 +3,85 @@ * SPDX-License-Identifier: Apache-2.0 */ -//! Schedule pure Python middlewares as `Tower` layers. +//! Schedule pure-Python middlewares as [tower::Layer]s. +//! +//! # Moving data from Rust to Python and back +//! +//! In middlewares we need to move some data back-and-forth between Rust and Python. +//! When you move some data from Rust to Python you can't get its ownership back, +//! you can only get `&T` or `&mut T` but not `T` unless you clone it. +//! +//! In order to overcome this shortcoming we are using wrappers for Python that holds +//! pure-Rust types with [Option]s and provides `take_inner(&mut self) -> Option` +//! method to get the ownership of `T` back. +//! +//! For example: +//! ```no_run +//! # use pyo3::prelude::*; +//! # use pyo3::exceptions::PyRuntimeError; +//! # enum PyMiddlewareError { +//! # InnerGone +//! # } +//! # impl From for PyErr { +//! # fn from(_: PyMiddlewareError) -> PyErr { +//! # PyRuntimeError::new_err("inner gone") +//! # } +//! # } +//! // Pure Rust type +//! struct Inner { +//! num: i32 +//! } +//! +//! // Python wrapper +//! #[pyclass] +//! pub struct Wrapper(Option); +//! +//! impl Wrapper { +//! // Call when Python is done processing the `Wrapper` +//! // to get ownership of `Inner` back +//! pub fn take_inner(&mut self) -> Option { +//! self.0.take() +//! } +//! } +//! +//! // Python exposed methods checks if `Wrapper` still has the `Inner` and +//! // fails with `InnerGone` otherwise. +//! #[pymethods] +//! impl Wrapper { +//! #[getter] +//! fn num(&self) -> PyResult { +//! self.0 +//! .as_ref() +//! .map(|inner| inner.num) +//! .ok_or_else(|| PyMiddlewareError::InnerGone.into()) +//! } +//! +//! #[setter] +//! fn set_num(&mut self, num: i32) -> PyResult<()> { +//! match self.0.as_mut() { +//! Some(inner) => { +//! inner.num = num; +//! Ok(()) +//! } +//! None => Err(PyMiddlewareError::InnerGone.into()), +//! } +//! } +//! } +//! ``` +//! +//! You can see this pattern in [PyRequest], [PyResponse] and the others. +//! + +mod error; mod handler; +mod header_map; mod layer; mod request; mod response; -use aws_smithy_http_server::body::{Body, BoxBody}; -use futures::future::BoxFuture; -use http::{Request, Response}; - -pub use self::handler::{PyMiddlewareType, PyMiddlewares}; +pub use self::error::PyMiddlewareError; +pub use self::handler::PyMiddlewareHandler; +pub use self::header_map::PyHeaderMap; pub use self::layer::PyMiddlewareLayer; -pub use self::request::{PyHttpVersion, PyRequest}; +pub use self::request::PyRequest; pub use self::response::PyResponse; - -pub(crate) use self::handler::PyMiddlewareHandler; -/// Future type returned by the Python middleware handler. -pub(crate) type PyFuture = BoxFuture<'static, Result, Response>>; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/harness.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/harness.rs new file mode 100644 index 0000000000..1c01554c5b --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/harness.rs @@ -0,0 +1,13 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#[pyo3_asyncio::tokio::main] +async fn main() -> pyo3::PyResult<()> { + pyo3_asyncio::testing::main().await +} + +mod layer; +mod request; +mod response; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs new file mode 100644 index 0000000000..118b4e7cf8 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs @@ -0,0 +1,312 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::convert::Infallible; + +use aws_smithy_http_server::{ + body::{to_boxed, Body, BoxBody}, + proto::rest_json_1::RestJson1, +}; +use aws_smithy_http_server_python::{ + middleware::{PyMiddlewareHandler, PyMiddlewareLayer}, + PyMiddlewareException, PyResponse, +}; +use http::{Request, Response, StatusCode}; +use pretty_assertions::assert_eq; +use pyo3::{prelude::*, types::PyDict}; +use pyo3_asyncio::TaskLocals; +use tokio_test::assert_ready_ok; +use tower::{layer::util::Stack, util::BoxCloneService, Layer, Service, ServiceExt}; +use tower_test::mock; + +#[pyo3_asyncio::tokio::test] +async fn identity_middleware() -> PyResult<()> { + let layer = layer( + r#" +async def middleware(request, next): + return await next(request) +"#, + ); + let (mut service, mut handle) = spawn_service(layer); + + let th = tokio::spawn(async move { + let (req, send_response) = handle.next_request().await.unwrap(); + let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap(); + assert_eq!(req_body, "hello server"); + send_response.send_response( + Response::builder() + .body(to_boxed("hello client")) + .expect("could not create response"), + ); + }); + + let request = simple_request("hello server"); + let response = service.call(request); + assert_body(response.await?, "hello client").await; + + th.await.unwrap(); + Ok(()) +} + +#[pyo3_asyncio::tokio::test] +async fn returning_response_from_python_middleware() -> PyResult<()> { + let layer = layer( + r#" +def middleware(request, next): + return Response(200, {}, b"hello client from Python") +"#, + ); + let (mut service, _handle) = spawn_service(layer); + + let request = simple_request("hello server"); + let response = service.call(request); + assert_body(response.await?, "hello client from Python").await; + + Ok(()) +} + +#[pyo3_asyncio::tokio::test] +async fn convert_exception_from_middleware_to_protocol_specific_response() -> PyResult<()> { + let layer = layer( + r#" +def middleware(request, next): + raise RuntimeError("fail") +"#, + ); + let (mut service, _handle) = spawn_service(layer); + + let request = simple_request("hello server"); + let response = service.call(request); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_body(response, r#"{"message":"RuntimeError: fail"}"#).await; + + Ok(()) +} + +#[pyo3_asyncio::tokio::test] +async fn uses_status_code_and_message_from_middleware_exception() -> PyResult<()> { + let layer = layer( + r#" +def middleware(request, next): + raise MiddlewareException("access denied", 401) +"#, + ); + let (mut service, _handle) = spawn_service(layer); + + let request = simple_request("hello server"); + let response = service.call(request); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_body(response, r#"{"message":"access denied"}"#).await; + + Ok(()) +} + +#[pyo3_asyncio::tokio::test] +async fn nested_middlewares() -> PyResult<()> { + let first_layer = layer( + r#" +async def middleware(request, next): + return await next(request) +"#, + ); + let second_layer = layer( + r#" +def middleware(request, next): + return Response(200, {}, b"hello client from Python second middleware") +"#, + ); + let layer = Stack::new(first_layer, second_layer); + let (mut service, _handle) = spawn_service(layer); + + let request = simple_request("hello server"); + let response = service.call(request); + assert_body( + response.await?, + "hello client from Python second middleware", + ) + .await; + + Ok(()) +} + +#[pyo3_asyncio::tokio::test] +async fn changes_request() -> PyResult<()> { + let layer = layer( + r#" +async def middleware(request, next): + body = bytes(await request.body).decode() + body_reversed = body[::-1] + request.body = body_reversed.encode() + request.headers["X-From-Middleware"] = "yes" + return await next(request) +"#, + ); + let (mut service, mut handle) = spawn_service(layer); + + let th = tokio::spawn(async move { + let (req, send_response) = handle.next_request().await.unwrap(); + assert_eq!(&"yes", req.headers().get("X-From-Middleware").unwrap()); + let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap(); + assert_eq!(req_body, "hello server".chars().rev().collect::()); + send_response.send_response( + Response::builder() + .body(to_boxed("hello client")) + .expect("could not create response"), + ); + }); + + let request = simple_request("hello server"); + let response = service.call(request); + assert_body(response.await?, "hello client").await; + + th.await.unwrap(); + Ok(()) +} + +#[pyo3_asyncio::tokio::test] +async fn changes_response() -> PyResult<()> { + let layer = layer( + r#" +async def middleware(request, next): + response = await next(request) + body = bytes(await response.body).decode() + body_reversed = body[::-1] + response.body = body_reversed.encode() + response.headers["X-From-Middleware"] = "yes" + return response +"#, + ); + let (mut service, mut handle) = spawn_service(layer); + + let th = tokio::spawn(async move { + let (req, send_response) = handle.next_request().await.unwrap(); + let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap(); + assert_eq!(req_body, "hello server"); + send_response.send_response( + Response::builder() + .body(to_boxed("hello client")) + .expect("could not create response"), + ); + }); + + let request = simple_request("hello server"); + let response = service.call(request); + let response = response.await.unwrap(); + assert_eq!(response.headers().get("X-From-Middleware").unwrap(), &"yes"); + assert_body(response, &"hello client".chars().rev().collect::()).await; + + th.await.unwrap(); + Ok(()) +} + +#[pyo3_asyncio::tokio::test] +async fn fails_if_req_is_used_after_calling_next() -> PyResult<()> { + let layer = layer( + r#" +async def middleware(request, next): + uri = request.uri + response = await next(request) + uri = request.uri # <- fails + return response +"#, + ); + + let (mut service, mut handle) = spawn_service(layer); + + let th = tokio::spawn(async move { + let (req, send_response) = handle.next_request().await.unwrap(); + let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap(); + assert_eq!(req_body, "hello server"); + send_response.send_response( + Response::builder() + .body(to_boxed("hello client")) + .expect("could not create response"), + ); + }); + + let request = simple_request("hello server"); + let response = service.call(request); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_body( + response, + r#"{"message":"RuntimeError: request is accessed after `next` is called"}"#, + ) + .await; + + th.await.unwrap(); + Ok(()) +} + +async fn assert_body(response: Response, eq: &str) { + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + assert_eq!(body, eq); +} + +fn simple_request(body: &'static str) -> Request { + Request::builder() + .body(Body::from(body)) + .expect("could not create request") +} + +fn spawn_service( + layer: L, +) -> ( + mock::Spawn, + mock::Handle, Response>, +) +where + L: Layer, Response, Infallible>>, + L::Service: Service, Error = E>, + E: std::fmt::Debug, +{ + let (mut service, handle) = mock::spawn_with(|svc| { + let svc = svc + .map_err(|err| panic!("service failed: {err}")) + .boxed_clone(); + layer.layer(svc) + }); + assert_ready_ok!(service.poll_ready()); + (service, handle) +} + +fn layer(code: &str) -> PyMiddlewareLayer { + PyMiddlewareLayer::::new(py_handler(code), task_locals()) +} + +fn task_locals() -> TaskLocals { + Python::with_gil(|py| { + Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) + }) + .unwrap() +} + +fn py_handler(code: &str) -> PyMiddlewareHandler { + Python::with_gil(|py| { + // `py.run` under the hood uses `eval` (`PyEval_EvalCode` in C API) + // and by default if you pass a `global` object without `__builtins__` key + // it inserts `__builtins__` with reference to the `builtins` module + // which provides prelude for Python so you can access `print()`, `bytes()`, `len()` etc. + // but this is not working for Python 3.7.10 which is the version we are using in our CI + // so our tests are failing in CI because there is no `print()`, `bytes()` etc. + // in order to fix that we are manually extending `__main__` module to preserve `__builtins__`. + let globals = PyModule::import(py, "__main__")?.dict(); + globals.set_item( + "MiddlewareException", + py.get_type::(), + )?; + globals.set_item("Response", py.get_type::())?; + let locals = PyDict::new(py); + py.run(code, Some(globals), Some(locals))?; + let handler = locals + .get_item("middleware") + .expect("your handler must be named `middleware`") + .into(); + Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) + }) + .unwrap() +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/request.rs new file mode 100644 index 0000000000..be5d85f8c9 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/request.rs @@ -0,0 +1,73 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_http_server_python::PyRequest; +use http::{Request, Version}; +use hyper::Body; +use pyo3::{prelude::*, py_run}; + +#[pyo3_asyncio::tokio::test] +async fn accessing_request_properties() -> PyResult<()> { + let request = Request::builder() + .method("POST") + .uri("https://www.rust-lang.org/") + .header("Accept-Encoding", "*") + .header("X-Custom", "42") + .version(Version::HTTP_2) + .body(Body::from("hello world")) + .expect("could not build request"); + let py_request = PyRequest::new(request); + + Python::with_gil(|py| { + let req = PyCell::new(py, py_request)?; + py_run!( + py, + req, + r#" +assert req.method == "POST" +assert req.uri == "https://www.rust-lang.org/" +assert req.headers["accept-encoding"] == "*" +assert req.headers["x-custom"] == "42" +assert req.version == "HTTP/2.0" + +assert req.headers.get("x-foo") == None +req.headers["x-foo"] = "bar" +assert req.headers["x-foo"] == "bar" +"# + ); + Ok(()) + }) +} + +#[pyo3_asyncio::tokio::test] +async fn accessing_and_changing_request_body() -> PyResult<()> { + let request = Request::builder() + .body(Body::from("hello world")) + .expect("could not build request"); + let py_request = PyRequest::new(request); + + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +async def handler(req): + # TODO(Ergonomics): why we need to wrap with `bytes`? + assert bytes(await req.body) == b"hello world" + + req.body = b"hello world from middleware" + assert bytes(await req.body) == b"hello world from middleware" +"#, + "", + "", + )?; + let handler = module.getattr("handler")?; + + let output = handler.call1((py_request,))?; + Ok::<_, PyErr>(pyo3_asyncio::tokio::into_future(output)) + })?? + .await?; + + Ok(()) +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/response.rs new file mode 100644 index 0000000000..e9f1a48066 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/response.rs @@ -0,0 +1,106 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_http_server::body::to_boxed; +use aws_smithy_http_server_python::PyResponse; +use http::{Response, StatusCode, Version}; +use pyo3::{ + prelude::*, + py_run, + types::{IntoPyDict, PyDict}, +}; + +#[pyo3_asyncio::tokio::test] +async fn building_response_in_python() -> PyResult<()> { + let response = Python::with_gil(|py| { + let globals = [("Response", py.get_type::())].into_py_dict(py); + let locals = PyDict::new(py); + + py.run( + r#" +response = Response(200, {"Content-Type": "application/json"}, b"hello world") +"#, + Some(globals), + Some(locals), + ) + .unwrap(); + + let py_response: Py = locals.get_item("response").unwrap().extract().unwrap(); + let response = py_response.borrow_mut(py).take_inner(); + response.unwrap() + }); + + assert_eq!(response.status(), StatusCode::OK); + + let headers = response.headers(); + { + assert_eq!(headers.len(), 1); + assert_eq!(headers.get("Content-Type").unwrap(), "application/json"); + } + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + assert_eq!(body, "hello world"); + + Ok(()) +} + +#[pyo3_asyncio::tokio::test] +async fn accessing_response_properties() -> PyResult<()> { + let response = Response::builder() + .status(StatusCode::IM_A_TEAPOT) + .version(Version::HTTP_3) + .header("X-Secret", "42") + .body(to_boxed("hello world")) + .expect("could not build response"); + let py_response = PyResponse::new(response); + + Python::with_gil(|py| { + let res = PyCell::new(py, py_response)?; + py_run!( + py, + res, + r#" +assert res.status == 418 +assert res.version == "HTTP/3.0" +assert res.headers["x-secret"] == "42" + +assert res.headers.get("x-foo") == None +res.headers["x-foo"] = "bar" +assert res.headers["x-foo"] == "bar" +"# + ); + Ok(()) + }) +} + +#[pyo3_asyncio::tokio::test] +async fn accessing_and_changing_response_body() -> PyResult<()> { + let response = Response::builder() + .body(to_boxed("hello world")) + .expect("could not build response"); + let py_response = PyResponse::new(response); + + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +async def handler(res): + assert bytes(await res.body) == b"hello world" + + res.body = b"hello world from middleware" + assert bytes(await res.body) == b"hello world from middleware" +"#, + "", + "", + )?; + let handler = module.getattr("handler")?; + + let output = handler.call1((py_response,))?; + Ok::<_, PyErr>(pyo3_asyncio::tokio::into_future(output)) + })?? + .await?; + + Ok(()) +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs index 467d7dbb7b..16fc9d6d7f 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -4,120 +4,114 @@ */ //! Python-compatible middleware [http::Request] implementation. -use std::collections::HashMap; -use aws_smithy_http_server::body::Body; -use http::{Request, Version}; -use pyo3::prelude::*; +use std::mem; +use std::sync::Arc; -/// Python compabible HTTP [Version]. -#[pyclass(name = "HttpVersion")] -#[derive(PartialEq, PartialOrd, Copy, Clone, Eq, Ord, Hash)] -pub struct PyHttpVersion(Version); +use aws_smithy_http_server::body::Body; +use http::{request::Parts, Request}; +use pyo3::{exceptions::PyRuntimeError, prelude::*}; +use tokio::sync::Mutex; -#[pymethods] -impl PyHttpVersion { - /// Extract the value of the HTTP [Version] into a string that - /// can be used by Python. - #[pyo3(text_signature = "($self)")] - fn value(&self) -> &str { - match self.0 { - Version::HTTP_09 => "HTTP/0.9", - Version::HTTP_10 => "HTTP/1.0", - Version::HTTP_11 => "HTTP/1.1", - Version::HTTP_2 => "HTTP/2.0", - Version::HTTP_3 => "HTTP/3.0", - _ => unreachable!(), - } - } -} +use super::{PyHeaderMap, PyMiddlewareError}; /// Python-compatible [Request] object. -/// -/// For performance reasons, there is not support yet to pass the body to the Python middleware, -/// as it requires to consume and clone the body, which is a very expensive operation. -/// -/// TODO(if customers request for it, we can implemented an opt-in functionality to also pass -/// the body around). #[pyclass(name = "Request")] #[pyo3(text_signature = "(request)")] -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct PyRequest { - #[pyo3(get, set)] - method: String, - #[pyo3(get, set)] - uri: String, - // TODO(investigate if using a PyDict can make the experience more idiomatic) - // I'd like to be able to do request.headers.get("my-header") and - // request.headers["my-header"] = 42 instead of implementing set_header() and get_header() - // under pymethods. The same applies to response. - pub(crate) headers: HashMap, - version: Version, + parts: Option, + headers: PyHeaderMap, + body: Arc>>, } impl PyRequest { /// Create a new Python-compatible [Request] structure from the Rust side. - /// - /// This is done by cloning the headers, method, URI and HTTP version to let them be owned by Python. - pub fn new(request: &Request) -> Self { + pub fn new(request: Request) -> Self { + let (mut parts, body) = request.into_parts(); + let headers = mem::take(&mut parts.headers); Self { - method: request.method().to_string(), - uri: request.uri().to_string(), - headers: request - .headers() - .into_iter() - .map(|(k, v)| -> (String, String) { - let name: String = k.as_str().to_string(); - let value: String = String::from_utf8_lossy(v.as_bytes()).to_string(); - (name, value) - }) - .collect(), - version: request.version(), + parts: Some(parts), + headers: PyHeaderMap::new(headers), + body: Arc::new(Mutex::new(Some(body))), } } + + // Consumes self by taking the inner Request. + // This method would have been `into_inner(self) -> Request` + // but we can't do that because we are crossing Python boundary. + pub fn take_inner(&mut self) -> Option> { + let headers = self.headers.take_inner()?; + let mut parts = self.parts.take()?; + parts.headers = headers; + let body = { + let body = mem::take(&mut self.body); + let body = Arc::try_unwrap(body).ok()?; + body.into_inner().take()? + }; + Some(Request::from_parts(parts, body)) + } } #[pymethods] impl PyRequest { - #[new] - /// Create a new Python-compatible `Request` object from the Python side. - fn newpy( - method: String, - uri: String, - headers: Option>, - version: Option, - ) -> Self { - let version = version.map(|v| v.0).unwrap_or(Version::HTTP_11); - Self { - method, - uri, - headers: headers.unwrap_or_default(), - version, - } + /// Return the HTTP method of this request. + #[getter] + fn method(&self) -> PyResult { + self.parts + .as_ref() + .map(|parts| parts.method.to_string()) + .ok_or_else(|| PyMiddlewareError::RequestGone.into()) + } + + /// Return the URI of this request. + #[getter] + fn uri(&self) -> PyResult { + self.parts + .as_ref() + .map(|parts| parts.uri.to_string()) + .ok_or_else(|| PyMiddlewareError::RequestGone.into()) } /// Return the HTTP version of this request. - #[pyo3(text_signature = "($self)")] - fn version(&self) -> String { - PyHttpVersion(self.version).value().to_string() + #[getter] + fn version(&self) -> PyResult { + self.parts + .as_ref() + .map(|parts| format!("{:?}", parts.version)) + .ok_or_else(|| PyMiddlewareError::RequestGone.into()) } /// Return the HTTP headers of this request. - /// TODO(can we use `Py::clone_ref()` to prevent cloning the hashmap?) - #[pyo3(text_signature = "($self)")] - fn headers(&self) -> HashMap { + #[getter] + fn headers(&self) -> PyHeaderMap { self.headers.clone() } - /// Insert a new key/value into this request's headers. - #[pyo3(text_signature = "($self, key, value)")] - fn set_header(&mut self, key: &str, value: &str) { - self.headers.insert(key.to_string(), value.to_string()); + /// Return the HTTP body of this request. + /// Note that this is a costly operation because the whole request body is cloned. + #[getter] + fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + let body = self.body.clone(); + pyo3_asyncio::tokio::future_into_py(py, async move { + let body = { + let mut body_guard = body.lock().await; + let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?; + let body = hyper::body::to_bytes(body) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + let buf = body.clone(); + body_guard.replace(Body::from(body)); + buf + }; + // TODO(Perf): can we use `PyBytes` here? + Ok(body.to_vec()) + }) } - /// Return a header value of this request. - #[pyo3(text_signature = "($self, key)")] - fn get_header(&self, key: &str) -> Option<&String> { - self.headers.get(key) + /// Set the HTTP body of this request. + #[setter] + fn set_body(&mut self, buf: &[u8]) { + self.body = Arc::new(Mutex::new(Some(Body::from(buf.to_owned())))); } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs index 773fe76327..5e3619cb11 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs @@ -3,77 +3,128 @@ * SPDX-License-Identifier: Apache-2.0 */ -//! Python-compatible middleware [http::Request] implementation. -use std::{collections::HashMap, convert::TryInto}; +//! Python-compatible middleware [http::Response] implementation. + +use std::collections::HashMap; +use std::mem; +use std::sync::Arc; use aws_smithy_http_server::body::{to_boxed, BoxBody}; -use http::{Response, StatusCode}; -use pyo3::prelude::*; +use http::{response::Parts, Response}; +use pyo3::{exceptions::PyRuntimeError, prelude::*}; +use tokio::sync::Mutex; + +use super::{PyHeaderMap, PyMiddlewareError}; /// Python-compatible [Response] object. -/// -/// For performance reasons, there is not support yet to pass the body to the Python middleware, -/// as it requires to consume and clone the body, which is a very expensive operation. -/// -// TODO(if customers request for it, we can implemented an opt-in functionality to also pass -// the body around). #[pyclass(name = "Response")] #[pyo3(text_signature = "(status, headers, body)")] -#[derive(Debug, Clone)] pub struct PyResponse { - #[pyo3(get, set)] - status: u16, - #[pyo3(get, set)] - body: Vec, - headers: HashMap, + parts: Option, + headers: PyHeaderMap, + body: Arc>>, +} + +impl PyResponse { + /// Create a new Python-compatible [Response] structure from the Rust side. + pub fn new(response: Response) -> Self { + let (mut parts, body) = response.into_parts(); + let headers = mem::take(&mut parts.headers); + Self { + parts: Some(parts), + headers: PyHeaderMap::new(headers), + body: Arc::new(Mutex::new(Some(body))), + } + } + + // Consumes self by taking the inner Response. + // This method would have been `into_inner(self) -> Response` + // but we can't do that because we are crossing Python boundary. + pub fn take_inner(&mut self) -> Option> { + let headers = self.headers.take_inner()?; + let mut parts = self.parts.take()?; + parts.headers = headers; + let body = { + let body = mem::take(&mut self.body); + let body = Arc::try_unwrap(body).ok()?; + body.into_inner().take()? + }; + Some(Response::from_parts(parts, body)) + } } #[pymethods] impl PyResponse { /// Python-compatible [Response] object from the Python side. #[new] - fn newpy(status: u16, headers: Option>, body: Option>) -> Self { - Self { - status, - body: body.unwrap_or_default(), - headers: headers.unwrap_or_default(), + fn newpy( + status: u16, + headers: Option>, + body: Option>, + ) -> PyResult { + let mut builder = Response::builder().status(status); + + if let Some(headers) = headers { + for (k, v) in headers { + builder = builder.header(k, v); + } } + + let response = builder + .body(body.map(to_boxed).unwrap_or_default()) + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + + Ok(Self::new(response)) } - /// Return the HTTP headers of this response. - // TODO(can we use `Py::clone_ref()` to prevent cloning the hashmap?) - #[pyo3(text_signature = "($self)")] - fn headers(&self) -> HashMap { - self.headers.clone() + /// Return the HTTP status of this response. + #[getter] + fn status(&self) -> PyResult { + self.parts + .as_ref() + .map(|parts| parts.status.as_u16()) + .ok_or_else(|| PyMiddlewareError::ResponseGone.into()) } - /// Insert a new key/value into this response's headers. - #[pyo3(text_signature = "($self, key, value)")] - fn set_header(&mut self, key: &str, value: &str) { - self.headers.insert(key.to_string(), value.to_string()); + /// Return the HTTP version of this response. + #[getter] + fn version(&self) -> PyResult { + self.parts + .as_ref() + .map(|parts| format!("{:?}", parts.version)) + .ok_or_else(|| PyMiddlewareError::ResponseGone.into()) } - /// Return a header value of this response. - #[pyo3(text_signature = "($self, key)")] - fn get_header(&self, key: &str) -> Option<&String> { - self.headers.get(key) + /// Return the HTTP headers of this response. + #[getter] + fn headers(&self) -> PyHeaderMap { + self.headers.clone() } -} -/// Allow to convert between a [PyResponse] and a [Response]. -impl From for Response { - fn from(pyresponse: PyResponse) -> Self { - let mut response = Response::builder() - .status( - StatusCode::from_u16(pyresponse.status) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), - ) - .body(to_boxed(pyresponse.body)) - .unwrap_or_default(); - match (&pyresponse.headers).try_into() { - Ok(headers) => *response.headers_mut() = headers, - Err(e) => tracing::error!("Error extracting HTTP headers from PyResponse: {e}"), - }; - response + /// Return the HTTP body of this response. + /// Note that this is a costly operation because the whole response body is cloned. + #[getter] + fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + let body = self.body.clone(); + pyo3_asyncio::tokio::future_into_py(py, async move { + let body = { + let mut body_guard = body.lock().await; + let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?; + let body = hyper::body::to_bytes(body) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + let buf = body.clone(); + body_guard.replace(to_boxed(body)); + buf + }; + // TODO(Perf): can we use `PyBytes` here? + Ok(body.to_vec()) + }) + } + + /// Set the HTTP body of this response. + #[setter] + fn set_body(&mut self, buf: &[u8]) { + self.body = Arc::new(Mutex::new(Some(to_boxed(buf.to_owned())))); } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index e2749ce51f..2de6c1fe4b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -2,7 +2,6 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT. use std::{collections::HashMap, convert::Infallible, ops::Deref, process, thread}; @@ -18,7 +17,10 @@ use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::{util::BoxCloneService, ServiceBuilder}; -use crate::{middleware::PyMiddlewareHandler, PyMiddlewareType, PyMiddlewares, PySocket}; +use crate::{ + util::{error::rich_py_err, func_metadata}, + PySocket, +}; /// A Python handler function representation. /// @@ -29,6 +31,7 @@ use crate::{middleware::PyMiddlewareHandler, PyMiddlewareType, PyMiddlewares, Py #[derive(Debug, Clone)] pub struct PyHandler { pub func: PyObject, + // Number of args is needed to decide whether handler accepts context as an argument pub args: usize, pub is_coroutine: bool, } @@ -69,8 +72,6 @@ pub trait PyApp: Clone + pyo3::IntoPy { /// Mapping between operation names and their `PyHandler` representation. fn handlers(&mut self) -> &mut HashMap; - fn middlewares(&mut self) -> &mut PyMiddlewares; - /// Build the app's `Service` using given `event_loop`. fn build_service(&mut self, event_loop: &pyo3::PyAny) -> pyo3::PyResult; @@ -86,16 +87,16 @@ pub trait PyApp: Clone + pyo3::IntoPy { .getattr(py, "pid") .map(|pid| pid.extract(py).unwrap_or(-1)) .unwrap_or(-1); - tracing::debug!("Terminating worker {idx}, PID: {pid}"); + tracing::debug!(idx, pid, "terminating worker"); match worker.call_method0(py, "terminate") { Ok(_) => {} Err(e) => { - tracing::error!("Error terminating worker {idx}, PID: {pid}: {e}"); + tracing::error!(error = ?rich_py_err(e), idx, pid, "error terminating worker"); worker .call_method0(py, "kill") .map_err(|e| { tracing::error!( - "Unable to kill kill worker {idx}, PID: {pid}: {e}" + error = ?rich_py_err(e), idx, pid, "unable to kill kill worker" ); }) .unwrap(); @@ -117,11 +118,11 @@ pub trait PyApp: Clone + pyo3::IntoPy { .getattr(py, "pid") .map(|pid| pid.extract(py).unwrap_or(-1)) .unwrap_or(-1); - tracing::debug!("Killing worker {idx}, PID: {pid}"); + tracing::debug!(idx, pid, "killing worker"); worker .call_method0(py, "kill") .map_err(|e| { - tracing::error!("Unable to kill kill worker {idx}, PID: {pid}: {e}"); + tracing::error!(error = ?rich_py_err(e), idx, pid, "unable to kill kill worker"); }) .unwrap(); }); @@ -145,19 +146,19 @@ pub trait PyApp: Clone + pyo3::IntoPy { match sig { SIGINT => { tracing::info!( - "Termination signal {sig:?} received, all workers will be immediately terminated" + sig = %sig, "termination signal received, all workers will be immediately terminated" ); self.immediate_termination(self.workers()); } SIGTERM | SIGQUIT => { tracing::info!( - "Termination signal {sig:?} received, all workers will be gracefully terminated" + sig = %sig, "termination signal received, all workers will be gracefully terminated" ); self.graceful_termination(self.workers()); } _ => { - tracing::debug!("Signal {sig:?} is ignored by this application"); + tracing::debug!(sig = %sig, "signal is ignored by this application"); } } } @@ -231,7 +232,7 @@ event_loop.add_signal_handler(signal.SIGINT, self.register_python_signals(py, event_loop.to_object(py))?; // Spawn a new background [std::thread] to run the application. - tracing::debug!("Start the Tokio runtime in a background task"); + tracing::trace!("start the tokio runtime in a background task"); thread::spawn(move || { // The thread needs a new [tokio] runtime. let rt = runtime::Builder::new_multi_thread() @@ -248,58 +249,19 @@ event_loop.add_signal_handler(signal.SIGINT, .expect("Unable to create hyper server from shared socket") .serve(IntoMakeService::new(service)); - tracing::debug!("Started hyper server from shared socket"); + tracing::trace!("started hyper server from shared socket"); // Run forever-ish... if let Err(err) = server.await { - tracing::error!("server error: {}", err); + tracing::error!(error = ?err, "server error"); } }); }); // Block on the event loop forever. - tracing::debug!("Run and block on the Python event loop until a signal is received"); + tracing::trace!("run and block on the python event loop until a signal is received"); event_loop.call_method0("run_forever")?; Ok(()) } - // Check if a Python function is a coroutine. Since the function has not run yet, - // we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`. - fn is_coroutine(&self, py: Python, func: &PyObject) -> PyResult { - let inspect = py.import("inspect")?; - // NOTE: that `asyncio.iscoroutine()` doesn't work here. - inspect - .call_method1("iscoroutinefunction", (func,))? - .extract::() - } - - /// Register a Python function to be executed inside a Tower middleware layer. - /// - /// There are some information needed to execute the Python code from a Rust handler, - /// such has if the registered function needs to be awaited (if it is a coroutine).. - fn register_middleware( - &mut self, - py: Python, - func: PyObject, - _type: PyMiddlewareType, - ) -> PyResult<()> { - let name = func.getattr(py, "__name__")?.extract::(py)?; - let is_coroutine = self.is_coroutine(py, &func)?; - // Find number of expected methods (a Python implementation could not accept the context). - let handler = PyMiddlewareHandler { - name, - func, - is_coroutine, - _type, - }; - tracing::info!( - "Registering middleware function `{}`, coroutine: {}, type: {:?}", - handler.name, - handler.is_coroutine, - handler._type - ); - self.middlewares().push(handler); - Ok(()) - } - /// Register a Python function to be executed inside the Smithy Rust handler. /// /// There are some information needed to execute the Python code from a Rust handler, @@ -307,22 +269,17 @@ event_loop.add_signal_handler(signal.SIGINT, /// the number of arguments available, which tells us if the handler wants the state to be /// passed or not. fn register_operation(&mut self, py: Python, name: &str, func: PyObject) -> PyResult<()> { - let is_coroutine = self.is_coroutine(py, &func)?; - // Find number of expected methods (a Python implementation could not accept the context). - let inspect = py.import("inspect")?; - let func_args = inspect - .call_method1("getargs", (func.getattr(py, "__code__")?,))? - .getattr("args")? - .extract::>()?; + let func_metadata = func_metadata(py, &func)?; let handler = PyHandler { func, - is_coroutine, - args: func_args.len(), + is_coroutine: func_metadata.is_coroutine, + args: func_metadata.num_args, }; tracing::info!( - "Registering handler function `{name}`, coroutine: {}, arguments: {}", - handler.is_coroutine, - handler.args, + name, + is_coroutine = handler.is_coroutine, + args = handler.args, + "registering handler function", ); // Insert the handler in the handlers map. self.handlers().insert(name.to_string(), handler); @@ -342,10 +299,10 @@ event_loop.add_signal_handler(signal.SIGINT, match py.import("uvloop") { Ok(uvloop) => { uvloop.call_method0("install")?; - tracing::debug!("Setting up uvloop for current process"); + tracing::trace!("setting up uvloop for current process"); } Err(_) => { - tracing::warn!("Uvloop not found, using Python standard event loop, which could have worse performance than uvloop"); + tracing::warn!("uvloop not found, using python standard event loop, which could have worse performance than uvloop"); } } let event_loop = asyncio.call_method0("new_event_loop")?; @@ -371,7 +328,7 @@ event_loop.add_signal_handler(signal.SIGINT, /// use std::convert::Infallible; /// use std::collections::HashMap; /// use pyo3::prelude::*; - /// use aws_smithy_http_server_python::{PyApp, PyHandler, PyMiddlewares}; + /// use aws_smithy_http_server_python::{PyApp, PyHandler}; /// use aws_smithy_http_server::body::{Body, BoxBody}; /// use parking_lot::Mutex; /// use http::{Request, Response}; @@ -385,7 +342,6 @@ event_loop.add_signal_handler(signal.SIGINT, /// fn workers(&self) -> &Mutex> { todo!() } /// fn context(&self) -> &Option { todo!() } /// fn handlers(&mut self) -> &mut HashMap { todo!() } - /// fn middlewares(&mut self) -> &mut PyMiddlewares { todo!() } /// fn build_service(&mut self, event_loop: &PyAny) -> PyResult, Response, Infallible>> { todo!() } /// } /// @@ -457,7 +413,7 @@ event_loop.add_signal_handler(signal.SIGINT, } // Unlock the workers mutex. drop(active_workers); - tracing::info!("Rust Python server started successfully"); + tracing::trace!("rust python server started successfully"); self.block_on_rust_signals(); Ok(()) } @@ -493,7 +449,6 @@ event_loop.add_signal_handler(signal.SIGINT, let service = self.build_service(event_loop)?; // Create the `PyState` object from the Python context object. let context = self.context().clone().unwrap_or_else(|| py.None()); - tracing::debug!("add middlewares to rust python router"); let service = ServiceBuilder::new() .boxed_clone() .layer(AddExtensionLayer::new(context)) diff --git a/rust-runtime/aws-smithy-http-server-python/src/socket.rs b/rust-runtime/aws-smithy-http-server-python/src/socket.rs index df0e151078..8243aa28c2 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/socket.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/socket.rs @@ -34,7 +34,7 @@ impl PySocket { pub fn new(address: String, port: i32, backlog: Option) -> PyResult { let address: SocketAddr = format!("{}:{}", address, port).parse()?; let (domain, ip_version) = PySocket::socket_domain(address); - tracing::info!("Shared socket listening on {address}, IP version: {ip_version}"); + tracing::trace!(address = %address, ip_version, "shared socket listening"); let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; // Set value for the `SO_REUSEPORT` and `SO_REUSEADDR` options on this socket. // This indicates that further calls to `bind` may allow reuse of local diff --git a/rust-runtime/aws-smithy-http-server-python/src/util.rs b/rust-runtime/aws-smithy-http-server-python/src/util.rs new file mode 100644 index 0000000000..fe1141e60c --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/util.rs @@ -0,0 +1,90 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +pub mod collection; +pub mod error; + +use pyo3::prelude::*; + +// Captures some information about a Python function. +#[derive(Debug, PartialEq)] +pub struct FuncMetadata { + pub name: String, + pub is_coroutine: bool, + pub num_args: usize, +} + +// Returns `FuncMetadata` for given `func`. +pub fn func_metadata(py: Python, func: &PyObject) -> PyResult { + let name = func.getattr(py, "__name__")?.extract::(py)?; + let is_coroutine = is_coroutine(py, func)?; + let inspect = py.import("inspect")?; + let args = inspect + .call_method1("getargs", (func.getattr(py, "__code__")?,))? + .getattr("args")? + .extract::>()?; + Ok(FuncMetadata { + name, + is_coroutine, + num_args: args.len(), + }) +} + +// Check if a Python function is a coroutine. Since the function has not run yet, +// we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`. +fn is_coroutine(py: Python, func: &PyObject) -> PyResult { + let inspect = py.import("inspect")?; + // NOTE: that `asyncio.iscoroutine()` doesn't work here. + inspect + .call_method1("iscoroutinefunction", (func,))? + .extract::() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn function_metadata() -> PyResult<()> { + pyo3::prepare_freethreaded_python(); + + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +def regular_func(first_arg, second_arg): + pass + +async def async_func(): + pass +"#, + "", + "", + )?; + + let regular_func = module.getattr("regular_func")?.into_py(py); + assert_eq!( + FuncMetadata { + name: "regular_func".to_string(), + is_coroutine: false, + num_args: 2, + }, + func_metadata(py, ®ular_func)? + ); + + let async_func = module.getattr("async_func")?.into_py(py); + assert_eq!( + FuncMetadata { + name: "async_func".to_string(), + is_coroutine: true, + num_args: 0, + }, + func_metadata(py, &async_func)? + ); + + Ok(()) + }) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/util/collection.rs b/rust-runtime/aws-smithy-http-server-python/src/util/collection.rs new file mode 100644 index 0000000000..7bf557f18a --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/util/collection.rs @@ -0,0 +1,349 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Provides Rust equivalents of [collections.abc] Python classes. +//! +//! Creating a custom container is achived in Python via extending a `collections.abc.*` class: +//! ```python +//! class MySeq(collections.abc.Sequence): +//! def __getitem__(self, index): ... # Required abstract method +//! def __len__(self): ... # Required abstract method +//! ``` +//! You just need to implement required abstract methods and you get +//! extra mixin methods for free. +//! +//! Ideally we also want to just extend abstract base classes from Python but +//! it is not supported yet: . +//! +//! Until then, we are providing traits with the required methods and, macros that +//! takes those types that implement those traits and provides mixin methods for them. +//! +//! [collections.abc]: https://docs.python.org/3/library/collections.abc.html + +use pyo3::PyResult; + +/// Rust version of [collections.abc.MutableMapping]. +/// +/// [collections.abc.MutableMapping]: https://docs.python.org/3/library/collections.abc.html#collections.abc.MutableMapping +pub trait PyMutableMapping { + type Key; + type Value; + + fn len(&self) -> PyResult; + fn contains(&self, key: Self::Key) -> PyResult; + fn get(&self, key: Self::Key) -> PyResult>; + fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()>; + fn del(&mut self, key: Self::Key) -> PyResult<()>; + + // TODO(Perf): This methods should return iterators instead of `Vec`s. + fn keys(&self) -> PyResult>; + fn values(&self) -> PyResult>; +} + +/// Macro that provides mixin methods of [collections.abc.MutableMapping] to the implementing type. +/// +/// [collections.abc.MutableMapping]: https://docs.python.org/3/library/collections.abc.html#collections.abc.MutableMapping +#[macro_export] +macro_rules! mutable_mapping_pymethods { + ($ty:ident, keys_iter: $keys_iter: ident) => { + const _: fn() = || { + fn assert_impl() {} + assert_impl::<$ty>(); + }; + + #[pyo3::pyclass] + struct $keys_iter(std::vec::IntoIter<<$ty as PyMutableMapping>::Key>); + + #[pyo3::pymethods] + impl $keys_iter { + fn __next__(&mut self) -> Option<<$ty as PyMutableMapping>::Key> { + self.0.next() + } + } + + #[pyo3::pymethods] + impl $ty { + // -- collections.abc.Sized + + fn __len__(&self) -> pyo3::PyResult { + self.len() + } + + // -- collections.abc.Container + + fn __contains__(&self, key: <$ty as PyMutableMapping>::Key) -> pyo3::PyResult { + self.contains(key) + } + + // -- collections.abc.Iterable + + /// Returns an iterator over the keys of the dictionary. + /// NOTE: This method currently causes all keys to be cloned. + fn __iter__(&self) -> pyo3::PyResult<$keys_iter> { + Ok($keys_iter(self.keys()?.into_iter())) + } + + // -- collections.abc.Mapping + + fn __getitem__( + &self, + key: <$ty as PyMutableMapping>::Key, + ) -> pyo3::PyResult::Value>> { + <$ty as PyMutableMapping>::get(&self, key) + } + + fn get( + &self, + key: <$ty as PyMutableMapping>::Key, + default: Option<<$ty as PyMutableMapping>::Value>, + ) -> pyo3::PyResult::Value>> { + Ok(<$ty as PyMutableMapping>::get(&self, key)?.or(default)) + } + + /// Returns keys of the dictionary. + /// NOTE: This method currently causes all keys to be cloned. + fn keys(&self) -> pyo3::PyResult::Key>> { + <$ty as PyMutableMapping>::keys(&self) + } + + /// Returns values of the dictionary. + /// NOTE: This method currently causes all values to be cloned. + fn values(&self) -> pyo3::PyResult::Value>> { + <$ty as PyMutableMapping>::values(&self) + } + + /// Returns items (key, value) of the dictionary. + /// NOTE: This method currently causes all keys and values to be cloned. + fn items( + &self, + ) -> pyo3::PyResult< + Vec<( + <$ty as PyMutableMapping>::Key, + <$ty as PyMutableMapping>::Value, + )>, + > { + Ok(self + .keys()? + .into_iter() + .zip(self.values()?.into_iter()) + .collect()) + } + + // -- collections.abc.MutableMapping + + fn __setitem__( + &mut self, + key: <$ty as PyMutableMapping>::Key, + value: <$ty as PyMutableMapping>::Value, + ) -> pyo3::PyResult<()> { + self.set(key, value) + } + + fn __delitem__(&mut self, key: <$ty as PyMutableMapping>::Key) -> pyo3::PyResult<()> { + self.del(key) + } + + fn pop( + &mut self, + key: <$ty as PyMutableMapping>::Key, + default: Option<<$ty as PyMutableMapping>::Value>, + ) -> pyo3::PyResult<<$ty as PyMutableMapping>::Value> { + let val = self.__getitem__(key.clone())?; + match val { + Some(val) => { + self.del(key)?; + Ok(val) + } + None => { + default.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("unknown key")) + } + } + } + + fn popitem( + &mut self, + ) -> pyo3::PyResult<( + <$ty as PyMutableMapping>::Key, + <$ty as PyMutableMapping>::Value, + )> { + let key = self + .keys()? + .iter() + .cloned() + .next() + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("no key"))?; + let value = self.pop(key.clone(), None)?; + Ok((key, value)) + } + + fn clear(&mut self, py: pyo3::Python) -> pyo3::PyResult<()> { + loop { + match self.popitem() { + Ok(_) => {} + Err(err) if err.is_instance_of::(py) => { + return Ok(()) + } + Err(err) => return Err(err), + } + } + } + + fn setdefault( + &mut self, + key: <$ty as PyMutableMapping>::Key, + default: Option<<$ty as PyMutableMapping>::Value>, + ) -> pyo3::PyResult::Value>> { + match self.__getitem__(key.clone())? { + Some(value) => Ok(Some(value)), + None => { + if let Some(value) = default.clone() { + self.set(key, value)?; + } + Ok(default) + } + } + } + } + }; +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use pyo3::{prelude::*, py_run}; + + use super::*; + + #[pyclass(mapping)] + struct Map(HashMap); + + impl PyMutableMapping for Map { + type Key = String; + type Value = String; + + fn len(&self) -> PyResult { + Ok(self.0.len()) + } + + fn contains(&self, key: Self::Key) -> PyResult { + Ok(self.0.contains_key(&key)) + } + + fn keys(&self) -> PyResult> { + Ok(self.0.keys().cloned().collect()) + } + + fn values(&self) -> PyResult> { + Ok(self.0.values().cloned().collect()) + } + + fn get(&self, key: Self::Key) -> PyResult> { + Ok(self.0.get(&key).cloned()) + } + + fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()> { + self.0.insert(key, value); + Ok(()) + } + + fn del(&mut self, key: Self::Key) -> PyResult<()> { + self.0.remove(&key); + Ok(()) + } + } + + mutable_mapping_pymethods!(Map, keys_iter: MapKeys); + + #[test] + fn mutable_mapping() -> PyResult<()> { + pyo3::prepare_freethreaded_python(); + + let map = Map({ + let mut hash_map = HashMap::new(); + hash_map.insert("foo".to_string(), "bar".to_string()); + hash_map.insert("baz".to_string(), "qux".to_string()); + hash_map + }); + + Python::with_gil(|py| { + let map = PyCell::new(py, map)?; + py_run!( + py, + map, + r#" +# collections.abc.Sized +assert len(map) == 2 + +# collections.abc.Container +assert "foo" in map +assert "foobar" not in map + +# collections.abc.Iterable +elems = ["foo", "baz"] + +for elem in map: + assert elem in elems + +it = iter(map) +assert next(it) in elems +assert next(it) in elems +try: + next(it) + assert False, "should stop iteration" +except StopIteration: + pass + +assert set(list(map)) == set(["foo", "baz"]) + +# collections.abc.Mapping +assert map["foo"] == "bar" +assert map.get("baz") == "qux" +assert map.get("foobar") == None +assert map.get("foobar", "default") == "default" + +assert set(list(map.keys())) == set(["foo", "baz"]) +assert set(list(map.values())) == set(["bar", "qux"]) +assert set(list(map.items())) == set([("foo", "bar"), ("baz", "qux")]) + +# collections.abc.MutableMapping +map["foobar"] = "bazqux" +del map["foo"] + +try: + map.pop("not_exist") + assert False, "should throw KeyError" +except KeyError: + pass +assert map.pop("not_exist", "default") == "default" +assert map.pop("foobar") == "bazqux" +assert "foobar" not in map + +# at this point there is only `baz => qux` in `map` +assert map.popitem() == ("baz", "qux") +assert len(map) == 0 +try: + map.popitem() + assert False, "should throw KeyError" +except KeyError: + pass + +map["foo"] = "bar" +assert len(map) == 1 +map.clear() +assert len(map) == 0 +assert "foo" not in "bar" + +assert map.setdefault("foo", "bar") == "bar" +assert map["foo"] == "bar" +assert map.setdefault("foo", "baz") == "bar" + +# TODO(MissingImpl): Add tests for map.update(...) +"# + ); + Ok(()) + }) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/util/error.rs b/rust-runtime/aws-smithy-http-server-python/src/util/error.rs new file mode 100644 index 0000000000..25b42e80e5 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/util/error.rs @@ -0,0 +1,89 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Provides utilities for Python errors. + +use std::fmt; + +use pyo3::{PyErr, Python}; + +/// Wraps [PyErr] with a richer debug output that includes traceback and cause. +pub struct RichPyErr(PyErr); + +impl fmt::Debug for RichPyErr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + Python::with_gil(|py| { + let mut debug_struct = f.debug_struct("RichPyErr"); + debug_struct + .field("type", self.0.get_type(py)) + .field("value", self.0.value(py)); + + if let Some(traceback) = self.0.traceback(py) { + if let Ok(traceback) = traceback.format() { + debug_struct.field("traceback", &traceback); + } + } + + if let Some(cause) = self.0.cause(py) { + debug_struct.field("cause", &rich_py_err(cause)); + } + + debug_struct.finish() + }) + } +} + +/// Wrap `err` with [RichPyErr] to have a richer debug output. +pub fn rich_py_err(err: PyErr) -> RichPyErr { + RichPyErr(err) +} + +#[cfg(test)] +mod tests { + use pyo3::prelude::*; + + use super::*; + + #[test] + fn rich_python_errors() -> PyResult<()> { + pyo3::prepare_freethreaded_python(); + + let py_err = Python::with_gil(|py| { + py.run( + r#" +def foo(): + base_err = ValueError("base error") + raise ValueError("some python error") from base_err + +def bar(): + foo() + +def baz(): + bar() + +baz() +"#, + None, + None, + ) + .unwrap_err() + }); + + let debug_output = format!("{:?}", rich_py_err(py_err)); + + // Make sure we are capturing error message + assert!(debug_output.contains("some python error")); + + // Make sure we are capturing traceback + assert!(debug_output.contains("foo")); + assert!(debug_output.contains("bar")); + assert!(debug_output.contains("baz")); + + // Make sure we are capturing cause + assert!(debug_output.contains("base error")); + + Ok(()) + } +}