Skip to content

Commit

Permalink
Fix protocol tests against the orchestrator (#2768)
Browse files Browse the repository at this point in the history
This PR fixes the protocol tests in orchestrator mode, and adds
`--all-targets` to the orchestrator CI checks.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: Zelda Hessler <[email protected]>
  • Loading branch information
jdisanti and Velfi authored Jun 14, 2023
1 parent 988eb61 commit 45f2711
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Compani
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.FuturesUtil
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.HdrHistogram
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.Hound
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.HttpBody
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.SerdeJson
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.Smol
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.TempFile
Expand Down Expand Up @@ -122,6 +123,7 @@ class S3TestDependencies(private val codegenContext: ClientCodegenContext) : Lib
addDependency(BytesUtils.toDevDependency())
addDependency(FastRand.toDevDependency())
addDependency(HdrHistogram)
addDependency(HttpBody.toDevDependency())
addDependency(Smol)
addDependency(TempFile)
addDependency(TracingAppender)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class DefaultProtocolTestGenerator(
writeInline("let expected_output =")
instantiator.render(this, expectedShape, testCase.params)
write(";")
write("let http_response = #T::new()", RuntimeType.HttpResponseBuilder)
write("let mut http_response = #T::new()", RuntimeType.HttpResponseBuilder)
testCase.headers.forEach { (key, value) ->
writeWithNoFormatting(".header(${key.dq()}, ${value.dq()})")
}
Expand Down Expand Up @@ -360,7 +360,9 @@ class DefaultProtocolTestGenerator(
let de = #{OperationDeserializer};
let parsed = de.deserialize_streaming(&mut http_response);
let parsed = parsed.unwrap_or_else(|| {
let http_response = http_response.map(|body|#{copy_from_slice}(body.bytes().unwrap()));
let http_response = http_response.map(|body| {
#{SdkBody}::from(#{copy_from_slice}(body.bytes().unwrap()))
});
de.deserialize_nonstreaming(&http_response)
});
""",
Expand All @@ -369,20 +371,34 @@ class DefaultProtocolTestGenerator(
"copy_from_slice" to RuntimeType.Bytes.resolve("copy_from_slice"),
"ResponseDeserializer" to CargoDependency.smithyRuntimeApi(codegenContext.runtimeConfig).toType()
.resolve("client::orchestrator::ResponseDeserializer"),
"SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig),
)
}
if (expectedShape.hasTrait<ErrorTrait>()) {
val errorSymbol = codegenContext.symbolProvider.symbolForOperationError(operationShape)
val errorVariant = codegenContext.symbolProvider.toSymbol(expectedShape).name
rust("""let parsed = parsed.expect_err("should be error response");""")
if (codegenContext.smithyRuntimeMode.defaultToOrchestrator) {
rustTemplate(
"""let parsed: &#{Error} = parsed.as_operation_error().expect("operation error").downcast_ref().unwrap();""",
"Error" to codegenContext.symbolProvider.symbolForOperationError(operationShape),
)
}
rustBlock("if let #T::$errorVariant(parsed) = parsed", errorSymbol) {
compareMembers(expectedShape)
}
rustBlock("else") {
rust("panic!(\"wrong variant: Got: {:?}. Expected: {:?}\", parsed, expected_output);")
}
} else {
rust("let parsed = parsed.unwrap();")
if (codegenContext.smithyRuntimeMode.defaultToMiddleware) {
rust("let parsed = parsed.unwrap();")
} else {
rustTemplate(
"""let parsed: #{Output} = *parsed.expect("should be successful response").downcast().unwrap();""",
"Output" to codegenContext.symbolProvider.toSymbol(expectedShape),
)
}
compareMembers(outputShape)
}
}
Expand Down
4 changes: 2 additions & 2 deletions tools/ci-scripts/check-aws-sdk-orchestrator-impl
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ cd aws/sdk/build/aws-sdk/sdk
for service in "${services_that_compile[@]}"; do
pushd "${service}"
echo -e "${C_YELLOW}# Running 'cargo check --all-features' on '${service}'${C_RESET}"
RUSTFLAGS="${RUSTFLAGS:-} --cfg aws_sdk_orchestrator_mode" cargo check --all-features
RUSTFLAGS="${RUSTFLAGS:-} --cfg aws_sdk_orchestrator_mode" cargo check --all-features --all-targets
popd
done

for service in "${services_that_pass_tests[@]}"; do
pushd "${service}"
echo -e "${C_YELLOW}# Running 'cargo test --all-features' on '${service}'${C_RESET}"
RUSTFLAGS="${RUSTFLAGS:-} --cfg aws_sdk_orchestrator_mode" cargo test --all-features --no-fail-fast
RUSTFLAGS="${RUSTFLAGS:-} --cfg aws_sdk_orchestrator_mode" cargo test --all-features --all-targets --no-fail-fast
popd
done

Expand Down

0 comments on commit 45f2711

Please sign in to comment.