From 78022d69efd70504c6ab539a6bd0ce1e9aedbd32 Mon Sep 17 00:00:00 2001 From: Harry Barber <106155934+hlbarber@users.noreply.github.com> Date: Mon, 10 Oct 2022 14:29:37 +0100 Subject: [PATCH] Remove Protocol enum (#1829) * Remove Protocol enum * Update Python implementation --- CHANGELOG.next.toml | 12 ++ .../generators/PythonApplicationGenerator.kt | 42 ++++-- .../PythonServerOperationHandlerGenerator.kt | 4 +- .../PythonServerServiceGenerator.kt | 4 +- .../ServerOperationHandlerGenerator.kt | 16 +-- .../generators/ServerServiceGenerator.kt | 4 +- .../protocol/ServerProtocolGenerator.kt | 3 +- .../protocol/ServerProtocolTestGenerator.kt | 3 +- .../ServerHttpBoundProtocolGenerator.kt | 39 +---- .../src/error.rs | 77 ++++++---- .../src/middleware/handler.rs | 67 ++++----- .../src/middleware/layer.rs | 67 +++------ .../src/server.rs | 3 - .../aws-smithy-http-server/src/extension.rs | 2 +- .../aws-smithy-http-server/src/protocols.rs | 9 -- .../src/runtime_error.rs | 134 +++++++++--------- 16 files changed, 226 insertions(+), 260 deletions(-) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index ae179af55e..b303a4b601 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -137,3 +137,15 @@ message = "Fix regression where `connect_timeout` and `read_timeout` fields are references = ["smithy-rs#1822"] meta = { "breaking" = false, "tada" = false, "bug" = true } author = "kevinpark1217" + +[[smithy-rs]] +message = "Remove `Protocol` enum, removing an obstruction to extending smithy to third-party protocols." +references = ["smithy-rs#1829"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" } +author = "hlbarber" + +[[smithy-rs]] +message = "Convert the `protocol` argument on `PyMiddlewares::new` constructor to a type parameter." +references = ["smithy-rs#1829"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" } +author = "hlbarber" diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 1f3e3e197a..c032ea9102 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol /** * Generates a Python compatible application and server that can be configured from Python. @@ -62,13 +63,13 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency */ class PythonApplicationGenerator( codegenContext: CodegenContext, + private val protocol: ServerProtocol, private val operations: List, ) { private val symbolProvider = codegenContext.symbolProvider private val libName = "lib${codegenContext.settings.moduleName.toSnakeCase()}" private val runtimeConfig = codegenContext.runtimeConfig private val model = codegenContext.model - private val protocol = codegenContext.protocol private val codegenScope = arrayOf( "SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(), @@ -88,6 +89,7 @@ class PythonApplicationGenerator( fun render(writer: RustWriter) { renderPyApplicationRustDocs(writer) renderAppStruct(writer) + renderAppDefault(writer) renderAppClone(writer) renderPyAppTrait(writer) renderAppImpl(writer) @@ -98,7 +100,7 @@ class PythonApplicationGenerator( writer.rustTemplate( """ ##[#{pyo3}::pyclass] - ##[derive(Debug, Default)] + ##[derive(Debug)] pub struct App { handlers: #{HashMap}, middlewares: #{SmithyPython}::PyMiddlewares, @@ -128,6 +130,25 @@ class PythonApplicationGenerator( ) } + private fun renderAppDefault(writer: RustWriter) { + writer.rustTemplate( + """ + impl Default for App { + fn default() -> Self { + Self { + handlers: Default::default(), + middlewares: #{SmithyPython}::PyMiddlewares::new::<#{Protocol}>(vec![]), + context: None, + workers: #{parking_lot}::Mutex::new(vec![]), + } + } + } + """, + "Protocol" to protocol.markerStruct(), + *codegenScope, + ) + } + private fun renderAppImpl(writer: RustWriter) { writer.rustBlockTemplate( """ @@ -165,20 +186,17 @@ class PythonApplicationGenerator( rustTemplate( """ let middleware_locals = pyo3_asyncio::TaskLocals::new(event_loop); - use #{SmithyPython}::PyApp; - let service = #{tower}::ServiceBuilder::new().layer( - #{SmithyPython}::PyMiddlewareLayer::new( - self.middlewares.clone(), - self.protocol(), - middleware_locals - )?, - ); + let service = #{tower}::ServiceBuilder::new() + .layer( + #{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(self.middlewares.clone(), middleware_locals), + ); let router: #{SmithyServer}::routing::Router = router .build() .expect("Unable to build operation registry") .into(); Ok(router.layer(service)) """, + "Protocol" to protocol.markerStruct(), *codegenScope, ) } @@ -186,7 +204,6 @@ class PythonApplicationGenerator( } private fun renderPyAppTrait(writer: RustWriter) { - val protocol = protocol.toString().replace("#", "##") writer.rustTemplate( """ impl #{SmithyPython}::PyApp for App { @@ -202,9 +219,6 @@ class PythonApplicationGenerator( fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares { &mut self.middlewares } - fn protocol(&self) -> &'static str { - "$protocol" - } } """, *codegenScope, diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index 6f814d7ee3..076a757014 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationHandlerGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol /** * The Rust code responsible to run the Python business logic on the Python interpreter @@ -33,8 +34,9 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperat */ class PythonServerOperationHandlerGenerator( codegenContext: CodegenContext, + protocol: ServerProtocol, private val operations: List, -) : ServerOperationHandlerGenerator(codegenContext, operations) { +) : ServerOperationHandlerGenerator(codegenContext, protocol, operations) { private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig private val codegenScope = diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt index ce1bf6a36f..56e119a87e 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt @@ -34,12 +34,12 @@ class PythonServerServiceGenerator( } override fun renderOperationHandler(writer: RustWriter, operations: List) { - PythonServerOperationHandlerGenerator(context, operations).render(writer) + PythonServerOperationHandlerGenerator(context, protocol, operations).render(writer) } override fun renderExtras(operations: List) { rustCrate.withModule(RustModule.public("python_server_application", "Python server and application implementation.")) { writer -> - PythonApplicationGenerator(context, operations) + PythonApplicationGenerator(context, protocol, operations) .render(writer) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt index 38e1a11409..5a3c93b2be 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt @@ -19,9 +19,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErr import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.outputShape -import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator /** @@ -29,12 +29,11 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBou */ open class ServerOperationHandlerGenerator( codegenContext: CodegenContext, + val protocol: ServerProtocol, private val operations: List, ) { private val serverCrate = "aws_smithy_http_server" - private val service = codegenContext.serviceShape private val model = codegenContext.model - private val protocol = codegenContext.protocol private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig private val codegenScope = arrayOf( @@ -83,11 +82,8 @@ open class ServerOperationHandlerGenerator( Ok(v) => v, Err(extension_not_found_rejection) => { let extension = $serverCrate::extension::RuntimeErrorExtension::new(extension_not_found_rejection.to_string()); - let runtime_error = $serverCrate::runtime_error::RuntimeError { - protocol: #{SmithyHttpServer}::protocols::Protocol::${protocol.name.toPascalCase()}, - kind: extension_not_found_rejection.into(), - }; - let mut response = runtime_error.into_response(); + let runtime_error = $serverCrate::runtime_error::RuntimeError::from(extension_not_found_rejection); + let mut response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error); response.extensions_mut().insert(extension); return response.map($serverCrate::body::boxed); } @@ -109,7 +105,8 @@ open class ServerOperationHandlerGenerator( let input_wrapper = match $inputWrapperName::from_request(&mut req).await { Ok(v) => v, Err(runtime_error) => { - return runtime_error.into_response().map($serverCrate::body::boxed); + let response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error); + return response.map($serverCrate::body::boxed); } }; $callImpl @@ -120,6 +117,7 @@ open class ServerOperationHandlerGenerator( response.map(#{SmithyHttpServer}::body::boxed) } """, + "Protocol" to protocol.markerStruct(), *codegenScope, ) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index a002a167ea..2e659adc26 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -29,7 +29,7 @@ open class ServerServiceGenerator( private val rustCrate: RustCrate, private val protocolGenerator: ServerProtocolGenerator, private val protocolSupport: ProtocolSupport, - private val protocol: ServerProtocol, + val protocol: ServerProtocol, private val codegenContext: CodegenContext, ) { private val index = TopDownIndex.of(codegenContext.model) @@ -107,7 +107,7 @@ open class ServerServiceGenerator( // Render operations handler. open fun renderOperationHandler(writer: RustWriter, operations: List) { - ServerOperationHandlerGenerator(codegenContext, operations).render(writer) + ServerOperationHandlerGenerator(codegenContext, protocol, operations).render(writer) } // Render operations registry. diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolGenerator.kt index f46ade2f66..6428e8dba0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolGenerator.kt @@ -11,11 +11,10 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.MakeOperationGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTraitImplGenerator -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol open class ServerProtocolGenerator( codegenContext: CodegenContext, - protocol: Protocol, + val protocol: ServerProtocol, makeOperationGenerator: MakeOperationGenerator, private val traitGenerator: ProtocolTraitImplGenerator, ) : ProtocolGenerator(codegenContext, protocol, makeOperationGenerator, traitGenerator) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 7530438afc..f4b3bab27a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -452,8 +452,9 @@ class ServerProtocolTestGenerator( """ let mut http_request = #{SmithyHttpServer}::request::RequestParts::new(http_request); let rejection = super::$operationName::from_request(&mut http_request).await.expect_err("request was accepted but we expected it to be rejected"); - let http_response = rejection.into_response(); + let http_response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(rejection); """, + "Protocol" to protocolGenerator.protocol.markerStruct(), *codegenScope, ) checkResponse(this, testCase.response) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 47449ae0a5..952a0dc73a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -70,7 +70,6 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.outputShape -import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext @@ -177,10 +176,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate( """ if ! #{SmithyHttpServer}::protocols::accept_header_classifier(req, ${contentType.dq()}) { - return Err(#{RuntimeError} { - protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, - kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::NotAcceptable, - }) + return Err(#{RuntimeError}::NotAcceptable) } """, *codegenScope, @@ -200,10 +196,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate( """ if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() { - return Err(#{RuntimeError} { - protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, - kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::UnsupportedMediaType, - }) + return Err(#{RuntimeError}::UnsupportedMediaType) } """, *codegenScope, @@ -230,12 +223,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( #{parse_request}(req) .await .map($inputName) - .map_err( - |err| #{RuntimeError} { - protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, - kind: err.into() - } - ) + .map_err(Into::into) } } @@ -282,12 +270,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( Self::Output(o) => { match #{serialize_response}(o) { Ok(response) => response, - Err(e) => { - #{RuntimeError} { - protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, - kind: e.into() - }.into_response() - } + Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) } }, Self::Error(err) => { @@ -296,12 +279,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( response.extensions_mut().insert(#{SmithyHttpServer}::extension::ModeledErrorExtension::new(err.name())); response }, - Err(e) => { - #{RuntimeError} { - protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, - kind: e.into() - }.into_response() - } + Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) } } } @@ -346,12 +324,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( """ match #{serialize_response}(self.0) { Ok(response) => response, - Err(e) => { - #{RuntimeError} { - protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, - kind: e.into() - }.into_response() - } + Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) } """.trimIndent() diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index 519e75501a..a6396fa969 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -5,8 +5,14 @@ //! Python error definition. -use aws_smithy_http_server::protocols::Protocol; -use aws_smithy_http_server::{body::to_boxed, response::Response}; +use aws_smithy_http_server::{ + body::{to_boxed, BoxBody}, + proto::{ + aws_json_10::AwsJson10, aws_json_11::AwsJson11, rest_json_1::AwsRestJson1, + rest_xml::AwsRestXml, + }, + response::IntoResponse, +}; use aws_smithy_types::date_time::{ConversionError, DateTimeParseError}; use pyo3::{create_exception, exceptions::PyException as BasePyException, prelude::*, PyErr}; use thiserror::Error; @@ -62,39 +68,50 @@ impl From for PyMiddlewareException { } } -impl PyMiddlewareException { - /// Convert the exception into a [Response], following the [Protocol] specification. - pub(crate) fn into_response(self, protocol: Protocol) -> Response { - let body = to_boxed(match protocol { - Protocol::RestJson1 => self.json_body(), - Protocol::RestXml => self.xml_body(), - // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization - Protocol::AwsJson10 => self.json_body(), - // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization - Protocol::AwsJson11 => self.json_body(), - }); +impl IntoResponse for PyMiddlewareException { + fn into_response(self) -> http::Response { + http::Response::builder() + .status(self.status_code) + .header("Content-Type", "application/json") + .header("X-Amzn-Errortype", "MiddlewareException") + .body(to_boxed(self.json_body())) + .expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + } +} - let mut builder = http::Response::builder(); - builder = builder.status(self.status_code); +impl IntoResponse for PyMiddlewareException { + fn into_response(self) -> http::Response { + http::Response::builder() + .status(self.status_code) + .header("Content-Type", "application/xml") + .body(to_boxed(self.xml_body())) + .expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + } +} - match protocol { - Protocol::RestJson1 => { - builder = builder - .header("Content-Type", "application/json") - .header("X-Amzn-Errortype", "MiddlewareException"); - } - Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"), - Protocol::AwsJson10 => { - builder = builder.header("Content-Type", "application/x-amz-json-1.0") - } - Protocol::AwsJson11 => { - builder = builder.header("Content-Type", "application/x-amz-json-1.1") - } - } +impl IntoResponse for PyMiddlewareException { + fn into_response(self) -> http::Response { + http::Response::builder() + .status(self.status_code) + .header("Content-Type", "application/x-amz-json-1.0") + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization + .body(to_boxed(self.json_body())) + .expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + } +} - builder.body(body).expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") +impl IntoResponse for PyMiddlewareException { + fn into_response(self) -> http::Response { + http::Response::builder() + .status(self.status_code) + .header("Content-Type", "application/x-amz-json-1.1") + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization + .body(to_boxed(self.json_body())) + .expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } +} +impl PyMiddlewareException { /// Serialize the body into a JSON object. fn json_body(&self) -> String { let mut out = String::new(); diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index 00f99312e8..20a4104925 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -4,11 +4,10 @@ */ //! Execute Python middleware handlers. -use aws_smithy_http_server::body::Body; +use aws_smithy_http_server::{body::Body, body::BoxBody, response::IntoResponse}; use http::Request; use pyo3::prelude::*; -use aws_smithy_http_server::protocols::Protocol; use pyo3_asyncio::TaskLocals; use crate::{PyMiddlewareException, PyRequest, PyResponse}; @@ -36,18 +35,27 @@ pub struct PyMiddlewareHandler { /// Structure holding the list of Python middlewares that will be executed by this server. /// /// Middlewares are executed one after each other inside the [crate::PyMiddlewareLayer] Tower layer. -#[derive(Debug, Clone, Default)] -pub struct PyMiddlewares(Vec); +#[derive(Debug, Clone)] +pub struct PyMiddlewares { + handlers: Vec, + into_response: fn(PyMiddlewareException) -> http::Response, +} impl PyMiddlewares { /// Create a new instance of `PyMiddlewareHandlers` from a list of heandlers. - pub fn new(handlers: Vec) -> Self { - Self(handlers) + pub fn new

(handlers: Vec) -> Self + where + PyMiddlewareException: IntoResponse

, + { + Self { + handlers, + into_response: PyMiddlewareException::into_response, + } } /// Add a new handler to the list. pub fn push(&mut self, handler: PyMiddlewareHandler) { - self.0.push(handler); + self.handlers.push(handler); } /// Execute a single middleware handler. @@ -114,13 +122,9 @@ impl PyMiddlewares { /// and return a protocol specific error, with the option of setting the HTTP return code. /// * Middleware raising any other exception will immediately terminate the request handling and /// return a protocol specific error, with HTTP status code 500. - pub fn run( - &mut self, - mut request: Request, - protocol: Protocol, - locals: TaskLocals, - ) -> PyFuture { - let handlers = self.0.clone(); + pub fn run(&mut self, mut request: Request, locals: TaskLocals) -> PyFuture { + let handlers = self.handlers.clone(); + let into_response = self.into_response; // Run all Python handlers in a loop. Box::pin(async move { tracing::debug!("Executing Python middleware stack"); @@ -152,7 +156,7 @@ impl PyMiddlewares { tracing::debug!( "Middleware `{name}` returned an error, exit middleware loop" ); - return Err(e.into_response(protocol)); + return Err((into_response)(e)); } } } @@ -166,6 +170,7 @@ impl PyMiddlewares { #[cfg(test)] mod tests { + use aws_smithy_http_server::proto::rest_json_1::AwsRestJson1; use http::HeaderValue; use hyper::body::to_bytes; use pretty_assertions::assert_eq; @@ -175,7 +180,7 @@ mod tests { #[tokio::test] async fn request_middleware_chain_keeps_headers_changes() -> PyResult<()> { let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares(vec![]); + let mut middlewares = PyMiddlewares::new::(vec![]); Python::with_gil(|py| { let middleware = PyModule::new(py, "middleware").unwrap(); @@ -212,11 +217,7 @@ def second_middleware(request: Request): })?; let result = middlewares - .run( - Request::builder().body(Body::from("")).unwrap(), - Protocol::RestJson1, - locals, - ) + .run(Request::builder().body(Body::from("")).unwrap(), locals) .await .unwrap(); assert_eq!( @@ -229,7 +230,7 @@ def second_middleware(request: Request): #[tokio::test] async fn request_middleware_return_response() -> PyResult<()> { let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares(vec![]); + let mut middlewares = PyMiddlewares::new::(vec![]); Python::with_gil(|py| { let middleware = PyModule::new(py, "middleware").unwrap(); @@ -252,11 +253,7 @@ def middleware(request: Request): })?; let result = middlewares - .run( - Request::builder().body(Body::from("")).unwrap(), - Protocol::RestJson1, - locals, - ) + .run(Request::builder().body(Body::from("")).unwrap(), locals) .await .unwrap_err(); assert_eq!(result.status(), 200); @@ -268,7 +265,7 @@ def middleware(request: Request): #[tokio::test] async fn request_middleware_raise_middleware_exception() -> PyResult<()> { let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares(vec![]); + let mut middlewares = PyMiddlewares::new::(vec![]); Python::with_gil(|py| { let middleware = PyModule::new(py, "middleware").unwrap(); @@ -291,11 +288,7 @@ def middleware(request: Request): })?; let result = middlewares - .run( - Request::builder().body(Body::from("")).unwrap(), - Protocol::RestJson1, - locals, - ) + .run(Request::builder().body(Body::from("")).unwrap(), locals) .await .unwrap_err(); assert_eq!(result.status(), 503); @@ -311,7 +304,7 @@ def middleware(request: Request): #[tokio::test] async fn request_middleware_raise_python_exception() -> PyResult<()> { let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares(vec![]); + let mut middlewares = PyMiddlewares::new::(vec![]); Python::with_gil(|py| { let middleware = PyModule::from_code( @@ -333,11 +326,7 @@ def middleware(request): })?; let result = middlewares - .run( - Request::builder().body(Body::from("")).unwrap(), - Protocol::RestJson1, - locals, - ) + .run(Request::builder().body(Body::from("")).unwrap(), locals) .await .unwrap_err(); assert_eq!(result.status(), 500); diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index 73508541a2..6c1cb6bce0 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -5,69 +5,52 @@ //! Tower layer implementation of Python middleware handling. use std::{ + marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use aws_smithy_http_server::{ body::{Body, BoxBody}, - protocols::Protocol, + response::IntoResponse, }; use futures::{ready, Future}; use http::{Request, Response}; use pin_project_lite::pin_project; -use pyo3::PyResult; use pyo3_asyncio::TaskLocals; use tower::{Layer, Service}; -use crate::{error::PyException, middleware::PyFuture, PyMiddlewares}; +use crate::{middleware::PyFuture, PyMiddlewareException, PyMiddlewares}; /// Tower [Layer] implementation of Python middleware handling. /// /// Middleware stored in the `handlers` attribute will be executed, in order, /// inside an async Tower middleware. #[derive(Debug, Clone)] -pub struct PyMiddlewareLayer { +pub struct PyMiddlewareLayer

{ handlers: PyMiddlewares, - protocol: Protocol, locals: TaskLocals, + _protocol: PhantomData

, } -impl PyMiddlewareLayer { - pub fn new( - handlers: PyMiddlewares, - protocol: &str, - locals: TaskLocals, - ) -> PyResult { - let protocol = match protocol { - "aws.protocols#restJson1" => Protocol::RestJson1, - "aws.protocols#restXml" => Protocol::RestXml, - "aws.protocols#awsjson10" => Protocol::AwsJson10, - "aws.protocols#awsjson11" => Protocol::AwsJson11, - _ => { - return Err(PyException::new_err(format!( - "Protocol {protocol} is not supported" - ))) - } - }; - Ok(Self { +impl

PyMiddlewareLayer

{ + pub fn new(handlers: PyMiddlewares, locals: TaskLocals) -> Self { + Self { handlers, - protocol, locals, - }) + _protocol: PhantomData, + } } } -impl Layer for PyMiddlewareLayer { +impl Layer for PyMiddlewareLayer

+where + PyMiddlewareException: IntoResponse

, +{ type Service = PyMiddlewareService; fn layer(&self, inner: S) -> Self::Service { - PyMiddlewareService::new( - inner, - self.handlers.clone(), - self.protocol, - self.locals.clone(), - ) + PyMiddlewareService::new(inner, self.handlers.clone(), self.locals.clone()) } } @@ -76,21 +59,14 @@ impl Layer for PyMiddlewareLayer { pub struct PyMiddlewareService { inner: S, handlers: PyMiddlewares, - protocol: Protocol, locals: TaskLocals, } impl PyMiddlewareService { - pub fn new( - inner: S, - handlers: PyMiddlewares, - protocol: Protocol, - locals: TaskLocals, - ) -> PyMiddlewareService { + pub fn new(inner: S, handlers: PyMiddlewares, locals: TaskLocals) -> PyMiddlewareService { Self { inner, handlers, - protocol, locals, } } @@ -113,7 +89,7 @@ where let clone = self.inner.clone(); // See https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services let inner = std::mem::replace(&mut self.inner, clone); - let run = self.handlers.run(req, self.protocol, self.locals.clone()); + let run = self.handlers.run(req, self.locals.clone()); ResponseFuture { middleware: State::Running { run }, @@ -184,6 +160,7 @@ mod tests { use super::*; use aws_smithy_http_server::body::to_boxed; + use aws_smithy_http_server::proto::rest_json_1::AwsRestJson1; use pyo3::prelude::*; use tower::{Service, ServiceBuilder, ServiceExt}; @@ -197,7 +174,7 @@ mod tests { #[tokio::test] async fn request_middlewares_are_chained_inside_layer() -> PyResult<()> { let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares::new(vec![]); + let mut middlewares = PyMiddlewares::new::(vec![]); Python::with_gil(|py| { let middleware = PyModule::new(py, "middleware").unwrap(); @@ -234,11 +211,7 @@ def second_middleware(request: Request): })?; let mut service = ServiceBuilder::new() - .layer(PyMiddlewareLayer::new( - middlewares, - "aws.protocols#restJson1", - locals, - )?) + .layer(PyMiddlewareLayer::::new(middlewares, locals)) .service_fn(echo); let request = Request::get("/").body(Body::empty()).unwrap(); diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index fdff93b398..6a4ac559c2 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -63,8 +63,6 @@ pub trait PyApp: Clone + pyo3::IntoPy { fn middlewares(&mut self) -> &mut PyMiddlewares; - fn protocol(&self) -> &'static str; - /// Handle the graceful termination of Python workers by looping through all the /// active workers and calling `terminate()` on them. If termination fails, this /// method will try to `kill()` any failed worker. @@ -385,7 +383,6 @@ event_loop.add_signal_handler(signal.SIGINT, /// fn context(&self) -> &Option { todo!() } /// fn handlers(&mut self) -> &mut HashMap { todo!() } /// fn middlewares(&mut self) -> &mut PyMiddlewares { todo!() } - /// fn protocol(&self) -> &'static str { "proto1" } /// } /// /// #[pymethods] diff --git a/rust-runtime/aws-smithy-http-server/src/extension.rs b/rust-runtime/aws-smithy-http-server/src/extension.rs index d0b4c1cdd9..736dfd0c44 100644 --- a/rust-runtime/aws-smithy-http-server/src/extension.rs +++ b/rust-runtime/aws-smithy-http-server/src/extension.rs @@ -131,7 +131,7 @@ impl Deref for ModeledErrorExtension { } /// Extension type used to store the _name_ of the [`crate::runtime_error::RuntimeError`] that -/// occurred during request handling (see [`crate::runtime_error::RuntimeErrorKind::name`]). +/// occurred during request handling (see [`crate::runtime_error::RuntimeError::name`]). /// These are _unmodeled_ errors; the operation handler was not invoked. #[derive(Debug, Clone)] pub struct RuntimeErrorExtension(String); diff --git a/rust-runtime/aws-smithy-http-server/src/protocols.rs b/rust-runtime/aws-smithy-http-server/src/protocols.rs index 1707e33cf7..6b22e9dd1b 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocols.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocols.rs @@ -7,15 +7,6 @@ use crate::rejection::MissingContentTypeReason; use crate::request::RequestParts; -/// Supported protocols. -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum Protocol { - RestJson1, - RestXml, - AwsJson10, - AwsJson11, -} - /// When there are no modeled inputs, /// a request body is empty and the content-type request header must not be set pub fn content_type_header_empty_body_no_modeled_input( diff --git a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs index b8a4e71e0a..e8174f0f08 100644 --- a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs @@ -11,7 +11,7 @@ //! the framework, `RuntimeError` is surfaced to clients in HTTP responses: indeed, it implements //! [`RuntimeError::into_response`]. Rejections can be "grouped" and converted into a //! specific `RuntimeError` kind: for example, all request rejections due to serialization issues -//! can be conflated under the [`RuntimeErrorKind::Serialization`] enum variant. +//! can be conflated under the [`RuntimeError::Serialization`] enum variant. //! //! The HTTP response representation of the specific `RuntimeError` can be protocol-specific: for //! example, the runtime error in the RestJson1 protocol sets the `X-Amzn-Errortype` header. @@ -21,15 +21,17 @@ //! and converts into the corresponding `RuntimeError`, and then it uses the its //! [`RuntimeError::into_response`] method to render and send a response. +use http::StatusCode; + +use crate::extension::RuntimeErrorExtension; use crate::proto::aws_json_10::AwsJson10; use crate::proto::aws_json_11::AwsJson11; use crate::proto::rest_json_1::AwsRestJson1; use crate::proto::rest_xml::AwsRestXml; -use crate::protocols::Protocol; -use crate::response::{IntoResponse, Response}; +use crate::response::IntoResponse; #[derive(Debug)] -pub enum RuntimeErrorKind { +pub enum RuntimeError { /// Request failed to deserialize or response failed to serialize. Serialization(crate::Error), /// As of writing, this variant can only occur upon failure to extract an @@ -43,13 +45,22 @@ pub enum RuntimeErrorKind { /// String representation of the runtime error type. /// Used as the value of the `X-Amzn-Errortype` header in RestJson1. /// Used as the value passed to construct an [`crate::extension::RuntimeErrorExtension`]. -impl RuntimeErrorKind { +impl RuntimeError { pub fn name(&self) -> &'static str { match self { - RuntimeErrorKind::Serialization(_) => "SerializationException", - RuntimeErrorKind::InternalFailure(_) => "InternalFailureException", - RuntimeErrorKind::NotAcceptable => "NotAcceptableException", - RuntimeErrorKind::UnsupportedMediaType => "UnsupportedMediaTypeException", + Self::Serialization(_) => "SerializationException", + Self::InternalFailure(_) => "InternalFailureException", + Self::NotAcceptable => "NotAcceptableException", + Self::UnsupportedMediaType => "UnsupportedMediaTypeException", + } + } + + pub fn status_code(&self) -> StatusCode { + match self { + Self::Serialization(_) => StatusCode::BAD_REQUEST, + Self::InternalFailure(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::NotAcceptable => StatusCode::NOT_ACCEPTABLE, + Self::UnsupportedMediaType => StatusCode::UNSUPPORTED_MEDIA_TYPE, } } } @@ -58,104 +69,93 @@ pub struct InternalFailureException; impl IntoResponse for InternalFailureException { fn into_response(self) -> http::Response { - RuntimeError::internal_failure_from_protocol(Protocol::AwsJson10).into_response() + IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) } } impl IntoResponse for InternalFailureException { fn into_response(self) -> http::Response { - RuntimeError::internal_failure_from_protocol(Protocol::AwsJson11).into_response() + IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) } } impl IntoResponse for InternalFailureException { fn into_response(self) -> http::Response { - RuntimeError::internal_failure_from_protocol(Protocol::RestJson1).into_response() + IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) } } impl IntoResponse for InternalFailureException { fn into_response(self) -> http::Response { - RuntimeError::internal_failure_from_protocol(Protocol::RestXml).into_response() + IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) } } -#[derive(Debug)] -pub struct RuntimeError { - pub protocol: Protocol, - pub kind: RuntimeErrorKind, +impl IntoResponse for RuntimeError { + fn into_response(self) -> http::Response { + http::Response::builder() + .status(self.status_code()) + .header("Content-Type", "application/json") + .header("X-Amzn-Errortype", self.name()) + .extension(RuntimeErrorExtension::new(self.name().to_string())) + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization + .body(crate::body::to_boxed("{}")) + .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + } } -impl

IntoResponse

for RuntimeError { +impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { - self.into_response() + http::Response::builder() + .status(self.status_code()) + .header("Content-Type", "application/xml") + .extension(RuntimeErrorExtension::new(self.name().to_string())) + .body(crate::body::to_boxed("")) + .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } -impl RuntimeError { - pub fn internal_failure_from_protocol(protocol: Protocol) -> Self { - RuntimeError { - protocol, - kind: RuntimeErrorKind::InternalFailure(crate::Error::new(String::new())), - } +impl IntoResponse for RuntimeError { + fn into_response(self) -> http::Response { + http::Response::builder() + .status(self.status_code()) + .header("Content-Type", "application/x-amz-json-1.0") + .extension(RuntimeErrorExtension::new(self.name().to_string())) + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization + .body(crate::body::to_boxed("")) + .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } +} - pub fn into_response(self) -> Response { - let status_code = match self.kind { - RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST, - RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR, - RuntimeErrorKind::NotAcceptable => http::StatusCode::NOT_ACCEPTABLE, - RuntimeErrorKind::UnsupportedMediaType => http::StatusCode::UNSUPPORTED_MEDIA_TYPE, - }; - - let body = crate::body::to_boxed(match self.protocol { - Protocol::RestJson1 => "{}", - Protocol::RestXml => "", - // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization - Protocol::AwsJson10 => "", +impl IntoResponse for RuntimeError { + fn into_response(self) -> http::Response { + http::Response::builder() + .status(self.status_code()) + .header("Content-Type", "application/x-amz-json-1.1") + .extension(RuntimeErrorExtension::new(self.name().to_string())) // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization - Protocol::AwsJson11 => "", - }); - - let mut builder = http::Response::builder(); - builder = builder.status(status_code); - - match self.protocol { - Protocol::RestJson1 => { - builder = builder - .header("Content-Type", "application/json") - .header("X-Amzn-Errortype", self.kind.name()); - } - Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"), - Protocol::AwsJson10 => builder = builder.header("Content-Type", "application/x-amz-json-1.0"), - Protocol::AwsJson11 => builder = builder.header("Content-Type", "application/x-amz-json-1.1"), - } - - builder = builder.extension(crate::extension::RuntimeErrorExtension::new(String::from( - self.kind.name(), - ))); - - builder.body(body).expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + .body(crate::body::to_boxed("")) + .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } -impl From for RuntimeErrorKind { +impl From for RuntimeError { fn from(err: crate::rejection::RequestExtensionNotFoundRejection) -> Self { - RuntimeErrorKind::InternalFailure(crate::Error::new(err)) + Self::InternalFailure(crate::Error::new(err)) } } -impl From for RuntimeErrorKind { +impl From for RuntimeError { fn from(err: crate::rejection::ResponseRejection) -> Self { - RuntimeErrorKind::Serialization(crate::Error::new(err)) + Self::Serialization(crate::Error::new(err)) } } -impl From for RuntimeErrorKind { +impl From for RuntimeError { fn from(err: crate::rejection::RequestRejection) -> Self { match err { - crate::rejection::RequestRejection::MissingContentType(_reason) => RuntimeErrorKind::UnsupportedMediaType, - _ => RuntimeErrorKind::Serialization(crate::Error::new(err)), + crate::rejection::RequestRejection::MissingContentType(_reason) => Self::UnsupportedMediaType, + _ => Self::Serialization(crate::Error::new(err)), } } }