Skip to content

Commit

Permalink
Make required context parameters required (#2964)
Browse files Browse the repository at this point in the history
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->
When a `@contextParam` is marked as required, we will enforce it on
inputs. Since these fields may influence endpoint, omitting them can
result in a different target being hit.

- #1668 
- aws-sdk-rust#873

## Description
<!--- Describe your changes in detail -->

## Testing
- [x] S3 Integration test

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
rcoh authored Aug 31, 2023
1 parent c98d5fe commit 2c27834
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 6 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,15 @@ message = "Fix code generation for union members with the `@httpPayload` trait."
references = ["smithy-rs#2969", "smithy-rs#1896"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "all" }
author = "jdisanti"

[[aws-sdk-rust]]
message = "Make `bucket` required for request construction for S3. When `bucket` is not set, a **different** operation than intended can be triggered."
references = ["smithy-rs#1668", "aws-sdk-rust#873", "smithy-rs#2964"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "rcoh"

[[smithy-rs]]
message = "Required members with @contextParam are now treated as client-side required."
references = ["smithy-rs#2964"]
meta = { "breaking" = false, "tada" = false, "bug" = false, target = "client" }
author = "rcoh"
28 changes: 28 additions & 0 deletions aws/sdk/integration-tests/s3/tests/bucket-required.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use aws_config::SdkConfig;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_sdk_s3::config::{Credentials, Region};
use aws_sdk_s3::Client;
use aws_smithy_client::test_connection::capture_request;

#[tokio::test]
async fn dont_dispatch_when_bucket_is_unset() {
let (conn, rcvr) = capture_request(None);
let sdk_config = SdkConfig::builder()
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.region(Region::new("us-east-1"))
.http_connector(conn.clone())
.build();
let client = Client::new(&sdk_config);
let err = client
.list_objects_v2()
.send()
.await
.expect_err("bucket not set");
assert_eq!(format!("{}", err), "failed to construct request");
rcvr.expect_no_request();
}
4 changes: 4 additions & 0 deletions aws/sdk/integration-tests/s3/tests/request_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async fn get_request_id_from_modeled_error() {
let err = client
.get_object()
.key("dontcare")
.bucket("dontcare")
.send()
.await
.expect_err("status was 404, this is an error")
Expand Down Expand Up @@ -83,6 +84,7 @@ async fn get_request_id_from_unmodeled_error() {
let client = Client::from_conf(config);
let err = client
.get_object()
.bucket("dontcare")
.key("dontcare")
.send()
.await
Expand Down Expand Up @@ -156,6 +158,7 @@ async fn get_request_id_from_successful_streaming_response() {
let output = client
.get_object()
.key("dontcare")
.bucket("dontcare")
.send()
.await
.expect("valid successful response");
Expand Down Expand Up @@ -194,6 +197,7 @@ async fn conversion_to_service_error_maintains_request_id() {
let client = Client::from_conf(config);
let err = client
.get_object()
.bucket("dontcare")
.key("dontcare")
.send()
.await
Expand Down
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ plugins { }

allprojects {
repositories {
mavenLocal()
/* mavenLocal() */
mavenCentral()
google()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.enforceRequired
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
Expand Down Expand Up @@ -134,8 +135,10 @@ class EndpointParamsInterceptorGenerator(
// lastly, allow these to be overridden by members
memberParams.forEach { (memberShape, param) ->
val memberName = codegenContext.symbolProvider.toMemberName(memberShape)
rust(
".${EndpointParamsGenerator.setterName(param.name)}(_input.$memberName.clone())",
val member = memberShape.enforceRequired(writable("_input.$memberName.clone()"), codegenContext)

rustTemplate(
".${EndpointParamsGenerator.setterName(param.name)}(#{member})", "member" to member,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class EndpointsDecoratorTest {
structure TestOperationInput {
@contextParam(name: "Bucket")
@required
bucket: String,
nested: NestedStructure
}
Expand Down Expand Up @@ -210,6 +211,10 @@ class EndpointsDecoratorTest {
interceptor.called.load(Ordering::Relaxed),
"the interceptor should have been called"
);
// bucket_name is unset and marked as required on the model, so we'll refuse to construct this request
let err = client.test_operation().send().await.expect_err("param missing");
assert_eq!(format!("{}", err), "failed to construct request");
}
""",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ fun Writable.isEmpty(): Boolean {
return writer.toString() == RustWriter.root().toString()
}

fun Writable.map(f: RustWriter.(Writable) -> Unit): Writable {
val self = this
return writable { f(self) }
}

fun Writable.isNotEmpty(): Boolean = !this.isEmpty()

operator fun Writable.plus(other: Writable): Writable {
Expand Down Expand Up @@ -108,10 +113,12 @@ fun rustTypeParameters(
"#{gg:W}",
"gg" to typeParameter.declaration(withAngleBrackets = false),
)

else -> {
// Check if it's a writer. If it is, invoke it; Else, throw a codegen error.
@Suppress("UNCHECKED_CAST")
val func = typeParameter as? Writable ?: throw CodegenException("Unhandled type '$typeParameter' encountered by rustTypeParameters writer")
val func = typeParameter as? Writable
?: throw CodegenException("Unhandled type '$typeParameter' encountered by rustTypeParameters writer")
func.invoke(this)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive
Expand All @@ -21,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlockTemplat
import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.documentShape
import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
Expand All @@ -29,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
Expand All @@ -46,6 +49,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

Expand Down Expand Up @@ -76,10 +80,37 @@ abstract class BuilderCustomization : NamedCustomization<BuilderSection>()
fun RuntimeConfig.operationBuildError() = RuntimeType.operationModule(this).resolve("error::BuildError")
fun RuntimeConfig.serializationError() = RuntimeType.operationModule(this).resolve("error::SerializationError")

fun MemberShape.enforceRequired(
field: Writable,
codegenContext: CodegenContext,
produceOption: Boolean = true,
): Writable {
if (!this.isRequired) {
return field
}
val shape = this
val error = OperationBuildError(codegenContext.runtimeConfig).missingField(
codegenContext.symbolProvider.toMemberName(shape), "A required field was not set",
)
val unwrapped = when (codegenContext.model.expectShape(this.target)) {
is StringShape -> writable {
rustTemplate(
"#{field}.filter(|f|!AsRef::<str>::as_ref(f).trim().is_empty())",
"field" to field,
)
}

else -> field
}.map { base -> rustTemplate("#{base}.ok_or_else(||#{error})?", "base" to base, "error" to error) }
return unwrapped.letIf(produceOption) { w -> w.map { rust("Some(#T)", it) } }
}

class OperationBuildError(private val runtimeConfig: RuntimeConfig) {

fun missingField(field: String, details: String) = writable {
rust("#T::missing_field(${field.dq()}, ${details.dq()})", runtimeConfig.operationBuildError())
}

fun invalidField(field: String, details: String) = invalidField(field) { rust(details.dq()) }
fun invalidField(field: String, details: Writable) = writable {
rustTemplate(
Expand Down Expand Up @@ -164,7 +195,8 @@ class BuilderGenerator(
}

private fun RustWriter.missingRequiredField(field: String) {
val detailedMessage = "$field was not specified but it is required when building ${symbolProvider.toSymbol(shape).name}"
val detailedMessage =
"$field was not specified but it is required when building ${symbolProvider.toSymbol(shape).name}"
OperationBuildError(runtimeConfig).missingField(field, detailedMessage)(this)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.core.rustlang

import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.string.shouldEndWith
import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType

Expand Down Expand Up @@ -106,4 +107,13 @@ internal class RustTypeParametersTest {
}.join(writable("+"))(writer)
writer.toString() shouldContain "A-B-CD+E+F"
}

@Test
fun `test map`() {
val writer = RustWriter.forModule("model")
val a = writable { rust("a") }
val b = a.map { rust("b(#T)", it) }
b(writer)
writer.toString().trim() shouldEndWith "b(a)"
}
}
2 changes: 1 addition & 1 deletion settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pluginManagement {

buildscript {
repositories {
mavenLocal()
/* mavenLocal() */
mavenCentral()
google()
}
Expand Down

0 comments on commit 2c27834

Please sign in to comment.