Skip to content

Commit

Permalink
rust-server-codegen: add ResponseExtensions to non-fallible opera…
Browse files Browse the repository at this point in the history
…tions

This commit also renames `RequestExtensions` to `ResponseExtensions`,
since we add this extension type only to HTTP responses. It also makes
its members private.

Closes #1063.
  • Loading branch information
david-perez committed Jan 28, 2022
1 parent dac4816 commit 4ae272b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class ServerProtocolTestGenerator(
checkHeaders(this, "&http_response.headers()", testCase.headers)
checkForbidHeaders(this, "&http_response.headers()", testCase.forbidHeaders)
checkRequiredHeaders(this, "&http_response.headers()", testCase.requireHeaders)
checkHttpExtensions(this)
checkHttpResponseExtensions(this)
if (!testCase.body.isEmpty) {
rustTemplate(
"""
Expand Down Expand Up @@ -341,15 +341,20 @@ class ServerProtocolTestGenerator(
}
}

private fun checkHttpExtensions(rustWriter: RustWriter) {
private fun checkHttpResponseExtensions(rustWriter: RustWriter) {
rustWriter.rustTemplate(
"""
let request_extensions = http_response.extensions().get::<aws_smithy_http_server::RequestExtensions>().expect("extension `RequestExtensions` not found");
#{AssertEq}(request_extensions.namespace, ${operationShape.id.getNamespace().dq()});
#{AssertEq}(request_extensions.operation_name, ${operationSymbol.name.dq()});
let request_extensions = http_response.extensions()
.get::<#{SmithyHttpServer}::ResponseExtensions>()
.expect("extension `ResponseExtensions` not found");
""".trimIndent(),
*codegenScope
)
rustWriter.writeWithNoFormatting(
"""
assert_eq!(request_extensions.operation(), format!("{}#{}", "${operationShape.id.namespace}", "${operationSymbol.name}"));
""".trimIndent()
)
}

private fun checkRequiredHeaders(rustWriter: RustWriter, actualExpression: String, requireHeaders: List<String>) {
Expand Down Expand Up @@ -502,6 +507,8 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonTimestampFormatHeaders", Action.Response),

FailingTest(RestJson, "RestJsonHttpPayloadTraitsWithBlob", Action.Request),
// probably remove
FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Request),
FailingTest(RestJson, "RestJsonOutputUnionWithUnitMember", Action.Response),
FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
FailingTest(RestJson, "RestJsonQueryStringEscaping", Action.Request),
Expand All @@ -513,7 +520,6 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "DocumentOutputArray", Action.Response),
FailingTest(RestJson, "DocumentTypeAsPayloadOutput", Action.Response),
FailingTest(RestJson, "DocumentTypeAsPayloadOutputString", Action.Response),
FailingTest(RestJson, "RestJsonEmptyInputAndEmptyOutput", Action.Response),
FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request),
FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request),
FailingTest(RestJson, "RestJsonInvalidGreetingError", Action.Response),
Expand All @@ -531,10 +537,28 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonHttpPayloadTraitsWithMediaTypeWithBlob", Action.Response),
FailingTest(RestJson, "RestJsonHttpPayloadWithStructure", Action.Request),
FailingTest(RestJson, "RestJsonHttpPayloadWithStructure", Action.Response),
// probably remove
FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request),
FailingTest(RestJson, "RestJsonHttpResponseCode", Action.Response),
FailingTest(RestJson, "StringPayloadResponse", Action.Response),
FailingTest(RestJson, "RestJsonIgnoreQueryParamsInResponse", Action.Response),

FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Request),
FailingTest(RestJson, "RestJsonHttpPrefixHeadersAreNotPresent", Action.Request),
FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request),
FailingTest(RestJson, "StringPayloadRequest", Action.Request),
FailingTest(RestJson, "StringPayloadResponse", Action.Response),
FailingTest(RestJson, "RestJsonInputAndOutputWithStringHeaders", Action.Request),
FailingTest(RestJson, "RestJsonInputAndOutputWithNumericHeaders", Action.Request),
FailingTest(RestJson, "RestJsonInputAndOutputWithBooleanHeaders", Action.Request),
FailingTest(RestJson, "RestJsonInputAndOutputWithTimestampHeaders", Action.Request),
FailingTest(RestJson, "RestJsonInputAndOutputWithEnumHeaders", Action.Request),
FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderInputs", Action.Request),
FailingTest(RestJson, "RestJsonSupportsInfinityFloatHeaderInputs", Action.Request),
FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatHeaderInputs", Action.Request),


// probably remove
FailingTest(RestJson, "RestJsonJsonBlobs", Action.Response),
FailingTest(RestJson, "RestJsonJsonEnums", Action.Response),
FailingTest(RestJson, "RestJsonLists", Action.Response),
Expand All @@ -559,12 +583,17 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonDeserializeListUnionValue", Action.Response),
FailingTest(RestJson, "RestJsonDeserializeMapUnionValue", Action.Response),
FailingTest(RestJson, "RestJsonDeserializeStructureUnionValue", Action.Response),
// probably remove
FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
FailingTest(RestJson, "RestJsonNoInputAndOutputWithJson", Action.Response),

FailingTest(RestJson, "MediaTypeHeaderInputBase64", Action.Request),
FailingTest(RestJson, "RestJsonNoInputAllowsAccept", Action.Request),
FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
FailingTest(RestJson, "RestJsonNoInputAndOutputAllowsAccept", Action.Request),
FailingTest(RestJson, "RestJsonRecursiveShapes", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
FailingTest(RestJson, "RestJsonSimpleScalarProperties", Action.Response),
FailingTest(RestJson, "RestJsonServersDontSerializeNullStructureValues", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Response),
FailingTest(RestJson, "RestJsonSupportsInfinityFloatInputs", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatInputs", Action.Response),
Expand All @@ -584,6 +613,10 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonTestPayloadBlob", Action.Request),
FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", Action.Request),
FailingTest(RestJson, "RestJsonTestPayloadStructure", Action.Request),
// probably remove
FailingTest(RestJson, "RestJsonHttpWithHeadersButNoPayload", Action.Request),
FailingTest(RestJson, "RestJsonTimestampFormatHeaders", Action.Request),

FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", Action.Response),
FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", Action.Request),
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing", Action.Request),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,12 @@ private class ServerHttpProtocolImplGenerator(
intoResponseStreaming
} else {
"""
match #{serialize_response}(&self.0) {
let mut response = match #{serialize_response}(&self.0) {
Ok(response) => response,
Err(e) => e.into_response()
}
};
$httpExtensions
response
""".trimIndent()
}
// The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
Expand Down Expand Up @@ -323,7 +325,7 @@ private class ServerHttpProtocolImplGenerator(
val namespace = operationShape.id.getNamespace()
val operationName = symbolProvider.toSymbol(operationShape).name
return """
response.extensions_mut().insert(#{SmithyHttpServer}::RequestExtensions::new(${namespace.dq()}, ${operationName.dq()}));
response.extensions_mut().insert(#{SmithyHttpServer}::ResponseExtensions::new(${namespace.dq()}, ${operationName.dq()}));
""".trimIndent()
}

Expand Down
16 changes: 8 additions & 8 deletions rust-runtime/aws-smithy-http-server/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,25 @@ use async_trait::async_trait;
use axum_core::extract::{FromRequest, RequestParts};
use std::ops::Deref;

/// Extension type used to store Smithy request information.
#[derive(Debug, Clone, Default, Copy)]
pub struct RequestExtensions {
/// Extension type used to store information in HTTP responses.
#[derive(Debug, Clone)]
pub struct ResponseExtensions {
/// Smithy model namespace.
pub namespace: &'static str,
namespace: &'static str,
/// Smithy operation name.
pub operation_name: &'static str,
operation_name: &'static str,
}

impl RequestExtensions {
/// Generates a new `RequestExtensions`.
impl ResponseExtensions {
/// Creates a new `ResponseExtensions`.
pub fn new(namespace: &'static str, operation_name: &'static str) -> Self {
Self {
namespace,
operation_name,
}
}

/// Returns the current operation formatted as <namespace>#<operation_name>.
/// Returns the current operation formatted as `<namespace>#<operation_name>`.
pub fn operation(&self) -> String {
format!("{}#{}", self.namespace, self.operation_name)
}
Expand Down
2 changes: 1 addition & 1 deletion rust-runtime/aws-smithy-http-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use self::body::{boxed, to_boxed, Body, BoxBody, HttpBody};
#[doc(inline)]
pub use self::error::Error;
#[doc(inline)]
pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, RequestExtensions};
pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, ResponseExtensions};
#[doc(inline)]
pub use self::routing::Router;
#[doc(inline)]
Expand Down

0 comments on commit 4ae272b

Please sign in to comment.