Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always write required query param keys, even if value is unset or empty #1973

Merged
merged 20 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,8 @@ references = ["smithy-rs#1935"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[aws-sdk-rust]]
Velfi marked this conversation as resolved.
Show resolved Hide resolved
message = "Neglecting to include an upload ID when sending an AbortMultipartUpload request will no longer result in accidental deletion of data."
references = ["smithy-rs#1957"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "Velfi"
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package software.amazon.smithy.rustsdk.customize.s3

import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.AbstractShapeBuilder
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
Expand All @@ -32,8 +33,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.smithy.traits.AllowInvalidXmlRoot
import software.amazon.smithy.rust.codegen.core.smithy.traits.Mandatory
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rustsdk.AwsRuntimeType
import software.amazon.smithy.utils.ToSmithyBuilder
import java.util.logging.Logger

/**
Expand All @@ -47,6 +51,10 @@ class S3Decorator : RustCodegenDecorator<ClientProtocolGenerator, ClientCodegenC
// API returns GetObjectAttributes_Response_ instead of Output
ShapeId.from("com.amazonaws.s3#GetObjectAttributesOutput"),
)
private val mandatoryShapesList = setOf(
// Must be included or else S3 interprets the request as a get, put, or delete object request.
ShapeId.from("com.amazonaws.s3#MultipartUploadId"),
)
Velfi marked this conversation as resolved.
Show resolved Hide resolved

private fun applies(serviceId: ShapeId) =
serviceId == ShapeId.from("com.amazonaws.s3#AmazonS3")
Expand All @@ -70,6 +78,16 @@ class S3Decorator : RustCodegenDecorator<ClientProtocolGenerator, ClientCodegenC
logger.info("Adding AllowInvalidXmlRoot trait to $shape")
(shape as StructureShape).toBuilder().addTrait(AllowInvalidXmlRoot()).build()
}

shape.letIf(isInMandatoryList(shape)) {
logger.info("Adding Mandatory trait to $shape")

if (shape is ToSmithyBuilder<*>) {
(shape.toBuilder() as AbstractShapeBuilder<*, *>).addTrait(Mandatory()).build()
} else {
PANIC("can't add Mandatory trait to $shape because it has no builder")
}
}
}
}
}
Expand All @@ -87,6 +105,10 @@ class S3Decorator : RustCodegenDecorator<ClientProtocolGenerator, ClientCodegenC
private fun isInInvalidXmlRootAllowList(shape: Shape): Boolean {
return shape.isStructureShape && invalidXmlRootAllowList.contains(shape.id)
}

private fun isInMandatoryList(shape: Shape): Boolean {
return mandatoryShapesList.contains(shape.id)
}
}

class S3(codegenContext: CodegenContext) : RestXml(codegenContext) {
Expand Down
45 changes: 45 additions & 0 deletions aws/sdk/integration-tests/s3/tests/mandatory-query-param.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use aws_smithy_client::never::NeverConnector;

#[tokio::test]
#[should_panic(
expected = r#"ConstructionFailure(MissingField { field: "upload_id", details: "cannot be empty or unset" })"#
)]
async fn test_mandatory_query_param_is_unset() {
Velfi marked this conversation as resolved.
Show resolved Hide resolved
let conf = aws_sdk_s3::Config::builder().build();
let conn = NeverConnector::new();

let client = aws_sdk_s3::Client::from_conf_conn(conf, conn.clone());

client
.abort_multipart_upload()
.bucket("a-bucket")
.key("a-key")
.send()
.await
.unwrap();
}
Velfi marked this conversation as resolved.
Show resolved Hide resolved

#[tokio::test]
#[should_panic(
expected = r#"ConstructionFailure(MissingField { field: "upload_id", details: "cannot be empty or unset" })"#
)]
async fn test_mandatory_query_param_is_empty() {
let conf = aws_sdk_s3::Config::builder().build();
let conn = NeverConnector::new();

let client = aws_sdk_s3::Client::from_conf_conn(conf, conn.clone());

client
.abort_multipart_upload()
.bucket("a-bucket")
.key("a-key")
.upload_id("")
.send()
.await
.unwrap();
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuild
import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.traits.isMandatory
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectMember
import software.amazon.smithy.rust.codegen.core.util.inputShape
Expand Down Expand Up @@ -201,15 +202,31 @@ class RequestBindingGenerator(
val memberSymbol = symbolProvider.toSymbol(memberShape)
val memberName = symbolProvider.toMemberName(memberShape)
val outerTarget = model.expectShape(memberShape.target)
ifSet(outerTarget, memberSymbol, "&_input.$memberName") { field ->
// if `param` is a list, generate another level of iteration
listForEach(outerTarget, field) { innerField, targetId ->
val target = model.expectShape(targetId)
rust(
"query.push_kv(${param.locationName.dq()}, ${
paramFmtFun(writer, target, memberShape, innerField)
});",
)

if (outerTarget.isMandatory()) {
val buildError = OperationBuildError(runtimeConfig).missingField(
this,
memberName,
"cannot be empty or unset",
)

rustBlock("match &_input.$memberName") {
val derefName = safeName("inner")
rustBlock("Some($derefName) if $derefName.is_empty() =>") {
rust("return Err($buildError);")
}
rustBlock("Some($derefName) =>") {
// if `param` is a list, generate another level of iteration
paramList(outerTarget, derefName, param, writer, memberShape)
}
rustBlock("None =>") {
rust("return Err($buildError);")
}
}
} else {
ifSet(outerTarget, memberSymbol, "&_input.$memberName") { field ->
// if `param` is a list, generate another level of iteration
paramList(outerTarget, field, param, writer, memberShape)
}
}
}
Expand All @@ -218,6 +235,23 @@ class RequestBindingGenerator(
return true
}

private fun RustWriter.paramList(
outerTarget: Shape,
field: String,
param: HttpBinding,
writer: RustWriter,
memberShape: MemberShape,
) {
listForEach(outerTarget, field) { innerField, targetId ->
val target = model.expectShape(targetId)
rust(
"query.push_kv(${param.locationName.dq()}, ${
paramFmtFun(writer, target, memberShape, innerField)
});",
)
}
}

/**
* Format [member] when used as a queryParam
*/
Expand All @@ -227,16 +261,19 @@ class RequestBindingGenerator(
val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_string"))
"&$func(&$targetName)"
}

target.isTimestampShape -> {
val timestampFormat =
index.determineTimestampFormat(member, HttpBinding.Location.QUERY, protocol.defaultTimestampFormat)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_timestamp"))
"&$func($targetName, ${writer.format(timestampFormatType)})?"
}

target.isListShape || target.isMemberShape -> {
throw IllegalArgumentException("lists should be handled at a higher level")
}

else -> {
"${writer.format(Encoder)}::from(${autoDeref(targetName)}).encode()"
}
Expand Down Expand Up @@ -268,13 +305,15 @@ class RequestBindingGenerator(
}
rust("let $outputVar = $func($input, #T);", encodingStrategy)
}

target.isTimestampShape -> {
val timestampFormat =
index.determineTimestampFormat(member, HttpBinding.Location.LABEL, protocol.defaultTimestampFormat)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp"))
rust("let $outputVar = $func($input, ${format(timestampFormatType)})?;")
}

else -> {
rust(
"let mut ${outputVar}_encoder = #T::from(${autoDeref(input)}); let $outputVar = ${outputVar}_encoder.encode();",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.core.smithy.traits

import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AnnotationTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait

/**
* Indicates that a shape must be set by users and that operations will fail to build if it's unset.
*/
class Mandatory : AnnotationTrait(ID, Node.objectNode()) {
companion object {
val ID: ShapeId = ShapeId.from("smithy.api.internal#mandatory")
}
}

fun Shape.isMandatory() = this.hasTrait<Mandatory>()