Skip to content

Always write required query param keys, even if value is unset or empty #1973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ references = ["smithy-rs#1935"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[aws-sdk-rust]]
message = "Neglecting to include an upload ID when sending an AbortMultipartUpload request will no longer result in accidental deletion of data."
references = ["smithy-rs#1957"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "Velfi"

[[smithy-rs]]
message = "Upgrade to Smithy 1.26.2"
references = ["smithy-rs#1972"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class S3Decorator : RustCodegenDecorator<ClientProtocolGenerator, ClientCodegenC
return model.letIf(applies(service.id)) {
ModelTransformer.create().mapShapes(model) { shape ->
shape.letIf(isInInvalidXmlRootAllowList(shape)) {
logger.info("Adding AllowInvalidXmlRoot trait to $shape")
(shape as StructureShape).toBuilder().addTrait(AllowInvalidXmlRoot()).build()
logger.info("Adding AllowInvalidXmlRoot trait to $it")
(it as StructureShape).toBuilder().addTrait(AllowInvalidXmlRoot()).build()
}
}
}
Expand Down
97 changes: 97 additions & 0 deletions aws/sdk/integration-tests/s3/tests/mandatory-query-param.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use std::time::{Duration, UNIX_EPOCH};

use aws_http::user_agent::AwsUserAgent;
use aws_sdk_s3::middleware::DefaultMiddleware;
use aws_sdk_s3::operation::AbortMultipartUpload;
use aws_sdk_s3::{Credentials, Region};
use aws_smithy_client::test_connection::TestConnection;
use aws_smithy_client::Client as CoreClient;
use aws_smithy_http::body::SdkBody;

pub type Client<C> = CoreClient<C, DefaultMiddleware>;

fn abort_multipart_upload_response_with_empty_upload_id() -> http::Request<SdkBody> {
http::Request::builder()
.header("authorization", "AWS4-HMAC-SHA256 Credential=ANOTREAL/20210618/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token;x-amz-user-agent, Signature=82f2d8f8b0e7e05dc08a243abe6bf30ca5b4399d46f89408a8cdbcc331948dc7")
.uri("https://s3.us-east-1.amazonaws.com/test-bucket/test.txt?x-id=AbortMultipartUpload&uploadId=")
.body(SdkBody::empty())
.unwrap()
}

fn empty_ok_response() -> http::Response<&'static str> {
http::Response::builder().status(200).body("").unwrap()
}

#[tokio::test]
async fn test_mandatory_query_param_is_unset() {
let creds = Credentials::new(
"ANOTREAL",
"notrealrnrELgWzOk3IfjzDKtFBhDby",
Some("notarealsessiontoken".to_string()),
None,
"test",
);
let conf = aws_sdk_s3::Config::builder()
.credentials_provider(creds)
.region(Region::new("us-east-1"))
.build();
let conn = TestConnection::new(vec![(
abort_multipart_upload_response_with_empty_upload_id(),
empty_ok_response(),
)]);
let client = Client::new(conn.clone());
let mut op = AbortMultipartUpload::builder()
.bucket("test-bucket")
.key("test.txt")
.build()
.unwrap()
.make_operation(&conf)
.await
.unwrap();
op.properties_mut()
.insert(UNIX_EPOCH + Duration::from_secs(1624036048));
op.properties_mut().insert(AwsUserAgent::for_tests());

client.call(op).await.expect("empty responses are OK");
conn.assert_requests_match(&[]);
}

#[tokio::test]
async fn test_mandatory_query_param_is_set_but_empty() {
let creds = Credentials::new(
"ANOTREAL",
"notrealrnrELgWzOk3IfjzDKtFBhDby",
Some("notarealsessiontoken".to_string()),
None,
"test",
);
let conf = aws_sdk_s3::Config::builder()
.credentials_provider(creds)
.region(Region::new("us-east-1"))
.build();
let conn = TestConnection::new(vec![(
abort_multipart_upload_response_with_empty_upload_id(),
empty_ok_response(),
)]);
let client = Client::new(conn.clone());
let mut op = AbortMultipartUpload::builder()
.bucket("test-bucket")
.key("test.txt")
.upload_id("")
.build()
.unwrap()
.make_operation(&conf)
.await
.unwrap();
op.properties_mut()
.insert(UNIX_EPOCH + Duration::from_secs(1624036048));
op.properties_mut().insert(AwsUserAgent::for_tests());

client.call(op).await.expect("empty responses are OK");
conn.assert_requests_match(&[]);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.RequiredTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
Expand All @@ -29,6 +30,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectMember
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape

fun HttpTrait.uriFormatString(): String {
Expand Down Expand Up @@ -201,23 +203,43 @@ class RequestBindingGenerator(
val memberSymbol = symbolProvider.toSymbol(memberShape)
val memberName = symbolProvider.toMemberName(memberShape)
val outerTarget = model.expectShape(memberShape.target)

ifSet(outerTarget, memberSymbol, "&_input.$memberName") { field ->
// if `param` is a list, generate another level of iteration
listForEach(outerTarget, field) { innerField, targetId ->
val target = model.expectShape(targetId)
rust(
"query.push_kv(${param.locationName.dq()}, ${
paramFmtFun(writer, target, memberShape, innerField)
});",
)
}
paramList(outerTarget, field, param, writer, memberShape)
}
if (memberShape.hasTrait<RequiredTrait>()) {
rust(
"""
else {
query.push_kv(${param.locationName.dq()}, "");
}
""",
)
}
}
writer.rust("Ok(())")
}
return true
}

private fun RustWriter.paramList(
outerTarget: Shape,
field: String,
param: HttpBinding,
writer: RustWriter,
memberShape: MemberShape,
) {
listForEach(outerTarget, field) { innerField, targetId ->
val target = model.expectShape(targetId)
rust(
"query.push_kv(${param.locationName.dq()}, ${
paramFmtFun(writer, target, memberShape, innerField)
});",
)
}
}

/**
* Format [member] when used as a queryParam
*/
Expand All @@ -227,16 +249,19 @@ class RequestBindingGenerator(
val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_string"))
"&$func(&$targetName)"
}

target.isTimestampShape -> {
val timestampFormat =
index.determineTimestampFormat(member, HttpBinding.Location.QUERY, protocol.defaultTimestampFormat)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_timestamp"))
"&$func($targetName, ${writer.format(timestampFormatType)})?"
}

target.isListShape || target.isMemberShape -> {
throw IllegalArgumentException("lists should be handled at a higher level")
}

else -> {
"${writer.format(Encoder)}::from(${autoDeref(targetName)}).encode()"
}
Expand Down Expand Up @@ -268,13 +293,15 @@ class RequestBindingGenerator(
}
rust("let $outputVar = $func($input, #T);", encodingStrategy)
}

target.isTimestampShape -> {
val timestampFormat =
index.determineTimestampFormat(member, HttpBinding.Location.LABEL, protocol.defaultTimestampFormat)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp"))
rust("let $outputVar = $func($input, ${format(timestampFormatType)})?;")
}

else -> {
rust(
"let mut ${outputVar}_encoder = #T::from(${autoDeref(input)}); let $outputVar = ${outputVar}_encoder.encode();",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ fun StructureShape.renderWithModelBuilder(
val TokioWithTestMacros = CargoDependency(
"tokio",
CratesIo("1"),
features = setOf("macros", "test-util", "rt"),
features = setOf("macros", "test-util", "rt", "rt-multi-thread"),
scope = DependencyScope.Dev,
)

Expand Down