From c63e79245430c98b2725627566b80a0fbf8e9619 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 17 Jul 2024 11:50:52 +0200 Subject: [PATCH] Add server RPC v2 CBOR support (#2544) RPC v2 CBOR is a new protocol that ~is being added~ has [recently been added](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) to the Smithy specification. _(I'll add more details here as the patchset evolves)_ Credit goes to @jjant for initial implementation of the router, which I built on top of from his [`jjant/smithy-rpc-v2-exploration`](https://github.com/awslabs/smithy-rs/tree/jjant/smithy-rpc-v2-exploration) branch. Tracking issue: https://github.com/smithy-lang/smithy-rs/issues/3573. ## Caveats `TODO`s are currently exhaustively sprinkled throughout the patch documenting what remains to be done. Most of these need to be addressed before this can be merged in; some can be punted on to not make this PR bigger. However, I'd like to call out the major caveats and blockers here. I'll keep updating this list as the patchset evolves. - [x] RPC v2 has still not been added to the Smithy specification. It is currently being worked on over in the [`smithy-rpc-v2`](https://github.com/awslabs/smithy/tree/smithy-rpc-v2) branch. The following are prerrequisites for this PR to be merged; **until they are done CI on this PR will fail**: - [x] Smithy merges in RPC v2 support. - [x] Smithy releases a new version incorporating RPC v2 support. - Released in [Smithy v1.47](https://github.com/smithy-lang/smithy/releases/tag/1.47.0) - [x] smithy-rs updates to the new version. - Updated in https://github.com/smithy-lang/smithy-rs/pull/3552 - [x] ~Protocol tests for the protocol do not currently exist in Smithy. Until those get written~, this PR resorts to Rust unit tests and integration tests that use `serde` to round-trip messages and compare `serde`'s encoders and decoders with ours for correctness. - Protocol tests are under the [`smithy-protocol-tests`](https://github.com/smithy-lang/smithy/tree/main/smithy-protocol-tests/model/rpcv2Cbor) directory in Smithy. - We're keeping the `serde_cbor` round-trip tests for defense in depth. - [ ] https://github.com/smithy-lang/smithy-rs/issues/3709 - Currently only server-side support has been implemented, because that's what I'm most familiar. However, we're almost all the way there to add client-side support. - ~[ ] [Smithy `document` shapes](https://smithy.io/2.0/spec/simple-types.html#document) are not supported. RPC v2's specification currently doesn't indicate how to implement them.~ - [The spec](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#shape-serialization) ended up leaving them as unsupported: "Document types are not currently supported in this protocol." ## Prerequisite PRs This section lists prerequisite PRs and issues that would make the diff for this one lighter or easier to understand. It's preferable that these PRs be merged prior to this one; some are hard prerequisites. They mostly relate to parts of the codebase I've had to touch or ~pilfer~ inspect in this PR where I've made necessary changes, refactors and "drive-by improvements" that are mostly unrelated, although some directly unlock things I've needed in this patchset. It makes sense to pull them out to ease reviewability and make this patch more semantically self-contained. - https://github.com/awslabs/smithy-rs/pull/2516 - https://github.com/awslabs/smithy-rs/pull/2517 - https://github.com/awslabs/smithy-rs/pull/2522 - https://github.com/awslabs/smithy-rs/pull/2524 - https://github.com/awslabs/smithy-rs/pull/2528 - https://github.com/awslabs/smithy-rs/pull/2536 - https://github.com/awslabs/smithy-rs/pull/2537 - https://github.com/awslabs/smithy-rs/pull/2531 - https://github.com/awslabs/smithy-rs/pull/2538 - https://github.com/awslabs/smithy-rs/pull/2539 - https://github.com/awslabs/smithy-rs/pull/2542 - https://github.com/smithy-lang/smithy-rs/pull/3684 - https://github.com/smithy-lang/smithy-rs/pull/3678 - https://github.com/smithy-lang/smithy-rs/pull/3690 - https://github.com/smithy-lang/smithy-rs/pull/3713 - https://github.com/smithy-lang/smithy-rs/pull/3726 - https://github.com/smithy-lang/smithy-rs/pull/3752 ## Testing ~RPC v2 has still not been added to the Smithy specification. It is currently being worked on over in the [`smithy-rpc-v2`](https://github.com/awslabs/smithy/tree/smithy-rpc-v2) branch.~ This can only be tested _locally_ following these steps: ~1. Clone [the Smithy repository](https://github.com/smithy-lang/smithy/tree/smithy-rpc-v2) and checkout the `smithy-rpc-v2` branch. 2. Inside your local checkout of smithy-rs pointing to this PR's branch, make sure you've added `mavenLocal()` as a repository in the `build.gradle.kts` files. [Example](https://github.com/smithy-lang/smithy-rs/pull/2544/commits/8df82fd3fc92434ea4f4ffb10b02df2da458624c). 4. Inside your local checkout of Smithy's `smithy-rpc-v2` branch: 1. Set `VERSION` to the current Smithy version used in smithy-rs (1.28.1 as of writing, but [check here](https://github.com/awslabs/smithy-rs/blob/main/gradle.properties#L21)). 2. Run `./gradlew clean build pTML`.~ ~6.~ 1. In your local checkout of the smithy-rs's `smithy-rpc-v2` branch, run `./gradlew codegen-server-test:build -P modules='rpcv2Cbor'`. ~You can troubleshoot whether you have Smithy correctly set up locally by inspecting `~/.m2/repository/software/amazon/smithy/smithy-protocols-traits`.~ ## Checklist - [ ] I have updated `CHANGELOG.next.toml` if I made changes to the smithy-rs codegen or runtime crates ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --- .cargo-deny-config.toml | 3 + build.gradle.kts | 2 +- buildSrc/src/main/kotlin/CodegenTestCommon.kt | 79 ++- buildSrc/src/main/kotlin/CrateSet.kt | 1 + codegen-client/build.gradle.kts | 3 +- .../customize/ClientCodegenDecorator.kt | 17 - .../smithy/generators/ClientInstantiator.kt | 3 +- .../protocol/ClientProtocolTestGenerator.kt | 12 +- .../smithy/protocols/ClientProtocolLoader.kt | 12 + codegen-core/build.gradle.kts | 1 + .../rpcv2Cbor-extras.smithy | 349 +++++++++ .../codegen/core/rustlang/CargoDependency.kt | 26 +- .../core/rustlang/RustReservedWords.kt | 1 + .../codegen/core/smithy/CodegenContext.kt | 2 +- .../codegen/core/smithy/CoreRustSettings.kt | 5 +- .../rust/codegen/core/smithy/RuntimeType.kt | 16 +- .../smithy/customize/CoreCodegenDecorator.kt | 17 + .../core/smithy/generators/Instantiator.kt | 77 +- .../protocol/ProtocolTestGenerator.kt | 21 +- .../codegen/core/smithy/protocols/AwsJson.kt | 2 +- .../smithy/protocols/HttpBindingResolver.kt | 8 + .../codegen/core/smithy/protocols/Protocol.kt | 19 +- .../smithy/protocols/ProtocolFunctions.kt | 2 +- .../codegen/core/smithy/protocols/RestJson.kt | 8 +- .../core/smithy/protocols/RpcV2Cbor.kt | 121 ++++ .../protocols/parse/CborParserGenerator.kt | 666 ++++++++++++++++++ .../protocols/parse/JsonParserGenerator.kt | 7 +- .../protocols/parse/ReturnSymbolToParse.kt | 14 + .../parse/StructuredDataParserGenerator.kt | 4 +- .../serialize/CborSerializerGenerator.kt | 419 +++++++++++ .../serialize/JsonSerializerGenerator.kt | 41 +- .../StructuredDataSerializerGenerator.kt | 8 +- .../core/smithy/traits/SyntheticInputTrait.kt | 9 +- .../smithy/traits/SyntheticOutputTrait.kt | 10 +- .../transformers/OperationNormalizer.kt | 36 +- .../NamingObstacleCourseTestModels.kt | 2 +- .../rust/codegen/core/testutil/TestHelpers.kt | 19 +- .../smithy/rust/codegen/core/util/Smithy.kt | 76 +- codegen-server-test/build.gradle.kts | 7 + codegen-server/build.gradle.kts | 4 + .../protocols/PythonServerProtocolLoader.kt | 4 +- .../server/smithy/ServerCargoDependency.kt | 3 +- .../server/smithy/ServerCodegenVisitor.kt | 11 +- ...ypeFieldToServerErrorsCborCustomization.kt | 60 ++ ...ncodingMapOrCollectionCborCustomization.kt | 41 ++ .../smithy/generators/ServerInstantiator.kt | 29 +- .../generators/protocol/ServerProtocol.kt | 95 ++- .../protocol/ServerProtocolTestGenerator.kt | 84 ++- .../ServerHttpBoundProtocolGenerator.kt | 90 ++- .../smithy/protocols/ServerProtocolLoader.kt | 10 +- .../protocols/ServerRpcV2CborFactory.kt | 41 ++ ...rGeneratorSerdeRoundTripIntegrationTest.kt | 358 ++++++++++ examples/Cargo.toml | 1 - rust-runtime/Cargo.lock | 127 +++- rust-runtime/Cargo.toml | 1 + rust-runtime/aws-smithy-cbor/Cargo.toml | 41 ++ rust-runtime/aws-smithy-cbor/LICENSE | 175 +++++ rust-runtime/aws-smithy-cbor/README.md | 8 + rust-runtime/aws-smithy-cbor/benches/blob.rs | 26 + .../aws-smithy-cbor/benches/string.rs | 136 ++++ rust-runtime/aws-smithy-cbor/src/data.rs | 102 +++ rust-runtime/aws-smithy-cbor/src/decode.rs | 341 +++++++++ rust-runtime/aws-smithy-cbor/src/encode.rs | 117 +++ rust-runtime/aws-smithy-cbor/src/lib.rs | 17 + .../aws-smithy-http-server/Cargo.toml | 3 +- .../src/operation/upgrade.rs | 2 + .../src/protocol/aws_json/router.rs | 4 +- .../src/protocol/aws_json_10/mod.rs | 2 +- .../src/protocol/aws_json_10/router.rs | 2 + .../src/protocol/aws_json_11/mod.rs | 2 +- .../src/protocol/aws_json_11/router.rs | 2 + .../src/protocol/mod.rs | 1 + .../src/protocol/rest/router.rs | 6 +- .../src/protocol/rest_json_1/mod.rs | 2 +- .../src/protocol/rest_json_1/router.rs | 2 + .../src/protocol/rest_xml/mod.rs | 2 +- .../src/protocol/rest_xml/router.rs | 3 +- .../src/protocol/rpc_v2_cbor/mod.rs | 12 + .../src/protocol/rpc_v2_cbor/rejection.rs | 49 ++ .../src/protocol/rpc_v2_cbor/router.rs | 406 +++++++++++ .../src/protocol/rpc_v2_cbor/runtime_error.rs | 98 +++ .../aws-smithy-http-server/src/routing/mod.rs | 4 +- .../aws-smithy-protocol-test/Cargo.toml | 6 +- .../aws-smithy-protocol-test/src/lib.rs | 70 +- 84 files changed, 4479 insertions(+), 246 deletions(-) create mode 100644 codegen-core/common-test-models/rpcv2Cbor-extras.smithy create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt create mode 100644 rust-runtime/aws-smithy-cbor/Cargo.toml create mode 100644 rust-runtime/aws-smithy-cbor/LICENSE create mode 100644 rust-runtime/aws-smithy-cbor/README.md create mode 100644 rust-runtime/aws-smithy-cbor/benches/blob.rs create mode 100644 rust-runtime/aws-smithy-cbor/benches/string.rs create mode 100644 rust-runtime/aws-smithy-cbor/src/data.rs create mode 100644 rust-runtime/aws-smithy-cbor/src/decode.rs create mode 100644 rust-runtime/aws-smithy-cbor/src/encode.rs create mode 100644 rust-runtime/aws-smithy-cbor/src/lib.rs create mode 100644 rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/mod.rs create mode 100644 rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/rejection.rs create mode 100644 rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs create mode 100644 rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs diff --git a/.cargo-deny-config.toml b/.cargo-deny-config.toml index 3a9407f00c..abb8c8c067 100644 --- a/.cargo-deny-config.toml +++ b/.cargo-deny-config.toml @@ -25,6 +25,9 @@ exceptions = [ { allow = ["OpenSSL"], name = "ring", version = "*" }, { allow = ["OpenSSL"], name = "aws-lc-sys", version = "*" }, { allow = ["OpenSSL"], name = "aws-lc-fips-sys", version = "*" }, + { allow = ["BlueOak-1.0.0"], name = "minicbor", version = "<=0.24.2" }, + # Safe to bump as long as license does not change -------------^ + # See D105255799. ] [[licenses.clarify]] diff --git a/build.gradle.kts b/build.gradle.kts index 20f2d9e400..5e11e0ab02 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ allprojects { val allowLocalDeps: String by project repositories { if (allowLocalDeps.toBoolean()) { - mavenLocal() + mavenLocal() } mavenCentral() google() diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index 3c025288ac..8e0fd36447 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -26,9 +26,84 @@ fun generateImports(imports: List): String = if (imports.isEmpty()) { "" } else { - "\"imports\": [${imports.map { "\"$it\"" }.joinToString(", ")}]," + "\"imports\": [${imports.joinToString(", ") { "\"$it\"" }}]," } +val RustKeywords = + setOf( + "as", + "break", + "const", + "continue", + "crate", + "else", + "enum", + "extern", + "false", + "fn", + "for", + "if", + "impl", + "in", + "let", + "loop", + "match", + "mod", + "move", + "mut", + "pub", + "ref", + "return", + "self", + "Self", + "static", + "struct", + "super", + "trait", + "true", + "type", + "unsafe", + "use", + "where", + "while", + "async", + "await", + "dyn", + "abstract", + "become", + "box", + "do", + "final", + "macro", + "override", + "priv", + "typeof", + "unsized", + "virtual", + "yield", + "try", + ) + +fun toRustCrateName(input: String): String { + if (input.isBlank()) { + throw IllegalArgumentException("Rust crate name cannot be empty") + } + val lowerCased = input.lowercase() + // Replace any sequence of characters that are not lowercase letters, numbers, dashes, or underscores with a single underscore. + val sanitized = lowerCased.replace(Regex("[^a-z0-9_-]+"), "_") + // Trim leading or trailing underscores. + val trimmed = sanitized.trim('_') + // Check if the resulting string is empty, purely numeric, or a reserved name + val finalName = + when { + trimmed.isEmpty() -> throw IllegalArgumentException("Rust crate name after sanitizing cannot be empty.") + trimmed.matches(Regex("\\d+")) -> "n$trimmed" // Prepend 'n' if the name is purely numeric. + trimmed in RustKeywords -> "${trimmed}_" // Append an underscore if the name is reserved. + else -> trimmed + } + return finalName +} + private fun generateSmithyBuild( projectDir: String, pluginName: String, @@ -48,7 +123,7 @@ private fun generateSmithyBuild( ${it.extraCodegenConfig ?: ""} }, "service": "${it.service}", - "module": "${it.module}", + "module": "${toRustCrateName(it.module)}", "moduleVersion": "0.0.1", "moduleDescription": "test", "moduleAuthors": ["protocoltest@example.com"] diff --git a/buildSrc/src/main/kotlin/CrateSet.kt b/buildSrc/src/main/kotlin/CrateSet.kt index bc90115443..253bfa08ca 100644 --- a/buildSrc/src/main/kotlin/CrateSet.kt +++ b/buildSrc/src/main/kotlin/CrateSet.kt @@ -56,6 +56,7 @@ object CrateSet { val SMITHY_RUNTIME_COMMON = listOf( "aws-smithy-async", + "aws-smithy-cbor", "aws-smithy-checksums", "aws-smithy-compression", "aws-smithy-client", diff --git a/codegen-client/build.gradle.kts b/codegen-client/build.gradle.kts index 485a656d7b..3e1f1ec580 100644 --- a/codegen-client/build.gradle.kts +++ b/codegen-client/build.gradle.kts @@ -27,9 +27,10 @@ dependencies { implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") // `smithy.framework#ValidationException` is defined here, which is used in event stream -// marshalling/unmarshalling tests. + // marshalling/unmarshalling tests. testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt index c4aec33b59..306a439a9e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt @@ -19,7 +19,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorC import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator -import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import java.util.ServiceLoader import java.util.logging.Logger @@ -93,14 +92,6 @@ interface ClientCodegenDecorator : CoreCodegenDecorator, ): List = baseCustomizations - - /** - * Hook to override the protocol test generator - */ - fun protocolTestGenerator( - codegenContext: ClientCodegenContext, - baseGenerator: ProtocolTestGenerator, - ): ProtocolTestGenerator = baseGenerator } /** @@ -176,14 +167,6 @@ open class CombinedClientCodegenDecorator(decorators: List - decorator.protocolTestGenerator(codegenContext, gen) - } - companion object { fun fromClasspath( context: PluginContext, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt index ca39264a1c..992d85dd72 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt @@ -28,11 +28,12 @@ class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = true } -class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Instantiator( +class ClientInstantiator(private val codegenContext: ClientCodegenContext, withinTest: Boolean = false) : Instantiator( codegenContext.symbolProvider, codegenContext.model, codegenContext.runtimeConfig, ClientBuilderKindBehavior(codegenContext), + withinTest = false, ) { fun renderFluentCall( writer: RustWriter, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt index 5936438adb..85d32c2bf3 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt @@ -114,7 +114,7 @@ class ClientProtocolTestGenerator( get() = AppliesTo.CLIENT override val expectFail: Set get() = ExpectFail - override val runOnly: Set + override val generateOnly: Set get() = emptySet() override val disabledTests: Set get() = emptySet() @@ -128,7 +128,7 @@ class ClientProtocolTestGenerator( private val inputShape = operationShape.inputShape(codegenContext.model) private val outputShape = operationShape.outputShape(codegenContext.model) - private val instantiator = ClientInstantiator(codegenContext) + private val instantiator = ClientInstantiator(codegenContext, withinTest = true) private val codegenScope = arrayOf( @@ -149,6 +149,8 @@ class ClientProtocolTestGenerator( } private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) { + logger.info("Generating request test: ${httpRequestTestCase.id}") + if (!protocolSupport.requestSerialization) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -234,6 +236,8 @@ class ClientProtocolTestGenerator( testCase: HttpResponseTestCase, expectedShape: StructureShape, ) { + logger.info("Generating response test: ${testCase.id}") + if (!protocolSupport.responseDeserialization || ( !protocolSupport.errorDeserialization && expectedShape.hasTrait( @@ -357,8 +361,8 @@ class ClientProtocolTestGenerator( if (body == "") { rustWriter.rustTemplate( """ - // No body - #{AssertEq}(::std::str::from_utf8(body).unwrap(), ""); + // No body. + #{AssertEq}(&body, &bytes::Bytes::new()); """, *codegenScope, ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index 5c0f7e6e1a..f1a01edd6b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext @@ -28,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.util.hasTrait class ClientProtocolLoader(supportedProtocols: ProtocolMap) : @@ -41,6 +43,7 @@ class ClientProtocolLoader(supportedProtocols: ProtocolMap { + override fun protocol(codegenContext: ClientCodegenContext): Protocol = RpcV2Cbor(codegenContext) + + override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator = + OperationGenerator(codegenContext, protocol(codegenContext)) + + override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT +} diff --git a/codegen-core/build.gradle.kts b/codegen-core/build.gradle.kts index 2fdd74abfb..eff612be35 100644 --- a/codegen-core/build.gradle.kts +++ b/codegen-core/build.gradle.kts @@ -28,6 +28,7 @@ dependencies { implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") } fun gitCommitHash(): String { diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy new file mode 100644 index 0000000000..c60b93736d --- /dev/null +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -0,0 +1,349 @@ +$version: "2.0" + +namespace smithy.protocoltests.rpcv2Cbor + +use smithy.framework#ValidationException +use smithy.protocols#rpcv2Cbor +use smithy.test#httpResponseTests +use smithy.test#httpMalformedRequestTests + +@rpcv2Cbor +service RpcV2CborService { + operations: [ + SimpleStructOperation + ErrorSerializationOperation + ComplexStructOperation + EmptyStructOperation + SingleMemberStructOperation + ] +} + +// TODO(https://github.com/smithy-lang/smithy/issues/2326): Smithy should not +// allow HTTP binding traits in this protocol. +@http(uri: "/simple-struct-operation", method: "POST") +operation SimpleStructOperation { + input: SimpleStruct + output: SimpleStruct + errors: [ValidationException] +} + +operation ErrorSerializationOperation { + input: SimpleStruct + output: ErrorSerializationOperationOutput + errors: [ValidationException] +} + +operation ComplexStructOperation { + input: ComplexStruct + output: ComplexStruct + errors: [ValidationException] +} + +operation EmptyStructOperation { + input: EmptyStruct + output: EmptyStruct +} + +operation SingleMemberStructOperation { + input: SingleMemberStruct + output: SingleMemberStruct +} + +apply EmptyStructOperation @httpMalformedRequestTests([ + { + id: "AdditionalTokensEmptyStruct", + documentation: """ + When additional tokens are found past where we expect the end of the body, + the request should be rejected with a serialization exception.""", + protocol: rpcv2Cbor, + request: { + method: "POST", + uri: "/service/RpcV2CborService/operation/EmptyStructOperation", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + // Two empty variable-length encoded CBOR maps back to back. + body: "v/+//w==" + }, + response: { + code: 400, + body: { + mediaType: "application/cbor", + assertion: { + // An empty CBOR map. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type` because `SerializationException` is not modeled. + contents: "oA==" + } + } + } + + } +]) + +apply SingleMemberStructOperation @httpMalformedRequestTests([ + { + id: "AdditionalTokensSingleMemberStruct", + documentation: """ + When additional tokens are found past where we expect the end of the body, + the request should be rejected with a serialization exception.""", + protocol: rpcv2Cbor, + request: { + method: "POST", + uri: "/service/RpcV2CborService/operation/SingleMemberStructOperation", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + // Two empty variable-length encoded CBOR maps back to back. + body: "v/+//w==" + }, + response: { + code: 400, + body: { + mediaType: "application/cbor", + assertion: { + // An empty CBOR map. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type` because `SerializationException` is not modeled. + contents: "oA==" + } + } + } + } +]) + +apply ErrorSerializationOperation @httpMalformedRequestTests([ + { + id: "ErrorSerializationIncludesTypeField", + documentation: """ + When invalid input is provided the request should be rejected with + a validation exception, and a `__type` field should be included""", + protocol: rpcv2Cbor, + request: { + method: "POST", + uri: "/service/RpcV2CborService/operation/ErrorSerializationOperation", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + // An empty CBOR map. We're missing a lot of `@required` members! + body: "oA==" + }, + response: { + code: 400, + body: { + mediaType: "application/cbor", + assertion: { + contents: "v2ZfX3R5cGV4JHNtaXRoeS5mcmFtZXdvcmsjVmFsaWRhdGlvbkV4Y2VwdGlvbmdtZXNzYWdleGsxIHZhbGlkYXRpb24gZXJyb3IgZGV0ZWN0ZWQuIFZhbHVlIGF0ICcvcmVxdWlyZWRCbG9iJyBmYWlsZWQgdG8gc2F0aXNmeSBjb25zdHJhaW50OiBNZW1iZXIgbXVzdCBub3QgYmUgbnVsbGlmaWVsZExpc3SBv2RwYXRobS9yZXF1aXJlZEJsb2JnbWVzc2FnZXhOVmFsdWUgYXQgJy9yZXF1aXJlZEJsb2InIGZhaWxlZCB0byBzYXRpc2Z5IGNvbnN0cmFpbnQ6IE1lbWJlciBtdXN0IG5vdCBiZSBudWxs//8=" + } + } + } + } +]) + +apply ErrorSerializationOperation @httpResponseTests([ + { + id: "OperationOutputSerializationQuestionablyIncludesTypeField", + documentation: """ + Despite the operation output being a structure shape with the `@error` trait, + `__type` field should, in a strict interpretation of the spec, not be included, + because we're not serializing a server error response. However, we do, because + there shouldn't™️ be any harm in doing so, and it greatly simplifies the + code generator. This test just pins this behavior in case we ever modify it.""", + protocol: rpcv2Cbor, + code: 200, + params: { + errorShape: { + message: "ValidationException message field" + } + } + bodyMediaType: "application/cbor" + body: "v2plcnJvclNoYXBlv2ZfX3R5cGV4JHNtaXRoeS5mcmFtZXdvcmsjVmFsaWRhdGlvbkV4Y2VwdGlvbmdtZXNzYWdleCFWYWxpZGF0aW9uRXhjZXB0aW9uIG1lc3NhZ2UgZmllbGT//w==" + } +]) + +apply SimpleStructOperation @httpResponseTests([ + { + id: "SimpleStruct", + protocol: rpcv2Cbor, + code: 200, // Not used. + params: { + blob: "blobby blob", + boolean: false, + + string: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + byte: 69, + short: 70, + integer: 71, + long: 72, + + float: 0.69, + double: 0.6969, + + timestamp: 1546300800, + enum: "DIAMOND" + + // With `@required`. + + requiredBlob: "blobby blob", + requiredBoolean: false, + + requiredString: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + requiredByte: 69, + requiredShort: 70, + requiredInteger: 71, + requiredLong: 72, + + requiredFloat: 0.69, + requiredDouble: 0.6969, + + requiredTimestamp: 1546300800, + requiredEnum: "DIAMOND" + } + }, + // Same test, but leave optional types empty + { + id: "SimpleStructWithOptionsSetToNone", + protocol: rpcv2Cbor, + code: 200, // Not used. + params: { + requiredBlob: "blobby blob", + requiredBoolean: false, + + requiredString: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + requiredByte: 69, + requiredShort: 70, + requiredInteger: 71, + requiredLong: 72, + + requiredFloat: 0.69, + requiredDouble: 0.6969, + + requiredTimestamp: 1546300800, + requiredEnum: "DIAMOND" + } + } +]) + +structure ErrorSerializationOperationOutput { + errorShape: ValidationException +} + +structure SimpleStruct { + blob: Blob + boolean: Boolean + + string: String + + byte: Byte + short: Short + integer: Integer + long: Long + + float: Float + double: Double + + timestamp: Timestamp + enum: Suit + + // With `@required`. + + @required requiredBlob: Blob + @required requiredBoolean: Boolean + + @required requiredString: String + + @required requiredByte: Byte + @required requiredShort: Short + @required requiredInteger: Integer + @required requiredLong: Long + + @required requiredFloat: Float + @required requiredDouble: Double + + @required requiredTimestamp: Timestamp + // @required requiredDocument: MyDocument + @required requiredEnum: Suit +} + +structure ComplexStruct { + structure: SimpleStruct + emptyStructure: EmptyStruct + list: SimpleList + map: SimpleMap + union: SimpleUnion + unitUnion: UnitUnion + + structureList: StructList + + // `@required` for good measure here. + @required complexList: ComplexList + @required complexMap: ComplexMap + @required complexUnion: ComplexUnion +} + +structure EmptyStruct { } + +structure SingleMemberStruct { + message: String +} + +list StructList { + member: SimpleStruct +} + +list SimpleList { + member: String +} + +map SimpleMap { + key: String + value: Integer +} + +// TODO(https://github.com/smithy-lang/smithy/issues/2325): Upstream protocol +// test suite doesn't cover unions. While the generated SDK compiles, we're not +// exercising the (de)serializers with actual values. +union SimpleUnion { + blob: Blob + boolean: Boolean + string: String + unit: Unit +} + +union UnitUnion { + unitA: Unit + unitB: Unit +} + +list ComplexList { + member: ComplexMap +} + +map ComplexMap { + key: String + value: ComplexUnion +} + +union ComplexUnion { + // Recursive path here. + complexStruct: ComplexStruct + + structure: SimpleStruct + list: SimpleList + map: SimpleMap + union: SimpleUnion +} + +enum Suit { + DIAMOND + CLUB + HEART + SPADE +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index fbdd0dca11..f34921dfff 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.SymbolDependencyContainer import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.dq import java.nio.file.Path @@ -41,6 +42,12 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer { ) + dependencies().flatMap { it.dependencies } } + open fun toDevDependency(): RustDependency = + when (this) { + is CargoDependency -> this.toDevDependency() + is InlineDependency -> PANIC("it does not make sense for an inline dependency to be a dev-dependency") + } + companion object { private const val PROPERTY_KEY = "rustdep" @@ -71,9 +78,7 @@ class InlineDependency( return renderer.hashCode().toString() } - override fun dependencies(): List { - return extraDependencies - } + override fun dependencies(): List = extraDependencies fun key() = "${module.fullyQualifiedPath()}::$name" @@ -170,7 +175,7 @@ data class Feature(val name: String, val default: Boolean, val deps: List { * Hook for customizing symbols by inserting an additional symbol provider. */ fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = base + + /** + * Hook to override the protocol test generator. + */ + fun protocolTestGenerator( + codegenContext: CodegenContext, + baseGenerator: ProtocolTestGenerator, + ): ProtocolTestGenerator = baseGenerator } /** @@ -199,6 +208,14 @@ abstract class CombinedCoreCodegenDecorator + decorator.protocolTestGenerator(codegenContext, gen) + } + /** * Combines customizations from multiple ordered codegen decorators. * diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index de73ab760d..c09bc545fc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -65,6 +65,7 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.letIf import java.math.BigDecimal +import kotlin.jvm.optionals.getOrNull /** * Class describing an instantiator section that can be used in a customization. @@ -94,6 +95,13 @@ open class Instantiator( private val customizations: List = listOf(), private val constructPattern: InstantiatorConstructPattern = InstantiatorConstructPattern.BUILDER, private val customWritable: CustomWritable = NoCustomWritable(), + /** + * A protocol test may provide data for missing members (because we transformed the model). + * This flag makes it so that it is simply ignored, and code generation continues. + **/ + private val ignoreMissingMembers: Boolean = false, + /** Whether we're rendering within a test, in which case we should use dev-dependencies. */ + private val withinTest: Boolean = false, ) { data class Ctx( // The `http` crate requires that headers be lowercase, but Smithy protocol tests @@ -171,7 +179,7 @@ open class Instantiator( is MemberShape -> renderMember(writer, shape, data, ctx) is SimpleShape -> - PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate( + PrimitiveInstantiator(runtimeConfig, symbolProvider, withinTest).instantiate( shape, data, customWritable, @@ -422,8 +430,14 @@ open class Instantiator( } } - data.members.forEach { (key, value) -> - val memberShape = shape.expectMember(key.value) + for ((key, value) in data.members) { + val memberShape = + shape.getMember(key.value).getOrNull() + ?: if (ignoreMissingMembers) { + continue + } else { + throw CodegenException("Protocol test defines data for member shape `${key.value}`, but member shape was not found on structure shape ${shape.id}") + } renderMemberHelper(memberShape, value) } @@ -471,7 +485,27 @@ open class Instantiator( } } -class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private val symbolProvider: SymbolProvider) { +class PrimitiveInstantiator( + private val runtimeConfig: RuntimeConfig, + private val symbolProvider: SymbolProvider, + withinTest: Boolean = false, +) { + val codegenScope = + listOf( + "DateTime" to RuntimeType.dateTime(runtimeConfig), + "Bytestream" to RuntimeType.byteStream(runtimeConfig), + "Blob" to RuntimeType.blob(runtimeConfig), + "SmithyJson" to RuntimeType.smithyJson(runtimeConfig), + "SmithyTypes" to RuntimeType.smithyTypes(runtimeConfig), + ).map { + it.first to + if (withinTest) { + it.second.toDevDependencyType() + } else { + it.second + } + }.toTypedArray() + fun instantiate( shape: SimpleShape, data: Node, @@ -485,9 +519,9 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va val num = BigDecimal(node.toString()) val wholePart = num.toInt() val fractionalPart = num.remainder(BigDecimal.ONE) - rust( - "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", - RuntimeType.dateTime(runtimeConfig), + rustTemplate( + "#{DateTime}::from_fractional_secs($wholePart, ${fractionalPart}_f64)", + *codegenScope, ) } @@ -498,14 +532,14 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va */ is BlobShape -> if (shape.hasTrait()) { - rust( - "#T::from_static(b${(data as StringNode).value.dq()})", - RuntimeType.byteStream(runtimeConfig), + rustTemplate( + "#{Bytestream}::from_static(b${(data as StringNode).value.dq()})", + *codegenScope, ) } else { - rust( - "#T::new(${(data as StringNode).value.dq()})", - RuntimeType.blob(runtimeConfig), + rustTemplate( + "#{Blob}::new(${(data as StringNode).value.dq()})", + *codegenScope, ) } @@ -515,10 +549,10 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va is StringNode -> { val numberSymbol = symbolProvider.toSymbol(shape) // support Smithy custom values, such as Infinity - rust( - """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", - numberSymbol, - RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Parse"), + rustTemplate( + """<#{NumberSymbol} as #{SmithyTypes}::primitive::Parse>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", + "NumberSymbol" to numberSymbol, + *codegenScope, ) } @@ -533,15 +567,14 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va is BooleanShape -> rust(data.asBooleanNode().get().toString()) is DocumentShape -> rustBlock("") { - val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() + val smithyJson = CargoDependency.smithyJson(runtimeConfig).toDevDependency().toType() rustTemplate( """ let json_bytes = br##"${Node.prettyPrintJson(data)}"##; - let mut tokens = #{json_token_iter}(json_bytes).peekable(); - #{expect_document}(&mut tokens).expect("well formed json") + let mut tokens = #{SmithyJson}::deserialize::json_token_iter(json_bytes).peekable(); + #{SmithyJson}::deserialize::token::expect_document(&mut tokens).expect("well formed json") """, - "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), - "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), + *codegenScope, ) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index 20121535a8..3c7950ef34 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -64,7 +64,7 @@ abstract class ProtocolTestGenerator { abstract val brokenTests: Set /** Only generate these tests; useful to temporarily set and shorten development cycles */ - abstract val runOnly: Set + abstract val generateOnly: Set /** * These tests are not even attempted to be generated, either because they will not compile @@ -89,6 +89,8 @@ abstract class ProtocolTestGenerator { allMatchingTestCases().flatMap { fixBrokenTestCase(it) } + // Filter afterward in case a fixed broken test is disabled. + .filterMatching() if (allTests.isEmpty()) { return } @@ -109,6 +111,8 @@ abstract class ProtocolTestGenerator { if (!it.isBroken()) { listOf(it) } else { + logger.info("Fixing ${it.kind} test case ${it.id}") + assert(it.expectFail()) val brokenTest = it.findInBroken()!! @@ -160,11 +164,11 @@ abstract class ProtocolTestGenerator { /** Filter out test cases that are disabled or don't match the service protocol. */ private fun List.filterMatching(): List = - if (runOnly.isEmpty()) { + if (generateOnly.isEmpty()) { this.filter { testCase -> testCase.protocol == codegenContext.protocol && !disabledTests.contains(testCase.id) } } else { logger.warning("Generating only specified tests") - this.filter { testCase -> runOnly.contains(testCase.id) } + this.filter { testCase -> generateOnly.contains(testCase.id) } } private fun TestCase.toFailingTest(): FailingTest = @@ -191,7 +195,7 @@ abstract class ProtocolTestGenerator { val requestTests = operationShape.getTrait()?.getTestCasesFor(appliesTo).orEmpty() .map { TestCase.RequestTest(it) } - return requestTests.filterMatching() + return requestTests } fun responseTestCases(): List { @@ -209,7 +213,7 @@ abstract class ProtocolTestGenerator { ?.getTestCasesFor(appliesTo).orEmpty().map { TestCase.ResponseTest(it, error) } } - return (responseTestsOnOperations + responseTestsOnErrors).filterMatching() + return (responseTestsOnOperations + responseTestsOnErrors) } fun malformedRequestTestCases(): List { @@ -221,7 +225,7 @@ abstract class ProtocolTestGenerator { } else { emptyList() } - return malformedRequestTests.filterMatching() + return malformedRequestTests } /** @@ -412,6 +416,11 @@ object ServiceShapeId { const val AWS_JSON_10 = "aws.protocoltests.json10#JsonRpc10" const val AWS_JSON_11 = "aws.protocoltests.json#JsonProtocol" const val REST_JSON = "aws.protocoltests.restjson#RestJson" + const val RPC_V2_CBOR = "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol" + const val RPC_V2_CBOR_EXTRAS = "smithy.protocoltests.rpcv2Cbor#RpcV2CborService" + const val REST_XML = "aws.protocoltests.restxml#RestXml" + const val AWS_QUERY = "aws.protocoltests.query#AwsQuery" + const val EC2_QUERY = "aws.protocoltests.ec2#AwsEc2" const val REST_JSON_VALIDATION = "aws.protocoltests.restjson.validation#RestJsonValidation" } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 486b443a6a..b44bdfb84d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -174,10 +174,10 @@ open class AwsJson( override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> + // `HeaderMap::new()` doesn't allocate. rustTemplate( """ pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - // Note: HeaderMap::new() doesn't allocate #{json_errors}::parse_error_metadata(payload, &#{Headers}::new()) } """, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt index ad36e79190..afeaf5e1ce 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt @@ -169,6 +169,10 @@ open class HttpTraitHttpBindingResolver( model: Model, ): TimestampFormatTrait.Format = httpIndex.determineTimestampFormat(memberShape, location, defaultTimestampFormat) + /** + * Note that `null` will be returned and hence `Content-Type` will not be set when operation input has no members. + * This is in line with what protocol tests assert. + */ override fun requestContentType(operationShape: OperationShape): String? = httpIndex.determineRequestContentType( operationShape, @@ -176,6 +180,10 @@ open class HttpTraitHttpBindingResolver( contentTypes.eventStreamContentType, ).orNull() + /** + * Note that `null` will be returned and hence `Content-Type` will not be set when operation output has no members. + * This is in line with what protocol tests assert. + */ override fun responseContentType(operationShape: OperationShape): String? = httpIndex.determineResponseContentType( operationShape, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt index 236c297db9..c7b139bfd5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt @@ -27,11 +27,28 @@ interface Protocol { /** The timestamp format that should be used if no override is specified in the model */ val defaultTimestampFormat: TimestampFormatTrait.Format - /** Returns additional HTTP headers that should be included in HTTP requests for the given operation for this protocol. */ + /** + * Returns additional HTTP headers that should be included in HTTP requests for the given operation for this protocol. + * + * These MUST all be lowercase, or the application will panic, as per + * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static + */ fun additionalRequestHeaders(operationShape: OperationShape): List> = emptyList() + /** + * Returns additional HTTP headers that should be included in HTTP responses for the given operation for this protocol. + * + * These MUST all be lowercase, or the application will panic, as per + * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static + */ + fun additionalResponseHeaders(operationShape: OperationShape): List> = emptyList() + /** * Returns additional HTTP headers that should be included in HTTP responses for the given error shape. + * These headers are added to responses _in addition_ to those returned by `additionalResponseHeaders`; if a header + * added by this function has the same header name as one added by `additionalResponseHeaders`, the one added by + * `additionalResponseHeaders` takes precedence. + * * These MUST all be lowercase, or the application will panic, as per * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static */ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt index e40046f1d8..cb9b766771 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt @@ -39,7 +39,7 @@ class ProtocolFunctions( private val codegenContext: CodegenContext, ) { companion object { - private val serDeModule = RustModule.pubCrate("protocol_serde") + val serDeModule = RustModule.pubCrate("protocol_serde") fun crossOperationFn( fnName: String, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index 641548fc11..c4e3980668 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -56,6 +56,12 @@ class RestJsonHttpBindingResolver( } } } + + // The spec does not mention whether we should set the `Content-Type` header when there is no modeled output. + // The protocol tests indicate it's optional: + // + // + // In our implementation, we opt to always set it to `application/json`. return super.responseContentType(operationShape) ?: "application/json" } } @@ -124,10 +130,10 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> + // `HeaderMap::new()` doesn't allocate. rustTemplate( """ pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - // Note: HeaderMap::new() doesn't allocate #{json_errors}::parse_error_metadata(payload, &#{Headers}::new()) } """, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt new file mode 100644 index 0000000000..d1af7ae72c --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -0,0 +1,121 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ToShapeId +import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.isStreaming +import software.amazon.smithy.rust.codegen.core.util.outputShape + +class RpcV2CborHttpBindingResolver( + private val model: Model, + private val contentTypes: ProtocolContentTypes, +) : HttpBindingResolver { + private fun bindings(shape: ToShapeId): List { + val members = shape.let { model.expectShape(it.toShapeId()) }.members() + // TODO(https://github.com/awslabs/smithy-rs/issues/2237): support non-streaming members too + if (members.size > 1 && members.any { it.isStreaming(model) }) { + throw CodegenException( + "We only support one payload member if that payload contains a streaming member." + + "Tracking issue to relax this constraint: https://github.com/awslabs/smithy-rs/issues/2237", + ) + } + + return members.map { + if (it.isStreaming(model)) { + HttpBindingDescriptor(it, HttpLocation.PAYLOAD, "document") + } else { + HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") + } + } + .toList() + } + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + // In the server, this is only used when the protocol actually supports the `@http` trait. + // However, we will have to do this for client support. Perhaps this method deserves a rename. + override fun httpTrait(operationShape: OperationShape) = PANIC("RPC v2 does not support the `@http` trait") + + override fun requestBindings(operationShape: OperationShape) = bindings(operationShape.inputShape) + + override fun responseBindings(operationShape: OperationShape) = bindings(operationShape.outputShape) + + override fun errorResponseBindings(errorShape: ToShapeId) = bindings(errorShape) + + /** + * https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#requests + * > Requests for operations with no defined input type MUST NOT contain bodies in their HTTP requests. + * > The `Content-Type` for the serialization format MUST NOT be set. + */ + override fun requestContentType(operationShape: OperationShape): String? = + if (OperationNormalizer.hadUserModeledOperationInput(operationShape, model)) { + contentTypes.requestDocument + } else { + null + } + + /** + * https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#responses + * > Responses for operations with no defined output type MUST NOT contain bodies in their HTTP responses. + * > The `Content-Type` for the serialization format MUST NOT be set. + */ + override fun responseContentType(operationShape: OperationShape): String? = + if (OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { + contentTypes.responseDocument + } else { + null + } + + override fun eventStreamMessageContentType(memberShape: MemberShape): String? = + ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/cbor") +} + +open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { + override val httpBindingResolver: HttpBindingResolver = + RpcV2CborHttpBindingResolver( + codegenContext.model, + ProtocolContentTypes( + requestDocument = "application/cbor", + responseDocument = "application/cbor", + eventStreamContentType = "application/vnd.amazon.eventstream", + eventStreamMessageContentType = "application/cbor", + ), + ) + + // Note that [CborParserGenerator] and [CborSerializerGenerator] automatically (de)serialize timestamps + // using floating point seconds from the epoch. + override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS + + override fun additionalResponseHeaders(operationShape: OperationShape): List> = + listOf("smithy-protocol" to "rpc-v2-cbor") + + override fun structuredDataParser(): StructuredDataParserGenerator = + CborParserGenerator(codegenContext, httpBindingResolver) + + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = + CborSerializerGenerator(codegenContext, httpBindingResolver) + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + TODO("rpcv2Cbor client support has not yet been implemented") + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + TODO("rpcv2Cbor event streams have not yet been implemented") +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt new file mode 100644 index 0000000000..99208b0b9a --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -0,0 +1,666 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.SparseTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.isTargetUnit +import software.amazon.smithy.rust.codegen.core.util.outputShape + +/** Class describing a CBOR parser section that can be used in a customization. */ +sealed class CborParserSection(name: String) : Section(name) { + data class BeforeBoxingDeserializedMember(val shape: MemberShape) : CborParserSection("BeforeBoxingDeserializedMember") +} + +/** Customization for the CBOR parser. */ +typealias CborParserCustomization = NamedCustomization + +class CborParserGenerator( + private val codegenContext: CodegenContext, + private val httpBindingResolver: HttpBindingResolver, + /** See docs for this parameter in [JsonParserGenerator]. */ + private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape -> + ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) + }, + private val customizations: List = emptyList(), +) : StructuredDataParserGenerator { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig + private val codegenTarget = codegenContext.target + private val smithyCbor = CargoDependency.smithyCbor(runtimeConfig).toType() + private val protocolFunctions = ProtocolFunctions(codegenContext) + private val codegenScope = + arrayOf( + "SmithyCbor" to smithyCbor, + "Decoder" to smithyCbor.resolve("Decoder"), + "Error" to smithyCbor.resolve("decode::DeserializeError"), + "HashMap" to RuntimeType.HashMap, + *preludeScope, + ) + + private fun listMemberParserFn( + listSymbol: Symbol, + isSparseList: Boolean, + memberShape: MemberShape, + returnUnconstrainedType: Boolean, + ) = writable { + rustBlockTemplate( + """ + fn member( + mut list: #{ListSymbol}, + decoder: &mut #{Decoder}, + ) -> #{Result}<#{ListSymbol}, #{Error}> + """, + *codegenScope, + "ListSymbol" to listSymbol, + ) { + val deserializeMemberWritable = deserializeMember(memberShape) + if (isSparseList) { + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + decoder.null()?; + None + } + _ => Some(#{DeserializeMember:W}?), + }; + """, + *codegenScope, + "DeserializeMember" to deserializeMemberWritable, + ) + } else { + rustTemplate( + """ + let value = #{DeserializeMember:W}?; + """, + "DeserializeMember" to deserializeMemberWritable, + ) + } + + if (returnUnconstrainedType) { + rust("list.0.push(value);") + } else { + rust("list.push(value);") + } + + rust("Ok(list)") + } + } + + private fun mapPairParserFnWritable( + keyTarget: StringShape, + valueShape: MemberShape, + isSparseMap: Boolean, + mapSymbol: Symbol, + returnUnconstrainedType: Boolean, + ) = writable { + rustBlockTemplate( + """ + fn pair( + mut map: #{MapSymbol}, + decoder: &mut #{Decoder}, + ) -> #{Result}<#{MapSymbol}, #{Error}> + """, + *codegenScope, + "MapSymbol" to mapSymbol, + ) { + val deserializeKeyWritable = deserializeString(keyTarget) + rustTemplate( + """ + let key = #{DeserializeKey:W}?; + """, + "DeserializeKey" to deserializeKeyWritable, + ) + val deserializeValueWritable = deserializeMember(valueShape) + if (isSparseMap) { + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + decoder.null()?; + None + } + _ => Some(#{DeserializeValue:W}?), + }; + """, + *codegenScope, + "DeserializeValue" to deserializeValueWritable, + ) + } else { + rustTemplate( + """ + let value = #{DeserializeValue:W}?; + """, + "DeserializeValue" to deserializeValueWritable, + ) + } + + if (returnUnconstrainedType) { + rust("map.0.insert(key, value);") + } else { + rust("map.insert(key, value);") + } + + rust("Ok(map)") + } + } + + private fun structurePairParserFnWritable( + builderSymbol: Symbol, + includedMembers: Collection, + ) = writable { + rustBlockTemplate( + """ + ##[allow(clippy::match_single_binding)] + fn pair( + mut builder: #{Builder}, + decoder: &mut #{Decoder} + ) -> #{Result}<#{Builder}, #{Error}> + """, + *codegenScope, + "Builder" to builderSymbol, + ) { + withBlock("builder = match decoder.str()?.as_ref() {", "};") { + for (member in includedMembers) { + rustBlock("${member.memberName.dq()} =>") { + val callBuilderSetMemberFieldWritable = + writable { + withBlock("builder.${member.setterName()}(", ")") { + conditionalBlock("Some(", ")", symbolProvider.toSymbol(member).isOptional()) { + val symbol = symbolProvider.toSymbol(member) + if (symbol.isRustBoxed()) { + rustBlock("") { + rustTemplate( + "let v = #{DeserializeMember:W}?;", + "DeserializeMember" to deserializeMember(member), + ) + + for (customization in customizations) { + customization.section( + CborParserSection.BeforeBoxingDeserializedMember( + member, + ), + )(this) + } + rust("Box::new(v)") + } + } else { + rustTemplate( + "#{DeserializeMember:W}?", + "DeserializeMember" to deserializeMember(member), + ) + } + } + } + } + + if (member.isOptional) { + // Call `builder.set_member()` only if the value for the field on the wire is not null. + rustTemplate( + """ + #{SmithyCbor}::decode::set_optional(builder, decoder, |builder, decoder| { + Ok(#{MemberSettingWritable:W}) + })? + """, + *codegenScope, + "MemberSettingWritable" to callBuilderSetMemberFieldWritable, + ) + } else { + callBuilderSetMemberFieldWritable.invoke(this) + } + } + } + + rust( + """ + _ => { + decoder.skip()?; + builder + } + """, + ) + } + rust("Ok(builder)") + } + } + + private fun unionPairParserFnWritable(shape: UnionShape) = + writable { + val returnSymbolToParse = returnSymbolToParse(shape) + rustBlockTemplate( + """ + fn pair( + decoder: &mut #{Decoder} + ) -> #{Result}<#{UnionSymbol}, #{Error}> + """, + *codegenScope, + "UnionSymbol" to returnSymbolToParse.symbol, + ) { + withBlock("Ok(match decoder.str()?.as_ref() {", "})") { + for (member in shape.members()) { + val variantName = symbolProvider.toMemberName(member) + + if (member.isTargetUnit()) { + rust( + """ + ${member.memberName.dq()} => { + decoder.skip()?; + #T::$variantName + } + """, + returnSymbolToParse.symbol, + ) + } else { + withBlock("${member.memberName.dq()} => #T::$variantName(", "?),", returnSymbolToParse.symbol) { + deserializeMember(member).invoke(this) + } + } + } + when (codegenTarget.renderUnknownVariant()) { + // In client mode, resolve an unknown union variant to the unknown variant. + true -> + rustTemplate( + """ + _ => { + decoder.skip()?; + Some(#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME}) + } + """, + "Union" to returnSymbolToParse.symbol, + *codegenScope, + ) + // In server mode, use strict parsing. + // Consultation: https://github.com/awslabs/smithy/issues/1222 + false -> + rustTemplate( + "variant => return Err(#{Error}::unknown_union_variant(variant, decoder.position()))", + *codegenScope, + ) + } + } + } + } + + enum class CollectionKind { + Map, + List, + ; + + /** Method to invoke on the decoder to decode this collection kind. **/ + fun decoderMethodName() = + when (this) { + Map -> "map" + List -> "list" + } + } + + /** + * Decode a collection of homogeneous CBOR data items: a map or an array. + * The first branch of the `match` corresponds to when the collection is encoded using variable-length encoding; + * the second branch corresponds to fixed-length encoding. + * + * https://www.rfc-editor.org/rfc/rfc8949.html#name-indefinite-length-arrays-an + */ + private fun decodeCollectionLoopWritable( + collectionKind: CollectionKind, + variableBindingName: String, + decodeItemFnName: String, + ) = writable { + rustTemplate( + """ + match decoder.${collectionKind.decoderMethodName()}()? { + None => loop { + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + break; + } + _ => { + $variableBindingName = $decodeItemFnName($variableBindingName, decoder)?; + } + }; + }, + Some(n) => { + for _ in 0..n { + $variableBindingName = $decodeItemFnName($variableBindingName, decoder)?; + } + } + }; + """, + *codegenScope, + ) + } + + private fun decodeStructureMapLoopWritable() = decodeCollectionLoopWritable(CollectionKind.Map, "builder", "pair") + + private fun decodeMapLoopWritable() = decodeCollectionLoopWritable(CollectionKind.Map, "map", "pair") + + private fun decodeListLoopWritable() = decodeCollectionLoopWritable(CollectionKind.List, "list", "member") + + /** + * Reusable structure parser implementation that can be used to generate parsing code for + * operation, error and structure shapes. + * We still generate the parser symbol even if there are no included members because the server + * generation requires parsers for all input structures. + */ + private fun structureParser( + shape: Shape, + builderSymbol: Symbol, + includedMembers: List, + fnNameSuffix: String? = null, + ): RuntimeType { + return protocolFunctions.deserializeFn(shape, fnNameSuffix) { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(value: &[u8], mut builder: #{Builder}) -> #{Result}<#{Builder}, #{Error}> { + #{StructurePairParserFn:W} + + let decoder = &mut #{Decoder}::new(value); + + #{DecodeStructureMapLoop:W} + + if decoder.position() != value.len() { + return Err(#{Error}::expected_end_of_stream(decoder.position())); + } + + Ok(builder) + } + """, + "Builder" to builderSymbol, + "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), + "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), + *codegenScope, + ) + } + } + + override fun payloadParser(member: MemberShape): RuntimeType { + UNREACHABLE("No protocol using CBOR serialization supports payload binding") + } + + override fun operationParser(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR deserializer if there is nothing bound to the HTTP body. + val httpDocumentMembers = httpBindingResolver.responseMembers(operationShape, HttpLocation.DOCUMENT) + if (httpDocumentMembers.isEmpty()) { + return null + } + val outputShape = operationShape.outputShape(model) + return structureParser(operationShape, symbolProvider.symbolForBuilder(outputShape), httpDocumentMembers) + } + + override fun errorParser(errorShape: StructureShape): RuntimeType? { + if (errorShape.members().isEmpty()) { + return null + } + return structureParser( + errorShape, + symbolProvider.symbolForBuilder(errorShape), + errorShape.members().toList(), + fnNameSuffix = "cbor_err", + ) + } + + override fun serverInputParser(operationShape: OperationShape): RuntimeType? { + val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + if (includedMembers.isEmpty()) { + return null + } + val inputShape = operationShape.inputShape(model) + return structureParser(operationShape, symbolProvider.symbolForBuilder(inputShape), includedMembers) + } + + private fun deserializeMember(memberShape: MemberShape) = + writable { + when (val target = model.expectShape(memberShape.target)) { + // Simple shapes: https://smithy.io/2.0/spec/simple-types.html + is BlobShape -> rust("decoder.blob()") + is BooleanShape -> rust("decoder.boolean()") + + is StringShape -> deserializeString(target).invoke(this) + + is ByteShape -> rust("decoder.byte()") + is ShortShape -> rust("decoder.short()") + is IntegerShape -> rust("decoder.integer()") + is LongShape -> rust("decoder.long()") + + is FloatShape -> rust("decoder.float()") + is DoubleShape -> rust("decoder.double()") + + is TimestampShape -> rust("decoder.timestamp()") + + // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html + is StructureShape -> deserializeStruct(target) + is CollectionShape -> deserializeCollection(target) + is MapShape -> deserializeMap(target) + is UnionShape -> deserializeUnion(target) + + // Note that no protocol using CBOR serialization supports `document` shapes. + else -> PANIC("unexpected shape: $target") + } + } + + private fun deserializeString(target: StringShape) = + writable { + when (target.hasTrait()) { + true -> { + if (this@CborParserGenerator.returnSymbolToParse(target).isUnconstrained) { + rust("decoder.string()") + } else { + rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) + } + } + false -> rust("decoder.string()") + } + } + + private fun RustWriter.deserializeCollection(shape: CollectionShape) { + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) + + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = + writable { + withBlock("let mut list = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{Vec}::new()", *codegenScope) + } + } + } + + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}> { + #{ListMemberParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeListLoop:W} + + Ok(list) + } + """, + "ReturnType" to returnSymbol, + "ListMemberParserFn" to + listMemberParserFn( + returnSymbol, + isSparseList = shape.hasTrait(), + shape.member, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeListLoop" to decodeListLoopWritable(), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeMap(shape: MapShape) { + val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) + + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = + writable { + withBlock("let mut map = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{HashMap}::new()", *codegenScope) + } + } + } + + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}> { + #{MapPairParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeMapLoop:W} + + Ok(map) + } + """, + "ReturnType" to returnSymbol, + "MapPairParserFn" to + mapPairParserFnWritable( + keyTarget, + shape.value, + isSparseMap = shape.hasTrait(), + returnSymbol, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeMapLoop" to decodeMapLoopWritable(), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeStruct(shape: StructureShape) { + val returnSymbolToParse = returnSymbolToParse(shape) + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + "pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}>", + "ReturnType" to returnSymbolToParse.symbol, + *codegenScope, + ) { + val builderSymbol = symbolProvider.symbolForBuilder(shape) + val includedMembers = shape.members() + + rustTemplate( + """ + #{StructurePairParserFn:W} + + let mut builder = #{Builder}::default(); + + #{DecodeStructureMapLoop:W} + """, + *codegenScope, + "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), + "Builder" to builderSymbol, + "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), + ) + + // Only call `build()` if the builder is not fallible. Otherwise, return the builder. + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(builder)") + } else { + rust("Ok(builder.build())") + } + } + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeUnion(shape: UnionShape) { + val returnSymbolToParse = returnSymbolToParse(shape) + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{UnionSymbol}, #{Error}> { + #{UnionPairParserFnWritable} + + match decoder.map()? { + None => { + let variant = pair(decoder)?; + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + Ok(variant) + } + ty => Err( + #{Error}::unexpected_union_variant( + ty, + decoder.position(), + ), + ), + } + } + Some(1) => pair(decoder), + Some(_) => Err(#{Error}::mixed_union_variants(decoder.position())) + } + } + """, + "UnionSymbol" to returnSymbolToParse.symbol, + "UnionPairParserFnWritable" to unionPairParserFnWritable(shape), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 4833f808fb..cf0676ebd0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -77,8 +77,6 @@ sealed class JsonParserSection(name: String) : Section(name) { */ typealias JsonParserCustomization = NamedCustomization -data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) - class JsonParserGenerator( private val codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, @@ -339,8 +337,7 @@ class JsonParserGenerator( rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) } } - - else -> rust("u.into_owned()") + false -> rust("u.into_owned()") } } } @@ -447,7 +444,7 @@ class JsonParserGenerator( } private fun RustWriter.deserializeMap(shape: MapShape) { - val keyTarget = model.expectShape(shape.key.target) as StringShape + val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) val isSparse = shape.hasTrait() val returnSymbolToParse = returnSymbolToParse(shape) val parser = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt new file mode 100644 index 0000000000..4b69e87328 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt @@ -0,0 +1,14 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse + +import software.amazon.smithy.codegen.core.Symbol + +/** + * Parsers need to know what symbol to parse and return, and whether it's unconstrained or not. + * This data class holds this information that the parsers fill out from a shape. + */ +data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt index f8b053d80f..fd7e6fcb28 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt @@ -12,8 +12,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType interface StructuredDataParserGenerator { /** - * Generate a parse function for a given targeted as a payload. + * Generate a parse function for a given shape targeted with `@httpPayload`. * Entry point for payload-based parsing. + * * Roughly: * ```rust * fn parse_my_struct(input: &[u8]) -> Result { @@ -49,6 +50,7 @@ interface StructuredDataParserGenerator { /** * Generate a parser for a server operation input structure + * * ```rust * fn deser_operation_crate_operation_my_operation_input( * value: &[u8], builder: my_operation_input::Builder diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt new file mode 100644 index 0000000000..f96a8b7cbc --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -0,0 +1,419 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize + +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.isTargetUnit +import software.amazon.smithy.rust.codegen.core.util.isUnit +import software.amazon.smithy.rust.codegen.core.util.outputShape + +/** + * Class describing a CBOR serializer section that can be used in a customization. + */ +sealed class CborSerializerSection(name: String) : Section(name) { + /** + * Mutate the serializer prior to serializing any structure members. Eg: this can be used to inject `__type` + * to record the error type in the case of an error structure. + */ + data class BeforeSerializingStructureMembers( + val structureShape: StructureShape, + val encoderBindingName: String, + ) : CborSerializerSection("BeforeSerializingStructureMembers") + + /** Manipulate the serializer context for a map prior to it being serialized. **/ + data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context) : + CborSerializerSection("BeforeIteratingOverMapOrCollection") +} + +/** + * Customization for the CBOR serializer. + */ +typealias CborSerializerCustomization = NamedCustomization + +class CborSerializerGenerator( + codegenContext: CodegenContext, + private val httpBindingResolver: HttpBindingResolver, + private val customizations: List = listOf(), +) : StructuredDataSerializerGenerator { + data class Context( + /** Expression representing the value to write to the encoder */ + var valueExpression: ValueExpression, + /** Shape to serialize */ + val shape: T, + ) + + data class MemberContext( + /** Name for the variable bound to the encoder object **/ + val encoderBindingName: String, + /** Expression representing the value to write to the `Encoder` */ + var valueExpression: ValueExpression, + val shape: MemberShape, + /** Whether to serialize null values if the type is optional */ + val writeNulls: Boolean = false, + ) { + companion object { + fun collectionMember( + context: Context, + itemName: String, + ): MemberContext = + MemberContext( + "encoder", + ValueExpression.Reference(itemName), + context.shape.member, + writeNulls = true, + ) + + fun mapMember( + context: Context, + key: String, + value: String, + ): MemberContext = + MemberContext( + "encoder.str($key)", + ValueExpression.Reference(value), + context.shape.value, + writeNulls = true, + ) + + fun structMember( + context: StructContext, + member: MemberShape, + symProvider: RustSymbolProvider, + ): MemberContext = + MemberContext( + encodeKeyExpression(member.memberName), + ValueExpression.Value("${context.localName}.${symProvider.toMemberName(member)}"), + member, + ) + + fun unionMember( + variantReference: String, + member: MemberShape, + ): MemberContext = + MemberContext( + encodeKeyExpression(member.memberName), + ValueExpression.Reference(variantReference), + member, + ) + + /** Returns an expression to encode a key member **/ + private fun encodeKeyExpression(name: String): String = "encoder.str(${name.dq()})" + } + } + + data class StructContext( + /** Name of the variable that holds the struct */ + val localName: String, + val shape: StructureShape, + ) + + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val codegenTarget = codegenContext.target + private val runtimeConfig = codegenContext.runtimeConfig + private val protocolFunctions = ProtocolFunctions(codegenContext) + + private val codegenScope = + arrayOf( + "Error" to runtimeConfig.serializationError(), + "Encoder" to RuntimeType.smithyCbor(runtimeConfig).resolve("Encoder"), + *preludeScope, + ) + private val serializerUtil = SerializerUtil(model, symbolProvider) + + /** + * Reusable structure serializer implementation that can be used to generate serializing code for + * operation outputs or errors. + * This function is only used by the server, the client uses directly [serializeStructure]. + */ + private fun serverSerializer( + structureShape: StructureShape, + includedMembers: List, + error: Boolean, + ): RuntimeType { + val suffix = + when (error) { + true -> "error" + else -> "output" + } + return protocolFunctions.serializeFn(structureShape, fnNameSuffix = suffix) { fnName -> + rustBlockTemplate( + "pub fn $fnName(value: &#{target}) -> #{Result}<#{Vec}, #{Error}>", + *codegenScope, + "target" to symbolProvider.toSymbol(structureShape), + ) { + rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope) + // Open a scope in which we can safely shadow the `encoder` variable to bind it to a mutable reference. + rustBlock("") { + rust("let encoder = &mut encoder;") + serializeStructure( + StructContext("value", structureShape), + includedMembers, + ) + } + rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope) + } + } + } + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + override fun payloadSerializer(member: MemberShape): RuntimeType { + TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573") + } + + override fun unsetStructure(structure: StructureShape): RuntimeType = + UNREACHABLE("Only clients use this method when serializing an `@httpPayload`. No protocol using CBOR supports this trait, so we don't need to implement this") + + override fun unsetUnion(union: UnionShape): RuntimeType = + UNREACHABLE("Only clients use this method when serializing an `@httpPayload`. No protocol using CBOR supports this trait, so we don't need to implement this") + + override fun operationInputSerializer(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR serializer if there is no CBOR body. + val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + if (httpDocumentMembers.isEmpty()) { + return null + } + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + TODO("Client implementation should fill this out") + } + + override fun documentSerializer(): RuntimeType = + UNREACHABLE("No protocol using CBOR supports `document` shapes, so we don't need to implement this") + + override fun operationOutputSerializer(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR serializer if there was no operation output shape in the + // original (untransformed) model. + if (!OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { + return null + } + + val httpDocumentMembers = httpBindingResolver.responseMembers(operationShape, HttpLocation.DOCUMENT) + val outputShape = operationShape.outputShape(model) + return serverSerializer(outputShape, httpDocumentMembers, error = false) + } + + override fun serverErrorSerializer(shape: ShapeId): RuntimeType { + val errorShape = model.expectShape(shape, StructureShape::class.java) + val includedMembers = + httpBindingResolver.errorResponseBindings(shape).filter { it.location == HttpLocation.DOCUMENT } + .map { it.member } + return serverSerializer(errorShape, includedMembers, error = true) + } + + private fun RustWriter.serializeStructure( + context: StructContext, + includedMembers: List? = null, + ) { + if (context.shape.isUnit()) { + rust( + """ + encoder.begin_map(); + encoder.end(); + """, + ) + return + } + + val structureSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> #{Result}<(), #{Error}>", + "StructureSymbol" to symbolProvider.toSymbol(context.shape), + *codegenScope, + ) { + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3745) If all members are non-`Option`-al, + // we know AOT the map's size and can use `.map()` instead of `.begin_map()` for efficiency. + rust("encoder.begin_map();") + for (customization in customizations) { + customization.section( + CborSerializerSection.BeforeSerializingStructureMembers( + context.shape, + "encoder", + ), + )(this) + } + context.copy(localName = "input").also { inner -> + val members = includedMembers ?: inner.shape.members() + for (member in members) { + serializeMember(MemberContext.structMember(inner, member, symbolProvider)) + } + } + rust("encoder.end();") + rust("Ok(())") + } + } + rust("#T(encoder, ${context.localName})?;", structureSerializer) + } + + private fun RustWriter.serializeMember(context: MemberContext) { + val targetShape = model.expectShape(context.shape.target) + if (symbolProvider.toSymbol(context.shape).isOptional()) { + safeName().also { local -> + rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { + context.valueExpression = ValueExpression.Reference(local) + serializeMemberValue(context, targetShape) + } + if (context.writeNulls) { + rustBlock("else") { + rust("${context.encoderBindingName}.null();") + } + } + } + } else { + with(serializerUtil) { + ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) { + serializeMemberValue(context, targetShape) + } + } + } + } + + private fun RustWriter.serializeMemberValue( + context: MemberContext, + target: Shape, + ) { + val encoder = context.encoderBindingName + val value = context.valueExpression + val containerShape = model.expectShape(context.shape.container) + + when (target) { + // Simple shapes: https://smithy.io/2.0/spec/simple-types.html + is BlobShape -> rust("$encoder.blob(${value.asRef()});") + is BooleanShape -> rust("$encoder.boolean(${value.asValue()});") + + is StringShape -> rust("$encoder.str(${value.name}.as_str());") + + is ByteShape -> rust("$encoder.byte(${value.asValue()});") + is ShortShape -> rust("$encoder.short(${value.asValue()});") + is IntegerShape -> rust("$encoder.integer(${value.asValue()});") + is LongShape -> rust("$encoder.long(${value.asValue()});") + + is FloatShape -> rust("$encoder.float(${value.asValue()});") + is DoubleShape -> rust("$encoder.double(${value.asValue()});") + + is TimestampShape -> rust("$encoder.timestamp(${value.asRef()});") + + is DocumentShape -> UNREACHABLE("Smithy RPC v2 CBOR does not support `document` shapes") + + // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html + else -> { + // This condition is equivalent to `containerShape !is CollectionShape`. + if (containerShape is StructureShape || containerShape is UnionShape || containerShape is MapShape) { + rust("$encoder;") // Encode the member key. + } + when (target) { + is StructureShape -> serializeStructure(StructContext(value.name, target)) + is CollectionShape -> serializeCollection(Context(value, target)) + is MapShape -> serializeMap(Context(value, target)) + is UnionShape -> serializeUnion(Context(value, target)) + else -> UNREACHABLE("Smithy added a new aggregate shape: $target") + } + } + } + } + + private fun RustWriter.serializeCollection(context: Context) { + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) + } + rust("encoder.array((${context.valueExpression.asValue()}).len());") + val itemName = safeName("item") + rustBlock("for $itemName in ${context.valueExpression.asRef()}") { + serializeMember(MemberContext.collectionMember(context, itemName)) + } + } + + private fun RustWriter.serializeMap(context: Context) { + val keyName = safeName("key") + val valueName = safeName("value") + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) + } + rust("encoder.map((${context.valueExpression.asValue()}).len());") + rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { + val keyExpression = "$keyName.as_str()" + serializeMember(MemberContext.mapMember(context, keyExpression, valueName)) + } + } + + private fun RustWriter.serializeUnion(context: Context) { + val unionSymbol = symbolProvider.toSymbol(context.shape) + val unionSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, input: &#{UnionSymbol}) -> #{Result}<(), #{Error}>", + "UnionSymbol" to unionSymbol, + *codegenScope, + ) { + // A union is serialized identically as a `structure` shape, but only a single member can be set to a + // non-null value. + rust("encoder.map(1);") + rustBlock("match input") { + for (member in context.shape.members()) { + val variantName = + if (member.isTargetUnit()) { + symbolProvider.toMemberName(member) + } else { + "${symbolProvider.toMemberName(member)}(inner)" + } + rustBlock("#T::$variantName =>", unionSymbol) { + serializeMember(MemberContext.unionMember("inner", member)) + } + } + if (codegenTarget.renderUnknownVariant()) { + rustTemplate( + "#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME} => return #{Err}(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", + "Union" to unionSymbol, + *codegenScope, + ) + } + } + rustTemplate("#{Ok}(())", *codegenScope) + } + } + rust("#T(encoder, ${context.valueExpression.asRef()})?;", unionSerializer) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 46341bb09c..69ec11fdd2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -47,9 +47,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -212,12 +211,21 @@ class JsonSerializerGenerator( *codegenScope, "target" to symbolProvider.toSymbol(structureShape), ) { - rust("let mut out = String::new();") - rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) + rustTemplate( + """ + let mut out = #{String}::new(); + let mut object = #{JsonObjectWriter}::new(&mut out); + """, + *codegenScope, + ) serializeStructure(StructContext("object", "value", structureShape), includedMembers) customizations.forEach { it.section(makeSection(structureShape, "object"))(this) } - rust("object.finish();") - rustTemplate("Ok(out)", *codegenScope) + rust( + """ + object.finish(); + Ok(out) + """, + ) } } } @@ -304,8 +312,7 @@ class JsonSerializerGenerator( override fun operationOutputSerializer(operationShape: OperationShape): RuntimeType? { // Don't generate an operation JSON serializer if there was no operation output shape in the // original (untransformed) model. - val syntheticOutputTrait = operationShape.outputShape(model).expectTrait() - if (syntheticOutputTrait.originalId == null) { + if (!OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { return null } @@ -485,13 +492,17 @@ class JsonSerializerGenerator( rust("let mut $objectName = ${context.writerExpression}.start_object();") // We call inner only when context's shape is not the Unit type. // If it were, calling inner would generate the following function: - // pub fn serialize_structure_crate_model_unit( - // object: &mut aws_smithy_json::serialize::JsonObjectWriter, - // input: &crate::model::Unit, - // ) -> Result<(), aws_smithy_http::operation::error::SerializationError> { - // let (_, _) = (object, input); - // Ok(()) - // } + // + // ```rust + // pub fn serialize_structure_crate_model_unit( + // object: &mut aws_smithy_json::serialize::JsonObjectWriter, + // input: &crate::model::Unit, + // ) -> Result<(), aws_smithy_http::operation::error::SerializationError> { + // let (_, _) = (object, input); + // Ok(()) + // } + // ``` + // // However, this would cause a compilation error at a call site because it cannot // extract data out of the Unit type that corresponds to the variable "input" above. if (!context.shape.isTargetUnit()) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt index 92b28d89fc..a85646673d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt @@ -25,12 +25,15 @@ interface StructuredDataSerializerGenerator { fun payloadSerializer(member: MemberShape): RuntimeType /** - * Generate the correct data when attempting to serialize a structure that is unset + * Generate the correct data when attempting to serialize a structure that is unset. * * ```rust * fn rest_json_unset_struct_payload() -> Vec { * ... * } + * ``` + * + * This method is only invoked when serializing an `@httpPayload`. */ fun unsetStructure(structure: StructureShape): RuntimeType @@ -41,6 +44,9 @@ interface StructuredDataSerializerGenerator { * fn rest_json_unset_union_payload() -> Vec { * ... * } + * ``` + * + * This method is only invoked when serializing an `@httpPayload`. */ fun unsetUnion(union: UnionShape): RuntimeType diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt index e4b95f5cde..b0c61d3d6c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt @@ -12,8 +12,13 @@ import software.amazon.smithy.model.traits.AnnotationTrait /** * Indicates that a shape is a synthetic input (see `OperationNormalizer.kt`) * - * All operations are normalized to have an input, even when they are defined without on. This is done for backwards - * compatibility and to produce a consistent API. + * All operations are normalized to have an input, even when they are defined without one. + * This is NOT done for backwards-compatibility, as adding an operation input is a breaking change + * (see ). + * + * It is only done to produce a consistent API. + * TODO(https://github.com/smithy-lang/smithy-rs/issues/3577): In the server, we'd like to stop adding + * these synthetic inputs. */ class SyntheticInputTrait( val operation: ShapeId, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt index d34ca0e39d..cc310e8a47 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt @@ -12,8 +12,14 @@ import software.amazon.smithy.model.traits.AnnotationTrait /** * Indicates that a shape is a synthetic output (see `OperationNormalizer.kt`) * - * All operations are normalized to have an output, even when they are defined without on. This is done for backwards - * compatibility and to produce a consistent API. + * All operations are normalized to have an output, even when they are defined without one. + * + * This is NOT done for backwards-compatibility, as adding an operation output is a breaking change + * (see ). + * + * It is only done to produce a consistent API. + * TODO(https://github.com/smithy-lang/smithy-rs/issues/3577): In the server, we'd like to stop adding + * these synthetic outputs. */ class SyntheticOutputTrait constructor(val operation: ShapeId, val originalId: ShapeId?) : AnnotationTrait(ID, Node.objectNode()) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt index 4092174b55..89d4512007 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt @@ -14,23 +14,27 @@ import software.amazon.smithy.model.traits.InputTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.orNull +import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.core.util.rename import java.util.Optional -import kotlin.streams.toList /** * Generate synthetic Input and Output structures for operations. * - * Operation input/output shapes can be retroactively added. In order to support this while maintaining backwards compatibility, - * we need to generate input/output shapes for all operations in a backwards compatible way. - * * This works by **adding** new shapes to the model for operation inputs & outputs. These new shapes have `SyntheticInputTrait` * and `SyntheticOutputTrait` attached to them as needed. This enables downstream code generators to determine if a shape is * "real" vs. a shape created as a synthetic input/output. * * The trait also tracks the original shape id for certain serialization tasks that require it to exist. + * + * Note that adding/removing operation input/output [is a breaking change]; the only reason why we synthetically add them + * is to produce a consistent API. + * + * [is a breaking change]: */ object OperationNormalizer { // Functions to construct synthetic shape IDs—Don't rely on these in external code. @@ -43,6 +47,30 @@ object OperationNormalizer { private fun OperationShape.syntheticOutputId() = ShapeId.fromParts(this.id.namespace + ".synthetic", "${this.id.name}Output") + /** + * Returns `true` if the user had originally modeled an operation input shape on the given [operation]; + * `false` if the transform added a synthetic one. + */ + fun hadUserModeledOperationInput( + operation: OperationShape, + model: Model, + ): Boolean { + val syntheticInputTrait = operation.inputShape(model).expectTrait() + return syntheticInputTrait.originalId != null + } + + /** + * Returns `true` if the user had originally modeled an operation output shape on the given [operation]; + * `false` if the transform added a synthetic one. + */ + fun hadUserModeledOperationOutput( + operation: OperationShape, + model: Model, + ): Boolean { + val syntheticOutputTrait = operation.outputShape(model).expectTrait() + return syntheticOutputTrait.originalId != null + } + /** * Add synthetic input & output shapes to every Operation in model. The generated shapes will be marked with * [SyntheticInputTrait] and [SyntheticOutputTrait] respectively. Shapes will be added _even_ if the operation does diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt index 5508611972..e0279c53c0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt @@ -178,7 +178,7 @@ object NamingObstacleCourseTestModels { /** * This targets two bug classes: * - operation inputs used as nested outputs - * - operation outputs used as nested outputs + * - operation outputs used as nested inputs */ fun reusedInputOutputShapesModel(protocol: Trait) = """ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index 13a59f3bae..42997a8012 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.testutil import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.loader.ModelDiscovery import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape @@ -151,7 +152,23 @@ fun String.asSmithyModel( disableValidation: Boolean = false, ): Model { val processed = letIf(!this.trimStart().startsWith("\$version")) { "\$version: ${smithyVersion.dq()}\n$it" } - val assembler = Model.assembler().discoverModels().addUnparsedModel(sourceLocation ?: "test.smithy", processed) + val denyModelsContaining = + arrayOf( + // If Smithy protocol test models are in our classpath, don't load them, since they are fairly large and we + // almost never need them. + "smithy-protocol-tests", + ) + val urls = + ModelDiscovery.findModels().filter { modelUrl -> + denyModelsContaining.none { + modelUrl.toString().contains(it) + } + } + val assembler = Model.assembler() + for (url in urls) { + assembler.addImport(url) + } + assembler.addUnparsedModel(sourceLocation ?: "test.smithy", processed) if (disableValidation) { assembler.disableValidation() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt index 975416d72f..f6d6ddf84f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt @@ -24,19 +24,15 @@ import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait -inline fun Model.lookup(shapeId: String): T { - return this.expectShape(ShapeId.from(shapeId), T::class.java) -} +inline fun Model.lookup(shapeId: String): T = this.expectShape(ShapeId.from(shapeId), T::class.java) -fun OperationShape.inputShape(model: Model): StructureShape { +fun OperationShape.inputShape(model: Model): StructureShape = // The Rust Smithy generator adds an input to all shapes automatically - return model.expectShape(this.input.get(), StructureShape::class.java) -} + model.expectShape(this.input.get(), StructureShape::class.java) -fun OperationShape.outputShape(model: Model): StructureShape { +fun OperationShape.outputShape(model: Model): StructureShape = // The Rust Smithy generator adds an output to all shapes automatically - return model.expectShape(this.output.get(), StructureShape::class.java) -} + model.expectShape(this.output.get(), StructureShape::class.java) fun StructureShape.expectMember(member: String): MemberShape = this.getMember(member).orElseThrow { CodegenException("$member did not exist on $this") } @@ -55,43 +51,32 @@ fun UnionShape.hasStreamingMember(model: Model) = this.findMemberWithTrait() -} +fun MemberShape.isInputEventStream(model: Model): Boolean = + isEventStream(model) && model.expectShape(container).hasTrait() -fun MemberShape.isOutputEventStream(model: Model): Boolean { - return isEventStream(model) && model.expectShape(container).hasTrait() -} +fun MemberShape.isOutputEventStream(model: Model): Boolean = + isEventStream(model) && model.expectShape(container).hasTrait() private val unitShapeId = ShapeId.from("smithy.api#Unit") -fun MemberShape.isTargetUnit(): Boolean { - return this.target == unitShapeId -} +fun Shape.isUnit(): Boolean = this.id == unitShapeId -fun Shape.hasEventStreamMember(model: Model): Boolean { - return members().any { it.isEventStream(model) } -} +fun MemberShape.isTargetUnit(): Boolean = this.target == unitShapeId -fun OperationShape.isInputEventStream(model: Model): Boolean { - return input.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -} +fun Shape.hasEventStreamMember(model: Model): Boolean = members().any { it.isEventStream(model) } -fun OperationShape.isOutputEventStream(model: Model): Boolean { - return output.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -} +fun OperationShape.isInputEventStream(model: Model): Boolean = + input.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -fun OperationShape.isEventStream(model: Model): Boolean { - return isInputEventStream(model) || isOutputEventStream(model) -} +fun OperationShape.isOutputEventStream(model: Model): Boolean = + output.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) + +fun OperationShape.isEventStream(model: Model): Boolean = isInputEventStream(model) || isOutputEventStream(model) fun ServiceShape.hasEventStreamOperations(model: Model): Boolean = operations.any { id -> @@ -125,17 +110,13 @@ fun Shape.redactIfNecessary( * * A structure must have at most one streaming member. */ -fun StructureShape.findStreamingMember(model: Model): MemberShape? { - return this.findMemberWithTrait(model) -} +fun StructureShape.findStreamingMember(model: Model): MemberShape? = this.findMemberWithTrait(model) -inline fun StructureShape.findMemberWithTrait(model: Model): MemberShape? { - return this.members().find { it.getMemberTrait(model, T::class.java).isPresent } -} +inline fun StructureShape.findMemberWithTrait(model: Model): MemberShape? = + this.members().find { it.getMemberTrait(model, T::class.java).isPresent } -inline fun UnionShape.findMemberWithTrait(model: Model): MemberShape? { - return this.members().find { it.getMemberTrait(model, T::class.java).isPresent } -} +inline fun UnionShape.findMemberWithTrait(model: Model): MemberShape? = + this.members().find { it.getMemberTrait(model, T::class.java).isPresent } /** * If is member shape returns target, otherwise returns self. @@ -156,12 +137,11 @@ inline fun Shape.expectTrait(): T = expectTrait(T::class.jav /** Kotlin sugar for getTrait() check. e.g. shape.getTrait() instead of shape.getTrait(EnumTrait::class.java) */ inline fun Shape.getTrait(): T? = getTrait(T::class.java).orNull() -fun Shape.isPrimitive(): Boolean { - return when (this) { +fun Shape.isPrimitive(): Boolean = + when (this) { is NumberShape, is BooleanShape -> true else -> false } -} /** Convert a string to a ShapeId */ fun String.shapeId() = ShapeId.from(this) diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 4134cdd039..808d476058 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -24,6 +24,7 @@ val workingDirUnderBuildDir = "smithyprojections/codegen-server-test/" dependencies { implementation(project(":codegen-server")) implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") @@ -43,6 +44,12 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"), ), CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")), + CodegenTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"), + CodegenTest( + "smithy.protocoltests.rpcv2Cbor#RpcV2CborService", + "rpcv2Cbor_extras", + imports = listOf("$commonModels/rpcv2Cbor-extras.smithy") + ), CodegenTest( "com.amazonaws.constraints#ConstraintsService", "constraints_without_public_constrained_types", diff --git a/codegen-server/build.gradle.kts b/codegen-server/build.gradle.kts index 0ba262225d..49e0462888 100644 --- a/codegen-server/build.gradle.kts +++ b/codegen-server/build.gradle.kts @@ -26,10 +26,14 @@ dependencies { implementation(project(":codegen-core")) implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") // `smithy.framework#ValidationException` is defined here, which is used in `constraints.smithy`, which is used // in `CustomValidationExceptionWithReasonDecoratorTest`. testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") + + // It's handy to re-use protocol test suite models from Smithy in our Kotlin tests. + testImplementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion") } java { diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt index 4e41cc8ed3..f50e236337 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt @@ -56,10 +56,10 @@ class PythonServerAfterDeserializedMemberJsonParserCustomization(private val run } /** - * Customization class used to force casting a non primitive type into one overriden by a new symbol provider, + * Customization class used to force casting a non-primitive type into one overridden by a new symbol provider, * by explicitly calling `into()` on it. */ -class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() : +class PythonServerAfterDeserializedMemberServerHttpBoundCustomization : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 01df0c0a93..a9c488503d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig */ object ServerCargoDependency { val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1.74")) + val Base64SimdDev: CargoDependency = CargoDependency("base64-simd", CratesIo("0.8"), scope = DependencyScope.Dev) val FormUrlEncoded: CargoDependency = CargoDependency("form_urlencoded", CratesIo("1")) val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3")) val Mime: CargoDependency = CargoDependency("mime", CratesIo("0.3")) @@ -26,7 +27,7 @@ object ServerCargoDependency { val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev) val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) - val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), DependencyScope.Dev) + val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), scope = DependencyScope.Dev) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index ea4eadad84..49c0e7c540 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -609,13 +609,20 @@ open class ServerCodegenVisitor( } /** - * Generate protocol tests. This method can be overridden by other languages such has Python. + * Generate protocol tests. This method can be overridden by other languages such as Python. */ open fun protocolTestsForOperation( writer: RustWriter, shape: OperationShape, ) { - ServerProtocolTestGenerator(codegenContext, protocolGeneratorFactory.support(), shape).render(writer) + codegenDecorator.protocolTestGenerator( + codegenContext, + ServerProtocolTestGenerator( + codegenContext, + protocolGeneratorFactory.support(), + shape, + ), + ).render(writer) } /** diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt new file mode 100644 index 0000000000..464a52dc46 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.escape +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.util.hasTrait + +/** + * Smithy RPC v2 CBOR requires errors to be serialized in server responses with an additional `__type` field. + * + * Note that we apply this customization when serializing _any_ structure with the `@error` trait, regardless if it's + * an error response or not. Consider this model: + * + * ```smithy + * operation ErrorSerializationOperation { + * input: SimpleStruct + * output: ErrorSerializationOperationOutput + * errors: [ValidationException] + * } + * + * structure ErrorSerializationOperationOutput { + * errorShape: ValidationException + * } + * ``` + * + * `ValidationException` is re-used across the operation output and the operation error. The `__type` field will + * appear when serializing both. + * + * Strictly speaking, the spec says we should only add `__type` when serializing an operation error response, but + * there shouldn't™️ be any harm in always including it, which simplifies the code generator. + */ +class AddTypeFieldToServerErrorsCborCustomization : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = + when (section) { + is CborSerializerSection.BeforeSerializingStructureMembers -> + if (section.structureShape.hasTrait()) { + writable { + rust( + """ + ${section.encoderBindingName} + .str("__type") + .str("${escape(section.structureShape.id.toString())}"); + """, + ) + } + } else { + emptySection + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt new file mode 100644 index 0000000000..a01d0076e9 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt @@ -0,0 +1,41 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +/** + * A customization to, just before we encode over a _constrained_ map or collection shape in a CBOR serializer, + * unwrap the wrapper newtype and take a shared reference to the actual value within it. + * That value will be a `std::collections::HashMap` for map shapes, and a `std::vec::Vec` for collection shapes. + */ +class BeforeEncodingMapOrCollectionCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = + when (section) { + is CborSerializerSection.BeforeIteratingOverMapOrCollection -> + writable { + check(section.shape is CollectionShape || section.shape is MapShape) + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index d06c21ff70..5b4860c26d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -70,18 +70,25 @@ class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat codegenContext.symbolProvider.toSymbol(memberShape).isOptional() } -class ServerInstantiator(codegenContext: CodegenContext, customWritable: CustomWritable = NoCustomWritable()) : +class ServerInstantiator( + codegenContext: CodegenContext, + customWritable: CustomWritable = NoCustomWritable(), + ignoreMissingMembers: Boolean = false, + withinTest: Boolean = false, +) : Instantiator( - codegenContext.symbolProvider, - codegenContext.model, - codegenContext.runtimeConfig, - ServerBuilderKindBehavior(codegenContext), - defaultsForRequiredFields = true, - customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), - // Construct with direct pattern to more closely replicate actual server customer usage - constructPattern = InstantiatorConstructPattern.DIRECT, - customWritable = customWritable, - ) + codegenContext.symbolProvider, + codegenContext.model, + codegenContext.runtimeConfig, + ServerBuilderKindBehavior(codegenContext), + defaultsForRequiredFields = true, + customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), + // Construct with direct pattern to more closely replicate actual server customer usage + constructPattern = InstantiatorConstructPattern.DIRECT, + customWritable = customWritable, + ignoreMissingMembers = ignoreMissingMembers, + withinTest = withinTest, + ) class ServerBuilderInstantiator( private val symbolProvider: RustSymbolProvider, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index f31f6d92da..d4984d65e9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -24,18 +24,26 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.ReturnSymbolToParse import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator @@ -298,6 +306,68 @@ class ServerRestXmlProtocol( ) } +class ServerRpcV2CborProtocol( + private val serverCodegenContext: ServerCodegenContext, +) : RpcV2Cbor(serverCodegenContext), ServerProtocol { + val runtimeConfig = codegenContext.runtimeConfig + + override val protocolModulePath = "rpc_v2_cbor" + + override fun structuredDataParser(): StructuredDataParserGenerator = + CborParserGenerator( + serverCodegenContext, httpBindingResolver, returnSymbolToParseFn(serverCodegenContext), + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization( + serverCodegenContext, + ), + ), + ) + + override fun structuredDataSerializer(): StructuredDataSerializerGenerator { + return CborSerializerGenerator( + codegenContext, + httpBindingResolver, + listOf( + BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), + AddTypeFieldToServerErrorsCborCustomization(), + ), + ) + } + + override fun markerStruct() = ServerRuntimeType.protocol("RpcV2Cbor", "rpc_v2_cbor", runtimeConfig) + + override fun routerType() = + ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() + .resolve("protocol::rpc_v2_cbor::router::RpcV2CborRouter") + + override fun serverRouterRequestSpec( + operationShape: OperationShape, + operationName: String, + serviceName: String, + requestSpecModule: RuntimeType, + ) = writable { + // This is just the key used by the router's map to store and look up operations, it's completely arbitrary. + // We use the same key used by the awsJson1.x routers for simplicity. + // The router will extract the service name and the operation name from the URI, build this key, and lookup the + // operation stored there. + rust("$serviceName.$operationName".dq()) + } + + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.StaticStr + + override fun serverRouterRuntimeConstructor() = "rpc_v2_router" + + override fun serverContentTypeCheckNoModeledInput() = false + + override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType = + deserializePayloadErrorType( + codegenContext, + binding, + requestRejection(runtimeConfig), + RuntimeType.smithyCbor(codegenContext.runtimeConfig).resolve("decode::DeserializeError"), + ) +} + /** Just a common function to keep things DRY. **/ fun deserializePayloadErrorType( codegenContext: CodegenContext, @@ -317,8 +387,8 @@ fun deserializePayloadErrorType( } /** - * A customization to, just before we box a recursive member that we've deserialized into `Option`, convert it into - * `MaybeConstrained` if the target shape can reach a constrained shape. + * A customization to, just before we box a recursive member that we've deserialized from JSON into `Option`, convert + * it into `MaybeConstrained` if the target shape can reach a constrained shape. */ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(val codegenContext: ServerCodegenContext) : JsonParserCustomization() { @@ -338,3 +408,24 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonPa else -> emptySection } } + +/** + * A customization to, just before we box a recursive member that we've deserialized from CBOR into `T` held in a + * variable binding `v`, convert it into `MaybeConstrained` if the target shape can reach a constrained shape. + */ +class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization(val codegenContext: ServerCodegenContext) : + CborParserCustomization() { + override fun section(section: CborParserSection): Writable = + when (section) { + is CborParserSection.BeforeBoxingDeserializedMember -> + writable { + // We're only interested in _structure_ member shapes that can reach constrained shapes. + if ( + codegenContext.model.expectShape(section.shape.container) is StructureShape && + section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) + ) { + rust("let v = v.into();") + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index cbcbca16e3..09e6b635de 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.DoubleShape import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.protocoltests.traits.AppliesTo @@ -38,6 +39,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Servi import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_11 import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON_VALIDATION +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR_EXTRAS import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.core.util.dq @@ -145,8 +148,15 @@ class ServerProtocolTestGenerator( AWS_JSON_10, "AwsJson10ServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams", ), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3723): This affects all protocols + FailingTest.MalformedRequestTest(RPC_V2_CBOR_EXTRAS, "AdditionalTokensEmptyStruct"), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3339) + FailingTest.ResponseTest(RPC_V2_CBOR, "RpcV2CborServerPopulatesDefaultsInResponseWhenMissingInParams"), FailingTest.ResponseTest(REST_JSON, "RestJsonServerPopulatesDefaultsInResponseWhenMissingInParams"), FailingTest.ResponseTest(REST_JSON, "RestJsonServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams"), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3743): We need to be able to configure + // instantiator so that it uses default _modeled_ values; `""` is not a valid enum value for `defaultEnum`. + FailingTest.RequestTest(RPC_V2_CBOR, "RpcV2CborServerPopulatesDefaultsWhenMissingInRequestBody"), // TODO(https://github.com/smithy-lang/smithy-rs/issues/3735): Null `Document` may come through a request even though its shape is `@required` FailingTest.RequestTest(REST_JSON, "RestJsonServerPopulatesDefaultsWhenMissingInRequestBody"), ) @@ -223,7 +233,7 @@ class ServerProtocolTestGenerator( get() = ExpectFail override val brokenTests: Set get() = BrokenTests - override val runOnly: Set + override val generateOnly: Set get() = emptySet() override val disabledTests: Set get() = DisabledTests @@ -258,10 +268,11 @@ class ServerProtocolTestGenerator( inputT to outputT } - private val instantiator = ServerInstantiator(codegenContext) + private val instantiator = ServerInstantiator(codegenContext, withinTest = true) private val codegenScope = arrayOf( + "Base64SimdDev" to ServerCargoDependency.Base64SimdDev.toType(), "Bytes" to RuntimeType.Bytes, "Hyper" to RuntimeType.Hyper, "Tokio" to ServerCargoDependency.TokioDev.toType(), @@ -288,20 +299,31 @@ class ServerProtocolTestGenerator( * an operation's input shape, the resulting shape is of the form we expect, as defined in the test case. */ private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) { + logger.info("Generating request test: ${httpRequestTestCase.id}") + if (!protocolSupport.requestDeserialization) { rust("/* test case disabled for this protocol (not yet supported) */") return } with(httpRequestTestCase) { - renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) + renderHttpRequest( + uri, + method, + headers, + body.orNull(), + bodyMediaType.orNull(), + protocol, + queryParams, + host.orNull(), + ) } if (protocolSupport.requestBodyDeserialization) { makeRequest(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase)) checkHandlerWasEntered(this) } - // Explicitly warn if the test case defined parameters that we aren't doing anything with + // Explicitly warn if the test case defined parameters that we aren't doing anything with. with(httpRequestTestCase) { if (authScheme.isPresent) { logger.warning("Test case provided authScheme but this was ignored") @@ -322,6 +344,8 @@ class ServerProtocolTestGenerator( testCase: HttpResponseTestCase, shape: StructureShape, ) { + logger.info("Generating response test: ${testCase.id}") + val operationErrorName = "crate::error::${operationSymbol.name}Error" if (!protocolSupport.responseSerialization || ( @@ -354,6 +378,8 @@ class ServerProtocolTestGenerator( * with the given response. */ private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase) { + logger.info("Generating malformed request test: ${testCase.id}") + val (_, outputT) = operationInputOutputTypes[operationShape]!! val panicMessage = "request should have been rejected, but we accepted it; we parsed operation input `{:?}`" @@ -361,7 +387,18 @@ class ServerProtocolTestGenerator( rustBlock("") { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. - renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) + // TODO(https://github.com/smithy-lang/smithy/issues/1932): we send `null` for `bodyMediaType` for now but + // the Smithy protocol test should give it to us. + renderHttpRequest( + uri.get(), + method, + headers, + body.orNull(), + bodyMediaType = null, + testCase.protocol, + queryParams, + host.orNull(), + ) } makeRequest( @@ -379,6 +416,8 @@ class ServerProtocolTestGenerator( method: String, headers: Map, body: String?, + bodyMediaType: String?, + protocol: ShapeId, queryParams: List, host: String?, ) { @@ -409,7 +448,26 @@ class ServerProtocolTestGenerator( // We also escape to avoid interactions with templating in the case where the body contains `#`. val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() - "#{SmithyHttpServer}::body::Body::from(#{Bytes}::from_static($sanitizedBody.as_bytes()))" + // TODO(https://github.com/smithy-lang/smithy/issues/1932): We're using the `protocol` field as a + // proxy for `bodyMediaType`. This works because `rpcv2Cbor` happens to be the only protocol where + // the body is base64-encoded in the protocol test, but checking `bodyMediaType` should be a more + // resilient check. + val encodedBody = + if (protocol.toShapeId() == ShapeId.from("smithy.protocols#rpcv2Cbor")) { + """ + #{Bytes}::from( + #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( + "`body` field of Smithy protocol test is not correctly base64 encoded" + ) + ) + """ + } else { + """ + #{Bytes}::from_static($sanitizedBody.as_bytes()) + """ + } + + "#{SmithyHttpServer}::body::Body::from($encodedBody)" } else { "#{SmithyHttpServer}::body::Body::empty()" } @@ -426,7 +484,7 @@ class ServerProtocolTestGenerator( } } - /** Returns the body of the request test. */ + /** Returns the body of the operation handler in a request test. */ private fun checkRequestHandler( operationShape: OperationShape, httpRequestTestCase: HttpRequestTestCase, @@ -434,7 +492,7 @@ class ServerProtocolTestGenerator( val inputShape = operationShape.inputShape(codegenContext.model) val outputShape = operationShape.outputShape(codegenContext.model) - // Construct expected request. + // Construct expected operation input. withBlock("let expected = ", ";") { instantiator.render(this, inputShape, httpRequestTestCase.params, httpRequestTestCase.headers) } @@ -442,14 +500,14 @@ class ServerProtocolTestGenerator( checkRequestParams(inputShape, this) // Construct a dummy response. - withBlock("let response = ", ";") { + withBlock("let output = ", ";") { instantiator.render(this, outputShape, Node.objectNode()) } if (operationShape.errors.isEmpty()) { - rust("response") + rust("output") } else { - rust("Ok(response)") + rust("Ok(output)") } } @@ -634,13 +692,13 @@ class ServerProtocolTestGenerator( rustWriter.rustTemplate( """ // No body. - #{AssertEq}(std::str::from_utf8(&body).unwrap(), ""); + #{AssertEq}(&body, &bytes::Bytes::new()); """, *codegenScope, ) } else { assertOk(rustWriter) { - rustWriter.rust( + rust( "#T(&body, ${ rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 3d94bb8821..8eeab9c22e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -54,7 +54,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.core.smithy.wrapOptional import software.amazon.smithy.rust.codegen.core.util.dq @@ -109,7 +109,7 @@ typealias ServerHttpBoundProtocolCustomization = NamedCustomization + setResponseHeaderIfAbsent(this, "content-type", contentTypeValue) + } + + for ((headerName, headerValue) in protocol.additionalResponseHeaders(operationShape)) { + setResponseHeaderIfAbsent(this, headerName, headerValue) } if (errorShape != null) { for ((headerName, headerValue) in protocol.additionalErrorResponseHeaders(errorShape)) { - rustTemplate( - """ - builder = #{header_util}::set_response_header_if_absent( - builder, - http::header::HeaderName::from_static("$headerName"), - "${escape(headerValue)}" - ); - """, - *codegenScope, - ) + setResponseHeaderIfAbsent(this, headerName, headerValue) } } } @@ -709,6 +720,28 @@ class ServerHttpBoundProtocolTraitImplGenerator( // there's something to parse (i.e. `parser != null`), so `!!` is safe here. val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!! rustTemplate("let bytes = #{Hyper}::body::to_bytes(body).await?;", *codegenScope) + // Note that the server is being very lenient here. We're accepting an empty body for when there is modeled + // operation input; we simply parse it as empty operation input. + // This behavior applies to all protocols. This might seem like a bug, but it isn't. There's protocol tests + // that assert that the server should be lenient and accept both empty payloads and no payload + // when there is modeled input: + // + // * [restJson1]: clients omit the payload altogether when the input is empty! So services must accept this. + // * [rpcv2Cbor]: services must accept no payload or empty CBOR map for operations with modeled input. + // + // For the AWS JSON 1.x protocols, services are lenient in the case when there is no modeled input: + // + // * [awsJson1_0]: services must accept no payload or empty JSON document payload for operations with no modeled input + // * [awsJson1_1]: services must accept no payload or empty JSON document payload for operations with no modeled input + // + // However, it's true that there are no tests pinning server behavior when there is _empty_ input. There's + // a [consultation with Smithy] to remedy this. Until that gets resolved, in the meantime, we are being lenient. + // + // [restJson1]: https://github.com/smithy-lang/smithy/blob/main/smithy-aws-protocol-tests/model/restJson1/empty-input-output.smithy#L22 + // [awsJson1_0]: https://github.com/smithy-lang/smithy/blob/main/smithy-aws-protocol-tests/model/awsJson1_0/empty-input-output.smithy + // [awsJson1_1]: https://github.com/smithy-lang/smithy/blob/main/smithy-aws-protocol-tests/model/awsJson1_1/empty-operation.smithy + // [rpcv2Cbor]: https://github.com/smithy-lang/smithy/blob/main/smithy-protocol-tests/model/rpcv2Cbor/empty-input-output.smithy + // [consultation with Smithy]: https://github.com/smithy-lang/smithy/issues/2327 rustBlock("if !bytes.is_empty()") { rustTemplate( """ @@ -750,7 +783,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( serverRenderQueryStringParser(this, operationShape) // If there's no modeled operation input, some protocols require that `Content-Type` header not be present. - val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null + val noInputs = !OperationNormalizer.hadUserModeledOperationInput(operationShape, model) if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) { rustTemplate( """ @@ -760,6 +793,9 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3723): we should inject a check here that asserts that + // the body contents are valid when there is empty operation input or no operation input. + val err = if (ServerBuilderGenerator.hasFallibleBuilder( inputShape, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index ae87ec5723..a121697eb7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -20,7 +21,7 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator -class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { +class StreamPayloadSerializerCustomization : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { is ServerHttpBoundProtocolSection.WrapStreamPayload -> @@ -79,6 +80,13 @@ class ServerProtocolLoader(supportedProtocols: ProtocolMap = + emptyList(), +) : ProtocolGeneratorFactory { + override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRpcV2CborProtocol(codegenContext) + + override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = + ServerHttpBoundProtocolGenerator( + codegenContext, + ServerRpcV2CborProtocol(codegenContext), + additionalServerHttpBoundProtocolCustomizations, + ) + + override fun support(): ProtocolSupport { + return ProtocolSupport( + // Client support + requestSerialization = false, + requestBodySerialization = false, + responseDeserialization = false, + errorDeserialization = false, + // Server support + requestDeserialization = true, + requestBodyDeserialization = true, + responseSerialization = true, + errorSerialization = true, + ) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt new file mode 100644 index 0000000000..2e92cde4e2 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -0,0 +1,358 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.protocoltests.traits.AppliesTo +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.BrokenTest +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRpcV2CborProtocol +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRpcV2CborFactory +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.util.function.Predicate +import java.util.logging.Logger + +/** + * This lives in `codegen-server` because we want to run a full integration test for convenience, + * but there's really nothing server-specific here. We're just testing that the CBOR (de)serializers work like + * the ones generated by `serde_cbor`. This is a good exhaustive litmus test for correctness, since `serde_cbor` + * is battle-tested. + */ +internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { + class DeriveSerdeSerializeDeserializeSymbolMetadataProvider( + private val base: RustSymbolProvider, + ) : SymbolMetadataProvider(base) { + private val serdeDeserialize = + CargoDependency.Serde.copy(scope = DependencyScope.Compile).toType().resolve("Deserialize") + private val serdeSerialize = + CargoDependency.Serde.copy(scope = DependencyScope.Compile).toType().resolve("Serialize") + + private fun addDeriveSerdeSerializeDeserialize(shape: Shape): RustMetadata { + check(shape !is MemberShape) + + val baseMetadata = base.toSymbol(shape).expectRustMetadata() + return baseMetadata.withDerives(serdeSerialize, serdeDeserialize) + } + + override fun memberMeta(memberShape: MemberShape): RustMetadata { + val baseMetadata = base.toSymbol(memberShape).expectRustMetadata() + return baseMetadata.copy( + additionalAttributes = + baseMetadata.additionalAttributes + + Attribute( + """serde(rename = "${memberShape.memberName}")""", + isDeriveHelper = true, + ), + ) + } + + override fun structureMeta(structureShape: StructureShape) = addDeriveSerdeSerializeDeserialize(structureShape) + + override fun unionMeta(unionShape: UnionShape) = addDeriveSerdeSerializeDeserialize(unionShape) + + override fun enumMeta(stringShape: StringShape) = addDeriveSerdeSerializeDeserialize(stringShape) + + override fun listMeta(listShape: ListShape): RustMetadata = addDeriveSerdeSerializeDeserialize(listShape) + + override fun mapMeta(mapShape: MapShape): RustMetadata = addDeriveSerdeSerializeDeserialize(mapShape) + + override fun stringMeta(stringShape: StringShape): RustMetadata = + addDeriveSerdeSerializeDeserialize(stringShape) + + override fun numberMeta(numberShape: NumberShape): RustMetadata = + addDeriveSerdeSerializeDeserialize(numberShape) + + override fun blobMeta(blobShape: BlobShape): RustMetadata = addDeriveSerdeSerializeDeserialize(blobShape) + } + + fun prepareRpcV2CborModel(): Model { + var model = Model.assembler().discoverModels().assemble().result.get() + + // Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on + // which we can't `#[derive(serde::Deserialize)]`. + // Note we can't use `ModelTransformer.removeShapes` because it will leave the model in an inconsistent state + // when removing list/set shape member shapes. + val removeTimestampAndBlobShapes: Predicate = + Predicate { shape -> + when (shape) { + is MemberShape -> { + val targetShape = model.expectShape(shape.target) + targetShape is BlobShape || targetShape is TimestampShape + } + is BlobShape, is TimestampShape -> true + is CollectionShape -> { + val targetShape = model.expectShape(shape.member.target) + targetShape is BlobShape || targetShape is TimestampShape + } + else -> false + } + } + + fun removeShapesByShapeId(shapeIds: Set): Predicate { + val predicate: Predicate = + Predicate { shape -> + when (shape) { + is MemberShape -> { + val targetShape = model.expectShape(shape.target) + shapeIds.contains(targetShape.id) + } + is CollectionShape -> { + val targetShape = model.expectShape(shape.member.target) + shapeIds.contains(targetShape.id) + } + else -> { + shapeIds.contains(shape.id) + } + } + } + return predicate + } + + val modelTransformer = ModelTransformer.create() + model = + modelTransformer.removeShapesIf( + modelTransformer.removeShapesIf(model, removeTimestampAndBlobShapes), + // These enums do not serialize their variants using the Rust members' names. + // We'd have to tack on `#[serde(rename = "name")]` using the proper name defined in the Smithy enum definition. + // But we have no way of injecting that attribute on Rust enum variants in the code generator. + // So we just remove these problematic shapes. + removeShapesByShapeId( + setOf( + ShapeId.from("smithy.protocoltests.shared#FooEnum"), + ShapeId.from("smithy.protocoltests.rpcv2Cbor#TestEnum"), + ), + ), + ) + + return model + } + + @Test + fun `serde_cbor round trip`() { + val addDeriveSerdeSerializeDeserializeDecorator = + object : ServerCodegenDecorator { + override val name: String = "Add `#[derive(serde::Serialize, serde::Deserialize)]`" + override val order: Byte = 0 + + override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = + DeriveSerdeSerializeDeserializeSymbolMetadataProvider(base) + } + + // Don't generate protocol tests, because it'll attempt to pull out `params` for member shapes we'll remove + // from the model. + val noProtocolTestsDecorator = + object : ServerCodegenDecorator { + override val name: String = "Don't generate protocol tests" + override val order: Byte = 0 + + override fun protocolTestGenerator( + codegenContext: ServerCodegenContext, + baseGenerator: ProtocolTestGenerator, + ): ProtocolTestGenerator { + val noOpProtocolTestsGenerator = + object : ProtocolTestGenerator() { + override val codegenContext: CodegenContext + get() = baseGenerator.codegenContext + override val protocolSupport: ProtocolSupport + get() = baseGenerator.protocolSupport + override val operationShape: OperationShape + get() = baseGenerator.operationShape + override val appliesTo: AppliesTo + get() = baseGenerator.appliesTo + override val logger: Logger + get() = Logger.getLogger(javaClass.name) + override val expectFail: Set + get() = baseGenerator.expectFail + override val brokenTests: Set + get() = emptySet() + override val generateOnly: Set + get() = baseGenerator.generateOnly + override val disabledTests: Set + get() = baseGenerator.disabledTests + + override fun RustWriter.renderAllTestCases(allTests: List) { + // No-op. + } + } + return noOpProtocolTestsGenerator + } + } + + val model = prepareRpcV2CborModel() + val serviceShape = model.expectShape(ShapeId.from(RPC_V2_CBOR)) + serverIntegrationTest( + model, + additionalDecorators = listOf(addDeriveSerdeSerializeDeserializeDecorator, noProtocolTestsDecorator), + params = IntegrationTestParams(service = serviceShape.id.toString()), + ) { codegenContext, rustCrate -> + // TODO(https://github.com/smithy-lang/smithy-rs/issues/1147): NaN != NaN. Ideally we when we address + // this issue, we'd re-use the structure shape comparison code that both client and server protocol test + // generators would use. + val expectFail = setOf("RpcV2CborSupportsNaNFloatInputs", "RpcV2CborSupportsNaNFloatOutputs") + + val codegenScope = + arrayOf( + "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), + "SerdeCbor" to CargoDependency.SerdeCbor.toType(), + ) + + val instantiator = ServerInstantiator(codegenContext, ignoreMissingMembers = true, withinTest = true) + val rpcv2Cbor = ServerRpcV2CborProtocol(codegenContext) + + for (operationShape in codegenContext.model.operationShapes) { + val serverProtocolTestGenerator = + ServerProtocolTestGenerator(codegenContext, ServerRpcV2CborFactory().support(), operationShape) + + rustCrate.withModule(ProtocolFunctions.serDeModule) { + // The SDK can only serialize operation outputs, so we only ask for response tests. + val responseTests = + serverProtocolTestGenerator.responseTestCases() + + for (test in responseTests) { + when (test) { + is TestCase.MalformedRequestTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.RequestTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.ResponseTest -> { + val targetShape = test.targetShape + val params = test.testCase.params + + val serializeFn = + if (targetShape.hasTrait()) { + rpcv2Cbor.structuredDataSerializer().serverErrorSerializer(targetShape.id) + } else { + rpcv2Cbor.structuredDataSerializer().operationOutputSerializer(operationShape) + } + + if (serializeFn == null) { + // Skip if there's nothing to serialize. + continue + } + + if (expectFail.contains(test.id)) { + writeWithNoFormatting("#[should_panic]") + } + unitTest("we_serialize_and_serde_cbor_deserializes_${test.id.toSnakeCase()}_${test.kind.toString().toSnakeCase()}") { + rustTemplate( + """ + let expected = #{InstantiateShape:W}; + let bytes = #{SerializeFn}(&expected) + .expect("our generated CBOR serializer failed"); + let actual = #{SerdeCbor}::from_slice(&bytes) + .expect("serde_cbor failed deserializing from bytes"); + #{AssertEq}(expected, actual); + """, + "InstantiateShape" to instantiator.generate(targetShape, params), + "SerializeFn" to serializeFn, + *codegenScope, + ) + } + } + } + } + + // The SDK can only deserialize operation inputs, so we only ask for request tests. + val requestTests = + serverProtocolTestGenerator.requestTestCases() + val inputShape = operationShape.inputShape(codegenContext.model) + val err = + if (ServerBuilderGenerator.hasFallibleBuilder( + inputShape, + codegenContext.model, + codegenContext.symbolProvider, + takeInUnconstrainedTypes = true, + ) + ) { + """.expect("builder failed to build")""" + } else { + "" + } + + for (test in requestTests) { + when (test) { + is TestCase.MalformedRequestTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.ResponseTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.RequestTest -> { + val targetShape = operationShape.inputShape(codegenContext.model) + val params = test.testCase.params + + val deserializeFn = + rpcv2Cbor.structuredDataParser().serverInputParser(operationShape) + ?: // Skip if there's nothing to serialize. + continue + + if (expectFail.contains(test.id)) { + writeWithNoFormatting("#[should_panic]") + } + unitTest("serde_cbor_serializes_and_we_deserialize_${test.id.toSnakeCase()}_${test.kind.toString().toSnakeCase()}") { + rustTemplate( + """ + let expected = #{InstantiateShape:W}; + let bytes: Vec = #{SerdeCbor}::to_vec(&expected) + .expect("serde_cbor failed serializing to `Vec`"); + let input = #{InputBuilder}::default(); + let input = #{DeserializeFn}(&bytes, input) + .expect("our generated CBOR deserializer failed"); + let actual = input.build()$err; + #{AssertEq}(expected, actual); + """, + "InstantiateShape" to instantiator.generate(targetShape, params), + "DeserializeFn" to deserializeFn, + "InputBuilder" to inputShape.serverBuilderSymbol(codegenContext), + *codegenScope, + ) + } + } + } + } + } + } + } + } +} diff --git a/examples/Cargo.toml b/examples/Cargo.toml index d92869a661..a374adf6f0 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -9,7 +9,6 @@ members = [ "pokemon-service-server-sdk", "pokemon-service-client", "pokemon-service-client-usage", - ] [profile.release] diff --git a/rust-runtime/Cargo.lock b/rust-runtime/Cargo.lock index 1e929f7b58..602ebd4502 100644 --- a/rust-runtime/Cargo.lock +++ b/rust-runtime/Cargo.lock @@ -302,6 +302,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "aws-smithy-cbor" +version = "0.60.6" +dependencies = [ + "aws-smithy-types 1.2.0", + "criterion", + "minicbor", +] + [[package]] name = "aws-smithy-checksums" version = "0.60.10" @@ -398,7 +407,7 @@ name = "aws-smithy-experimental" version = "0.1.3" dependencies = [ "aws-smithy-async 1.2.1", - "aws-smithy-runtime 1.6.1", + "aws-smithy-runtime 1.6.2", "aws-smithy-runtime-api 1.7.1", "aws-smithy-types 1.2.0", "h2 0.4.5", @@ -465,8 +474,9 @@ version = "0.60.3" [[package]] name = "aws-smithy-http-server" -version = "0.63.0" +version = "0.63.2" dependencies = [ + "aws-smithy-cbor", "aws-smithy-http 0.60.9", "aws-smithy-json 0.60.7", "aws-smithy-runtime-api 1.7.1", @@ -495,7 +505,7 @@ dependencies = [ [[package]] name = "aws-smithy-http-server-python" -version = "0.62.1" +version = "0.63.1" dependencies = [ "aws-smithy-http 0.60.9", "aws-smithy-http-server", @@ -582,14 +592,17 @@ dependencies = [ [[package]] name = "aws-smithy-protocol-test" -version = "0.60.8" +version = "0.62.0" dependencies = [ "assert-json-diff", "aws-smithy-runtime-api 1.7.1", + "base64-simd", + "cbor-diag", "http 0.2.12", "pretty_assertions", "regex-lite", "roxmltree", + "serde_cbor", "serde_json", "thiserror", ] @@ -635,12 +648,12 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.6.1" +version = "1.6.2" dependencies = [ "approx", "aws-smithy-async 1.2.1", "aws-smithy-http 0.60.9", - "aws-smithy-protocol-test 0.60.8", + "aws-smithy-protocol-test 0.62.0", "aws-smithy-runtime-api 1.7.1", "aws-smithy-types 1.2.0", "bytes", @@ -789,7 +802,7 @@ dependencies = [ name = "aws-smithy-xml" version = "0.60.8" dependencies = [ - "aws-smithy-protocol-test 0.60.8", + "aws-smithy-protocol-test 0.62.0", "base64 0.13.1", "proptest", "xmlparser", @@ -954,6 +967,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bs58" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf88ba1141d185c399bee5288d850d63b8369520c1eafc32a0430b5b6c287bf4" +dependencies = [ + "tinyvec", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -985,6 +1007,25 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cbor-diag" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc245b6ecd09b23901a4fbad1ad975701fd5061ceaef6afa93a2d70605a64429" +dependencies = [ + "bs58", + "chrono", + "data-encoding", + "half 2.4.1", + "nom", + "num-bigint", + "num-rational", + "num-traits", + "separator", + "url", + "uuid", +] + [[package]] name = "cc" version = "1.0.99" @@ -1044,7 +1085,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" dependencies = [ "ciborium-io", - "half", + "half 2.4.1", ] [[package]] @@ -1282,6 +1323,12 @@ dependencies = [ "typenum", ] +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + [[package]] name = "der" version = "0.6.1" @@ -1634,6 +1681,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.4.1" @@ -2168,6 +2221,27 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minicbor" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f8e213c36148d828083ae01948eed271d03f95f7e72571fa242d78184029af2" +dependencies = [ + "half 2.4.1", + "minicbor-derive", +] + +[[package]] +name = "minicbor-derive" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6bdc119b1a405df86a8cde673295114179dbd0ebe18877c26ba89fb080365c2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.67", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2220,6 +2294,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -2235,6 +2319,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -3032,6 +3127,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +[[package]] +name = "separator" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f97841a747eef040fcd2e7b3b9a220a7205926e60488e673d9e4926d27772ce5" + [[package]] name = "serde" version = "1.0.203" @@ -3041,6 +3142,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.203" diff --git a/rust-runtime/Cargo.toml b/rust-runtime/Cargo.toml index 3a618e6bbe..d7853b0099 100644 --- a/rust-runtime/Cargo.toml +++ b/rust-runtime/Cargo.toml @@ -3,6 +3,7 @@ resolver = "2" members = [ "inlineable", "aws-smithy-async", + "aws-smithy-cbor", "aws-smithy-checksums", "aws-smithy-compression", "aws-smithy-client", diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml new file mode 100644 index 0000000000..b87366d6ef --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "aws-smithy-cbor" +version = "0.60.6" +authors = [ + "AWS Rust SDK Team ", + "David Pérez ", +] +description = "CBOR utilities for smithy-rs." +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/awslabs/smithy-rs" + +[dependencies.minicbor] +version = "0.24.2" +features = [ + # To write to a `Vec`: https://docs.rs/minicbor/latest/minicbor/encode/write/trait.Write.html#impl-Write-for-Vec%3Cu8%3E + "alloc", + # To support reading `f16` to accomodate fewer bytes transmitted that fit the value. + "half", +] + +[dependencies] +aws-smithy-types = { path = "../aws-smithy-types" } + +[dev-dependencies] +criterion = "0.5.1" + +[[bench]] +name = "string" +harness = false + +[[bench]] +name = "blob" +harness = false + +[package.metadata.docs.rs] +all-features = true +targets = ["x86_64-unknown-linux-gnu"] +cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] +rustdoc-args = ["--cfg", "docsrs"] +# End of docs.rs metadata diff --git a/rust-runtime/aws-smithy-cbor/LICENSE b/rust-runtime/aws-smithy-cbor/LICENSE new file mode 100644 index 0000000000..67db858821 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. diff --git a/rust-runtime/aws-smithy-cbor/README.md b/rust-runtime/aws-smithy-cbor/README.md new file mode 100644 index 0000000000..367577b3e5 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/README.md @@ -0,0 +1,8 @@ +# aws-smithy-cbor + +CBOR serialization and deserialization primitives for clients and servers +generated by [smithy-rs](https://github.com/smithy-lang/smithy-rs). + + +This crate is part of the [AWS SDK for Rust](https://awslabs.github.io/aws-sdk-rust/) and the [smithy-rs](https://github.com/smithy-lang/smithy-rs) code generator. In most cases, it should not be used directly. + diff --git a/rust-runtime/aws-smithy-cbor/benches/blob.rs b/rust-runtime/aws-smithy-cbor/benches/blob.rs new file mode 100644 index 0000000000..221940bb98 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/benches/blob.rs @@ -0,0 +1,26 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_cbor::decode::Decoder; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn blob_benchmark(c: &mut Criterion) { + // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`. + let blob_indefinite_bytes = [ + 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62, 0x79, + 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c, 0x4e, 0x20, + 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0xff, + ]; + + c.bench_function("blob", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&blob_indefinite_bytes); + let _ = black_box(decoder.blob()); + }) + }); +} + +criterion_group!(benches, blob_benchmark); +criterion_main!(benches); diff --git a/rust-runtime/aws-smithy-cbor/benches/string.rs b/rust-runtime/aws-smithy-cbor/benches/string.rs new file mode 100644 index 0000000000..f60ff353e0 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/benches/string.rs @@ -0,0 +1,136 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::borrow::Cow; + +use aws_smithy_cbor::decode::Decoder; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn str_benchmark(c: &mut Criterion) { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, 0x79, + 0xff, + ]; + + c.bench_function("definite str()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&definite_bytes); + let x = black_box(decoder.str()); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("definite str_alt", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let x = black_box(str_alt(&mut decoder)); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("indefinite str()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&indefinite_bytes); + let x = black_box(decoder.str()); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("indefinite str_alt", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let x = black_box(str_alt(&mut decoder)); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); +} + +// The following seems to be a bit slower than the implementation that we have +// kept in the `aws_smithy_cbor::Decoder`. +pub fn string_alt<'b>( + decoder: &'b mut minicbor::Decoder<'b>, +) -> Result { + decoder.str_iter()?.collect() +} + +// The following seems to be a bit slower than the implementation that we have +// kept in the `aws_smithy_cbor::Decoder`. +fn str_alt<'b>( + decoder: &'b mut minicbor::Decoder<'b>, +) -> Result, minicbor::decode::Error> { + // This implementation uses `next` twice to see if there is + // another str chunk. If there is, it returns a owned `String`. + let mut chunks_iter = decoder.str_iter()?; + let head = match chunks_iter.next() { + Some(Ok(head)) => head, + None => return Ok(Cow::Borrowed("")), + Some(Err(e)) => return Err(e), + }; + + match chunks_iter.next() { + None => Ok(Cow::Borrowed(head)), + Some(Err(e)) => Err(e), + Some(Ok(next)) => { + let mut concatenated_string = String::from(head); + concatenated_string.push_str(next); + for chunk in chunks_iter { + concatenated_string.push_str(chunk?); + } + Ok(Cow::Owned(concatenated_string)) + } + } +} + +// We have two `string` implementations. One uses `collect` the other +// uses `String::new` followed by `string::push`. +pub fn string_benchmark(c: &mut Criterion) { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, 0x79, + 0xff, + ]; + + c.bench_function("definite string()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&definite_bytes); + let _ = black_box(decoder.string()); + }) + }); + + c.bench_function("definite string_alt()", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let _ = black_box(string_alt(&mut decoder)); + }) + }); + + c.bench_function("indefinite string()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&indefinite_bytes); + let _ = black_box(decoder.string()); + }) + }); + + c.bench_function("indefinite string_alt()", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let _ = black_box(string_alt(&mut decoder)); + }) + }); +} + +criterion_group!(benches, string_benchmark, str_benchmark,); +criterion_main!(benches); diff --git a/rust-runtime/aws-smithy-cbor/src/data.rs b/rust-runtime/aws-smithy-cbor/src/data.rs new file mode 100644 index 0000000000..e3bfdad2d9 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/data.rs @@ -0,0 +1,102 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Debug, Hash)] +pub enum Type { + Bool, + Null, + Undefined, + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, + Int, + F16, + F32, + F64, + Simple, + Bytes, + BytesIndef, + String, + StringIndef, + Array, + ArrayIndef, + Map, + MapIndef, + Tag, + Break, + Unknown(u8), +} + +impl Type { + pub(crate) fn new(ty: minicbor::data::Type) -> Self { + match ty { + minicbor::data::Type::Bool => Self::Bool, + minicbor::data::Type::Null => Self::Null, + minicbor::data::Type::Undefined => Self::Undefined, + minicbor::data::Type::U8 => Self::U8, + minicbor::data::Type::U16 => Self::U16, + minicbor::data::Type::U32 => Self::U32, + minicbor::data::Type::U64 => Self::U64, + minicbor::data::Type::I8 => Self::I8, + minicbor::data::Type::I16 => Self::I16, + minicbor::data::Type::I32 => Self::I32, + minicbor::data::Type::I64 => Self::I64, + minicbor::data::Type::Int => Self::Int, + minicbor::data::Type::F16 => Self::F16, + minicbor::data::Type::F32 => Self::F32, + minicbor::data::Type::F64 => Self::F64, + minicbor::data::Type::Simple => Self::Simple, + minicbor::data::Type::Bytes => Self::Bytes, + minicbor::data::Type::BytesIndef => Self::BytesIndef, + minicbor::data::Type::String => Self::String, + minicbor::data::Type::StringIndef => Self::StringIndef, + minicbor::data::Type::Array => Self::Array, + minicbor::data::Type::ArrayIndef => Self::ArrayIndef, + minicbor::data::Type::Map => Self::Map, + minicbor::data::Type::MapIndef => Self::MapIndef, + minicbor::data::Type::Tag => Self::Tag, + minicbor::data::Type::Break => Self::Break, + minicbor::data::Type::Unknown(byte) => Self::Unknown(byte), + } + } + + // This is just the reverse mapping of `new`. + pub(crate) fn into_minicbor_type(self) -> minicbor::data::Type { + match self { + Type::Bool => minicbor::data::Type::Bool, + Type::Null => minicbor::data::Type::Null, + Type::Undefined => minicbor::data::Type::Undefined, + Type::U8 => minicbor::data::Type::U8, + Type::U16 => minicbor::data::Type::U16, + Type::U32 => minicbor::data::Type::U32, + Type::U64 => minicbor::data::Type::U64, + Type::I8 => minicbor::data::Type::I8, + Type::I16 => minicbor::data::Type::I16, + Type::I32 => minicbor::data::Type::I32, + Type::I64 => minicbor::data::Type::I64, + Type::Int => minicbor::data::Type::Int, + Type::F16 => minicbor::data::Type::F16, + Type::F32 => minicbor::data::Type::F32, + Type::F64 => minicbor::data::Type::F64, + Type::Simple => minicbor::data::Type::Simple, + Type::Bytes => minicbor::data::Type::Bytes, + Type::BytesIndef => minicbor::data::Type::BytesIndef, + Type::String => minicbor::data::Type::String, + Type::StringIndef => minicbor::data::Type::StringIndef, + Type::Array => minicbor::data::Type::Array, + Type::ArrayIndef => minicbor::data::Type::ArrayIndef, + Type::Map => minicbor::data::Type::Map, + Type::MapIndef => minicbor::data::Type::MapIndef, + Type::Tag => minicbor::data::Type::Tag, + Type::Break => minicbor::data::Type::Break, + Type::Unknown(byte) => minicbor::data::Type::Unknown(byte), + } + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs new file mode 100644 index 0000000000..3cfe070397 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -0,0 +1,341 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::borrow::Cow; + +use aws_smithy_types::{Blob, DateTime}; +use minicbor::decode::Error; + +use crate::data::Type; + +/// Provides functions for decoding a CBOR object with a known schema. +/// +/// Although CBOR is a self-describing format, this decoder is tailored for cases where the schema +/// is known in advance. Therefore, the caller can determine which object key exists at the current +/// position by calling `str` method, and call the relevant function based on the predetermined schema +/// for that key. If an unexpected key is encountered, the caller can use the `skip` method to skip +/// over the element. +#[derive(Debug, Clone)] +pub struct Decoder<'b> { + decoder: minicbor::Decoder<'b>, +} + +/// When any of the decode methods are called they look for that particular data type at the current +/// position. If the CBOR data tag does not match the type, a `DeserializeError` is returned. +#[derive(Debug)] +pub struct DeserializeError { + #[allow(dead_code)] + _inner: Error, +} + +impl std::fmt::Display for DeserializeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self._inner.fmt(f) + } +} + +impl std::error::Error for DeserializeError {} + +impl DeserializeError { + pub(crate) fn new(inner: Error) -> Self { + Self { _inner: inner } + } + + /// More than one union variant was detected: `unexpected_type` was unexpected. + pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self { + Self { + _inner: Error::type_mismatch(unexpected_type.into_minicbor_type()) + .with_message("encountered unexpected union variant; expected end of union") + .at(at), + } + } + + /// Unknown union variant was detected. Servers reject unknown union varaints. + pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self { + Self { + _inner: Error::message(format!( + "encountered unknown union variant {}", + variant_name + )) + .at(at), + } + } + + /// More than one union variant was detected, but we never even got to parse the first one. + /// We immediately raise this error when detecting a union serialized as a fixed-length CBOR + /// map whose length (specified upfront) is a value different than 1. + pub fn mixed_union_variants(at: usize) -> Self { + Self { + _inner: Error::message( + "encountered mixed variants in union; expected a single union variant to be set", + ) + .at(at), + } + } + + /// Expected end of stream but more data is available. + pub fn expected_end_of_stream(at: usize) -> Self { + Self { + _inner: Error::message("encountered additional data; expected end of stream").at(at), + } + } + + /// An unexpected type was encountered. + // We handle this one when decoding sparse collections: we have to expect either a `null` or an + // item, so we try decoding both. + pub fn is_type_mismatch(&self) -> bool { + self._inner.is_type_mismatch() + } +} + +/// Macro for delegating method calls to the decoder. +/// +/// This macro generates wrapper methods for calling specific methods on the decoder and returning +/// the result with error handling. +/// +/// # Example +/// +/// ```ignore +/// delegate_method! { +/// /// Wrapper method for encoding method `encode_str` on the decoder. +/// encode_str_wrapper => encode_str(String); +/// /// Wrapper method for encoding method `encode_int` on the decoder. +/// encode_int_wrapper => encode_int(i32); +/// } +/// ``` +macro_rules! delegate_method { + ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => { + $( + pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> { + self.decoder.$encoder_name().map_err(DeserializeError::new) + } + )+ + }; +} + +impl<'b> Decoder<'b> { + pub fn new(bytes: &'b [u8]) -> Self { + Self { + decoder: minicbor::Decoder::new(bytes), + } + } + + pub fn datatype(&self) -> Result { + self.decoder + .datatype() + .map(Type::new) + .map_err(DeserializeError::new) + } + + delegate_method! { + /// Skips the current CBOR element. + skip => skip(()); + /// Reads a boolean at the current position. + boolean => bool(bool); + /// Reads a byte at the current position. + byte => i8(i8); + /// Reads a short at the current position. + short => i16(i16); + /// Reads a integer at the current position. + integer => i32(i32); + /// Reads a long at the current position. + long => i64(i64); + /// Reads a float at the current position. + float => f32(f32); + /// Reads a double at the current position. + double => f64(f64); + /// Reads a null CBOR element at the current position. + null => null(()); + /// Returns the number of elements in a definite list. For indefinite lists it returns a `None`. + list => array(Option); + /// Returns the number of elements in a definite map. For indefinite map it returns a `None`. + map => map(Option); + } + + /// Returns the current position of the buffer, which will be decoded when any of the methods is called. + pub fn position(&self) -> usize { + self.decoder.position() + } + + /// Returns a `Cow::Borrowed(&str)` if the element at the current position in the buffer is a definite + /// length string. Otherwise, it returns a `Cow::Owned(String)` if the element at the current position is an + /// indefinite-length string. An error is returned if the element is neither a definite length nor an + /// indefinite-length string. + pub fn str(&mut self) -> Result, DeserializeError> { + let bookmark = self.decoder.position(); + match self.decoder.str() { + Ok(str_value) => Ok(Cow::Borrowed(str_value)), + Err(e) if e.is_type_mismatch() => { + // Move the position back to the start of the CBOR element and then try + // decoding it as an indefinite length string. + self.decoder.set_position(bookmark); + Ok(Cow::Owned(self.string()?)) + } + Err(e) => Err(DeserializeError::new(e)), + } + } + + /// Allocates and returns a `String` if the element at the current position in the buffer is either a + /// definite-length or an indefinite-length string. Otherwise, an error is returned if the element is not a string type. + pub fn string(&mut self) -> Result { + let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?; + let head = iter.next(); + + let decoded_string = match head { + None => String::new(), + Some(head) => { + let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?); + for chunk in iter { + combined_chunks.push_str(chunk.map_err(DeserializeError::new)?); + } + combined_chunks + } + }; + + Ok(decoded_string) + } + + /// Returns a `blob` if the element at the current position in the buffer is a byte string. Otherwise, + /// a `DeserializeError` error is returned. + pub fn blob(&mut self) -> Result { + let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?; + let parts: Vec<&[u8]> = iter + .collect::>() + .map_err(DeserializeError::new)?; + + Ok(if parts.len() == 1 { + Blob::new(parts[0]) // Directly convert &[u8] to Blob if there's only one part. + } else { + Blob::new(parts.concat()) // Concatenate all parts into a single Blob. + }) + } + + /// Returns a `DateTime` if the element at the current position in the buffer is a `timestamp`. Otherwise, + /// a `DeserializeError` error is returned. + pub fn timestamp(&mut self) -> Result { + let tag = self.decoder.tag().map_err(DeserializeError::new)?; + let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp); + + if tag != timestamp_tag { + Err(DeserializeError::new(Error::message( + "expected timestamp tag", + ))) + } else { + let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?; + Ok(DateTime::from_secs_f64(epoch_seconds)) + } + } +} + +#[derive(Debug)] +pub struct ArrayIter<'a, 'b, T> { + inner: minicbor::decode::ArrayIter<'a, 'b, T>, +} + +impl<'a, 'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'a, 'b, T> { + type Item = Result; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|opt| opt.map_err(DeserializeError::new)) + } +} + +#[derive(Debug)] +pub struct MapIter<'a, 'b, K, V> { + inner: minicbor::decode::MapIter<'a, 'b, K, V>, +} + +impl<'a, 'b, K, V> Iterator for MapIter<'a, 'b, K, V> +where + K: minicbor::Decode<'b, ()>, + V: minicbor::Decode<'b, ()>, +{ + type Item = Result<(K, V), DeserializeError>; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|opt| opt.map_err(DeserializeError::new)) + } +} + +pub fn set_optional(builder: B, decoder: &mut Decoder, f: F) -> Result +where + F: Fn(B, &mut Decoder) -> Result, +{ + match decoder.datatype()? { + crate::data::Type::Null => { + decoder.null()?; + Ok(builder) + } + _ => f(builder, decoder), + } +} + +#[cfg(test)] +mod tests { + use crate::Decoder; + + #[test] + fn test_definite_str_is_cow_borrowed() { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + let mut decoder = Decoder::new(&definite_bytes); + let member = decoder.str().expect("could not decode str"); + assert_eq!(member, "thisIsAKey"); + assert!(matches!(member, std::borrow::Cow::Borrowed(_))); + } + + #[test] + fn test_indefinite_str_is_cow_owned() { + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, + 0x79, 0xff, + ]; + let mut decoder = Decoder::new(&indefinite_bytes); + let member = decoder.str().expect("could not decode str"); + assert_eq!(member, "thisIsAKey"); + assert!(matches!(member, std::borrow::Cow::Owned(_))); + } + + #[test] + fn test_empty_str_works() { + let bytes = [0x60]; + let mut decoder = Decoder::new(&bytes); + let member = decoder.str().expect("could not decode empty str"); + assert_eq!(member, ""); + } + + #[test] + fn test_empty_blob_works() { + let bytes = [0x40]; + let mut decoder = Decoder::new(&bytes); + let member = decoder.blob().expect("could not decode an empty blob"); + assert_eq!(member, aws_smithy_types::Blob::new(&[])); + } + + #[test] + fn test_indefinite_length_blob() { + // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`. + // https://cbor.nemo157.com/#type=hex&value=bf69626c6f6256616c75655f50696e646566696e6974652d627974652c49206368756e6b65642c4e206f6e206561636820636f6d6d61ffff + let indefinite_bytes = [ + 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62, + 0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c, + 0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, + 0x61, 0xff, + ]; + let mut decoder = Decoder::new(&indefinite_bytes); + let member = decoder.blob().expect("could not decode blob"); + assert_eq!( + member, + aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes()) + ); + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs new file mode 100644 index 0000000000..1651c37f9b --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -0,0 +1,117 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_types::{Blob, DateTime}; + +/// Macro for delegating method calls to the encoder. +/// +/// This macro generates wrapper methods for calling specific encoder methods on the encoder +/// and returning a mutable reference to self for method chaining. +/// +/// # Example +/// +/// ```ignore +/// delegate_method! { +/// /// Wrapper method for encoding method `encode_str` on the encoder. +/// encode_str_wrapper => encode_str(data: &str); +/// /// Wrapper method for encoding method `encode_int` on the encoder. +/// encode_int_wrapper => encode_int(value: i32); +/// } +/// ``` +macro_rules! delegate_method { + ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($($param_name:ident : $param_type:ty),*);)+) => { + $( + pub fn $wrapper_name(&mut self, $($param_name: $param_type),*) -> &mut Self { + self.encoder.$encoder_name($($param_name)*).expect(INFALLIBLE_WRITE); + self + } + )+ + }; +} + +#[derive(Debug, Clone)] +pub struct Encoder { + encoder: minicbor::Encoder>, +} + +/// We always write to a `Vec`, which is infallible in `minicbor`. +/// +const INFALLIBLE_WRITE: &str = "write failed"; + +impl Encoder { + pub fn new(writer: Vec) -> Self { + Self { + encoder: minicbor::Encoder::new(writer), + } + } + + delegate_method! { + /// Used when it's not cheap to calculate the size, i.e. when the struct has one or more + /// `Option`al members. + begin_map => begin_map(); + /// Writes a definite length string. + str => str(x: &str); + /// Writes a boolean value. + boolean => bool(x: bool); + /// Writes a byte value. + byte => i8(x: i8); + /// Writes a short value. + short => i16(x: i16); + /// Writes an integer value. + integer => i32(x: i32); + /// Writes an long value. + long => i64(x: i64); + /// Writes an float value. + float => f32(x: f32); + /// Writes an double value. + double => f64(x: f64); + /// Writes a null tag. + null => null(); + /// Writes an end tag. + end => end(); + } + + pub fn blob(&mut self, x: &Blob) -> &mut Self { + self.encoder.bytes(x.as_ref()).expect(INFALLIBLE_WRITE); + self + } + + /// Writes a fixed length array of given length. + pub fn array(&mut self, len: usize) -> &mut Self { + self.encoder + // `.expect()` safety: `From for usize` is not in the standard library, + // but the conversion should be infallible (unless we ever have 128-bit machines I + // guess). . + .array(len.try_into().expect("`usize` to `u64` conversion failed")) + .expect(INFALLIBLE_WRITE); + self + } + + /// Writes a fixed length map of given length. + /// Used when we know the size in advance, i.e.: + /// - when a struct has all non-`Option`al members. + /// - when serializing `union` shapes (they can only have one member set). + /// - when serializing a `map` shape. + pub fn map(&mut self, len: usize) -> &mut Self { + self.encoder + .map(len.try_into().expect("`usize` to `u64` conversion failed")) + .expect(INFALLIBLE_WRITE); + self + } + + pub fn timestamp(&mut self, x: &DateTime) -> &mut Self { + self.encoder + .tag(minicbor::data::Tag::from( + minicbor::data::IanaTag::Timestamp, + )) + .expect(INFALLIBLE_WRITE); + self.encoder.f64(x.as_secs_f64()).expect(INFALLIBLE_WRITE); + self + } + + pub fn into_writer(self) -> Vec { + self.encoder.into_writer() + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/lib.rs b/rust-runtime/aws-smithy-cbor/src/lib.rs new file mode 100644 index 0000000000..6db4813980 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/lib.rs @@ -0,0 +1,17 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! CBOR abstractions for Smithy. + +/* Automatically managed default lints */ +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +/* End of automatically managed default lints */ + +pub mod data; +pub mod decode; +pub mod encode; + +pub use decode::Decoder; +pub use encode::Encoder; diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index a63a125941..73c2ba42d1 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-http-server" -version = "0.63.1" +version = "0.63.2" authors = ["Smithy Rust Server "] edition = "2021" license = "Apache-2.0" @@ -23,6 +23,7 @@ aws-smithy-json = { path = "../aws-smithy-json" } aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["http-02x"] } aws-smithy-types = { path = "../aws-smithy-types", features = ["http-body-0-4-x", "hyper-0-14-x"] } aws-smithy-xml = { path = "../aws-smithy-xml" } +aws-smithy-cbor = { path = "../aws-smithy-cbor" } bytes = "1.1" futures-util = { version = "0.3.29", default-features = false } http = "0.2" diff --git a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs index ac7a645f26..cd9e333bb2 100644 --- a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs +++ b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs @@ -175,6 +175,8 @@ where type Future = UpgradeFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // The check that the inner service is ready is done by `Oneshot` in `UpgradeFuture`'s + // implementation. Poll::Ready(Ok(())) } diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs index df304d823f..38538fe1e9 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs @@ -39,9 +39,9 @@ pub enum Error { // This constant determines when the `TinyMap` implementation switches from being a `Vec` to a // `HashMap`. This is chosen to be 15 as a result of the discussion around // https://github.com/smithy-lang/smithy-rs/pull/1429#issuecomment-1147516546 -const ROUTE_CUTOFF: usize = 15; +pub(crate) const ROUTE_CUTOFF: usize = 15; -/// A [`Router`] supporting [`AWS JSON 1.0`] and [`AWS JSON 1.1`] protocols. +/// A [`Router`] supporting [AWS JSON 1.0] and [AWS JSON 1.1] protocols. /// /// [AWS JSON 1.0]: https://smithy.io/2.0/aws/protocols/aws-json-1_0-protocol.html /// [AWS JSON 1.1]: https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs index 8dfd48e466..a3e8f2c919 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs @@ -5,5 +5,5 @@ pub mod router; -/// [AWS JSON 1.0 Protocol](https://smithy.io/2.0/aws/protocols/aws-json-1_0-protocol.html). +/// [AWS JSON 1.0](https://smithy.io/2.0/aws/protocols/aws-json-1_0-protocol.html) protocol. pub struct AwsJson1_0; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs index 30a28d6255..ac963ffe51 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs @@ -12,6 +12,8 @@ use super::AwsJson1_0; pub use crate::protocol::aws_json::router::*; +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs index 6fb09920a0..697aae52d3 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs @@ -5,5 +5,5 @@ pub mod router; -/// [AWS JSON 1.1 Protocol](https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html). +/// [AWS JSON 1.1](https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html) protocol. pub struct AwsJson1_1; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs index 5ebd1002f2..2e3e16d8ad 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs @@ -12,6 +12,8 @@ use super::AwsJson1_1; pub use crate::protocol::aws_json::router::*; +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs index 27a16d9f18..6d6bbf3b65 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs @@ -9,6 +9,7 @@ pub mod aws_json_11; pub mod rest; pub mod rest_json_1; pub mod rest_xml; +pub mod rpc_v2_cbor; use crate::rejection::MissingContentTypeReason; use aws_smithy_runtime_api::http::Headers as SmithyHeaders; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs index 19a644e7e4..94f99a98df 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs @@ -26,10 +26,10 @@ pub enum Error { MethodNotAllowed, } -/// A [`Router`] supporting [`AWS REST JSON 1.0`] and [`AWS REST XML`] protocols. +/// A [`Router`] supporting [AWS restJson1] and [AWS restXml] protocols. /// -/// [AWS REST JSON 1.0]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restjson1-protocol.html -/// [AWS REST XML]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restxml-protocol.html +/// [AWS restJson1]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restjson1-protocol.html +/// [AWS restXml]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restxml-protocol.html #[derive(Debug, Clone)] pub struct RestRouter { routes: Vec<(RequestSpec, S)>, diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs index f8384578d2..695d995ce1 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs @@ -7,5 +7,5 @@ pub mod rejection; pub mod router; pub mod runtime_error; -/// [AWS REST JSON 1.0 Protocol](https://smithy.io/2.0/aws/protocols/aws-restjson1-protocol.html). +/// [AWS restJson1](https://smithy.io/2.0/aws/protocols/aws-restjson1-protocol.html) protocol. pub struct RestJson1; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs index 023b43031c..939b1bb6ec 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs @@ -12,6 +12,8 @@ use super::RestJson1; pub use crate::protocol::rest::router::*; +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs index 0b16df11e3..e16570567e 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs @@ -7,5 +7,5 @@ pub mod rejection; pub mod router; pub mod runtime_error; -/// [AWS REST XML Protocol](https://smithy.io/2.0/aws/protocols/aws-restxml-protocol.html). +/// [AWS restXml](https://smithy.io/2.0/aws/protocols/aws-restxml-protocol.html) protocol. pub struct RestXml; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs index 529a3d19a2..e684ced4de 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs @@ -13,7 +13,8 @@ use super::RestXml; pub use crate::protocol::rest::router::*; -/// An AWS REST routing error. +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/mod.rs new file mode 100644 index 0000000000..287a756446 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +pub mod rejection; +pub mod router; +pub mod runtime_error; + +/// [Smithy RPC v2 CBOR](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) +/// protocol. +pub struct RpcV2Cbor; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/rejection.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/rejection.rs new file mode 100644 index 0000000000..2ec8b957af --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/rejection.rs @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::num::TryFromIntError; + +use crate::rejection::MissingContentTypeReason; +use aws_smithy_runtime_api::http::HttpError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ResponseRejection { + #[error("invalid bound HTTP status code; status codes must be inside the 100-999 range: {0}")] + InvalidHttpStatusCode(TryFromIntError), + #[error("error serializing CBOR-encoded body: {0}")] + Serialization(#[from] aws_smithy_types::error::operation::SerializationError), + #[error("error building HTTP response: {0}")] + HttpBuild(#[from] http::Error), +} + +#[derive(Debug, Error)] +pub enum RequestRejection { + #[error("error converting non-streaming body to bytes: {0}")] + BufferHttpBodyBytes(crate::Error), + #[error("request contains invalid value for `Accept` header")] + NotAcceptable, + #[error("expected `Content-Type` header not found: {0}")] + MissingContentType(#[from] MissingContentTypeReason), + #[error("error deserializing request HTTP body as CBOR: {0}")] + CborDeserialize(#[from] aws_smithy_cbor::decode::DeserializeError), + // Unlike the other protocols, RPC v2 uses CBOR, a binary serialization format, so we take in a + // `Vec` here instead of `String`. + #[error("request does not adhere to modeled constraints")] + ConstraintViolation(Vec), + + /// Typically happens when the request has headers that are not valid UTF-8. + #[error("failed to convert request: {0}")] + HttpConversion(#[from] HttpError), +} + +impl From for RequestRejection { + fn from(_err: std::convert::Infallible) -> Self { + match _err {} + } +} + +convert_to_request_rejection!(hyper::Error, BufferHttpBodyBytes); +convert_to_request_rejection!(Box, BufferHttpBodyBytes); diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs new file mode 100644 index 0000000000..53d6e31483 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs @@ -0,0 +1,406 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::convert::Infallible; +use std::str::FromStr; + +use http::header::ToStrError; +use http::HeaderMap; +use once_cell::sync::Lazy; +use regex::Regex; +use thiserror::Error; +use tower::Layer; +use tower::Service; + +use crate::body::empty; +use crate::body::BoxBody; +use crate::extension::RuntimeErrorExtension; +use crate::protocol::aws_json_11::router::ROUTE_CUTOFF; +use crate::response::IntoResponse; +use crate::routing::tiny_map::TinyMap; +use crate::routing::Route; +use crate::routing::Router; +use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; + +use super::RpcV2Cbor; + +pub use crate::protocol::rest::router::*; + +/// An RPC v2 CBOR routing error. +#[derive(Debug, Error)] +pub enum Error { + /// Method was not `POST`. + #[error("method not POST")] + MethodNotAllowed, + /// Requests for the `rpcv2Cbor` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` + /// header. + #[error("contains forbidden headers")] + ForbiddenHeaders, + /// Unable to parse `smithy-protocol` header into a valid wire format value. + #[error("failed to parse `smithy-protocol` header into a valid wire format value")] + InvalidWireFormatHeader(#[from] WireFormatError), + /// Operation not found. + #[error("operation not found")] + NotFound, +} + +/// A [`Router`] supporting the [Smithy RPC v2 CBOR] protocol. +/// +/// [Smithy RPC v2 CBOR]: https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html +#[derive(Debug, Clone)] +pub struct RpcV2CborRouter { + routes: TinyMap<&'static str, S, ROUTE_CUTOFF>, +} + +/// Requests for the `rpcv2Cbor` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` +/// header. An `rpcv2Cbor` request is malformed if it contains either of these headers. Server-side +/// implementations MUST reject such requests for security reasons. +const FORBIDDEN_HEADERS: &[&str] = &["x-amz-target", "x-amzn-target"]; + +/// Matches the `Identifier` ABNF rule in +/// . +const IDENTIFIER_PATTERN: &str = r#"((_+([A-Za-z]|[0-9]))|[A-Za-z])[A-Za-z0-9_]*"#; + +impl RpcV2CborRouter { + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3748) Consider building a nom parser. + fn uri_path_regex() -> &'static Regex { + // Every request for the `rpcv2Cbor` protocol MUST be sent to a URL with the + // following form: `{prefix?}/service/{serviceName}/operation/{operationName}` + // + // * The optional `prefix` segment may span multiple path segments and is not + // utilized by the Smithy RPC v2 CBOR protocol. For example, a service could + // use a `v1` prefix for the following URL path: `v1/service/FooService/operation/BarOperation` + // * The `serviceName` segment MUST be replaced by the [`shape + // name`](https://smithy.io/2.0/spec/model.html#grammar-token-smithy-Identifier) + // of the service's [Shape ID](https://smithy.io/2.0/spec/model.html#shape-id) + // in the Smithy model. The `serviceName` produced by client implementations + // MUST NOT contain the namespace of the `service` shape. Service + // implementations SHOULD accept an absolute shape ID as the content of this + // segment with the `#` character replaced with a `.` character, routing it + // the same as if only the name was specified. For example, if the `service`'s + // absolute shape ID is `com.example#TheService`, a service should accept both + // `TheService` and `com.example.TheService` as values for the `serviceName` + // segment. + static PATH_REGEX: Lazy = Lazy::new(|| { + Regex::new(&format!( + r#"/service/({}\.)*(?P{})/operation/(?P{})$"#, + IDENTIFIER_PATTERN, IDENTIFIER_PATTERN, IDENTIFIER_PATTERN, + )) + .unwrap() + }); + + &PATH_REGEX + } + + pub fn wire_format_regex() -> &'static Regex { + static SMITHY_PROTOCOL_REGEX: Lazy = Lazy::new(|| Regex::new(r#"^rpc-v2-(?P\w+)$"#).unwrap()); + + &SMITHY_PROTOCOL_REGEX + } + + pub fn boxed(self) -> RpcV2CborRouter> + where + S: Service, Response = http::Response, Error = Infallible>, + S: Send + Clone + 'static, + S::Future: Send + 'static, + { + RpcV2CborRouter { + routes: self.routes.into_iter().map(|(key, s)| (key, Route::new(s))).collect(), + } + } + + /// Applies a [`Layer`] uniformly to all routes. + pub fn layer(self, layer: L) -> RpcV2CborRouter + where + L: Layer, + { + RpcV2CborRouter { + routes: self + .routes + .into_iter() + .map(|(key, route)| (key, layer.layer(route))) + .collect(), + } + } +} + +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! +impl IntoResponse for Error { + fn into_response(self) -> http::Response { + match self { + Error::MethodNotAllowed => method_disallowed(), + _ => http::Response::builder() + .status(http::StatusCode::NOT_FOUND) + .header(http::header::CONTENT_TYPE, "application/cbor") + .extension(RuntimeErrorExtension::new( + UNKNOWN_OPERATION_EXCEPTION.to_string(), + )) + .body(empty()) + .expect("invalid HTTP response for RPCv2 CBOR routing error; please file a bug report under https://github.com/awslabs/smithy-rs/issues"), + } + } +} + +/// Errors that can happen when parsing the wire format from the `smithy-protocol` header. +#[derive(Debug, Error)] +pub enum WireFormatError { + /// Header not found. + #[error("`smithy-protocol` header not found")] + HeaderNotFound, + /// Header value is not visible ASCII. + #[error("`smithy-protocol` header not visible ASCII")] + HeaderValueNotVisibleAscii(ToStrError), + /// Header value does not match the `rpc-v2-{format}` pattern. The actual parsed header value + /// is stored in the tuple struct. + // https://doc.rust-lang.org/std/fmt/index.html#escaping + #[error("`smithy-protocol` header does not match the `rpc-v2-{{format}}` pattern: `{0}`")] + HeaderValueNotValid(String), + /// Header value matches the `rpc-v2-{format}` pattern, but the `format` is not supported. The + /// actual parsed header value is stored in the tuple struct. + #[error("found unsupported `smithy-protocol` wire format: `{0}`")] + WireFormatNotSupported(String), +} + +/// Smithy RPC V2 requests have a `smithy-protocol` header with the value +/// `"rpc-v2-{format}"`, where `format` is one of the supported wire formats +/// by the protocol (see [`WireFormat`]). +fn parse_wire_format_from_header(headers: &HeaderMap) -> Result { + let header = headers.get("smithy-protocol").ok_or(WireFormatError::HeaderNotFound)?; + let header = header.to_str().map_err(WireFormatError::HeaderValueNotVisibleAscii)?; + let captures = RpcV2CborRouter::<()>::wire_format_regex() + .captures(header) + .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; + + let format = captures + .name("format") + .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; + + let wire_format_parse_res: Result = format.as_str().parse(); + wire_format_parse_res.map_err(|_| WireFormatError::WireFormatNotSupported(header.to_owned())) +} + +/// Supported wire formats by RPC V2. +enum WireFormat { + Cbor, +} + +struct WireFormatFromStrError; + +impl FromStr for WireFormat { + type Err = WireFormatFromStrError; + + fn from_str(format: &str) -> Result { + match format { + "cbor" => Ok(Self::Cbor), + _ => Err(WireFormatFromStrError), + } + } +} + +impl Router for RpcV2CborRouter { + type Service = S; + + type Error = Error; + + fn match_route(&self, request: &http::Request) -> Result { + // Only `Method::POST` is allowed. + if request.method() != http::Method::POST { + return Err(Error::MethodNotAllowed); + } + + // Some headers are not allowed. + let request_has_forbidden_header = FORBIDDEN_HEADERS + .iter() + .any(|&forbidden_header| request.headers().contains_key(forbidden_header)); + if request_has_forbidden_header { + return Err(Error::ForbiddenHeaders); + } + + // Wire format has to be specified and supported. + let _wire_format = parse_wire_format_from_header(request.headers())?; + + // Extract the service name and the operation name from the request URI. + let request_path = request.uri().path(); + let regex = Self::uri_path_regex(); + + tracing::trace!(%request_path, "capturing service and operation from URI"); + let captures = regex.captures(request_path).ok_or(Error::NotFound)?; + let (service, operation) = (&captures["service"], &captures["operation"]); + tracing::trace!(%service, %operation, "captured service and operation from URI"); + + // Lookup in the `TinyMap` for a route for the target. + let route = self + .routes + .get((format!("{service}.{operation}")).as_str()) + .ok_or(Error::NotFound)?; + Ok(route.clone()) + } +} + +impl FromIterator<(&'static str, S)> for RpcV2CborRouter { + #[inline] + fn from_iter>(iter: T) -> Self { + Self { + routes: iter.into_iter().collect(), + } + } +} + +#[cfg(test)] +mod tests { + use http::{HeaderMap, HeaderValue, Method}; + use regex::Regex; + + use crate::protocol::test_helpers::req; + + use super::{Error, Router, RpcV2CborRouter}; + + fn identifier_regex() -> Regex { + Regex::new(&format!("^{}$", super::IDENTIFIER_PATTERN)).unwrap() + } + + #[test] + fn valid_identifiers() { + let valid_identifiers = vec!["a", "_a", "_0", "__0", "variable123", "_underscored_variable"]; + + for id in &valid_identifiers { + assert!(identifier_regex().is_match(id), "'{}' is incorrectly rejected", id); + } + } + + #[test] + fn invalid_identifiers() { + let invalid_identifiers = vec![ + "0", + "123starts_with_digit", + "@invalid_start_character", + " space_in_identifier", + "invalid-character", + "invalid@character", + "no#hashes", + ]; + + for id in &invalid_identifiers { + assert!(!identifier_regex().is_match(id), "'{}' is incorrectly accepted", id); + } + } + + #[test] + fn uri_regex_works_accepts() { + let regex = RpcV2CborRouter::<()>::uri_path_regex(); + + for uri in [ + "/service/Service/operation/Operation", + "prefix/69/service/Service/operation/Operation", + // Here the prefix is up to the last occurrence of the string `/service`. + "prefix/69/service/Service/operation/Operation/service/Service/operation/Operation", + // Service implementations SHOULD accept an absolute shape ID as the content of this + // segment with the `#` character replaced with a `.` character, routing it the same as + // if only the name was specified. For example, if the `service`'s absolute shape ID is + // `com.example#TheService`, a service should accept both `TheService` and + // `com.example.TheService` as values for the `serviceName` segment. + "/service/aws.protocoltests.rpcv2Cbor.Service/operation/Operation", + "/service/namespace.Service/operation/Operation", + ] { + let captures = regex.captures(uri).unwrap(); + assert_eq!("Service", &captures["service"], "uri: {}", uri); + assert_eq!("Operation", &captures["operation"], "uri: {}", uri); + } + } + + #[test] + fn uri_regex_works_rejects() { + let regex = RpcV2CborRouter::<()>::uri_path_regex(); + + for uri in [ + "", + "foo", + "/servicee/Service/operation/Operation", + "/service/Service", + "/service/Service/operation/", + "/service/Service/operation/Operation/", + "/service/Service/operation/Operation/invalid-suffix", + "/service/namespace.foo#Service/operation/Operation", + "/service/namespace-Service/operation/Operation", + "/service/.Service/operation/Operation", + ] { + assert!(regex.captures(uri).is_none(), "uri: {}", uri); + } + } + + #[test] + fn wire_format_regex_works() { + let regex = RpcV2CborRouter::<()>::wire_format_regex(); + + let captures = regex.captures("rpc-v2-something").unwrap(); + assert_eq!("something", &captures["format"]); + + let captures = regex.captures("rpc-v2-SomethingElse").unwrap(); + assert_eq!("SomethingElse", &captures["format"]); + + let invalid = regex.captures("rpc-v1-something"); + assert!(invalid.is_none()); + } + + /// Helper function returning the only strictly required header. + fn headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert("smithy-protocol", HeaderValue::from_static("rpc-v2-cbor")); + headers + } + + #[test] + fn simple_routing() { + let router: RpcV2CborRouter<_> = ["Service.Operation"].into_iter().map(|op| (op, ())).collect(); + let good_uri = "/prefix/service/Service/operation/Operation"; + + // The request should match. + let routing_result = router.match_route(&req(&Method::POST, good_uri, Some(headers()))); + assert!(routing_result.is_ok()); + + // The request would be valid if it used `Method::POST`. + let invalid_request = req(&Method::GET, good_uri, Some(headers())); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::MethodNotAllowed) + )); + + // The request would be valid if it did not have forbidden headers. + for forbidden_header_name in ["x-amz-target", "x-amzn-target"] { + let mut headers = headers(); + headers.insert(forbidden_header_name, HeaderValue::from_static("Service.Operation")); + let invalid_request = req(&Method::POST, good_uri, Some(headers)); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::ForbiddenHeaders) + )); + } + + for bad_uri in [ + // These requests would be valid if they used correct URIs. + "/prefix/Service/Service/operation/Operation", + "/prefix/service/Service/operation/Operation/suffix", + // These requests would be valid if their URI matched an existing operation. + "/prefix/service/ThisServiceDoesNotExist/operation/Operation", + "/prefix/service/Service/operation/ThisOperationDoesNotExist", + ] { + let invalid_request = &req(&Method::POST, bad_uri, Some(headers())); + assert!(matches!(router.match_route(&invalid_request), Err(Error::NotFound))); + } + + // The request would be valid if it specified a supported wire format in the + // `smithy-protocol` header. + for header_name in ["bad-header", "rpc-v2-json", "foo-rpc-v2-cbor", "rpc-v2-cbor-foo"] { + let mut headers = HeaderMap::new(); + headers.insert("smithy-protocol", HeaderValue::from_static(header_name)); + let invalid_request = &req(&Method::POST, good_uri, Some(headers)); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::InvalidWireFormatHeader(_)) + )); + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs new file mode 100644 index 0000000000..b3f01da351 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs @@ -0,0 +1,98 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::response::IntoResponse; +use crate::runtime_error::{InternalFailureException, INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE}; +use crate::{extension::RuntimeErrorExtension, protocol::rpc_v2_cbor::RpcV2Cbor}; +use bytes::Bytes; +use http::StatusCode; + +use super::rejection::{RequestRejection, ResponseRejection}; + +#[derive(Debug, thiserror::Error)] +pub enum RuntimeError { + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::Serialization`] + #[error("request failed to deserialize or response failed to serialize: {0}")] + Serialization(crate::Error), + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::InternalFailure`] + #[error("internal failure: {0}")] + InternalFailure(crate::Error), + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::NotAcceptable`] + #[error("not acceptable request: request contains an `Accept` header with a MIME type, and the server cannot return a response body adhering to that MIME type")] + NotAcceptable, + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::UnsupportedMediaType`] + #[error("unsupported media type: request does not contain the expected `Content-Type` header value")] + UnsupportedMediaType, + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::Validation`] + #[error( + "validation failure: operation input contains data that does not adhere to the modeled constraints: {0:?}" + )] + Validation(Vec), +} + +impl RuntimeError { + pub fn name(&self) -> &'static str { + match self { + Self::Serialization(_) => "SerializationException", + Self::InternalFailure(_) => "InternalFailureException", + Self::NotAcceptable => "NotAcceptableException", + Self::UnsupportedMediaType => "UnsupportedMediaTypeException", + Self::Validation(_) => "ValidationException", + } + } + + pub fn status_code(&self) -> StatusCode { + match self { + Self::Serialization(_) => StatusCode::BAD_REQUEST, + Self::InternalFailure(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::NotAcceptable => StatusCode::NOT_ACCEPTABLE, + Self::UnsupportedMediaType => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::Validation(_) => StatusCode::BAD_REQUEST, + } + } +} + +impl IntoResponse for InternalFailureException { + fn into_response(self) -> http::Response { + IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) + } +} + +impl IntoResponse for RuntimeError { + fn into_response(self) -> http::Response { + let res = http::Response::builder() + .status(self.status_code()) + .header("Content-Type", "application/cbor") + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + // https://cbor.nemo157.com/#type=hex&value=a0 + const EMPTY_CBOR_MAP: Bytes = Bytes::from_static(&[0xa0]); + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing + // `__type`. + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + _ => crate::body::to_boxed(EMPTY_CBOR_MAP), + }; + + res.body(body) + .expect(INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE) + } +} + +impl From for RuntimeError { + fn from(err: ResponseRejection) -> Self { + Self::Serialization(crate::Error::new(err)) + } +} + +impl From for RuntimeError { + fn from(err: RequestRejection) -> Self { + match err { + RequestRejection::ConstraintViolation(reason) => Self::Validation(reason), + _ => Self::Serialization(crate::Error::new(err)), + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 14f124f687..ede1f5117b 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -37,7 +37,6 @@ use futures_util::{ use http::Response; use http_body::Body as HttpBody; use tower::{util::Oneshot, Service, ServiceExt}; -use tracing::debug; use crate::{ body::{boxed, BoxBody}, @@ -191,12 +190,13 @@ where } fn call(&mut self, req: http::Request) -> Self::Future { + tracing::debug!("inside routing service call"); match self.router.match_route(&req) { // Successfully routed, use the routes `Service::call`. Ok(ok) => RoutingFuture::from_oneshot(ok.oneshot(req)), // Failed to route, use the `R::Error`s `IntoResponse

`. Err(error) => { - debug!(%error, "failed to route"); + tracing::debug!(%error, "failed to route"); RoutingFuture::from_response(error.into_response()) } } diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index e674b66dfe..9f5189079a 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-protocol-test" -version = "0.60.8" +version = "0.62.0" authors = ["AWS Rust SDK Team ", "Russell Cohen "] description = "A collection of library functions to validate HTTP requests against Smithy protocol tests." edition = "2021" @@ -10,6 +10,9 @@ repository = "https://github.com/smithy-lang/smithy-rs" [dependencies] # Not perfect for our needs, but good for now assert-json-diff = "1.1" +base64-simd = "0.8" +cbor-diag = "0.1.12" +serde_cbor = "0.11" http = "0.2.1" pretty_assertions = "1.3" regex-lite = "0.1.5" @@ -18,7 +21,6 @@ serde_json = "1" thiserror = "1.0.40" aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["client"] } - [package.metadata.docs.rs] all-features = true targets = ["x86_64-unknown-linux-gnu"] diff --git a/rust-runtime/aws-smithy-protocol-test/src/lib.rs b/rust-runtime/aws-smithy-protocol-test/src/lib.rs index 7fca5463d6..06cdbc2ff2 100644 --- a/rust-runtime/aws-smithy-protocol-test/src/lib.rs +++ b/rust-runtime/aws-smithy-protocol-test/src/lib.rs @@ -306,10 +306,12 @@ pub fn require_headers( #[derive(Clone)] pub enum MediaType { - /// Json media types are deserialized and compared + /// JSON media types are deserialized and compared Json, /// XML media types are normalized and compared Xml, + /// CBOR media types are decoded from base64 to binary and compared + Cbor, /// For x-www-form-urlencoded, do some map order comparison shenanigans UrlEncodedForm, /// Other media types are compared literally @@ -322,13 +324,14 @@ impl> From for MediaType { "application/json" => MediaType::Json, "application/x-amz-json-1.1" => MediaType::Json, "application/xml" => MediaType::Xml, + "application/cbor" => MediaType::Cbor, "application/x-www-form-urlencoded" => MediaType::UrlEncodedForm, other => MediaType::Other(other.to_string()), } } } -pub fn validate_body>( +pub fn validate_body + Debug>( actual_body: T, expected_body: &str, media_type: MediaType, @@ -336,11 +339,11 @@ pub fn validate_body>( let body_str = std::str::from_utf8(actual_body.as_ref()); match (media_type, body_str) { (MediaType::Json, Ok(actual_body)) => try_json_eq(expected_body, actual_body), - (MediaType::Xml, Ok(actual_body)) => try_xml_equivalent(expected_body, actual_body), (MediaType::Json, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "json".to_owned(), found: "input was not valid UTF-8".to_owned(), }), + (MediaType::Xml, Ok(actual_body)) => try_xml_equivalent(actual_body, expected_body), (MediaType::Xml, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "XML".to_owned(), found: "input was not valid UTF-8".to_owned(), @@ -352,6 +355,7 @@ pub fn validate_body>( expected: "x-www-form-urlencoded".to_owned(), found: "input was not valid UTF-8".to_owned(), }), + (MediaType::Cbor, _) => try_cbor_eq(actual_body, expected_body), (MediaType::Other(media_type), Ok(actual_body)) => { if actual_body != expected_body { Err(ProtocolTestFailure::BodyDidNotMatch { @@ -410,6 +414,66 @@ fn try_json_eq(expected: &str, actual: &str) -> Result<(), ProtocolTestFailure> } } +fn try_cbor_eq + Debug>( + actual_body: T, + expected_body: &str, +) -> Result<(), ProtocolTestFailure> { + let decoded = base64_simd::STANDARD + .decode_to_vec(expected_body) + .expect("smithy protocol test `body` property is not properly base64 encoded"); + let expected_cbor_value: serde_cbor::Value = + serde_cbor::from_slice(decoded.as_slice()).expect("expected value must be valid CBOR"); + let actual_cbor_value: serde_cbor::Value = serde_cbor::from_slice(actual_body.as_ref()) + .map_err(|e| ProtocolTestFailure::InvalidBodyFormat { + expected: "cbor".to_owned(), + found: format!("{} {:?}", e, actual_body), + })?; + let actual_body_base64 = base64_simd::STANDARD.encode_to_string(&actual_body); + + if expected_cbor_value != actual_cbor_value { + let expected_body_annotated_hex: String = cbor_diag::parse_bytes(&decoded) + .expect("smithy protocol test `body` property is not valid CBOR") + .to_hex(); + let expected_body_diag: String = cbor_diag::parse_bytes(&decoded) + .expect("smithy protocol test `body` property is not valid CBOR") + .to_diag_pretty(); + let actual_body_annotated_hex: String = cbor_diag::parse_bytes(&actual_body) + .expect("actual body is not valid CBOR") + .to_hex(); + let actual_body_diag: String = cbor_diag::parse_bytes(&actual_body) + .expect("actual body is not valid CBOR") + .to_diag_pretty(); + + Err(ProtocolTestFailure::BodyDidNotMatch { + comparison: PrettyString(format!( + "{}", + Comparison::new(&expected_cbor_value, &actual_cbor_value) + )), + // The last newline is important because the panic message ends with a `.` + hint: format!( + "expected body in diagnostic format: +{} +actual body in diagnostic format: +{} +expected body in annotated hex: +{} +actual body in annotated hex: +{} +actual body in base64 (useful to update the protocol test): +{} +", + expected_body_diag, + actual_body_diag, + expected_body_annotated_hex, + actual_body_annotated_hex, + actual_body_base64, + ), + }) + } else { + Ok(()) + } +} + #[cfg(test)] mod tests { use crate::{