Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow <Error> to trigger error handling for S3 #2958

Merged
merged 4 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ references = ["smithy-rs#2948"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" }
author = "Velfi"

[[aws-sdk-rust]]
message = "Correctly identify HTTP 200 responses from S3 with `<Error>` as the root Element as errors. **Note**: This a behavior change and will change the error type returned by the SDK in some cases."
references = ["smithy-rs#2958", "aws-sdk-rust#873"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "rcoh"

[[aws-sdk-rust]]
message = "Allow `no_credentials` to be used with all S3 operations."
references = ["smithy-rs#2955", "aws-sdk-rust#878"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientRustSettings
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientRestXmlFactory
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
Expand Down Expand Up @@ -106,6 +108,34 @@ class S3Decorator : ClientCodegenDecorator {
)
}

override fun operationCustomizations(
codegenContext: ClientCodegenContext,
operation: OperationShape,
baseCustomizations: List<OperationCustomization>,
): List<OperationCustomization> {
return baseCustomizations + object : OperationCustomization() {
override fun section(section: OperationSection): Writable {
return writable {
when (section) {
is OperationSection.BeforeParseResponse -> {
section.body?.also { body ->
rustTemplate(
"""
if matches!(#{errors}::body_is_error($body), Ok(true)) {
${section.forceError} = true;
}
""",
"errors" to RuntimeType.unwrappedXmlErrors(codegenContext.runtimeConfig),
)
}
}
else -> {}
}
}
}
}
}

private fun isInInvalidXmlRootAllowList(shape: Shape): Boolean {
return shape.isStructureShape && invalidXmlRootAllowList.contains(shape.id)
}
Expand All @@ -115,15 +145,15 @@ class FilterEndpointTests(
private val testFilter: (EndpointTestCase) -> EndpointTestCase? = { a -> a },
private val operationInputFilter: (EndpointTestOperationInput) -> EndpointTestOperationInput? = { a -> a },
) {
fun updateEndpointTests(endpointTests: List<EndpointTestCase>): List<EndpointTestCase> {
private fun updateEndpointTests(endpointTests: List<EndpointTestCase>): List<EndpointTestCase> {
val filteredTests = endpointTests.mapNotNull { test -> testFilter(test) }
return filteredTests.map { test ->
val operationInputs = test.operationInputs
test.toBuilder().operationInputs(operationInputs.mapNotNull { operationInputFilter(it) }).build()
}
}

fun transform(model: Model) = ModelTransformer.create().mapTraits(model) { _, trait ->
fun transform(model: Model): Model = ModelTransformer.create().mapTraits(model) { _, trait ->
when (trait) {
is EndpointTestsTrait -> EndpointTestsTrait.builder().testCases(updateEndpointTests(trait.testCases))
.version(trait.version).build()
Expand All @@ -135,7 +165,7 @@ class FilterEndpointTests(

// TODO(P96049742): This model transform may need to change depending on if and how the S3 model is updated.
private class AddOptionalAuth {
fun transform(model: Model) = ModelTransformer.create().mapShapes(model) { shape ->
fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape ->
// Add @optionalAuth to all S3 operations
if (shape is OperationShape && !shape.hasTrait<OptionalAuthTrait>()) {
shape.toBuilder()
Expand Down
40 changes: 40 additions & 0 deletions aws/sdk/integration-tests/s3/tests/status-200-errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::Credentials;
use aws_sdk_s3::Client;
use aws_smithy_client::test_connection::infallible_connection_fn;
use aws_smithy_http::body::SdkBody;
use aws_smithy_types::error::metadata::ProvideErrorMetadata;
use aws_types::region::Region;
use aws_types::SdkConfig;

const ERROR_RESPONSE: &str = r#"<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>SlowDown</Code>
<Message>Please reduce your request rate.</Message>
<RequestId>K2H6N7ZGQT6WHCEG</RequestId>
<HostId>WWoZlnK4pTjKCYn6eNV7GgOurabfqLkjbSyqTvDMGBaI9uwzyNhSaDhOCPs8paFGye7S6b/AB3A=</HostId>
</Error>
"#;

#[tokio::test]
async fn status_200_errors() {
let conn = infallible_connection_fn(|_req| http::Response::new(SdkBody::from(ERROR_RESPONSE)));
let sdk_config = SdkConfig::builder()
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.region(Region::new("us-west-4"))
.http_connector(conn)
.build();
let client = Client::new(&sdk_config);
let error = client
.delete_objects()
.bucket("bucket")
.send()
.await
.expect_err("should fail");
assert_eq!(error.into_service_error().code(), Some("SlowDown"));
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ sealed class OperationSection(name: String) : Section(name) {
data class BeforeParseResponse(
override val customizations: List<OperationCustomization>,
val responseName: String,
/**
* Name of the `force_error` variable. Set this to true to trigger error parsing.
*/
val forceError: String,
/**
* When set, the name of the response body data field
*/
val body: String?,
) : OperationSection("BeforeParseResponse")

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@ class ResponseDeserializerGenerator(
rustTemplate(
"""
fn deserialize_streaming(&self, response: &mut #{HttpResponse}) -> #{Option}<#{OutputOrError}> {
##[allow(unused_mut)]
let mut force_error = false;
#{BeforeParseResponse}

// If this is an error, defer to the non-streaming parser
if !response.status().is_success() && response.status().as_u16() != $successCode {
if (!response.status().is_success() && response.status().as_u16() != $successCode) || force_error {
return #{None};
}
#{Some}(#{type_erase_result}(#{parse_streaming_response}(response)))
Expand All @@ -106,7 +108,7 @@ class ResponseDeserializerGenerator(
*codegenScope,
"parse_streaming_response" to parserGenerator.parseStreamingResponseFn(operationShape, customizations),
"BeforeParseResponse" to writable {
writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response"))
writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", body = null))
},
)
}
Expand Down Expand Up @@ -136,8 +138,10 @@ class ResponseDeserializerGenerator(
let (success, status) = (response.status().is_success(), response.status().as_u16());
let headers = response.headers();
let body = response.body().bytes().expect("body loaded");
##[allow(unused_mut)]
let mut force_error = false;
#{BeforeParseResponse}
let parse_result = if !success && status != $successCode {
let parse_result = if !success && status != $successCode || force_error {
#{parse_error}(status, headers, body)
} else {
#{parse_response}(status, headers, body)
Expand All @@ -148,7 +152,7 @@ class ResponseDeserializerGenerator(
"parse_error" to parserGenerator.parseErrorFn(operationShape, customizations),
"parse_response" to parserGenerator.parseResponseFn(operationShape, customizations),
"BeforeParseResponse" to writable {
writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response"))
writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", "body"))
},
)
}
Expand Down