Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always write required query param keys, even if value is unset or empty #1973

Merged
merged 20 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ references = ["smithy-rs#1935"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[aws-sdk-rust]]
Velfi marked this conversation as resolved.
Show resolved Hide resolved
message = "Neglecting to include an upload ID when sending an AbortMultipartUpload request will no longer result in accidental deletion of data."
references = ["smithy-rs#1957"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "Velfi"

[[smithy-rs]]
message = "Upgrade to Smithy 1.26.2"
references = ["smithy-rs#1972"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class S3Decorator : RustCodegenDecorator<ClientProtocolGenerator, ClientCodegenC
return model.letIf(applies(service.id)) {
ModelTransformer.create().mapShapes(model) { shape ->
shape.letIf(isInInvalidXmlRootAllowList(shape)) {
logger.info("Adding AllowInvalidXmlRoot trait to $shape")
(shape as StructureShape).toBuilder().addTrait(AllowInvalidXmlRoot()).build()
logger.info("Adding AllowInvalidXmlRoot trait to $it")
(it as StructureShape).toBuilder().addTrait(AllowInvalidXmlRoot()).build()
}
}
}
Expand Down
50 changes: 50 additions & 0 deletions aws/sdk/integration-tests/s3/tests/required-query-params.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use aws_sdk_s3::operation::AbortMultipartUpload;
use aws_sdk_s3::Region;
use aws_smithy_http::operation::error::BuildError;

#[tokio::test]
async fn test_error_when_required_query_param_is_unset() {
Velfi marked this conversation as resolved.
Show resolved Hide resolved
let conf = aws_sdk_s3::Config::builder()
.region(Region::new("us-east-1"))
.build();

let err = AbortMultipartUpload::builder()
.bucket("test-bucket")
.key("test.txt")
.build()
.unwrap()
.make_operation(&conf)
.await
.unwrap_err();

assert_eq!(
BuildError::missing_field("upload_id", "cannot be empty or unset").to_string(),
err.to_string(),
)
}

#[tokio::test]
async fn test_error_when_required_query_param_is_set_but_empty() {
let conf = aws_sdk_s3::Config::builder()
.region(Region::new("us-east-1"))
.build();
let err = AbortMultipartUpload::builder()
.bucket("test-bucket")
.key("test.txt")
.upload_id("")
.build()
.unwrap()
.make_operation(&conf)
.await
.unwrap_err();

assert_eq!(
BuildError::missing_field("upload_id", "cannot be empty or unset").to_string(),
err.to_string(),
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
Expand All @@ -23,7 +24,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.autoDeref
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustInlineTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuildError
Expand All @@ -33,6 +36,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectMember
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape

fun HttpTrait.uriFormatString(): String {
Expand All @@ -55,7 +59,7 @@ fun SmithyPattern.rustFormatString(prefix: String, separator: String): String {
* Generates methods to serialize and deserialize requests based on the HTTP trait. Specifically:
* 1. `fn update_http_request(builder: http::request::Builder) -> Builder`
*
* This method takes a builder (perhaps pre configured with some headers) from the caller and sets the HTTP
* This method takes a builder (perhaps pre-configured with some headers) from the caller and sets the HTTP
* headers & URL based on the HTTP trait implementation.
*/
class RequestBindingGenerator(
Expand All @@ -72,7 +76,7 @@ class RequestBindingGenerator(
private val httpBindingGenerator =
HttpBindingGenerator(protocol, codegenContext, codegenContext.symbolProvider, operationShape, ::builderSymbol)
private val index = HttpBindingIndex.of(model)
private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder")
private val encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder")

private val codegenScope = arrayOf(
"BuildError" to runtimeConfig.operationBuildError(),
Expand Down Expand Up @@ -208,24 +212,64 @@ class RequestBindingGenerator(
val memberShape = param.member
val memberSymbol = symbolProvider.toSymbol(memberShape)
val memberName = symbolProvider.toMemberName(memberShape)
val outerTarget = model.expectShape(memberShape.target)
ifSet(outerTarget, memberSymbol, "&_input.$memberName") { field ->
// if `param` is a list, generate another level of iteration
listForEach(outerTarget, field) { innerField, targetId ->
val target = model.expectShape(targetId)
rust(
"query.push_kv(${param.locationName.dq()}, ${
paramFmtFun(writer, target, memberShape, innerField)
});",
val target = model.expectShape(memberShape.target)

if (memberShape.isRequired) {
val codegenScope = arrayOf(
"BuildError" to OperationBuildError(runtimeConfig).missingField(
memberName,
"cannot be empty or unset",
),
)
val derefName = safeName("inner")
rust("let $derefName = &_input.$memberName;")
if (memberSymbol.isOptional()) {
rustTemplate(
"let $derefName = $derefName.as_ref().ok_or_else(|| #{BuildError:W})?;",
*codegenScope,
)
}

when {
target.isStringShape -> {
val asStr = writable {
if (target.hasTrait<EnumTrait>()) {
Velfi marked this conversation as resolved.
Show resolved Hide resolved
rustInlineTemplate(".as_str()")
}
}
rustBlock("if $derefName#W.is_empty()", asStr) {
rustTemplate("return Err(#{BuildError:W});", *codegenScope)
}
}
}

paramList(target, derefName, param, writer, memberShape)
} else {
ifSet(target, memberSymbol, "&_input.$memberName") { field ->
// if `param` is a list, generate another level of iteration
paramList(target, field, param, writer, memberShape)
}
}
}
writer.rust("Ok(())")
}
return true
}

private fun RustWriter.paramList(
outerTarget: Shape,
field: String,
param: HttpBinding,
writer: RustWriter,
memberShape: MemberShape,
) {
listForEach(outerTarget, field) { innerField, targetId ->
val target = model.expectShape(targetId)
val value = paramFmtFun(writer, target, memberShape, innerField)
rust("""query.push_kv("${param.locationName}", $value);""")
}
}

/**
* Format [member] when used as a queryParam
*/
Expand All @@ -235,18 +279,21 @@ class RequestBindingGenerator(
val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_string"))
"&$func(&$targetName)"
}

target.isTimestampShape -> {
val timestampFormat =
index.determineTimestampFormat(member, HttpBinding.Location.QUERY, protocol.defaultTimestampFormat)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_timestamp"))
"&$func($targetName, ${writer.format(timestampFormatType)})?"
}

target.isListShape || target.isMemberShape -> {
throw IllegalArgumentException("lists should be handled at a higher level")
}

else -> {
"${writer.format(Encoder)}::from(${autoDeref(targetName)}).encode()"
"${writer.format(encoder)}::from(${autoDeref(targetName)}).encode()"
}
}
}
Expand All @@ -273,17 +320,19 @@ class RequestBindingGenerator(
}
rust("let $outputVar = $func($input, #T);", encodingStrategy)
}

target.isTimestampShape -> {
val timestampFormat =
index.determineTimestampFormat(member, HttpBinding.Location.LABEL, protocol.defaultTimestampFormat)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp"))
rust("let $outputVar = $func($input, ${format(timestampFormatType)})?;")
}

else -> {
rust(
"let mut ${outputVar}_encoder = #T::from(${autoDeref(input)}); let $outputVar = ${outputVar}_encoder.encode();",
Encoder,
encoder,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ fun StructureShape.renderWithModelBuilder(
val TokioWithTestMacros = CargoDependency(
"tokio",
CratesIo("1"),
features = setOf("macros", "test-util", "rt"),
features = setOf("macros", "test-util", "rt", "rt-multi-thread"),
scope = DependencyScope.Dev,
)

Expand Down