Skip to content

Commit

Permalink
rust-server-codegen: add ResponseExtensions to non-fallible opera…
Browse files Browse the repository at this point in the history
…tions

This commit also renames `RequestExtensions` to `ResponseExtensions`,
since we add this extension type only to HTTP responses. It also makes
its members private.

Closes #1063.
  • Loading branch information
david-perez committed Jan 13, 2022
1 parent 844422b commit ec9515d
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class ServerProtocolTestGenerator(
);
"""
)
checkHttpExtensions(this)
checkHttpResponseExtensions(this)
if (!testCase.body.isEmpty()) {
rustTemplate(
"""
Expand Down Expand Up @@ -333,12 +333,18 @@ class ServerProtocolTestGenerator(
}
}

private fun checkHttpExtensions(rustWriter: RustWriter) {
rustWriter.rust(
private fun checkHttpResponseExtensions(rustWriter: RustWriter) {
rustWriter.rustTemplate(
"""
let request_extensions = http_response.extensions()
.get::<#{SmithyHttpServer}::ResponseExtensions>()
.expect("extension `ResponseExtensions` not found");
""".trimIndent(),
*codegenScope
)
rustWriter.writeWithNoFormatting(
"""
let request_extensions = http_response.extensions().get::<aws_smithy_http_server::RequestExtensions>().expect("extension `RequestExtensions` not found");
assert_eq!(request_extensions.namespace, ${operationShape.id.getNamespace().dq()});
assert_eq!(request_extensions.operation_name, ${operationSymbol.name.dq()});
assert_eq!(request_extensions.operation(), format!("{}#{}", "${operationShape.id.namespace}", "${operationSymbol.name}"));
""".trimIndent()
)
}
Expand Down Expand Up @@ -434,7 +440,6 @@ class ServerProtocolTestGenerator(
private val Ec2Query = "aws.protocoltests.ec2#AwsEc2"
private val ExpectFail = setOf<FailingTest>(
FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Request),
FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Response),
FailingTest(RestJson, "RestJsonOutputUnionWithUnitMember", Action.Response),
FailingTest(RestJson, "RestJsonUnitInputAllowsAccept", Action.Request),
FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
Expand All @@ -450,7 +455,6 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "DocumentTypeAsPayloadInputString", Action.Request),
FailingTest(RestJson, "DocumentTypeAsPayloadOutput", Action.Response),
FailingTest(RestJson, "DocumentTypeAsPayloadOutputString", Action.Response),
FailingTest(RestJson, "RestJsonEmptyInputAndEmptyOutput", Action.Response),
FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request),
FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request),
FailingTest(RestJson, "RestJsonInvalidGreetingError", Action.Response),
Expand All @@ -476,13 +480,9 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonHttpPayloadWithStructure", Action.Response),
FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Request),
FailingTest(RestJson, "RestJsonHttpPrefixHeadersAreNotPresent", Action.Request),
FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Response),
FailingTest(RestJson, "HttpPrefixHeadersResponse", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request),
FailingTest(RestJson, "RestJsonHttpResponseCode", Action.Response),
FailingTest(RestJson, "StringPayloadRequest", Action.Request),
FailingTest(RestJson, "StringPayloadResponse", Action.Response),
FailingTest(RestJson, "RestJsonIgnoreQueryParamsInResponse", Action.Response),
FailingTest(RestJson, "RestJsonInputAndOutputWithStringHeaders", Action.Request),
FailingTest(RestJson, "RestJsonInputAndOutputWithNumericHeaders", Action.Request),
FailingTest(RestJson, "RestJsonInputAndOutputWithBooleanHeaders", Action.Request),
Expand All @@ -491,14 +491,6 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderInputs", Action.Request),
FailingTest(RestJson, "RestJsonSupportsInfinityFloatHeaderInputs", Action.Request),
FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatHeaderInputs", Action.Request),
FailingTest(RestJson, "RestJsonInputAndOutputWithStringHeaders", Action.Response),
FailingTest(RestJson, "RestJsonInputAndOutputWithNumericHeaders", Action.Response),
FailingTest(RestJson, "RestJsonInputAndOutputWithBooleanHeaders", Action.Response),
FailingTest(RestJson, "RestJsonInputAndOutputWithTimestampHeaders", Action.Response),
FailingTest(RestJson, "RestJsonInputAndOutputWithEnumHeaders", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderOutputs", Action.Response),
FailingTest(RestJson, "RestJsonSupportsInfinityFloatHeaderOutputs", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatHeaderOutputs", Action.Response),
FailingTest(RestJson, "RestJsonJsonBlobs", Action.Response),
FailingTest(RestJson, "RestJsonJsonEnums", Action.Response),
FailingTest(RestJson, "RestJsonLists", Action.Response),
Expand All @@ -524,17 +516,13 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonDeserializeMapUnionValue", Action.Response),
FailingTest(RestJson, "RestJsonDeserializeStructureUnionValue", Action.Response),
FailingTest(RestJson, "MediaTypeHeaderInputBase64", Action.Request),
FailingTest(RestJson, "MediaTypeHeaderOutputBase64", Action.Response),
FailingTest(RestJson, "RestJsonNoInputAllowsAccept", Action.Request),
FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
FailingTest(RestJson, "RestJsonNoInputAndOutputAllowsAccept", Action.Request),
FailingTest(RestJson, "RestJsonNoInputAndOutputWithJson", Action.Response),
FailingTest(RestJson, "RestJsonNullAndEmptyHeaders", Action.Response),
FailingTest(RestJson, "RestJsonRecursiveShapes", Action.Response),
FailingTest(RestJson, "RestJsonSimpleScalarProperties", Action.Request),
FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
FailingTest(RestJson, "RestJsonSimpleScalarProperties", Action.Response),
FailingTest(RestJson, "RestJsonServersDontSerializeNullStructureValues", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Response),
FailingTest(RestJson, "RestJsonSupportsInfinityFloatInputs", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatInputs", Action.Response),
Expand All @@ -557,7 +545,6 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonTestPayloadStructure", Action.Request),
FailingTest(RestJson, "RestJsonHttpWithHeadersButNoPayload", Action.Request),
FailingTest(RestJson, "RestJsonTimestampFormatHeaders", Action.Request),
FailingTest(RestJson, "RestJsonTimestampFormatHeaders", Action.Response),
FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", Action.Response),
FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", Action.Request),
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing", Action.Request),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,12 @@ private class ServerHttpProtocolImplGenerator(
intoResponseStreaming
} else {
"""
match #{serialize_response}(&self.0) {
let mut response = match #{serialize_response}(&self.0) {
Ok(response) => response,
Err(e) => e.into_response()
}
};
$httpExtensions
response
""".trimIndent()
}
// The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
Expand Down Expand Up @@ -315,7 +317,7 @@ private class ServerHttpProtocolImplGenerator(
val namespace = operationShape.id.getNamespace()
val operationName = symbolProvider.toSymbol(operationShape).name
return """
response.extensions_mut().insert(#{SmithyHttpServer}::RequestExtensions::new(${namespace.dq()}, ${operationName.dq()}));
response.extensions_mut().insert(#{SmithyHttpServer}::ResponseExtensions::new(${namespace.dq()}, ${operationName.dq()}));
""".trimIndent()
}

Expand Down
16 changes: 8 additions & 8 deletions rust-runtime/aws-smithy-http-server/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,25 @@ use async_trait::async_trait;
use axum_core::extract::{FromRequest, RequestParts};
use std::ops::Deref;

/// Extension type used to store Smithy request information.
#[derive(Debug, Clone, Default, Copy)]
pub struct RequestExtensions {
/// Extension type used to store information in HTTP responses.
#[derive(Debug, Clone)]
pub struct ResponseExtensions {
/// Smithy model namespace.
pub namespace: &'static str,
namespace: &'static str,
/// Smithy operation name.
pub operation_name: &'static str,
operation_name: &'static str,
}

impl RequestExtensions {
/// Generates a new `RequestExtensions`.
impl ResponseExtensions {
/// Creates a new `ResponseExtensions`.
pub fn new(namespace: &'static str, operation_name: &'static str) -> Self {
Self {
namespace,
operation_name,
}
}

/// Returns the current operation formatted as <namespace>#<operation_name>.
/// Returns the current operation formatted as `<namespace>#<operation_name>`.
pub fn operation(&self) -> String {
format!("{}#{}", self.namespace, self.operation_name)
}
Expand Down
2 changes: 1 addition & 1 deletion rust-runtime/aws-smithy-http-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use self::body::{boxed, to_boxed, Body, BoxBody, HttpBody};
#[doc(inline)]
pub use self::error::Error;
#[doc(inline)]
pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, RequestExtensions};
pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, ResponseExtensions};
#[doc(inline)]
pub use self::routing::Router;
#[doc(inline)]
Expand Down

0 comments on commit ec9515d

Please sign in to comment.