Skip to content

Commit

Permalink
Make required context parameters required
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Aug 30, 2023
1 parent 690be0f commit 38c137b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 3 deletions.
28 changes: 28 additions & 0 deletions aws/sdk/integration-tests/s3/tests/bucket-required.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use aws_config::SdkConfig;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_sdk_s3::config::{Credentials, Region};
use aws_sdk_s3::Client;
use aws_smithy_client::test_connection::capture_request;

#[tokio::test]
async fn dont_dispatch_when_bucket_is_unset() {
let (conn, rcvr) = capture_request(None);
let sdk_config = SdkConfig::builder()
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.region(Region::new("us-east-1"))
.http_connector(conn.clone())
.build();
let client = Client::new(&sdk_config);
let err = client
.list_objects_v2()
.send()
.await
.expect_err("bucket not set");
assert_eq!(format!("{}", err), "failed to construct request");
rcvr.expect_no_request();
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuildError
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

Expand Down Expand Up @@ -134,8 +136,11 @@ class EndpointParamsInterceptorGenerator(
// lastly, allow these to be overridden by members
memberParams.forEach { (memberShape, param) ->
val memberName = codegenContext.symbolProvider.toMemberName(memberShape)
rust(
".${EndpointParamsGenerator.setterName(param.name)}(_input.$memberName.clone())",
val member = writable("_input.$memberName.clone()").letIf(memberShape.isRequired) { ref ->
OperationBuildError(codegenContext.runtimeConfig).emptyOrUnset(symbolProvider, ref, memberShape)
}
rustTemplate(
".${EndpointParamsGenerator.setterName(param.name)}(dbg!(#{member}))", "member" to member,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class EndpointsDecoratorTest {
structure TestOperationInput {
@contextParam(name: "Bucket")
@required
bucket: String,
nested: NestedStructure
}
Expand Down Expand Up @@ -210,6 +211,10 @@ class EndpointsDecoratorTest {
interceptor.called.load(Ordering::Relaxed),
"the interceptor should have been called"
);
// bucket_name is unset and marked as required on the model, so we'll refuse to construct this request
let err = client.test_operation().send().await.expect_err("param missing");
assert_eq!(format!("{}", err), "failed to construct request");
}
""",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
Expand Down Expand Up @@ -77,9 +78,39 @@ fun RuntimeConfig.operationBuildError() = RuntimeType.operationModule(this).reso
fun RuntimeConfig.serializationError() = RuntimeType.operationModule(this).resolve("error::SerializationError")

class OperationBuildError(private val runtimeConfig: RuntimeConfig) {

fun emptyOrUnset(symbolProvider: SymbolProvider, ref: Writable, member: MemberShape): Writable {
val checkSet = RuntimeType.forInlineFun("check_set", RustModule.private("serde_util")) {
rustTemplate(
"""
pub(crate) fn check_set<T: #{Default} + #{PartialEq}>(field: Option<T>, name: &'static str) -> Result<T, #{BuildError}> {
if let Some(field) = field {
if field != Default::default() {
return Ok(field)
}
}
Err(#{BuildError}::missing_field(name, "field was required"))
}
""",
"BuildError" to runtimeConfig.operationBuildError(), *preludeScope,
)
}

val fieldName = symbolProvider.toMemberName(member)
return writable {
rustTemplate(
"Some(#{checkSet}(#{ref}, ${fieldName.dq()})?)",
"checkSet" to checkSet,
"ref" to ref,
)
}
}

fun missingField(field: String, details: String) = writable {
rust("#T::missing_field(${field.dq()}, ${details.dq()})", runtimeConfig.operationBuildError())
}

fun invalidField(field: String, details: String) = invalidField(field) { rust(details.dq()) }
fun invalidField(field: String, details: Writable) = writable {
rustTemplate(
Expand Down Expand Up @@ -164,7 +195,8 @@ class BuilderGenerator(
}

private fun RustWriter.missingRequiredField(field: String) {
val detailedMessage = "$field was not specified but it is required when building ${symbolProvider.toSymbol(shape).name}"
val detailedMessage =
"$field was not specified but it is required when building ${symbolProvider.toSymbol(shape).name}"
OperationBuildError(runtimeConfig).missingField(field, detailedMessage)(this)
}

Expand Down

0 comments on commit 38c137b

Please sign in to comment.