Skip to content

Commit

Permalink
Update Python implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Harry Barber committed Oct 7, 2022
1 parent cbff82b commit 0469b3b
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* Generates a Python compatible application and server that can be configured from Python.
Expand Down Expand Up @@ -62,13 +63,13 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
*/
class PythonApplicationGenerator(
codegenContext: CodegenContext,
private val protocol: ServerProtocol,
private val operations: List<OperationShape>,
) {
private val symbolProvider = codegenContext.symbolProvider
private val libName = "lib${codegenContext.settings.moduleName.toSnakeCase()}"
private val runtimeConfig = codegenContext.runtimeConfig
private val model = codegenContext.model
private val protocol = codegenContext.protocol
private val codegenScope =
arrayOf(
"SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(),
Expand Down Expand Up @@ -98,7 +99,7 @@ class PythonApplicationGenerator(
writer.rustTemplate(
"""
##[#{pyo3}::pyclass]
##[derive(Debug, Default)]
##[derive(Debug)]
pub struct App {
handlers: #{HashMap}<String, #{SmithyPython}::PyHandler>,
middlewares: #{SmithyPython}::PyMiddlewares,
Expand Down Expand Up @@ -165,28 +166,24 @@ class PythonApplicationGenerator(
rustTemplate(
"""
let middleware_locals = pyo3_asyncio::TaskLocals::new(event_loop);
use #{SmithyPython}::PyApp;
let service = #{tower}::ServiceBuilder::new().layer(
#{SmithyPython}::PyMiddlewareLayer::new(
self.middlewares.clone(),
self.protocol(),
middleware_locals
)?,
);
let service = #{tower}::ServiceBuilder::new()
.layer(
#{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(self.middlewares.clone(), middleware_locals),
);
let router: #{SmithyServer}::routing::Router = router
.build()
.expect("Unable to build operation registry")
.into();
Ok(router.layer(service))
""",
"Protocol" to protocol.markerStruct(),
*codegenScope,
)
}
}
}

private fun renderPyAppTrait(writer: RustWriter) {
val protocol = protocol.toString().replace("#", "##")
writer.rustTemplate(
"""
impl #{SmithyPython}::PyApp for App {
Expand All @@ -202,9 +199,6 @@ class PythonApplicationGenerator(
fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares {
&mut self.middlewares
}
fn protocol(&self) -> &'static str {
"$protocol"
}
}
""",
*codegenScope,
Expand All @@ -224,7 +218,13 @@ class PythonApplicationGenerator(
/// Create a new [App].
##[new]
pub fn new() -> Self {
Self::default()
Self {
handlers: Default::default(),
middlewares: #{SmithyPython}::PyMiddlewares::new::<#{Protocol}>(vec![]),
context: None,
workers: #{parking_lot}::Mutex::new(vec![]),
}
}
/// Register a context object that will be shared between handlers.
##[pyo3(text_signature = "(${'$'}self, context)")]
Expand Down Expand Up @@ -264,6 +264,7 @@ class PythonApplicationGenerator(
self.start_hyper_worker(py, socket, event_loop, router, worker_number)
}
""",
"Protocol" to protocol.markerStruct(),
*codegenScope,
)
operations.map { operation ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationHandlerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* The Rust code responsible to run the Python business logic on the Python interpreter
Expand All @@ -33,8 +34,9 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperat
*/
class PythonServerOperationHandlerGenerator(
codegenContext: CodegenContext,
protocol: ServerProtocol,
private val operations: List<OperationShape>,
) : ServerOperationHandlerGenerator(codegenContext, operations) {
) : ServerOperationHandlerGenerator(codegenContext, protocol, operations) {
private val symbolProvider = codegenContext.symbolProvider
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ class PythonServerServiceGenerator(
}

override fun renderOperationHandler(writer: RustWriter, operations: List<OperationShape>) {
PythonServerOperationHandlerGenerator(context, operations).render(writer)
PythonServerOperationHandlerGenerator(context, protocol, operations).render(writer)
}

override fun renderExtras(operations: List<OperationShape>) {
rustCrate.withModule(RustModule.public("python_server_application", "Python server and application implementation.")) { writer ->
PythonApplicationGenerator(context, operations)
PythonApplicationGenerator(context, protocol, operations)
.render(writer)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBou
*/
open class ServerOperationHandlerGenerator(
codegenContext: CodegenContext,
private val protocol: ServerProtocol,
val protocol: ServerProtocol,
private val operations: List<OperationShape>,
) {
private val serverCrate = "aws_smithy_http_server"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ open class ServerServiceGenerator(
private val rustCrate: RustCrate,
private val protocolGenerator: ServerProtocolGenerator,
private val protocolSupport: ProtocolSupport,
private val protocol: ServerProtocol,
val protocol: ServerProtocol,
private val codegenContext: CodegenContext,
) {
private val index = TopDownIndex.of(codegenContext.model)
Expand Down
77 changes: 47 additions & 30 deletions rust-runtime/aws-smithy-http-server-python/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@

//! Python error definition.
use aws_smithy_http_server::protocols::Protocol;
use aws_smithy_http_server::{body::to_boxed, response::Response};
use aws_smithy_http_server::{
body::{to_boxed, BoxBody},
proto::{
aws_json_10::AwsJson10, aws_json_11::AwsJson11, rest_json_1::AwsRestJson1,
rest_xml::AwsRestXml,
},
response::IntoResponse,
};
use aws_smithy_types::date_time::{ConversionError, DateTimeParseError};
use pyo3::{create_exception, exceptions::PyException as BasePyException, prelude::*, PyErr};
use thiserror::Error;
Expand Down Expand Up @@ -62,39 +68,50 @@ impl From<PyErr> for PyMiddlewareException {
}
}

impl PyMiddlewareException {
/// Convert the exception into a [Response], following the [Protocol] specification.
pub(crate) fn into_response(self, protocol: Protocol) -> Response {
let body = to_boxed(match protocol {
Protocol::RestJson1 => self.json_body(),
Protocol::RestXml => self.xml_body(),
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization
Protocol::AwsJson10 => self.json_body(),
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
Protocol::AwsJson11 => self.json_body(),
});
impl IntoResponse<AwsRestJson1> for PyMiddlewareException {
fn into_response(self) -> http::Response<BoxBody> {
http::Response::builder()
.status(self.status_code)
.header("Content-Type", "application/json")
.header("X-Amzn-Errortype", "MiddlewareException")
.body(to_boxed(self.json_body()))
.expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}

let mut builder = http::Response::builder();
builder = builder.status(self.status_code);
impl IntoResponse<AwsRestXml> for PyMiddlewareException {
fn into_response(self) -> http::Response<BoxBody> {
http::Response::builder()
.status(self.status_code)
.header("Content-Type", "application/xml")
.body(to_boxed(self.xml_body()))
.expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}

match protocol {
Protocol::RestJson1 => {
builder = builder
.header("Content-Type", "application/json")
.header("X-Amzn-Errortype", "MiddlewareException");
}
Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"),
Protocol::AwsJson10 => {
builder = builder.header("Content-Type", "application/x-amz-json-1.0")
}
Protocol::AwsJson11 => {
builder = builder.header("Content-Type", "application/x-amz-json-1.1")
}
}
impl IntoResponse<AwsJson10> for PyMiddlewareException {
fn into_response(self) -> http::Response<BoxBody> {
http::Response::builder()
.status(self.status_code)
.header("Content-Type", "application/x-amz-json-1.0")
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization
.body(to_boxed(self.json_body()))
.expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}

builder.body(body).expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
impl IntoResponse<AwsJson11> for PyMiddlewareException {
fn into_response(self) -> http::Response<BoxBody> {
http::Response::builder()
.status(self.status_code)
.header("Content-Type", "application/x-amz-json-1.1")
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
.body(to_boxed(self.json_body()))
.expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}

impl PyMiddlewareException {
/// Serialize the body into a JSON object.
fn json_body(&self) -> String {
let mut out = String::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
*/

//! Execute Python middleware handlers.
use aws_smithy_http_server::body::Body;
use aws_smithy_http_server::{body::Body, body::BoxBody, response::IntoResponse};
use http::Request;
use pyo3::prelude::*;

use aws_smithy_http_server::protocols::Protocol;
use pyo3_asyncio::TaskLocals;

use crate::{PyMiddlewareException, PyRequest, PyResponse};
Expand Down Expand Up @@ -36,18 +35,27 @@ pub struct PyMiddlewareHandler {
/// Structure holding the list of Python middlewares that will be executed by this server.
///
/// Middlewares are executed one after each other inside the [crate::PyMiddlewareLayer] Tower layer.
#[derive(Debug, Clone, Default)]
pub struct PyMiddlewares(Vec<PyMiddlewareHandler>);
#[derive(Debug, Clone)]
pub struct PyMiddlewares {
handlers: Vec<PyMiddlewareHandler>,
into_response: fn(PyMiddlewareException) -> http::Response<BoxBody>,
}

impl PyMiddlewares {
/// Create a new instance of `PyMiddlewareHandlers` from a list of heandlers.
pub fn new(handlers: Vec<PyMiddlewareHandler>) -> Self {
Self(handlers)
pub fn new<P>(handlers: Vec<PyMiddlewareHandler>) -> Self
where
PyMiddlewareException: IntoResponse<P>,
{
Self {
handlers,
into_response: PyMiddlewareException::into_response,
}
}

/// Add a new handler to the list.
pub fn push(&mut self, handler: PyMiddlewareHandler) {
self.0.push(handler);
self.handlers.push(handler);
}

/// Execute a single middleware handler.
Expand Down Expand Up @@ -114,13 +122,9 @@ impl PyMiddlewares {
/// and return a protocol specific error, with the option of setting the HTTP return code.
/// * Middleware raising any other exception will immediately terminate the request handling and
/// return a protocol specific error, with HTTP status code 500.
pub fn run(
&mut self,
mut request: Request<Body>,
protocol: Protocol,
locals: TaskLocals,
) -> PyFuture {
let handlers = self.0.clone();
pub fn run(&mut self, mut request: Request<Body>, locals: TaskLocals) -> PyFuture {
let handlers = self.handlers.clone();
let into_response = self.into_response;
// Run all Python handlers in a loop.
Box::pin(async move {
tracing::debug!("Executing Python middleware stack");
Expand Down Expand Up @@ -152,7 +156,7 @@ impl PyMiddlewares {
tracing::debug!(
"Middleware `{name}` returned an error, exit middleware loop"
);
return Err(e.into_response(protocol));
return Err((into_response)(e));
}
}
}
Expand All @@ -166,6 +170,7 @@ impl PyMiddlewares {

#[cfg(test)]
mod tests {
use aws_smithy_http_server::proto::rest_json_1::AwsRestJson1;
use http::HeaderValue;
use hyper::body::to_bytes;
use pretty_assertions::assert_eq;
Expand Down Expand Up @@ -212,11 +217,7 @@ def second_middleware(request: Request):
})?;

let result = middlewares
.run(
Request::builder().body(Body::from("")).unwrap(),
Protocol::RestJson1,
locals,
)
.run::<AwsRestJson1>(Request::builder().body(Body::from("")).unwrap(), locals)
.await
.unwrap();
assert_eq!(
Expand Down Expand Up @@ -252,11 +253,7 @@ def middleware(request: Request):
})?;

let result = middlewares
.run(
Request::builder().body(Body::from("")).unwrap(),
Protocol::RestJson1,
locals,
)
.run::<AwsRestJson1>(Request::builder().body(Body::from("")).unwrap(), locals)
.await
.unwrap_err();
assert_eq!(result.status(), 200);
Expand Down Expand Up @@ -291,11 +288,7 @@ def middleware(request: Request):
})?;

let result = middlewares
.run(
Request::builder().body(Body::from("")).unwrap(),
Protocol::RestJson1,
locals,
)
.run::<AwsRestJson1>(Request::builder().body(Body::from("")).unwrap(), locals)
.await
.unwrap_err();
assert_eq!(result.status(), 503);
Expand Down Expand Up @@ -333,11 +326,7 @@ def middleware(request):
})?;

let result = middlewares
.run(
Request::builder().body(Body::from("")).unwrap(),
Protocol::RestJson1,
locals,
)
.run::<AwsRestJson1>(Request::builder().body(Body::from("")).unwrap(), locals)
.await
.unwrap_err();
assert_eq!(result.status(), 500);
Expand Down
Loading

0 comments on commit 0469b3b

Please sign in to comment.