From ec9515d9db7131b813eb031c1a9c1c317919c7da Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 13 Jan 2022 16:45:47 +0100 Subject: [PATCH] `rust-server-codegen`: add `ResponseExtensions` to non-fallible operations This commit also renames `RequestExtensions` to `ResponseExtensions`, since we add this extension type only to HTTP responses. It also makes its members private. Closes #1063. --- .../protocol/ServerProtocolTestGenerator.kt | 37 ++++++------------- .../protocols/ServerHttpProtocolGenerator.kt | 8 ++-- .../aws-smithy-http-server/src/extension.rs | 16 ++++---- .../aws-smithy-http-server/src/lib.rs | 2 +- 4 files changed, 26 insertions(+), 37 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index ea78b93fdf..ab51436994 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -295,7 +295,7 @@ class ServerProtocolTestGenerator( ); """ ) - checkHttpExtensions(this) + checkHttpResponseExtensions(this) if (!testCase.body.isEmpty()) { rustTemplate( """ @@ -333,12 +333,18 @@ class ServerProtocolTestGenerator( } } - private fun checkHttpExtensions(rustWriter: RustWriter) { - rustWriter.rust( + private fun checkHttpResponseExtensions(rustWriter: RustWriter) { + rustWriter.rustTemplate( + """ + let request_extensions = http_response.extensions() + .get::<#{SmithyHttpServer}::ResponseExtensions>() + .expect("extension `ResponseExtensions` not found"); + """.trimIndent(), + *codegenScope + ) + rustWriter.writeWithNoFormatting( """ - let request_extensions = http_response.extensions().get::().expect("extension `RequestExtensions` not found"); - assert_eq!(request_extensions.namespace, ${operationShape.id.getNamespace().dq()}); - assert_eq!(request_extensions.operation_name, ${operationSymbol.name.dq()}); + assert_eq!(request_extensions.operation(), format!("{}#{}", "${operationShape.id.namespace}", "${operationSymbol.name}")); """.trimIndent() ) } @@ -434,7 +440,6 @@ class ServerProtocolTestGenerator( private val Ec2Query = "aws.protocoltests.ec2#AwsEc2" private val ExpectFail = setOf( FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Request), - FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Response), FailingTest(RestJson, "RestJsonOutputUnionWithUnitMember", Action.Response), FailingTest(RestJson, "RestJsonUnitInputAllowsAccept", Action.Request), FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response), @@ -450,7 +455,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "DocumentTypeAsPayloadInputString", Action.Request), FailingTest(RestJson, "DocumentTypeAsPayloadOutput", Action.Response), FailingTest(RestJson, "DocumentTypeAsPayloadOutputString", Action.Response), - FailingTest(RestJson, "RestJsonEmptyInputAndEmptyOutput", Action.Response), FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request), FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request), FailingTest(RestJson, "RestJsonInvalidGreetingError", Action.Response), @@ -476,13 +480,9 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonHttpPayloadWithStructure", Action.Response), FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Request), FailingTest(RestJson, "RestJsonHttpPrefixHeadersAreNotPresent", Action.Request), - FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Response), - FailingTest(RestJson, "HttpPrefixHeadersResponse", Action.Response), FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request), - FailingTest(RestJson, "RestJsonHttpResponseCode", Action.Response), FailingTest(RestJson, "StringPayloadRequest", Action.Request), FailingTest(RestJson, "StringPayloadResponse", Action.Response), - FailingTest(RestJson, "RestJsonIgnoreQueryParamsInResponse", Action.Response), FailingTest(RestJson, "RestJsonInputAndOutputWithStringHeaders", Action.Request), FailingTest(RestJson, "RestJsonInputAndOutputWithNumericHeaders", Action.Request), FailingTest(RestJson, "RestJsonInputAndOutputWithBooleanHeaders", Action.Request), @@ -491,14 +491,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderInputs", Action.Request), FailingTest(RestJson, "RestJsonSupportsInfinityFloatHeaderInputs", Action.Request), FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatHeaderInputs", Action.Request), - FailingTest(RestJson, "RestJsonInputAndOutputWithStringHeaders", Action.Response), - FailingTest(RestJson, "RestJsonInputAndOutputWithNumericHeaders", Action.Response), - FailingTest(RestJson, "RestJsonInputAndOutputWithBooleanHeaders", Action.Response), - FailingTest(RestJson, "RestJsonInputAndOutputWithTimestampHeaders", Action.Response), - FailingTest(RestJson, "RestJsonInputAndOutputWithEnumHeaders", Action.Response), - FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderOutputs", Action.Response), - FailingTest(RestJson, "RestJsonSupportsInfinityFloatHeaderOutputs", Action.Response), - FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatHeaderOutputs", Action.Response), FailingTest(RestJson, "RestJsonJsonBlobs", Action.Response), FailingTest(RestJson, "RestJsonJsonEnums", Action.Response), FailingTest(RestJson, "RestJsonLists", Action.Response), @@ -524,17 +516,13 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonDeserializeMapUnionValue", Action.Response), FailingTest(RestJson, "RestJsonDeserializeStructureUnionValue", Action.Response), FailingTest(RestJson, "MediaTypeHeaderInputBase64", Action.Request), - FailingTest(RestJson, "MediaTypeHeaderOutputBase64", Action.Response), FailingTest(RestJson, "RestJsonNoInputAllowsAccept", Action.Request), FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response), FailingTest(RestJson, "RestJsonNoInputAndOutputAllowsAccept", Action.Request), - FailingTest(RestJson, "RestJsonNoInputAndOutputWithJson", Action.Response), - FailingTest(RestJson, "RestJsonNullAndEmptyHeaders", Action.Response), FailingTest(RestJson, "RestJsonRecursiveShapes", Action.Response), FailingTest(RestJson, "RestJsonSimpleScalarProperties", Action.Request), FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request), FailingTest(RestJson, "RestJsonSimpleScalarProperties", Action.Response), - FailingTest(RestJson, "RestJsonServersDontSerializeNullStructureValues", Action.Response), FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Response), FailingTest(RestJson, "RestJsonSupportsInfinityFloatInputs", Action.Response), FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatInputs", Action.Response), @@ -557,7 +545,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonTestPayloadStructure", Action.Request), FailingTest(RestJson, "RestJsonHttpWithHeadersButNoPayload", Action.Request), FailingTest(RestJson, "RestJsonTimestampFormatHeaders", Action.Request), - FailingTest(RestJson, "RestJsonTimestampFormatHeaders", Action.Response), FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", Action.Response), FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", Action.Request), FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing", Action.Request), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt index 8438c25618..9267314d86 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt @@ -242,10 +242,12 @@ private class ServerHttpProtocolImplGenerator( intoResponseStreaming } else { """ - match #{serialize_response}(&self.0) { + let mut response = match #{serialize_response}(&self.0) { Ok(response) => response, Err(e) => e.into_response() - } + }; + $httpExtensions + response """.trimIndent() } // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type @@ -315,7 +317,7 @@ private class ServerHttpProtocolImplGenerator( val namespace = operationShape.id.getNamespace() val operationName = symbolProvider.toSymbol(operationShape).name return """ - response.extensions_mut().insert(#{SmithyHttpServer}::RequestExtensions::new(${namespace.dq()}, ${operationName.dq()})); + response.extensions_mut().insert(#{SmithyHttpServer}::ResponseExtensions::new(${namespace.dq()}, ${operationName.dq()})); """.trimIndent() } diff --git a/rust-runtime/aws-smithy-http-server/src/extension.rs b/rust-runtime/aws-smithy-http-server/src/extension.rs index 1355c9e696..8721674c9d 100644 --- a/rust-runtime/aws-smithy-http-server/src/extension.rs +++ b/rust-runtime/aws-smithy-http-server/src/extension.rs @@ -39,17 +39,17 @@ use async_trait::async_trait; use axum_core::extract::{FromRequest, RequestParts}; use std::ops::Deref; -/// Extension type used to store Smithy request information. -#[derive(Debug, Clone, Default, Copy)] -pub struct RequestExtensions { +/// Extension type used to store information in HTTP responses. +#[derive(Debug, Clone)] +pub struct ResponseExtensions { /// Smithy model namespace. - pub namespace: &'static str, + namespace: &'static str, /// Smithy operation name. - pub operation_name: &'static str, + operation_name: &'static str, } -impl RequestExtensions { - /// Generates a new `RequestExtensions`. +impl ResponseExtensions { + /// Creates a new `ResponseExtensions`. pub fn new(namespace: &'static str, operation_name: &'static str) -> Self { Self { namespace, @@ -57,7 +57,7 @@ impl RequestExtensions { } } - /// Returns the current operation formatted as #. + /// Returns the current operation formatted as `#`. pub fn operation(&self) -> String { format!("{}#{}", self.namespace, self.operation_name) } diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index 95302b5027..2e604645c4 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -24,7 +24,7 @@ pub use self::body::{boxed, to_boxed, Body, BoxBody, HttpBody}; #[doc(inline)] pub use self::error::Error; #[doc(inline)] -pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, RequestExtensions}; +pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, ResponseExtensions}; #[doc(inline)] pub use self::routing::Router; #[doc(inline)]