From b9016aed72b970ce0dcfb07357d27ae3bdd5f0e2 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 4 Jun 2024 12:24:40 +0200 Subject: [PATCH 01/77] Add RPCv2 support --- aws/sdk-adhoc-test/build.gradle.kts | 3 +- aws/sdk/build.gradle.kts | 3 +- build.gradle.kts | 2 +- buildSrc/src/main/kotlin/CodegenTestCommon.kt | 59 +- codegen-client-test/build.gradle.kts | 8 +- codegen-client/build.gradle.kts | 3 +- .../protocol/ProtocolTestGenerator.kt | 24 +- .../smithy/protocols/ClientProtocolLoader.kt | 29 +- codegen-core/build.gradle.kts | 1 + .../adwait-cbor-structs.smithy | 142 ++++ .../adwait-empty-input-output.smithy | 174 +++++ .../common-test-models/adwait-main.smithy | 30 + codegen-core/common-test-models/rpcv2.smithy | 206 ++++++ .../codegen/core/rustlang/CargoDependency.kt | 3 + .../rust/codegen/core/rustlang/RustType.kt | 1 + .../rust/codegen/core/smithy/RuntimeType.kt | 13 +- .../core/smithy/generators/Instantiator.kt | 9 +- .../codegen/core/smithy/protocols/AwsJson.kt | 2 +- .../codegen/core/smithy/protocols/Protocol.kt | 19 +- .../smithy/protocols/ProtocolFunctions.kt | 2 +- .../codegen/core/smithy/protocols/RestJson.kt | 8 +- .../codegen/core/smithy/protocols/RpcV2.kt | 141 ++++ .../protocols/parse/CborParserGenerator.kt | 657 ++++++++++++++++++ .../protocols/parse/JsonParserGenerator.kt | 4 +- .../protocols/parse/ReturnSymbolToParse.kt | 8 + .../parse/StructuredDataParserGenerator.kt | 4 +- .../serialize/CborSerializerGenerator.kt | 469 +++++++++++++ .../serialize/JsonSerializerGenerator.kt | 17 +- codegen-server-test/build.gradle.kts | 17 +- codegen-server-test/python/build.gradle.kts | 3 +- .../typescript/build.gradle.kts | 3 +- codegen-server/build.gradle.kts | 1 + .../server/smithy/ServerCargoDependency.kt | 3 +- ...ypeFieldToServerErrorsCborCustomization.kt | 39 ++ ...ncodingMapOrCollectionCborCustomization.kt | 39 ++ .../generators/protocol/ServerProtocol.kt | 93 ++- .../protocol/ServerProtocolTestGenerator.kt | 30 +- .../ServerHttpBoundProtocolGenerator.kt | 69 +- .../smithy/protocols/ServerProtocolLoader.kt | 5 +- .../protocols/ServerRpcV2CborFactory.kt | 35 + .../server/smithy/protocols/RpcV2Test.kt | 42 ++ .../serialize/CborSerializerGeneratorTest.kt | 126 ++++ examples/Cargo.toml | 3 + examples/Makefile | 13 +- gradle.properties | 8 +- rust-runtime/Cargo.toml | 1 + rust-runtime/aws-smithy-cbor/Cargo.toml | 40 ++ rust-runtime/aws-smithy-cbor/LICENSE | 175 +++++ rust-runtime/aws-smithy-cbor/README.md | 7 + rust-runtime/aws-smithy-cbor/benches/blob.rs | 21 + .../aws-smithy-cbor/benches/string.rs | 131 ++++ rust-runtime/aws-smithy-cbor/src/data.rs | 97 +++ rust-runtime/aws-smithy-cbor/src/decode.rs | 314 +++++++++ rust-runtime/aws-smithy-cbor/src/encode.rs | 92 +++ rust-runtime/aws-smithy-cbor/src/lib.rs | 13 + .../aws-smithy-http-server/Cargo.toml | 3 +- .../examples/rpcv2-service/Cargo.toml | 14 + .../examples/rpcv2-service/src/main.rs | 32 + .../src/protocol/aws_json/router.rs | 3 +- .../src/protocol/mod.rs | 1 + .../src/protocol/rpc_v2/mod.rs | 12 + .../src/protocol/rpc_v2/rejection.rs | 49 ++ .../src/protocol/rpc_v2/router.rs | 409 +++++++++++ .../src/protocol/rpc_v2/runtime_error.rs | 82 +++ .../aws-smithy-http-server/src/routing/mod.rs | 4 +- .../aws-smithy-protocol-test/Cargo.toml | 6 +- .../aws-smithy-protocol-test/src/lib.rs | 65 +- .../aws-smithy-types/src/error/operation.rs | 1 + 68 files changed, 4022 insertions(+), 120 deletions(-) create mode 100644 codegen-core/common-test-models/adwait-cbor-structs.smithy create mode 100644 codegen-core/common-test-models/adwait-empty-input-output.smithy create mode 100644 codegen-core/common-test-models/adwait-main.smithy create mode 100644 codegen-core/common-test-models/rpcv2.smithy create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt create mode 100644 rust-runtime/aws-smithy-cbor/Cargo.toml create mode 100644 rust-runtime/aws-smithy-cbor/LICENSE create mode 100644 rust-runtime/aws-smithy-cbor/README.md create mode 100644 rust-runtime/aws-smithy-cbor/benches/blob.rs create mode 100644 rust-runtime/aws-smithy-cbor/benches/string.rs create mode 100644 rust-runtime/aws-smithy-cbor/src/data.rs create mode 100644 rust-runtime/aws-smithy-cbor/src/decode.rs create mode 100644 rust-runtime/aws-smithy-cbor/src/encode.rs create mode 100644 rust-runtime/aws-smithy-cbor/src/lib.rs create mode 100644 rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml create mode 100644 rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs create mode 100644 rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs create mode 100644 rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/rejection.rs create mode 100644 rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs create mode 100644 rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs diff --git a/aws/sdk-adhoc-test/build.gradle.kts b/aws/sdk-adhoc-test/build.gradle.kts index 1f2b7fde5a..3bdc663dad 100644 --- a/aws/sdk-adhoc-test/build.gradle.kts +++ b/aws/sdk-adhoc-test/build.gradle.kts @@ -20,7 +20,6 @@ java { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val pluginName = "rust-client-codegen" @@ -78,7 +77,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile, defaultRustDocFlags) +project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index 6e2165eb12..95b27bad1e 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -31,7 +31,6 @@ configure { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val crateHasherToolPath = rootProject.projectDir.resolve("tools/ci-build/crate-hasher") @@ -442,7 +441,7 @@ tasks["assemble"].apply { outputs.upToDateWhen { false } } -project.registerCargoCommandsTasks(outputDir.asFile, defaultRustDocFlags) +project.registerCargoCommandsTasks(outputDir.asFile) project.registerGenerateCargoConfigTomlTask(outputDir.asFile) //The task name "test" is already registered by one of our plugins diff --git a/build.gradle.kts b/build.gradle.kts index 20f2d9e400..5e11e0ab02 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ allprojects { val allowLocalDeps: String by project repositories { if (allowLocalDeps.toBoolean()) { - mavenLocal() + mavenLocal() } mavenCentral() google() diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index b7681a5b56..1977be9b11 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -29,6 +29,49 @@ fun generateImports(imports: List): String = "\"imports\": [${imports.map { "\"$it\"" }.joinToString(", ")}]," } +fun toRustCrateName(input: String): String { + val rustKeywords = setOf( + // Strict Keywords. + "as", "break", "const", "continue", "crate", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", + "move", "mut", "pub", "ref", "return", "Self", "self", "static", "struct", + "super", "trait", "true", "type", "unsafe", "use", "where", "while", + + // Weak Keywords. + "dyn", "async", "await", "try", + + // Reserved for Future Use. + "abstract", "become", "box", "do", "final", "macro", "override", "priv", + "typeof", "unsized", "virtual", "yield", + + // Primitive Types. + "bool", "char", "i8", "i16", "i32", "i64", "i128", "isize", + "u8", "u16", "u32", "u64", "u128", "usize", "f32", "f64", "str", + + // Additional significant identifiers. + "proc_macro" + ) + + // Then within your function, you could include a check against this set + if (input.isBlank()) { + throw IllegalArgumentException("Rust crate name cannot be empty") + } + val lowerCased = input.lowercase() + // Replace any sequence of characters that are not lowercase letters, numbers, or underscores with a single underscore. + val sanitized = lowerCased.replace(Regex("[^a-z0-9_]+"), "_") + // Trim leading or trailing underscores. + val trimmed = sanitized.trim('_') + // Check if the resulting string is empty, purely numeric, or a reserved name + val finalName = when { + trimmed.isEmpty() -> throw IllegalArgumentException("Rust crate name after sanitizing cannot be empty.") + trimmed.matches(Regex("\\d+")) -> "n$trimmed" // Prepend 'n' if the name is purely numeric. + trimmed in rustKeywords -> "${trimmed}_" // Append an underscore if the name is reserved. + else -> trimmed + } + return finalName +} + + private fun generateSmithyBuild( projectDir: String, pluginName: String, @@ -48,7 +91,7 @@ private fun generateSmithyBuild( ${it.extraCodegenConfig ?: ""} }, "service": "${it.service}", - "module": "${it.module}", + "module": "${toRustCrateName(it.module)}", "moduleVersion": "0.0.1", "moduleDescription": "test", "moduleAuthors": ["protocoltest@example.com"] @@ -205,13 +248,15 @@ fun Project.registerGenerateCargoWorkspaceTask( fun Project.registerGenerateCargoConfigTomlTask(outputDir: File) { this.tasks.register("generateCargoConfigToml") { description = "generate `.cargo/config.toml`" + // TODO(https://github.com/smithy-lang/smithy-rs/issues/1068): Once doc normalization + // is completed, warnings can be prohibited in rustdoc by setting `rustdocflags` to `-D warnings`. doFirst { outputDir.resolve(".cargo").mkdirs() outputDir.resolve(".cargo/config.toml") .writeText( """ [build] - rustflags = ["--deny", "warnings"] + rustflags = ["--deny", "warnings", "--cfg", "aws_sdk_unstable"] """.trimIndent(), ) } @@ -255,10 +300,7 @@ fun Project.registerModifyMtimeTask() { } } -fun Project.registerCargoCommandsTasks( - outputDir: File, - defaultRustDocFlags: String, -) { +fun Project.registerCargoCommandsTasks(outputDir: File) { val dependentTasks = listOfNotNull( "assemble", @@ -269,29 +311,24 @@ fun Project.registerCargoCommandsTasks( this.tasks.register(Cargo.CHECK.toString) { dependsOn(dependentTasks) workingDir(outputDir) - environment("RUSTFLAGS", "--cfg aws_sdk_unstable") commandLine("cargo", "check", "--lib", "--tests", "--benches", "--all-features") } this.tasks.register(Cargo.TEST.toString) { dependsOn(dependentTasks) workingDir(outputDir) - environment("RUSTFLAGS", "--cfg aws_sdk_unstable") commandLine("cargo", "test", "--all-features", "--no-fail-fast") } this.tasks.register(Cargo.DOCS.toString) { dependsOn(dependentTasks) workingDir(outputDir) - environment("RUSTDOCFLAGS", defaultRustDocFlags) - environment("RUSTFLAGS", "--cfg aws_sdk_unstable") commandLine("cargo", "doc", "--no-deps", "--document-private-items") } this.tasks.register(Cargo.CLIPPY.toString) { dependsOn(dependentTasks) workingDir(outputDir) - environment("RUSTFLAGS", "--cfg aws_sdk_unstable") commandLine("cargo", "clippy") } } diff --git a/codegen-client-test/build.gradle.kts b/codegen-client-test/build.gradle.kts index de4b54d5b3..a1c6de5c23 100644 --- a/codegen-client-test/build.gradle.kts +++ b/codegen-client-test/build.gradle.kts @@ -15,7 +15,6 @@ plugins { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) fun getSmithyRuntimeMode(): String = properties.get("smithy.runtime.mode") ?: "orchestrator" @@ -112,6 +111,11 @@ val allCodegenTests = listOf( "pokemon-service-awsjson-client", dependsOn = listOf("pokemon-awsjson.smithy", "pokemon-common.smithy"), ), + ClientTest( + "com.amazonaws.simple#RpcV2Service", + "rpcv2-pokemon-client", + dependsOn = listOf("rpcv2.smithy") + ), ClientTest("aws.protocoltests.misc#QueryCompatService", "query-compat-test", dependsOn = listOf("aws-json-query-compat.smithy")), ).map(ClientTest::toCodegenTest) @@ -125,7 +129,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile, defaultRustDocFlags) +project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/codegen-client/build.gradle.kts b/codegen-client/build.gradle.kts index 485a656d7b..3e1f1ec580 100644 --- a/codegen-client/build.gradle.kts +++ b/codegen-client/build.gradle.kts @@ -27,9 +27,10 @@ dependencies { implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") // `smithy.framework#ValidationException` is defined here, which is used in event stream -// marshalling/unmarshalling tests. + // marshalling/unmarshalling tests. testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt index cb961cbd19..47ea9f5626 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -33,7 +33,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -173,13 +175,19 @@ class DefaultProtocolTestGenerator( } testModuleWriter.write("Test ID: ${testCase.id}") testModuleWriter.newlinePrefix = "" + Attribute.TokioTest.render(testModuleWriter) - val action = - when (testCase) { - is HttpResponseTestCase -> Action.Response - is HttpRequestTestCase -> Action.Request - else -> throw CodegenException("unknown test case type") - } + Attribute.TracedTest.render(testModuleWriter) + // The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in + // case the code rendered by the test does not make use of `tracing` at all. + val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) } + testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) + + val action = when (testCase) { + is HttpResponseTestCase -> Action.Response + is HttpRequestTestCase -> Action.Request + else -> throw CodegenException("unknown test case type") + } if (expectFail(testCase)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") } @@ -415,8 +423,8 @@ class DefaultProtocolTestGenerator( if (body == "") { rustWriter.rustTemplate( """ - // No body - #{AssertEq}(::std::str::from_utf8(body).unwrap(), ""); + // No body. + #{AssertEq}(&body, &bytes::Bytes::new()); """, *codegenScope, ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index 5c0f7e6e1a..2af64bbad8 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext @@ -28,20 +29,21 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2 import software.amazon.smithy.rust.codegen.core.util.hasTrait class ClientProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { companion object { - val DefaultProtocols = - mapOf( - AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10), - AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11), - AwsQueryTrait.ID to ClientAwsQueryFactory(), - Ec2QueryTrait.ID to ClientEc2QueryFactory(), - RestJson1Trait.ID to ClientRestJsonFactory(), - RestXmlTrait.ID to ClientRestXmlFactory(), - ) + val DefaultProtocols = mapOf( + AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10), + AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11), + AwsQueryTrait.ID to ClientAwsQueryFactory(), + Ec2QueryTrait.ID to ClientEc2QueryFactory(), + RestJson1Trait.ID to ClientRestJsonFactory(), + RestXmlTrait.ID to ClientRestXmlFactory(), + Rpcv2CborTrait.ID to ClientRpcV2CborFactory(), + ) val Default = ClientProtocolLoader(DefaultProtocols) } } @@ -117,3 +119,12 @@ class ClientRestXmlFactory( override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT } + +class ClientRpcV2CborFactory : ProtocolGeneratorFactory { + override fun protocol(codegenContext: ClientCodegenContext): Protocol = RpcV2(codegenContext) + + override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator = + OperationGenerator(codegenContext, protocol(codegenContext)) + + override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT +} diff --git a/codegen-core/build.gradle.kts b/codegen-core/build.gradle.kts index 2fdd74abfb..eff612be35 100644 --- a/codegen-core/build.gradle.kts +++ b/codegen-core/build.gradle.kts @@ -28,6 +28,7 @@ dependencies { implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") } fun gitCommitHash(): String { diff --git a/codegen-core/common-test-models/adwait-cbor-structs.smithy b/codegen-core/common-test-models/adwait-cbor-structs.smithy new file mode 100644 index 0000000000..b88d930967 --- /dev/null +++ b/codegen-core/common-test-models/adwait-cbor-structs.smithy @@ -0,0 +1,142 @@ +$version: "2.0" + +namespace aws.protocoltests.rpcv2 + +use aws.protocoltests.shared#StringList +use smithy.protocols#rpcv2 +use smithy.test#httpRequestTests +use smithy.test#httpResponseTests + + +@httpRequestTests([ + { + id: "RpcV2CborSimpleScalarProperties", + protocol: rpcv2, + documentation: "Serializes simple scalar properties", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + method: "POST", + bodyMediaType: "application/cbor", + uri: "/service/RpcV2Protocol/operation/SimpleScalarProperties", + body: "v2lieXRlVmFsdWUFa2RvdWJsZVZhbHVl+z/+OVgQYk3TcWZhbHNlQm9vbGVhblZhbHVl9GpmbG9hdFZhbHVl+kDz989saW50ZWdlclZhbHVlGQEAaWxvbmdWYWx1ZRkmkWpzaG9ydFZhbHVlGSaqa3N0cmluZ1ZhbHVlZnNpbXBsZXB0cnVlQm9vbGVhblZhbHVl9f8=" + params: { + trueBooleanValue: true, + falseBooleanValue: false, + byteValue: 5, + doubleValue: 1.889, + floatValue: 7.624, + integerValue: 256, + shortValue: 9898, + longValue: 9873 + stringValue: "simple" + } + }, + { + id: "RpcV2CborClientDoesntSerializeNullStructureValues", + documentation: "RpcV2 Cbor should not serialize null structure values", + protocol: rpcv2, + method: "POST", + uri: "/service/RpcV2Protocol/operation/SimpleScalarProperties", + body: "v/8=", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + params: { + stringValue: null + }, + appliesTo: "client" + }, + { + id: "RpcV2CborServerDoesntDeSerializeNullStructureValues", + documentation: "RpcV2 Cbor should not deserialize null structure values", + protocol: rpcv2, + method: "POST", + uri: "/service/RpcV2Protocol/operation/SimpleScalarProperties", + body: "v2tzdHJpbmdWYWx1Zfb/", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + params: {}, + appliesTo: "server" + }, +]) +@httpResponseTests([ + { + id: "simple_scalar_structure", + protocol: rpcv2, + documentation: "Serializes simple scalar properties", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + bodyMediaType: "application/cbor", + body: "v2lieXRlVmFsdWUFa2RvdWJsZVZhbHVl+z/+OVgQYk3TcWZhbHNlQm9vbGVhblZhbHVl9GpmbG9hdFZhbHVl+kDz989saW50ZWdlclZhbHVlGQEAaWxvbmdWYWx1ZRkmkWpzaG9ydFZhbHVlGSaqa3N0cmluZ1ZhbHVlZnNpbXBsZXB0cnVlQm9vbGVhblZhbHVl9f8=", + code: 200, + params: { + trueBooleanValue: true, + falseBooleanValue: false, + byteValue: 5, + doubleValue: 1.889, + floatValue: 7.624, + integerValue: 256, + shortValue: 9898, + stringValue: "simple" + } + }, + { + id: "RpcV2CborClientDoesntDeSerializeNullStructureValues", + documentation: "RpcV2 Cbor should deserialize null structure values", + protocol: rpcv2, + body: "v2tzdHJpbmdWYWx1Zfb/", + code: 200, + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + params: {} + appliesTo: "client" + }, + { + id: "RpcV2CborServerDoesntSerializeNullStructureValues", + documentation: "RpcV2 Cbor should not serialize null structure values", + protocol: rpcv2, + body: "v/8=", + code: 200, + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + params: { + stringValue: null + }, + appliesTo: "server" + }, +]) +operation SimpleScalarProperties { + input: SimpleScalarStructure, + output: SimpleScalarStructure +} + + +structure SimpleScalarStructure { + trueBooleanValue: Boolean, + falseBooleanValue: Boolean, + byteValue: Byte, + doubleValue: Double, + floatValue: Float, + integerValue: Integer, + longValue: Long, + shortValue: Short, + stringValue: String, +} diff --git a/codegen-core/common-test-models/adwait-empty-input-output.smithy b/codegen-core/common-test-models/adwait-empty-input-output.smithy new file mode 100644 index 0000000000..186cbaaa30 --- /dev/null +++ b/codegen-core/common-test-models/adwait-empty-input-output.smithy @@ -0,0 +1,174 @@ +$version: "2.0" + +namespace aws.protocoltests.rpcv2 + +use smithy.protocols#rpcv2 +use smithy.test#httpRequestTests +use smithy.test#httpResponseTests + + +@httpRequestTests([ + { + id: "no_input", + protocol: rpcv2, + documentation: "Body is empty and no Content-Type header if no input", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + }, + forbidHeaders: [ + "Content-Type", + "X-Amz-Target" + ] + method: "POST", + uri: "/service/RpcV2Protocol/operation/NoInputOutput", + body: "" + }, + { + id: "no_input_server_allows_accept", + protocol: rpcv2, + documentation: "Servers should allow the Accept header to be set to the default content-type.", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + method: "POST", + uri: "/service/RpcV2Protocol/operation/NoInputOutput", + body: "", + appliesTo: "server" + }, + { + id: "no_input_server_allows_empty_cbor", + protocol: rpcv2, + documentation: "Servers should accept CBOR empty struct if no input.", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + method: "POST", + uri: "/service/RpcV2Protocol/operation/NoInputOutput", + body: "v/8=", + appliesTo: "server" + } +]) +@httpResponseTests([ + { + id: "no_output", + protocol: rpcv2, + documentation: "Body is empty and no Content-Type header if no response", + body: "", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + }, + forbidHeaders: [ + "Content-Type" + ] + code: 200, + }, + { + id: "no_output_client_allows_accept", + protocol: rpcv2, + documentation: "Servers should allow the accept header to be set to the default content-type.", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + body: "", + code: 200, + appliesTo: "client", + }, + { + id: "no_input_client_allows_empty_cbor", + protocol: rpcv2, + documentation: "Client should accept CBOR empty struct if no output", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + body: "v/8=", + code: 200, + appliesTo: "client", + } +]) +operation NoInputOutput {} + + +@httpRequestTests([ + { + id: "empty_input", + protocol: rpcv2, + documentation: "When Input structure is empty we write CBOR equivalent of {}", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + }, + forbidHeaders: [ + "X-Amz-Target" + ] + method: "POST", + uri: "/service/RpcV2Protocol/operation/EmptyInputOutput", + body: "v/8=", + }, +]) +@httpResponseTests([ + { + id: "empty_output", + protocol: rpcv2, + documentation: "When output structure is empty we write CBOR equivalent of {}", + body: "v/8=", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + code: 200, + }, +]) +operation EmptyInputOutput { + input: EmptyStructure, + output: EmptyStructure +} + +@httpRequestTests([ + { + id: "optional_input", + protocol: rpcv2, + documentation: "When input is empty we write CBOR equivalent of {}", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + }, + forbidHeaders: [ + "X-Amz-Target" + ] + method: "POST", + uri: "/service/RpcV2Protocol/operation/OptionalInputOutput", + body: "v/8=", + bodyMediaType: "application/cbor", + }, +]) +@httpResponseTests([ + { + id: "optional_output", + protocol: rpcv2, + documentation: "When output is empty we write CBOR equivalent of {}", + body: "v/8=", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + code: 200, + }, +]) +operation OptionalInputOutput { + input: SimpleStructure, + output: SimpleStructure +} diff --git a/codegen-core/common-test-models/adwait-main.smithy b/codegen-core/common-test-models/adwait-main.smithy new file mode 100644 index 0000000000..c827b1dd79 --- /dev/null +++ b/codegen-core/common-test-models/adwait-main.smithy @@ -0,0 +1,30 @@ +$version: "2.0" + +namespace aws.protocoltests.rpcv2 +use aws.api#service +use smithy.protocols#rpcv2 +use smithy.test#httpRequestTests +use smithy.test#httpResponseTests + +@service(sdkId: "Sample RpcV2 Protocol") +@rpcv2(format: ["cbor"]) +@title("RpcV2 Protocol Service") +service RpcV2Protocol { + version: "2020-07-14", + operations: [ + //Basic input/output tests + NoInputOutput, + EmptyInputOutput, + OptionalInputOutput, + + SimpleScalarProperties, + ] +} + +structure EmptyStructure { + +} + +structure SimpleStructure { + value: String, +} \ No newline at end of file diff --git a/codegen-core/common-test-models/rpcv2.smithy b/codegen-core/common-test-models/rpcv2.smithy new file mode 100644 index 0000000000..48a4fb694b --- /dev/null +++ b/codegen-core/common-test-models/rpcv2.smithy @@ -0,0 +1,206 @@ +$version: "2.0" + +// TODO Update namespace +namespace com.amazonaws.simple + +use smithy.framework#ValidationException +use smithy.protocols#rpcv2 +use smithy.test#httpResponseTests + +@rpcv2(format: ["cbor"]) +service RpcV2Service { + operations: [ + SimpleStructOperation, + ComplexStructOperation + ] +} + +// TODO RpcV2 should not use the `@http` trait. +@http(uri: "/simple-struct-operation", method: "POST") +operation SimpleStructOperation { + input: SimpleStruct + output: SimpleStruct + errors: [ValidationException] +} + +// TODO RpcV2 should not use the `@http` trait. +@http(uri: "/complex-struct-operation", method: "POST") +operation ComplexStructOperation { + input: ComplexStruct + output: ComplexStruct + errors: [ValidationException] +} + +apply SimpleStructOperation @httpResponseTests([ + { + id: "SimpleStruct", + protocol: "smithy.protocols#rpcv2", + code: 200, // Not used. + params: { + blob: "blobby blob", + boolean: false, + + string: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + byte: 69, + short: 70, + integer: 71, + long: 72, + + float: 0.69, + double: 0.6969, + + timestamp: 1546300800, + // document: { + // documentInteger: 69 + // } + enum: "DIAMOND" + + // With `@required`. + + requiredBlob: "blobby blob", + requiredBoolean: false, + + requiredString: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + requiredByte: 69, + requiredShort: 70, + requiredInteger: 71, + requiredLong: 72, + + requiredFloat: 0.69, + requiredDouble: 0.6969, + + requiredTimestamp: 1546300800, + // document: { + // documentInteger: 69 + // } + requiredEnum: "DIAMOND" + } + }, + // Same test, but leave optional types empty + { + id: "SimpleStructWithOptionsSetToNone", + protocol: "smithy.protocols#rpcv2", + code: 200, // Not used. + params: { + requiredBlob: "blobby blob", + requiredBoolean: false, + + requiredString: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + requiredByte: 69, + requiredShort: 70, + requiredInteger: 71, + requiredLong: 72, + + requiredFloat: 0.69, + requiredDouble: 0.6969, + + requiredTimestamp: 1546300800, + // document: { + // documentInteger: 69 + // } + requiredEnum: "DIAMOND" + } + } +]) + +structure SimpleStruct { + blob: Blob + boolean: Boolean + + string: String + + byte: Byte + short: Short + integer: Integer + long: Long + + float: Float + double: Double + + timestamp: Timestamp + // document: MyDocument + enum: Suit + + // With `@required`. + + @required requiredBlob: Blob + @required requiredBoolean: Boolean + + @required requiredString: String + + @required requiredByte: Byte + @required requiredShort: Short + @required requiredInteger: Integer + @required requiredLong: Long + + @required requiredFloat: Float + @required requiredDouble: Double + + @required requiredTimestamp: Timestamp + // @required requiredDocument: MyDocument + @required requiredEnum: Suit +} + +structure ComplexStruct { + structure: SimpleStruct + list: SimpleList + map: SimpleMap + union: SimpleUnion + + structureList: StructureList + + // `@required` for good measure here. + @required complexList: ComplexList + @required complexMap: ComplexMap + @required complexUnion: ComplexUnion +} + +list StructureList { + member: SimpleStruct +} + +list SimpleList { + member: String +} + +map SimpleMap { + key: String + value: Integer +} + +union SimpleUnion { + blob: Blob + boolean: Boolean + string: String +} + +list ComplexList { + member: ComplexMap +} + +map ComplexMap { + key: String + value: ComplexUnion +} + +union ComplexUnion { + // Recursive path here. + complexStruct: ComplexStruct + + structure: SimpleStruct + list: SimpleList + map: SimpleMap + union: SimpleUnion +} + +enum Suit { + DIAMOND + CLUB + HEART + SPADE +} + +// document MyDocument diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 8fc0f9b6a7..b4207eeb99 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -294,6 +294,7 @@ data class CargoDependency( val Hound: CargoDependency = CargoDependency("hound", CratesIo("3.4.0"), DependencyScope.Dev) val PrettyAssertions: CargoDependency = CargoDependency("pretty_assertions", CratesIo("1.3.0"), DependencyScope.Dev) + val SerdeCbor: CargoDependency = CargoDependency("serde_cbor", CratesIo("0.11"), DependencyScope.Dev) val SerdeJson: CargoDependency = CargoDependency("serde_json", CratesIo("1.0.0"), DependencyScope.Dev) val Smol: CargoDependency = CargoDependency("smol", CratesIo("1.2.0"), DependencyScope.Dev) val TempFile: CargoDependency = CargoDependency("tempfile", CratesIo("3.2.0"), DependencyScope.Dev) @@ -327,6 +328,8 @@ data class CargoDependency( fun smithyAsync(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-async") + fun smithyCbor(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-cbor") + fun smithyChecksums(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-checksums") fun smithyCompression(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-compression") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt index 10c8def399..6a98294e1b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt @@ -570,6 +570,7 @@ class Attribute(val inner: Writable, val isDeriveHelper: Boolean = false) { val Test = Attribute("test") val TokioTest = Attribute(RuntimeType.Tokio.resolve("test").writable) + val TracedTest = Attribute(RuntimeType.TracingTest.resolve("traced_test").writable) val AwsSdkUnstableAttribute = Attribute(cfg("aws_sdk_unstable")) /** diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index da5d742647..c909d411f6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -272,6 +272,9 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val U64 = std.resolve("primitive::u64") val Vec = std.resolve("vec::Vec") + // primitive types + val StaticStr = RuntimeType("&'static str") + // external cargo dependency types val Bytes = CargoDependency.Bytes.toType().resolve("Bytes") val Http = CargoDependency.Http.toType() @@ -288,6 +291,9 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val PercentEncoding = CargoDependency.PercentEncoding.toType() val PrettyAssertions = CargoDependency.PrettyAssertions.toType() val Regex = CargoDependency.Regex.toType() + val Serde= CargoDependency.Serde.toType() + val SerdeDeserialize = Serde.resolve("Deserialize") + val SerdeSerialize = Serde.resolve("Serialize") val RegexLite = CargoDependency.RegexLite.toType() val Tokio = CargoDependency.Tokio.toType() val TokioStream = CargoDependency.TokioStream.toType() @@ -299,14 +305,11 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val ConstrainedTrait = RuntimeType("crate::constrained::Constrained", InlineDependency.constrained()) val MaybeConstrained = RuntimeType("crate::constrained::MaybeConstrained", InlineDependency.constrained()) - // serde types. Gated behind `CfgUnstable`. - val Serde = CargoDependency.Serde.toType() - val SerdeSerialize = Serde.resolve("Serialize") - val SerdeDeserialize = Serde.resolve("Deserialize") - // smithy runtime types fun smithyAsync(runtimeConfig: RuntimeConfig) = CargoDependency.smithyAsync(runtimeConfig).toType() + fun smithyCbor(runtimeConfig: RuntimeConfig) = CargoDependency.smithyCbor(runtimeConfig).toType() + fun smithyChecksums(runtimeConfig: RuntimeConfig) = CargoDependency.smithyChecksums(runtimeConfig).toType() fun smithyCompression(runtimeConfig: RuntimeConfig) = CargoDependency.smithyCompression(runtimeConfig).toType() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index de73ab760d..db51633fef 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -38,6 +38,7 @@ import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait import software.amazon.smithy.model.traits.StreamingTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -65,6 +66,8 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.letIf import java.math.BigDecimal +import software.amazon.smithy.model.traits.DefaultTrait +import software.amazon.smithy.rust.codegen.core.util.getTrait /** * Class describing an instantiator section that can be used in a customization. @@ -452,7 +455,7 @@ open class Instantiator( */ private fun fillDefaultValue(shape: Shape): Node = when (shape) { - is MemberShape -> fillDefaultValue(model.expectShape(shape.target)) + is MemberShape -> shape.getTrait()?.toNode() ?: fillDefaultValue(model.expectShape(shape.target)) // Aggregate shapes. is StructureShape -> Node.objectNode() @@ -487,7 +490,9 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va val fractionalPart = num.remainder(BigDecimal.ONE) rust( "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", - RuntimeType.dateTime(runtimeConfig), +// RuntimeType.dateTime(runtimeConfig), + // TODO + runtimeConfig.smithyRuntimeCrate("smithy-types", scope = DependencyScope.Dev).toType().resolve("DateTime"), ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 486b443a6a..b44bdfb84d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -174,10 +174,10 @@ open class AwsJson( override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> + // `HeaderMap::new()` doesn't allocate. rustTemplate( """ pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - // Note: HeaderMap::new() doesn't allocate #{json_errors}::parse_error_metadata(payload, &#{Headers}::new()) } """, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt index 236c297db9..c7b139bfd5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt @@ -27,11 +27,28 @@ interface Protocol { /** The timestamp format that should be used if no override is specified in the model */ val defaultTimestampFormat: TimestampFormatTrait.Format - /** Returns additional HTTP headers that should be included in HTTP requests for the given operation for this protocol. */ + /** + * Returns additional HTTP headers that should be included in HTTP requests for the given operation for this protocol. + * + * These MUST all be lowercase, or the application will panic, as per + * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static + */ fun additionalRequestHeaders(operationShape: OperationShape): List> = emptyList() + /** + * Returns additional HTTP headers that should be included in HTTP responses for the given operation for this protocol. + * + * These MUST all be lowercase, or the application will panic, as per + * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static + */ + fun additionalResponseHeaders(operationShape: OperationShape): List> = emptyList() + /** * Returns additional HTTP headers that should be included in HTTP responses for the given error shape. + * These headers are added to responses _in addition_ to those returned by `additionalResponseHeaders`; if a header + * added by this function has the same header name as one added by `additionalResponseHeaders`, the one added by + * `additionalResponseHeaders` takes precedence. + * * These MUST all be lowercase, or the application will panic, as per * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static */ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt index e40046f1d8..cb9b766771 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt @@ -39,7 +39,7 @@ class ProtocolFunctions( private val codegenContext: CodegenContext, ) { companion object { - private val serDeModule = RustModule.pubCrate("protocol_serde") + val serDeModule = RustModule.pubCrate("protocol_serde") fun crossOperationFn( fnName: String, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index 641548fc11..c4e3980668 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -56,6 +56,12 @@ class RestJsonHttpBindingResolver( } } } + + // The spec does not mention whether we should set the `Content-Type` header when there is no modeled output. + // The protocol tests indicate it's optional: + // + // + // In our implementation, we opt to always set it to `application/json`. return super.responseContentType(operationShape) ?: "application/json" } } @@ -124,10 +130,10 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> + // `HeaderMap::new()` doesn't allocate. rustTemplate( """ pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - // Note: HeaderMap::new() doesn't allocate #{json_errors}::parse_error_metadata(payload, &#{Headers}::new()) } """, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt new file mode 100644 index 0000000000..75b275fc8f --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt @@ -0,0 +1,141 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ToShapeId +import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.isStreaming +import software.amazon.smithy.rust.codegen.core.util.outputShape + +// TODO Rename these to RpcV2Cbor +class RpcV2HttpBindingResolver( + private val model: Model, +) : HttpBindingResolver { + private fun bindings(shape: ToShapeId): List { + val members = shape.let { model.expectShape(it.toShapeId()) }.members() + // TODO(https://github.com/awslabs/smithy-rs/issues/2237): support non-streaming members too + if (members.size > 1 && members.any { it.isStreaming(model) }) { + throw CodegenException( + "We only support one payload member if that payload contains a streaming member." + + "Tracking issue to relax this constraint: https://github.com/awslabs/smithy-rs/issues/2237", + ) + } + + return members.map { + if (it.isStreaming(model)) { + HttpBindingDescriptor(it, HttpLocation.PAYLOAD, "document") + } else { + HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") + } + } + .toList() + } + + // TODO + // In the server, this is only used when the protocol actually supports the `@http` trait. + // However, we will have to do this for client support. Perhaps this method deserves a rename. + override fun httpTrait(operationShape: OperationShape) = PANIC("RPC v2 does not support the `@http` trait") + + override fun requestBindings(operationShape: OperationShape) = bindings(operationShape.inputShape) + override fun responseBindings(operationShape: OperationShape) = bindings(operationShape.outputShape) + override fun errorResponseBindings(errorShape: ToShapeId) = bindings(errorShape) + + // TODO This should return null when operationShape has no members, and we should not rely on our janky + // `serverContentTypeCheckNoModeledInput`. Same goes for restJson1 protocol. + override fun requestContentType(operationShape: OperationShape): String = "application/cbor" + + /** + * > Responses for operations with no defined output type MUST NOT contain bodies in their HTTP responses. + * > The `Content-Type` for the serialization format MUST NOT be set. + */ + override fun responseContentType(operationShape: OperationShape): String? { + // When `syntheticOutputTrait.originalId == null` it implies that the operation had no output defined + // in the Smithy model. + val syntheticOutputTrait = operationShape.outputShape(model).expectTrait() + if (syntheticOutputTrait.originalId == null) { + return null + } + return requestContentType(operationShape) + } + + override fun eventStreamMessageContentType(memberShape: MemberShape): String? = + ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/cbor") +} + +/** + * TODO: Docs. + */ +open class RpcV2(val codegenContext: CodegenContext) : Protocol { + private val runtimeConfig = codegenContext.runtimeConfig + private val errorScope = arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "HeaderMap" to RuntimeType.Http.resolve("HeaderMap"), + "JsonError" to CargoDependency.smithyJson(runtimeConfig).toType() + .resolve("deserialize::error::DeserializeError"), + "Response" to RuntimeType.Http.resolve("Response"), + "json_errors" to RuntimeType.jsonErrors(runtimeConfig), + ) + private val jsonDeserModule = RustModule.private("json_deser") + + override val httpBindingResolver: HttpBindingResolver = RpcV2HttpBindingResolver(codegenContext.model) + + // Note that [CborParserGenerator] and [CborSerializerGenerator] automatically (de)serialize timestamps + // using floating point seconds from the epoch. + override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS + + override fun additionalResponseHeaders(operationShape: OperationShape): List> = + listOf("smithy-protocol" to "rpc-v2-cbor") + + override fun structuredDataParser(): StructuredDataParserGenerator = + CborParserGenerator(codegenContext, httpBindingResolver) + + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = + CborSerializerGenerator(codegenContext, httpBindingResolver) + + // TODO: Implement `RpcV2.parseHttpErrorMetadata` + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_http_error_metadata", jsonDeserModule) { + rustTemplate( + """ + pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { + #{json_errors}::parse_error_metadata(response.body(), response.headers()) + } + """, + *errorScope, + ) + } + + // TODO: Implement `RpcV2.parseEventStreamErrorMetadata` + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_event_stream_error_metadata", jsonDeserModule) { + // `HeaderMap::new()` doesn't allocate. + rustTemplate( + """ + pub fn parse_event_stream_error_metadata(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { + #{json_errors}::parse_error_metadata(payload, &#{HeaderMap}::new()) + } + """, + *errorScope, + ) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt new file mode 100644 index 0000000000..012c695d16 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -0,0 +1,657 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.SparseTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.outputShape + +/** + * Class describing a CBOR parser section that can be used in a customization. + */ +sealed class CborParserSection(name: String) : Section(name) { + data class BeforeBoxingDeserializedMember(val shape: MemberShape) : CborParserSection("BeforeBoxingDeserializedMember") +} + +/** + * Customization for the CBOR parser. + */ +typealias CborParserCustomization = NamedCustomization + +// TODO Add a `CborParserGeneratorTest` a la `CborSerializerGeneratorTest`. +class CborParserGenerator( + private val codegenContext: CodegenContext, + private val httpBindingResolver: HttpBindingResolver, + /** See docs for this parameter in [JsonParserGenerator]. */ + private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape -> + ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) + }, + private val customizations: List = listOf(), +) : StructuredDataParserGenerator { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig + // TODO Use? + private val codegenTarget = codegenContext.target + private val smithyCbor = CargoDependency.smithyCbor(runtimeConfig).toType() + private val protocolFunctions = ProtocolFunctions(codegenContext) + private val codegenScope = arrayOf( + "SmithyCbor" to smithyCbor, + "Decoder" to smithyCbor.resolve("Decoder"), + "Error" to smithyCbor.resolve("decode::DeserializeError"), + "HashMap" to RuntimeType.HashMap, + "Vec" to RuntimeType.Vec, + ) + + private fun listMemberParserFn( + listSymbol: Symbol, + isSparseList: Boolean, + memberShape: MemberShape, + returnUnconstrainedType: Boolean, + ) = writable { + rustBlockTemplate( + """ + fn member( + mut list: #{ListSymbol}, + decoder: &mut #{Decoder}, + ) -> Result<#{ListSymbol}, #{Error}> + """, + *codegenScope, + "ListSymbol" to listSymbol, + ) { + val deserializeMemberWritable = deserializeMember(memberShape) + if (isSparseList) { + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + decoder.null()?; + None + } + _ => Some(#{DeserializeMember:W}?), + }; + """, + *codegenScope, + "DeserializeMember" to deserializeMemberWritable, + ) + } else { + rustTemplate( + """ + let value = #{DeserializeMember:W}?; + """, + "DeserializeMember" to deserializeMemberWritable, + ) + } + + if (returnUnconstrainedType) { + rust("list.0.push(value);") + } else { + rust("list.push(value);") + } + + rust("Ok(list)") + } + } + + private fun mapPairParserFnWritable( + keyTarget: StringShape, + valueShape: MemberShape, + isSparseMap: Boolean, + mapSymbol: Symbol, + returnUnconstrainedType: Boolean, + ) = writable { + rustBlockTemplate( + """ + fn pair( + mut map: #{MapSymbol}, + decoder: &mut #{Decoder}, + ) -> Result<#{MapSymbol}, #{Error}> + """, + *codegenScope, + "MapSymbol" to mapSymbol, + ) { + val deserializeKeyWritable = deserializeString(keyTarget) + rustTemplate( + """ + let key = #{DeserializeKey:W}?; + """, + "DeserializeKey" to deserializeKeyWritable, + ) + val deserializeValueWritable = deserializeMember(valueShape) + if (isSparseMap) { + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + decoder.null()?; + None + } + _ => Some(#{DeserializeValue:W}?), + }; + """, + *codegenScope, + "DeserializeValue" to deserializeValueWritable, + ) + } else { + rustTemplate( + """ + let value = #{DeserializeValue:W}?; + """, + "DeserializeValue" to deserializeValueWritable, + ) + } + + if (returnUnconstrainedType) { + rust("map.0.insert(key, value);") + } else { + rust("map.insert(key, value);") + } + + rust("Ok(map)") + } + } + + private fun structurePairParserFnWritable(builderSymbol: Symbol, includedMembers: Collection) = writable { + rustBlockTemplate( + """ + fn pair( + mut builder: #{Builder}, + decoder: &mut #{Decoder} + ) -> Result<#{Builder}, #{Error}> + """, + *codegenScope, + "Builder" to builderSymbol, + ) { + withBlock("builder = match decoder.str()?.as_ref() {", "};") { + for (member in includedMembers) { + rustBlock("${member.memberName.dq()} =>") { + val callBuilderSetMemberFieldWritable = writable { + withBlock("builder.${member.setterName()}(", ")") { + conditionalBlock("Some(", ")", symbolProvider.toSymbol(member).isOptional()) { + val symbol = symbolProvider.toSymbol(member) + if (symbol.isRustBoxed()) { + rustBlock("") { + rustTemplate("let v = #{DeserializeMember:W}?;", + "DeserializeMember" to deserializeMember(member)) + + for (customization in customizations) { + customization.section( + CborParserSection.BeforeBoxingDeserializedMember( + member + ) + )(this) + } + rust("Box::new(v)") + } + } else { + rustTemplate("#{DeserializeMember:W}?", + "DeserializeMember" to deserializeMember(member)) + } + } + } + } + + if (member.isOptional) { + // Call `builder.set_member()` only if the value for the field on the wire is not null. + rustTemplate( + """ + ::aws_smithy_cbor::decode::set_optional(builder, decoder, |builder, decoder| { + Ok(#{MemberSettingWritable:W}) + })? + """, + "MemberSettingWritable" to callBuilderSetMemberFieldWritable + ) + } + else { + callBuilderSetMemberFieldWritable.invoke(this) + } + } + } + + rust( + """ + _ => { + decoder.skip()?; + builder + } + """) + } + rust("Ok(builder)") + } + } + + private fun unionPairParserFnWritable(shape: UnionShape) = writable { + val returnSymbolToParse = returnSymbolToParse(shape) + // TODO Test with unit variants + // TODO Test with all unit variants + rustBlockTemplate( + """ + fn pair( + decoder: &mut #{Decoder} + ) -> Result<#{UnionSymbol}, #{Error}> + """, + *codegenScope, + "UnionSymbol" to returnSymbolToParse.symbol, + ) { + withBlock("Ok(match decoder.str()? {", "})") { + for (member in shape.members()) { + val variantName = symbolProvider.toMemberName(member) + + withBlock("${member.memberName.dq()} => #T::$variantName(", "?),", returnSymbolToParse.symbol) { + deserializeMember(member).invoke(this) + } + } + // TODO Test client mode (parse unknown variant) and server mode (reject unknown variant). + // In client mode, resolve an unknown union variant to the unknown variant. + // In server mode, use strict parsing. + // Consultation: https://github.com/awslabs/smithy/issues/1222 + rust("_ => { todo!() }") + } + } + } + + private fun decodeStructureMapLoopWritable() = writable { + rustTemplate( + """ + match decoder.map()? { + None => loop { + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + break; + } + _ => { + builder = pair(builder, decoder)?; + } + }; + }, + Some(n) => { + for _ in 0..n { + builder = pair(builder, decoder)?; + } + } + }; + """, + *codegenScope, + ) + } + + // TODO This should be DRYed up with `decodeStructureMapLoopWritable`. + private fun decodeMapLoopWritable() = writable { + rustTemplate( + """ + match decoder.map()? { + None => loop { + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + break; + } + _ => { + map = pair(map, decoder)?; + } + }; + }, + Some(n) => { + for _ in 0..n { + map = pair(map, decoder)?; + } + } + }; + """, + *codegenScope, + ) + } + + // TODO This should be DRYed up with `decodeStructureMapLoopWritable`. + private fun decodeListLoop() = writable { + rustTemplate( + """ + match decoder.list()? { + None => loop { + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + break; + } + _ => { + list = member(list, decoder)?; + } + }; + }, + Some(n) => { + for _ in 0..n { + list = member(list, decoder)?; + } + } + }; + """, + *codegenScope, + ) + } + + /** + * Reusable structure parser implementation that can be used to generate parsing code for + * operation, error and structure shapes. + * We still generate the parser symbol even if there are no included members because the server + * generation requires parsers for all input structures. + */ + private fun structureParser( + shape: Shape, + builderSymbol: Symbol, + includedMembers: List, + fnNameSuffix: String? = null, + ): RuntimeType { + return protocolFunctions.deserializeFn(shape, fnNameSuffix) { fnName -> + // TODO Test no members. +// val unusedMut = if (includedMembers.isEmpty()) "##[allow(unused_mut)] " else "" + // TODO Assert token stream ended. + rustTemplate( + """ + pub(crate) fn $fnName(value: &[u8], mut builder: #{Builder}) -> Result<#{Builder}, #{Error}> { + #{StructurePairParserFn:W} + + let decoder = &mut #{Decoder}::new(value); + + #{DecodeStructureMapLoop:W} + + if decoder.position() != value.len() { + todo!() + } + + Ok(builder) + } + """, + "Builder" to builderSymbol, + "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), + "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), + *codegenScope, + ) + } + } + + override fun payloadParser(member: MemberShape): RuntimeType { + UNREACHABLE("No protocol using CBOR serialization supports payload binding") + } + + override fun operationParser(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR deserializer if there is nothing bound to the HTTP body. + val httpDocumentMembers = httpBindingResolver.responseMembers(operationShape, HttpLocation.DOCUMENT) + if (httpDocumentMembers.isEmpty()) { + return null + } + val outputShape = operationShape.outputShape(model) + return structureParser(operationShape, symbolProvider.symbolForBuilder(outputShape), httpDocumentMembers) + } + + override fun errorParser(errorShape: StructureShape): RuntimeType? { + if (errorShape.members().isEmpty()) { + return null + } + return structureParser( + errorShape, + symbolProvider.symbolForBuilder(errorShape), + errorShape.members().toList(), + fnNameSuffix = "json_err", + ) + } + + override fun serverInputParser(operationShape: OperationShape): RuntimeType? { + val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + if (includedMembers.isEmpty()) { + return null + } + val inputShape = operationShape.inputShape(model) + return structureParser(operationShape, symbolProvider.symbolForBuilder(inputShape), includedMembers) + } + + private fun RustWriter.deserializeMember(memberShape: MemberShape) = writable { + when (val target = model.expectShape(memberShape.target)) { + // Simple shapes: https://smithy.io/2.0/spec/simple-types.html + is BlobShape -> rust("decoder.blob()") + is BooleanShape -> rust("decoder.boolean()") + + is StringShape -> deserializeString(target).invoke(this) + + is ByteShape -> rust("decoder.byte()") + is ShortShape -> rust("decoder.short()") + is IntegerShape -> rust("decoder.integer()") + is LongShape -> rust("decoder.long()") + + is FloatShape -> rust("decoder.float()") + is DoubleShape -> rust("decoder.double()") + + is TimestampShape -> rust("decoder.timestamp()") + + // TODO Document shapes have not been specced out yet. + // is DocumentShape -> rustTemplate("Some(#{expect_document}(tokens)?)", *codegenScope) + + // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html + is StructureShape -> deserializeStruct(target) + is CollectionShape -> deserializeCollection(target) + is MapShape -> deserializeMap(target) + is UnionShape -> deserializeUnion(target) + else -> PANIC("unexpected shape: $target") + } + // TODO Boxing +// val symbol = symbolProvider.toSymbol(memberShape) +// if (symbol.isRustBoxed()) { +// for (customization in customizations) { +// customization.section(JsonParserSection.BeforeBoxingDeserializedMember(memberShape))(this) +// } +// rust(".map(Box::new)") +// } + } + + private fun RustWriter.deserializeString(target: StringShape, bubbleUp: Boolean = true) = writable { + // TODO Handle enum shapes + rust("decoder.string()") + } + + private fun RustWriter.deserializeCollection(shape: CollectionShape) { + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) + + // TODO Test `@sparse` and non-@sparse lists. + // - Clients should insert only non-null values in non-`@sparse` list. + // - Servers should reject upon encountering first null value in non-`@sparse` list. + // - Both clients and servers should insert null values in `@sparse` list. + + val parser = protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = writable { + withBlock("let mut list = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{Vec}::new()", *codegenScope) + } + } + } + + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { + #{ListMemberParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeListLoop:W} + + Ok(list) + } + """, + "ReturnType" to returnSymbol, + "ListMemberParserFn" to listMemberParserFn( + returnSymbol, + isSparseList = shape.hasTrait(), + shape.member, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeListLoop" to decodeListLoop(), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeMap(shape: MapShape) { + val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) + + // TODO Test `@sparse` and non-@sparse maps. + // - Clients should insert only non-null values in non-`@sparse` map. + // - Servers should reject upon encountering first null value in non-`@sparse` map. + // - Both clients and servers should insert null values in `@sparse` map. + + val parser = protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = writable { + withBlock("let mut map = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{HashMap}::new()", *codegenScope) + } + } + } + + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { + #{MapPairParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeMapLoop:W} + + Ok(map) + } + """, + "ReturnType" to returnSymbol, + "MapPairParserFn" to mapPairParserFnWritable( + keyTarget, + shape.value, + isSparseMap = shape.hasTrait(), + returnSymbol, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeMapLoop" to decodeMapLoopWritable(), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeStruct(shape: StructureShape) { + val returnSymbolToParse = returnSymbolToParse(shape) + val parser = protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + "pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}>", + "ReturnType" to returnSymbolToParse.symbol, + *codegenScope, + ) { + val builderSymbol = symbolProvider.symbolForBuilder(shape) + val includedMembers = shape.members() + + rustTemplate( + """ + #{StructurePairParserFn:W} + + let mut builder = #{Builder}::default(); + + #{DecodeStructureMapLoop:W} + """, + *codegenScope, + "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), + "Builder" to builderSymbol, + "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), + ) + + // Only call `build()` if the builder is not fallible. Otherwise, return the builder. + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(builder)") + } else { + rust("Ok(builder.build())") + } + } + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeUnion(shape: UnionShape) { + val returnSymbolToParse = returnSymbolToParse(shape) + val parser = protocolFunctions.deserializeFn(shape) { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{UnionSymbol}, #{Error}> { + #{UnionPairParserFnWritable} + + match decoder.map()? { + None => { + let variant = pair(decoder)?; + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + Ok(variant) + } + ty => Err( + #{Error}::unexpected_union_variant( + ty, + decoder.position(), + ), + ), + } + } + Some(1) => pair(decoder), + Some(_) => Err(#{Error}::mixed_union_variants(decoder.position())) + } + } + """, + "UnionSymbol" to returnSymbolToParse.symbol, + "UnionPairParserFnWritable" to unionPairParserFnWritable(shape), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 4833f808fb..5f248aad33 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -77,8 +77,6 @@ sealed class JsonParserSection(name: String) : Section(name) { */ typealias JsonParserCustomization = NamedCustomization -data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) - class JsonParserGenerator( private val codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, @@ -447,7 +445,7 @@ class JsonParserGenerator( } private fun RustWriter.deserializeMap(shape: MapShape) { - val keyTarget = model.expectShape(shape.key.target) as StringShape + val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) val isSparse = shape.hasTrait() val returnSymbolToParse = returnSymbolToParse(shape) val parser = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt new file mode 100644 index 0000000000..d78f2d98fd --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt @@ -0,0 +1,8 @@ +package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse + +import software.amazon.smithy.codegen.core.Symbol + +/** + * Given a shape, parsers need to know the symbol to parse and return, and whether it's unconstrained or not. + */ +data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt index f8b053d80f..fd7e6fcb28 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt @@ -12,8 +12,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType interface StructuredDataParserGenerator { /** - * Generate a parse function for a given targeted as a payload. + * Generate a parse function for a given shape targeted with `@httpPayload`. * Entry point for payload-based parsing. + * * Roughly: * ```rust * fn parse_my_struct(input: &[u8]) -> Result { @@ -49,6 +50,7 @@ interface StructuredDataParserGenerator { /** * Generate a parser for a server operation input structure + * * ```rust * fn deser_operation_crate_operation_my_operation_input( * value: &[u8], builder: my_operation_input::Builder diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt new file mode 100644 index 0000000000..68533ef7d1 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -0,0 +1,469 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize + +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.isTargetUnit +import software.amazon.smithy.rust.codegen.core.util.outputShape + +// TODO Cleanup commented and unused code. + +/** + * Class describing a JSON serializer section that can be used in a customization. + */ +sealed class CborSerializerSection(name: String) : Section(name) { + /** + * Mutate the serializer prior to serializing any structure members. Eg: this can be used to inject `__type` + * to record the error type in the case of an error structure. + */ + data class BeforeSerializingStructureMembers(val structureShape: StructureShape, val encoderBindingName: String) : + CborSerializerSection("ServerError") + + /** Manipulate the serializer context for a map prior to it being serialized. **/ + data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context) : + CborSerializerSection("BeforeIteratingOverMapOrCollection") +} + +/** + * Customization for the CBOR serializer. + */ +typealias CborSerializerCustomization = NamedCustomization + +class CborSerializerGenerator( + codegenContext: CodegenContext, + private val httpBindingResolver: HttpBindingResolver, + private val customizations: List = listOf(), +) : StructuredDataSerializerGenerator { + data class Context( + /** Expression representing the value to write to the encoder */ + var valueExpression: ValueExpression, + /** Path in the JSON to get here, used for errors */ + val shape: T, + ) + + data class MemberContext( + /** Name for the variable bound to the encoder object **/ + val encoderBindingName: String, + /** Expression representing the value to write to the `Encoder` */ + var valueExpression: ValueExpression, + val shape: MemberShape, + /** Whether to serialize null values if the type is optional */ + val writeNulls: Boolean = false, + ) { + companion object { + fun collectionMember(context: Context, itemName: String): MemberContext = + MemberContext( + "encoder", + ValueExpression.Reference(itemName), + context.shape.member, + writeNulls = true, + ) + + fun mapMember(context: Context, key: String, value: String): MemberContext = + MemberContext( + "encoder.str($key)", + ValueExpression.Reference(value), + context.shape.value, + writeNulls = true, + ) + + fun structMember( + context: StructContext, + member: MemberShape, + symProvider: RustSymbolProvider, + ): MemberContext = + MemberContext( + encodeKeyExpression(member.memberName), + ValueExpression.Value("${context.localName}.${symProvider.toMemberName(member)}"), + member, + ) + + fun unionMember( + variantReference: String, + member: MemberShape, + ): MemberContext = + MemberContext( + encodeKeyExpression(member.memberName), + ValueExpression.Reference(variantReference), + member, + ) + + /** Returns an expression to encode a key member **/ + private fun encodeKeyExpression(name: String): String = + "encoder.str(${name.dq()})" + } + } + + // Specialized since it holds a JsonObjectWriter expression rather than a JsonValueWriter + data class StructContext( + /** Name of the variable that holds the struct */ + val localName: String, + val shape: StructureShape, + ) + + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val codegenTarget = codegenContext.target + private val runtimeConfig = codegenContext.runtimeConfig + private val protocolFunctions = ProtocolFunctions(codegenContext) + // TODO Cleanup + private val codegenScope = arrayOf( + "String" to RuntimeType.String, + "Error" to runtimeConfig.serializationError(), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "Encoder" to RuntimeType.smithyCbor(runtimeConfig).resolve("Encoder"), + "ByteSlab" to RuntimeType.ByteSlab, + ) + private val serializerUtil = SerializerUtil(model, symbolProvider) + + /** + * Reusable structure serializer implementation that can be used to generate serializing code for + * operation outputs or errors. + * This function is only used by the server, the client uses directly [serializeStructure]. + */ + private fun serverSerializer( + structureShape: StructureShape, + includedMembers: List, + error: Boolean, + ): RuntimeType { + val suffix = when (error) { + true -> "error" + else -> "output" + } + return protocolFunctions.serializeFn(structureShape, fnNameSuffix = suffix) { fnName -> + rustBlockTemplate( + "pub fn $fnName(value: &#{target}) -> Result, #{Error}>", + *codegenScope, + "target" to symbolProvider.toSymbol(structureShape), + ) { + rustTemplate("let mut encoder = #{Encoder}::new(Vec::new());", *codegenScope) + // Open a scope in which we can safely shadow the `encoder` variable to bind it to a mutable reference. + rustBlock("") { + rust("let encoder = &mut encoder;") + serializeStructure(StructContext("value", structureShape), includedMembers) + } + rust("Ok(encoder.into_writer())") + } + } + } + + // TODO + override fun payloadSerializer(member: MemberShape): RuntimeType { + val target = model.expectShape(member.target) + return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName -> + rustBlockTemplate( + "pub fn $fnName(input: &#{target}) -> std::result::Result<#{ByteSlab}, #{Error}>", + *codegenScope, + "target" to symbolProvider.toSymbol(target), + ) { + rust("let mut out = String::new();") + rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) + when (target) { + is StructureShape -> serializeStructure(StructContext("input", target)) + is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target)) + else -> throw IllegalStateException("json payloadSerializer only supports structs and unions") + } + rust("object.finish();") + rustTemplate("Ok(out.into_bytes())", *codegenScope) + } + } + } + + // TODO Unclear whether we'll need this. + override fun unsetStructure(structure: StructureShape): RuntimeType = + ProtocolFunctions.crossOperationFn("rest_json_unsetpayload") { fnName -> + rustTemplate( + """ + pub fn $fnName() -> #{ByteSlab} { + b"{}"[..].into() + } + """, + *codegenScope, + ) + } + + override fun unsetUnion(union: UnionShape): RuntimeType { + // TODO + TODO("Not yet implemented") + } + + override fun operationInputSerializer(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR serializer if there is no CBOR body. + val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + if (httpDocumentMembers.isEmpty()) { + return null + } + + val inputShape = operationShape.inputShape(model) + return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> + rustBlockTemplate( + "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>", + *codegenScope, "target" to symbolProvider.toSymbol(inputShape), + ) { + rust("let mut out = String::new();") + rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) + serializeStructure(StructContext("input", inputShape), httpDocumentMembers) + rust("object.finish();") + rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope) + } + } + } + + override fun documentSerializer(): RuntimeType { + return ProtocolFunctions.crossOperationFn("serialize_document") { fnName -> + rustTemplate( + """ + pub fn $fnName(input: &#{Document}) -> #{ByteSlab} { + let mut out = String::new(); + #{JsonValueWriter}::new(&mut out).document(input); + out.into_bytes() + } + """, + "Document" to RuntimeType.document(runtimeConfig), *codegenScope, + ) + } + } + + override fun operationOutputSerializer(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation JSON serializer if there was no operation output shape in the + // original (untransformed) model. + val syntheticOutputTrait = operationShape.outputShape(model).expectTrait() + if (syntheticOutputTrait.originalId == null) { + return null + } + + // TODO + // Note that, unlike the client, we serialize an empty JSON document `"{}"` if the operation output shape is + // empty (has no members). + // The client instead serializes an empty payload `""` in _both_ these scenarios: + // 1. there is no operation input shape; and + // 2. the operation input shape is empty (has no members). + // The first case gets reduced to the second, because all operations get a synthetic input shape with + // the [OperationNormalizer] transformation. + val httpDocumentMembers = httpBindingResolver.responseMembers(operationShape, HttpLocation.DOCUMENT) + + val outputShape = operationShape.outputShape(model) + return serverSerializer(outputShape, httpDocumentMembers, error = false) + } + + override fun serverErrorSerializer(shape: ShapeId): RuntimeType { + val errorShape = model.expectShape(shape, StructureShape::class.java) + val includedMembers = + httpBindingResolver.errorResponseBindings(shape).filter { it.location == HttpLocation.DOCUMENT } + .map { it.member } + return serverSerializer(errorShape, includedMembers, error = true) + } + + private fun RustWriter.serializeStructure( + context: StructContext, + includedMembers: List? = null, + ) { + // TODO Need to inject `__type` when serializing errors. + val structureSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> Result<(), #{Error}>", + "StructureSymbol" to symbolProvider.toSymbol(context.shape), + *codegenScope, + ) { + // TODO If all members are non-`Option`-al, we know AOT the map's size and can use `.map()` + // instead of `.begin_map()` for efficiency. Add test. + rust("encoder.begin_map();") + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeSerializingStructureMembers(context.shape, "encoder"))(this) + } + context.copy(localName = "input").also { inner -> + val members = includedMembers ?: inner.shape.members() + for (member in members) { + serializeMember(MemberContext.structMember(inner, member, symbolProvider)) + } + } + rust("encoder.end();") + rust("Ok(())") + } + } + rust("#T(encoder, ${context.localName})?;", structureSerializer) + } + + private fun RustWriter.serializeMember(context: MemberContext) { + val targetShape = model.expectShape(context.shape.target) + if (symbolProvider.toSymbol(context.shape).isOptional()) { + safeName().also { local -> + rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { + context.valueExpression = ValueExpression.Reference(local) + serializeMemberValue(context, targetShape) + } + if (context.writeNulls) { + rustBlock("else") { + rust("${context.encoderBindingName}.null();") + } + } + } + } else { + with(serializerUtil) { + ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) { + serializeMemberValue(context, targetShape) + } + } + } + } + + private fun RustWriter.serializeMemberValue(context: MemberContext, target: Shape) { + val encoder = context.encoderBindingName + val value = context.valueExpression + val containerShape = model.expectShape(context.shape.container) + + when (target) { + // Simple shapes: https://smithy.io/2.0/spec/simple-types.html + is BlobShape -> rust("$encoder.blob(${value.asRef()});") + is BooleanShape -> rust("$encoder.boolean(${value.asValue()});") + + is StringShape -> rust("$encoder.str(${value.name}.as_str());") + + is ByteShape -> rust("$encoder.byte(${value.asValue()});") + is ShortShape -> rust("$encoder.short(${value.asValue()});") + is IntegerShape -> rust("$encoder.integer(${value.asValue()});") + is LongShape -> rust("$encoder.long(${value.asValue()});") + + is FloatShape -> rust("$encoder.float(${value.asValue()});") + is DoubleShape -> rust("$encoder.double(${value.asValue()});") + + is TimestampShape -> rust("$encoder.timestamp(${value.asRef()});") + + // TODO Document shapes have not been specced out yet. + // is DocumentShape -> rust("$encoder.document(${value.asRef()});") + + // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html + else -> { + // This condition is equivalent to `containerShape !is CollectionShape`. + if (containerShape is StructureShape || containerShape is UnionShape || containerShape is MapShape) { + rust("$encoder;") // Encode the member key. + } + when (target) { + is StructureShape -> serializeStructure(StructContext(value.name, target)) + is CollectionShape -> serializeCollection(Context(value, target)) + is MapShape -> serializeMap(Context(value, target)) + is UnionShape -> serializeUnion(Context(value, target)) + else -> UNREACHABLE("Smithy added a new aggregate shape: $target") + } + } + } + } + + private fun RustWriter.serializeCollection(context: Context) { + // `.expect()` safety: `From for usize` is not in the standard library, but the conversion should be + // infallible (unless we ever have 128-bit machines I guess). + // See https://users.rust-lang.org/t/cant-convert-usize-to-u64/6243. + // TODO Point to a `static` to not inflate the binary. + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) + } + rust( + """ + encoder.array( + (${context.valueExpression.asValue()}).len().try_into().expect("`usize` to `u64` conversion failed") + ); + """ + ) + val itemName = safeName("item") + rustBlock("for $itemName in ${context.valueExpression.asRef()}") { + serializeMember(MemberContext.collectionMember(context, itemName)) + } + } + + private fun RustWriter.serializeMap(context: Context) { + val keyName = safeName("key") + val valueName = safeName("value") + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) + } + rust( + """ + encoder.map( + (${context.valueExpression.asValue()}).len().try_into().expect("`usize` to `u64` conversion failed") + ); + """ + ) + rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { + val keyExpression = "$keyName.as_str()" + serializeMember(MemberContext.mapMember(context, keyExpression, valueName)) + } + } + + private fun RustWriter.serializeUnion(context: Context) { + val unionSymbol = symbolProvider.toSymbol(context.shape) + val unionSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, input: &#{UnionSymbol}) -> Result<(), #{Error}>", + "UnionSymbol" to unionSymbol, + *codegenScope, + ) { + // A union is serialized identically as a `structure` shape, but only a single member can be set to a + // non-null value. + rust("encoder.map(1);") + rustBlock("match input") { + for (member in context.shape.members()) { + val variantName = if (member.isTargetUnit()) { + symbolProvider.toMemberName(member) + } else { + "${symbolProvider.toMemberName(member)}(inner)" + } + rustBlock("#T::$variantName =>", unionSymbol) { + serializeMember(MemberContext.unionMember("inner", member)) + } + } + if (codegenTarget.renderUnknownVariant()) { + rustTemplate( + "#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", + "Union" to unionSymbol, + *codegenScope, + ) + } + } + rust("Ok(())") + } + } + rust("#T(encoder, ${context.valueExpression.asRef()})?;", unionSerializer) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 46341bb09c..8b18dfa5c8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -212,12 +212,21 @@ class JsonSerializerGenerator( *codegenScope, "target" to symbolProvider.toSymbol(structureShape), ) { - rust("let mut out = String::new();") - rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) + rustTemplate( + """ + let mut out = String::new(); + let mut object = #{JsonObjectWriter}::new(&mut out); + """, + *codegenScope, + ) serializeStructure(StructContext("object", "value", structureShape), includedMembers) customizations.forEach { it.section(makeSection(structureShape, "object"))(this) } - rust("object.finish();") - rustTemplate("Ok(out)", *codegenScope) + rust( + """ + object.finish(); + Ok(out) + """ + ) } } } diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index cab36dbbcc..1c9c9d4c7d 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -16,7 +16,6 @@ plugins { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val pluginName = "rust-server-codegen" @@ -25,6 +24,7 @@ val workingDirUnderBuildDir = "smithyprojections/codegen-server-test/" dependencies { implementation(project(":codegen-server")) implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") @@ -44,6 +44,19 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"), ), CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")), + // CodegenTest("aws.protocoltests.restxml#RestXml", "restXml"), + CodegenTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"), + // Todo: change this to rpcv2extra + CodegenTest("com.amazonaws.simple#RpcV2Service", "rpcv2Extra", imports = listOf("$commonModels/rpcv2.smithy")), + CodegenTest( + "aws.protocoltests.rpcv2#RpcV2Protocol", + "adwait-main", + imports = listOf( + "$commonModels/adwait-main.smithy", + "$commonModels/adwait-cbor-structs.smithy", + "$commonModels/adwait-empty-input-output.smithy", + ) + ), CodegenTest( "com.amazonaws.constraints#ConstraintsService", "constraints_without_public_constrained_types", @@ -103,7 +116,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace", "generateCargoConfigToml") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile, defaultRustDocFlags) +project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/codegen-server-test/python/build.gradle.kts b/codegen-server-test/python/build.gradle.kts index f9129a2fde..b6ee22147a 100644 --- a/codegen-server-test/python/build.gradle.kts +++ b/codegen-server-test/python/build.gradle.kts @@ -16,7 +16,6 @@ plugins { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val buildDir = layout.buildDirectory.get().asFile @@ -120,7 +119,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(buildDir.resolve(workingDirUnderBuildDir), defaultRustDocFlags) +project.registerCargoCommandsTasks(buildDir.resolve(workingDirUnderBuildDir)) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/codegen-server-test/typescript/build.gradle.kts b/codegen-server-test/typescript/build.gradle.kts index 428da17df6..22c27b7d90 100644 --- a/codegen-server-test/typescript/build.gradle.kts +++ b/codegen-server-test/typescript/build.gradle.kts @@ -16,7 +16,6 @@ plugins { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val buildDir = layout.buildDirectory.get().asFile @@ -49,7 +48,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(buildDir.resolve(workingDirUnderBuildDir), defaultRustDocFlags) +project.registerCargoCommandsTasks(buildDir.resolve(workingDirUnderBuildDir)) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/codegen-server/build.gradle.kts b/codegen-server/build.gradle.kts index 0ba262225d..2dbc33df05 100644 --- a/codegen-server/build.gradle.kts +++ b/codegen-server/build.gradle.kts @@ -26,6 +26,7 @@ dependencies { implementation(project(":codegen-core")) implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") // `smithy.framework#ValidationException` is defined here, which is used in `constraints.smithy`, which is used // in `CustomValidationExceptionWithReasonDecoratorTest`. diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 01df0c0a93..a9c488503d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig */ object ServerCargoDependency { val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1.74")) + val Base64SimdDev: CargoDependency = CargoDependency("base64-simd", CratesIo("0.8"), scope = DependencyScope.Dev) val FormUrlEncoded: CargoDependency = CargoDependency("form_urlencoded", CratesIo("1")) val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3")) val Mime: CargoDependency = CargoDependency("mime", CratesIo("0.3")) @@ -26,7 +27,7 @@ object ServerCargoDependency { val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev) val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) - val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), DependencyScope.Dev) + val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), scope = DependencyScope.Dev) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt new file mode 100644 index 0000000000..59ead97478 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.escape +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext + +/** + * Smithy RPC v2 CBOR requires errors to be serialized in server responses with an additional `__type` field. + */ +class AddTypeFieldToServerErrorsCborCustomization : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = when (section) { + is CborSerializerSection.BeforeSerializingStructureMembers -> + if (section.structureShape.hasTrait()) { + writable { + rust( + """ + ${section.encoderBindingName} + .str("__type") + .str("${escape(section.structureShape.id.toString())}"); + """ + ) + } + } else { + emptySection + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt new file mode 100644 index 0000000000..42eb1f6843 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +/** + * A customization to, just before we encode over a _constrained_ map or collection shape in a CBOR serializer, + * unwrap the wrapper newtype and take a shared reference to the actual value within it. + * That value will be a `std::collections::HashMap` for map shapes, and a `std::vec::Vec` for collection shapes. + */ +class BeforeEncodingMapOrCollectionCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = when (section) { + is CborSerializerSection.BeforeIteratingOverMapOrCollection -> writable { + check(section.shape is CollectionShape || section.shape is MapShape) + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 2fb76bf879..d4281b4ddf 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -21,18 +21,27 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2 import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.ReturnSymbolToParse +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator @@ -70,8 +79,8 @@ interface ServerProtocol : Protocol { fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType /** - * In some protocols, such as restJson1, - * when there is no modeled body input, content type must not be set and the body must be empty. + * In some protocols, such as restJson1 and rpcv2, + * when there is no modeled body input, `content-type` must not be set and the body must be empty. * Returns a boolean indicating whether to perform this check. */ fun serverContentTypeCheckNoModeledInput(): Boolean = false @@ -166,7 +175,10 @@ class ServerAwsJsonProtocol( rust("""String::from("$serviceName.$operationName")""") } - override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.String + // TODO This could technically be `&static str` right? + override fun serverRouterRequestSpecType( + requestSpecModule: RuntimeType, + ): RuntimeType = RuntimeType.String override fun serverRouterRuntimeConstructor() = when (version) { @@ -255,8 +267,8 @@ class ServerRestXmlProtocol( } /** - * A customization to, just before we box a recursive member that we've deserialized into `Option`, convert it into - * `MaybeConstrained` if the target shape can reach a constrained shape. + * A customization to, just before we box a recursive member that we've deserialized from JSON into `Option`, convert + * it into `MaybeConstrained` if the target shape can reach a constrained shape. */ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(val codegenContext: ServerCodegenContext) : JsonParserCustomization() { @@ -276,3 +288,74 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonPa else -> emptySection } } + +/** + * A customization to, just before we box a recursive member that we've deserialized from CBOR into `T` held in a + * variable binding `v`, convert it into `MaybeConstrained` if the target shape can reach a constrained shape. + */ +class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization(val codegenContext: ServerCodegenContext) : + CborParserCustomization() { + override fun section(section: CborParserSection): Writable = when (section) { + is CborParserSection.BeforeBoxingDeserializedMember -> writable { + // We're only interested in _structure_ member shapes that can reach constrained shapes. + if ( + codegenContext.model.expectShape(section.shape.container) is StructureShape && + section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) + ) { + rust("let v = v.into();") + } + } + } +} + +class ServerRpcV2Protocol( + private val serverCodegenContext: ServerCodegenContext, +) : RpcV2(serverCodegenContext), ServerProtocol { + val runtimeConfig = codegenContext.runtimeConfig + + override val protocolModulePath = "rpc_v2" + + override fun structuredDataParser(): StructuredDataParserGenerator = + CborParserGenerator( + serverCodegenContext, httpBindingResolver, returnSymbolToParseFn(serverCodegenContext), + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization( + serverCodegenContext, + ), + ), + ) + + override fun structuredDataSerializer(): StructuredDataSerializerGenerator { + return CborSerializerGenerator( + codegenContext, + httpBindingResolver, + listOf( + BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), + AddTypeFieldToServerErrorsCborCustomization(), + ), + ) + } + + override fun markerStruct() = ServerRuntimeType.protocol("RpcV2", "rpc_v2", runtimeConfig) + + override fun routerType() = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() + .resolve("protocol::rpc_v2::router::RpcV2Router") + + override fun serverRouterRequestSpec( + operationShape: OperationShape, + operationName: String, + serviceName: String, + requestSpecModule: RuntimeType, + ) = writable { + // This is just the key used by the router's map to store and lookup operations, it's completely arbitrary. + // We use the same key used by the awsJson1.x routers for simplicity. + rust("$serviceName.$operationName".dq()) + } + + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = + RuntimeType.StaticStr + + override fun serverRouterRuntimeConstructor() = "rpc_v2_router" + + override fun serverContentTypeCheckNoModeledInput() = false +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index c35518319c..c9ffa06cc4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -26,6 +26,7 @@ import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords @@ -42,6 +43,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors +import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember @@ -100,6 +102,7 @@ class ServerProtocolTestGenerator( private val codegenScope = arrayOf( + "Base64SimdDev" to ServerCargoDependency.Base64SimdDev.toType(), "Bytes" to RuntimeType.Bytes, "SmithyHttp" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), "Http" to RuntimeType.Http, @@ -278,6 +281,11 @@ class ServerProtocolTestGenerator( testModuleWriter.newlinePrefix = "" Attribute.TokioTest.render(testModuleWriter) + Attribute.TracedTest.render(testModuleWriter) + // The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in + // case the code rendered by the test does not make use of `tracing` at all. + val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) } + testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) if (expectFail(testCase)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") @@ -309,7 +317,7 @@ class ServerProtocolTestGenerator( } with(httpRequestTestCase) { - renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) + renderHttpRequest(uri, method, headers, body.orNull(), bodyMediaType.orNull(), queryParams, host.orNull()) } if (protocolSupport.requestBodyDeserialization) { makeRequest(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase)) @@ -387,7 +395,9 @@ class ServerProtocolTestGenerator( rustBlock("") { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. - renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) + // TODO(https://github.com/smithy-lang/smithy/issues/1932): we send `null` for `bodyMediaType` for now but + // the Smithy protocol test should give it to us. + renderHttpRequest(uri.get(), method, headers, body.orNull(), bodyMediaType = null, queryParams, host.orNull()) } makeRequest( @@ -405,6 +415,7 @@ class ServerProtocolTestGenerator( method: String, headers: Map, body: String?, + bodyMediaType: String?, queryParams: List, host: String?, ) { @@ -434,8 +445,17 @@ class ServerProtocolTestGenerator( // // We also escape to avoid interactions with templating in the case where the body contains `#`. val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() + + val encodedBody = + """ + #{Bytes}::from( + #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( + "`body` field of Smithy protocol test is not correctly base64 encoded" + ) + ) + """ - "#{SmithyHttpServer}::body::Body::from(#{Bytes}::from_static($sanitizedBody.as_bytes()))" + "#{SmithyHttpServer}::body::Body::from($encodedBody)" } else { "#{SmithyHttpServer}::body::Body::empty()" } @@ -660,13 +680,13 @@ class ServerProtocolTestGenerator( rustWriter.rustTemplate( """ // No body. - #{AssertEq}(std::str::from_utf8(&body).unwrap(), ""); + #{AssertEq}(&body, &bytes::Bytes::new()); """, *codegenScope, ) } else { assertOk(rustWriter) { - rustWriter.rust( + rust( "#T(&body, ${ rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 910acf4a70..ac966853e6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -114,7 +114,7 @@ typealias ServerHttpBoundProtocolCustomization = NamedCustomization = listOf(), additionalHttpBindingCustomizations: List = listOf(), ) : ServerProtocolGenerator( - protocol, - ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), - ) { + protocol, + ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), +) { + // TODO Delete, unused // Define suffixes for operation input / output / error wrappers companion object { const val OPERATION_INPUT_WRAPPER_SUFFIX = "OperationInputWrapper" @@ -604,13 +605,35 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } + private fun setResponseHeaderIfAbsent(writer: RustWriter, headerName: String, headerValue: String) { + // We can be a tad more efficient if there's a `const` `HeaderName` in the `http` crate that matches. + // https://docs.rs/http/latest/http/header/index.html#constants + val headerNameExpr = if (headerName == "content-type") { + "#{http}::header::CONTENT_TYPE" + } else { + "#{http}::header::HeaderName::from_static(\"$headerName\")" + } + + writer.rustTemplate( + """ + builder = #{header_util}::set_response_header_if_absent( + builder, + $headerNameExpr, + "${writer.escape(headerValue)}", + ); + """, + *codegenScope, + ) + } + + /** * Sets HTTP response headers for the operation's output shape or the operation's error shape. * It will generate response headers for the operation's output shape, unless [errorShape] is non-null, in which * case it will generate response headers for the given error shape. * * It sets three groups of headers in order. Headers from one group take precedence over headers in a later group. - * 1. Headers bound by the `httpHeader` and `httpPrefixHeader` traits. = null + * 1. Headers bound by the `httpHeader` and `httpPrefixHeader` traits. * 2. The protocol-specific `Content-Type` header for the operation. * 3. Additional protocol-specific headers for errors, if [errorShape] is non-null. */ @@ -626,7 +649,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( rust( """ builder = #{T}($outputOwnedOrBorrowed, builder)?; - """.trimIndent(), + """, addHeadersFn, ) } @@ -635,32 +658,17 @@ class ServerHttpBoundProtocolTraitImplGenerator( // to allow operations that bind a member to `Content-Type` (which we set earlier) to take precedence (this is // because we always use `set_response_header_if_absent`, so the _first_ header value we set for a given // header name is the one that takes precedence). - val contentType = httpBindingResolver.responseContentType(operationShape) - if (contentType != null) { - rustTemplate( - """ - builder = #{header_util}::set_response_header_if_absent( - builder, - #{http}::header::CONTENT_TYPE, - "$contentType" - ); - """, - *codegenScope, - ) + httpBindingResolver.responseContentType(operationShape)?.let { contentTypeValue -> + setResponseHeaderIfAbsent(this, "content-type", contentTypeValue) + } + + for ((headerName, headerValue) in protocol.additionalResponseHeaders(operationShape)) { + setResponseHeaderIfAbsent(this, headerName, headerValue) } if (errorShape != null) { for ((headerName, headerValue) in protocol.additionalErrorResponseHeaders(errorShape)) { - rustTemplate( - """ - builder = #{header_util}::set_response_header_if_absent( - builder, - http::header::HeaderName::from_static("$headerName"), - "${escape(headerValue)}" - ); - """, - *codegenScope, - ) + setResponseHeaderIfAbsent(this, headerName, headerValue) } } } @@ -747,13 +755,14 @@ class ServerHttpBoundProtocolTraitImplGenerator( "RequestParts" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("http::RequestParts"), ) val parser = structuredDataParser.serverInputParser(operationShape) - val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { // `null` is only returned by Smithy when there are no members, but we know there's at least one, since // there's something to parse (i.e. `parser != null`), so `!!` is safe here. val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!! rustTemplate("let bytes = #{Hyper}::body::to_bytes(body).await?;", *codegenScope) + // TODO Isn't this VERY wrong? If there's modeled operation input, we must reject if there's no payload! + // We currently accept and silently build empty input! rustBlock("if !bytes.is_empty()") { rustTemplate( """ @@ -793,6 +802,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( serverRenderUriPathParser(this, operationShape) serverRenderQueryStringParser(this, operationShape) + val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) { conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) { rustTemplate( @@ -1300,6 +1310,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( * Returns the error type of the function that deserializes a non-streaming HTTP payload (a byte slab) into the * shape targeted by the `httpPayload` trait. */ + // TODO This should not live here. Plus, only some protocols support `@httpPayload`. private fun getDeserializePayloadErrorSymbol(binding: HttpBindingDescriptor): Symbol { check(binding.location == HttpLocation.PAYLOAD) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index ae87ec5723..446f805851 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -20,7 +21,7 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator -class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { +class StreamPayloadSerializerCustomization : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { is ServerHttpBoundProtocolSection.WrapStreamPayload -> @@ -79,6 +80,8 @@ class ServerProtocolLoader(supportedProtocols: ProtocolMap { + override fun protocol(codegenContext: ServerCodegenContext): Protocol = + ServerRpcV2Protocol(codegenContext) + + override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = + ServerHttpBoundProtocolGenerator(codegenContext, ServerRpcV2Protocol(codegenContext)) + + override fun support(): ProtocolSupport { + return ProtocolSupport( + /* Client support */ + requestSerialization = false, + requestBodySerialization = false, + responseDeserialization = false, + errorDeserialization = false, + /* Server support */ + requestDeserialization = true, + requestBodyDeserialization = true, + responseSerialization = true, + errorSerialization = true, + ) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt new file mode 100644 index 0000000000..3b69814ffd --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt @@ -0,0 +1,42 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest + +// TODO This won't be needed since we'll cover it with a proper integration test. +internal class RpcV2Test { + val model = """ + ${"\$"}version: "2.0" + + namespace com.amazonaws.simple + + use smithy.protocols#rpcv2 + + @rpcv2(format: ["cbor"]) + service RpcV2Service { + version: "SomeVersion", + operations: [RpcV2Operation], + } + + @http(uri: "/operation", method: "POST") + operation RpcV2Operation { + input: OperationInputOutput + output: OperationInputOutput + } + + structure OperationInputOutput { + message: String + } + """.asSmithyModel() + + @Test + fun `generate a rpc v2 service that compiles`() { + serverIntegrationTest(model) { _, _ -> } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt new file mode 100644 index 0000000000..ed418ad71f --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt @@ -0,0 +1,126 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.protocoltests.traits.AppliesTo +import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2 +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.outputShape +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.io.File + +internal class CborSerializerGeneratorTest { + class DeriveSerdeDeserializeSymbolMetadataProvider( + private val base: RustSymbolProvider, + ) : SymbolMetadataProvider(base) { + private fun addDeriveSerdeDeserialize(shape: Shape): RustMetadata { + check(shape !is MemberShape) + + val baseMetadata = base.toSymbol(shape).expectRustMetadata() + return baseMetadata.withDerives(RuntimeType.SerdeDeserialize) + } + + override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + + override fun structureMeta(structureShape: StructureShape) = addDeriveSerdeDeserialize(structureShape) + override fun unionMeta(unionShape: UnionShape) = addDeriveSerdeDeserialize(unionShape) + override fun enumMeta(stringShape: StringShape) = addDeriveSerdeDeserialize(stringShape) + + override fun listMeta(listShape: ListShape): RustMetadata = addDeriveSerdeDeserialize(listShape) + override fun mapMeta(mapShape: MapShape): RustMetadata = addDeriveSerdeDeserialize(mapShape) + override fun stringMeta(stringShape: StringShape): RustMetadata = addDeriveSerdeDeserialize(stringShape) + override fun numberMeta(numberShape: NumberShape): RustMetadata = addDeriveSerdeDeserialize(numberShape) + override fun blobMeta(blobShape: BlobShape): RustMetadata = addDeriveSerdeDeserialize(blobShape) + } + + @Test + fun `we serialize and serde_cbor deserializes round trip`() { + val model = File("../codegen-core/common-test-models/rpcv2.smithy").readText().asSmithyModel() + + val addDeriveSerdeSerializeDecorator = object : ServerCodegenDecorator { + override val name: String = "Add `#[derive(serde::Deserialize)]`" + override val order: Byte = 0 + + override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = + DeriveSerdeDeserializeSymbolMetadataProvider(base) + } + + serverIntegrationTest( + model, + additionalDecorators = listOf(addDeriveSerdeSerializeDecorator), + ) { codegenContext, rustCrate -> + val codegenScope = arrayOf( + "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), + "SerdeCbor" to CargoDependency.SerdeCbor.toType(), + ) + + val instantiator = ServerInstantiator(codegenContext) + val rpcV2 = RpcV2(codegenContext) + + for (operationShape in codegenContext.model.operationShapes) { + val outputShape = operationShape.outputShape(codegenContext.model) + // TODO Use `httpRequestTests` and error tests too. + val tests = operationShape.getTrait() + ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { + ServerProtocolTestGenerator.TestCase.ResponseTest( + it, + outputShape, + ) + } + val serializeFn = rpcV2 + .structuredDataSerializer() + .operationOutputSerializer(operationShape) + ?: continue // Skip if there's nothing to serialize. + + // TODO Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on + // which we can't `#[derive(Deserialize)]`. + rustCrate.withModule(ProtocolFunctions.serDeModule) { + for ((idx, test) in tests.withIndex()) { + unitTest("TODO_$idx") { + rustTemplate( + """ + let expected = #{InstantiateShape:W}; + let bytes = #{SerializeFn}(&expected) + .expect("generated CBOR serializer failed"); + let actual = #{SerdeCbor}::from_slice(&bytes) + .expect("serde_cbor failed deserializing from bytes"); + #{AssertEq}(expected, actual); + """, + "InstantiateShape" to instantiator.generate(test.targetShape, test.testCase.params), + "SerializeFn" to serializeFn, + *codegenScope, + ) + } + } + } + } + } + } +} diff --git a/examples/Cargo.toml b/examples/Cargo.toml index d92869a661..f65cd8c4ec 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -8,8 +8,11 @@ members = [ "pokemon-service-lambda", "pokemon-service-server-sdk", "pokemon-service-client", +<<<<<<< HEAD +======= "pokemon-service-client-usage", +>>>>>>> release-2023-11-01 ] [profile.release] diff --git a/examples/Makefile b/examples/Makefile index 9c2bf9061d..791248c0a0 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -3,16 +3,23 @@ CUR_DIR := $(shell pwd) GRADLE := $(SRC_DIR)/gradlew SERVER_SDK_DST := $(CUR_DIR)/pokemon-service-server-sdk CLIENT_SDK_DST := $(CUR_DIR)/pokemon-service-client +RPCV2_SDK_DST := $(CUR_DIR)/rpcv2-server-sdk +RPCV2_CLIENT_DST := $(CUR_DIR)/rpcv2-pokemon-client + SERVER_SDK_SRC := $(SRC_DIR)/codegen-server-test/build/smithyprojections/codegen-server-test/pokemon-service-server-sdk/rust-server-codegen CLIENT_SDK_SRC := $(SRC_DIR)/codegen-client-test/build/smithyprojections/codegen-client-test/pokemon-service-client/rust-client-codegen +RPCV2_SDK_SRC := $(SRC_DIR)/codegen-server-test/build/smithyprojections/codegen-server-test/rpcv2/rust-server-codegen +RPCV2_CLIENT_SRC := $(SRC_DIR)/codegen-client-test/build/smithyprojections/codegen-client-test/rpcv2-pokemon-client/rust-client-codegen all: codegen codegen: - $(GRADLE) --project-dir $(SRC_DIR) -P modules='pokemon-service-server-sdk,pokemon-service-client' :codegen-client-test:assemble :codegen-server-test:assemble - mkdir -p $(SERVER_SDK_DST) $(CLIENT_SDK_DST) + $(GRADLE) --project-dir $(SRC_DIR) -P modules='pokemon-service-server-sdk,pokemon-service-client,rpcv2,rpcv2-pokemon-client' :codegen-client-test:clean :codegen-server-test:clean :codegen-client-test:assemble :codegen-server-test:assemble + mkdir -p $(SERVER_SDK_DST) $(CLIENT_SDK_DST) $(RPCV2_SDK_DST) $(RPCV2_CLIENT_DST) cp -av $(SERVER_SDK_SRC)/* $(SERVER_SDK_DST)/ cp -av $(CLIENT_SDK_SRC)/* $(CLIENT_SDK_DST)/ + cp -av $(RPCV2_SDK_SRC)/* $(RPCV2_SDK_DST)/ + cp -av $(RPCV2_CLIENT_SRC)/* $(RPCV2_CLIENT_DST)/ build: codegen cargo build @@ -39,6 +46,6 @@ lambda_invoke: cargo lambda invoke pokemon-service-lambda --data-file pokemon-service/tests/fixtures/example-apigw-request.json distclean: clean - rm -rf $(SERVER_SDK_DST) $(CLIENT_SDK_DST) Cargo.lock + rm -rf $(SERVER_SDK_DST) $(CLIENT_SDK_DST) $(RPCV2_SDK_DST) $(RPCV2_CLIENT_DST) Cargo.lock .PHONY: all diff --git a/gradle.properties b/gradle.properties index 6d06ec08dc..c9e5ab7774 100644 --- a/gradle.properties +++ b/gradle.properties @@ -31,10 +31,4 @@ kotlinVersion=1.9.20 ktlintVersion=1.0.1 kotestVersion=5.8.0 # Avoid registering dependencies/plugins/tasks that are only used for testing purposes -isTestingEnabled=true - -# TODO(https://github.com/smithy-lang/smithy-rs/issues/1068): Once doc normalization -# is completed, warnings can be prohibited in rustdoc. -# -# defaultRustDocFlags=-D warnings -defaultRustDocFlags= +isTestingEnabled=true \ No newline at end of file diff --git a/rust-runtime/Cargo.toml b/rust-runtime/Cargo.toml index 3a618e6bbe..d7853b0099 100644 --- a/rust-runtime/Cargo.toml +++ b/rust-runtime/Cargo.toml @@ -3,6 +3,7 @@ resolver = "2" members = [ "inlineable", "aws-smithy-async", + "aws-smithy-cbor", "aws-smithy-checksums", "aws-smithy-compression", "aws-smithy-client", diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml new file mode 100644 index 0000000000..49c176a8f2 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "aws-smithy-cbor" +version = "0.0.0-smithy-rs-head" +authors = [ + "AWS Rust SDK Team ", + "David Pérez ", +] +description = "CBOR utilities for smithy-rs." +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/awslabs/smithy-rs" + +[dependencies.minicbor] +version = "0.19.1" # TODO Update +features = [ + # To write to a `Vec`: https://docs.rs/minicbor/latest/minicbor/encode/write/trait.Write.html#impl-Write-for-Vec%3Cu8%3E + "alloc", + # To support reading `f16` to accomodate fewer bytes transmitted that fit the value. + "half", +] + +[dependencies] +aws-smithy-types = { path = "../aws-smithy-types" } + +[package.metadata.docs.rs] +all-features = true +targets = ["x86_64-unknown-linux-gnu"] +rustdoc-args = ["--cfg", "docsrs"] +# End of docs.rs metadata + +[dev-dependencies] +criterion = "0.5.1" + +[[bench]] +name = "string" +harness = false + +[[bench]] +name = "blob" +harness = false diff --git a/rust-runtime/aws-smithy-cbor/LICENSE b/rust-runtime/aws-smithy-cbor/LICENSE new file mode 100644 index 0000000000..67db858821 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. diff --git a/rust-runtime/aws-smithy-cbor/README.md b/rust-runtime/aws-smithy-cbor/README.md new file mode 100644 index 0000000000..4d0ad07413 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/README.md @@ -0,0 +1,7 @@ +# aws-smithy-cbor + +TODO + + +This crate is part of the [AWS SDK for Rust](https://awslabs.github.io/aws-sdk-rust/) and the [smithy-rs](https://github.com/awslabs/smithy-rs) code generator. In most cases, it should not be used directly. + diff --git a/rust-runtime/aws-smithy-cbor/benches/blob.rs b/rust-runtime/aws-smithy-cbor/benches/blob.rs new file mode 100644 index 0000000000..dd4960da0d --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/benches/blob.rs @@ -0,0 +1,21 @@ +use aws_smithy_cbor::decode::Decoder; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn blob_benchmark(c: &mut Criterion) { + // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`. + let blob_indefinite_bytes = [ + 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62, 0x79, + 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c, 0x4e, 0x20, + 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0xff, + ]; + + c.bench_function("blob", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&blob_indefinite_bytes); + let _ = black_box(decoder.blob()); + }) + }); +} + +criterion_group!(benches, blob_benchmark); +criterion_main!(benches); diff --git a/rust-runtime/aws-smithy-cbor/benches/string.rs b/rust-runtime/aws-smithy-cbor/benches/string.rs new file mode 100644 index 0000000000..18d6b2ceee --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/benches/string.rs @@ -0,0 +1,131 @@ +use std::borrow::Cow; + +use aws_smithy_cbor::decode::Decoder; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn str_benchmark(c: &mut Criterion) { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, 0x79, + 0xff, + ]; + + c.bench_function("definite str()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&definite_bytes); + let x = black_box(decoder.str()); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("definite str_alt", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let x = black_box(str_alt(&mut decoder)); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("indefinite str()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&indefinite_bytes); + let x = black_box(decoder.str()); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("indefinite str_alt", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let x = black_box(str_alt(&mut decoder)); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); +} + +// The following seems to be a bit slower than the implementation that we have +// kept in the `aws_smithy_cbor::Decoder`. +pub fn string_alt<'b>( + decoder: &'b mut minicbor::Decoder<'b>, +) -> Result { + decoder.str_iter()?.collect() +} + +// The following seems to be a bit slower than the implementation that we have +// kept in the `aws_smithy_cbor::Decoder`. +fn str_alt<'b>( + decoder: &'b mut minicbor::Decoder<'b>, +) -> Result, minicbor::decode::Error> { + // This implementation uses `next` twice to see if there is + // another str chunk. If there is, it returns a owned `String`. + let mut chunks_iter = decoder.str_iter()?; + let head = match chunks_iter.next() { + Some(Ok(head)) => head, + None => return Ok(Cow::Borrowed("")), + Some(Err(e)) => return Err(e), + }; + + match chunks_iter.next() { + None => Ok(Cow::Borrowed(head)), + Some(Err(e)) => Err(e), + Some(Ok(next)) => { + let mut concatenated_string = String::from(head); + concatenated_string.push_str(next); + for chunk in chunks_iter { + concatenated_string.push_str(chunk?); + } + Ok(Cow::Owned(concatenated_string)) + } + } +} + +// We have two `string` implementations. One uses `collect` the other +// uses `String::new` followed by `string::push`. +pub fn string_benchmark(c: &mut Criterion) { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, 0x79, + 0xff, + ]; + + c.bench_function("definite string()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&definite_bytes); + let _ = black_box(decoder.string()); + }) + }); + + c.bench_function("definite string_alt()", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let _ = black_box(string_alt(&mut decoder)); + }) + }); + + c.bench_function("indefinite string()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&indefinite_bytes); + let _ = black_box(decoder.string()); + }) + }); + + c.bench_function("indefinite string_alt()", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let _ = black_box(string_alt(&mut decoder)); + }) + }); +} + +criterion_group!(benches, string_benchmark, str_benchmark,); +criterion_main!(benches); diff --git a/rust-runtime/aws-smithy-cbor/src/data.rs b/rust-runtime/aws-smithy-cbor/src/data.rs new file mode 100644 index 0000000000..a6eab6c549 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/data.rs @@ -0,0 +1,97 @@ +#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Debug, Hash)] +pub enum Type { + Bool, + Null, + Undefined, + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, + Int, + F16, + F32, + F64, + Simple, + Bytes, + BytesIndef, + String, + StringIndef, + Array, + ArrayIndef, + Map, + MapIndef, + Tag, + Break, + Unknown(u8), +} + +impl Type { + pub(crate) fn new(ty: minicbor::data::Type) -> Self { + match ty { + minicbor::data::Type::Bool => Self::Bool, + minicbor::data::Type::Null => Self::Null, + minicbor::data::Type::Undefined => Self::Undefined, + minicbor::data::Type::U8 => Self::U8, + minicbor::data::Type::U16 => Self::U16, + minicbor::data::Type::U32 => Self::U32, + minicbor::data::Type::U64 => Self::U64, + minicbor::data::Type::I8 => Self::I8, + minicbor::data::Type::I16 => Self::I16, + minicbor::data::Type::I32 => Self::I32, + minicbor::data::Type::I64 => Self::I64, + minicbor::data::Type::Int => Self::Int, + minicbor::data::Type::F16 => Self::F16, + minicbor::data::Type::F32 => Self::F32, + minicbor::data::Type::F64 => Self::F64, + minicbor::data::Type::Simple => Self::Simple, + minicbor::data::Type::Bytes => Self::Bytes, + minicbor::data::Type::BytesIndef => Self::BytesIndef, + minicbor::data::Type::String => Self::String, + minicbor::data::Type::StringIndef => Self::StringIndef, + minicbor::data::Type::Array => Self::Array, + minicbor::data::Type::ArrayIndef => Self::ArrayIndef, + minicbor::data::Type::Map => Self::Map, + minicbor::data::Type::MapIndef => Self::MapIndef, + minicbor::data::Type::Tag => Self::Tag, + minicbor::data::Type::Break => Self::Break, + minicbor::data::Type::Unknown(byte) => Self::Unknown(byte), + } + } + + // This is just the reverse mapping of `new`. + pub(crate) fn into_minicbor_type(self) -> minicbor::data::Type { + match self { + Type::Bool => minicbor::data::Type::Bool, + Type::Null => minicbor::data::Type::Null, + Type::Undefined => minicbor::data::Type::Undefined, + Type::U8 => minicbor::data::Type::U8, + Type::U16 => minicbor::data::Type::U16, + Type::U32 => minicbor::data::Type::U32, + Type::U64 => minicbor::data::Type::U64, + Type::I8 => minicbor::data::Type::I8, + Type::I16 => minicbor::data::Type::I16, + Type::I32 => minicbor::data::Type::I32, + Type::I64 => minicbor::data::Type::I64, + Type::Int => minicbor::data::Type::Int, + Type::F16 => minicbor::data::Type::F16, + Type::F32 => minicbor::data::Type::F32, + Type::F64 => minicbor::data::Type::F64, + Type::Simple => minicbor::data::Type::Simple, + Type::Bytes => minicbor::data::Type::Bytes, + Type::BytesIndef => minicbor::data::Type::BytesIndef, + Type::String => minicbor::data::Type::String, + Type::StringIndef => minicbor::data::Type::StringIndef, + Type::Array => minicbor::data::Type::Array, + Type::ArrayIndef => minicbor::data::Type::ArrayIndef, + Type::Map => minicbor::data::Type::Map, + Type::MapIndef => minicbor::data::Type::MapIndef, + Type::Tag => minicbor::data::Type::Tag, + Type::Break => minicbor::data::Type::Break, + Type::Unknown(byte) => minicbor::data::Type::Unknown(byte), + } + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs new file mode 100644 index 0000000000..bbff32bfd5 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -0,0 +1,314 @@ +use std::borrow::Cow; + +use aws_smithy_types::{Blob, DateTime}; +use minicbor::decode::Error; + +use crate::data::Type; + +/// Provides functions for decoding a CBOR object with a known schema. +/// +/// Although CBOR is a self-describing format, this decoder is tailored for cases where the schema +/// is known in advance. Therefore, the caller can determine which object key exists at the current +/// position by calling `str` method, and call the relevant function based on the predetermined schema +/// for that key. If an unexpected key is encountered, the caller can use the `skip` method to skip +/// over the element. +#[derive(Debug, Clone)] +pub struct Decoder<'b> { + decoder: minicbor::Decoder<'b>, +} + +/// When any of the decode methods are called they look for that particular data type at the current +/// position. If the CBOR data tag does not match the type, a `DeserializeError` is returned. +#[derive(Debug)] +pub struct DeserializeError { + #[allow(dead_code)] + _inner: Error, +} + +impl std::fmt::Display for DeserializeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self._inner.fmt(f) + } +} + +impl std::error::Error for DeserializeError {} + +impl DeserializeError { + pub(crate) fn new(inner: Error) -> Self { + Self { _inner: inner } + } + + /// More than one union variant was detected: `unexpected_type` was unexpected. + pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self { + Self { + _inner: Error::type_mismatch(unexpected_type.into_minicbor_type()) + .with_message("encountered unexpected union variant; expected end of union") + .at(at), + } + } + + /// More than one union variant was detected, but we never even got to parse the first one. + pub fn mixed_union_variants(at: usize) -> Self { + Self { + _inner: Error::message("encountered mixed variants in union; expected end of union") + .at(at), + } + } + + /// An unexpected type was encountered. + // We handle this one when decoding sparse collections: we have to expect either a `null` or an + // item, so we try decoding both. + pub fn is_type_mismatch(&self) -> bool { + self._inner.is_type_mismatch() + } +} + + +/// Macro for delegating method calls to the decoder. +/// +/// This macro generates wrapper methods for calling specific encoder methods on the decoder +/// and returning the result with error handling. +/// +/// # Example +/// +/// ``` +/// delegate_method! { +/// /// Wrapper method for encoding method `encode_str` on the decoder. +/// encode_str_wrapper => encode_str(String); +/// /// Wrapper method for encoding method `encode_int` on the decoder. +/// encode_int_wrapper => encode_int(i32); +/// } +/// ``` +macro_rules! delegate_method { + ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => { + $( + pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> { + self.decoder.$encoder_name().map_err(DeserializeError::new) + } + )+ + }; +} + +impl<'b> Decoder<'b> { + pub fn new(bytes: &'b [u8]) -> Self { + Self { + decoder: minicbor::Decoder::new(bytes), + } + } + + pub fn datatype(&self) -> Result { + self.decoder + .datatype() + .map(Type::new) + .map_err(DeserializeError::new) + } + + delegate_method! { + /// Skips the current CBOR element. + skip => skip(()); + /// Reads a boolean at the current position. + boolean => bool(bool); + /// Reads a byte at the current position. + byte => i8(i8); + /// Reads a short at the current position. + short => i16(i16); + /// Reads a integer at the current position. + integer => i32(i32); + /// Reads a long at the current position. + long => i64(i64); + /// Reads a float at the current position. + float => f32(f32); + /// Reads a double at the current position. + double => f64(f64); + /// Reads a null CBOR element at the current position. + null => null(()); + /// Returns the number of elements in a definite list. For indefinite lists it returns a `None`. + list => array(Option); + /// Returns the number of elements in a definite map. For indefinite map it returns a `None`. + map => map(Option); + } + + /// Returns the current position of the buffer, which will be decoded when any of the methods is called. + pub fn position(&self) -> usize { + self.decoder.position() + } + + /// Returns a `cow::Borrowed(&str)` if the element at the current position in the buffer is a definite + /// length string. Otherwise, it returns a `cow::Owned(String)` if the element at the current position is an + /// indefinite-length string. An error is returned if the element is neither a definite length nor an + /// indefinite-length string. + pub fn str(&mut self) -> Result, DeserializeError> { + let bookmark = self.decoder.position(); + match self.decoder.str() { + Ok(str_value) => Ok(Cow::Borrowed(str_value)), + Err(e) if e.is_type_mismatch() => { + // Move the position back to the start of the CBOR element and then try + // decoding it as a indefinite length string. + self.decoder.set_position(bookmark); + Ok(Cow::Owned(self.string()?)) + } + Err(e) => Err(DeserializeError::new(e)), + } + } + + /// Allocates and returns a `String` if the element at the current position in the buffer is either a + /// definite-length or an indefinite-length string. Otherwise, an error is returned if the element is not a string type. + pub fn string(&mut self) -> Result { + let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?; + let head = iter.next(); + + let decoded_string = match head { + None => String::new(), + Some(head) => { + let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?); + for chunk in iter { + combined_chunks.push_str(chunk.map_err(DeserializeError::new)?); + } + combined_chunks + } + }; + + Ok(decoded_string) + } + + /// Returns a `blob` if the element at the current position in the buffer is a byte string. Otherwise, + /// a `DeserializeError` error is returned. + pub fn blob(&mut self) -> Result { + let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?; + let parts: Vec<&[u8]> = iter + .collect::>() + .map_err(DeserializeError::new)?; + + Ok(if parts.len() == 1 { + Blob::new(parts[0]) // Directly convert &[u8] to Blob if there's only one part. + } else { + Blob::new(parts.concat()) // Concatenate all parts into a single Blob. + }) + } + + /// Returns a `DateTime` if the element at the current position in the buffer is a `timestamp`. Otherwise, + /// a `DeserializeError` error is returned. + pub fn timestamp(&mut self) -> Result { + let tag = self.decoder.tag().map_err(DeserializeError::new)?; + + if !matches!(tag, minicbor::data::Tag::Timestamp) { + Err(DeserializeError::new(Error::message( + "expected timestamp tag", + ))) + } else { + let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?; + Ok(DateTime::from_secs_f64(epoch_seconds)) + } + } +} + +#[derive(Debug)] +pub struct ArrayIter<'a, 'b, T> { + inner: minicbor::decode::ArrayIter<'a, 'b, T>, +} + +impl<'a, 'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'a, 'b, T> { + type Item = Result; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|opt| opt.map_err(DeserializeError::new)) + } +} + +#[derive(Debug)] +pub struct MapIter<'a, 'b, K, V> { + inner: minicbor::decode::MapIter<'a, 'b, K, V>, +} + +impl<'a, 'b, K, V> Iterator for MapIter<'a, 'b, K, V> +where + K: minicbor::Decode<'b, ()>, + V: minicbor::Decode<'b, ()>, +{ + type Item = Result<(K, V), DeserializeError>; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|opt| opt.map_err(DeserializeError::new)) + } +} + +pub fn set_optional(builder: B, decoder: &mut Decoder, f: F) -> Result +where + F: Fn(B, &mut Decoder) -> Result, +{ + match decoder.datatype()? { + crate::data::Type::Null => { + decoder.null()?; + Ok(builder) + } + _ => f(builder, decoder), + } +} + +#[cfg(test)] +mod tests { + use crate::Decoder; + + #[test] + fn test_definite_str_is_cow_borrowed() { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + let mut decoder = Decoder::new(&definite_bytes); + let member = decoder.str().expect("could not decode str"); + assert_eq!(member, "thisIsAKey"); + assert!(matches!(member, std::borrow::Cow::Borrowed(_))); + } + + #[test] + fn test_indefinite_str_is_cow_owned() { + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, + 0x79, 0xff, + ]; + let mut decoder = Decoder::new(&indefinite_bytes); + let member = decoder.str().expect("could not decode str"); + assert_eq!(member, "thisIsAKey"); + assert!(matches!(member, std::borrow::Cow::Owned(_))); + } + + #[test] + fn test_empty_str_works() { + let bytes = [0x60]; + let mut decoder = Decoder::new(&bytes); + let member = decoder.str().expect("could not decode empty str"); + assert_eq!(member, ""); + } + + #[test] + fn test_empty_blob_works() { + let bytes = [0x40]; + let mut decoder = Decoder::new(&bytes); + let member = decoder.blob().expect("could not decode an empty blob"); + assert_eq!(member, aws_smithy_types::Blob::new(&[])); + } + + #[test] + fn test_indefinite_length_blob() { + // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`. + // https://cbor.nemo157.com/#type=hex&value=bf69626c6f6256616c75655f50696e646566696e6974652d627974652c49206368756e6b65642c4e206f6e206561636820636f6d6d61ffff + let indefinite_bytes = [ + 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62, + 0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c, + 0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, + 0x61, 0xff, + ]; + let mut decoder = Decoder::new(&indefinite_bytes); + let member = decoder.blob().expect("could not decode blob"); + assert_eq!( + member, + aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes()) + ); + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs new file mode 100644 index 0000000000..475071ffcc --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -0,0 +1,92 @@ +use aws_smithy_types::{Blob, DateTime}; + +/// Macro for delegating method calls to the encoder. +/// +/// This macro generates wrapper methods for calling specific encoder methods on the encoder +/// and returning a mutable reference to self for method chaining. +/// +/// # Example +/// +/// ``` +/// delegate_method! { +/// /// Wrapper method for encoding method `encode_str` on the encoder. +/// encode_str_wrapper => encode_str(data: &str); +/// /// Wrapper method for encoding method `encode_int` on the encoder. +/// encode_int_wrapper => encode_int(value: i32); +/// } +/// ``` +macro_rules! delegate_method { + ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($($param_name:ident : $param_type:ty),*);)+) => { + $( + pub fn $wrapper_name(&mut self, $($param_name: $param_type),*) -> &mut Self { + self.encoder.$encoder_name($($param_name)*).expect(INFALLIBLE_WRITE); + self + } + )+ + }; +} + +#[derive(Debug, Clone)] +pub struct Encoder { + encoder: minicbor::Encoder>, +} + +// TODO docs +const INFALLIBLE_WRITE: &str = "write failed"; + +impl Encoder { + pub fn new(writer: Vec) -> Self { + Self { + encoder: minicbor::Encoder::new(writer), + } + } + + delegate_method! { + /// Writes a fixed length array of given length. + array => array(len: u64); + /// Used when we know the size in advance, i.e.: + /// - when a struct has all non-`Option`al members. + /// - when serializing `union` shapes (they can only have one member set). + map => map(len: u64); + /// Used when it's not cheap to calculate the size, i.e. when the struct has one or more + /// `Option`al members. + begin_map => begin_map(); + /// Writes a definite length string. + str => str(x: &str); + /// Writes a boolean value. + boolean => bool(x: bool); + /// Writes a byte value. + byte => i8(x: i8); + /// Writes a short value. + short => i16(x: i16); + /// Writes an integer value. + integer => i32(x: i32); + /// Writes an long value. + long => i64(x: i64); + /// Writes an float value. + float => f32(x: f32); + /// Writes an double value. + double => f64(x: f64); + /// Writes a null tag. + null => null(); + /// Writes an end tag. + end => end(); + } + + pub fn blob(&mut self, x: &Blob) -> &mut Self { + self.encoder.bytes(x.as_ref()).expect(INFALLIBLE_WRITE); + self + } + + pub fn timestamp(&mut self, x: &DateTime) -> &mut Self { + self.encoder + .tag(minicbor::data::Tag::Timestamp) + .expect(INFALLIBLE_WRITE); + self.encoder.f64(x.as_secs_f64()).expect(INFALLIBLE_WRITE); + self + } + + pub fn into_writer(self) -> Vec { + self.encoder.into_writer() + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/lib.rs b/rust-runtime/aws-smithy-cbor/src/lib.rs new file mode 100644 index 0000000000..1a547fedee --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/lib.rs @@ -0,0 +1,13 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! CBOR abstractions for Smithy. + +pub mod data; +pub mod decode; +pub mod encode; + +pub use decode::Decoder; +pub use encode::Encoder; diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 18b12aaa01..0009e5c0d5 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-http-server" -version = "0.61.2" +version = "0.61.3" authors = ["Smithy Rust Server "] edition = "2021" license = "Apache-2.0" @@ -23,6 +23,7 @@ aws-smithy-json = { path = "../aws-smithy-json" } aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["http-02x"] } aws-smithy-types = { path = "../aws-smithy-types", features = ["http-body-0-4-x", "hyper-0-14-x"] } aws-smithy-xml = { path = "../aws-smithy-xml" } +aws-smithy-cbor = { path = "../aws-smithy-cbor" } bytes = "1.1" futures-util = { version = "0.3.29", default-features = false } http = "0.2" diff --git a/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml new file mode 100644 index 0000000000..e6ae4c9f6c --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "rpcv2-service" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +# Local paths +aws-smithy-http-server = { path = "../../" } +hyper = "0.14.24" +tokio = "1.25.0" +rpcv2-server-sdk = { path = "../rpcv2-server-sdk/", package = "rpcv2" } +rpcv2-pokemon-client = { path = "../rpcv2-pokemon-client/" } diff --git a/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs new file mode 100644 index 0000000000..96e2eb0221 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs @@ -0,0 +1,32 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +use std::net::SocketAddr; + +use aws_smithy_http_server::{body::Body, routing::Route}; +use rpcv2_server_sdk::{error, input, output, RpcV2Service}; + +async fn handler( + input: input::RpcV2OperationInput, +) -> Result { + println!("{input:#?}"); + + todo!() +} + +#[tokio::main] +async fn main() { + let service: RpcV2Service> = rpcv2_server_sdk::RpcV2Service::builder_without_plugins() + .rpc_v2_operation(handler) + .build() + .unwrap(); + + let server = service.into_make_service(); + let bind: SocketAddr = "127.0.0.1:6969" + .parse() + .expect("unable to parse the server bind address and port"); + + println!("Binding {bind}"); + hyper::Server::bind(&bind).serve(server).await.unwrap(); +} diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs index 295c670b27..74fc44b8df 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs @@ -39,7 +39,7 @@ pub enum Error { // This constant determines when the `TinyMap` implementation switches from being a `Vec` to a // `HashMap`. This is chosen to be 15 as a result of the discussion around // https://github.com/smithy-lang/smithy-rs/pull/1429#issuecomment-1147516546 -const ROUTE_CUTOFF: usize = 15; +pub(crate) const ROUTE_CUTOFF: usize = 15; /// A [`Router`] supporting [`AWS JSON 1.0`] and [`AWS JSON 1.1`] protocols. /// @@ -65,6 +65,7 @@ impl AwsJsonRouter { } } + // TODO This function is not used? Codegen should probably delegate to this function. /// Applies type erasure to the inner route using [`Route::new`]. pub fn boxed(self) -> AwsJsonRouter> where diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs index 0b3faf0136..4686e94e07 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs @@ -9,6 +9,7 @@ pub mod aws_json_11; pub mod rest; pub mod rest_json_1; pub mod rest_xml; +pub mod rpc_v2; use crate::rejection::MissingContentTypeReason; use aws_smithy_runtime_api::http::Headers as SmithyHeaders; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs new file mode 100644 index 0000000000..716befd77a --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +pub mod rejection; +pub mod router; +pub mod runtime_error; + +// TODO: Fill link +/// [Smithy RPC V2](). +pub struct RpcV2; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/rejection.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/rejection.rs new file mode 100644 index 0000000000..2ec8b957af --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/rejection.rs @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::num::TryFromIntError; + +use crate::rejection::MissingContentTypeReason; +use aws_smithy_runtime_api::http::HttpError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ResponseRejection { + #[error("invalid bound HTTP status code; status codes must be inside the 100-999 range: {0}")] + InvalidHttpStatusCode(TryFromIntError), + #[error("error serializing CBOR-encoded body: {0}")] + Serialization(#[from] aws_smithy_types::error::operation::SerializationError), + #[error("error building HTTP response: {0}")] + HttpBuild(#[from] http::Error), +} + +#[derive(Debug, Error)] +pub enum RequestRejection { + #[error("error converting non-streaming body to bytes: {0}")] + BufferHttpBodyBytes(crate::Error), + #[error("request contains invalid value for `Accept` header")] + NotAcceptable, + #[error("expected `Content-Type` header not found: {0}")] + MissingContentType(#[from] MissingContentTypeReason), + #[error("error deserializing request HTTP body as CBOR: {0}")] + CborDeserialize(#[from] aws_smithy_cbor::decode::DeserializeError), + // Unlike the other protocols, RPC v2 uses CBOR, a binary serialization format, so we take in a + // `Vec` here instead of `String`. + #[error("request does not adhere to modeled constraints")] + ConstraintViolation(Vec), + + /// Typically happens when the request has headers that are not valid UTF-8. + #[error("failed to convert request: {0}")] + HttpConversion(#[from] HttpError), +} + +impl From for RequestRejection { + fn from(_err: std::convert::Infallible) -> Self { + match _err {} + } +} + +convert_to_request_rejection!(hyper::Error, BufferHttpBodyBytes); +convert_to_request_rejection!(Box, BufferHttpBodyBytes); diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs new file mode 100644 index 0000000000..3e73b8b9e0 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs @@ -0,0 +1,409 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::convert::Infallible; +use std::str::FromStr; + +use http::header::ToStrError; +use http::HeaderMap; +use once_cell::sync::Lazy; +use regex::Regex; +use thiserror::Error; +use tower::Layer; +use tower::Service; + +use crate::body::empty; +use crate::body::BoxBody; +use crate::extension::RuntimeErrorExtension; +use crate::protocol::aws_json_11::router::ROUTE_CUTOFF; +use crate::response::IntoResponse; +use crate::routing::tiny_map::TinyMap; +use crate::routing::Route; +use crate::routing::Router; +use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; + +use super::RpcV2; + +pub use crate::protocol::rest::router::*; + +/// An RPC v2 routing error. +#[derive(Debug, Error)] +pub enum Error { + /// Method was not `POST`. + #[error("method not POST")] + MethodNotAllowed, + /// Requests for the `rpcv2` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` + /// header. + #[error("contains forbidden headers")] + ForbiddenHeaders, + /// Unable to parse `smithy-protocol` header into a valid wire format value. + #[error("failed to parse `smithy-protocol` header into a valid wire format value")] + InvalidWireFormatHeader(#[from] WireFormatError), + /// Operation not found. + #[error("operation not found")] + NotFound, +} + +// TODO Docs +#[derive(Debug, Clone)] +pub struct RpcV2Router { + routes: TinyMap<&'static str, S, ROUTE_CUTOFF>, +} + +/// Requests for the `rpcv2` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` +/// header. An `rpcv2` request is malformed if it contains either of these headers. Server-side +/// implementations MUST reject such requests for security reasons. +const FORBIDDEN_HEADERS: &'static [&'static str] = &["x-amz-target", "x-amzn-target"]; + +/// Matches the `Identifier` ABNF rule in +/// . +const IDENTIFIER_PATTERN: &'static str = r#"((_+([A-Za-z]|[0-9]))|[A-Za-z])[A-Za-z0-9_]*"#; + +impl RpcV2Router { + // TODO Consider building a nom parser + fn uri_path_regex() -> &'static Regex { + // Every request for the `rpcv2Cbor` protocol MUST be sent to a URL with the + // following form: `{prefix?}/service/{serviceName}/operation/{operationName}` + // + // * The optional `prefix` segment may span multiple path segments and is not + // utilized by the Smithy RPC v2 CBOR protocol. For example, a service could + // use a `v1` prefix for the following URL path: `v1/service/FooService/operation/BarOperation` + // * The `serviceName` segment MUST be replaced by the [`shape + // name`](https://smithy.io/2.0/spec/model.html#grammar-token-smithy-Identifier) + // of the service's [Shape ID](https://smithy.io/2.0/spec/model.html#shape-id) + // in the Smithy model. The `serviceName` produced by client implementations + // MUST NOT contain the namespace of the `service` shape. Service + // implementations SHOULD accept an absolute shape ID as the content of this + // segment with the `#` character replaced with a `.` character, routing it + // the same as if only the name was specified. For example, if the `service`'s + // absolute shape ID is `com.example#TheService`, a service should accept both + // `TheService` and `com.example.TheService` as values for the `serviceName` + // segment. + static PATH_REGEX: Lazy = Lazy::new(|| { + Regex::new(&format!( + r#"/service/({}\.)*(?P{})/operation/(?P{})$"#, + IDENTIFIER_PATTERN, IDENTIFIER_PATTERN, IDENTIFIER_PATTERN, + )) + .unwrap() + }); + + &PATH_REGEX + } + + pub fn wire_format_regex() -> &'static Regex { + static SMITHY_PROTOCOL_REGEX: Lazy = Lazy::new(|| Regex::new(r#"^rpc-v2-(?P\w+)$"#).unwrap()); + + &SMITHY_PROTOCOL_REGEX + } + + pub fn boxed(self) -> RpcV2Router> + where + S: Service, Response = http::Response, Error = Infallible>, + S: Send + Clone + 'static, + S::Future: Send + 'static, + { + RpcV2Router { + routes: self.routes.into_iter().map(|(key, s)| (key, Route::new(s))).collect(), + } + } + + /// Applies a [`Layer`] uniformly to all routes. + pub fn layer(self, layer: L) -> RpcV2Router + where + L: Layer, + { + RpcV2Router { + routes: self + .routes + .into_iter() + .map(|(key, route)| (key, layer.layer(route))) + .collect(), + } + } +} + +// TODO: Implement (current body copied from the rest xml impl) +// and document. +/// A Smithy RPC V2 routing error. +impl IntoResponse for Error { + fn into_response(self) -> http::Response { + match self { + Error::NotFound => http::Response::builder() + .status(http::StatusCode::NOT_FOUND) + // TODO + .header(http::header::CONTENT_TYPE, "application/xml") + .extension(RuntimeErrorExtension::new( + UNKNOWN_OPERATION_EXCEPTION.to_string(), + )) + .body(empty()) + .expect("invalid HTTP response for REST XML routing error; please file a bug report under https://github.com/awslabs/smithy-rs/issues"), + Error::MethodNotAllowed => method_disallowed(), + // TODO + _ => todo!(), + } + } +} + +/// Errors that can happen when parsing the wire format from the `smithy-protocol` header. +#[derive(Debug, Error)] +pub enum WireFormatError { + /// Header not found. + #[error("`smithy-protocol` header not found")] + HeaderNotFound, + /// Header value is not visible ASCII. + #[error("`smithy-protocol` header not visible ASCII")] + HeaderValueNotVisibleAscii(ToStrError), + /// Header value does not match the `rpc-v2-{format}` pattern. The actual parsed header value + /// is stored in the tuple struct. + // https://doc.rust-lang.org/std/fmt/index.html#escaping + #[error("`smithy-protocol` header does not match the `rpc-v2-{{format}}` pattern: `{0}`")] + HeaderValueNotValid(String), + /// Header value matches the `rpc-v2-{format}` pattern, but the `format` is not supported. The + /// actual parsed header value is stored in the tuple struct. + #[error("found unsupported `smithy-protocol` wire format: `{0}`")] + WireFormatNotSupported(String), +} + +/// Smithy RPC V2 requests have a `smithy-protocol` header with the value +/// `"rpc-v2-{format}"`, where `format` is one of the supported wire formats +/// by the protocol (see [`WireFormat`]). +fn parse_wire_format_from_header(headers: &HeaderMap) -> Result { + let header = headers.get("smithy-protocol").ok_or(WireFormatError::HeaderNotFound)?; + let header = header + .to_str() + .map_err(|e| WireFormatError::HeaderValueNotVisibleAscii(e))?; + let captures = RpcV2Router::<()>::wire_format_regex() + .captures(header) + .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; + + let format = captures + .name("format") + .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; + + let wire_format_parse_res: Result = format.as_str().parse(); + wire_format_parse_res.map_err(|_| WireFormatError::WireFormatNotSupported(header.to_owned())) +} + +/// Supported wire formats by RPC V2. +enum WireFormat { + Cbor, +} + +struct WireFormatFromStrError; + +impl FromStr for WireFormat { + type Err = WireFormatFromStrError; + + fn from_str(format: &str) -> Result { + match format { + "cbor" => Ok(Self::Cbor), + _ => Err(WireFormatFromStrError), + } + } +} + +impl Router for RpcV2Router { + type Service = S; + + type Error = Error; + + fn match_route(&self, request: &http::Request) -> Result { + // Only `Method::POST` is allowed. + if request.method() != http::Method::POST { + return Err(Error::MethodNotAllowed); + } + + // Some headers are not allowed. + let request_has_forbidden_header = FORBIDDEN_HEADERS + .into_iter() + .any(|&forbidden_header| request.headers().contains_key(forbidden_header)); + if request_has_forbidden_header { + return Err(Error::ForbiddenHeaders); + } + + // Wire format has to be specified and supported. + let _wire_format = parse_wire_format_from_header(request.headers())?; + + // Extract the service name and the operation name from the request URI. + let request_path = request.uri().path(); + let regex = Self::uri_path_regex(); + + tracing::trace!(%request_path, "capturing service and operation from URI"); + let captures = regex.captures(request_path).ok_or(Error::NotFound)?; + let (service, operation) = (&captures["service"], &captures["operation"]); + tracing::trace!(%service, %operation, "captured service and operation from URI"); + + // Lookup in the `TinyMap` for a route for the target. + let route = self + .routes + .get((format!("{service}.{operation}")).as_str()) + .ok_or(Error::NotFound)?; + Ok(route.clone()) + } +} + +impl FromIterator<(&'static str, S)> for RpcV2Router { + fn from_iter>(iter: T) -> Self { + Self { + routes: iter.into_iter().collect(), + } + } +} + +#[cfg(test)] +mod tests { + use http::{HeaderMap, HeaderValue, Method}; + use regex::Regex; + + use crate::protocol::test_helpers::req; + + use super::{Error, Router, RpcV2Router}; + + fn identifier_regex() -> Regex { + Regex::new(&format!("^{}$", super::IDENTIFIER_PATTERN)).unwrap() + } + + #[test] + fn valid_identifiers() { + let valid_identifiers = vec!["a", "_a", "_0", "__0", "variable123", "_underscored_variable"]; + + for id in &valid_identifiers { + assert!(identifier_regex().is_match(id), "'{}' is incorrectly rejected", id); + } + } + + #[test] + fn invalid_identifiers() { + let invalid_identifiers = vec![ + "0", + "123starts_with_digit", + "@invalid_start_character", + " space_in_identifier", + "invalid-character", + "invalid@character", + "no#hashes", + ]; + + for id in &invalid_identifiers { + assert!(!identifier_regex().is_match(id), "'{}' is incorrectly accepted", id); + } + } + + #[test] + fn uri_regex_works_accepts() { + let regex = RpcV2Router::<()>::uri_path_regex(); + + for uri in [ + "/service/Service/operation/Operation", + "prefix/69/service/Service/operation/Operation", + // Here the prefix is up to the last occurrence of the string `/service`. + "prefix/69/service/Service/operation/Operation/service/Service/operation/Operation", + // Service implementations SHOULD accept an absolute shape ID as the content of this + // segment with the `#` character replaced with a `.` character, routing it the same as + // if only the name was specified. For example, if the `service`'s absolute shape ID is + // `com.example#TheService`, a service should accept both `TheService` and + // `com.example.TheService` as values for the `serviceName` segment. + "/service/aws.protocoltests.rpcv2.Service/operation/Operation", + "/service/namespace.Service/operation/Operation", + ] { + let captures = regex.captures(uri).unwrap(); + assert_eq!("Service", &captures["service"], "uri: {}", uri); + assert_eq!("Operation", &captures["operation"], "uri: {}", uri); + } + } + + #[test] + fn uri_regex_works_rejects() { + let regex = RpcV2Router::<()>::uri_path_regex(); + + for uri in [ + "", + "foo", + "/servicee/Service/operation/Operation", + "/service/Service", + "/service/Service/operation/", + "/service/Service/operation/Operation/", + "/service/Service/operation/Operation/invalid-suffix", + "/service/namespace.foo#Service/operation/Operation", + "/service/namespace-Service/operation/Operation", + "/service/.Service/operation/Operation", + ] { + assert!(regex.captures(uri).is_none(), "uri: {}", uri); + } + } + + #[test] + fn wire_format_regex_works() { + let regex = RpcV2Router::<()>::wire_format_regex(); + + let captures = regex.captures("rpc-v2-something").unwrap(); + assert_eq!("something", &captures["format"]); + + let captures = regex.captures("rpc-v2-SomethingElse").unwrap(); + assert_eq!("SomethingElse", &captures["format"]); + + let invalid = regex.captures("rpc-v1-something"); + assert!(invalid.is_none()); + } + + /// Helper function returning the only strictly required header. + fn headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert("smithy-protocol", HeaderValue::from_static("rpc-v2-cbor")); + headers + } + + #[test] + fn simple_routing() { + let router: RpcV2Router<_> = ["Service.Operation"].into_iter().map(|op| (op, ())).collect(); + let good_uri = "/prefix/service/Service/operation/Operation"; + + // The request should match. + let routing_result = router.match_route(&req(&Method::POST, good_uri, Some(headers()))); + assert!(routing_result.is_ok()); + + // The request would be valid if it used `Method::POST`. + let invalid_request = req(&Method::GET, good_uri, Some(headers())); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::MethodNotAllowed) + )); + + // The request would be valid if it did not have forbidden headers. + for forbidden_header_name in ["x-amz-target", "x-amzn-target"] { + let mut headers = headers(); + headers.insert(forbidden_header_name, HeaderValue::from_static("Service.Operation")); + let invalid_request = req(&Method::POST, good_uri, Some(headers)); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::ForbiddenHeaders) + )); + } + + for bad_uri in [ + // These requests would be valid if they used correct URIs. + "/prefix/Service/Service/operation/Operation", + "/prefix/service/Service/operation/Operation/suffix", + // These requests would be valid if their URI matched an existing operation. + "/prefix/service/ThisServiceDoesNotExist/operation/Operation", + "/prefix/service/Service/operation/ThisOperationDoesNotExist", + ] { + let invalid_request = &req(&Method::POST, bad_uri, Some(headers())); + assert!(matches!(router.match_route(&invalid_request), Err(Error::NotFound))); + } + + // The request would be valid if it specified a supported wire format in the + // `smithy-protocol` header. + for header_name in ["bad-header", "rpc-v2-json", "foo-rpc-v2-cbor", "rpc-v2-cbor-foo"] { + let mut headers = HeaderMap::new(); + headers.insert("smithy-protocol", HeaderValue::from_static(header_name)); + let invalid_request = &req(&Method::POST, good_uri, Some(headers)); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::InvalidWireFormatHeader(_)) + )); + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs new file mode 100644 index 0000000000..d8ee60796a --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs @@ -0,0 +1,82 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::response::IntoResponse; +use crate::runtime_error::{InternalFailureException, INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE}; +use crate::{extension::RuntimeErrorExtension, protocol::rpc_v2::RpcV2}; +use http::StatusCode; + +use super::rejection::{RequestRejection, ResponseRejection}; + +#[derive(Debug)] +pub enum RuntimeError { + Serialization(crate::Error), + InternalFailure(crate::Error), + NotAcceptable, + UnsupportedMediaType, + Validation(Vec), +} + +impl RuntimeError { + pub fn name(&self) -> &'static str { + match self { + Self::Serialization(_) => "SerializationException", + Self::InternalFailure(_) => "InternalFailureException", + Self::NotAcceptable => "NotAcceptableException", + Self::UnsupportedMediaType => "UnsupportedMediaTypeException", + Self::Validation(_) => "ValidationException", + } + } + + pub fn status_code(&self) -> StatusCode { + match self { + Self::Serialization(_) => StatusCode::BAD_REQUEST, + Self::InternalFailure(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::NotAcceptable => StatusCode::NOT_ACCEPTABLE, + Self::UnsupportedMediaType => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::Validation(_) => StatusCode::BAD_REQUEST, + } + } +} + +impl IntoResponse for InternalFailureException { + fn into_response(self) -> http::Response { + IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) + } +} + +impl IntoResponse for RuntimeError { + fn into_response(self) -> http::Response { + let res = http::Response::builder() + .status(self.status_code()) + .header("Content-Type", "application/cbor") + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + // TODO + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + // See https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_0-protocol.html#empty-body-serialization + _ => crate::body::to_boxed("{}"), + }; + + res.body(body) + .expect(INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE) + } +} + +impl From for RuntimeError { + fn from(err: ResponseRejection) -> Self { + Self::Serialization(crate::Error::new(err)) + } +} + +impl From for RuntimeError { + fn from(err: RequestRejection) -> Self { + match err { + RequestRejection::ConstraintViolation(reason) => Self::Validation(reason), + _ => Self::Serialization(crate::Error::new(err)), + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 14f124f687..ede1f5117b 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -37,7 +37,6 @@ use futures_util::{ use http::Response; use http_body::Body as HttpBody; use tower::{util::Oneshot, Service, ServiceExt}; -use tracing::debug; use crate::{ body::{boxed, BoxBody}, @@ -191,12 +190,13 @@ where } fn call(&mut self, req: http::Request) -> Self::Future { + tracing::debug!("inside routing service call"); match self.router.match_route(&req) { // Successfully routed, use the routes `Service::call`. Ok(ok) => RoutingFuture::from_oneshot(ok.oneshot(req)), // Failed to route, use the `R::Error`s `IntoResponse

`. Err(error) => { - debug!(%error, "failed to route"); + tracing::debug!(%error, "failed to route"); RoutingFuture::from_response(error.into_response()) } } diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index c056bbc020..e50222b200 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-protocol-test" -version = "0.60.7" +version = "0.60.8" authors = ["AWS Rust SDK Team ", "Russell Cohen "] description = "A collection of library functions to validate HTTP requests against Smithy protocol tests." edition = "2021" @@ -10,6 +10,9 @@ repository = "https://github.com/smithy-lang/smithy-rs" [dependencies] # Not perfect for our needs, but good for now assert-json-diff = "1.1" +base64-simd = "0.8" +cbor-diag = "0.1" +serde_cbor = "0.11" http = "0.2.1" pretty_assertions = "1.3" regex-lite = "0.1.5" @@ -18,7 +21,6 @@ serde_json = "1" thiserror = "1.0.40" aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["client"] } - [package.metadata.docs.rs] all-features = true targets = ["x86_64-unknown-linux-gnu"] diff --git a/rust-runtime/aws-smithy-protocol-test/src/lib.rs b/rust-runtime/aws-smithy-protocol-test/src/lib.rs index 07839e2a38..63328ebba4 100644 --- a/rust-runtime/aws-smithy-protocol-test/src/lib.rs +++ b/rust-runtime/aws-smithy-protocol-test/src/lib.rs @@ -306,10 +306,12 @@ pub fn require_headers( #[derive(Clone)] pub enum MediaType { - /// Json media types are deserialized and compared + /// JSON media types are deserialized and compared Json, /// XML media types are normalized and compared Xml, + /// CBOR media types are decoded from base64 to binary and compared + Cbor, /// For x-www-form-urlencoded, do some map order comparison shenanigans UrlEncodedForm, /// Other media types are compared literally @@ -322,6 +324,7 @@ impl> From for MediaType { "application/json" => MediaType::Json, "application/x-amz-json-1.1" => MediaType::Json, "application/xml" => MediaType::Xml, + "application/cbor" => MediaType::Cbor, "application/x-www-form-urlencoded" => MediaType::UrlEncodedForm, other => MediaType::Other(other.to_string()), } @@ -336,11 +339,11 @@ pub fn validate_body>( let body_str = std::str::from_utf8(actual_body.as_ref()); match (media_type, body_str) { (MediaType::Json, Ok(actual_body)) => try_json_eq(expected_body, actual_body), - (MediaType::Xml, Ok(actual_body)) => try_xml_equivalent(expected_body, actual_body), (MediaType::Json, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "json".to_owned(), found: "input was not valid UTF-8".to_owned(), }), + (MediaType::Xml, Ok(actual_body)) => try_xml_equivalent(actual_body, expected_body), (MediaType::Xml, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "XML".to_owned(), found: "input was not valid UTF-8".to_owned(), @@ -352,6 +355,7 @@ pub fn validate_body>( expected: "x-www-form-urlencoded".to_owned(), found: "input was not valid UTF-8".to_owned(), }), + (MediaType::Cbor, _) => try_cbor_eq(actual_body, expected_body), (MediaType::Other(media_type), Ok(actual_body)) => { if actual_body != expected_body { Err(ProtocolTestFailure::BodyDidNotMatch { @@ -410,6 +414,63 @@ fn try_json_eq(expected: &str, actual: &str) -> Result<(), ProtocolTestFailure> } } +fn try_cbor_eq>( + actual_body: T, + expected_body: &str, +) -> Result<(), ProtocolTestFailure> { + let decoded = base64_simd::STANDARD + .decode_to_vec(expected_body) + .expect("smithy protocol test `body` property is not properly base64 encoded"); + let expected_cbor_value: serde_cbor::Value = + serde_cbor::from_slice(decoded.as_slice()).unwrap(); + let actual_cbor_value: serde_cbor::Value = + serde_cbor::from_slice(actual_body.as_ref()).unwrap(); // TODO Don't panic + let actual_body_base64 = base64_simd::STANDARD.encode_to_string(&actual_body); + + if expected_cbor_value != actual_cbor_value { + let expected_body_annotated_hex: String = cbor_diag::parse_bytes(&decoded) + .expect("smithy protocol test `body` property is not valid CBOR") + .to_hex(); + let expected_body_diag: String = cbor_diag::parse_bytes(&decoded) + .expect("smithy protocol test `body` property is not valid CBOR") + .to_diag_pretty(); + let actual_body_annotated_hex: String = cbor_diag::parse_bytes(&actual_body) + .expect("actual body is not valid CBOR") + .to_hex(); + let actual_body_diag: String = cbor_diag::parse_bytes(&actual_body) + .expect("actual body is not valid CBOR") + .to_diag_pretty(); + + Err(ProtocolTestFailure::BodyDidNotMatch { + comparison: PrettyString(format!( + "{}", + Comparison::new(&expected_cbor_value, &actual_cbor_value) + )), + // The last newline is important because the panic message ends with a `.` + hint: format!( + "expected body in diagnostic format: +{} +actual body in diagnostic format: +{} +expected body in annotated hex: +{} +actual body in annotated hex: +{} +actual body in base64 (useful to update the protocol test): +{} +", + expected_body_diag, + actual_body_diag, + expected_body_annotated_hex, + actual_body_annotated_hex, + actual_body_base64, + ), + }) + } else { + Ok(()) + } +} + #[cfg(test)] mod tests { use crate::{ diff --git a/rust-runtime/aws-smithy-types/src/error/operation.rs b/rust-runtime/aws-smithy-types/src/error/operation.rs index 5b61f1acfe..ad5032fc4c 100644 --- a/rust-runtime/aws-smithy-types/src/error/operation.rs +++ b/rust-runtime/aws-smithy-types/src/error/operation.rs @@ -21,6 +21,7 @@ pub struct SerializationError { kind: SerializationErrorKind, } +// TODO The docs in `main` are wrong. impl SerializationError { /// An error that occurs when serialization of an operation fails for an unknown reason. pub fn unknown_variant(union: &'static str) -> Self { From f1b75082d44b28b3d3156a45442a52017e6a1420 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 4 Jun 2024 13:46:00 +0200 Subject: [PATCH 02/77] Small fixes --- codegen-client-test/build.gradle.kts | 5 - .../adwait-cbor-structs.smithy | 142 -------------- .../adwait-empty-input-output.smithy | 174 ------------------ .../common-test-models/adwait-main.smithy | 30 --- .../{rpcv2.smithy => rpcv2-extras.smithy} | 18 +- .../protocols/parse/CborParserGenerator.kt | 2 +- codegen-server-test/build.gradle.kts | 12 +- .../serialize/CborSerializerGeneratorTest.kt | 2 +- rust-runtime/aws-smithy-cbor/src/decode.rs | 7 +- 9 files changed, 18 insertions(+), 374 deletions(-) delete mode 100644 codegen-core/common-test-models/adwait-cbor-structs.smithy delete mode 100644 codegen-core/common-test-models/adwait-empty-input-output.smithy delete mode 100644 codegen-core/common-test-models/adwait-main.smithy rename codegen-core/common-test-models/{rpcv2.smithy => rpcv2-extras.smithy} (92%) diff --git a/codegen-client-test/build.gradle.kts b/codegen-client-test/build.gradle.kts index a1c6de5c23..c69a4ae719 100644 --- a/codegen-client-test/build.gradle.kts +++ b/codegen-client-test/build.gradle.kts @@ -111,11 +111,6 @@ val allCodegenTests = listOf( "pokemon-service-awsjson-client", dependsOn = listOf("pokemon-awsjson.smithy", "pokemon-common.smithy"), ), - ClientTest( - "com.amazonaws.simple#RpcV2Service", - "rpcv2-pokemon-client", - dependsOn = listOf("rpcv2.smithy") - ), ClientTest("aws.protocoltests.misc#QueryCompatService", "query-compat-test", dependsOn = listOf("aws-json-query-compat.smithy")), ).map(ClientTest::toCodegenTest) diff --git a/codegen-core/common-test-models/adwait-cbor-structs.smithy b/codegen-core/common-test-models/adwait-cbor-structs.smithy deleted file mode 100644 index b88d930967..0000000000 --- a/codegen-core/common-test-models/adwait-cbor-structs.smithy +++ /dev/null @@ -1,142 +0,0 @@ -$version: "2.0" - -namespace aws.protocoltests.rpcv2 - -use aws.protocoltests.shared#StringList -use smithy.protocols#rpcv2 -use smithy.test#httpRequestTests -use smithy.test#httpResponseTests - - -@httpRequestTests([ - { - id: "RpcV2CborSimpleScalarProperties", - protocol: rpcv2, - documentation: "Serializes simple scalar properties", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - "Content-Type": "application/cbor" - } - method: "POST", - bodyMediaType: "application/cbor", - uri: "/service/RpcV2Protocol/operation/SimpleScalarProperties", - body: "v2lieXRlVmFsdWUFa2RvdWJsZVZhbHVl+z/+OVgQYk3TcWZhbHNlQm9vbGVhblZhbHVl9GpmbG9hdFZhbHVl+kDz989saW50ZWdlclZhbHVlGQEAaWxvbmdWYWx1ZRkmkWpzaG9ydFZhbHVlGSaqa3N0cmluZ1ZhbHVlZnNpbXBsZXB0cnVlQm9vbGVhblZhbHVl9f8=" - params: { - trueBooleanValue: true, - falseBooleanValue: false, - byteValue: 5, - doubleValue: 1.889, - floatValue: 7.624, - integerValue: 256, - shortValue: 9898, - longValue: 9873 - stringValue: "simple" - } - }, - { - id: "RpcV2CborClientDoesntSerializeNullStructureValues", - documentation: "RpcV2 Cbor should not serialize null structure values", - protocol: rpcv2, - method: "POST", - uri: "/service/RpcV2Protocol/operation/SimpleScalarProperties", - body: "v/8=", - bodyMediaType: "application/cbor", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - "Content-Type": "application/cbor" - } - params: { - stringValue: null - }, - appliesTo: "client" - }, - { - id: "RpcV2CborServerDoesntDeSerializeNullStructureValues", - documentation: "RpcV2 Cbor should not deserialize null structure values", - protocol: rpcv2, - method: "POST", - uri: "/service/RpcV2Protocol/operation/SimpleScalarProperties", - body: "v2tzdHJpbmdWYWx1Zfb/", - bodyMediaType: "application/cbor", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - "Content-Type": "application/cbor" - } - params: {}, - appliesTo: "server" - }, -]) -@httpResponseTests([ - { - id: "simple_scalar_structure", - protocol: rpcv2, - documentation: "Serializes simple scalar properties", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Content-Type": "application/cbor" - } - bodyMediaType: "application/cbor", - body: "v2lieXRlVmFsdWUFa2RvdWJsZVZhbHVl+z/+OVgQYk3TcWZhbHNlQm9vbGVhblZhbHVl9GpmbG9hdFZhbHVl+kDz989saW50ZWdlclZhbHVlGQEAaWxvbmdWYWx1ZRkmkWpzaG9ydFZhbHVlGSaqa3N0cmluZ1ZhbHVlZnNpbXBsZXB0cnVlQm9vbGVhblZhbHVl9f8=", - code: 200, - params: { - trueBooleanValue: true, - falseBooleanValue: false, - byteValue: 5, - doubleValue: 1.889, - floatValue: 7.624, - integerValue: 256, - shortValue: 9898, - stringValue: "simple" - } - }, - { - id: "RpcV2CborClientDoesntDeSerializeNullStructureValues", - documentation: "RpcV2 Cbor should deserialize null structure values", - protocol: rpcv2, - body: "v2tzdHJpbmdWYWx1Zfb/", - code: 200, - bodyMediaType: "application/cbor", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Content-Type": "application/cbor" - } - params: {} - appliesTo: "client" - }, - { - id: "RpcV2CborServerDoesntSerializeNullStructureValues", - documentation: "RpcV2 Cbor should not serialize null structure values", - protocol: rpcv2, - body: "v/8=", - code: 200, - bodyMediaType: "application/cbor", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Content-Type": "application/cbor" - } - params: { - stringValue: null - }, - appliesTo: "server" - }, -]) -operation SimpleScalarProperties { - input: SimpleScalarStructure, - output: SimpleScalarStructure -} - - -structure SimpleScalarStructure { - trueBooleanValue: Boolean, - falseBooleanValue: Boolean, - byteValue: Byte, - doubleValue: Double, - floatValue: Float, - integerValue: Integer, - longValue: Long, - shortValue: Short, - stringValue: String, -} diff --git a/codegen-core/common-test-models/adwait-empty-input-output.smithy b/codegen-core/common-test-models/adwait-empty-input-output.smithy deleted file mode 100644 index 186cbaaa30..0000000000 --- a/codegen-core/common-test-models/adwait-empty-input-output.smithy +++ /dev/null @@ -1,174 +0,0 @@ -$version: "2.0" - -namespace aws.protocoltests.rpcv2 - -use smithy.protocols#rpcv2 -use smithy.test#httpRequestTests -use smithy.test#httpResponseTests - - -@httpRequestTests([ - { - id: "no_input", - protocol: rpcv2, - documentation: "Body is empty and no Content-Type header if no input", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - }, - forbidHeaders: [ - "Content-Type", - "X-Amz-Target" - ] - method: "POST", - uri: "/service/RpcV2Protocol/operation/NoInputOutput", - body: "" - }, - { - id: "no_input_server_allows_accept", - protocol: rpcv2, - documentation: "Servers should allow the Accept header to be set to the default content-type.", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - "Content-Type": "application/cbor" - } - method: "POST", - uri: "/service/RpcV2Protocol/operation/NoInputOutput", - body: "", - appliesTo: "server" - }, - { - id: "no_input_server_allows_empty_cbor", - protocol: rpcv2, - documentation: "Servers should accept CBOR empty struct if no input.", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - "Content-Type": "application/cbor" - } - method: "POST", - uri: "/service/RpcV2Protocol/operation/NoInputOutput", - body: "v/8=", - appliesTo: "server" - } -]) -@httpResponseTests([ - { - id: "no_output", - protocol: rpcv2, - documentation: "Body is empty and no Content-Type header if no response", - body: "", - bodyMediaType: "application/cbor", - headers: { - "smithy-protocol": "rpc-v2-cbor", - }, - forbidHeaders: [ - "Content-Type" - ] - code: 200, - }, - { - id: "no_output_client_allows_accept", - protocol: rpcv2, - documentation: "Servers should allow the accept header to be set to the default content-type.", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - "Content-Type": "application/cbor" - } - body: "", - code: 200, - appliesTo: "client", - }, - { - id: "no_input_client_allows_empty_cbor", - protocol: rpcv2, - documentation: "Client should accept CBOR empty struct if no output", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - "Content-Type": "application/cbor" - } - body: "v/8=", - code: 200, - appliesTo: "client", - } -]) -operation NoInputOutput {} - - -@httpRequestTests([ - { - id: "empty_input", - protocol: rpcv2, - documentation: "When Input structure is empty we write CBOR equivalent of {}", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - "Content-Type": "application/cbor" - }, - forbidHeaders: [ - "X-Amz-Target" - ] - method: "POST", - uri: "/service/RpcV2Protocol/operation/EmptyInputOutput", - body: "v/8=", - }, -]) -@httpResponseTests([ - { - id: "empty_output", - protocol: rpcv2, - documentation: "When output structure is empty we write CBOR equivalent of {}", - body: "v/8=", - bodyMediaType: "application/cbor", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Content-Type": "application/cbor" - } - code: 200, - }, -]) -operation EmptyInputOutput { - input: EmptyStructure, - output: EmptyStructure -} - -@httpRequestTests([ - { - id: "optional_input", - protocol: rpcv2, - documentation: "When input is empty we write CBOR equivalent of {}", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Accept": "application/cbor", - "Content-Type": "application/cbor" - }, - forbidHeaders: [ - "X-Amz-Target" - ] - method: "POST", - uri: "/service/RpcV2Protocol/operation/OptionalInputOutput", - body: "v/8=", - bodyMediaType: "application/cbor", - }, -]) -@httpResponseTests([ - { - id: "optional_output", - protocol: rpcv2, - documentation: "When output is empty we write CBOR equivalent of {}", - body: "v/8=", - bodyMediaType: "application/cbor", - headers: { - "smithy-protocol": "rpc-v2-cbor", - "Content-Type": "application/cbor" - } - code: 200, - }, -]) -operation OptionalInputOutput { - input: SimpleStructure, - output: SimpleStructure -} diff --git a/codegen-core/common-test-models/adwait-main.smithy b/codegen-core/common-test-models/adwait-main.smithy deleted file mode 100644 index c827b1dd79..0000000000 --- a/codegen-core/common-test-models/adwait-main.smithy +++ /dev/null @@ -1,30 +0,0 @@ -$version: "2.0" - -namespace aws.protocoltests.rpcv2 -use aws.api#service -use smithy.protocols#rpcv2 -use smithy.test#httpRequestTests -use smithy.test#httpResponseTests - -@service(sdkId: "Sample RpcV2 Protocol") -@rpcv2(format: ["cbor"]) -@title("RpcV2 Protocol Service") -service RpcV2Protocol { - version: "2020-07-14", - operations: [ - //Basic input/output tests - NoInputOutput, - EmptyInputOutput, - OptionalInputOutput, - - SimpleScalarProperties, - ] -} - -structure EmptyStructure { - -} - -structure SimpleStructure { - value: String, -} \ No newline at end of file diff --git a/codegen-core/common-test-models/rpcv2.smithy b/codegen-core/common-test-models/rpcv2-extras.smithy similarity index 92% rename from codegen-core/common-test-models/rpcv2.smithy rename to codegen-core/common-test-models/rpcv2-extras.smithy index 48a4fb694b..6714bbfea3 100644 --- a/codegen-core/common-test-models/rpcv2.smithy +++ b/codegen-core/common-test-models/rpcv2-extras.smithy @@ -1,13 +1,13 @@ $version: "2.0" -// TODO Update namespace -namespace com.amazonaws.simple +namespace smithy.protocoltests.rpcv2Cbor use smithy.framework#ValidationException -use smithy.protocols#rpcv2 +use smithy.protocols#rpcv2Cbor use smithy.test#httpResponseTests -@rpcv2(format: ["cbor"]) + +@rpcv2Cbor service RpcV2Service { operations: [ SimpleStructOperation, @@ -34,7 +34,7 @@ operation ComplexStructOperation { apply SimpleStructOperation @httpResponseTests([ { id: "SimpleStruct", - protocol: "smithy.protocols#rpcv2", + protocol: "smithy.protocols#rpcv2Cbor", code: 200, // Not used. params: { blob: "blobby blob", @@ -81,7 +81,7 @@ apply SimpleStructOperation @httpResponseTests([ // Same test, but leave optional types empty { id: "SimpleStructWithOptionsSetToNone", - protocol: "smithy.protocols#rpcv2", + protocol: "smithy.protocols#rpcv2Cbor", code: 200, // Not used. params: { requiredBlob: "blobby blob", @@ -150,7 +150,7 @@ structure ComplexStruct { map: SimpleMap union: SimpleUnion - structureList: StructureList + structureList: StructList // `@required` for good measure here. @required complexList: ComplexList @@ -158,7 +158,7 @@ structure ComplexStruct { @required complexUnion: ComplexUnion } -list StructureList { +list StructList { member: SimpleStruct } @@ -171,6 +171,7 @@ map SimpleMap { value: Integer } +// TODO Cut ticket to Smithy: their protocol tests don't have unions union SimpleUnion { blob: Blob boolean: Boolean @@ -203,4 +204,5 @@ enum Suit { SPADE } +// TODO Documents are not supported in RPC v2 CBOR. // document MyDocument diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 012c695d16..d7e5426a87 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -277,7 +277,7 @@ class CborParserGenerator( *codegenScope, "UnionSymbol" to returnSymbolToParse.symbol, ) { - withBlock("Ok(match decoder.str()? {", "})") { + withBlock("Ok(match decoder.str()?.as_ref() {", "})") { for (member in shape.members()) { val variantName = symbolProvider.toMemberName(member) diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 1c9c9d4c7d..7692ffb93f 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -46,16 +46,10 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")), // CodegenTest("aws.protocoltests.restxml#RestXml", "restXml"), CodegenTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"), - // Todo: change this to rpcv2extra - CodegenTest("com.amazonaws.simple#RpcV2Service", "rpcv2Extra", imports = listOf("$commonModels/rpcv2.smithy")), CodegenTest( - "aws.protocoltests.rpcv2#RpcV2Protocol", - "adwait-main", - imports = listOf( - "$commonModels/adwait-main.smithy", - "$commonModels/adwait-cbor-structs.smithy", - "$commonModels/adwait-empty-input-output.smithy", - ) + "smithy.protocoltests.rpcv2Cbor#RpcV2Service", + "rpcv2_extras", + imports = listOf("$commonModels/rpcv2-extras.smithy") ), CodegenTest( "com.amazonaws.constraints#ConstraintsService", diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt index ed418ad71f..a84ef8979e 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt @@ -62,7 +62,7 @@ internal class CborSerializerGeneratorTest { @Test fun `we serialize and serde_cbor deserializes round trip`() { - val model = File("../codegen-core/common-test-models/rpcv2.smithy").readText().asSmithyModel() + val model = File("../codegen-core/common-test-models/rpc-v2-extras.smithy").readText().asSmithyModel() val addDeriveSerdeSerializeDecorator = object : ServerCodegenDecorator { override val name: String = "Add `#[derive(serde::Deserialize)]`" diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs index bbff32bfd5..9a0f33e209 100644 --- a/rust-runtime/aws-smithy-cbor/src/decode.rs +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -63,7 +63,6 @@ impl DeserializeError { } } - /// Macro for delegating method calls to the decoder. /// /// This macro generates wrapper methods for calling specific encoder methods on the decoder @@ -133,8 +132,8 @@ impl<'b> Decoder<'b> { self.decoder.position() } - /// Returns a `cow::Borrowed(&str)` if the element at the current position in the buffer is a definite - /// length string. Otherwise, it returns a `cow::Owned(String)` if the element at the current position is an + /// Returns a `Cow::Borrowed(&str)` if the element at the current position in the buffer is a definite + /// length string. Otherwise, it returns a `Cow::Owned(String)` if the element at the current position is an /// indefinite-length string. An error is returned if the element is neither a definite length nor an /// indefinite-length string. pub fn str(&mut self) -> Result, DeserializeError> { @@ -143,7 +142,7 @@ impl<'b> Decoder<'b> { Ok(str_value) => Ok(Cow::Borrowed(str_value)), Err(e) if e.is_type_mismatch() => { // Move the position back to the start of the CBOR element and then try - // decoding it as a indefinite length string. + // decoding it as an indefinite length string. self.decoder.set_position(bookmark); Ok(Cow::Owned(self.string()?)) } From e0e0366747d0dc4d54fd66be282072168fcb9cac Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 4 Jun 2024 14:14:42 +0200 Subject: [PATCH 03/77] Fix e2e tests --- buildSrc/src/main/kotlin/CodegenTestCommon.kt | 5 ++--- examples/Makefile | 13 +++---------- rust-runtime/aws-smithy-cbor/src/decode.rs | 6 +++--- rust-runtime/aws-smithy-cbor/src/encode.rs | 2 +- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index 1977be9b11..d1aac0f42a 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -52,13 +52,12 @@ fun toRustCrateName(input: String): String { "proc_macro" ) - // Then within your function, you could include a check against this set if (input.isBlank()) { throw IllegalArgumentException("Rust crate name cannot be empty") } val lowerCased = input.lowercase() - // Replace any sequence of characters that are not lowercase letters, numbers, or underscores with a single underscore. - val sanitized = lowerCased.replace(Regex("[^a-z0-9_]+"), "_") + // Replace any sequence of characters that are not lowercase letters, numbers, dashes, or underscores with a single underscore. + val sanitized = lowerCased.replace(Regex("[^a-z0-9_-]+"), "_") // Trim leading or trailing underscores. val trimmed = sanitized.trim('_') // Check if the resulting string is empty, purely numeric, or a reserved name diff --git a/examples/Makefile b/examples/Makefile index 791248c0a0..9c2bf9061d 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -3,23 +3,16 @@ CUR_DIR := $(shell pwd) GRADLE := $(SRC_DIR)/gradlew SERVER_SDK_DST := $(CUR_DIR)/pokemon-service-server-sdk CLIENT_SDK_DST := $(CUR_DIR)/pokemon-service-client -RPCV2_SDK_DST := $(CUR_DIR)/rpcv2-server-sdk -RPCV2_CLIENT_DST := $(CUR_DIR)/rpcv2-pokemon-client - SERVER_SDK_SRC := $(SRC_DIR)/codegen-server-test/build/smithyprojections/codegen-server-test/pokemon-service-server-sdk/rust-server-codegen CLIENT_SDK_SRC := $(SRC_DIR)/codegen-client-test/build/smithyprojections/codegen-client-test/pokemon-service-client/rust-client-codegen -RPCV2_SDK_SRC := $(SRC_DIR)/codegen-server-test/build/smithyprojections/codegen-server-test/rpcv2/rust-server-codegen -RPCV2_CLIENT_SRC := $(SRC_DIR)/codegen-client-test/build/smithyprojections/codegen-client-test/rpcv2-pokemon-client/rust-client-codegen all: codegen codegen: - $(GRADLE) --project-dir $(SRC_DIR) -P modules='pokemon-service-server-sdk,pokemon-service-client,rpcv2,rpcv2-pokemon-client' :codegen-client-test:clean :codegen-server-test:clean :codegen-client-test:assemble :codegen-server-test:assemble - mkdir -p $(SERVER_SDK_DST) $(CLIENT_SDK_DST) $(RPCV2_SDK_DST) $(RPCV2_CLIENT_DST) + $(GRADLE) --project-dir $(SRC_DIR) -P modules='pokemon-service-server-sdk,pokemon-service-client' :codegen-client-test:assemble :codegen-server-test:assemble + mkdir -p $(SERVER_SDK_DST) $(CLIENT_SDK_DST) cp -av $(SERVER_SDK_SRC)/* $(SERVER_SDK_DST)/ cp -av $(CLIENT_SDK_SRC)/* $(CLIENT_SDK_DST)/ - cp -av $(RPCV2_SDK_SRC)/* $(RPCV2_SDK_DST)/ - cp -av $(RPCV2_CLIENT_SRC)/* $(RPCV2_CLIENT_DST)/ build: codegen cargo build @@ -46,6 +39,6 @@ lambda_invoke: cargo lambda invoke pokemon-service-lambda --data-file pokemon-service/tests/fixtures/example-apigw-request.json distclean: clean - rm -rf $(SERVER_SDK_DST) $(CLIENT_SDK_DST) $(RPCV2_SDK_DST) $(RPCV2_CLIENT_DST) Cargo.lock + rm -rf $(SERVER_SDK_DST) $(CLIENT_SDK_DST) Cargo.lock .PHONY: all diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs index 9a0f33e209..6e3dd42ea3 100644 --- a/rust-runtime/aws-smithy-cbor/src/decode.rs +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -65,12 +65,12 @@ impl DeserializeError { /// Macro for delegating method calls to the decoder. /// -/// This macro generates wrapper methods for calling specific encoder methods on the decoder -/// and returning the result with error handling. +/// This macro generates wrapper methods for calling specific methods on the decoder and returning +/// the result with error handling. /// /// # Example /// -/// ``` +/// ```ignore /// delegate_method! { /// /// Wrapper method for encoding method `encode_str` on the decoder. /// encode_str_wrapper => encode_str(String); diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs index 475071ffcc..60663a9a7b 100644 --- a/rust-runtime/aws-smithy-cbor/src/encode.rs +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -7,7 +7,7 @@ use aws_smithy_types::{Blob, DateTime}; /// /// # Example /// -/// ``` +/// ```ignore /// delegate_method! { /// /// Wrapper method for encoding method `encode_str` on the encoder. /// encode_str_wrapper => encode_str(data: &str); From bf0157a3c9e760066efbb922ac78384701d34f2c Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 6 Jun 2024 14:34:34 +0200 Subject: [PATCH 04/77] Merge conflict --- examples/Cargo.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index f65cd8c4ec..a374adf6f0 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -8,11 +8,7 @@ members = [ "pokemon-service-lambda", "pokemon-service-server-sdk", "pokemon-service-client", -<<<<<<< HEAD -======= "pokemon-service-client-usage", - ->>>>>>> release-2023-11-01 ] [profile.release] From 853eb893c0996db2ff5b08a40b4b27ebd9c446c3 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 6 Jun 2024 14:34:57 +0200 Subject: [PATCH 05/77] Clear RUSTDOCFLAGS --- buildSrc/src/main/kotlin/CodegenTestCommon.kt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index d1aac0f42a..294969a3f4 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -322,6 +322,10 @@ fun Project.registerCargoCommandsTasks(outputDir: File) { this.tasks.register(Cargo.DOCS.toString) { dependsOn(dependentTasks) workingDir(outputDir) + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3194#issuecomment-2147657902) + // Clear `RUSTDOCFLAGS` because in the CI Docker image we bake in `-D warnings`, but we currently + // generate docs with warnings. + environment("RUSTDOCFLAGS", "") commandLine("cargo", "doc", "--no-deps", "--document-private-items") } From b7d7d6021895b76475691344ebef4d7d8bfa179b Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 6 Jun 2024 14:38:52 +0200 Subject: [PATCH 06/77] Fix CborSerializerGeneratorTest --- .../smithy/protocols/serialize/CborSerializerGeneratorTest.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt index a84ef8979e..1f59abd014 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt @@ -62,7 +62,7 @@ internal class CborSerializerGeneratorTest { @Test fun `we serialize and serde_cbor deserializes round trip`() { - val model = File("../codegen-core/common-test-models/rpc-v2-extras.smithy").readText().asSmithyModel() + val model = File("../codegen-core/common-test-models/rpcv2-extras.smithy").readText().asSmithyModel() val addDeriveSerdeSerializeDecorator = object : ServerCodegenDecorator { override val name: String = "Add `#[derive(serde::Deserialize)]`" From 73379019e925a3ce9fd2f0aa096b56105fa2c50a Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 10 Jun 2024 14:23:38 +0200 Subject: [PATCH 07/77] Add clarifying comment to upgrade.rs --- rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs index 15e87ebc19..f0b72a79ef 100644 --- a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs +++ b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs @@ -167,6 +167,8 @@ where type Future = UpgradeFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // The check that the inner service is ready is done by `Oneshot` in `UpgradeFuture`'s + // implementation. Poll::Ready(Ok(())) } From ab660157f3427b7f21377b2eba7e0953c003ac24 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 11 Jun 2024 17:44:01 +0200 Subject: [PATCH 08/77] Fix request `Content-Type` header checking in servers This fixes two bugs: 1. `Content-Type` header checking was succeeding when no `Content-Type` header was present but one was expected. 2. When a shape was @httpPayload`-bound, `Content-Type` header checking occurred even when no payload was being sent. In this case it is not necessary to check the header, since there is no content. Code has been refactored and cleaned up. The crux of the logic is now easier to understand, and contained in `content_type_header_classifier`. --- .../rest-json-extras-2310.smithy | 36 +++ .../rest-json-extras-2314.smithy | 158 +++++++++++++ .../rest-json-extras-2315.smithy | 62 +++++ .../rest-json-extras.smithy | 9 + codegen-server-test/build.gradle.kts | 10 +- .../protocol/ServerProtocolTestGenerator.kt | 3 + .../ServerHttpBoundProtocolGenerator.kt | 105 ++++----- .../src/protocol/mod.rs | 211 ++++++++---------- .../aws-smithy-http-server/src/rejection.rs | 6 +- 9 files changed, 430 insertions(+), 170 deletions(-) create mode 100644 codegen-core/common-test-models/rest-json-extras-2310.smithy create mode 100644 codegen-core/common-test-models/rest-json-extras-2314.smithy create mode 100644 codegen-core/common-test-models/rest-json-extras-2315.smithy diff --git a/codegen-core/common-test-models/rest-json-extras-2310.smithy b/codegen-core/common-test-models/rest-json-extras-2310.smithy new file mode 100644 index 0000000000..1c72f492ce --- /dev/null +++ b/codegen-core/common-test-models/rest-json-extras-2310.smithy @@ -0,0 +1,36 @@ +$version: "1.0" + +namespace aws.protocoltests.restjson + +use aws.protocols#restJson1 +use smithy.test#httpMalformedRequestTests + +@http(method: "POST", uri: "/MalformedContentTypeWithBody") +operation MalformedContentTypeWithBody2 { + input: GreetingStruct +} + +structure GreetingStruct { + salutation: String, +} + +apply MalformedContentTypeWithBody2 @httpMalformedRequestTests([ + { + id: "RestJsonWithBodyExpectsApplicationJsonContentTypeNoHeaders", + documentation: """ + When there is modeled input, the content type must be application/json""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedContentTypeWithBody", + body: "{}", + }, + response: { + code: 415, + headers: { + "x-amzn-errortype": "UnsupportedMediaTypeException" + } + }, + tags: [ "content-type" ] + } +]) diff --git a/codegen-core/common-test-models/rest-json-extras-2314.smithy b/codegen-core/common-test-models/rest-json-extras-2314.smithy new file mode 100644 index 0000000000..cdeb59f112 --- /dev/null +++ b/codegen-core/common-test-models/rest-json-extras-2314.smithy @@ -0,0 +1,158 @@ +$version: "1.0" + +namespace aws.protocoltests.restjson + +use aws.protocols#restJson1 +use smithy.test#httpMalformedRequestTests +use smithy.test#httpRequestTests +use smithy.test#httpResponseTests + +/// This example serializes a blob shape in the payload. +/// +/// In this example, no JSON document is synthesized because the payload is +/// not a structure or a union type. +@http(uri: "/HttpPayloadTraits", method: "POST") +operation HttpPayloadTraits2 { + input: HttpPayloadTraitsInputOutput, + output: HttpPayloadTraitsInputOutput +} + +apply HttpPayloadTraits2 @httpRequestTests([ + { + id: "RestJsonHttpPayloadTraitsWithBlobAcceptsNoContentType", + documentation: """ + Servers must accept no content type for blob inputs + without the media type trait.""", + protocol: restJson1, + method: "POST", + uri: "/HttpPayloadTraits", + body: "This is definitely a jpeg", + bodyMediaType: "application/octet-stream", + headers: { + "X-Foo": "Foo", + }, + params: { + foo: "Foo", + blob: "This is definitely a jpeg" + }, + appliesTo: "server", + tags: [ "content-type" ] + } +]) + +/// This example serializes a string shape in the payload. +/// +/// In this example, no JSON document is synthesized because the payload is +/// not a structure or a union type. +@http(uri: "/HttpPayloadTraitOnString", method: "POST") +operation HttpPayloadTraitOnString2 { + input: HttpPayloadTraitOnStringInputOutput, + output: HttpPayloadTraitOnStringInputOutput +} + +structure HttpPayloadTraitOnStringInputOutput { + @httpPayload + foo: String, +} + +apply HttpPayloadTraitOnString2 @httpRequestTests([ + { + id: "RestJsonHttpPayloadTraitOnString", + documentation: "Serializes a string in the HTTP payload", + protocol: restJson1, + method: "POST", + uri: "/HttpPayloadTraitOnString", + body: "Foo", + bodyMediaType: "text/plain", + headers: { + "Content-Type": "text/plain", + }, + requireHeaders: [ + "Content-Length" + ], + params: { + foo: "Foo", + } + }, +]) + +apply HttpPayloadTraitOnString2 @httpResponseTests([ + { + id: "RestJsonHttpPayloadTraitOnString", + documentation: "Serializes a string in the HTTP payload", + protocol: restJson1, + code: 200, + body: "Foo", + bodyMediaType: "text/plain", + headers: { + "Content-Type": "text/plain", + }, + params: { + foo: "Foo", + } + }, +]) + +apply HttpPayloadTraitOnString2 @httpMalformedRequestTests([ + { + id: "RestJsonHttpPayloadTraitOnStringNoContentType", + documentation: "Serializes a string in the HTTP payload without a content-type header", + protocol: restJson1, + request: { + method: "POST", + uri: "/HttpPayloadTraitOnString", + body: "Foo", + // We expect a `Content-Type` header but none was provided. + }, + response: { + code: 415, + headers: { + "x-amzn-errortype": "UnsupportedMediaTypeException" + } + }, + tags: [ "content-type" ] + }, + { + id: "RestJsonHttpPayloadTraitOnStringWrongContentType", + documentation: "Serializes a string in the HTTP payload without the expected content-type header", + protocol: restJson1, + request: { + method: "POST", + uri: "/HttpPayloadTraitOnString", + body: "Foo", + headers: { + // We expect `text/plain`. + "Content-Type": "application/json", + }, + }, + response: { + code: 415, + headers: { + "x-amzn-errortype": "UnsupportedMediaTypeException" + } + }, + tags: [ "content-type" ] + }, + { + id: "RestJsonHttpPayloadTraitOnStringUnsatisfiableAccept", + documentation: "Serializes a string in the HTTP payload with an unstatisfiable accept header", + protocol: restJson1, + request: { + method: "POST", + uri: "/HttpPayloadTraitOnString", + body: "Foo", + headers: { + "Content-Type": "text/plain", + // We can't satisfy this requirement; the server will return `text/plain`. + "Accept": "application/json", + }, + }, + response: { + code: 406, + headers: { + "x-amzn-errortype": "NotAcceptableException" + } + }, + tags: [ "accept" ] + }, +]) diff --git a/codegen-core/common-test-models/rest-json-extras-2315.smithy b/codegen-core/common-test-models/rest-json-extras-2315.smithy new file mode 100644 index 0000000000..e979c86b84 --- /dev/null +++ b/codegen-core/common-test-models/rest-json-extras-2315.smithy @@ -0,0 +1,62 @@ +$version: "2.0" + +namespace aws.protocoltests.restjson + +use smithy.test#httpRequestTests +use smithy.test#httpResponseTests +use smithy.framework#ValidationException + +@http(uri: "/EnumPayload2", method: "POST") +@httpRequestTests([ + { + id: "RestJsonEnumPayloadRequest2", + uri: "/EnumPayload2", + headers: { "Content-Type": "text/plain" }, + body: "enumvalue", + params: { payload: "enumvalue" }, + method: "POST", + protocol: "aws.protocols#restJson1" + } +]) +@httpResponseTests([ + { + id: "RestJsonEnumPayloadResponse2", + headers: { "Content-Type": "text/plain" }, + body: "enumvalue", + params: { payload: "enumvalue" }, + protocol: "aws.protocols#restJson1", + code: 200 + } +]) +operation HttpEnumPayload2 { + input: EnumPayloadInput, + output: EnumPayloadInput + errors: [ValidationException] +} + +@http(uri: "/StringPayload2", method: "POST") +@httpRequestTests([ + { + id: "RestJsonStringPayloadRequest2", + uri: "/StringPayload2", + headers: { "Content-Type": "text/plain" }, + body: "rawstring", + params: { payload: "rawstring" }, + method: "POST", + protocol: "aws.protocols#restJson1" + } +]) +@httpResponseTests([ + { + id: "RestJsonStringPayloadResponse2", + headers: { "Content-Type": "text/plain" }, + body: "rawstring", + params: { payload: "rawstring" }, + protocol: "aws.protocols#restJson1", + code: 200 + } +]) +operation HttpStringPayload2 { + input: StringPayloadInput, + output: StringPayloadInput +} diff --git a/codegen-core/common-test-models/rest-json-extras.smithy b/codegen-core/common-test-models/rest-json-extras.smithy index 73f45e7dfd..2952b2a2f2 100644 --- a/codegen-core/common-test-models/rest-json-extras.smithy +++ b/codegen-core/common-test-models/rest-json-extras.smithy @@ -66,6 +66,14 @@ service RestJsonExtras { CaseInsensitiveErrorOperation, EmptyStructWithContentOnWireOp, QueryPrecedence, + // TODO(https://github.com/smithy-lang/smithy/pull/2314) + HttpPayloadTraitOnString2, + HttpPayloadTraits2, + // TODO(https://github.com/smithy-lang/smithy/pull/2310) + MalformedContentTypeWithBody2, + // TODO(https://github.com/smithy-lang/smithy/pull/2315) + HttpEnumPayload2, + HttpStringPayload2, ], errors: [ExtraError] } @@ -101,6 +109,7 @@ structure ExtraError {} id: "StringPayload", uri: "/StringPayload", body: "rawstring", + headers: { "Content-Type": "text/plain" }, params: { payload: "rawstring" }, method: "POST", protocol: "aws.protocols#restJson1" diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 7692ffb93f..56671e3923 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -71,7 +71,15 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> CodegenTest( "aws.protocoltests.restjson#RestJsonExtras", "rest_json_extras", - imports = listOf("$commonModels/rest-json-extras.smithy"), + imports = listOf( + "$commonModels/rest-json-extras.smithy", + // TODO(https://github.com/smithy-lang/smithy/pull/2310): Can be deleted when consumed in next Smithy version. + "$commonModels/rest-json-extras-2310.smithy", + // TODO(https://github.com/smithy-lang/smithy/pull/2314): Can be deleted when consumed in next Smithy version. + "$commonModels/rest-json-extras-2314.smithy", + // TODO(https://github.com/smithy-lang/smithy/pull/2315): Can be deleted when consumed in next Smithy version. + "$commonModels/rest-json-extras-2315.smithy", + ), ), CodegenTest( "aws.protocoltests.restjson.validation#RestJsonValidation", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index c9ffa06cc4..073cae0205 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -832,6 +832,9 @@ class ServerProtocolTestGenerator( FailingTest(REST_JSON, "RestJsonEndpointTrait", TestType.Request), FailingTest(REST_JSON, "RestJsonEndpointTraitWithHostLabel", TestType.Request), FailingTest(REST_JSON, "RestJsonOmitsEmptyListQueryValues", TestType.Request), + // TODO(https://github.com/smithy-lang/smithy/pull/2315): Can be deleted when fixed tests are consumed in next Smithy version + FailingTest(REST_JSON, "RestJsonEnumPayloadRequest", TestType.Request), + FailingTest(REST_JSON, "RestJsonStringPayloadRequest", TestType.Request), // Tests involving `@range` on floats. // Pending resolution from the Smithy team, see https://github.com/smithy-lang/smithy-rs/issues/2007. FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloat_case0", TestType.MalformedRequest), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index ac966853e6..5e921a83ea 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -219,7 +219,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( * we require the HTTP body to be fully read in memory before parsing or deserialization. * From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize * an HTTP response to `Bytes`. - * These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server. + * These traits are the public entrypoint of the ser/de logic of the generated server. */ private fun RustWriter.renderTraits( inputSymbol: Symbol, @@ -259,35 +259,6 @@ class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate(init, *codegenScope) } } - // This checks for the expected `Content-Type` header if the `@httpPayload` trait is present, as dictated by - // the core Smithy library, which _does not_ require deserializing the payload. - // If no members have `@httpPayload`, the expected `Content-Type` header as dictated _by the protocol_ is - // checked later on for non-streaming operations, in `serverRenderShapeParser`: that check _does_ require at - // least buffering the entire payload, since the check must only be performed if the payload is empty. - val verifyRequestContentTypeHeader = - writable { - operationShape - .inputShape(model) - .members() - .find { it.hasTrait() } - ?.let { payload -> - val target = model.expectShape(payload.target) - if (!target.isBlobShape || target.hasTrait()) { - // `null` is only returned by Smithy when there are no members, but we know there's at least - // the one with `@httpPayload`, so `!!` is safe here. - val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!! - rustTemplate( - """ - #{SmithyHttpServer}::protocol::content_type_header_classifier_http( - request.headers(), - Some("$expectedRequestContentType"), - )?; - """, - *codegenScope, - ) - } - } - } // Implement `from_request` trait for input types. val inputFuture = "${inputSymbol.name}Future" @@ -326,7 +297,6 @@ class ServerHttpBoundProtocolTraitImplGenerator( fn from_request(request: #{http}::Request) -> Self::Future { let fut = async move { #{verifyAcceptHeader:W} - #{verifyRequestContentTypeHeader:W} #{parse_request}(request) .await .map_err(Into::into) @@ -348,7 +318,6 @@ class ServerHttpBoundProtocolTraitImplGenerator( "parse_request" to serverParseRequest(operationShape), "verifyAcceptHeader" to verifyAcceptHeader, "verifyAcceptHeaderStaticContentTypeInit" to verifyAcceptHeaderStaticContentTypeInit, - "verifyRequestContentTypeHeader" to verifyRequestContentTypeHeader, ) // Implement `into_response` for output types. @@ -777,42 +746,42 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } } + for (binding in bindings) { val member = binding.member val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) + val valueToSet = if (symbolProvider.toSymbol(binding.member).isOptional()) { + "Some(value)" + } else { + "value" + } if (parsedValue != null) { - rust("if let Some(value) = ") - parsedValue(this) - rust( + rustTemplate( """ - { - input = input.${member.setterName()}(${ - if (symbolProvider.toSymbol(binding.member).isOptional()) { - "Some(value)" - } else { - "value" - } - }); + if let Some(value) = #{ParsedValue:W} { + input = input.${member.setterName()}($valueToSet) } """, + "ParsedValue" to parsedValue ) } } + serverRenderUriPathParser(this, operationShape) serverRenderQueryStringParser(this, operationShape) + // If there's no modeled operation input, some protocols require that `Content-Type` header not be present. val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) { - conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) { - rustTemplate( - """ - #{SmithyHttpServer}::protocol::content_type_header_empty_body_no_modeled_input(&headers)?; - """, - *codegenScope, - ) - } + rustTemplate( + """ + #{SmithyHttpServer}::protocol::content_type_header_classifier_smithy(&headers, None)?; + """, + *codegenScope, + ) } + val err = if (ServerBuilderGenerator.hasFallibleBuilder( inputShape, @@ -860,14 +829,48 @@ class ServerHttpBoundProtocolTraitImplGenerator( *codegenScope, ) } else { + // This checks for the expected `Content-Type` header if the `@httpPayload` trait is present, as dictated by + // the core Smithy library, which _does not_ require deserializing the payload. + // If no members have `@httpPayload`, the expected `Content-Type` header as dictated _by the protocol_ is + // checked later on for non-streaming operations, in `serverRenderShapeParser`. + // Both checks require buffering the entire payload, since the check must only be performed if the payload is + // not empty. + val verifyRequestContentTypeHeader = + writable { + operationShape + .inputShape(model) + .members() + .find { it.hasTrait() } + ?.let { payload -> + val target = model.expectShape(payload.target) + if (!target.isBlobShape || target.hasTrait()) { + // `null` is only returned by Smithy when there are no members, but we know there's at least + // the one with `@httpPayload`, so `!!` is safe here. + val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!! + rustTemplate( + """ + if !bytes.is_empty() { + #{SmithyHttpServer}::protocol::content_type_header_classifier_smithy( + &headers, + Some("$expectedRequestContentType"), + )?; + } + """, + *codegenScope, + ) + } + } + } rustTemplate( """ { let bytes = #{Hyper}::body::to_bytes(body).await?; + #{VerifyRequestContentTypeHeader:W} #{Deserializer}(&bytes)? } """, "Deserializer" to deserializer, + "VerifyRequestContentTypeHeader" to verifyRequestContentTypeHeader, *codegenScope, ) } diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs index 4686e94e07..83c111f76b 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs @@ -14,7 +14,7 @@ pub mod rpc_v2; use crate::rejection::MissingContentTypeReason; use aws_smithy_runtime_api::http::Headers as SmithyHeaders; use http::header::CONTENT_TYPE; -use http::{HeaderMap, HeaderValue}; +use http::HeaderMap; #[cfg(test)] pub mod test_helpers { @@ -47,81 +47,63 @@ fn parse_mime(content_type: &str) -> Result Result<(), MissingContentTypeReason> { - if headers.contains_key(http::header::CONTENT_TYPE) { - let found_mime = headers - .get(http::header::CONTENT_TYPE) - .unwrap() // The header is present, `unwrap` will not panic. - .parse::() - .map_err(MissingContentTypeReason::MimeParseError)?; - Err(MissingContentTypeReason::UnexpectedMimeType { - expected_mime: None, - found_mime: Some(found_mime), - }) - } else { - Ok(()) - } -} - -/// Checks that the `content-type` header is valid from a Smithy `Headers`. +/// Checks that the `content-type` header from a `SmithyHeaders` matches what we expect. #[allow(clippy::result_large_err)] pub fn content_type_header_classifier_smithy( headers: &SmithyHeaders, expected_content_type: Option<&'static str>, ) -> Result<(), MissingContentTypeReason> { - match headers.get(CONTENT_TYPE) { - Some(content_type) => content_type_header_classifier(content_type, expected_content_type), - None => Ok(()), - } + let actual_content_type = headers.get(CONTENT_TYPE); + content_type_header_classifier(actual_content_type, expected_content_type) } -/// Checks that the `content-type` header is valid from a `http::HeaderMap`. +/// Checks that the `content-type` header matches what we expect. #[allow(clippy::result_large_err)] -pub fn content_type_header_classifier_http( - headers: &HeaderMap, +fn content_type_header_classifier( + actual_content_type: Option<&str>, expected_content_type: Option<&'static str>, ) -> Result<(), MissingContentTypeReason> { - if let Some(content_type) = headers.get(http::header::CONTENT_TYPE) { - let content_type = content_type.to_str().map_err(MissingContentTypeReason::ToStrError)?; - content_type_header_classifier(content_type, expected_content_type) - } else { - Ok(()) + fn parse_expected_mime(expected_content_type: &str) -> mime::Mime { + let mime = expected_content_type + .parse::() + // `expected_content_type` comes from the codegen. + .expect("BUG: MIME parsing failed, `expected_content_type` is not valid; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues"); + debug_assert_eq!( + mime, expected_content_type, + "BUG: expected `content-type` header value we own from codegen should coincide with its mime type; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues", + ); + mime } -} -/// Checks that the `content-type` header is valid. -#[allow(clippy::result_large_err)] -fn content_type_header_classifier( - content_type: &str, - expected_content_type: Option<&'static str>, -) -> Result<(), MissingContentTypeReason> { - let found_mime = parse_mime(content_type)?; - // There is a `content-type` header. - // If there is an implied content type, they must match. - if let Some(expected_content_type) = expected_content_type { - let expected_mime = expected_content_type - .parse::() - // `expected_content_type` comes from the codegen. - .expect("BUG: MIME parsing failed, `expected_content_type` is not valid. Please file a bug report under https://github.com/smithy-lang/smithy-rs/issues"); - if expected_content_type != found_mime { - return Err(MissingContentTypeReason::UnexpectedMimeType { + match (actual_content_type, expected_content_type) { + (None, None) => Ok(()), + (None, Some(expected_content_type)) => { + let expected_mime = parse_expected_mime(expected_content_type); + Err(MissingContentTypeReason::UnexpectedMimeType { expected_mime: Some(expected_mime), + found_mime: None, + }) + } + (Some(actual_content_type), None) => { + let found_mime = parse_mime(actual_content_type)?; + Err(MissingContentTypeReason::UnexpectedMimeType { + expected_mime: None, found_mime: Some(found_mime), - }); + }) + } + (Some(actual_content_type), Some(expected_content_type)) => { + let expected_mime = parse_expected_mime(expected_content_type); + let found_mime = parse_mime(actual_content_type)?; + if expected_mime != found_mime { + Err(MissingContentTypeReason::UnexpectedMimeType { + expected_mime: Some(expected_mime), + found_mime: Some(found_mime), + }) + } else { + Ok(()) + } } - } else { - // `content-type` header and no modeled input (mismatch). - return Err(MissingContentTypeReason::UnexpectedMimeType { - expected_mime: None, - found_mime: Some(found_mime), - }); } - Ok(()) } pub fn accept_header_classifier(headers: &HeaderMap, content_type: &mime::Mime) -> bool { @@ -164,8 +146,8 @@ mod tests { use aws_smithy_runtime_api::http::Headers; use http::header::{HeaderValue, ACCEPT, CONTENT_TYPE}; - fn req_content_type(content_type: &'static str) -> HeaderMap { - let mut headers = HeaderMap::new(); + fn req_content_type_smithy(content_type: &'static str) -> SmithyHeaders { + let mut headers = SmithyHeaders::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_str(content_type).unwrap()); headers } @@ -176,75 +158,72 @@ mod tests { headers } - const EXPECTED_MIME_APPLICATION_JSON: Option<&'static str> = Some("application/json"); + const APPLICATION_JSON: Option<&'static str> = Some("application/json"); - #[test] - fn check_content_type_header_empty_body_no_modeled_input() { - assert!(content_type_header_empty_body_no_modeled_input(&Headers::new()).is_ok()); + // Validates the rejection type since we cannot implement `PartialEq` + // for `MissingContentTypeReason`. + fn assert_unexpected_mime_type( + result: Result<(), MissingContentTypeReason>, + actually_expected_mime: Option, + actually_found_mime: Option, + ) { + match result { + Ok(()) => panic!("content-type validation is expected to fail"), + Err(e) => match e { + MissingContentTypeReason::UnexpectedMimeType { + expected_mime, + found_mime, + } => { + assert_eq!(actually_expected_mime, expected_mime); + assert_eq!(actually_found_mime, found_mime); + } + _ => panic!("unexpected `MissingContentTypeReason`: {}", e), + }, + } } #[test] - fn check_invalid_content_type_header_empty_body_no_modeled_input() { - let mut valid = Headers::new(); - valid.insert(CONTENT_TYPE, "application/json"); - let result = content_type_header_empty_body_no_modeled_input(&valid).unwrap_err(); - assert!(matches!( - result, - MissingContentTypeReason::UnexpectedMimeType { - expected_mime: None, - found_mime: Some(_) - } - )); + fn check_valid_content_type() { + let headers = req_content_type_smithy("application/json"); + assert!(content_type_header_classifier_smithy(&headers, APPLICATION_JSON,).is_ok()); } #[test] fn check_invalid_content_type() { let invalid = vec!["application/jason", "text/xml"]; for invalid_mime in invalid { - let headers = req_content_type(invalid_mime); - let mut results = Vec::new(); - results.push(content_type_header_classifier_http( - &headers, - EXPECTED_MIME_APPLICATION_JSON, - )); - results.push(content_type_header_classifier_smithy( - &Headers::try_from(headers).unwrap(), - EXPECTED_MIME_APPLICATION_JSON, - )); + let headers = req_content_type_smithy(invalid_mime); + let results = vec![content_type_header_classifier_smithy(&headers, APPLICATION_JSON)]; - // Validates the rejection type since we cannot implement `PartialEq` - // for `MissingContentTypeReason`. + let actually_expected_mime = Some(parse_mime(APPLICATION_JSON.unwrap()).unwrap()); for result in results { - match result { - Ok(()) => panic!("Content-type validation is expected to fail"), - Err(e) => match e { - MissingContentTypeReason::UnexpectedMimeType { - expected_mime, - found_mime, - } => { - assert_eq!( - expected_mime.unwrap(), - "application/json".parse::().unwrap() - ); - assert_eq!(found_mime, invalid_mime.parse::().ok()); - } - _ => panic!("Unexpected `MissingContentTypeReason`: {}", e), - }, - } + let actually_found_mime = invalid_mime.parse::().ok(); + assert_unexpected_mime_type(result, actually_expected_mime.clone(), actually_found_mime); } } } #[test] - fn check_missing_content_type_is_allowed() { - let result = content_type_header_classifier_http(&HeaderMap::new(), EXPECTED_MIME_APPLICATION_JSON); - assert!(result.is_ok()); + fn check_missing_content_type_is_not_allowed() { + let actually_expected_mime = Some(parse_mime(APPLICATION_JSON.unwrap()).unwrap()); + let result = content_type_header_classifier_smithy(&SmithyHeaders::new(), APPLICATION_JSON); + assert_unexpected_mime_type(result, actually_expected_mime, None); + } + + #[test] + fn check_missing_content_type_is_expected() { + let headers = req_content_type_smithy(APPLICATION_JSON.unwrap()); + let actually_found_mime = Some(parse_mime(APPLICATION_JSON.unwrap()).unwrap()); + let actually_expected_mime = None; + + let result = content_type_header_classifier_smithy(&headers, None); + assert_unexpected_mime_type(result, actually_expected_mime, actually_found_mime); } #[test] fn check_not_parsable_content_type() { - let request = req_content_type("123"); - let result = content_type_header_classifier_http(&request, EXPECTED_MIME_APPLICATION_JSON); + let request = req_content_type_smithy("123"); + let result = content_type_header_classifier_smithy(&request, APPLICATION_JSON); assert!(matches!( result.unwrap_err(), MissingContentTypeReason::MimeParseError(_) @@ -253,9 +232,15 @@ mod tests { #[test] fn check_non_ascii_visible_characters_content_type() { - let request = req_content_type("application/💩"); - let result = content_type_header_classifier_http(&request, EXPECTED_MIME_APPLICATION_JSON); - assert!(matches!(result.unwrap_err(), MissingContentTypeReason::ToStrError(_))); + // Note that for Smithy headers, validation fails when attempting to parse the mime type, + // unlike with `http`'s `HeaderMap`, that would fail when checking the header value is + // valid (~ASCII string). + let request = req_content_type_smithy("application/💩"); + let result = content_type_header_classifier_smithy(&request, APPLICATION_JSON); + assert!(matches!( + result.unwrap_err(), + MissingContentTypeReason::MimeParseError(_) + )); } #[test] diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index 1f1e247435..ae3d2854fb 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -11,13 +11,9 @@ use thiserror::Error; pub enum MissingContentTypeReason { #[error("headers taken by another extractor")] HeadersTakenByAnotherExtractor, - #[error("no `Content-Type` header")] - NoContentTypeHeader, - #[error("`Content-Type` header value is not a valid HTTP header value: {0}")] - ToStrError(http::header::ToStrError), #[error("invalid `Content-Type` header value mime type: {0}")] MimeParseError(mime::FromStrError), - #[error("unexpected `Content-Type` header value; expected {expected_mime:?}, found {found_mime:?}")] + #[error("unexpected `Content-Type` header value; expected mime {expected_mime:?}, found mime {found_mime:?}")] UnexpectedMimeType { expected_mime: Option, found_mime: Option, From 89004f750038927e6499632b4dad2e69f58ba232 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 11 Jun 2024 18:21:32 +0200 Subject: [PATCH 09/77] ServerProtocolTestGenerator --- .../protocol/ServerProtocolTestGenerator.kt | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 073cae0205..49388e106a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -446,13 +446,19 @@ class ServerProtocolTestGenerator( // We also escape to avoid interactions with templating in the case where the body contains `#`. val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() - val encodedBody = + // TODO This for RPC v2; use `bodyMediaType`, see GitHub issue above +// val encodedBody = +// """ +// #{Bytes}::from( +// #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( +// "`body` field of Smithy protocol test is not correctly base64 encoded" +// ) +// ) +// """ + // TODO This for other protocols + val encodedBody = """ - #{Bytes}::from( - #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( - "`body` field of Smithy protocol test is not correctly base64 encoded" - ) - ) + #{Bytes}::from_static($sanitizedBody.as_bytes()) """ "#{SmithyHttpServer}::body::Body::from($encodedBody)" From 0bc7e2344fbe0a0fecc374eaf92513b749f3519a Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 13 Jun 2024 13:46:18 +0200 Subject: [PATCH 10/77] Address some TODOs --- .../smithy/protocols/ClientProtocolLoader.kt | 4 +- .../common-test-models/pokemon-awsjson.smithy | 1 + ...-extras.smithy => rpcv2Cbor-extras.smithy} | 20 ++----- .../core/smithy/generators/Instantiator.kt | 4 +- .../protocols/{RpcV2.kt => RpcV2Cbor.kt} | 9 ++-- codegen-server-test/build.gradle.kts | 4 +- .../generators/protocol/ServerProtocol.kt | 9 ++-- .../protocol/ServerProtocolTestGenerator.kt | 52 +++++++++++++------ .../protocols/ServerRpcV2CborFactory.kt | 6 +-- .../{RpcV2Test.kt => RpcV2CborTest.kt} | 2 +- .../serialize/CborSerializerGeneratorTest.kt | 4 +- 11 files changed, 62 insertions(+), 53 deletions(-) rename codegen-core/common-test-models/{rpcv2-extras.smithy => rpcv2Cbor-extras.smithy} (88%) rename codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/{RpcV2.kt => RpcV2Cbor.kt} (96%) rename codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/{RpcV2Test.kt => RpcV2CborTest.kt} (97%) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index 2af64bbad8..dc49c5a6c6 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -29,7 +29,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml -import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2 +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.util.hasTrait class ClientProtocolLoader(supportedProtocols: ProtocolMap) : @@ -121,7 +121,7 @@ class ClientRestXmlFactory( } class ClientRpcV2CborFactory : ProtocolGeneratorFactory { - override fun protocol(codegenContext: ClientCodegenContext): Protocol = RpcV2(codegenContext) + override fun protocol(codegenContext: ClientCodegenContext): Protocol = RpcV2Cbor(codegenContext) override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator = OperationGenerator(codegenContext, protocol(codegenContext)) diff --git a/codegen-core/common-test-models/pokemon-awsjson.smithy b/codegen-core/common-test-models/pokemon-awsjson.smithy index 16eab7df90..77e78a58d7 100644 --- a/codegen-core/common-test-models/pokemon-awsjson.smithy +++ b/codegen-core/common-test-models/pokemon-awsjson.smithy @@ -27,6 +27,7 @@ service PokemonService { } /// Capture Pokémons via event streams. +@http(uri: "/simple-struct-operation", method: "POST") operation CapturePokemon { input := { events: AttemptCapturingPokemonEvent diff --git a/codegen-core/common-test-models/rpcv2-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy similarity index 88% rename from codegen-core/common-test-models/rpcv2-extras.smithy rename to codegen-core/common-test-models/rpcv2Cbor-extras.smithy index 6714bbfea3..44d9b51c16 100644 --- a/codegen-core/common-test-models/rpcv2-extras.smithy +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -15,7 +15,8 @@ service RpcV2Service { ] } -// TODO RpcV2 should not use the `@http` trait. +// TODO(https://github.com/smithy-lang/smithy/issues/2326): Smithy should not +// allow HTTP binding traits in this protocol. @http(uri: "/simple-struct-operation", method: "POST") operation SimpleStructOperation { input: SimpleStruct @@ -23,7 +24,6 @@ operation SimpleStructOperation { errors: [ValidationException] } -// TODO RpcV2 should not use the `@http` trait. @http(uri: "/complex-struct-operation", method: "POST") operation ComplexStructOperation { input: ComplexStruct @@ -51,9 +51,6 @@ apply SimpleStructOperation @httpResponseTests([ double: 0.6969, timestamp: 1546300800, - // document: { - // documentInteger: 69 - // } enum: "DIAMOND" // With `@required`. @@ -72,9 +69,6 @@ apply SimpleStructOperation @httpResponseTests([ requiredDouble: 0.6969, requiredTimestamp: 1546300800, - // document: { - // documentInteger: 69 - // } requiredEnum: "DIAMOND" } }, @@ -98,9 +92,6 @@ apply SimpleStructOperation @httpResponseTests([ requiredDouble: 0.6969, requiredTimestamp: 1546300800, - // document: { - // documentInteger: 69 - // } requiredEnum: "DIAMOND" } } @@ -121,7 +112,6 @@ structure SimpleStruct { double: Double timestamp: Timestamp - // document: MyDocument enum: Suit // With `@required`. @@ -171,7 +161,8 @@ map SimpleMap { value: Integer } -// TODO Cut ticket to Smithy: their protocol tests don't have unions +// TODO(https://github.com/smithy-lang/smithy/issues/2325): Upstream protocol +// test suite doesn't cover unions. union SimpleUnion { blob: Blob boolean: Boolean @@ -203,6 +194,3 @@ enum Suit { HEART SPADE } - -// TODO Documents are not supported in RPC v2 CBOR. -// document MyDocument diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index db51633fef..621c00cbc4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -490,9 +490,7 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va val fractionalPart = num.remainder(BigDecimal.ONE) rust( "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", -// RuntimeType.dateTime(runtimeConfig), - // TODO - runtimeConfig.smithyRuntimeCrate("smithy-types", scope = DependencyScope.Dev).toType().resolve("DateTime"), + RuntimeType.dateTime(runtimeConfig), ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt similarity index 96% rename from codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt rename to codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index 75b275fc8f..acd93e3883 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -26,8 +26,7 @@ import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.outputShape -// TODO Rename these to RpcV2Cbor -class RpcV2HttpBindingResolver( +class RpcV2CborHttpBindingResolver( private val model: Model, ) : HttpBindingResolver { private fun bindings(shape: ToShapeId): List { @@ -50,7 +49,7 @@ class RpcV2HttpBindingResolver( .toList() } - // TODO + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) // In the server, this is only used when the protocol actually supports the `@http` trait. // However, we will have to do this for client support. Perhaps this method deserves a rename. override fun httpTrait(operationShape: OperationShape) = PANIC("RPC v2 does not support the `@http` trait") @@ -84,7 +83,7 @@ class RpcV2HttpBindingResolver( /** * TODO: Docs. */ -open class RpcV2(val codegenContext: CodegenContext) : Protocol { +open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, @@ -97,7 +96,7 @@ open class RpcV2(val codegenContext: CodegenContext) : Protocol { ) private val jsonDeserModule = RustModule.private("json_deser") - override val httpBindingResolver: HttpBindingResolver = RpcV2HttpBindingResolver(codegenContext.model) + override val httpBindingResolver: HttpBindingResolver = RpcV2CborHttpBindingResolver(codegenContext.model) // Note that [CborParserGenerator] and [CborSerializerGenerator] automatically (de)serialize timestamps // using floating point seconds from the epoch. diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 56671e3923..76549038a2 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -48,8 +48,8 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> CodegenTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"), CodegenTest( "smithy.protocoltests.rpcv2Cbor#RpcV2Service", - "rpcv2_extras", - imports = listOf("$commonModels/rpcv2-extras.smithy") + "rpcv2Cbor_extras", + imports = listOf("$commonModels/rpcv2Cbor-extras.smithy") ), CodegenTest( "com.amazonaws.constraints#ConstraintsService", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index d4281b4ddf..738f4faa88 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -21,7 +21,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml -import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2 +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization @@ -33,7 +33,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParse import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency @@ -308,9 +307,9 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborPa } } -class ServerRpcV2Protocol( +class ServerRpcV2CborProtocol( private val serverCodegenContext: ServerCodegenContext, -) : RpcV2(serverCodegenContext), ServerProtocol { +) : RpcV2Cbor(serverCodegenContext), ServerProtocol { val runtimeConfig = codegenContext.runtimeConfig override val protocolModulePath = "rpc_v2" @@ -349,6 +348,8 @@ class ServerRpcV2Protocol( ) = writable { // This is just the key used by the router's map to store and lookup operations, it's completely arbitrary. // We use the same key used by the awsJson1.x routers for simplicity. + // The router will extract the service name and the operation name from the URI, build this key, and lookup the + // operation stored there. rust("$serviceName.$operationName".dq()) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 49388e106a..3d8ae6a4f2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -280,12 +280,12 @@ class ServerProtocolTestGenerator( testModuleWriter.rust("Test ID: ${testCase.id}") testModuleWriter.newlinePrefix = "" - Attribute.TokioTest.render(testModuleWriter) - Attribute.TracedTest.render(testModuleWriter) // The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in // case the code rendered by the test does not make use of `tracing` at all. val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) } testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) + Attribute.TokioTest.render(testModuleWriter) + Attribute.TracedTest.render(testModuleWriter) if (expectFail(testCase)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") @@ -317,7 +317,16 @@ class ServerProtocolTestGenerator( } with(httpRequestTestCase) { - renderHttpRequest(uri, method, headers, body.orNull(), bodyMediaType.orNull(), queryParams, host.orNull()) + renderHttpRequest( + uri, + method, + headers, + body.orNull(), + bodyMediaType.orNull(), + protocol, + queryParams, + host.orNull(), + ) } if (protocolSupport.requestBodyDeserialization) { makeRequest(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase)) @@ -397,7 +406,16 @@ class ServerProtocolTestGenerator( // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. // TODO(https://github.com/smithy-lang/smithy/issues/1932): we send `null` for `bodyMediaType` for now but // the Smithy protocol test should give it to us. - renderHttpRequest(uri.get(), method, headers, body.orNull(), bodyMediaType = null, queryParams, host.orNull()) + renderHttpRequest( + uri.get(), + method, + headers, + body.orNull(), + bodyMediaType = null, + testCase.protocol, + queryParams, + host.orNull(), + ) } makeRequest( @@ -416,6 +434,7 @@ class ServerProtocolTestGenerator( headers: Map, body: String?, bodyMediaType: String?, + protocol: ShapeId, queryParams: List, host: String?, ) { @@ -446,20 +465,23 @@ class ServerProtocolTestGenerator( // We also escape to avoid interactions with templating in the case where the body contains `#`. val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() - // TODO This for RPC v2; use `bodyMediaType`, see GitHub issue above -// val encodedBody = -// """ -// #{Bytes}::from( -// #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( -// "`body` field of Smithy protocol test is not correctly base64 encoded" -// ) -// ) -// """ - // TODO This for other protocols - val encodedBody = + // TODO(https://github.com/smithy-lang/smithy/issues/1932): We're using the `protocol` field as a + // proxy for `bodyMediaType`. This works because `rpcv2Cbor` happens to be the only protocol where + // the body is base64-encoded in the protocol test, but checking `bodyMediaType` should be a more + // resilient check. + val encodedBody = if (protocol.toShapeId() == ShapeId.from("smithy.protocols#rpcv2Cbor")) { + """ + #{Bytes}::from( + #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( + "`body` field of Smithy protocol test is not correctly base64 encoded" + ) + ) + """ + } else { """ #{Bytes}::from_static($sanitizedBody.as_bytes()) """ + } "#{SmithyHttpServer}::body::Body::from($encodedBody)" } else { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt index 1dc2ce0679..a51119978b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt @@ -9,14 +9,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Proto import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRpcV2Protocol +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRpcV2CborProtocol class ServerRpcV2CborFactory : ProtocolGeneratorFactory { override fun protocol(codegenContext: ServerCodegenContext): Protocol = - ServerRpcV2Protocol(codegenContext) + ServerRpcV2CborProtocol(codegenContext) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = - ServerHttpBoundProtocolGenerator(codegenContext, ServerRpcV2Protocol(codegenContext)) + ServerHttpBoundProtocolGenerator(codegenContext, ServerRpcV2CborProtocol(codegenContext)) override fun support(): ProtocolSupport { return ProtocolSupport( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2CborTest.kt similarity index 97% rename from codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt rename to codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2CborTest.kt index 3b69814ffd..5c804f8739 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2CborTest.kt @@ -10,7 +10,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest // TODO This won't be needed since we'll cover it with a proper integration test. -internal class RpcV2Test { +internal class RpcV2CborTest { val model = """ ${"\$"}version: "2.0" diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt index 1f59abd014..8819e8e95f 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt @@ -25,7 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions -import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2 +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.getTrait @@ -82,7 +82,7 @@ internal class CborSerializerGeneratorTest { ) val instantiator = ServerInstantiator(codegenContext) - val rpcV2 = RpcV2(codegenContext) + val rpcV2 = RpcV2Cbor(codegenContext) for (operationShape in codegenContext.model.operationShapes) { val outputShape = operationShape.outputShape(codegenContext.model) From ce621fd25c724a5cb187e4e9a586a03b25c26374 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 18 Jun 2024 12:21:04 +0200 Subject: [PATCH 11/77] save work --- .../smithy/protocols/HttpBindingResolver.kt | 8 ++++ .../core/smithy/protocols/RpcV2Cbor.kt | 47 +++++++++++++++++-- .../generators/protocol/ServerProtocol.kt | 2 +- .../ServerHttpBoundProtocolGenerator.kt | 27 ++++++++++- 4 files changed, 76 insertions(+), 8 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt index ad36e79190..afeaf5e1ce 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt @@ -169,6 +169,10 @@ open class HttpTraitHttpBindingResolver( model: Model, ): TimestampFormatTrait.Format = httpIndex.determineTimestampFormat(memberShape, location, defaultTimestampFormat) + /** + * Note that `null` will be returned and hence `Content-Type` will not be set when operation input has no members. + * This is in line with what protocol tests assert. + */ override fun requestContentType(operationShape: OperationShape): String? = httpIndex.determineRequestContentType( operationShape, @@ -176,6 +180,10 @@ open class HttpTraitHttpBindingResolver( contentTypes.eventStreamContentType, ).orNull() + /** + * Note that `null` will be returned and hence `Content-Type` will not be set when operation output has no members. + * This is in line with what protocol tests assert. + */ override fun responseContentType(operationShape: OperationShape): String? = httpIndex.determineResponseContentType( operationShape, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index acd93e3883..5c7e0404e7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ToShapeId @@ -20,15 +21,21 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParse import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isStreaming +import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.outputShape class RpcV2CborHttpBindingResolver( private val model: Model, + private val contentTypes: ProtocolContentTypes, ) : HttpBindingResolver { + private val httpIndex: HttpBindingIndex = HttpBindingIndex.of(model) + private fun bindings(shape: ToShapeId): List { val members = shape.let { model.expectShape(it.toShapeId()) }.members() // TODO(https://github.com/awslabs/smithy-rs/issues/2237): support non-streaming members too @@ -58,11 +65,28 @@ class RpcV2CborHttpBindingResolver( override fun responseBindings(operationShape: OperationShape) = bindings(operationShape.outputShape) override fun errorResponseBindings(errorShape: ToShapeId) = bindings(errorShape) - // TODO This should return null when operationShape has no members, and we should not rely on our janky - // `serverContentTypeCheckNoModeledInput`. Same goes for restJson1 protocol. - override fun requestContentType(operationShape: OperationShape): String = "application/cbor" + /** + * https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#requests + * > Requests for operations with no defined input type MUST NOT contain bodies in their HTTP requests. + * > The `Content-Type` for the serialization format MUST NOT be set. + */ + override fun requestContentType(operationShape: OperationShape): String? { + // When `syntheticInputTrait.originalId == null` it implies that the operation had no input defined + // in the Smithy model. + val syntheticInputTrait = operationShape.inputShape(model).expectTrait() + if (syntheticInputTrait.originalId == null) { + return null + } + + return httpIndex.determineRequestContentType( + operationShape, + contentTypes.requestDocument, + contentTypes.eventStreamContentType, + ).orNull() + } /** + * https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#responses * > Responses for operations with no defined output type MUST NOT contain bodies in their HTTP responses. * > The `Content-Type` for the serialization format MUST NOT be set. */ @@ -73,7 +97,12 @@ class RpcV2CborHttpBindingResolver( if (syntheticOutputTrait.originalId == null) { return null } - return requestContentType(operationShape) + + return httpIndex.determineResponseContentType( + operationShape, + contentTypes.responseDocument, + contentTypes.eventStreamContentType, + ).orNull() } override fun eventStreamMessageContentType(memberShape: MemberShape): String? = @@ -96,7 +125,15 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { ) private val jsonDeserModule = RustModule.private("json_deser") - override val httpBindingResolver: HttpBindingResolver = RpcV2CborHttpBindingResolver(codegenContext.model) + override val httpBindingResolver: HttpBindingResolver = RpcV2CborHttpBindingResolver( + codegenContext.model, + ProtocolContentTypes( + requestDocument = "application/cbor", + responseDocument = "application/cbor", + eventStreamContentType = "application/vnd.amazon.eventstream", + eventStreamMessageContentType = "application/cbor", + ), + ) // Note that [CborParserGenerator] and [CborSerializerGenerator] automatically (de)serialize timestamps // using floating point seconds from the epoch. diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 738f4faa88..e7379e8a44 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -78,7 +78,7 @@ interface ServerProtocol : Protocol { fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType /** - * In some protocols, such as restJson1 and rpcv2, + * In some protocols, such as restJson1 and rpcv2Cbor, * when there is no modeled body input, `content-type` must not be set and the body must be empty. * Returns a boolean indicating whether to perform this check. */ diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 5e921a83ea..a3f2e1051e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -730,8 +730,28 @@ class ServerHttpBoundProtocolTraitImplGenerator( // there's something to parse (i.e. `parser != null`), so `!!` is safe here. val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!! rustTemplate("let bytes = #{Hyper}::body::to_bytes(body).await?;", *codegenScope) - // TODO Isn't this VERY wrong? If there's modeled operation input, we must reject if there's no payload! - // We currently accept and silently build empty input! + // Note that the server is being very lenient here. We're accepting an empty body for when there is modeled + // operation input; we simply parse it as empty operation input. + // This behavior applies to all protocols. This might seem like a bug, but it isn't. There's protocol tests + // that assert that the server should be lenient and accept both empty payloads and no payload + // when there is modeled input: + // + // * [restJson1]: clients omit the payload altogether when the input is empty! So services must accept this. + // * [rpcv2Cbor]: services must accept no payload or empty CBOR map for operations with modeled input. + // + // For the AWS JSON 1.x protocols, services are lenient in the case when there is no modeled input: + // + // * [awsJson1_0]: services must accept no payload or empty JSON document payload for operations with no modeled input + // * [awsJson1_1]: services must accept no payload or empty JSON document payload for operations with no modeled input + // + // However, it's true that there are no tests pinning server behavior when there is _empty_ input. There's + // a [consultation with Smithy] to remedy this. Until that gets resolved, in the meantime, we are being lenient. + // + // [restJson1]: https://github.com/smithy-lang/smithy/blob/main/smithy-aws-protocol-tests/model/restJson1/empty-input-output.smithy#L22 + // [awsJson1_0]: https://github.com/smithy-lang/smithy/blob/main/smithy-aws-protocol-tests/model/awsJson1_0/empty-input-output.smithy + // [awsJson1_1]: https://github.com/smithy-lang/smithy/blob/main/smithy-aws-protocol-tests/model/awsJson1_1/empty-operation.smithy + // [rpcv2Cbor]: https://github.com/smithy-lang/smithy/blob/main/smithy-protocol-tests/model/rpcv2Cbor/empty-input-output.smithy + // [consultation with Smithy]: https://github.com/smithy-lang/smithy/issues/2327 rustBlock("if !bytes.is_empty()") { rustTemplate( """ @@ -782,6 +802,9 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } + // TODO What about when there's no modeled operation input but the payload is not empty? In some protocols we + // must accept `{}` but we currently accept anything! + val err = if (ServerBuilderGenerator.hasFallibleBuilder( inputShape, From 04e1838854ba8335ae84a69343c26f907a8a6974 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 19 Jun 2024 15:11:51 +0200 Subject: [PATCH 12/77] fix merge conflict, fix Content-Type --- aws/sdk/build.gradle.kts | 4 -- .../core/smithy/protocols/RpcV2Cbor.kt | 39 ++++++------------- .../serialize/CborSerializerGenerator.kt | 6 +-- .../serialize/JsonSerializerGenerator.kt | 4 +- .../core/smithy/traits/SyntheticInputTrait.kt | 9 ++++- .../smithy/traits/SyntheticOutputTrait.kt | 10 ++++- .../transformers/OperationNormalizer.kt | 22 ++++++++++- .../ServerHttpBoundProtocolGenerator.kt | 3 +- 8 files changed, 54 insertions(+), 43 deletions(-) diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index d11573937a..b8d4eab906 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -453,9 +453,7 @@ tasks["assemble"].apply { outputs.upToDateWhen { false } } -<<<<<<< HEAD project.registerCargoCommandsTasks(outputDir.asFile) -======= tasks.register("copyCheckedInCargoLock") { description = "Copy the checked in Cargo.lock file back to the build directory" this.outputs.upToDateWhen { false } @@ -463,8 +461,6 @@ tasks.register("copyCheckedInCargoLock") { into(outputDir) } -project.registerCargoCommandsTasks(outputDir.asFile, defaultRustDocFlags) ->>>>>>> upstream/main project.registerGenerateCargoConfigTomlTask(outputDir.asFile) //The task name "test" is already registered by one of our plugins diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index 5c7e0404e7..a93c428a8c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborS import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -34,8 +35,6 @@ class RpcV2CborHttpBindingResolver( private val model: Model, private val contentTypes: ProtocolContentTypes, ) : HttpBindingResolver { - private val httpIndex: HttpBindingIndex = HttpBindingIndex.of(model) - private fun bindings(shape: ToShapeId): List { val members = shape.let { model.expectShape(it.toShapeId()) }.members() // TODO(https://github.com/awslabs/smithy-rs/issues/2237): support non-streaming members too @@ -70,41 +69,25 @@ class RpcV2CborHttpBindingResolver( * > Requests for operations with no defined input type MUST NOT contain bodies in their HTTP requests. * > The `Content-Type` for the serialization format MUST NOT be set. */ - override fun requestContentType(operationShape: OperationShape): String? { - // When `syntheticInputTrait.originalId == null` it implies that the operation had no input defined - // in the Smithy model. - val syntheticInputTrait = operationShape.inputShape(model).expectTrait() - if (syntheticInputTrait.originalId == null) { - return null + override fun requestContentType(operationShape: OperationShape): String? = + if (OperationNormalizer.hadUserModeledOperationInput(operationShape, model)) { + "application/cbor" + } else { + null } - return httpIndex.determineRequestContentType( - operationShape, - contentTypes.requestDocument, - contentTypes.eventStreamContentType, - ).orNull() - } - /** * https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#responses * > Responses for operations with no defined output type MUST NOT contain bodies in their HTTP responses. * > The `Content-Type` for the serialization format MUST NOT be set. */ - override fun responseContentType(operationShape: OperationShape): String? { - // When `syntheticOutputTrait.originalId == null` it implies that the operation had no output defined - // in the Smithy model. - val syntheticOutputTrait = operationShape.outputShape(model).expectTrait() - if (syntheticOutputTrait.originalId == null) { - return null + override fun responseContentType(operationShape: OperationShape): String? = + if (OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { + "application/cbor" + } else { + null } - return httpIndex.determineResponseContentType( - operationShape, - contentTypes.responseDocument, - contentTypes.eventStreamContentType, - ).orNull() - } - override fun eventStreamMessageContentType(memberShape: MemberShape): String? = ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/cbor") } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 68533ef7d1..28b204bdbb 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait @@ -267,10 +268,9 @@ class CborSerializerGenerator( } override fun operationOutputSerializer(operationShape: OperationShape): RuntimeType? { - // Don't generate an operation JSON serializer if there was no operation output shape in the + // Don't generate an operation CBOR serializer if there was no operation output shape in the // original (untransformed) model. - val syntheticOutputTrait = operationShape.outputShape(model).expectTrait() - if (syntheticOutputTrait.originalId == null) { + if (!OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { return null } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 8b18dfa5c8..b3f16d90c2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -48,6 +48,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -313,8 +314,7 @@ class JsonSerializerGenerator( override fun operationOutputSerializer(operationShape: OperationShape): RuntimeType? { // Don't generate an operation JSON serializer if there was no operation output shape in the // original (untransformed) model. - val syntheticOutputTrait = operationShape.outputShape(model).expectTrait() - if (syntheticOutputTrait.originalId == null) { + if (!OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { return null } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt index e4b95f5cde..b0c61d3d6c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt @@ -12,8 +12,13 @@ import software.amazon.smithy.model.traits.AnnotationTrait /** * Indicates that a shape is a synthetic input (see `OperationNormalizer.kt`) * - * All operations are normalized to have an input, even when they are defined without on. This is done for backwards - * compatibility and to produce a consistent API. + * All operations are normalized to have an input, even when they are defined without one. + * This is NOT done for backwards-compatibility, as adding an operation input is a breaking change + * (see ). + * + * It is only done to produce a consistent API. + * TODO(https://github.com/smithy-lang/smithy-rs/issues/3577): In the server, we'd like to stop adding + * these synthetic inputs. */ class SyntheticInputTrait( val operation: ShapeId, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt index d34ca0e39d..cc310e8a47 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt @@ -12,8 +12,14 @@ import software.amazon.smithy.model.traits.AnnotationTrait /** * Indicates that a shape is a synthetic output (see `OperationNormalizer.kt`) * - * All operations are normalized to have an output, even when they are defined without on. This is done for backwards - * compatibility and to produce a consistent API. + * All operations are normalized to have an output, even when they are defined without one. + * + * This is NOT done for backwards-compatibility, as adding an operation output is a breaking change + * (see ). + * + * It is only done to produce a consistent API. + * TODO(https://github.com/smithy-lang/smithy-rs/issues/3577): In the server, we'd like to stop adding + * these synthetic outputs. */ class SyntheticOutputTrait constructor(val operation: ShapeId, val originalId: ShapeId?) : AnnotationTrait(ID, Node.objectNode()) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt index 4092174b55..cbc7f00508 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt @@ -14,11 +14,13 @@ import software.amazon.smithy.model.traits.InputTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.orNull +import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.core.util.rename import java.util.Optional -import kotlin.streams.toList /** * Generate synthetic Input and Output structures for operations. @@ -43,6 +45,24 @@ object OperationNormalizer { private fun OperationShape.syntheticOutputId() = ShapeId.fromParts(this.id.namespace + ".synthetic", "${this.id.name}Output") + /** + * Returns `true` if the user had originally modeled an operation input shape on the given [operation]; + * `false` if the transform added a synthetic one. + */ + fun hadUserModeledOperationInput(operation: OperationShape, model: Model): Boolean { + val syntheticInputTrait = operation.inputShape(model).expectTrait() + return syntheticInputTrait.originalId != null + } + + /** + * Returns `true` if the user had originally modeled an operation output shape on the given [operation]; + * `false` if the transform added a synthetic one. + */ + fun hadUserModeledOperationOutput(operation: OperationShape, model: Model): Boolean { + val syntheticOutputTrait = operation.outputShape(model).expectTrait() + return syntheticOutputTrait.originalId != null + } + /** * Add synthetic input & output shapes to every Operation in model. The generated shapes will be marked with * [SyntheticInputTrait] and [SyntheticOutputTrait] respectively. Shapes will be added _even_ if the operation does diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 171ce2c6b2..587e94a476 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -60,6 +60,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.core.smithy.wrapOptional import software.amazon.smithy.rust.codegen.core.util.dq @@ -793,7 +794,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( serverRenderQueryStringParser(this, operationShape) // If there's no modeled operation input, some protocols require that `Content-Type` header not be present. - val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null + val noInputs = !OperationNormalizer.hadUserModeledOperationInput(operationShape, model) if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) { rustTemplate( """ From c1b9741f042e7af5d1328a35cb2bb0e23a5d4da4 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 19 Jun 2024 15:36:11 +0200 Subject: [PATCH 13/77] misc fixes --- .../core/smithy/protocols/RpcV2Cbor.kt | 34 ++++--------------- .../protocols/parse/CborParserGenerator.kt | 5 ++- .../protocol/ServerProtocolTestGenerator.kt | 5 +-- .../serialize/CborSerializerGeneratorTest.kt | 2 +- 4 files changed, 10 insertions(+), 36 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index a93c428a8c..955f330297 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -71,7 +71,7 @@ class RpcV2CborHttpBindingResolver( */ override fun requestContentType(operationShape: OperationShape): String? = if (OperationNormalizer.hadUserModeledOperationInput(operationShape, model)) { - "application/cbor" + contentTypes.requestDocument } else { null } @@ -83,7 +83,7 @@ class RpcV2CborHttpBindingResolver( */ override fun responseContentType(operationShape: OperationShape): String? = if (OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { - "application/cbor" + contentTypes.responseDocument } else { null } @@ -92,9 +92,6 @@ class RpcV2CborHttpBindingResolver( ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/cbor") } -/** - * TODO: Docs. - */ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( @@ -131,30 +128,11 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { override fun structuredDataSerializer(): StructuredDataSerializerGenerator = CborSerializerGenerator(codegenContext, httpBindingResolver) - // TODO: Implement `RpcV2.parseHttpErrorMetadata` + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_http_error_metadata", jsonDeserModule) { - rustTemplate( - """ - pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - #{json_errors}::parse_error_metadata(response.body(), response.headers()) - } - """, - *errorScope, - ) - } + TODO("rpcv2Cbor client support has not yet been implemented") - // TODO: Implement `RpcV2.parseEventStreamErrorMetadata` + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_event_stream_error_metadata", jsonDeserModule) { - // `HeaderMap::new()` doesn't allocate. - rustTemplate( - """ - pub fn parse_event_stream_error_metadata(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - #{json_errors}::parse_error_metadata(payload, &#{HeaderMap}::new()) - } - """, - *errorScope, - ) - } + TODO("rpcv2Cbor event streams have not yet been implemented") } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index d7e5426a87..2114ff7c19 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -467,14 +467,13 @@ class CborParserGenerator( is TimestampShape -> rust("decoder.timestamp()") - // TODO Document shapes have not been specced out yet. - // is DocumentShape -> rustTemplate("Some(#{expect_document}(tokens)?)", *codegenScope) - // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html is StructureShape -> deserializeStruct(target) is CollectionShape -> deserializeCollection(target) is MapShape -> deserializeMap(target) is UnionShape -> deserializeUnion(target) + + // Note that no protocol using CBOR serialization supports `document` shapes. else -> PANIC("unexpected shape: $target") } // TODO Boxing diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 3d8ae6a4f2..4316d4de3c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -231,9 +231,6 @@ class ServerProtocolTestGenerator( } } - private fun OperationShape.toName(): String = - RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(this).name.toSnakeCase()) - /** * Filter out test cases that are disabled or don't match the service protocol */ @@ -333,7 +330,7 @@ class ServerProtocolTestGenerator( checkHandlerWasEntered(this) } - // Explicitly warn if the test case defined parameters that we aren't doing anything with + // Explicitly warn if the test case defined parameters that we aren't doing anything with. with(httpRequestTestCase) { if (authScheme.isPresent) { logger.warning("Test case provided authScheme but this was ignored") diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt index 8819e8e95f..e190d2a16e 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt @@ -62,7 +62,7 @@ internal class CborSerializerGeneratorTest { @Test fun `we serialize and serde_cbor deserializes round trip`() { - val model = File("../codegen-core/common-test-models/rpcv2-extras.smithy").readText().asSmithyModel() + val model = File("../codegen-core/common-test-models/rpcv2Cbor-extras.smithy").readText().asSmithyModel() val addDeriveSerdeSerializeDecorator = object : ServerCodegenDecorator { override val name: String = "Add `#[derive(serde::Deserialize)]`" From c5f1df5d5abcb7fb2624cbafb0417ffa216bf215 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 20 Jun 2024 14:12:38 +0200 Subject: [PATCH 14/77] Trying to make CborSerializerGeneratorTest use the upstream protocol test model; but switching to refactoring protocol test generation now --- .../codegen/core/rustlang/CargoDependency.kt | 22 +++++----- .../codegen/core/smithy/CoreRustSettings.kt | 12 +++--- .../rust/codegen/core/smithy/RuntimeType.kt | 6 +-- .../core/smithy/generators/Instantiator.kt | 42 ++++++++++++------- .../NamingObstacleCourseTestModels.kt | 2 +- .../smithy/rust/codegen/core/testutil/Rust.kt | 3 +- codegen-server/build.gradle.kts | 3 ++ .../server/smithy/ServerCodegenVisitor.kt | 2 +- .../smithy/generators/ServerInstantiator.kt | 7 +++- .../serialize/CborSerializerGeneratorTest.kt | 42 ++++++++++++++----- 10 files changed, 92 insertions(+), 49 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index e2b7b9ae40..05acbd0c33 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.SymbolDependencyContainer import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.dq import java.nio.file.Path @@ -41,6 +42,11 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer { ) + dependencies().flatMap { it.dependencies } } + open fun toDevDependency(): RustDependency = when (this) { + is CargoDependency -> this.toDevDependency() + is InlineDependency -> PANIC("it does not make sense for an inline dependency to be a dev-dependency") + } + companion object { private const val PROPERTY_KEY = "rustdep" @@ -71,9 +77,7 @@ class InlineDependency( return renderer.hashCode().toString() } - override fun dependencies(): List { - return extraDependencies - } + override fun dependencies(): List = extraDependencies fun key() = "${module.fullyQualifiedPath()}::$name" @@ -170,7 +174,7 @@ data class Feature(val name: String, val default: Boolean, val deps: List = listOf(), private val constructPattern: InstantiatorConstructPattern = InstantiatorConstructPattern.BUILDER, private val customWritable: CustomWritable = NoCustomWritable(), + /** + * The protocol test may provide data for missing members (because we transformed the model). + * This flag makes it so that it is simply ignored, and code generation continues. + **/ + private val ignoreMissingMembers: Boolean = false, ) { data class Ctx( // The `http` crate requires that headers be lowercase, but Smithy protocol tests @@ -213,12 +218,12 @@ open class Instantiator( ")", // The conditions are not commutative: note client builders always take in `Option`. conditional = - symbol.isOptional() || - ( - model.expectShape(memberShape.container) is StructureShape && - builderKindBehavior.doesSetterTakeInOption( - memberShape, - ) + symbol.isOptional() || + ( + model.expectShape(memberShape.container) is StructureShape && + builderKindBehavior.doesSetterTakeInOption( + memberShape, + ) ), *preludeScope, ) { @@ -425,8 +430,13 @@ open class Instantiator( } } - data.members.forEach { (key, value) -> - val memberShape = shape.expectMember(key.value) + for ((key, value) in data.members) { + val memberShape = shape.getMember(key.value).getOrNull() + ?: if (ignoreMissingMembers) { + continue + } else { + throw CodegenException("Protocol test defines data for member shape `${key.value}`, but member shape was not found on structure shape ${shape.id}") + } renderMemberHelper(memberShape, value) } @@ -455,7 +465,8 @@ open class Instantiator( */ private fun fillDefaultValue(shape: Shape): Node = when (shape) { - is MemberShape -> shape.getTrait()?.toNode() ?: fillDefaultValue(model.expectShape(shape.target)) + is MemberShape -> shape.getTrait()?.toNode() + ?: fillDefaultValue(model.expectShape(shape.target)) // Aggregate shapes. is StructureShape -> Node.objectNode() @@ -490,7 +501,7 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va val fractionalPart = num.remainder(BigDecimal.ONE) rust( "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", - RuntimeType.dateTime(runtimeConfig), + RuntimeType.dateTime(runtimeConfig).toDevDependencyType(), ) } @@ -503,12 +514,12 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va if (shape.hasTrait()) { rust( "#T::from_static(b${(data as StringNode).value.dq()})", - RuntimeType.byteStream(runtimeConfig), + RuntimeType.byteStream(runtimeConfig).toDevDependencyType(), ) } else { rust( "#T::new(${(data as StringNode).value.dq()})", - RuntimeType.blob(runtimeConfig), + RuntimeType.blob(runtimeConfig).toDevDependencyType(), ) } @@ -521,7 +532,8 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va rust( """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", numberSymbol, - RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Parse"), + RuntimeType.smithyTypes(runtimeConfig).toDevDependencyType() + .resolve("primitive::Parse"), ) } @@ -536,7 +548,7 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va is BooleanShape -> rust(data.asBooleanNode().get().toString()) is DocumentShape -> rustBlock("") { - val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() + val smithyJson = CargoDependency.smithyJson(runtimeConfig).toDevDependency().toType() rustTemplate( """ let json_bytes = br##"${Node.prettyPrintJson(data)}"##; diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt index 5508611972..e0279c53c0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt @@ -178,7 +178,7 @@ object NamingObstacleCourseTestModels { /** * This targets two bug classes: * - operation inputs used as nested outputs - * - operation outputs used as nested outputs + * - operation outputs used as nested inputs */ fun reusedInputOutputShapesModel(protocol: Trait) = """ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt index d97d32a444..321df02805 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt @@ -88,7 +88,8 @@ object TestWorkspace { private val cargoLock: File by lazy { var curFile = File(this.javaClass.protectionDomain.codeSource.location.path) - while (!curFile.endsWith("smithy-rs")) { + // TODO This is not a robust check. + while (!curFile.endsWith("smithy-rs") && !curFile.endsWith("SmithyRsSource")) { curFile = curFile.parentFile } diff --git a/codegen-server/build.gradle.kts b/codegen-server/build.gradle.kts index 2dbc33df05..49e0462888 100644 --- a/codegen-server/build.gradle.kts +++ b/codegen-server/build.gradle.kts @@ -31,6 +31,9 @@ dependencies { // `smithy.framework#ValidationException` is defined here, which is used in `constraints.smithy`, which is used // in `CustomValidationExceptionWithReasonDecoratorTest`. testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") + + // It's handy to re-use protocol test suite models from Smithy in our Kotlin tests. + testImplementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion") } java { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 87b5506616..a0e604e57b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -608,7 +608,7 @@ open class ServerCodegenVisitor( } /** - * Generate protocol tests. This method can be overridden by other languages such has Python. + * Generate protocol tests. This method can be overridden by other languages such as Python. */ open fun protocolTests() { rustCrate.withModule(ServerRustModule.Operation) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index d06c21ff70..8dc0c3f8cf 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -70,7 +70,11 @@ class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat codegenContext.symbolProvider.toSymbol(memberShape).isOptional() } -class ServerInstantiator(codegenContext: CodegenContext, customWritable: CustomWritable = NoCustomWritable()) : +class ServerInstantiator( + codegenContext: CodegenContext, + customWritable: CustomWritable = NoCustomWritable(), + ignoreMissingMembers: Boolean = false, +) : Instantiator( codegenContext.symbolProvider, codegenContext.model, @@ -81,6 +85,7 @@ class ServerInstantiator(codegenContext: CodegenContext, customWritable: CustomW // Construct with direct pattern to more closely replicate actual server customer usage constructPattern = InstantiatorConstructPattern.DIRECT, customWritable = customWritable, + ignoreMissingMembers = ignoreMissingMembers, ) class ServerBuilderInstantiator( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt index e190d2a16e..88a00dd2b9 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape @@ -14,10 +15,13 @@ import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.protocoltests.traits.AppliesTo import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -26,7 +30,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor -import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -34,17 +38,20 @@ import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegen import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest -import java.io.File +import java.util.function.Predicate internal class CborSerializerGeneratorTest { class DeriveSerdeDeserializeSymbolMetadataProvider( private val base: RustSymbolProvider, ) : SymbolMetadataProvider(base) { + private val serdeDeserialize = + CargoDependency.Serde.copy(scope = DependencyScope.Compile).toType().resolve("Deserialize") + private fun addDeriveSerdeDeserialize(shape: Shape): RustMetadata { check(shape !is MemberShape) val baseMetadata = base.toSymbol(shape).expectRustMetadata() - return baseMetadata.withDerives(RuntimeType.SerdeDeserialize) + return baseMetadata.withDerives(serdeDeserialize) } override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() @@ -62,8 +69,6 @@ internal class CborSerializerGeneratorTest { @Test fun `we serialize and serde_cbor deserializes round trip`() { - val model = File("../codegen-core/common-test-models/rpcv2Cbor-extras.smithy").readText().asSmithyModel() - val addDeriveSerdeSerializeDecorator = object : ServerCodegenDecorator { override val name: String = "Add `#[derive(serde::Deserialize)]`" override val order: Byte = 0 @@ -72,16 +77,32 @@ internal class CborSerializerGeneratorTest { DeriveSerdeDeserializeSymbolMetadataProvider(base) } + // Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on + // which we can't `#[derive(serde::Deserialize)]`. + val model = Model.assembler().discoverModels().assemble().result.get() + val removeTimestampAndBlobShapes: Predicate = Predicate { shape -> + when (shape) { + is MemberShape -> { + val targetShape = model.expectShape(shape.target) + targetShape is BlobShape || targetShape is TimestampShape + } + is BlobShape, is TimestampShape -> true + else -> false + } + } + val transformedModel = ModelTransformer.create().removeShapesIf(model, removeTimestampAndBlobShapes) + serverIntegrationTest( - model, + transformedModel, additionalDecorators = listOf(addDeriveSerdeSerializeDecorator), + params = IntegrationTestParams(service = "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol") ) { codegenContext, rustCrate -> val codegenScope = arrayOf( "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), "SerdeCbor" to CargoDependency.SerdeCbor.toType(), ) - val instantiator = ServerInstantiator(codegenContext) + val instantiator = ServerInstantiator(codegenContext, ignoreMissingMembers = true) val rpcV2 = RpcV2Cbor(codegenContext) for (operationShape in codegenContext.model.operationShapes) { @@ -94,6 +115,7 @@ internal class CborSerializerGeneratorTest { outputShape, ) } + val serializeFn = rpcV2 .structuredDataSerializer() .operationOutputSerializer(operationShape) @@ -102,8 +124,8 @@ internal class CborSerializerGeneratorTest { // TODO Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on // which we can't `#[derive(Deserialize)]`. rustCrate.withModule(ProtocolFunctions.serDeModule) { - for ((idx, test) in tests.withIndex()) { - unitTest("TODO_$idx") { + for (test in tests) { + unitTest("we_serialize_and_serde_cbor_deserializes_${test.id}") { rustTemplate( """ let expected = #{InstantiateShape:W}; @@ -111,7 +133,7 @@ internal class CborSerializerGeneratorTest { .expect("generated CBOR serializer failed"); let actual = #{SerdeCbor}::from_slice(&bytes) .expect("serde_cbor failed deserializing from bytes"); - #{AssertEq}(expected, actual); + #{AssertEq}(expected, actual); """, "InstantiateShape" to instantiator.generate(test.targetShape, test.testCase.params), "SerializeFn" to serializeFn, From 3ac44d0d62a7d6be4cf13947e530fb400a00032c Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 20 Jun 2024 17:19:16 +0200 Subject: [PATCH 15/77] Comment in OperationNormalizer --- ...colTestGenerator.kt => ClientProtocolTestGenerator.kt} | 0 .../core/smithy/transformers/OperationNormalizer.kt | 8 +++++--- 2 files changed, 5 insertions(+), 3 deletions(-) rename codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/{ProtocolTestGenerator.kt => ClientProtocolTestGenerator.kt} (100%) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt similarity index 100% rename from codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt rename to codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt index cbc7f00508..dc2c84bd52 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt @@ -25,14 +25,16 @@ import java.util.Optional /** * Generate synthetic Input and Output structures for operations. * - * Operation input/output shapes can be retroactively added. In order to support this while maintaining backwards compatibility, - * we need to generate input/output shapes for all operations in a backwards compatible way. - * * This works by **adding** new shapes to the model for operation inputs & outputs. These new shapes have `SyntheticInputTrait` * and `SyntheticOutputTrait` attached to them as needed. This enables downstream code generators to determine if a shape is * "real" vs. a shape created as a synthetic input/output. * * The trait also tracks the original shape id for certain serialization tasks that require it to exist. + * + * Note that adding/removing operation input/output [is a breaking change]; the only reason why we synthetically add them + * is to produce a consistent API. + * + * [is a breaking change]: */ object OperationNormalizer { // Functions to construct synthetic shape IDs—Don't rely on these in external code. From e84ce5f75ad137faeb7899a9a9e2d8a787a9adc6 Mon Sep 17 00:00:00 2001 From: david-perez Date: Fri, 21 Jun 2024 12:44:17 +0200 Subject: [PATCH 16/77] Refactor and DRY up protocol test generation --- .../rustsdk/AwsFluentClientDecorator.kt | 6 +- .../client/smithy/ClientCodegenVisitor.kt | 4 +- .../customize/ClientCodegenDecorator.kt | 2 +- .../smithy/customize/ConditionalDecorator.kt | 2 +- .../protocol/ClientProtocolTestGenerator.kt | 290 +------- .../rust/codegen/core/rustlang/RustWriter.kt | 14 +- .../protocol/ProtocolTestGenerator.kt | 311 +++++++++ .../smithy/PythonServerCodegenVisitor.kt | 3 +- .../server/smithy/ServerCodegenVisitor.kt | 16 +- .../protocol/ServerProtocolTestGenerator.kt | 642 +++++------------- 10 files changed, 554 insertions(+), 736 deletions(-) create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt index 95e754ff0c..7b14bf4414 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt @@ -13,8 +13,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.client.Fluen import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDocs import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection -import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.DefaultProtocolTestGenerator -import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator +import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter @@ -28,6 +27,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.serviceNameOrDefault import software.amazon.smithy.rustsdk.customize.s3.S3ExpressFluentClientCustomization @@ -91,7 +91,7 @@ class AwsFluentClientDecorator : ClientCodegenDecorator { codegenContext: ClientCodegenContext, baseGenerator: ProtocolTestGenerator, ): ProtocolTestGenerator = - DefaultProtocolTestGenerator( + ClientProtocolTestGenerator( codegenContext, baseGenerator.protocolSupport, baseGenerator.operationShape, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt index d057b7e6f5..7e8737f50d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt @@ -23,7 +23,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGen import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.error.OperationErrorGenerator -import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.DefaultProtocolTestGenerator +import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolTestGenerator import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader import software.amazon.smithy.rust.codegen.client.smithy.transformers.AddErrorMessage import software.amazon.smithy.rust.codegen.client.smithy.transformers.RemoveEventStreamOperations @@ -322,7 +322,7 @@ class ClientCodegenVisitor( // render protocol tests into `operation.rs` (note operationWriter vs. inputWriter) codegenDecorator.protocolTestGenerator( codegenContext, - DefaultProtocolTestGenerator( + ClientProtocolTestGenerator( codegenContext, protocolGeneratorFactory.support(), operationShape, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt index c676e5259a..c4aec33b59 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt @@ -16,10 +16,10 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGen import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization -import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import java.util.ServiceLoader import java.util.logging.Logger diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ConditionalDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ConditionalDecorator.kt index 355d49b41f..d0a66c6b6a 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ConditionalDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ConditionalDecorator.kt @@ -17,7 +17,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCus import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization -import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization @@ -25,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomiza import software.amazon.smithy.rust.codegen.core.smithy.generators.ManifestCustomizations import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator /** * Delegating decorator that only applies when a condition is true diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt index 47ea9f5626..34a718ecd5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt @@ -5,45 +5,37 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.protocol -import software.amazon.smithy.codegen.core.CodegenException -import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.DoubleShape import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.protocoltests.traits.AppliesTo -import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase -import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase -import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientInstantiator -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport -import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_10 +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase +import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.outputShape -import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import java.util.logging.Logger import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType as RT @@ -54,21 +46,18 @@ data class ClientCreationParams( val clientName: String, ) -interface ProtocolTestGenerator { - val codegenContext: ClientCodegenContext - val protocolSupport: ProtocolSupport - val operationShape: OperationShape - - fun render(writer: RustWriter) -} - /** - * Generate protocol tests for an operation + * Generate client protocol tests for an [operationShape]. */ -class DefaultProtocolTestGenerator( +class ClientProtocolTestGenerator( override val codegenContext: ClientCodegenContext, override val protocolSupport: ProtocolSupport, override val operationShape: OperationShape, + + override val expectFail: Set = ExpectFail, + override val runOnly: Set = emptySet(), + override val disabledTests: Set = emptySet(), + private val renderClientCreation: RustWriter.(ClientCreationParams) -> Unit = { params -> rustTemplate( """ @@ -82,126 +71,53 @@ class DefaultProtocolTestGenerator( ) }, ) : ProtocolTestGenerator { + companion object { + private val ExpectFail = + setOf( + // Failing because we don't serialize default values if they match the default. + FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse", TestCaseKind.Request), + FailingTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults", TestCaseKind.Request), + FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput", TestCaseKind.Request), + ) + } + private val rc = codegenContext.runtimeConfig private val logger = Logger.getLogger(javaClass.name) private val inputShape = operationShape.inputShape(codegenContext.model) private val outputShape = operationShape.outputShape(codegenContext.model) - private val operationSymbol = codegenContext.symbolProvider.toSymbol(operationShape) - private val operationIndex = OperationIndex.of(codegenContext.model) private val instantiator = ClientInstantiator(codegenContext) private val codegenScope = arrayOf( - "SmithyHttp" to RT.smithyHttp(rc), "AssertEq" to RT.PrettyAssertions.resolve("assert_eq!"), "Uri" to RT.Http.resolve("Uri"), ) - sealed class TestCase { - abstract val testCase: HttpMessageTestCase - - data class RequestTest(override val testCase: HttpRequestTestCase) : TestCase() - - data class ResponseTest(override val testCase: HttpResponseTestCase, val targetShape: StructureShape) : - TestCase() - } - override fun render(writer: RustWriter) { - val requestTests = - operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.RequestTest(it) } - val responseTests = - operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.ResponseTest(it, outputShape) } - val errorTests = - operationIndex.getErrors(operationShape).flatMap { error -> - val testCases = - error.getTrait() - ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty() - testCases.map { TestCase.ResponseTest(it, error) } - } - val allTests: List = (requestTests + responseTests + errorTests).filterMatching() - if (allTests.isNotEmpty()) { - val operationName = operationSymbol.name - val testModuleName = "${operationName.toSnakeCase()}_request_test" - val additionalAttributes = - listOf( - Attribute(allow("unreachable_code", "unused_variables")), - ) - writer.withInlineModule( - RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes), - null, - ) { - renderAllTestCases(allTests) - } + val allTests = allTestCases(AppliesTo.CLIENT) + if (allTests.isEmpty()) { + return + } + + writer.withInlineModule(protocolTestsModule(), null) { + renderAllTestCases(allTests) } } private fun RustWriter.renderAllTestCases(allTests: List) { - allTests.forEach { - renderTestCaseBlock(it.testCase, this) { + for (it in allTests) { + renderTestCaseBlock(it, this) { when (it) { is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase) is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape) + is TestCase.MalformedRequestTest -> PANIC("Client protocol test generation does not support HTTP compliance test case type `$it`") } } } } - /** - * Filter out test cases that are disabled or don't match the service protocol - */ - private fun List.filterMatching(): List { - return if (RunOnly.isNullOrEmpty()) { - this.filter { testCase -> - testCase.testCase.protocol == codegenContext.protocol && - !DisableTests.contains(testCase.testCase.id) - } - } else { - this.filter { RunOnly.contains(it.testCase.id) } - } - } - - private fun renderTestCaseBlock( - testCase: HttpMessageTestCase, - testModuleWriter: RustWriter, - block: Writable, - ) { - testModuleWriter.newlinePrefix = "/// " - testCase.documentation.map { - testModuleWriter.writeWithNoFormatting(it) - } - testModuleWriter.write("Test ID: ${testCase.id}") - testModuleWriter.newlinePrefix = "" - - Attribute.TokioTest.render(testModuleWriter) - Attribute.TracedTest.render(testModuleWriter) - // The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in - // case the code rendered by the test does not make use of `tracing` at all. - val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) } - testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) - - val action = when (testCase) { - is HttpResponseTestCase -> Action.Response - is HttpRequestTestCase -> Action.Request - else -> throw CodegenException("unknown test case type") - } - if (expectFail(testCase)) { - testModuleWriter.writeWithNoFormatting("#[should_panic]") - } - val fnName = - when (action) { - is Action.Response -> "_response" - is Action.Request -> "_request" - } - Attribute.AllowUnusedMut.render(testModuleWriter) - testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnName()") { - block(this) - } - } - private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) { if (!protocolSupport.requestSerialization) { rust("/* test case disabled for this protocol (not yet supported) */") @@ -284,18 +200,6 @@ class DefaultProtocolTestGenerator( } } - private fun HttpMessageTestCase.action(): Action = - when (this) { - is HttpRequestTestCase -> Action.Request - is HttpResponseTestCase -> Action.Response - else -> throw CodegenException("Unknown test case type") - } - - private fun expectFail(testCase: HttpMessageTestCase): Boolean = - ExpectFail.find { - it.id == testCase.id && it.action == testCase.action() && it.service == codegenContext.serviceShape.id.toString() - } != null - private fun RustWriter.renderHttpResponseTestCase( testCase: HttpResponseTestCase, expectedShape: StructureShape, @@ -442,58 +346,6 @@ class DefaultProtocolTestGenerator( } } - private fun checkRequiredHeaders( - rustWriter: RustWriter, - actualExpression: String, - requireHeaders: List, - ) { - basicCheck( - requireHeaders, - rustWriter, - "required_headers", - actualExpression, - "require_headers", - ) - } - - private fun checkForbidHeaders( - rustWriter: RustWriter, - actualExpression: String, - forbidHeaders: List, - ) { - basicCheck( - forbidHeaders, - rustWriter, - "forbidden_headers", - actualExpression, - "forbid_headers", - ) - } - - private fun checkHeaders( - rustWriter: RustWriter, - actualExpression: String, - headers: Map, - ) { - if (headers.isEmpty()) { - return - } - val variableName = "expected_headers" - rustWriter.withBlock("let $variableName = [", "];") { - writeWithNoFormatting( - headers.entries.joinToString(",") { - "(${it.key.dq()}, ${it.value.dq()})" - }, - ) - } - assertOk(rustWriter) { - write( - "#T($actualExpression, $variableName)", - RT.protocolTest(rc, "validate_headers"), - ) - } - } - private fun checkRequiredQueryParams( rustWriter: RustWriter, requiredParams: List, @@ -526,80 +378,4 @@ class DefaultProtocolTestGenerator( "&http_request", "validate_query_string", ) - - private fun basicCheck( - params: List, - rustWriter: RustWriter, - expectedVariableName: String, - actualExpression: String, - checkFunction: String, - ) { - if (params.isEmpty()) { - return - } - rustWriter.withBlock("let $expectedVariableName = ", ";") { - strSlice(this, params) - } - assertOk(rustWriter) { - write( - "#T($actualExpression, $expectedVariableName)", - RT.protocolTest(rc, checkFunction), - ) - } - } - - /** - * wraps `inner` in a call to `aws_smithy_protocol_test::assert_ok`, a convenience wrapper - * for pretty printing protocol test helper results - */ - private fun assertOk( - rustWriter: RustWriter, - inner: Writable, - ) { - rustWriter.write("#T(", RT.protocolTest(rc, "assert_ok")) - inner(rustWriter) - rustWriter.write(");") - } - - private fun strSlice( - writer: RustWriter, - args: List, - ) { - writer.withBlock("&[", "]") { - write(args.joinToString(",") { it.dq() }) - } - } - - companion object { - sealed class Action { - object Request : Action() - - object Response : Action() - } - - data class FailingTest(val service: String, val id: String, val action: Action) - - // These tests fail due to shortcomings in our implementation. - // These could be configured via runtime configuration, but since this won't be long-lasting, - // it makes sense to do the simplest thing for now. - // The test will _fail_ if these pass, so we will discover & remove if we fix them by accident - private val JsonRpc10 = "aws.protocoltests.json10#JsonRpc10" - private val AwsJson11 = "aws.protocoltests.json#JsonProtocol" - private val RestJson = "aws.protocoltests.restjson#RestJson" - private val RestXml = "aws.protocoltests.restxml#RestXml" - private val AwsQuery = "aws.protocoltests.query#AwsQuery" - private val Ec2Query = "aws.protocoltests.ec2#AwsEc2" - private val ExpectFail = - setOf( - // Failing because we don't serialize default values if they match the default - FailingTest(JsonRpc10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse", Action.Request), - FailingTest(JsonRpc10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults", Action.Request), - FailingTest(JsonRpc10, "AwsJson10ClientPopulatesDefaultValuesInInput", Action.Request), - ) - private val RunOnly: Set? = null - - // These tests are not even attempted to be generated, either because they will not compile - // or because they are flaky - private val DisableTests: Set = setOf() - } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt index 94b3dc67c6..80ae291537 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt @@ -378,7 +378,13 @@ fun > T.docs( vararg args: Any, newlinePrefix: String = "/// ", trimStart: Boolean = true, + /** If `false`, will disable templating in `args` into `#{T}` spans */ + templating: Boolean = true, ): T { + if (!templating && args.isNotEmpty()) { + PANIC("Templating was disabled yet the following arguments were passed in: $args") + } + // Because writing docs relies on the newline prefix, ensure that there was a new line written // before we write the docs this.ensureNewline() @@ -392,7 +398,13 @@ fun > T.docs( else -> it }.replace("\t", " ") // Rustdoc warns on tabs in documentation } - write(cleaned, *args) + + if (templating) { + write(cleaned, *args) + } else { + writeWithNoFormatting(cleaned) + } + popState() return this } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt new file mode 100644 index 0000000000..b93c4a4dae --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -0,0 +1,311 @@ +package software.amazon.smithy.rust.codegen.core.smithy.generators.protocol + +import software.amazon.smithy.model.knowledge.OperationIndex +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.protocoltests.traits.AppliesTo +import software.amazon.smithy.protocoltests.traits.HttpMalformedRequestTestCase +import software.amazon.smithy.protocoltests.traits.HttpMalformedRequestTestsTrait +import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase +import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait +import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase +import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.orNull +import software.amazon.smithy.rust.codegen.core.util.outputShape +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase + +/** + * Common interface to generate protocol tests for a given [operationShape]. + */ +interface ProtocolTestGenerator { + val codegenContext: CodegenContext + val protocolSupport: ProtocolSupport + val operationShape: OperationShape + + /** + * We expect these tests to fail due to shortcomings in our implementation. + * They will _fail_ if they pass, so we will discover and remove them if we fix them by accident. + **/ + val expectFail: Set + + /** Only generate these tests; useful to temporarily set and shorten development cycles */ + val runOnly: Set + + /** + * These tests are not even attempted to be generated, either because they will not compile + * or because they are flaky. + */ + val disabledTests: Set + + /** The Rust module in which we should generate the protocol tests for [operationShape]. */ + fun protocolTestsModule(): RustModule.LeafModule { + val operationName = codegenContext.symbolProvider.toSymbol(operationShape).name + val testModuleName = "${operationName.toSnakeCase()}_test" + val additionalAttributes = + listOf(Attribute(allow("unreachable_code", "unused_variables"))) + return RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes) + } + + /** The entry point to render the protocol tests, invoked by the code generators. */ + fun render(writer: RustWriter) + + /** Filter out test cases that are disabled or don't match the service protocol. */ + fun List.filterMatching(): List = if (runOnly.isEmpty()) { + this.filter { testCase -> testCase.protocol == codegenContext.protocol && !disabledTests.contains(testCase.id) } + } else { + this.filter { testCase -> runOnly.contains(testCase.id) } + } + + /** Do we expect this [testCase] to fail? */ + fun expectFail(testCase: TestCase): Boolean = + expectFail.find { + it.id == testCase.id && it.kind == testCase.kind && it.service == codegenContext.serviceShape.id.toString() + } != null + + /** + * Parses from the model and returns all test cases for [operationShape] applying to the [appliesTo] artifact type + * that should be rendered by implementors. + **/ + fun allTestCases(appliesTo: AppliesTo): List { + val operationIndex = OperationIndex.of(codegenContext.model) + val outputShape = operationShape.outputShape(codegenContext.model) + + val requestTests = + operationShape.getTrait() + ?.getTestCasesFor(appliesTo).orEmpty().map { TestCase.RequestTest(it) } + + // `@httpResponseTests` trait can apply to operation shapes and structure shapes with the `@error` trait. + // Find both kinds for the operation for which we're generating protocol tests. + val responseTestsOnOperations = + operationShape.getTrait() + ?.getTestCasesFor(appliesTo).orEmpty().map { TestCase.ResponseTest(it, outputShape) } + val responseTestsOnErrors = + operationIndex.getErrors(operationShape).flatMap { error -> + val testCases = + error.getTrait() + ?.getTestCasesFor(appliesTo).orEmpty() + testCases.map { TestCase.ResponseTest(it, error) } + } + + // `@httpMalformedRequestTests` only make sense for servers. + val malformedRequestTests = if (appliesTo == AppliesTo.SERVER) { + operationShape.getTrait() + ?.testCases.orEmpty().map { TestCase.MalformedRequestTest(it) } + } else { + emptyList() + } + + // Note there's no `@httpMalformedResponseTests`: https://github.com/smithy-lang/smithy/issues/2334 + + val allTests: List = + (requestTests + responseTestsOnOperations + responseTestsOnErrors + malformedRequestTests) + .filterMatching() + return allTests + } + + fun renderTestCaseBlock( + testCase: TestCase, + testModuleWriter: RustWriter, + block: Writable, + ) { + if (testCase.documentation != null) { + testModuleWriter.docs(testCase.documentation!!, templating = false) + } + testModuleWriter.docs("Test ID: ${testCase.id}") + + // The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in + // case the code rendered by the test does not make use of `tracing` at all. + val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) } + testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) + Attribute.TokioTest.render(testModuleWriter) + Attribute.TracedTest.render(testModuleWriter) + + if (expectFail(testCase)) { + testModuleWriter.writeWithNoFormatting("#[should_panic]") + } + val fnNameSuffix = + when (testCase) { + is TestCase.ResponseTest -> "_response" + is TestCase.RequestTest -> "_request" + is TestCase.MalformedRequestTest -> "_malformed_request" + } + // TODO Do we need this one? + Attribute.AllowUnusedMut.render(testModuleWriter) + testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnNameSuffix()") { + block(this) + } + } + + fun checkRequiredHeaders( + rustWriter: RustWriter, + actualExpression: String, + requireHeaders: List, + ) { + basicCheck( + requireHeaders, + rustWriter, + "required_headers", + actualExpression, + "require_headers", + ) + } + + fun checkForbidHeaders( + rustWriter: RustWriter, + actualExpression: String, + forbidHeaders: List, + ) { + basicCheck( + forbidHeaders, + rustWriter, + "forbidden_headers", + actualExpression, + "forbid_headers", + ) + } + + fun checkHeaders( + rustWriter: RustWriter, + actualExpression: String, + headers: Map, + ) { + if (headers.isEmpty()) { + return + } + val variableName = "expected_headers" + rustWriter.withBlock("let $variableName = [", "];") { + writeWithNoFormatting( + headers.entries.joinToString(",") { + "(${it.key.dq()}, ${it.value.dq()})" + }, + ) + } + assertOk(rustWriter) { + write( + "#T($actualExpression, $variableName)", + RuntimeType.protocolTest(codegenContext.runtimeConfig, "validate_headers"), + ) + } + } + + fun basicCheck( + params: List, + rustWriter: RustWriter, + expectedVariableName: String, + actualExpression: String, + checkFunction: String, + ) { + if (params.isEmpty()) { + return + } + rustWriter.withBlock("let $expectedVariableName = ", ";") { + strSlice(this, params) + } + assertOk(rustWriter) { + rustWriter.rust( + "#T($actualExpression, $expectedVariableName)", + RuntimeType.protocolTest(codegenContext.runtimeConfig, checkFunction), + ) + } + } + + /** + * Wraps `inner` in a call to `aws_smithy_protocol_test::assert_ok`, a convenience wrapper + * for pretty printing protocol test helper results. + */ + fun assertOk( + rustWriter: RustWriter, + inner: Writable, + ) { + rustWriter.rust("#T(", RuntimeType.protocolTest(codegenContext.runtimeConfig, "assert_ok")) + inner(rustWriter) + rustWriter.write(");") + } + + private fun strSlice( + writer: RustWriter, + args: List, + ) { + writer.withBlock("&[", "]") { + rust(args.joinToString(",") { it.dq() }) + } + } +} + +/** + * Service shape IDs in common protocol test suites defined upstream. + */ +object ServiceShapeId { + const val AWS_JSON_10 = "aws.protocoltests.json10#JsonRpc10" + const val AWS_JSON_11 = "aws.protocoltests.json#JsonProtocol" + const val REST_JSON = "aws.protocoltests.restjson#RestJson" + const val REST_XML = "aws.protocoltests.restxml#RestXml" + const val AWS_QUERY = "aws.protocoltests.query#AwsQuery" + const val EC2_QUERY = "aws.protocoltests.ec2#AwsEc2" + const val REST_JSON_VALIDATION = "aws.protocoltests.restjson.validation#RestJsonValidation" +} + +data class FailingTest(val service: String, val id: String, val kind: TestCaseKind) + +sealed class TestCaseKind { + data object Request : TestCaseKind() + data object Response : TestCaseKind() + data object MalformedRequest : TestCaseKind() +} + +sealed class TestCase { + data class RequestTest(val testCase: HttpRequestTestCase) : TestCase() + data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() + data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase() + + /* + * `HttpRequestTestCase` and `HttpResponseTestCase` both implement `HttpMessageTestCase`, but + * `HttpMalformedRequestTestCase` doesn't, so we have to define the following trivial delegators to provide a nice + * common accessor API. + */ + + val id: String + get() = when (this) { + is RequestTest -> this.testCase.id + is MalformedRequestTest -> this.testCase.id + is ResponseTest -> this.testCase.id + } + + val protocol: ShapeId + get() = when (this) { + is RequestTest -> this.testCase.protocol + is MalformedRequestTest -> this.testCase.protocol + is ResponseTest -> this.testCase.protocol + } + + val kind: TestCaseKind + get() = when (this) { + is RequestTest -> TestCaseKind.Request + is ResponseTest -> TestCaseKind.Response + is MalformedRequestTest -> TestCaseKind.MalformedRequest + } + + val documentation: String? + get() = when (this) { + is RequestTest -> this.testCase.documentation.orNull() + is ResponseTest -> this.testCase.documentation.orNull() + is MalformedRequestTest -> this.testCase.documentation.orNull() + } +} diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index 19ea83426a..836f1d234b 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -17,6 +17,7 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig @@ -222,7 +223,7 @@ class PythonServerCodegenVisitor( } } - override fun protocolTests() { + override fun protocolTestsForOperation(writer: RustWriter, operationShape: OperationShape) { logger.warning("[python-server-codegen] Protocol tests are disabled for this language") } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index a0e604e57b..2c14798d34 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -610,12 +610,8 @@ open class ServerCodegenVisitor( /** * Generate protocol tests. This method can be overridden by other languages such as Python. */ - open fun protocolTests() { - rustCrate.withModule(ServerRustModule.Operation) { - ServerProtocolTestGenerator(codegenContext, protocolGeneratorFactory.support(), protocolGenerator).render( - this, - ) - } + open fun protocolTestsForOperation(writer: RustWriter, shape: OperationShape) { + ServerProtocolTestGenerator(codegenContext, protocolGeneratorFactory.support(), shape).render(writer) } /** @@ -648,9 +644,6 @@ open class ServerCodegenVisitor( ServerRuntimeTypesReExportsGenerator(codegenContext).render(this) } - // Generate protocol tests. - protocolTests() - // Generate service module. rustCrate.withModule(ServerRustModule.Service) { ServerServiceGenerator( @@ -693,6 +686,11 @@ open class ServerCodegenVisitor( codegenDecorator.postprocessOperationGenerateAdditionalStructures(shape) .forEach { structureShape -> this.structureShape(structureShape) } + + // Generate protocol tests. + rustCrate.withModule(ServerRustModule.Operation) { + protocolTestsForOperation(this, shape) + } } override fun blobShape(shape: BlobShape) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 4316d4de3c..e8c73fae18 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.DoubleShape @@ -17,21 +16,12 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.protocoltests.traits.AppliesTo import software.amazon.smithy.protocoltests.traits.HttpMalformedRequestTestCase -import software.amazon.smithy.protocoltests.traits.HttpMalformedRequestTestsTrait import software.amazon.smithy.protocoltests.traits.HttpMalformedResponseBodyDefinition import software.amazon.smithy.protocoltests.traits.HttpMalformedResponseDefinition import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase -import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase -import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -41,11 +31,17 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_10 +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_11 +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON_VALIDATION +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors -import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -55,24 +51,176 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency -import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator import java.util.logging.Logger import kotlin.reflect.KFunction1 /** - * Generate protocol tests for an operation + * Generate server protocol tests for an [operationShape]. */ class ServerProtocolTestGenerator( - private val codegenContext: CodegenContext, - private val protocolSupport: ProtocolSupport, - private val protocolGenerator: ServerProtocolGenerator, -) { + override val codegenContext: CodegenContext, + override val protocolSupport: ProtocolSupport, + override val operationShape: OperationShape, + + override val expectFail: Set = ExpectFail, + override val runOnly: Set = emptySet(), + override val disabledTests: Set = DisabledTests, +): ProtocolTestGenerator { + companion object { + private val ExpectFail: Set = + setOf( + // Endpoint trait is not implemented yet, see https://github.com/smithy-lang/smithy-rs/issues/950. + FailingTest(REST_JSON, "RestJsonEndpointTrait", TestCaseKind.Request), + FailingTest(REST_JSON, "RestJsonEndpointTraitWithHostLabel", TestCaseKind.Request), + FailingTest(REST_JSON, "RestJsonOmitsEmptyListQueryValues", TestCaseKind.Request), + // TODO(https://github.com/smithy-lang/smithy/pull/2315): Can be deleted when fixed tests are consumed in next Smithy version + FailingTest(REST_JSON, "RestJsonEnumPayloadRequest", TestCaseKind.Request), + FailingTest(REST_JSON, "RestJsonStringPayloadRequest", TestCaseKind.Request), + // Tests involving `@range` on floats. + // Pending resolution from the Smithy team, see https://github.com/smithy-lang/smithy-rs/issues/2007. + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloat_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloat_case1", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMaxFloat", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMinFloat", TestCaseKind.MalformedRequest), + // Tests involving floating point shapes and the `@range` trait; see https://github.com/smithy-lang/smithy-rs/issues/2007 + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloatOverride_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloatOverride_case1", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMaxFloatOverride", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMinFloatOverride", TestCaseKind.MalformedRequest), + // Some tests for the S3 service (restXml). + FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", TestCaseKind.Response), + FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", TestCaseKind.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing", TestCaseKind.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3PathAddressing", TestCaseKind.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAddressing", TestCaseKind.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestCaseKind.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestCaseKind.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestCaseKind.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationNoErrorWrappingResponse", TestCaseKind.Response), + // AwsJson1.0 failing tests. + FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel", TestCaseKind.Request), + FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestCaseKind.Request), + // AwsJson1.1 failing tests. + FailingTest(AWS_JSON_11, "AwsJson11EndpointTraitWithHostLabel", TestCaseKind.Request), + FailingTest(AWS_JSON_11, "AwsJson11EndpointTrait", TestCaseKind.Request), + FailingTest(AWS_JSON_11, "parses_the_request_id_from_the_response", TestCaseKind.Response), + // TODO(https://github.com/awslabs/smithy/issues/1683): This has been marked as failing until resolution of said issue + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBlobList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBooleanList_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBooleanList_case1", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsStringList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsByteList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsShortList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsIntegerList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsLongList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsTimestampList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsDateTimeList", TestCaseKind.MalformedRequest), + FailingTest( + REST_JSON_VALIDATION, + "RestJsonMalformedUniqueItemsHttpDateList_case0", + TestCaseKind.MalformedRequest, + ), + FailingTest( + REST_JSON_VALIDATION, + "RestJsonMalformedUniqueItemsHttpDateList_case1", + TestCaseKind.MalformedRequest, + ), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsEnumList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsIntEnumList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsListList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsStructureList", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsUnionList_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsUnionList_case1", TestCaseKind.MalformedRequest), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/2472): We don't respect the `@internal` trait + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumList_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumList_case1", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapKey_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapKey_case1", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapValue_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapValue_case1", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumString_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumString_case1", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumUnion_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumUnion_case1", TestCaseKind.MalformedRequest), + // TODO(https://github.com/awslabs/smithy/issues/1737): Specs on @internal, @tags, and enum values need to be clarified + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumTraitString_case0", TestCaseKind.MalformedRequest), + FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumTraitString_case1", TestCaseKind.MalformedRequest), + // These tests are broken because they are missing a target header. + FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesNestedDefaultsWhenMissingInRequestBody", TestCaseKind.Request), + FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesDefaultsWhenMissingInRequestBody", TestCaseKind.Request), + // Response defaults are not set when builders are not used https://github.com/smithy-lang/smithy-rs/issues/3339 + FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesDefaultsInResponseWhenMissingInParams", TestCaseKind.Response), + FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams", TestCaseKind.Response), + ) + + private val DisabledTests = + setOf( + // TODO(https://github.com/smithy-lang/smithy-rs/issues/2891): Implement support for `@requestCompression` + "SDKAppendedGzipAfterProvidedEncoding_restJson1", + "SDKAppendedGzipAfterProvidedEncoding_restXml", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_0", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_1", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsQuery", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_ec2Query", + "SDKAppliedContentEncoding_awsJson1_0", + "SDKAppliedContentEncoding_awsJson1_1", + "SDKAppliedContentEncoding_awsQuery", + "SDKAppliedContentEncoding_ec2Query", + "SDKAppliedContentEncoding_restJson1", + "SDKAppliedContentEncoding_restXml", + // RestXml S3 tests that fail to compile + "S3EscapeObjectKeyInUriLabel", + "S3EscapePathObjectKeyInUriLabel", + "S3PreservesLeadingDotSegmentInUriLabel", + "S3PreservesEmbeddedDotSegmentInUriLabel", + ) + + // TODO(https://github.com/awslabs/smithy/issues/1506) + private fun fixRestJsonMalformedPatternReDOSString( + testCase: HttpMalformedRequestTestCase, + ): HttpMalformedRequestTestCase { + val brokenResponse = testCase.response + val brokenBody = brokenResponse.body.get() + val fixedBody = + HttpMalformedResponseBodyDefinition.builder() + .mediaType(brokenBody.mediaType) + .contents( + """ + { + "message" : "1 validation error detected. Value at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", + "fieldList" : [{"message": "Value at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", "path": "/evilString"}] + } + """.trimIndent(), + ) + .build() + + return testCase.toBuilder() + .response(brokenResponse.toBuilder().body(fixedBody).build()) + .build() + } + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/1288): Move the fixed versions into + // `rest-json-extras.smithy` and put the unfixed ones in `ExpectFail`: this has the + // advantage that once our upstream PRs get merged and we upgrade to the next Smithy release, our build will + // fail and we will take notice to remove the fixes from `rest-json-extras.smithy`. This is exactly what the + // client does. + private val BrokenMalformedRequestTests: + Map, KFunction1> = + // TODO(https://github.com/awslabs/smithy/issues/1506) + mapOf( + Pair( + REST_JSON_VALIDATION, + "RestJsonMalformedPatternReDOSString", + ) to ::fixRestJsonMalformedPatternReDOSString, + ) + } + private val logger = Logger.getLogger(javaClass.name) private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider - private val operationIndex = OperationIndex.of(codegenContext.model) + private val operationSymbol = symbolProvider.toSymbol(operationShape) private val serviceName = codegenContext.serviceShape.id.name.toPascalCase() private val operations = @@ -104,147 +252,36 @@ class ServerProtocolTestGenerator( arrayOf( "Base64SimdDev" to ServerCargoDependency.Base64SimdDev.toType(), "Bytes" to RuntimeType.Bytes, - "SmithyHttp" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), - "Http" to RuntimeType.Http, "Hyper" to RuntimeType.Hyper, "Tokio" to ServerCargoDependency.TokioDev.toType(), "Tower" to RuntimeType.Tower, "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType(), "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), - "Router" to ServerRuntimeType.router(codegenContext.runtimeConfig), ) - sealed class TestCase { - abstract val id: String - abstract val documentation: String? - abstract val protocol: ShapeId - abstract val testType: TestType - - data class RequestTest(val testCase: HttpRequestTestCase, val operationShape: OperationShape) : TestCase() { - override val id: String = testCase.id - override val documentation: String? = testCase.documentation.orNull() - override val protocol: ShapeId = testCase.protocol - override val testType: TestType = TestType.Request - } - - data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() { - override val id: String = testCase.id - override val documentation: String? = testCase.documentation.orNull() - override val protocol: ShapeId = testCase.protocol - override val testType: TestType = TestType.Response - } - - data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase() { - override val id: String = testCase.id - override val documentation: String? = testCase.documentation.orNull() - override val protocol: ShapeId = testCase.protocol - override val testType: TestType = TestType.MalformedRequest - } - } - - fun render(writer: RustWriter) { - for (operation in operations) { - renderOperationTestCases(operation, writer) + override fun render(writer: RustWriter) { + val allTests = allTestCases(AppliesTo.SERVER).fixBroken() + if (allTests.isEmpty()) { + return } - } - private fun renderOperationTestCases( - operationShape: OperationShape, - writer: RustWriter, - ) { - val outputShape = operationShape.outputShape(codegenContext.model) - val operationSymbol = symbolProvider.toSymbol(operationShape) - - val requestTests = - operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it, operationShape) } - val responseTests = - operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.ResponseTest(it, outputShape) } - val errorTests = - operationIndex.getErrors(operationShape).flatMap { error -> - val testCases = - error.getTrait() - ?.getTestCasesFor(AppliesTo.SERVER).orEmpty() - testCases.map { TestCase.ResponseTest(it, error) } - } - val malformedRequestTests = - operationShape.getTrait() - ?.testCases.orEmpty().map { TestCase.MalformedRequestTest(it) } - val allTests: List = - (requestTests + responseTests + errorTests + malformedRequestTests) - .filterMatching() - .fixBroken() - - if (allTests.isNotEmpty()) { - val operationName = operationSymbol.name - val module = - RustModule.LeafModule( - "server_${operationName.toSnakeCase()}_test", - RustMetadata( - additionalAttributes = - listOf( - Attribute.CfgTest, - Attribute(allow("unreachable_code", "unused_variables")), - ), - visibility = Visibility.PRIVATE, - ), - inline = true, - ) - writer.withInlineModule(module, null) { - renderAllTestCases(operationShape, allTests) - } + writer.withInlineModule(protocolTestsModule(), null) { + renderAllTestCases(allTests) } } - private fun RustWriter.renderAllTestCases( - operationShape: OperationShape, - allTests: List, - ) { - allTests.forEach { - val operationSymbol = symbolProvider.toSymbol(operationShape) + private fun RustWriter.renderAllTestCases(allTests: List) { + for (it in allTests) { renderTestCaseBlock(it, this) { when (it) { - is TestCase.RequestTest -> - this.renderHttpRequestTestCase( - it.testCase, - operationShape, - operationSymbol, - ) - - is TestCase.ResponseTest -> - this.renderHttpResponseTestCase( - it.testCase, - it.targetShape, - operationShape, - operationSymbol, - ) - - is TestCase.MalformedRequestTest -> - this.renderHttpMalformedRequestTestCase( - it.testCase, - operationShape, - operationSymbol, - ) + is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase) + is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape) + is TestCase.MalformedRequestTest -> this.renderHttpMalformedRequestTestCase(it.testCase) } } } } - /** - * Filter out test cases that are disabled or don't match the service protocol - */ - private fun List.filterMatching(): List { - return if (RunOnly.isNullOrEmpty()) { - this.filter { testCase -> - testCase.protocol == codegenContext.protocol && - !DisableTests.contains(testCase.id) - } - } else { - this.filter { RunOnly.contains(it.id) } - } - } - // This function applies a "fix function" to each broken test before we synthesize it. // Broken tests are those whose definitions in the `awslabs/smithy` repository are wrong, usually because they have // not been written with a server-side perspective in mind. @@ -264,50 +301,12 @@ class ServerProtocolTestGenerator( } } - private fun renderTestCaseBlock( - testCase: TestCase, - testModuleWriter: RustWriter, - block: Writable, - ) { - testModuleWriter.newlinePrefix = "/// " - if (testCase.documentation != null) { - testModuleWriter.writeWithNoFormatting(testCase.documentation) - } - - testModuleWriter.rust("Test ID: ${testCase.id}") - testModuleWriter.newlinePrefix = "" - - // The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in - // case the code rendered by the test does not make use of `tracing` at all. - val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) } - testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) - Attribute.TokioTest.render(testModuleWriter) - Attribute.TracedTest.render(testModuleWriter) - - if (expectFail(testCase)) { - testModuleWriter.writeWithNoFormatting("#[should_panic]") - } - val fnNameSuffix = - when (testCase.testType) { - is TestType.Response -> "_response" - is TestType.Request -> "_request" - is TestType.MalformedRequest -> "_malformed_request" - } - testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnNameSuffix()") { - block(this) - } - } - /** * Renders an HTTP request test case. * We are given an HTTP request in the test case, and we assert that when we deserialize said HTTP request into * an operation's input shape, the resulting shape is of the form we expect, as defined in the test case. */ - private fun RustWriter.renderHttpRequestTestCase( - httpRequestTestCase: HttpRequestTestCase, - operationShape: OperationShape, - operationSymbol: Symbol, - ) { + private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) { if (!protocolSupport.requestDeserialization) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -341,23 +340,13 @@ class ServerProtocolTestGenerator( } } - private fun expectFail(testCase: TestCase): Boolean = - ExpectFail.find { - it.id == testCase.id && it.testType == testCase.testType && it.service == codegenContext.serviceShape.id.toString() - } != null - /** * Renders an HTTP response test case. * We are given an operation output shape or an error shape in the `params` field, and we assert that when we * serialize said shape, the resulting HTTP response is of the form we expect, as defined in the test case. * [shape] is either an operation output shape or an error shape. */ - private fun RustWriter.renderHttpResponseTestCase( - testCase: HttpResponseTestCase, - shape: StructureShape, - operationShape: OperationShape, - operationSymbol: Symbol, - ) { + private fun RustWriter.renderHttpResponseTestCase(testCase: HttpResponseTestCase, shape: StructureShape) { val operationErrorName = "crate::error::${operationSymbol.name}Error" if (!protocolSupport.responseSerialization || ( @@ -389,11 +378,7 @@ class ServerProtocolTestGenerator( * We are given a request definition and a response definition, and we have to assert that the request is rejected * with the given response. */ - private fun RustWriter.renderHttpMalformedRequestTestCase( - testCase: HttpMalformedRequestTestCase, - operationShape: OperationShape, - operationSymbol: Symbol, - ) { + private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase) { val (_, outputT) = operationInputOutputTypes[operationShape]!! val panicMessage = "request should have been rejected, but we accepted it; we parsed operation input `{:?}`" @@ -736,269 +721,4 @@ class ServerProtocolTestGenerator( *codegenScope, ) } - - private fun checkRequiredHeaders( - rustWriter: RustWriter, - actualExpression: String, - requireHeaders: List, - ) { - basicCheck( - requireHeaders, - rustWriter, - "required_headers", - actualExpression, - "require_headers", - ) - } - - private fun checkForbidHeaders( - rustWriter: RustWriter, - actualExpression: String, - forbidHeaders: List, - ) { - basicCheck( - forbidHeaders, - rustWriter, - "forbidden_headers", - actualExpression, - "forbid_headers", - ) - } - - private fun checkHeaders( - rustWriter: RustWriter, - actualExpression: String, - headers: Map, - ) { - if (headers.isEmpty()) { - return - } - val variableName = "expected_headers" - rustWriter.withBlock("let $variableName = [", "];") { - writeWithNoFormatting( - headers.entries.joinToString(",") { - "(${it.key.dq()}, ${it.value.dq()})" - }, - ) - } - assertOk(rustWriter) { - rust( - "#T($actualExpression, $variableName)", - RuntimeType.protocolTest(codegenContext.runtimeConfig, "validate_headers"), - ) - } - } - - private fun basicCheck( - params: List, - rustWriter: RustWriter, - expectedVariableName: String, - actualExpression: String, - checkFunction: String, - ) { - if (params.isEmpty()) { - return - } - rustWriter.withBlock("let $expectedVariableName = ", ";") { - strSlice(this, params) - } - assertOk(rustWriter) { - rustWriter.rust( - "#T($actualExpression, $expectedVariableName)", - RuntimeType.protocolTest(codegenContext.runtimeConfig, checkFunction), - ) - } - } - - /** - * wraps `inner` in a call to `aws_smithy_protocol_test::assert_ok`, a convenience wrapper - * for pretty printing protocol test helper results - */ - private fun assertOk( - rustWriter: RustWriter, - inner: Writable, - ) { - rustWriter.rust("#T(", RuntimeType.protocolTest(codegenContext.runtimeConfig, "assert_ok")) - inner(rustWriter) - rustWriter.write(");") - } - - private fun strSlice( - writer: RustWriter, - args: List, - ) { - writer.withBlock("&[", "]") { - rust(args.joinToString(",") { it.dq() }) - } - } - - companion object { - sealed class TestType { - object Request : TestType() - - object Response : TestType() - - object MalformedRequest : TestType() - } - - data class FailingTest(val service: String, val id: String, val testType: TestType) - - // These tests fail due to shortcomings in our implementation. - // These could be configured via runtime configuration, but since this won't be long-lasting, - // it makes sense to do the simplest thing for now. - // The test will _fail_ if these pass, so we will discover & remove if we fix them by accident - private const val AWS_JSON11 = "aws.protocoltests.json#JsonProtocol" - private const val AWS_JSON10 = "aws.protocoltests.json10#JsonRpc10" - private const val REST_JSON = "aws.protocoltests.restjson#RestJson" - private const val REST_JSON_VALIDATION = "aws.protocoltests.restjson.validation#RestJsonValidation" - private val ExpectFail: Set = - setOf( - // Endpoint trait is not implemented yet, see https://github.com/smithy-lang/smithy-rs/issues/950. - FailingTest(REST_JSON, "RestJsonEndpointTrait", TestType.Request), - FailingTest(REST_JSON, "RestJsonEndpointTraitWithHostLabel", TestType.Request), - FailingTest(REST_JSON, "RestJsonOmitsEmptyListQueryValues", TestType.Request), - // TODO(https://github.com/smithy-lang/smithy/pull/2315): Can be deleted when fixed tests are consumed in next Smithy version - FailingTest(REST_JSON, "RestJsonEnumPayloadRequest", TestType.Request), - FailingTest(REST_JSON, "RestJsonStringPayloadRequest", TestType.Request), - // Tests involving `@range` on floats. - // Pending resolution from the Smithy team, see https://github.com/smithy-lang/smithy-rs/issues/2007. - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloat_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloat_case1", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMaxFloat", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMinFloat", TestType.MalformedRequest), - // Tests involving floating point shapes and the `@range` trait; see https://github.com/smithy-lang/smithy-rs/issues/2007 - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloatOverride_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloatOverride_case1", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMaxFloatOverride", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMinFloatOverride", TestType.MalformedRequest), - // Some tests for the S3 service (restXml). - FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", TestType.Response), - FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3PathAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationNoErrorWrappingResponse", TestType.Response), - // AwsJson1.0 failing tests. - FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel", TestType.Request), - FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestType.Request), - // AwsJson1.1 failing tests. - FailingTest(AWS_JSON11, "AwsJson11EndpointTraitWithHostLabel", TestType.Request), - FailingTest(AWS_JSON11, "AwsJson11EndpointTrait", TestType.Request), - FailingTest(AWS_JSON11, "parses_the_request_id_from_the_response", TestType.Response), - // TODO(https://github.com/awslabs/smithy/issues/1683): This has been marked as failing until resolution of said issue - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBlobList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBooleanList_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBooleanList_case1", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsStringList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsByteList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsShortList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsIntegerList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsLongList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsTimestampList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsDateTimeList", TestType.MalformedRequest), - FailingTest( - REST_JSON_VALIDATION, - "RestJsonMalformedUniqueItemsHttpDateList_case0", - TestType.MalformedRequest, - ), - FailingTest( - REST_JSON_VALIDATION, - "RestJsonMalformedUniqueItemsHttpDateList_case1", - TestType.MalformedRequest, - ), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsEnumList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsIntEnumList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsListList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsStructureList", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsUnionList_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsUnionList_case1", TestType.MalformedRequest), - // TODO(https://github.com/smithy-lang/smithy-rs/issues/2472): We don't respect the `@internal` trait - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumList_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumList_case1", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapKey_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapKey_case1", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapValue_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapValue_case1", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumString_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumString_case1", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumUnion_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumUnion_case1", TestType.MalformedRequest), - // TODO(https://github.com/awslabs/smithy/issues/1737): Specs on @internal, @tags, and enum values need to be clarified - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumTraitString_case0", TestType.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumTraitString_case1", TestType.MalformedRequest), - // These tests are broken because they are missing a target header - FailingTest(AWS_JSON10, "AwsJson10ServerPopulatesNestedDefaultsWhenMissingInRequestBody", TestType.Request), - FailingTest(AWS_JSON10, "AwsJson10ServerPopulatesDefaultsWhenMissingInRequestBody", TestType.Request), - // Response defaults are not set when builders are not used https://github.com/smithy-lang/smithy-rs/issues/3339 - FailingTest(AWS_JSON10, "AwsJson10ServerPopulatesDefaultsInResponseWhenMissingInParams", TestType.Response), - FailingTest(AWS_JSON10, "AwsJson10ServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams", TestType.Response), - ) - private val RunOnly: Set? = null - - // These tests are not even attempted to be generated, either because they will not compile - // or because they are flaky - private val DisableTests = - setOf( - // TODO(https://github.com/smithy-lang/smithy-rs/issues/2891): Implement support for `@requestCompression` - "SDKAppendedGzipAfterProvidedEncoding_restJson1", - "SDKAppendedGzipAfterProvidedEncoding_restXml", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_0", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_1", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsQuery", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_ec2Query", - "SDKAppliedContentEncoding_awsJson1_0", - "SDKAppliedContentEncoding_awsJson1_1", - "SDKAppliedContentEncoding_awsQuery", - "SDKAppliedContentEncoding_ec2Query", - "SDKAppliedContentEncoding_restJson1", - "SDKAppliedContentEncoding_restXml", - // RestXml S3 tests that fail to compile - "S3EscapeObjectKeyInUriLabel", - "S3EscapePathObjectKeyInUriLabel", - "S3PreservesLeadingDotSegmentInUriLabel", - "S3PreservesEmbeddedDotSegmentInUriLabel", - ) - - // TODO(https://github.com/awslabs/smithy/issues/1506) - private fun fixRestJsonMalformedPatternReDOSString( - testCase: HttpMalformedRequestTestCase, - ): HttpMalformedRequestTestCase { - val brokenResponse = testCase.response - val brokenBody = brokenResponse.body.get() - val fixedBody = - HttpMalformedResponseBodyDefinition.builder() - .mediaType(brokenBody.mediaType) - .contents( - """ - { - "message" : "1 validation error detected. Value at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", - "fieldList" : [{"message": "Value at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", "path": "/evilString"}] - } - """.trimIndent(), - ) - .build() - - return testCase.toBuilder() - .response(brokenResponse.toBuilder().body(fixedBody).build()) - .build() - } - - // TODO(https://github.com/smithy-lang/smithy-rs/issues/1288): Move the fixed versions into - // `rest-json-extras.smithy` and put the unfixed ones in `ExpectFail`: this has the - // advantage that once our upstream PRs get merged and we upgrade to the next Smithy release, our build will - // fail and we will take notice to remove the fixes from `rest-json-extras.smithy`. This is exactly what the - // client does. - private val BrokenMalformedRequestTests: - Map, KFunction1> = - // TODO(https://github.com/awslabs/smithy/issues/1506) - mapOf( - Pair( - REST_JSON_VALIDATION, - "RestJsonMalformedPatternReDOSString", - ) to ::fixRestJsonMalformedPatternReDOSString, - ) - } } From 9c1b290d7b4defef3e8e64acd6e0726bb5243905 Mon Sep 17 00:00:00 2001 From: david-perez Date: Fri, 21 Jun 2024 16:18:09 +0200 Subject: [PATCH 17/77] Cherry-pick this over to protocol test improvements branch --- .../protocol/ClientProtocolTestGenerator.kt | 29 +++---- .../protocol/ProtocolTestGenerator.kt | 80 ++++++++++++------- .../protocol/ServerProtocolTestGenerator.kt | 37 ++++----- 3 files changed, 79 insertions(+), 67 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt index 34a718ecd5..65a8e21559 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt @@ -53,11 +53,6 @@ class ClientProtocolTestGenerator( override val codegenContext: ClientCodegenContext, override val protocolSupport: ProtocolSupport, override val operationShape: OperationShape, - - override val expectFail: Set = ExpectFail, - override val runOnly: Set = emptySet(), - override val disabledTests: Set = emptySet(), - private val renderClientCreation: RustWriter.(ClientCreationParams) -> Unit = { params -> rustTemplate( """ @@ -70,7 +65,7 @@ class ClientProtocolTestGenerator( "Client" to ClientRustModule.root.toType().resolve("Client"), ) }, -) : ProtocolTestGenerator { +) : ProtocolTestGenerator() { companion object { private val ExpectFail = setOf( @@ -81,6 +76,15 @@ class ClientProtocolTestGenerator( ) } + override val appliesTo: AppliesTo + get() = AppliesTo.CLIENT + override val expectFail: Set + get() = ExpectFail + override val runOnly: Set + get() = emptySet() + override val disabledTests: Set + get() = emptySet() + private val rc = codegenContext.runtimeConfig private val logger = Logger.getLogger(javaClass.name) @@ -95,18 +99,7 @@ class ClientProtocolTestGenerator( "Uri" to RT.Http.resolve("Uri"), ) - override fun render(writer: RustWriter) { - val allTests = allTestCases(AppliesTo.CLIENT) - if (allTests.isEmpty()) { - return - } - - writer.withInlineModule(protocolTestsModule(), null) { - renderAllTestCases(allTests) - } - } - - private fun RustWriter.renderAllTestCases(allTests: List) { + override fun RustWriter.renderAllTestCases(allTests: List) { for (it in allTests) { renderTestCaseBlock(it, this) { when (it) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index b93c4a4dae..98e4480b9c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -1,6 +1,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.protocol import software.amazon.smithy.model.knowledge.OperationIndex +import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape @@ -25,6 +26,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly +import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.orNull @@ -34,28 +36,29 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase /** * Common interface to generate protocol tests for a given [operationShape]. */ -interface ProtocolTestGenerator { - val codegenContext: CodegenContext - val protocolSupport: ProtocolSupport - val operationShape: OperationShape +abstract class ProtocolTestGenerator { + abstract val codegenContext: CodegenContext + abstract val protocolSupport: ProtocolSupport + abstract val operationShape: OperationShape + abstract val appliesTo: AppliesTo /** * We expect these tests to fail due to shortcomings in our implementation. * They will _fail_ if they pass, so we will discover and remove them if we fix them by accident. **/ - val expectFail: Set + abstract val expectFail: Set /** Only generate these tests; useful to temporarily set and shorten development cycles */ - val runOnly: Set + abstract val runOnly: Set /** * These tests are not even attempted to be generated, either because they will not compile * or because they are flaky. */ - val disabledTests: Set + abstract val disabledTests: Set /** The Rust module in which we should generate the protocol tests for [operationShape]. */ - fun protocolTestsModule(): RustModule.LeafModule { + private fun protocolTestsModule(): RustModule.LeafModule { val operationName = codegenContext.symbolProvider.toSymbol(operationShape).name val testModuleName = "${operationName.toSnakeCase()}_test" val additionalAttributes = @@ -64,33 +67,50 @@ interface ProtocolTestGenerator { } /** The entry point to render the protocol tests, invoked by the code generators. */ - fun render(writer: RustWriter) + fun render(writer: RustWriter) { + val allTests = allTestCases().fixBroken() + if (allTests.isEmpty()) { + return + } + + writer.withInlineModule(protocolTestsModule(), null) { + renderAllTestCases(allTests) + } + } + + /** Implementors should describe how to render the test cases. **/ + abstract fun RustWriter.renderAllTestCases(allTests: List) + + /** + * This function applies a "fix function" to each broken test before we synthesize it. + * Broken tests are those whose definitions in the `awslabs/smithy` repository are wrong. + * We try to contribute fixes upstream to pare down this function to the identity function. + */ + open fun List.fixBroken(): List = this /** Filter out test cases that are disabled or don't match the service protocol. */ - fun List.filterMatching(): List = if (runOnly.isEmpty()) { + private fun List.filterMatching(): List = if (runOnly.isEmpty()) { this.filter { testCase -> testCase.protocol == codegenContext.protocol && !disabledTests.contains(testCase.id) } } else { this.filter { testCase -> runOnly.contains(testCase.id) } } /** Do we expect this [testCase] to fail? */ - fun expectFail(testCase: TestCase): Boolean = + private fun expectFail(testCase: TestCase): Boolean = expectFail.find { it.id == testCase.id && it.kind == testCase.kind && it.service == codegenContext.serviceShape.id.toString() } != null - /** - * Parses from the model and returns all test cases for [operationShape] applying to the [appliesTo] artifact type - * that should be rendered by implementors. - **/ - fun allTestCases(appliesTo: AppliesTo): List { + fun requestTestCases(): List { + val requestTests = operationShape.getTrait()?.getTestCasesFor(appliesTo).orEmpty() + .map { TestCase.RequestTest(it) } + return requestTests.filterMatching() + } + + fun responseTestCases(): List { val operationIndex = OperationIndex.of(codegenContext.model) val outputShape = operationShape.outputShape(codegenContext.model) - val requestTests = - operationShape.getTrait() - ?.getTestCasesFor(appliesTo).orEmpty().map { TestCase.RequestTest(it) } - // `@httpResponseTests` trait can apply to operation shapes and structure shapes with the `@error` trait. // Find both kinds for the operation for which we're generating protocol tests. val responseTestsOnOperations = @@ -104,6 +124,10 @@ interface ProtocolTestGenerator { testCases.map { TestCase.ResponseTest(it, error) } } + return (responseTestsOnOperations + responseTestsOnErrors).filterMatching() + } + + fun malformedRequestTestCases(): List { // `@httpMalformedRequestTests` only make sense for servers. val malformedRequestTests = if (appliesTo == AppliesTo.SERVER) { operationShape.getTrait() @@ -111,14 +135,16 @@ interface ProtocolTestGenerator { } else { emptyList() } + return malformedRequestTests.filterMatching() + } + /** + * Parses from the model and returns all test cases for [operationShape] applying to the [appliesTo] artifact type + * that should be rendered by implementors. + **/ + fun allTestCases(): List = // Note there's no `@httpMalformedResponseTests`: https://github.com/smithy-lang/smithy/issues/2334 - - val allTests: List = - (requestTests + responseTestsOnOperations + responseTestsOnErrors + malformedRequestTests) - .filterMatching() - return allTests - } + requestTestCases() + responseTestCases() + malformedRequestTestCases() fun renderTestCaseBlock( testCase: TestCase, @@ -146,8 +172,6 @@ interface ProtocolTestGenerator { is TestCase.RequestTest -> "_request" is TestCase.MalformedRequestTest -> "_malformed_request" } - // TODO Do we need this one? - Attribute.AllowUnusedMut.render(testModuleWriter) testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnNameSuffix()") { block(this) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index e8c73fae18..73bf62e449 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -62,11 +62,7 @@ class ServerProtocolTestGenerator( override val codegenContext: CodegenContext, override val protocolSupport: ProtocolSupport, override val operationShape: OperationShape, - - override val expectFail: Set = ExpectFail, - override val runOnly: Set = emptySet(), - override val disabledTests: Set = DisabledTests, -): ProtocolTestGenerator { +): ProtocolTestGenerator() { companion object { private val ExpectFail: Set = setOf( @@ -216,6 +212,15 @@ class ServerProtocolTestGenerator( ) } + override val appliesTo: AppliesTo + get() = AppliesTo.SERVER + override val expectFail: Set + get() = ExpectFail + override val runOnly: Set + get() = emptySet() + override val disabledTests: Set + get() = DisabledTests + private val logger = Logger.getLogger(javaClass.name) private val model = codegenContext.model @@ -259,18 +264,7 @@ class ServerProtocolTestGenerator( "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), ) - override fun render(writer: RustWriter) { - val allTests = allTestCases(AppliesTo.SERVER).fixBroken() - if (allTests.isEmpty()) { - return - } - - writer.withInlineModule(protocolTestsModule(), null) { - renderAllTestCases(allTests) - } - } - - private fun RustWriter.renderAllTestCases(allTests: List) { + override fun RustWriter.renderAllTestCases(allTests: List) { for (it in allTests) { renderTestCaseBlock(it, this) { when (it) { @@ -282,10 +276,11 @@ class ServerProtocolTestGenerator( } } - // This function applies a "fix function" to each broken test before we synthesize it. - // Broken tests are those whose definitions in the `awslabs/smithy` repository are wrong, usually because they have - // not been written with a server-side perspective in mind. - private fun List.fixBroken(): List = + /** + * Broken tests in the `awslabs/smithy` repository are usually wrong because they have not been written + * with a server-side perspective in mind. + */ + override fun List.fixBroken(): List = this.map { when (it) { is TestCase.MalformedRequestTest -> { From 97db4a494f43ba304b0093cac6bb508bf118c993 Mon Sep 17 00:00:00 2001 From: david-perez Date: Fri, 21 Jun 2024 18:36:56 +0200 Subject: [PATCH 18/77] rename file; going to put protocol test generation in CodegenDecorator now --- ...GeneratorSerdeRoundTripIntegrationTest.kt} | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) rename codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/{CborSerializerGeneratorTest.kt => CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt} (82%) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt similarity index 82% rename from codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt rename to codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt index 88a00dd2b9..3ea131b811 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -18,8 +18,6 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.transform.ModelTransformer -import software.amazon.smithy.protocoltests.traits.AppliesTo -import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata @@ -28,19 +26,27 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.getTrait -import software.amazon.smithy.rust.codegen.core.util.outputShape +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator -import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRpcV2CborFactory import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest import java.util.function.Predicate -internal class CborSerializerGeneratorTest { +/** + * This lives in `codegen-server` because we want to run a full integration test for convenience, + * but there's really nothing server-specific here. We're just testing that the CBOR (de)serializers work like + * the ones generated by `serde_cbor`. This is a good exhaustive litmus test for correctness, since `serde_cbor` + * is battle-tested. + */ +internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { class DeriveSerdeDeserializeSymbolMetadataProvider( private val base: RustSymbolProvider, ) : SymbolMetadataProvider(base) { @@ -106,25 +112,24 @@ internal class CborSerializerGeneratorTest { val rpcV2 = RpcV2Cbor(codegenContext) for (operationShape in codegenContext.model.operationShapes) { - val outputShape = operationShape.outputShape(codegenContext.model) - // TODO Use `httpRequestTests` and error tests too. - val tests = operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { - ServerProtocolTestGenerator.TestCase.ResponseTest( - it, - outputShape, - ) - } + val serverProtocolTestGenerator = + ServerProtocolTestGenerator(codegenContext, ServerRpcV2CborFactory().support(), operationShape) + val tests = + serverProtocolTestGenerator.requestTestCases() + serverProtocolTestGenerator.responseTestCases() val serializeFn = rpcV2 .structuredDataSerializer() .operationOutputSerializer(operationShape) ?: continue // Skip if there's nothing to serialize. - // TODO Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on - // which we can't `#[derive(Deserialize)]`. rustCrate.withModule(ProtocolFunctions.serDeModule) { for (test in tests) { + val (targetShape, params) = when (test) { + is TestCase.MalformedRequestTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.RequestTest -> operationShape.inputShape(codegenContext.model) to test.testCase.params + is TestCase.ResponseTest -> test.targetShape to test.testCase.params + } + unitTest("we_serialize_and_serde_cbor_deserializes_${test.id}") { rustTemplate( """ @@ -135,7 +140,7 @@ internal class CborSerializerGeneratorTest { .expect("serde_cbor failed deserializing from bytes"); #{AssertEq}(expected, actual); """, - "InstantiateShape" to instantiator.generate(test.targetShape, test.testCase.params), + "InstantiateShape" to instantiator.generate(targetShape, params), "SerializeFn" to serializeFn, *codegenScope, ) From b923bf8da5d6299cd5afa8bab660cff03d0200a4 Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 24 Jun 2024 13:56:37 +0200 Subject: [PATCH 19/77] save work --- .../customize/ClientCodegenDecorator.kt | 16 ------- .../smithy/customize/CoreCodegenDecorator.kt | 17 +++++++ .../server/smithy/ServerCodegenVisitor.kt | 9 +++- ...rGeneratorSerdeRoundTripIntegrationTest.kt | 46 ++++++++++++++++++- 4 files changed, 70 insertions(+), 18 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt index c4aec33b59..ecc2c3132b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt @@ -93,14 +93,6 @@ interface ClientCodegenDecorator : CoreCodegenDecorator, ): List = baseCustomizations - - /** - * Hook to override the protocol test generator - */ - fun protocolTestGenerator( - codegenContext: ClientCodegenContext, - baseGenerator: ProtocolTestGenerator, - ): ProtocolTestGenerator = baseGenerator } /** @@ -176,14 +168,6 @@ open class CombinedClientCodegenDecorator(decorators: List - decorator.protocolTestGenerator(codegenContext, gen) - } - companion object { fun fromClasspath( context: PluginContext, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt index 55b0a147d4..fa83c62208 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomiza import software.amazon.smithy.rust.codegen.core.smithy.generators.ManifestCustomizations import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.util.deepMergeWith import java.util.ServiceLoader import java.util.logging.Logger @@ -120,6 +121,14 @@ interface CoreCodegenDecorator { * Hook for customizing symbols by inserting an additional symbol provider. */ fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = base + + /** + * Hook to override the protocol test generator. + */ + fun protocolTestGenerator( + codegenContext: CodegenContext, + baseGenerator: ProtocolTestGenerator, + ): ProtocolTestGenerator = baseGenerator } /** @@ -199,6 +208,14 @@ abstract class CombinedCoreCodegenDecorator + decorator.protocolTestGenerator(codegenContext, gen) + } + /** * Combines customizations from multiple ordered codegen decorators. * diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 2c14798d34..cf5c81df54 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -611,7 +611,14 @@ open class ServerCodegenVisitor( * Generate protocol tests. This method can be overridden by other languages such as Python. */ open fun protocolTestsForOperation(writer: RustWriter, shape: OperationShape) { - ServerProtocolTestGenerator(codegenContext, protocolGeneratorFactory.support(), shape).render(writer) + codegenDecorator.protocolTestGenerator( + codegenContext, + ServerProtocolTestGenerator( + codegenContext, + protocolGeneratorFactory.support(), + shape, + ), + ).render(writer) } /** diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt index 3ea131b811..52c4d4ab2d 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -12,27 +12,36 @@ import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.protocoltests.traits.AppliesTo import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator @@ -83,6 +92,41 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { DeriveSerdeDeserializeSymbolMetadataProvider(base) } + // Don't generate protocol tests, because it'll attempt to pull out `params` for member shapes we'll remove + // from the model. + val noProtocolTestsDecorator = object : ServerCodegenDecorator { + override val name: String = "Don't generate protocol tests" + override val order: Byte = 0 + + override fun protocolTestGenerator( + codegenContext: ServerCodegenContext, + baseGenerator: ProtocolTestGenerator + ): ProtocolTestGenerator { + val noOpProtocolTestsGenerator = object : ProtocolTestGenerator() { + override val codegenContext: CodegenContext + get() = PANIC("We'll never need this") + override val protocolSupport: ProtocolSupport + get() = PANIC("We'll never need this") + override val operationShape: OperationShape + get() = PANIC("We'll never need this") + override val appliesTo: AppliesTo + get() = PANIC("We'll never need this") + override val expectFail: Set + get() = PANIC("We'll never need this") + override val runOnly: Set + get() = PANIC("We'll never need this") + override val disabledTests: Set + get() = PANIC("We'll never need this") + + override fun RustWriter.renderAllTestCases(allTests: List) { + // No-op. + } + + } + return noOpProtocolTestsGenerator + } + } + // Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on // which we can't `#[derive(serde::Deserialize)]`. val model = Model.assembler().discoverModels().assemble().result.get() @@ -100,7 +144,7 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { serverIntegrationTest( transformedModel, - additionalDecorators = listOf(addDeriveSerdeSerializeDecorator), + additionalDecorators = listOf(addDeriveSerdeSerializeDecorator, noProtocolTestsDecorator), params = IntegrationTestParams(service = "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol") ) { codegenContext, rustCrate -> val codegenScope = arrayOf( From 975fc67e8581909dc9be06190ce7f56427e5a844 Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 24 Jun 2024 21:14:32 +0200 Subject: [PATCH 20/77] CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest finally works --- .../protocol/ProtocolTestGenerator.kt | 10 +- .../serialize/CborSerializerGenerator.kt | 6 +- .../serialize/JsonSerializerGenerator.kt | 2 - .../testutil/ServerCodegenIntegrationTest.kt | 1 + ...rGeneratorSerdeRoundTripIntegrationTest.kt | 151 +++++++++++++----- 5 files changed, 119 insertions(+), 51 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index 98e4480b9c..14025df452 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -166,13 +166,8 @@ abstract class ProtocolTestGenerator { if (expectFail(testCase)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") } - val fnNameSuffix = - when (testCase) { - is TestCase.ResponseTest -> "_response" - is TestCase.RequestTest -> "_request" - is TestCase.MalformedRequestTest -> "_malformed_request" - } - testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnNameSuffix()") { + val fnNameSuffix = testCase.kind.toString().toSnakeCase() + testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}_$fnNameSuffix()") { block(this) } } @@ -280,6 +275,7 @@ object ServiceShapeId { const val AWS_JSON_10 = "aws.protocoltests.json10#JsonRpc10" const val AWS_JSON_11 = "aws.protocoltests.json#JsonProtocol" const val REST_JSON = "aws.protocoltests.restjson#RestJson" + const val RPC_V2_CBOR = "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol" const val REST_XML = "aws.protocoltests.restxml#RestXml" const val AWS_QUERY = "aws.protocoltests.query#AwsQuery" const val EC2_QUERY = "aws.protocoltests.ec2#AwsEc2" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 28b204bdbb..839d6f312d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -392,13 +392,13 @@ class CborSerializerGenerator( } private fun RustWriter.serializeCollection(context: Context) { + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) + } // `.expect()` safety: `From for usize` is not in the standard library, but the conversion should be // infallible (unless we ever have 128-bit machines I guess). // See https://users.rust-lang.org/t/cant-convert-usize-to-u64/6243. // TODO Point to a `static` to not inflate the binary. - for (customization in customizations) { - customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) - } rust( """ encoder.array( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index b3f16d90c2..bd8efdbb6f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -47,10 +47,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.outputShape diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt index 8c0254904e..a40673eb4a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.build.PluginContext import software.amazon.smithy.build.SmithyBuildPlugin import software.amazon.smithy.model.Model import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.codegenIntegrationTest import software.amazon.smithy.rust.codegen.server.smithy.RustServerCodegenPlugin diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt index 52c4d4ab2d..f0b7a4401e 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -8,18 +8,22 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.protocoltests.traits.AppliesTo +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata @@ -33,18 +37,19 @@ import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions -import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE -import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRpcV2CborProtocol import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRpcV2CborFactory import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest import java.util.function.Predicate @@ -69,7 +74,15 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { return baseMetadata.withDerives(serdeDeserialize) } - override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + override fun memberMeta(memberShape: MemberShape): RustMetadata { + val baseMetadata = base.toSymbol(memberShape).expectRustMetadata() + return baseMetadata.copy( + additionalAttributes = baseMetadata.additionalAttributes + Attribute( + """serde(rename = "${memberShape.memberName}")""", + isDeriveHelper = true, + ), + ) + } override fun structureMeta(structureShape: StructureShape) = addDeriveSerdeDeserialize(structureShape) override fun unionMeta(unionShape: UnionShape) = addDeriveSerdeDeserialize(unionShape) @@ -104,19 +117,19 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { ): ProtocolTestGenerator { val noOpProtocolTestsGenerator = object : ProtocolTestGenerator() { override val codegenContext: CodegenContext - get() = PANIC("We'll never need this") + get() = baseGenerator.codegenContext override val protocolSupport: ProtocolSupport - get() = PANIC("We'll never need this") + get() = baseGenerator.protocolSupport override val operationShape: OperationShape - get() = PANIC("We'll never need this") + get() = baseGenerator.operationShape override val appliesTo: AppliesTo - get() = PANIC("We'll never need this") + get() = baseGenerator.appliesTo override val expectFail: Set - get() = PANIC("We'll never need this") + get() = baseGenerator.expectFail override val runOnly: Set - get() = PANIC("We'll never need this") + get() = baseGenerator.runOnly override val disabledTests: Set - get() = PANIC("We'll never need this") + get() = baseGenerator.disabledTests override fun RustWriter.renderAllTestCases(allTests: List) { // No-op. @@ -127,9 +140,12 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { } } + val model = Model.assembler().discoverModels().assemble().result.get() + // Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on // which we can't `#[derive(serde::Deserialize)]`. - val model = Model.assembler().discoverModels().assemble().result.get() + // Note we can't use `ModelTransformer.removeShapes` because it will leave the model in an inconsistent state + // when removing list/set shape member shapes. val removeTimestampAndBlobShapes: Predicate = Predicate { shape -> when (shape) { is MemberShape -> { @@ -137,57 +153,114 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { targetShape is BlobShape || targetShape is TimestampShape } is BlobShape, is TimestampShape -> true + is CollectionShape -> { + val targetShape = model.expectShape(shape.member.target) + targetShape is BlobShape || targetShape is TimestampShape + } else -> false } } - val transformedModel = ModelTransformer.create().removeShapesIf(model, removeTimestampAndBlobShapes) + fun removeShapesByShapeId(shapeIds: Set): Predicate { + val predicate: Predicate = Predicate { shape -> + when (shape) { + is MemberShape -> { + val targetShape = model.expectShape(shape.target) + shapeIds.contains(targetShape.id) + } + is CollectionShape -> { + val targetShape = model.expectShape(shape.member.target) + shapeIds.contains(targetShape.id) + } + else -> { + shapeIds.contains(shape.id) + } + } + } + return predicate + } + + + val modelTransformer = ModelTransformer.create() + val transformedModel = modelTransformer.removeShapesIf( + modelTransformer.removeShapesIf(model, removeTimestampAndBlobShapes), + // These enums do not serialize their variants using the Rust members' names. + // We'd have to tack on `#[serde(rename = "name")]` using the proper name defined in the Smithy enum definition. + // But we have no way of injecting that attribute on Rust enum variants in the code generator. + // So we just remove these problematic shapes. + removeShapesByShapeId( + setOf( + ShapeId.from("smithy.protocoltests.shared#FooEnum"), + ShapeId.from("smithy.protocoltests.rpcv2Cbor#TestEnum"), + ), + ), + ) + val serviceShape = transformedModel.expectShape(ShapeId.from(RPC_V2_CBOR)) serverIntegrationTest( transformedModel, additionalDecorators = listOf(addDeriveSerdeSerializeDecorator, noProtocolTestsDecorator), - params = IntegrationTestParams(service = "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol") + params = IntegrationTestParams(service = serviceShape.id.toString()) ) { codegenContext, rustCrate -> + // TODO(https://github.com/smithy-lang/smithy-rs/issues/1147): NaN != NaN. Ideally we when we address + // this issue, we'd re-use the structure shape comparison code that both client and server protocol test + // generators would use. + val expectFail = setOf("RpcV2CborSupportsNaNFloatOutputs") + val codegenScope = arrayOf( "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), "SerdeCbor" to CargoDependency.SerdeCbor.toType(), ) val instantiator = ServerInstantiator(codegenContext, ignoreMissingMembers = true) - val rpcV2 = RpcV2Cbor(codegenContext) + val rpcV2 = ServerRpcV2CborProtocol(codegenContext) for (operationShape in codegenContext.model.operationShapes) { val serverProtocolTestGenerator = ServerProtocolTestGenerator(codegenContext, ServerRpcV2CborFactory().support(), operationShape) + // The SDK can only serialize operation outputs, so we only ask for response tests. val tests = - serverProtocolTestGenerator.requestTestCases() + serverProtocolTestGenerator.responseTestCases() - - val serializeFn = rpcV2 - .structuredDataSerializer() - .operationOutputSerializer(operationShape) - ?: continue // Skip if there's nothing to serialize. + serverProtocolTestGenerator.responseTestCases() rustCrate.withModule(ProtocolFunctions.serDeModule) { for (test in tests) { - val (targetShape, params) = when (test) { + when (test) { is TestCase.MalformedRequestTest -> UNREACHABLE("we did not ask for tests of this kind") - is TestCase.RequestTest -> operationShape.inputShape(codegenContext.model) to test.testCase.params - is TestCase.ResponseTest -> test.targetShape to test.testCase.params - } + is TestCase.RequestTest -> UNREACHABLE("we did not ask for tests of this kind") +// is TestCase.RequestTest -> operationShape.inputShape(codegenContext.model) to test.testCase.params + is TestCase.ResponseTest -> { + val targetShape = test.targetShape + val params = test.testCase.params + + val serializeFn = if (targetShape.hasTrait()) { + rpcV2.structuredDataSerializer().serverErrorSerializer(targetShape.id) + } else { + rpcV2.structuredDataSerializer().operationOutputSerializer(operationShape) + } + + if (serializeFn == null) { + // Skip if there's nothing to serialize. + continue + } - unitTest("we_serialize_and_serde_cbor_deserializes_${test.id}") { - rustTemplate( - """ - let expected = #{InstantiateShape:W}; - let bytes = #{SerializeFn}(&expected) - .expect("generated CBOR serializer failed"); - let actual = #{SerdeCbor}::from_slice(&bytes) - .expect("serde_cbor failed deserializing from bytes"); - #{AssertEq}(expected, actual); - """, - "InstantiateShape" to instantiator.generate(targetShape, params), - "SerializeFn" to serializeFn, - *codegenScope, - ) + if (expectFail.contains(test.id)) { + writeWithNoFormatting("#[should_panic]") + } + unitTest("we_serialize_and_serde_cbor_deserializes_${test.id.toSnakeCase()}_${test.kind.toString().toSnakeCase()}") { + rustTemplate( + """ + let expected = #{InstantiateShape:W}; + let bytes = #{SerializeFn}(&expected) + .expect("generated CBOR serializer failed"); + let actual = #{SerdeCbor}::from_slice(&bytes) + .expect("serde_cbor failed deserializing from bytes"); + #{AssertEq}(expected, actual); + """, + "InstantiateShape" to instantiator.generate(targetShape, params), + "SerializeFn" to serializeFn, + *codegenScope, + ) + } + } } } } From 010bc1b3adf729fb4669a6deab54675ca1c1b5e8 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 25 Jun 2024 14:07:37 +0200 Subject: [PATCH 21/77] CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest works deserializer too --- ...rGeneratorSerdeRoundTripIntegrationTest.kt | 202 ++++++++++++------ 1 file changed, 133 insertions(+), 69 deletions(-) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt index f0b7a4401e..c24f88a88b 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -40,16 +40,20 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Proto import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRpcV2CborProtocol +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRpcV2CborFactory import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest import java.util.function.Predicate @@ -61,17 +65,19 @@ import java.util.function.Predicate * is battle-tested. */ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { - class DeriveSerdeDeserializeSymbolMetadataProvider( + class DeriveSerdeSerializeDeserializeSymbolMetadataProvider( private val base: RustSymbolProvider, ) : SymbolMetadataProvider(base) { private val serdeDeserialize = CargoDependency.Serde.copy(scope = DependencyScope.Compile).toType().resolve("Deserialize") + private val serdeSerialize = + CargoDependency.Serde.copy(scope = DependencyScope.Compile).toType().resolve("Serialize") - private fun addDeriveSerdeDeserialize(shape: Shape): RustMetadata { + private fun addDeriveSerdeSerializeDeserialize(shape: Shape): RustMetadata { check(shape !is MemberShape) val baseMetadata = base.toSymbol(shape).expectRustMetadata() - return baseMetadata.withDerives(serdeDeserialize) + return baseMetadata.withDerives(serdeSerialize, serdeDeserialize) } override fun memberMeta(memberShape: MemberShape): RustMetadata { @@ -84,63 +90,19 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { ) } - override fun structureMeta(structureShape: StructureShape) = addDeriveSerdeDeserialize(structureShape) - override fun unionMeta(unionShape: UnionShape) = addDeriveSerdeDeserialize(unionShape) - override fun enumMeta(stringShape: StringShape) = addDeriveSerdeDeserialize(stringShape) + override fun structureMeta(structureShape: StructureShape) = addDeriveSerdeSerializeDeserialize(structureShape) + override fun unionMeta(unionShape: UnionShape) = addDeriveSerdeSerializeDeserialize(unionShape) + override fun enumMeta(stringShape: StringShape) = addDeriveSerdeSerializeDeserialize(stringShape) - override fun listMeta(listShape: ListShape): RustMetadata = addDeriveSerdeDeserialize(listShape) - override fun mapMeta(mapShape: MapShape): RustMetadata = addDeriveSerdeDeserialize(mapShape) - override fun stringMeta(stringShape: StringShape): RustMetadata = addDeriveSerdeDeserialize(stringShape) - override fun numberMeta(numberShape: NumberShape): RustMetadata = addDeriveSerdeDeserialize(numberShape) - override fun blobMeta(blobShape: BlobShape): RustMetadata = addDeriveSerdeDeserialize(blobShape) + override fun listMeta(listShape: ListShape): RustMetadata = addDeriveSerdeSerializeDeserialize(listShape) + override fun mapMeta(mapShape: MapShape): RustMetadata = addDeriveSerdeSerializeDeserialize(mapShape) + override fun stringMeta(stringShape: StringShape): RustMetadata = addDeriveSerdeSerializeDeserialize(stringShape) + override fun numberMeta(numberShape: NumberShape): RustMetadata = addDeriveSerdeSerializeDeserialize(numberShape) + override fun blobMeta(blobShape: BlobShape): RustMetadata = addDeriveSerdeSerializeDeserialize(blobShape) } - @Test - fun `we serialize and serde_cbor deserializes round trip`() { - val addDeriveSerdeSerializeDecorator = object : ServerCodegenDecorator { - override val name: String = "Add `#[derive(serde::Deserialize)]`" - override val order: Byte = 0 - - override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = - DeriveSerdeDeserializeSymbolMetadataProvider(base) - } - - // Don't generate protocol tests, because it'll attempt to pull out `params` for member shapes we'll remove - // from the model. - val noProtocolTestsDecorator = object : ServerCodegenDecorator { - override val name: String = "Don't generate protocol tests" - override val order: Byte = 0 - - override fun protocolTestGenerator( - codegenContext: ServerCodegenContext, - baseGenerator: ProtocolTestGenerator - ): ProtocolTestGenerator { - val noOpProtocolTestsGenerator = object : ProtocolTestGenerator() { - override val codegenContext: CodegenContext - get() = baseGenerator.codegenContext - override val protocolSupport: ProtocolSupport - get() = baseGenerator.protocolSupport - override val operationShape: OperationShape - get() = baseGenerator.operationShape - override val appliesTo: AppliesTo - get() = baseGenerator.appliesTo - override val expectFail: Set - get() = baseGenerator.expectFail - override val runOnly: Set - get() = baseGenerator.runOnly - override val disabledTests: Set - get() = baseGenerator.disabledTests - - override fun RustWriter.renderAllTestCases(allTests: List) { - // No-op. - } - - } - return noOpProtocolTestsGenerator - } - } - - val model = Model.assembler().discoverModels().assemble().result.get() + fun prepareRpcV2CborModel(): Model { + var model = Model.assembler().discoverModels().assemble().result.get() // Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on // which we can't `#[derive(serde::Deserialize)]`. @@ -179,9 +141,8 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { return predicate } - val modelTransformer = ModelTransformer.create() - val transformedModel = modelTransformer.removeShapesIf( + model = modelTransformer.removeShapesIf( modelTransformer.removeShapesIf(model, removeTimestampAndBlobShapes), // These enums do not serialize their variants using the Rust members' names. // We'd have to tack on `#[serde(rename = "name")]` using the proper name defined in the Smithy enum definition. @@ -195,16 +156,65 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { ), ) - val serviceShape = transformedModel.expectShape(ShapeId.from(RPC_V2_CBOR)) + return model + } + + @Test + fun `serde_cbor round trip`() { + val addDeriveSerdeSerializeDeserializeDecorator = object : ServerCodegenDecorator { + override val name: String = "Add `#[derive(serde::Serialize, serde::Deserialize)]`" + override val order: Byte = 0 + + override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = + DeriveSerdeSerializeDeserializeSymbolMetadataProvider(base) + } + + // Don't generate protocol tests, because it'll attempt to pull out `params` for member shapes we'll remove + // from the model. + val noProtocolTestsDecorator = object : ServerCodegenDecorator { + override val name: String = "Don't generate protocol tests" + override val order: Byte = 0 + + override fun protocolTestGenerator( + codegenContext: ServerCodegenContext, + baseGenerator: ProtocolTestGenerator + ): ProtocolTestGenerator { + val noOpProtocolTestsGenerator = object : ProtocolTestGenerator() { + override val codegenContext: CodegenContext + get() = baseGenerator.codegenContext + override val protocolSupport: ProtocolSupport + get() = baseGenerator.protocolSupport + override val operationShape: OperationShape + get() = baseGenerator.operationShape + override val appliesTo: AppliesTo + get() = baseGenerator.appliesTo + override val expectFail: Set + get() = baseGenerator.expectFail + override val runOnly: Set + get() = baseGenerator.runOnly + override val disabledTests: Set + get() = baseGenerator.disabledTests + + override fun RustWriter.renderAllTestCases(allTests: List) { + // No-op. + } + + } + return noOpProtocolTestsGenerator + } + } + + val model = prepareRpcV2CborModel() + val serviceShape = model.expectShape(ShapeId.from(RPC_V2_CBOR)) serverIntegrationTest( - transformedModel, - additionalDecorators = listOf(addDeriveSerdeSerializeDecorator, noProtocolTestsDecorator), + model, + additionalDecorators = listOf(addDeriveSerdeSerializeDeserializeDecorator, noProtocolTestsDecorator), params = IntegrationTestParams(service = serviceShape.id.toString()) ) { codegenContext, rustCrate -> // TODO(https://github.com/smithy-lang/smithy-rs/issues/1147): NaN != NaN. Ideally we when we address // this issue, we'd re-use the structure shape comparison code that both client and server protocol test // generators would use. - val expectFail = setOf("RpcV2CborSupportsNaNFloatOutputs") + val expectFail = setOf("RpcV2CborSupportsNaNFloatInputs", "RpcV2CborSupportsNaNFloatOutputs") val codegenScope = arrayOf( "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), @@ -217,16 +227,16 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { for (operationShape in codegenContext.model.operationShapes) { val serverProtocolTestGenerator = ServerProtocolTestGenerator(codegenContext, ServerRpcV2CborFactory().support(), operationShape) - // The SDK can only serialize operation outputs, so we only ask for response tests. - val tests = - serverProtocolTestGenerator.responseTestCases() rustCrate.withModule(ProtocolFunctions.serDeModule) { - for (test in tests) { + // The SDK can only serialize operation outputs, so we only ask for response tests. + val responseTests = + serverProtocolTestGenerator.responseTestCases() + + for (test in responseTests) { when (test) { is TestCase.MalformedRequestTest -> UNREACHABLE("we did not ask for tests of this kind") is TestCase.RequestTest -> UNREACHABLE("we did not ask for tests of this kind") -// is TestCase.RequestTest -> operationShape.inputShape(codegenContext.model) to test.testCase.params is TestCase.ResponseTest -> { val targetShape = test.targetShape val params = test.testCase.params @@ -250,7 +260,7 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { """ let expected = #{InstantiateShape:W}; let bytes = #{SerializeFn}(&expected) - .expect("generated CBOR serializer failed"); + .expect("our generated CBOR serializer failed"); let actual = #{SerdeCbor}::from_slice(&bytes) .expect("serde_cbor failed deserializing from bytes"); #{AssertEq}(expected, actual); @@ -263,6 +273,60 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { } } } + + // The SDK can only deserialize operation inputs, so we only ask for request tests. + val requestTests = + serverProtocolTestGenerator.requestTestCases() + val inputShape = operationShape.inputShape(codegenContext.model) + val err = + if (ServerBuilderGenerator.hasFallibleBuilder( + inputShape, + codegenContext.model, + codegenContext.symbolProvider, + takeInUnconstrainedTypes = true, + ) + ) { + """.expect("builder failed to build")""" + } else { + "" + } + + for (test in requestTests) { + when (test) { + is TestCase.MalformedRequestTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.ResponseTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.RequestTest -> { + val targetShape = operationShape.inputShape(codegenContext.model) + val params = test.testCase.params + + val deserializeFn = rpcV2.structuredDataParser().serverInputParser(operationShape) + ?: // Skip if there's nothing to serialize. + continue + + if (expectFail.contains(test.id)) { + writeWithNoFormatting("#[should_panic]") + } + unitTest("serde_cbor_serializes_and_we_deserialize_${test.id.toSnakeCase()}_${test.kind.toString().toSnakeCase()}") { + rustTemplate( + """ + let expected = #{InstantiateShape:W}; + let bytes: Vec = #{SerdeCbor}::to_vec(&expected) + .expect("serde_cbor failed serializing to `Vec`"); + let input = #{InputBuilder}::default(); + let input = #{DeserializeFn}(&bytes, input) + .expect("our generated CBOR deserializer failed"); + let actual = input.build()$err; + #{AssertEq}(expected, actual); + """, + "InstantiateShape" to instantiator.generate(targetShape, params), + "DeserializeFn" to deserializeFn, + "InputBuilder" to inputShape.serverBuilderSymbol(codegenContext), + *codegenScope, + ) + } + } + } + } } } } From 465c86d6f6fef1d5ca610c16c2b2822f898c9da9 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 25 Jun 2024 17:05:06 +0200 Subject: [PATCH 22/77] Unions with unit variants compile --- .../rpcv2Cbor-extras.smithy | 10 ++++- codegen-core/common-test-models/simple.smithy | 6 +++ .../protocols/parse/CborParserGenerator.kt | 32 ++++++++------ .../serialize/CborSerializerGenerator.kt | 16 +++++-- .../serialize/JsonSerializerGenerator.kt | 18 ++++---- .../smithy/rust/codegen/core/util/Smithy.kt | 4 ++ .../server/smithy/protocols/RpcV2CborTest.kt | 42 ------------------- 7 files changed, 62 insertions(+), 66 deletions(-) delete mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2CborTest.kt diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy index 44d9b51c16..b308fef37b 100644 --- a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -139,6 +139,7 @@ structure ComplexStruct { list: SimpleList map: SimpleMap union: SimpleUnion + unitUnion: UnitUnion structureList: StructList @@ -162,11 +163,18 @@ map SimpleMap { } // TODO(https://github.com/smithy-lang/smithy/issues/2325): Upstream protocol -// test suite doesn't cover unions. +// test suite doesn't cover unions. While the generated SDK compiles, we're not +// exercising the (de)serializers with actual values. union SimpleUnion { blob: Blob boolean: Boolean string: String + unit: Unit +} + +union UnitUnion { + unitA: Unit + unitB: Unit } list ComplexList { diff --git a/codegen-core/common-test-models/simple.smithy b/codegen-core/common-test-models/simple.smithy index c7e58c8e4a..b4a3256dcb 100644 --- a/codegen-core/common-test-models/simple.smithy +++ b/codegen-core/common-test-models/simple.smithy @@ -19,4 +19,10 @@ operation Operation { structure OperationInputOutput { message: String + unitUnion: UnitUnion +} + +union UnitUnion { + unitA: Unit + message: String } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 2114ff7c19..5944f4305d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -48,21 +48,17 @@ import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.outputShape -/** - * Class describing a CBOR parser section that can be used in a customization. - */ +/** Class describing a CBOR parser section that can be used in a customization. */ sealed class CborParserSection(name: String) : Section(name) { data class BeforeBoxingDeserializedMember(val shape: MemberShape) : CborParserSection("BeforeBoxingDeserializedMember") } -/** - * Customization for the CBOR parser. - */ +/** Customization for the CBOR parser. */ typealias CborParserCustomization = NamedCustomization -// TODO Add a `CborParserGeneratorTest` a la `CborSerializerGeneratorTest`. class CborParserGenerator( private val codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, @@ -75,7 +71,6 @@ class CborParserGenerator( private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig - // TODO Use? private val codegenTarget = codegenContext.target private val smithyCbor = CargoDependency.smithyCbor(runtimeConfig).toType() private val protocolFunctions = ProtocolFunctions(codegenContext) @@ -239,10 +234,11 @@ class CborParserGenerator( // Call `builder.set_member()` only if the value for the field on the wire is not null. rustTemplate( """ - ::aws_smithy_cbor::decode::set_optional(builder, decoder, |builder, decoder| { + #{SmithyCbor}::decode::set_optional(builder, decoder, |builder, decoder| { Ok(#{MemberSettingWritable:W}) })? """, + *codegenScope, "MemberSettingWritable" to callBuilderSetMemberFieldWritable ) } @@ -266,8 +262,6 @@ class CborParserGenerator( private fun unionPairParserFnWritable(shape: UnionShape) = writable { val returnSymbolToParse = returnSymbolToParse(shape) - // TODO Test with unit variants - // TODO Test with all unit variants rustBlockTemplate( """ fn pair( @@ -281,8 +275,20 @@ class CborParserGenerator( for (member in shape.members()) { val variantName = symbolProvider.toMemberName(member) - withBlock("${member.memberName.dq()} => #T::$variantName(", "?),", returnSymbolToParse.symbol) { - deserializeMember(member).invoke(this) + if (member.isTargetUnit()) { + rust( + """ + ${member.memberName.dq()} => { + decoder.skip()?; + #T::$variantName + } + """, + returnSymbolToParse.symbol + ) + } else { + withBlock("${member.memberName.dq()} => #T::$variantName(", "?),", returnSymbolToParse.symbol) { + deserializeMember(member).invoke(this) + } } } // TODO Test client mode (parse unknown variant) and server mode (reject unknown variant). diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 839d6f312d..d509d048d0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ByteShape import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DocumentShape import software.amazon.smithy.model.shapes.DoubleShape import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.IntegerShape @@ -47,6 +48,7 @@ import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isTargetUnit +import software.amazon.smithy.rust.codegen.core.util.isUnit import software.amazon.smithy.rust.codegen.core.util.outputShape // TODO Cleanup commented and unused code. @@ -137,7 +139,6 @@ class CborSerializerGenerator( } } - // Specialized since it holds a JsonObjectWriter expression rather than a JsonValueWriter data class StructContext( /** Name of the variable that holds the struct */ val localName: String, @@ -300,6 +301,16 @@ class CborSerializerGenerator( context: StructContext, includedMembers: List? = null, ) { + if (context.shape.isUnit()) { + rust( + """ + encoder.begin_map(); + encoder.end(); + """ + ) + return + } + // TODO Need to inject `__type` when serializing errors. val structureSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> rustBlockTemplate( @@ -371,8 +382,7 @@ class CborSerializerGenerator( is TimestampShape -> rust("$encoder.timestamp(${value.asRef()});") - // TODO Document shapes have not been specced out yet. - // is DocumentShape -> rust("$encoder.document(${value.asRef()});") + is DocumentShape -> UNREACHABLE("Smithy RPC v2 CBOR does not support `document` shapes") // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html else -> { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index bd8efdbb6f..14cf21d6cb 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -492,13 +492,17 @@ class JsonSerializerGenerator( rust("let mut $objectName = ${context.writerExpression}.start_object();") // We call inner only when context's shape is not the Unit type. // If it were, calling inner would generate the following function: - // pub fn serialize_structure_crate_model_unit( - // object: &mut aws_smithy_json::serialize::JsonObjectWriter, - // input: &crate::model::Unit, - // ) -> Result<(), aws_smithy_http::operation::error::SerializationError> { - // let (_, _) = (object, input); - // Ok(()) - // } + // + // ```rust + // pub fn serialize_structure_crate_model_unit( + // object: &mut aws_smithy_json::serialize::JsonObjectWriter, + // input: &crate::model::Unit, + // ) -> Result<(), aws_smithy_http::operation::error::SerializationError> { + // let (_, _) = (object, input); + // Ok(()) + // } + // ``` + // // However, this would cause a compilation error at a call site because it cannot // extract data out of the Unit type that corresponds to the variable "input" above. if (!context.shape.isTargetUnit()) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt index b167f05d2d..b3759a6132 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt @@ -73,6 +73,10 @@ fun MemberShape.isOutputEventStream(model: Model): Boolean { private val unitShapeId = ShapeId.from("smithy.api#Unit") +fun Shape.isUnit(): Boolean { + return this.id == unitShapeId +} + fun MemberShape.isTargetUnit(): Boolean { return this.target == unitShapeId } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2CborTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2CborTest.kt deleted file mode 100644 index 5c804f8739..0000000000 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2CborTest.kt +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy.protocols - -import org.junit.jupiter.api.Test -import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel -import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest - -// TODO This won't be needed since we'll cover it with a proper integration test. -internal class RpcV2CborTest { - val model = """ - ${"\$"}version: "2.0" - - namespace com.amazonaws.simple - - use smithy.protocols#rpcv2 - - @rpcv2(format: ["cbor"]) - service RpcV2Service { - version: "SomeVersion", - operations: [RpcV2Operation], - } - - @http(uri: "/operation", method: "POST") - operation RpcV2Operation { - input: OperationInputOutput - output: OperationInputOutput - } - - structure OperationInputOutput { - message: String - } - """.asSmithyModel() - - @Test - fun `generate a rpc v2 service that compiles`() { - serverIntegrationTest(model) { _, _ -> } - } -} From b9b34ec3a0777f88cdfdc5948ea83a94d5c4b073 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 25 Jun 2024 18:23:45 +0200 Subject: [PATCH 23/77] Unknown union variants --- .../protocols/parse/CborParserGenerator.kt | 28 +++++++++++++++---- rust-runtime/aws-smithy-cbor/src/decode.rs | 14 ++++++++-- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 5944f4305d..da9f420074 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -37,6 +37,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed @@ -291,11 +293,27 @@ class CborParserGenerator( } } } - // TODO Test client mode (parse unknown variant) and server mode (reject unknown variant). - // In client mode, resolve an unknown union variant to the unknown variant. - // In server mode, use strict parsing. - // Consultation: https://github.com/awslabs/smithy/issues/1222 - rust("_ => { todo!() }") + when (codegenTarget.renderUnknownVariant()) { + // In client mode, resolve an unknown union variant to the unknown variant. + true -> + rustTemplate( + """ + _ => { + decoder.skip()?; + Some(#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME}) + } + """, + "Union" to returnSymbolToParse.symbol, + *codegenScope, + ) + // In server mode, use strict parsing. + // Consultation: https://github.com/awslabs/smithy/issues/1222 + false -> + rustTemplate( + "variant => return Err(#{Error}::unknown_union_variant(variant, decoder.position()))", + *codegenScope, + ) + } } } } diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs index 6e3dd42ea3..f17f480ec8 100644 --- a/rust-runtime/aws-smithy-cbor/src/decode.rs +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -47,11 +47,21 @@ impl DeserializeError { } } + /// Unknown union variant was detected. Servers reject unknown union varaints. + pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self { + Self { + _inner: Error::message(format!("encountered unknown union variant {}", variant_name)) + .at(at), + } + } + /// More than one union variant was detected, but we never even got to parse the first one. pub fn mixed_union_variants(at: usize) -> Self { Self { - _inner: Error::message("encountered mixed variants in union; expected end of union") - .at(at), + _inner: Error::message( + "encountered mixed variants in union; expected a single union variant to be set", + ) + .at(at), } } From 40648983bb2c0b0f3bfa805c46525cd2e2a4e0c1 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 26 Jun 2024 13:20:21 +0200 Subject: [PATCH 24/77] DRY up collection decoding --- .../protocols/parse/CborParserGenerator.kt | 84 +++++++------------ 1 file changed, 28 insertions(+), 56 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index da9f420074..b2788ea9e4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -318,37 +318,32 @@ class CborParserGenerator( } } - private fun decodeStructureMapLoopWritable() = writable { - rustTemplate( - """ - match decoder.map()? { - None => loop { - match decoder.datatype()? { - #{SmithyCbor}::data::Type::Break => { - decoder.skip()?; - break; - } - _ => { - builder = pair(builder, decoder)?; - } - }; - }, - Some(n) => { - for _ in 0..n { - builder = pair(builder, decoder)?; - } - } - }; - """, - *codegenScope, - ) + enum class CollectionKind { + Map, + List; + + /** Method to invoke on the decoder to decode this collection kind. **/ + fun decoderMethodName() = when (this) { + Map -> "map" + List -> "list" + } } - // TODO This should be DRYed up with `decodeStructureMapLoopWritable`. - private fun decodeMapLoopWritable() = writable { + /** + * Decode a collection of homogeneous CBOR data items: a map or an array. + * The first branch of the `match` corresponds to when the collection is encoded using variable-length encoding; + * the second branch corresponds to fixed-length encoding. + * + * https://www.rfc-editor.org/rfc/rfc8949.html#name-indefinite-length-arrays-an + */ + private fun decodeCollectionLoopWritable( + collectionKind: CollectionKind, + variableBindingName: String, + decodeItemFnName: String, + ) = writable { rustTemplate( """ - match decoder.map()? { + match decoder.${collectionKind.decoderMethodName()}()? { None => loop { match decoder.datatype()? { #{SmithyCbor}::data::Type::Break => { @@ -356,13 +351,13 @@ class CborParserGenerator( break; } _ => { - map = pair(map, decoder)?; + $variableBindingName = $decodeItemFnName($variableBindingName, decoder)?; } }; }, Some(n) => { for _ in 0..n { - map = pair(map, decoder)?; + $variableBindingName = $decodeItemFnName($variableBindingName, decoder)?; } } }; @@ -371,32 +366,9 @@ class CborParserGenerator( ) } - // TODO This should be DRYed up with `decodeStructureMapLoopWritable`. - private fun decodeListLoop() = writable { - rustTemplate( - """ - match decoder.list()? { - None => loop { - match decoder.datatype()? { - #{SmithyCbor}::data::Type::Break => { - decoder.skip()?; - break; - } - _ => { - list = member(list, decoder)?; - } - }; - }, - Some(n) => { - for _ in 0..n { - list = member(list, decoder)?; - } - } - }; - """, - *codegenScope, - ) - } + private fun decodeStructureMapLoopWritable() = decodeCollectionLoopWritable(CollectionKind.Map, "builder", "pair") + private fun decodeMapLoopWritable() = decodeCollectionLoopWritable(CollectionKind.Map, "map", "pair") + private fun decodeListLoopWritable() = decodeCollectionLoopWritable(CollectionKind.List, "list", "member") /** * Reusable structure parser implementation that can be used to generate parsing code for @@ -552,7 +524,7 @@ class CborParserGenerator( returnUnconstrainedType = returnUnconstrainedType, ), "InitContainerWritable" to initContainerWritable, - "DecodeListLoop" to decodeListLoop(), + "DecodeListLoop" to decodeListLoopWritable(), *codegenScope, ) } From dc8452d56971aa3b05294ecad892fb3778f4e73c Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 26 Jun 2024 15:52:31 +0200 Subject: [PATCH 25/77] Assert token stream ends --- .../rpcv2Cbor-extras.smithy | 91 ++++++++++++++++++- .../protocol/ProtocolTestGenerator.kt | 1 + .../protocols/parse/CborParserGenerator.kt | 5 +- .../protocol/ServerProtocolTestGenerator.kt | 6 +- rust-runtime/aws-smithy-cbor/src/decode.rs | 7 ++ .../src/protocol/rpc_v2/mod.rs | 1 + .../src/protocol/rpc_v2/runtime_error.rs | 10 +- 7 files changed, 110 insertions(+), 11 deletions(-) diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy index b308fef37b..96dc7f1047 100644 --- a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -5,13 +5,15 @@ namespace smithy.protocoltests.rpcv2Cbor use smithy.framework#ValidationException use smithy.protocols#rpcv2Cbor use smithy.test#httpResponseTests - +use smithy.test#httpMalformedRequestTests @rpcv2Cbor service RpcV2Service { operations: [ - SimpleStructOperation, + SimpleStructOperation ComplexStructOperation + EmptyStructOperation + SingleMemberStructOperation ] } @@ -24,13 +26,89 @@ operation SimpleStructOperation { errors: [ValidationException] } -@http(uri: "/complex-struct-operation", method: "POST") operation ComplexStructOperation { input: ComplexStruct output: ComplexStruct errors: [ValidationException] } +operation EmptyStructOperation { + input: EmptyStruct + output: EmptyStruct +} + +operation SingleMemberStructOperation { + input: SingleMemberStruct + output: SingleMemberStruct +} + +// TODO We fail this one, cut issue +apply EmptyStructOperation @httpMalformedRequestTests([ + { + id: "AdditionalTokensEmptyStruct", + documentation: """ + When additional tokens are found past where we expect the end of the body, + the request should be rejected with a serialization exception.""", + protocol: rpcv2Cbor, + request: { + method: "POST", + uri: "/service/RpcV2Service/operation/EmptyStructOperation", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + // Two empty variable-length encoded CBOR maps back to back. + body: "v/+//w==" + }, + response: { + code: 400, + body: { + mediaType: "application/cbor", + assertion: { + // An empty CBOR map. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type`. + contents: "oA==" + } + } + } + + } +]) + +apply SingleMemberStructOperation @httpMalformedRequestTests([ + { + id: "AdditionalTokensSingleMemberStruct", + documentation: """ + When additional tokens are found past where we expect the end of the body, + the request should be rejected with a serialization exception.""", + protocol: rpcv2Cbor, + request: { + method: "POST", + uri: "/service/RpcV2Service/operation/SingleMemberStructOperation", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + // Two empty variable-length encoded CBOR maps back to back. + body: "v/+//w==" + }, + response: { + code: 400, + body: { + mediaType: "application/cbor", + assertion: { + // An empty CBOR map. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type`. + contents: "oA==" + } + } + } + + } +]) + apply SimpleStructOperation @httpResponseTests([ { id: "SimpleStruct", @@ -136,6 +214,7 @@ structure SimpleStruct { structure ComplexStruct { structure: SimpleStruct + emptyStructure: EmptyStruct list: SimpleList map: SimpleMap union: SimpleUnion @@ -149,6 +228,12 @@ structure ComplexStruct { @required complexUnion: ComplexUnion } +structure EmptyStruct { } + +structure SingleMemberStruct { + message: String +} + list StructList { member: SimpleStruct } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index 14025df452..18abece09a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -276,6 +276,7 @@ object ServiceShapeId { const val AWS_JSON_11 = "aws.protocoltests.json#JsonProtocol" const val REST_JSON = "aws.protocoltests.restjson#RestJson" const val RPC_V2_CBOR = "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol" + const val RPC_V2_CBOR_EXTRAS = "smithy.protocoltests.rpcv2Cbor#RpcV2Service" const val REST_XML = "aws.protocoltests.restxml#RestXml" const val AWS_QUERY = "aws.protocoltests.query#AwsQuery" const val EC2_QUERY = "aws.protocoltests.ec2#AwsEc2" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index b2788ea9e4..7c6af55358 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -383,9 +383,6 @@ class CborParserGenerator( fnNameSuffix: String? = null, ): RuntimeType { return protocolFunctions.deserializeFn(shape, fnNameSuffix) { fnName -> - // TODO Test no members. -// val unusedMut = if (includedMembers.isEmpty()) "##[allow(unused_mut)] " else "" - // TODO Assert token stream ended. rustTemplate( """ pub(crate) fn $fnName(value: &[u8], mut builder: #{Builder}) -> Result<#{Builder}, #{Error}> { @@ -396,7 +393,7 @@ class CborParserGenerator( #{DecodeStructureMapLoop:W} if decoder.position() != value.len() { - todo!() + return Err(#{Error}::expected_end_of_stream(decoder.position())); } Ok(builder) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 73bf62e449..5f8421b4f8 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -31,7 +31,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator @@ -39,7 +38,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Servi import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_11 import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON_VALIDATION +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR_EXTRAS import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember @@ -148,6 +149,9 @@ class ServerProtocolTestGenerator( // Response defaults are not set when builders are not used https://github.com/smithy-lang/smithy-rs/issues/3339 FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesDefaultsInResponseWhenMissingInParams", TestCaseKind.Response), FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams", TestCaseKind.Response), + + // TODO + FailingTest(RPC_V2_CBOR_EXTRAS, "AdditionalTokensEmptyStruct", TestCaseKind.MalformedRequest), ) private val DisabledTests = diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs index f17f480ec8..27ada4f7c9 100644 --- a/rust-runtime/aws-smithy-cbor/src/decode.rs +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -65,6 +65,13 @@ impl DeserializeError { } } + /// Expected end of stream but more data is available. + pub fn expected_end_of_stream(at: usize) -> Self { + Self { + _inner: Error::message("encountered additional data; expected end of stream").at(at), + } + } + /// An unexpected type was encountered. // We handle this one when decoding sparse collections: we have to expect either a `null` or an // item, so we try decoding both. diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs index 716befd77a..cec31e5aee 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs @@ -7,6 +7,7 @@ pub mod rejection; pub mod router; pub mod runtime_error; +// TODO Rename to RpcV2Cbor // TODO: Fill link /// [Smithy RPC V2](). pub struct RpcV2; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs index d8ee60796a..8375611639 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs @@ -6,6 +6,7 @@ use crate::response::IntoResponse; use crate::runtime_error::{InternalFailureException, INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE}; use crate::{extension::RuntimeErrorExtension, protocol::rpc_v2::RpcV2}; +use bytes::Bytes; use http::StatusCode; use super::rejection::{RequestRejection, ResponseRejection}; @@ -54,11 +55,14 @@ impl IntoResponse for RuntimeError { .header("Content-Type", "application/cbor") .extension(RuntimeErrorExtension::new(self.name().to_string())); - // TODO + // https://cbor.nemo157.com/#type=hex&value=a0 + const EMPTY_CBOR_MAP: Bytes = Bytes::from_static(&[0xa0]); + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing + // `__type`. let body = match self { RuntimeError::Validation(reason) => crate::body::to_boxed(reason), - // See https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_0-protocol.html#empty-body-serialization - _ => crate::body::to_boxed("{}"), + _ => crate::body::to_boxed(EMPTY_CBOR_MAP), }; res.body(body) From 6ace376a213439300ec76e07a93fa780c1d7f548 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 27 Jun 2024 12:58:34 +0200 Subject: [PATCH 26/77] Add TODO about accepting any body when operation input is empty/non-existent --- codegen-core/common-test-models/rpcv2Cbor-extras.smithy | 1 - .../smithy/generators/protocol/ServerProtocolTestGenerator.kt | 2 +- .../smithy/protocols/ServerHttpBoundProtocolGenerator.kt | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy index 96dc7f1047..8a97ebe3c5 100644 --- a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -42,7 +42,6 @@ operation SingleMemberStructOperation { output: SingleMemberStruct } -// TODO We fail this one, cut issue apply EmptyStructOperation @httpMalformedRequestTests([ { id: "AdditionalTokensEmptyStruct", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 5f8421b4f8..03d1965f67 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -150,7 +150,7 @@ class ServerProtocolTestGenerator( FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesDefaultsInResponseWhenMissingInParams", TestCaseKind.Response), FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams", TestCaseKind.Response), - // TODO + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3723): This affects all protocols FailingTest(RPC_V2_CBOR_EXTRAS, "AdditionalTokensEmptyStruct", TestCaseKind.MalformedRequest), ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 587e94a476..31a55130f9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -804,8 +804,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } - // TODO What about when there's no modeled operation input but the payload is not empty? In some protocols we - // must accept `{}` but we currently accept anything! + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3723): we should inject a check here that asserts that + // the body contents are valid when there is empty operation input or no operation input. val err = if (ServerBuilderGenerator.hasFallibleBuilder( From 75407f9d19671d3bc1b05b5e25bbaa1a1db070c5 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 27 Jun 2024 15:22:54 +0200 Subject: [PATCH 27/77] Address some TODOs, serialize __type correctly; now going to work on fixBroken --- .../rpcv2Cbor-extras.smithy | 62 ++++++++++- codegen-core/common-test-models/simple.smithy | 6 -- .../codegen/core/smithy/CodegenContext.kt | 2 +- .../protocols/parse/CborParserGenerator.kt | 44 ++++---- .../protocols/parse/JsonParserGenerator.kt | 3 +- .../serialize/CborSerializerGenerator.kt | 100 ++++-------------- .../StructuredDataSerializerGenerator.kt | 8 +- ...ypeFieldToServerErrorsCborCustomization.kt | 2 +- 8 files changed, 111 insertions(+), 116 deletions(-) diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy index 8a97ebe3c5..a633ec9e4f 100644 --- a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -11,6 +11,7 @@ use smithy.test#httpMalformedRequestTests service RpcV2Service { operations: [ SimpleStructOperation + ErrorSerializationOperation ComplexStructOperation EmptyStructOperation SingleMemberStructOperation @@ -26,6 +27,12 @@ operation SimpleStructOperation { errors: [ValidationException] } +operation ErrorSerializationOperation { + input: SimpleStruct + output: ValidationException + errors: [ValidationException] +} + operation ComplexStructOperation { input: ComplexStruct output: ComplexStruct @@ -99,19 +106,68 @@ apply SingleMemberStructOperation @httpMalformedRequestTests([ mediaType: "application/cbor", assertion: { // An empty CBOR map. - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type`. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type` because `SerializationException` is not modeled. contents: "oA==" } } } + } +]) +apply ErrorSerializationOperation @httpMalformedRequestTests([ + { + id: "ErrorSerializationIncludesTypeField", + documentation: """ + When invalid input is provided the request should be rejected with + a validation exception, and a `__type` field should be included""", + protocol: rpcv2Cbor, + request: { + method: "POST", + uri: "/service/RpcV2Service/operation/ErrorSerializationOperation", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + // An empty CBOR map. We're missing a lot of `@required` members! + body: "oA==" + }, + response: { + code: 400, + body: { + mediaType: "application/cbor", + assertion: { + // TODO Adjust + contents: "oA==" + } + } + } } ]) +apply ErrorSerializationOperation @httpResponseTests([ + { + id: "OperationOutputSerializationDoesNotIncludeTypeField", + documentation: """ + Despite the operation output being a structure shape with the `@error` trait, + `__type` field should not be included, because we're not serializing a + server error response""", + protocol: rpcv2Cbor, + // TODO This should fail with another code! + code: 200, + params: { + message: "ValidationException message field" + } + bodyMediaType: "application/cbor" + // TODO Adjust + body: "" + } +)] + apply SimpleStructOperation @httpResponseTests([ { id: "SimpleStruct", - protocol: "smithy.protocols#rpcv2Cbor", + protocol: rpcv2Cbor, code: 200, // Not used. params: { blob: "blobby blob", @@ -152,7 +208,7 @@ apply SimpleStructOperation @httpResponseTests([ // Same test, but leave optional types empty { id: "SimpleStructWithOptionsSetToNone", - protocol: "smithy.protocols#rpcv2Cbor", + protocol: rpcv2Cbor, code: 200, // Not used. params: { requiredBlob: "blobby blob", diff --git a/codegen-core/common-test-models/simple.smithy b/codegen-core/common-test-models/simple.smithy index b4a3256dcb..c7e58c8e4a 100644 --- a/codegen-core/common-test-models/simple.smithy +++ b/codegen-core/common-test-models/simple.smithy @@ -19,10 +19,4 @@ operation Operation { structure OperationInputOutput { message: String - unitUnion: UnitUnion -} - -union UnitUnion { - unitA: Unit - message: String } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt index eba03b6f71..cafc58587c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt @@ -53,7 +53,7 @@ abstract class CodegenContext( * Several code generators are reused by both the client and server plugins, but only deviate in small and contained * parts (e.g. changing a return type or adding an attribute). * Instead of splitting the generator in two or setting up an inheritance relationship, sometimes it's best - * to just lookup this flag. + * to just look up this flag. */ open val target: CodegenTarget, ) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 7c6af55358..237c8b988e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -23,6 +23,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter @@ -212,8 +213,10 @@ class CborParserGenerator( val symbol = symbolProvider.toSymbol(member) if (symbol.isRustBoxed()) { rustBlock("") { - rustTemplate("let v = #{DeserializeMember:W}?;", - "DeserializeMember" to deserializeMember(member)) + rustTemplate( + "let v = #{DeserializeMember:W}?;", + "DeserializeMember" to deserializeMember(member) + ) for (customization in customizations) { customization.section( @@ -226,7 +229,8 @@ class CborParserGenerator( } } else { rustTemplate("#{DeserializeMember:W}?", - "DeserializeMember" to deserializeMember(member)) + "DeserializeMember" to deserializeMember(member) + ) } } } @@ -442,7 +446,7 @@ class CborParserGenerator( return structureParser(operationShape, symbolProvider.symbolForBuilder(inputShape), includedMembers) } - private fun RustWriter.deserializeMember(memberShape: MemberShape) = writable { + private fun deserializeMember(memberShape: MemberShape) = writable { when (val target = model.expectShape(memberShape.target)) { // Simple shapes: https://smithy.io/2.0/spec/simple-types.html is BlobShape -> rust("decoder.blob()") @@ -469,29 +473,24 @@ class CborParserGenerator( // Note that no protocol using CBOR serialization supports `document` shapes. else -> PANIC("unexpected shape: $target") } - // TODO Boxing -// val symbol = symbolProvider.toSymbol(memberShape) -// if (symbol.isRustBoxed()) { -// for (customization in customizations) { -// customization.section(JsonParserSection.BeforeBoxingDeserializedMember(memberShape))(this) -// } -// rust(".map(Box::new)") -// } } - private fun RustWriter.deserializeString(target: StringShape, bubbleUp: Boolean = true) = writable { - // TODO Handle enum shapes - rust("decoder.string()") + private fun deserializeString(target: StringShape) = writable { + when (target.hasTrait()) { + true -> { + if (this@CborParserGenerator.returnSymbolToParse(target).isUnconstrained) { + rust("decoder.string()") + } else { + rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) + } + } + false -> rust("decoder.string()") + } } private fun RustWriter.deserializeCollection(shape: CollectionShape) { val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) - // TODO Test `@sparse` and non-@sparse lists. - // - Clients should insert only non-null values in non-`@sparse` list. - // - Servers should reject upon encountering first null value in non-`@sparse` list. - // - Both clients and servers should insert null values in `@sparse` list. - val parser = protocolFunctions.deserializeFn(shape) { fnName -> val initContainerWritable = writable { withBlock("let mut list = ", ";") { @@ -532,11 +531,6 @@ class CborParserGenerator( val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) - // TODO Test `@sparse` and non-@sparse maps. - // - Clients should insert only non-null values in non-`@sparse` map. - // - Servers should reject upon encountering first null value in non-`@sparse` map. - // - Both clients and servers should insert null values in `@sparse` map. - val parser = protocolFunctions.deserializeFn(shape) { fnName -> val initContainerWritable = writable { withBlock("let mut map = ", ";") { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 5f248aad33..cf0676ebd0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -337,8 +337,7 @@ class JsonParserGenerator( rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) } } - - else -> rust("u.into_owned()") + false -> rust("u.into_owned()") } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index d509d048d0..db783efd5a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -41,12 +41,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait -import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.isUnit import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -61,8 +58,11 @@ sealed class CborSerializerSection(name: String) : Section(name) { * Mutate the serializer prior to serializing any structure members. Eg: this can be used to inject `__type` * to record the error type in the case of an error structure. */ - data class BeforeSerializingStructureMembers(val structureShape: StructureShape, val encoderBindingName: String) : - CborSerializerSection("ServerError") + data class BeforeSerializingStructureMembers( + val structureShape: StructureShape, + val encoderBindingName: String, + val isServerErrorResponse: Boolean, + ) : CborSerializerSection("ServerError") /** Manipulate the serializer context for a map prior to it being serialized. **/ data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context) : @@ -191,45 +191,16 @@ class CborSerializerGenerator( } } - // TODO + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun payloadSerializer(member: MemberShape): RuntimeType { - val target = model.expectShape(member.target) - return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName -> - rustBlockTemplate( - "pub fn $fnName(input: &#{target}) -> std::result::Result<#{ByteSlab}, #{Error}>", - *codegenScope, - "target" to symbolProvider.toSymbol(target), - ) { - rust("let mut out = String::new();") - rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) - when (target) { - is StructureShape -> serializeStructure(StructContext("input", target)) - is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target)) - else -> throw IllegalStateException("json payloadSerializer only supports structs and unions") - } - rust("object.finish();") - rustTemplate("Ok(out.into_bytes())", *codegenScope) - } - } + TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573") } - // TODO Unclear whether we'll need this. override fun unsetStructure(structure: StructureShape): RuntimeType = - ProtocolFunctions.crossOperationFn("rest_json_unsetpayload") { fnName -> - rustTemplate( - """ - pub fn $fnName() -> #{ByteSlab} { - b"{}"[..].into() - } - """, - *codegenScope, - ) - } + UNREACHABLE("Only clients use this method when serializing an `@httpPayload`. No protocol using CBOR supports this trait, so we don't need to implement this") - override fun unsetUnion(union: UnionShape): RuntimeType { - // TODO - TODO("Not yet implemented") - } + override fun unsetUnion(union: UnionShape): RuntimeType = + UNREACHABLE("Only clients use this method when serializing an `@httpPayload`. No protocol using CBOR supports this trait, so we don't need to implement this") override fun operationInputSerializer(operationShape: OperationShape): RuntimeType? { // Don't generate an operation CBOR serializer if there is no CBOR body. @@ -238,35 +209,12 @@ class CborSerializerGenerator( return null } - val inputShape = operationShape.inputShape(model) - return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> - rustBlockTemplate( - "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>", - *codegenScope, "target" to symbolProvider.toSymbol(inputShape), - ) { - rust("let mut out = String::new();") - rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) - serializeStructure(StructContext("input", inputShape), httpDocumentMembers) - rust("object.finish();") - rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope) - } - } + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + TODO("Client implementation should fill this out") } - override fun documentSerializer(): RuntimeType { - return ProtocolFunctions.crossOperationFn("serialize_document") { fnName -> - rustTemplate( - """ - pub fn $fnName(input: &#{Document}) -> #{ByteSlab} { - let mut out = String::new(); - #{JsonValueWriter}::new(&mut out).document(input); - out.into_bytes() - } - """, - "Document" to RuntimeType.document(runtimeConfig), *codegenScope, - ) - } - } + override fun documentSerializer(): RuntimeType = + UNREACHABLE("No protocol using CBOR supports `document` shapes, so we don't need to implement this") override fun operationOutputSerializer(operationShape: OperationShape): RuntimeType? { // Don't generate an operation CBOR serializer if there was no operation output shape in the @@ -275,16 +223,7 @@ class CborSerializerGenerator( return null } - // TODO - // Note that, unlike the client, we serialize an empty JSON document `"{}"` if the operation output shape is - // empty (has no members). - // The client instead serializes an empty payload `""` in _both_ these scenarios: - // 1. there is no operation input shape; and - // 2. the operation input shape is empty (has no members). - // The first case gets reduced to the second, because all operations get a synthetic input shape with - // the [OperationNormalizer] transformation. val httpDocumentMembers = httpBindingResolver.responseMembers(operationShape, HttpLocation.DOCUMENT) - val outputShape = operationShape.outputShape(model) return serverSerializer(outputShape, httpDocumentMembers, error = false) } @@ -300,6 +239,8 @@ class CborSerializerGenerator( private fun RustWriter.serializeStructure( context: StructContext, includedMembers: List? = null, + /** Whether we're serializing a top-level structure shape corresponding for a server operation response. */ + isServerErrorResponse: Boolean = false, ) { if (context.shape.isUnit()) { rust( @@ -311,7 +252,6 @@ class CborSerializerGenerator( return } - // TODO Need to inject `__type` when serializing errors. val structureSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> rustBlockTemplate( "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> Result<(), #{Error}>", @@ -322,7 +262,13 @@ class CborSerializerGenerator( // instead of `.begin_map()` for efficiency. Add test. rust("encoder.begin_map();") for (customization in customizations) { - customization.section(CborSerializerSection.BeforeSerializingStructureMembers(context.shape, "encoder"))(this) + customization.section( + CborSerializerSection.BeforeSerializingStructureMembers( + context.shape, + "encoder", + isServerErrorResponse, + ), + )(this) } context.copy(localName = "input").also { inner -> val members = includedMembers ?: inner.shape.members() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt index 92b28d89fc..a85646673d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt @@ -25,12 +25,15 @@ interface StructuredDataSerializerGenerator { fun payloadSerializer(member: MemberShape): RuntimeType /** - * Generate the correct data when attempting to serialize a structure that is unset + * Generate the correct data when attempting to serialize a structure that is unset. * * ```rust * fn rest_json_unset_struct_payload() -> Vec { * ... * } + * ``` + * + * This method is only invoked when serializing an `@httpPayload`. */ fun unsetStructure(structure: StructureShape): RuntimeType @@ -41,6 +44,9 @@ interface StructuredDataSerializerGenerator { * fn rest_json_unset_union_payload() -> Vec { * ... * } + * ``` + * + * This method is only invoked when serializing an `@httpPayload`. */ fun unsetUnion(union: UnionShape): RuntimeType diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt index 59ead97478..d60b2941fa 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt @@ -21,7 +21,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext class AddTypeFieldToServerErrorsCborCustomization : CborSerializerCustomization() { override fun section(section: CborSerializerSection): Writable = when (section) { is CborSerializerSection.BeforeSerializingStructureMembers -> - if (section.structureShape.hasTrait()) { + if (section.isServerErrorResponse && section.structureShape.hasTrait()) { writable { rust( """ From 16ca1883b17518da53819fe22b799241cde0ae4e Mon Sep 17 00:00:00 2001 From: david-perez Date: Fri, 28 Jun 2024 17:53:04 +0200 Subject: [PATCH 28/77] Improve broken protocol test case generation We currently "hotfix" a broken protocol test in-memory, but there's no mechanism that alerts us when the broken protocol test has been fixed upstream when updating our Smithy version. This commit introduces such a mechanism by generating both the original and the fixed test, with a `#[should_panic]` attribute on the former, so that the test fails when all its assertions succeed. With this change, in general this approach of fixing tests in-memory should now be used over adding the broken test to `expectFail` and adding the fixed test to a `-extras.smithy` Smithy model, which is substantially more effort. --- .../protocol/ClientProtocolTestGenerator.kt | 10 +- .../protocol/ProtocolTestGenerator.kt | 237 +++++++++++++++--- .../protocol/ServerProtocolTestGenerator.kt | 207 +++++++-------- 3 files changed, 303 insertions(+), 151 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt index 65a8e21559..f9bad514d3 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt @@ -23,7 +23,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.BrokenTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator @@ -70,9 +70,9 @@ class ClientProtocolTestGenerator( private val ExpectFail = setOf( // Failing because we don't serialize default values if they match the default. - FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse", TestCaseKind.Request), - FailingTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults", TestCaseKind.Request), - FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput", TestCaseKind.Request), + FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse"), + FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults"), + FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput"), ) } @@ -84,6 +84,8 @@ class ClientProtocolTestGenerator( get() = emptySet() override val disabledTests: Set get() = emptySet() + override val brokenTests: Set + get() = emptySet() private val rc = codegenContext.runtimeConfig private val logger = Logger.getLogger(javaClass.name) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index 18abece09a..eb4d301a1f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -1,7 +1,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.protocol import software.amazon.smithy.model.knowledge.OperationIndex -import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape @@ -21,7 +20,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustInlineTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -45,9 +44,17 @@ abstract class ProtocolTestGenerator { /** * We expect these tests to fail due to shortcomings in our implementation. * They will _fail_ if they pass, so we will discover and remove them if we fix them by accident. - **/ + */ abstract val expectFail: Set + /** + * We expect these tests to fail because their definitions are broken. + * We map from a failing test to a "hotfix" function that can mutate the test in-memory and return a fixed version of it. + * The tests will _fail_ if they pass, so we will discover and remove the hotfix if we're updating to a newer + * version of Smithy where the test was fixed upstream. + */ + abstract val brokenTests: Set + /** Only generate these tests; useful to temporarily set and shorten development cycles */ abstract val runOnly: Set @@ -57,18 +64,22 @@ abstract class ProtocolTestGenerator { */ abstract val disabledTests: Set + private val serviceShapeId: ShapeId + get() = codegenContext.serviceShape.id + /** The Rust module in which we should generate the protocol tests for [operationShape]. */ private fun protocolTestsModule(): RustModule.LeafModule { val operationName = codegenContext.symbolProvider.toSymbol(operationShape).name val testModuleName = "${operationName.toSnakeCase()}_test" - val additionalAttributes = - listOf(Attribute(allow("unreachable_code", "unused_variables"))) + val additionalAttributes = listOf(Attribute(allow("unreachable_code", "unused_variables"))) return RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes) } /** The entry point to render the protocol tests, invoked by the code generators. */ fun render(writer: RustWriter) { - val allTests = allTestCases().fixBroken() + val allTests = allTestCases().flatMap { + fixBrokenTestCase(it) + } if (allTests.isEmpty()) { return } @@ -78,15 +89,60 @@ abstract class ProtocolTestGenerator { } } - /** Implementors should describe how to render the test cases. **/ - abstract fun RustWriter.renderAllTestCases(allTests: List) - /** - * This function applies a "fix function" to each broken test before we synthesize it. - * Broken tests are those whose definitions in the `awslabs/smithy` repository are wrong. - * We try to contribute fixes upstream to pare down this function to the identity function. + * This function applies a "hotfix function" to a broken test case before we synthesize it. + * Broken tests are those whose definitions in the `smithy-lang/smithy` repository are wrong. + * We try to contribute fixes upstream to pare down the list of broken tests. + * If the test is broken, we synthesize it in two versions: the original broken test with a `#[should_panic]` + * attribute, so get alerted if the test now passes, and the fixed version, which should pass. */ - open fun List.fixBroken(): List = this + private fun fixBrokenTestCase(it: TestCase): List = if (!it.isBroken()) { + listOf(it) + } else { + assert(it.expectFail()) + + val brokenTest = it.findInBroken()!! + var fixed = brokenTest.fixIt(it) + + val intro = "The hotfix function for broken test case ${it.kind} ${it.id}" + val moreInfo = + """This test case was identified to be broken in at least these Smithy versions: [${brokenTest.inAtLeast.joinToString()}]. + |We are tracking things here: [${brokenTest.trackedIn.joinToString()}].""".trimMargin() + + // Something must change... + if (it == fixed) { + PANIC( + """$intro did not make any modifications. It is likely that the test case was + |fixed upstream, and you're now updating the Smithy version; in this case, remove the hotfix + |function, as the test is no longer broken. + |$moreInfo""".trimMargin(), + ) + } + + // ... but the hotfix function is not allowed to change the test case kind... + if (it.kind != fixed.kind) { + PANIC( + """$intro changed the test case kind. This is not allowed. + |$moreInfo""".trimMargin(), + ) + } + + // ... nor its id. + if (it.id != fixed.id) { + PANIC( + """$intro changed the test case id. This is not allowed. + |$moreInfo""".trimMargin(), + ) + } + + // The latter is because we're going to generate the fixed version with an identifiable suffix. + fixed = fixed.suffixIdWith("_hotfixed") + + listOf(it, fixed) + } + + /** Implementors should describe how to render the test cases. **/ + abstract fun RustWriter.renderAllTestCases(allTests: List) /** Filter out test cases that are disabled or don't match the service protocol. */ private fun List.filterMatching(): List = if (runOnly.isEmpty()) { @@ -95,11 +151,23 @@ abstract class ProtocolTestGenerator { this.filter { testCase -> runOnly.contains(testCase.id) } } - /** Do we expect this [testCase] to fail? */ - private fun expectFail(testCase: TestCase): Boolean = - expectFail.find { - it.id == testCase.id && it.kind == testCase.kind && it.service == codegenContext.serviceShape.id.toString() - } != null + private fun TestCase.toFailingTest(): FailingTest = when (this) { + is TestCase.MalformedRequestTest -> FailingTest.MalformedRequestTest(serviceShapeId.toString(), this.id) + is TestCase.RequestTest -> FailingTest.RequestTest(serviceShapeId.toString(), this.id) + is TestCase.ResponseTest -> FailingTest.ResponseTest(serviceShapeId.toString(), this.id) + } + + /** Do we expect this test case to fail? */ + private fun TestCase.expectFail(): Boolean = this.isBroken() || expectFail.contains(this.toFailingTest()) + + /** Is this test case broken? */ + private fun TestCase.isBroken(): Boolean = this.findInBroken() != null + + private fun TestCase.findInBroken(): BrokenTest? = brokenTests.find { brokenTest -> + (this is TestCase.RequestTest && brokenTest is BrokenTest.RequestTest && this.id == brokenTest.id) || + (this is TestCase.ResponseTest && brokenTest is BrokenTest.ResponseTest && this.id == brokenTest.id) || + (this is TestCase.MalformedRequestTest && brokenTest is BrokenTest.MalformedRequestTest && this.id == brokenTest.id) + } fun requestTestCases(): List { val requestTests = operationShape.getTrait()?.getTestCasesFor(appliesTo).orEmpty() @@ -114,15 +182,12 @@ abstract class ProtocolTestGenerator { // `@httpResponseTests` trait can apply to operation shapes and structure shapes with the `@error` trait. // Find both kinds for the operation for which we're generating protocol tests. val responseTestsOnOperations = - operationShape.getTrait() - ?.getTestCasesFor(appliesTo).orEmpty().map { TestCase.ResponseTest(it, outputShape) } - val responseTestsOnErrors = - operationIndex.getErrors(operationShape).flatMap { error -> - val testCases = - error.getTrait() - ?.getTestCasesFor(appliesTo).orEmpty() - testCases.map { TestCase.ResponseTest(it, error) } - } + operationShape.getTrait()?.getTestCasesFor(appliesTo).orEmpty() + .map { TestCase.ResponseTest(it, outputShape) } + val responseTestsOnErrors = operationIndex.getErrors(operationShape).flatMap { error -> + val testCases = error.getTrait()?.getTestCasesFor(appliesTo).orEmpty() + testCases.map { TestCase.ResponseTest(it, error) } + } return (responseTestsOnOperations + responseTestsOnErrors).filterMatching() } @@ -130,8 +195,8 @@ abstract class ProtocolTestGenerator { fun malformedRequestTestCases(): List { // `@httpMalformedRequestTests` only make sense for servers. val malformedRequestTests = if (appliesTo == AppliesTo.SERVER) { - operationShape.getTrait() - ?.testCases.orEmpty().map { TestCase.MalformedRequestTest(it) } + operationShape.getTrait()?.testCases.orEmpty() + .map { TestCase.MalformedRequestTest(it) } } else { emptyList() } @@ -152,6 +217,7 @@ abstract class ProtocolTestGenerator { block: Writable, ) { if (testCase.documentation != null) { + testModuleWriter.rust("") testModuleWriter.docs(testCase.documentation!!, templating = false) } testModuleWriter.docs("Test ID: ${testCase.id}") @@ -159,11 +225,11 @@ abstract class ProtocolTestGenerator { // The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in // case the code rendered by the test does not make use of `tracing` at all. val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) } - testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) + testModuleWriter.rustInlineTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) Attribute.TokioTest.render(testModuleWriter) Attribute.TracedTest.render(testModuleWriter) - if (expectFail(testCase)) { + if (testCase.expectFail()) { testModuleWriter.writeWithNoFormatting("#[should_panic]") } val fnNameSuffix = testCase.kind.toString().toSnakeCase() @@ -268,6 +334,52 @@ abstract class ProtocolTestGenerator { } } +sealed class BrokenTest( + open val serviceShapeId: String, + open val id: String, + + /** A non-exhaustive set of Smithy versions where the test was found to be broken. */ + open val inAtLeast: Set, + /** + * GitHub URLs related to the test brokenness, like a GitHub issue in Smithy where we reported the test was broken, + * or a PR where we fixed it. + **/ + open val trackedIn: Set +) { + data class RequestTest( + override val serviceShapeId: String, + override val id: String, + override val inAtLeast: Set, + override val trackedIn: Set, + val howToFixItFn: (TestCase.RequestTest) -> TestCase.RequestTest, + ) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn) + + data class ResponseTest( + override val serviceShapeId: String, + override val id: String, + override val inAtLeast: Set, + override val trackedIn: Set, + val howToFixItFn: (TestCase.ResponseTest) -> TestCase.ResponseTest, + ) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn) + + data class MalformedRequestTest( + override val serviceShapeId: String, + override val id: String, + override val inAtLeast: Set, + override val trackedIn: Set, + val howToFixItFn: (TestCase.MalformedRequestTest) -> TestCase.MalformedRequestTest, + ) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn) + + fun fixIt(testToFix: TestCase): TestCase { + check(testToFix.id == this.id) + return when (this) { + is MalformedRequestTest -> howToFixItFn(testToFix as TestCase.MalformedRequestTest) + is RequestTest -> howToFixItFn(testToFix as TestCase.RequestTest) + is ResponseTest -> howToFixItFn(testToFix as TestCase.ResponseTest) + } + } +} + /** * Service shape IDs in common protocol test suites defined upstream. */ @@ -283,7 +395,14 @@ object ServiceShapeId { const val REST_JSON_VALIDATION = "aws.protocoltests.restjson.validation#RestJsonValidation" } -data class FailingTest(val service: String, val id: String, val kind: TestCaseKind) +sealed class FailingTest(open val serviceShapeId: String, open val id: String) { + data class RequestTest(override val serviceShapeId: String, override val id: String) : + FailingTest(serviceShapeId, id) + data class ResponseTest(override val serviceShapeId: String, override val id: String) : + FailingTest(serviceShapeId, id) + data class MalformedRequestTest(override val serviceShapeId: String, override val id: String) : + FailingTest(serviceShapeId, id) +} sealed class TestCaseKind { data object Request : TestCaseKind() @@ -292,9 +411,57 @@ sealed class TestCaseKind { } sealed class TestCase { - data class RequestTest(val testCase: HttpRequestTestCase) : TestCase() - data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() - data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase() + /* + * The properties of these data classes don't implement `equals()` usefully in Smithy, so we delegate to `equals()` + * of their `Node` representations. + */ + + data class RequestTest(val testCase: HttpRequestTestCase) : TestCase() { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is RequestTest) return false + return testCase.toNode().equals(other.testCase.toNode()) + } + + override fun hashCode(): Int = testCase.hashCode() + } + + data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is ResponseTest) return false + return testCase.toNode().equals(other.testCase.toNode()) + } + + override fun hashCode(): Int = testCase.hashCode() + } + + data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase() { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is MalformedRequestTest) return false + return this.protocol == other.protocol && this.id == other.id && this.documentation == other.documentation && this.testCase.request.toNode() + .equals(other.testCase.request.toNode()) && this.testCase.response.toNode() + .equals(other.testCase.response.toNode()) + } + + override fun hashCode(): Int = testCase.hashCode() + } + + fun suffixIdWith(suffix: String): TestCase = when (this) { + is RequestTest -> RequestTest(this.testCase.suffixIdWith(suffix)) + is MalformedRequestTest -> MalformedRequestTest(this.testCase.suffixIdWith(suffix)) + is ResponseTest -> ResponseTest(this.testCase.suffixIdWith(suffix), this.targetShape) + } + + private fun HttpRequestTestCase.suffixIdWith(suffix: String): HttpRequestTestCase = + this.toBuilder().id(this.id + suffix).build() + + private fun HttpResponseTestCase.suffixIdWith(suffix: String): HttpResponseTestCase = + this.toBuilder().id(this.id + suffix).build() + + private fun HttpMalformedRequestTestCase.suffixIdWith(suffix: String): HttpMalformedRequestTestCase = + this.toBuilder().id(this.id + suffix).build() /* * `HttpRequestTestCase` and `HttpResponseTestCase` both implement `HttpMessageTestCase`, but diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 03d1965f67..22828d1bd1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.BrokenTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator @@ -40,7 +41,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Servi import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON_VALIDATION import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR_EXTRAS import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase -import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember @@ -54,7 +54,6 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator import java.util.logging.Logger -import kotlin.reflect.KFunction1 /** * Generate server protocol tests for an [operationShape]. @@ -63,95 +62,113 @@ class ServerProtocolTestGenerator( override val codegenContext: CodegenContext, override val protocolSupport: ProtocolSupport, override val operationShape: OperationShape, -): ProtocolTestGenerator() { +) : ProtocolTestGenerator() { companion object { private val ExpectFail: Set = setOf( // Endpoint trait is not implemented yet, see https://github.com/smithy-lang/smithy-rs/issues/950. - FailingTest(REST_JSON, "RestJsonEndpointTrait", TestCaseKind.Request), - FailingTest(REST_JSON, "RestJsonEndpointTraitWithHostLabel", TestCaseKind.Request), - FailingTest(REST_JSON, "RestJsonOmitsEmptyListQueryValues", TestCaseKind.Request), + FailingTest.RequestTest(REST_JSON, "RestJsonEndpointTrait"), + FailingTest.RequestTest(REST_JSON, "RestJsonEndpointTraitWithHostLabel"), + FailingTest.RequestTest(REST_JSON, "RestJsonOmitsEmptyListQueryValues"), // TODO(https://github.com/smithy-lang/smithy/pull/2315): Can be deleted when fixed tests are consumed in next Smithy version - FailingTest(REST_JSON, "RestJsonEnumPayloadRequest", TestCaseKind.Request), - FailingTest(REST_JSON, "RestJsonStringPayloadRequest", TestCaseKind.Request), + FailingTest.RequestTest(REST_JSON, "RestJsonEnumPayloadRequest"), + FailingTest.RequestTest(REST_JSON, "RestJsonStringPayloadRequest"), // Tests involving `@range` on floats. // Pending resolution from the Smithy team, see https://github.com/smithy-lang/smithy-rs/issues/2007. - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloat_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloat_case1", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMaxFloat", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMinFloat", TestCaseKind.MalformedRequest), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloat_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloat_case1"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMaxFloat"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMinFloat"), // Tests involving floating point shapes and the `@range` trait; see https://github.com/smithy-lang/smithy-rs/issues/2007 - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloatOverride_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloatOverride_case1", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMaxFloatOverride", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMinFloatOverride", TestCaseKind.MalformedRequest), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloatOverride_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeFloatOverride_case1"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMaxFloatOverride"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedRangeMinFloatOverride"), // Some tests for the S3 service (restXml). - FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", TestCaseKind.Response), - FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", TestCaseKind.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing", TestCaseKind.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3PathAddressing", TestCaseKind.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAddressing", TestCaseKind.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestCaseKind.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestCaseKind.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestCaseKind.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationNoErrorWrappingResponse", TestCaseKind.Response), + FailingTest.ResponseTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput"), + FailingTest.RequestTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing"), + FailingTest.RequestTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing"), + FailingTest.RequestTest("com.amazonaws.s3#AmazonS3", "S3PathAddressing"), + FailingTest.RequestTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAddressing"), + FailingTest.RequestTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing"), + FailingTest.RequestTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing"), + FailingTest.RequestTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred"), + FailingTest.ResponseTest("com.amazonaws.s3#AmazonS3", "S3OperationNoErrorWrappingResponse"), // AwsJson1.0 failing tests. - FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel", TestCaseKind.Request), - FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestCaseKind.Request), + FailingTest.RequestTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel"), + FailingTest.RequestTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait"), // AwsJson1.1 failing tests. - FailingTest(AWS_JSON_11, "AwsJson11EndpointTraitWithHostLabel", TestCaseKind.Request), - FailingTest(AWS_JSON_11, "AwsJson11EndpointTrait", TestCaseKind.Request), - FailingTest(AWS_JSON_11, "parses_the_request_id_from_the_response", TestCaseKind.Response), + FailingTest.RequestTest(AWS_JSON_11, "AwsJson11EndpointTraitWithHostLabel"), + FailingTest.RequestTest(AWS_JSON_11, "AwsJson11EndpointTrait"), + FailingTest.ResponseTest(AWS_JSON_11, "parses_the_request_id_from_the_response"), // TODO(https://github.com/awslabs/smithy/issues/1683): This has been marked as failing until resolution of said issue - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBlobList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBooleanList_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBooleanList_case1", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsStringList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsByteList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsShortList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsIntegerList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsLongList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsTimestampList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsDateTimeList", TestCaseKind.MalformedRequest), - FailingTest( + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBlobList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBooleanList_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsBooleanList_case1"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsStringList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsByteList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsShortList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsIntegerList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsLongList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsTimestampList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsDateTimeList"), + FailingTest.MalformedRequestTest( REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsHttpDateList_case0", - TestCaseKind.MalformedRequest, ), - FailingTest( + FailingTest.MalformedRequestTest( REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsHttpDateList_case1", - TestCaseKind.MalformedRequest, ), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsEnumList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsIntEnumList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsListList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsStructureList", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsUnionList_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsUnionList_case1", TestCaseKind.MalformedRequest), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsEnumList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsIntEnumList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsListList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsStructureList"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsUnionList_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedUniqueItemsUnionList_case1"), // TODO(https://github.com/smithy-lang/smithy-rs/issues/2472): We don't respect the `@internal` trait - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumList_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumList_case1", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapKey_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapKey_case1", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapValue_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapValue_case1", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumString_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumString_case1", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumUnion_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumUnion_case1", TestCaseKind.MalformedRequest), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumList_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumList_case1"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapKey_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapKey_case1"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapValue_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumMapValue_case1"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumString_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumString_case1"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumUnion_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumUnion_case1"), // TODO(https://github.com/awslabs/smithy/issues/1737): Specs on @internal, @tags, and enum values need to be clarified - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumTraitString_case0", TestCaseKind.MalformedRequest), - FailingTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumTraitString_case1", TestCaseKind.MalformedRequest), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumTraitString_case0"), + FailingTest.MalformedRequestTest(REST_JSON_VALIDATION, "RestJsonMalformedEnumTraitString_case1"), // These tests are broken because they are missing a target header. - FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesNestedDefaultsWhenMissingInRequestBody", TestCaseKind.Request), - FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesDefaultsWhenMissingInRequestBody", TestCaseKind.Request), + FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ServerPopulatesNestedDefaultsWhenMissingInRequestBody"), + FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ServerPopulatesDefaultsWhenMissingInRequestBody"), // Response defaults are not set when builders are not used https://github.com/smithy-lang/smithy-rs/issues/3339 - FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesDefaultsInResponseWhenMissingInParams", TestCaseKind.Response), - FailingTest(AWS_JSON_10, "AwsJson10ServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams", TestCaseKind.Response), + FailingTest.ResponseTest(AWS_JSON_10, "AwsJson10ServerPopulatesDefaultsInResponseWhenMissingInParams"), + FailingTest.ResponseTest( + AWS_JSON_10, + "AwsJson10ServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams", + ), // TODO(https://github.com/smithy-lang/smithy-rs/issues/3723): This affects all protocols - FailingTest(RPC_V2_CBOR_EXTRAS, "AdditionalTokensEmptyStruct", TestCaseKind.MalformedRequest), + FailingTest.MalformedRequestTest(RPC_V2_CBOR_EXTRAS, "AdditionalTokensEmptyStruct"), + ) + + private val BrokenTests: + Set = + setOf( + BrokenTest.MalformedRequestTest( + REST_JSON_VALIDATION, + "RestJsonMalformedPatternReDOSString", + howToFixItFn = ::fixRestJsonMalformedPatternReDOSString, + inAtLeast = setOf("1.26.2", "1.49.0"), + trackedIn = setOf( + // TODO(https://github.com/awslabs/smithy/issues/1506) + "https://github.com/awslabs/smithy/issues/1506", + // TODO(https://github.com/smithy-lang/smithy/pull/2340) + "https://github.com/smithy-lang/smithy/pull/2340", + ), + ), ) private val DisabledTests = @@ -176,11 +193,8 @@ class ServerProtocolTestGenerator( "S3PreservesEmbeddedDotSegmentInUriLabel", ) - // TODO(https://github.com/awslabs/smithy/issues/1506) - private fun fixRestJsonMalformedPatternReDOSString( - testCase: HttpMalformedRequestTestCase, - ): HttpMalformedRequestTestCase { - val brokenResponse = testCase.response + private fun fixRestJsonMalformedPatternReDOSString(testCase: TestCase.MalformedRequestTest): TestCase.MalformedRequestTest { + val brokenResponse = testCase.testCase.response val brokenBody = brokenResponse.body.get() val fixedBody = HttpMalformedResponseBodyDefinition.builder() @@ -195,31 +209,20 @@ class ServerProtocolTestGenerator( ) .build() - return testCase.toBuilder() - .response(brokenResponse.toBuilder().body(fixedBody).build()) - .build() - } - - // TODO(https://github.com/smithy-lang/smithy-rs/issues/1288): Move the fixed versions into - // `rest-json-extras.smithy` and put the unfixed ones in `ExpectFail`: this has the - // advantage that once our upstream PRs get merged and we upgrade to the next Smithy release, our build will - // fail and we will take notice to remove the fixes from `rest-json-extras.smithy`. This is exactly what the - // client does. - private val BrokenMalformedRequestTests: - Map, KFunction1> = - // TODO(https://github.com/awslabs/smithy/issues/1506) - mapOf( - Pair( - REST_JSON_VALIDATION, - "RestJsonMalformedPatternReDOSString", - ) to ::fixRestJsonMalformedPatternReDOSString, + return TestCase.MalformedRequestTest( + testCase.testCase.toBuilder() + .response(brokenResponse.toBuilder().body(fixedBody).build()) + .build(), ) + } } override val appliesTo: AppliesTo get() = AppliesTo.SERVER override val expectFail: Set get() = ExpectFail + override val brokenTests: Set + get() = BrokenTests override val runOnly: Set get() = emptySet() override val disabledTests: Set @@ -280,26 +283,6 @@ class ServerProtocolTestGenerator( } } - /** - * Broken tests in the `awslabs/smithy` repository are usually wrong because they have not been written - * with a server-side perspective in mind. - */ - override fun List.fixBroken(): List = - this.map { - when (it) { - is TestCase.MalformedRequestTest -> { - val howToFixIt = BrokenMalformedRequestTests[Pair(codegenContext.serviceShape.id.toString(), it.id)] - if (howToFixIt == null) { - it - } else { - val fixed = howToFixIt(it.testCase) - TestCase.MalformedRequestTest(fixed) - } - } - else -> it - } - } - /** * Renders an HTTP request test case. * We are given an HTTP request in the test case, and we assert that when we deserialize said HTTP request into @@ -350,7 +333,7 @@ class ServerProtocolTestGenerator( if (!protocolSupport.responseSerialization || ( !protocolSupport.errorSerialization && shape.hasTrait() - ) + ) ) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -445,7 +428,7 @@ class ServerProtocolTestGenerator( // // We also escape to avoid interactions with templating in the case where the body contains `#`. val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() - + // TODO(https://github.com/smithy-lang/smithy/issues/1932): We're using the `protocol` field as a // proxy for `bodyMediaType`. This works because `rpcv2Cbor` happens to be the only protocol where // the body is base64-encoded in the protocol test, but checking `bodyMediaType` should be a more From a503606f7426958993104844561b13a3066d6765 Mon Sep 17 00:00:00 2001 From: david-perez Date: Fri, 28 Jun 2024 19:21:26 +0200 Subject: [PATCH 29/77] Fix serialization of __type --- .../smithy/protocols/serialize/CborSerializerGenerator.kt | 6 +++++- .../generators/protocol/ServerProtocolTestGenerator.kt | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index db783efd5a..39cf4d5217 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -184,7 +184,11 @@ class CborSerializerGenerator( // Open a scope in which we can safely shadow the `encoder` variable to bind it to a mutable reference. rustBlock("") { rust("let encoder = &mut encoder;") - serializeStructure(StructContext("value", structureShape), includedMembers) + serializeStructure( + StructContext("value", structureShape), + includedMembers, + isServerErrorResponse = true, + ) } rust("Ok(encoder.into_writer())") } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 126fd685b4..8b7f0da116 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -39,6 +39,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Servi import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_11 import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON_VALIDATION +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR_EXTRAS import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors From a8ba4fb8f623956213a7f9c46446f7119187ddee Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 2 Jul 2024 17:57:46 +0200 Subject: [PATCH 30/77] fixes in rpcv2Cbor_extras --- .../rpcv2Cbor-extras.smithy | 29 +++++++++++-------- .../serialize/CborSerializerGenerator.kt | 5 ---- ...ypeFieldToServerErrorsCborCustomization.kt | 23 ++++++++++++++- ...rGeneratorSerdeRoundTripIntegrationTest.kt | 6 ++++ 4 files changed, 45 insertions(+), 18 deletions(-) diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy index a633ec9e4f..37ad8f21d1 100644 --- a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -29,7 +29,7 @@ operation SimpleStructOperation { operation ErrorSerializationOperation { input: SimpleStruct - output: ValidationException + output: ErrorSerializationOperationOutput errors: [ValidationException] } @@ -73,7 +73,7 @@ apply EmptyStructOperation @httpMalformedRequestTests([ mediaType: "application/cbor", assertion: { // An empty CBOR map. - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type`. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type` because `SerializationException` is not modeled. contents: "oA==" } } @@ -137,8 +137,7 @@ apply ErrorSerializationOperation @httpMalformedRequestTests([ body: { mediaType: "application/cbor", assertion: { - // TODO Adjust - contents: "oA==" + contents: "v2ZfX3R5cGV4JHNtaXRoeS5mcmFtZXdvcmsjVmFsaWRhdGlvbkV4Y2VwdGlvbmdtZXNzYWdleGsxIHZhbGlkYXRpb24gZXJyb3IgZGV0ZWN0ZWQuIFZhbHVlIGF0ICcvcmVxdWlyZWRCbG9iJyBmYWlsZWQgdG8gc2F0aXNmeSBjb25zdHJhaW50OiBNZW1iZXIgbXVzdCBub3QgYmUgbnVsbGlmaWVsZExpc3SBv2RwYXRobS9yZXF1aXJlZEJsb2JnbWVzc2FnZXhOVmFsdWUgYXQgJy9yZXF1aXJlZEJsb2InIGZhaWxlZCB0byBzYXRpc2Z5IGNvbnN0cmFpbnQ6IE1lbWJlciBtdXN0IG5vdCBiZSBudWxs//8=" } } } @@ -147,22 +146,24 @@ apply ErrorSerializationOperation @httpMalformedRequestTests([ apply ErrorSerializationOperation @httpResponseTests([ { - id: "OperationOutputSerializationDoesNotIncludeTypeField", + id: "OperationOutputSerializationQuestionablyIncludesTypeField", documentation: """ Despite the operation output being a structure shape with the `@error` trait, - `__type` field should not be included, because we're not serializing a - server error response""", + `__type` field should, in a strict interpretation of the spec, not be included, + because we're not serializing a server error response. However, we do, because + there shouldn't™️ be any harm in doing so, and it greatly simplifies the + code generator. This test just pins this behavior in case we ever modify it.""", protocol: rpcv2Cbor, - // TODO This should fail with another code! code: 200, params: { - message: "ValidationException message field" + errorShape: { + message: "ValidationException message field" + } } bodyMediaType: "application/cbor" - // TODO Adjust - body: "" + body: "v2plcnJvclNoYXBlv2ZfX3R5cGV4JHNtaXRoeS5mcmFtZXdvcmsjVmFsaWRhdGlvbkV4Y2VwdGlvbmdtZXNzYWdleCFWYWxpZGF0aW9uRXhjZXB0aW9uIG1lc3NhZ2UgZmllbGT//w==" } -)] +]) apply SimpleStructOperation @httpResponseTests([ { @@ -230,6 +231,10 @@ apply SimpleStructOperation @httpResponseTests([ } ]) +structure ErrorSerializationOperationOutput { + errorShape: ValidationException +} + structure SimpleStruct { blob: Blob boolean: Boolean diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 39cf4d5217..ca8bdcd223 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -61,7 +61,6 @@ sealed class CborSerializerSection(name: String) : Section(name) { data class BeforeSerializingStructureMembers( val structureShape: StructureShape, val encoderBindingName: String, - val isServerErrorResponse: Boolean, ) : CborSerializerSection("ServerError") /** Manipulate the serializer context for a map prior to it being serialized. **/ @@ -187,7 +186,6 @@ class CborSerializerGenerator( serializeStructure( StructContext("value", structureShape), includedMembers, - isServerErrorResponse = true, ) } rust("Ok(encoder.into_writer())") @@ -243,8 +241,6 @@ class CborSerializerGenerator( private fun RustWriter.serializeStructure( context: StructContext, includedMembers: List? = null, - /** Whether we're serializing a top-level structure shape corresponding for a server operation response. */ - isServerErrorResponse: Boolean = false, ) { if (context.shape.isUnit()) { rust( @@ -270,7 +266,6 @@ class CborSerializerGenerator( CborSerializerSection.BeforeSerializingStructureMembers( context.shape, "encoder", - isServerErrorResponse, ), )(this) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt index d60b2941fa..ce2e590519 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt @@ -17,11 +17,32 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext /** * Smithy RPC v2 CBOR requires errors to be serialized in server responses with an additional `__type` field. + * + * Note that we apply this customization when serializing _any_ structure with the `@error` trait, regardless if it's + * an error response or not. Consider this model: + * + * ```smithy + * operation ErrorSerializationOperation { + * input: SimpleStruct + * output: ErrorSerializationOperationOutput + * errors: [ValidationException] + * } + * + * structure ErrorSerializationOperationOutput { + * errorShape: ValidationException + * } + * ``` + * + * `ValidationException` is re-used across the operation output and the operation error. The `__type` field will + * appear when serializing both. + * + * Strictly speaking, the spec says we should only add `__type` when serializing an operation error response, but + * there shouldn't™️ be any harm in always including it, which simplifies the code generator. */ class AddTypeFieldToServerErrorsCborCustomization : CborSerializerCustomization() { override fun section(section: CborSerializerSection): Writable = when (section) { is CborSerializerSection.BeforeSerializingStructureMembers -> - if (section.isServerErrorResponse && section.structureShape.hasTrait()) { + if (section.structureShape.hasTrait()) { writable { rust( """ diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt index c24f88a88b..71187f1516 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -34,6 +34,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.BrokenTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator @@ -57,6 +58,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilde import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRpcV2CborFactory import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest import java.util.function.Predicate +import java.util.logging.Logger /** * This lives in `codegen-server` because we want to run a full integration test for convenience, @@ -188,8 +190,12 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { get() = baseGenerator.operationShape override val appliesTo: AppliesTo get() = baseGenerator.appliesTo + override val logger: Logger + get() = Logger.getLogger(javaClass.name) override val expectFail: Set get() = baseGenerator.expectFail + override val brokenTests: Set + get() = emptySet() override val runOnly: Set get() = baseGenerator.runOnly override val disabledTests: Set From 8da229874b68c48dc9672de4de3a82dbf33df54b Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 2 Jul 2024 18:03:21 +0200 Subject: [PATCH 31/77] ./gradlew ktlintFormat --- buildSrc/src/main/kotlin/CodegenTestCommon.kt | 53 ++- .../customize/ClientCodegenDecorator.kt | 1 - .../smithy/protocols/ClientProtocolLoader.kt | 19 +- .../codegen/core/rustlang/CargoDependency.kt | 9 +- .../codegen/core/smithy/CoreRustSettings.kt | 11 +- .../rust/codegen/core/smithy/RuntimeType.kt | 2 +- .../core/smithy/generators/Instantiator.kt | 34 +- .../protocol/ProtocolTestGenerator.kt | 62 +-- .../core/smithy/protocols/RpcV2Cbor.kt | 47 ++- .../protocols/parse/CborParserGenerator.kt | 375 +++++++++--------- .../serialize/CborSerializerGenerator.kt | 153 +++---- .../serialize/JsonSerializerGenerator.kt | 2 +- .../transformers/OperationNormalizer.kt | 10 +- .../server/smithy/ServerCodegenVisitor.kt | 5 +- ...ypeFieldToServerErrorsCborCustomization.kt | 28 +- ...ncodingMapOrCollectionCborCustomization.kt | 30 +- .../smithy/generators/ServerInstantiator.kt | 22 +- .../generators/protocol/ServerProtocol.kt | 38 +- .../protocol/ServerProtocolTestGenerator.kt | 19 +- .../ServerHttpBoundProtocolGenerator.kt | 25 +- .../protocols/ServerRpcV2CborFactory.kt | 7 +- .../testutil/ServerCodegenIntegrationTest.kt | 1 - ...rGeneratorSerdeRoundTripIntegrationTest.kt | 204 +++++----- 23 files changed, 617 insertions(+), 540 deletions(-) diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index 475475160c..202e6a0c83 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -30,27 +30,24 @@ fun generateImports(imports: List): String = } fun toRustCrateName(input: String): String { - val rustKeywords = setOf( - // Strict Keywords. - "as", "break", "const", "continue", "crate", "else", "enum", "extern", - "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", - "move", "mut", "pub", "ref", "return", "Self", "self", "static", "struct", - "super", "trait", "true", "type", "unsafe", "use", "where", "while", - - // Weak Keywords. - "dyn", "async", "await", "try", - - // Reserved for Future Use. - "abstract", "become", "box", "do", "final", "macro", "override", "priv", - "typeof", "unsized", "virtual", "yield", - - // Primitive Types. - "bool", "char", "i8", "i16", "i32", "i64", "i128", "isize", - "u8", "u16", "u32", "u64", "u128", "usize", "f32", "f64", "str", - - // Additional significant identifiers. - "proc_macro" - ) + val rustKeywords = + setOf( + // Strict Keywords. + "as", "break", "const", "continue", "crate", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", + "move", "mut", "pub", "ref", "return", "Self", "self", "static", "struct", + "super", "trait", "true", "type", "unsafe", "use", "where", "while", + // Weak Keywords. + "dyn", "async", "await", "try", + // Reserved for Future Use. + "abstract", "become", "box", "do", "final", "macro", "override", "priv", + "typeof", "unsized", "virtual", "yield", + // Primitive Types. + "bool", "char", "i8", "i16", "i32", "i64", "i128", "isize", + "u8", "u16", "u32", "u64", "u128", "usize", "f32", "f64", "str", + // Additional significant identifiers. + "proc_macro", + ) if (input.isBlank()) { throw IllegalArgumentException("Rust crate name cannot be empty") @@ -61,16 +58,16 @@ fun toRustCrateName(input: String): String { // Trim leading or trailing underscores. val trimmed = sanitized.trim('_') // Check if the resulting string is empty, purely numeric, or a reserved name - val finalName = when { - trimmed.isEmpty() -> throw IllegalArgumentException("Rust crate name after sanitizing cannot be empty.") - trimmed.matches(Regex("\\d+")) -> "n$trimmed" // Prepend 'n' if the name is purely numeric. - trimmed in rustKeywords -> "${trimmed}_" // Append an underscore if the name is reserved. - else -> trimmed - } + val finalName = + when { + trimmed.isEmpty() -> throw IllegalArgumentException("Rust crate name after sanitizing cannot be empty.") + trimmed.matches(Regex("\\d+")) -> "n$trimmed" // Prepend 'n' if the name is purely numeric. + trimmed in rustKeywords -> "${trimmed}_" // Append an underscore if the name is reserved. + else -> trimmed + } return finalName } - private fun generateSmithyBuild( projectDir: String, pluginName: String, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt index ecc2c3132b..306a439a9e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt @@ -19,7 +19,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorC import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator -import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import java.util.ServiceLoader import java.util.logging.Logger diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index dc49c5a6c6..f1a01edd6b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -35,15 +35,16 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait class ClientProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { companion object { - val DefaultProtocols = mapOf( - AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10), - AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11), - AwsQueryTrait.ID to ClientAwsQueryFactory(), - Ec2QueryTrait.ID to ClientEc2QueryFactory(), - RestJson1Trait.ID to ClientRestJsonFactory(), - RestXmlTrait.ID to ClientRestXmlFactory(), - Rpcv2CborTrait.ID to ClientRpcV2CborFactory(), - ) + val DefaultProtocols = + mapOf( + AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10), + AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11), + AwsQueryTrait.ID to ClientAwsQueryFactory(), + Ec2QueryTrait.ID to ClientEc2QueryFactory(), + RestJson1Trait.ID to ClientRestJsonFactory(), + RestXmlTrait.ID to ClientRestXmlFactory(), + Rpcv2CborTrait.ID to ClientRpcV2CborFactory(), + ) val Default = ClientProtocolLoader(DefaultProtocols) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 05acbd0c33..f34921dfff 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -42,10 +42,11 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer { ) + dependencies().flatMap { it.dependencies } } - open fun toDevDependency(): RustDependency = when (this) { - is CargoDependency -> this.toDevDependency() - is InlineDependency -> PANIC("it does not make sense for an inline dependency to be a dev-dependency") - } + open fun toDevDependency(): RustDependency = + when (this) { + is CargoDependency -> this.toDevDependency() + is InlineDependency -> PANIC("it does not make sense for an inline dependency to be a dev-dependency") + } companion object { private const val PROPERTY_KEY = "rustdep" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt index 5822ca4534..f8e3c75ec6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt @@ -96,11 +96,12 @@ open class CoreRustSettings( * @return Returns the found `Service` * @throws CodegenException if the service is invalid or not found */ - fun getService(model: Model): ServiceShape = model - .getShape(service) - .orElseThrow { CodegenException("Service shape not found: $service") } - .asServiceShape() - .orElseThrow { CodegenException("Shape is not a service: $service") } + fun getService(model: Model): ServiceShape = + model + .getShape(service) + .orElseThrow { CodegenException("Service shape not found: $service") } + .asServiceShape() + .orElseThrow { CodegenException("Shape is not a service: $service") } companion object { private val LOGGER: Logger = Logger.getLogger(CoreRustSettings::class.java.name) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index 42a1b1edf5..cd20c22502 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -291,7 +291,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val PercentEncoding = CargoDependency.PercentEncoding.toType() val PrettyAssertions = CargoDependency.PrettyAssertions.toType() val Regex = CargoDependency.Regex.toType() - val Serde= CargoDependency.Serde.toType() + val Serde = CargoDependency.Serde.toType() val SerdeDeserialize = Serde.resolve("Deserialize") val SerdeSerialize = Serde.resolve("Serialize") val RegexLite = CargoDependency.RegexLite.toType() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index 2bcc7836d7..4e4b95897d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -32,6 +32,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.DefaultTrait import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.HttpHeaderTrait import software.amazon.smithy.model.traits.HttpPayloadTrait @@ -61,12 +62,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectMember import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.letIf import java.math.BigDecimal -import software.amazon.smithy.model.traits.DefaultTrait -import software.amazon.smithy.rust.codegen.core.util.getTrait import kotlin.jvm.optionals.getOrNull /** @@ -218,12 +218,12 @@ open class Instantiator( ")", // The conditions are not commutative: note client builders always take in `Option`. conditional = - symbol.isOptional() || - ( - model.expectShape(memberShape.container) is StructureShape && - builderKindBehavior.doesSetterTakeInOption( - memberShape, - ) + symbol.isOptional() || + ( + model.expectShape(memberShape.container) is StructureShape && + builderKindBehavior.doesSetterTakeInOption( + memberShape, + ) ), *preludeScope, ) { @@ -431,12 +431,13 @@ open class Instantiator( } for ((key, value) in data.members) { - val memberShape = shape.getMember(key.value).getOrNull() - ?: if (ignoreMissingMembers) { - continue - } else { - throw CodegenException("Protocol test defines data for member shape `${key.value}`, but member shape was not found on structure shape ${shape.id}") - } + val memberShape = + shape.getMember(key.value).getOrNull() + ?: if (ignoreMissingMembers) { + continue + } else { + throw CodegenException("Protocol test defines data for member shape `${key.value}`, but member shape was not found on structure shape ${shape.id}") + } renderMemberHelper(memberShape, value) } @@ -465,8 +466,9 @@ open class Instantiator( */ private fun fillDefaultValue(shape: Shape): Node = when (shape) { - is MemberShape -> shape.getTrait()?.toNode() - ?: fillDefaultValue(model.expectShape(shape.target)) + is MemberShape -> + shape.getTrait()?.toNode() + ?: fillDefaultValue(model.expectShape(shape.target)) // Aggregate shapes. is StructureShape -> Node.objectNode() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index 033612dce9..0dd33c1e01 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -159,11 +159,12 @@ abstract class ProtocolTestGenerator { abstract fun RustWriter.renderAllTestCases(allTests: List) /** Filter out test cases that are disabled or don't match the service protocol. */ - private fun List.filterMatching(): List = if (runOnly.isEmpty()) { - this.filter { testCase -> testCase.protocol == codegenContext.protocol && !disabledTests.contains(testCase.id) } - } else { - this.filter { testCase -> runOnly.contains(testCase.id) } - } + private fun List.filterMatching(): List = + if (runOnly.isEmpty()) { + this.filter { testCase -> testCase.protocol == codegenContext.protocol && !disabledTests.contains(testCase.id) } + } else { + this.filter { testCase -> runOnly.contains(testCase.id) } + } private fun TestCase.toFailingTest(): FailingTest = when (this) { @@ -186,8 +187,9 @@ abstract class ProtocolTestGenerator { } fun requestTestCases(): List { - val requestTests = operationShape.getTrait()?.getTestCasesFor(appliesTo).orEmpty() - .map { TestCase.RequestTest(it) } + val requestTests = + operationShape.getTrait()?.getTestCasesFor(appliesTo).orEmpty() + .map { TestCase.RequestTest(it) } return requestTests.filterMatching() } @@ -430,7 +432,9 @@ sealed class FailingTest(open val serviceShapeId: String, open val id: String) { sealed class TestCaseKind { data object Request : TestCaseKind() + data object Response : TestCaseKind() + data object MalformedRequest : TestCaseKind() } @@ -497,30 +501,34 @@ sealed class TestCase { */ val id: String - get() = when (this) { - is RequestTest -> this.testCase.id - is MalformedRequestTest -> this.testCase.id - is ResponseTest -> this.testCase.id - } + get() = + when (this) { + is RequestTest -> this.testCase.id + is MalformedRequestTest -> this.testCase.id + is ResponseTest -> this.testCase.id + } val protocol: ShapeId - get() = when (this) { - is RequestTest -> this.testCase.protocol - is MalformedRequestTest -> this.testCase.protocol - is ResponseTest -> this.testCase.protocol - } + get() = + when (this) { + is RequestTest -> this.testCase.protocol + is MalformedRequestTest -> this.testCase.protocol + is ResponseTest -> this.testCase.protocol + } val kind: TestCaseKind - get() = when (this) { - is RequestTest -> TestCaseKind.Request - is ResponseTest -> TestCaseKind.Response - is MalformedRequestTest -> TestCaseKind.MalformedRequest - } + get() = + when (this) { + is RequestTest -> TestCaseKind.Request + is ResponseTest -> TestCaseKind.Response + is MalformedRequestTest -> TestCaseKind.MalformedRequest + } val documentation: String? - get() = when (this) { - is RequestTest -> this.testCase.documentation.orNull() - is ResponseTest -> this.testCase.documentation.orNull() - is MalformedRequestTest -> this.testCase.documentation.orNull() - } + get() = + when (this) { + is RequestTest -> this.testCase.documentation.orNull() + is ResponseTest -> this.testCase.documentation.orNull() + is MalformedRequestTest -> this.testCase.documentation.orNull() + } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index 955f330297..2d910d5b39 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -7,28 +7,22 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model -import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.PANIC -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isStreaming -import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.outputShape class RpcV2CborHttpBindingResolver( @@ -61,7 +55,9 @@ class RpcV2CborHttpBindingResolver( override fun httpTrait(operationShape: OperationShape) = PANIC("RPC v2 does not support the `@http` trait") override fun requestBindings(operationShape: OperationShape) = bindings(operationShape.inputShape) + override fun responseBindings(operationShape: OperationShape) = bindings(operationShape.outputShape) + override fun errorResponseBindings(errorShape: ToShapeId) = bindings(errorShape) /** @@ -94,26 +90,29 @@ class RpcV2CborHttpBindingResolver( open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig - private val errorScope = arrayOf( - "Bytes" to RuntimeType.Bytes, - "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), - "HeaderMap" to RuntimeType.Http.resolve("HeaderMap"), - "JsonError" to CargoDependency.smithyJson(runtimeConfig).toType() - .resolve("deserialize::error::DeserializeError"), - "Response" to RuntimeType.Http.resolve("Response"), - "json_errors" to RuntimeType.jsonErrors(runtimeConfig), - ) + private val errorScope = + arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "HeaderMap" to RuntimeType.Http.resolve("HeaderMap"), + "JsonError" to + CargoDependency.smithyJson(runtimeConfig).toType() + .resolve("deserialize::error::DeserializeError"), + "Response" to RuntimeType.Http.resolve("Response"), + "json_errors" to RuntimeType.jsonErrors(runtimeConfig), + ) private val jsonDeserModule = RustModule.private("json_deser") - override val httpBindingResolver: HttpBindingResolver = RpcV2CborHttpBindingResolver( - codegenContext.model, - ProtocolContentTypes( - requestDocument = "application/cbor", - responseDocument = "application/cbor", - eventStreamContentType = "application/vnd.amazon.eventstream", - eventStreamMessageContentType = "application/cbor", - ), - ) + override val httpBindingResolver: HttpBindingResolver = + RpcV2CborHttpBindingResolver( + codegenContext.model, + ProtocolContentTypes( + requestDocument = "application/cbor", + responseDocument = "application/cbor", + eventStreamContentType = "application/vnd.amazon.eventstream", + eventStreamMessageContentType = "application/cbor", + ), + ) // Note that [CborParserGenerator] and [CborSerializerGenerator] automatically (de)serialize timestamps // using floating point seconds from the epoch. diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 237c8b988e..707b6a502c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -77,13 +77,14 @@ class CborParserGenerator( private val codegenTarget = codegenContext.target private val smithyCbor = CargoDependency.smithyCbor(runtimeConfig).toType() private val protocolFunctions = ProtocolFunctions(codegenContext) - private val codegenScope = arrayOf( - "SmithyCbor" to smithyCbor, - "Decoder" to smithyCbor.resolve("Decoder"), - "Error" to smithyCbor.resolve("decode::DeserializeError"), - "HashMap" to RuntimeType.HashMap, - "Vec" to RuntimeType.Vec, - ) + private val codegenScope = + arrayOf( + "SmithyCbor" to smithyCbor, + "Decoder" to smithyCbor.resolve("Decoder"), + "Error" to smithyCbor.resolve("decode::DeserializeError"), + "HashMap" to RuntimeType.HashMap, + "Vec" to RuntimeType.Vec, + ) private fun listMemberParserFn( listSymbol: Symbol, @@ -193,7 +194,10 @@ class CborParserGenerator( } } - private fun structurePairParserFnWritable(builderSymbol: Symbol, includedMembers: Collection) = writable { + private fun structurePairParserFnWritable( + builderSymbol: Symbol, + includedMembers: Collection, + ) = writable { rustBlockTemplate( """ fn pair( @@ -207,34 +211,36 @@ class CborParserGenerator( withBlock("builder = match decoder.str()?.as_ref() {", "};") { for (member in includedMembers) { rustBlock("${member.memberName.dq()} =>") { - val callBuilderSetMemberFieldWritable = writable { - withBlock("builder.${member.setterName()}(", ")") { - conditionalBlock("Some(", ")", symbolProvider.toSymbol(member).isOptional()) { - val symbol = symbolProvider.toSymbol(member) - if (symbol.isRustBoxed()) { - rustBlock("") { + val callBuilderSetMemberFieldWritable = + writable { + withBlock("builder.${member.setterName()}(", ")") { + conditionalBlock("Some(", ")", symbolProvider.toSymbol(member).isOptional()) { + val symbol = symbolProvider.toSymbol(member) + if (symbol.isRustBoxed()) { + rustBlock("") { + rustTemplate( + "let v = #{DeserializeMember:W}?;", + "DeserializeMember" to deserializeMember(member), + ) + + for (customization in customizations) { + customization.section( + CborParserSection.BeforeBoxingDeserializedMember( + member, + ), + )(this) + } + rust("Box::new(v)") + } + } else { rustTemplate( - "let v = #{DeserializeMember:W}?;", - "DeserializeMember" to deserializeMember(member) + "#{DeserializeMember:W}?", + "DeserializeMember" to deserializeMember(member), ) - - for (customization in customizations) { - customization.section( - CborParserSection.BeforeBoxingDeserializedMember( - member - ) - )(this) - } - rust("Box::new(v)") } - } else { - rustTemplate("#{DeserializeMember:W}?", - "DeserializeMember" to deserializeMember(member) - ) } } } - } if (member.isOptional) { // Call `builder.set_member()` only if the value for the field on the wire is not null. @@ -245,10 +251,9 @@ class CborParserGenerator( })? """, *codegenScope, - "MemberSettingWritable" to callBuilderSetMemberFieldWritable + "MemberSettingWritable" to callBuilderSetMemberFieldWritable, ) - } - else { + } else { callBuilderSetMemberFieldWritable.invoke(this) } } @@ -260,77 +265,81 @@ class CborParserGenerator( decoder.skip()?; builder } - """) + """, + ) } rust("Ok(builder)") } } - private fun unionPairParserFnWritable(shape: UnionShape) = writable { - val returnSymbolToParse = returnSymbolToParse(shape) - rustBlockTemplate( - """ + private fun unionPairParserFnWritable(shape: UnionShape) = + writable { + val returnSymbolToParse = returnSymbolToParse(shape) + rustBlockTemplate( + """ fn pair( decoder: &mut #{Decoder} ) -> Result<#{UnionSymbol}, #{Error}> """, - *codegenScope, - "UnionSymbol" to returnSymbolToParse.symbol, - ) { - withBlock("Ok(match decoder.str()?.as_ref() {", "})") { - for (member in shape.members()) { - val variantName = symbolProvider.toMemberName(member) + *codegenScope, + "UnionSymbol" to returnSymbolToParse.symbol, + ) { + withBlock("Ok(match decoder.str()?.as_ref() {", "})") { + for (member in shape.members()) { + val variantName = symbolProvider.toMemberName(member) - if (member.isTargetUnit()) { - rust( - """ + if (member.isTargetUnit()) { + rust( + """ ${member.memberName.dq()} => { decoder.skip()?; #T::$variantName } """, - returnSymbolToParse.symbol - ) - } else { - withBlock("${member.memberName.dq()} => #T::$variantName(", "?),", returnSymbolToParse.symbol) { - deserializeMember(member).invoke(this) + returnSymbolToParse.symbol, + ) + } else { + withBlock("${member.memberName.dq()} => #T::$variantName(", "?),", returnSymbolToParse.symbol) { + deserializeMember(member).invoke(this) + } } } - } - when (codegenTarget.renderUnknownVariant()) { - // In client mode, resolve an unknown union variant to the unknown variant. - true -> - rustTemplate( - """ + when (codegenTarget.renderUnknownVariant()) { + // In client mode, resolve an unknown union variant to the unknown variant. + true -> + rustTemplate( + """ _ => { decoder.skip()?; Some(#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME}) } """, - "Union" to returnSymbolToParse.symbol, - *codegenScope, - ) - // In server mode, use strict parsing. - // Consultation: https://github.com/awslabs/smithy/issues/1222 - false -> - rustTemplate( - "variant => return Err(#{Error}::unknown_union_variant(variant, decoder.position()))", - *codegenScope, - ) + "Union" to returnSymbolToParse.symbol, + *codegenScope, + ) + // In server mode, use strict parsing. + // Consultation: https://github.com/awslabs/smithy/issues/1222 + false -> + rustTemplate( + "variant => return Err(#{Error}::unknown_union_variant(variant, decoder.position()))", + *codegenScope, + ) + } } } } - } enum class CollectionKind { Map, - List; + List, + ; /** Method to invoke on the decoder to decode this collection kind. **/ - fun decoderMethodName() = when (this) { - Map -> "map" - List -> "list" - } + fun decoderMethodName() = + when (this) { + Map -> "map" + List -> "list" + } } /** @@ -371,7 +380,9 @@ class CborParserGenerator( } private fun decodeStructureMapLoopWritable() = decodeCollectionLoopWritable(CollectionKind.Map, "builder", "pair") + private fun decodeMapLoopWritable() = decodeCollectionLoopWritable(CollectionKind.Map, "map", "pair") + private fun decodeListLoopWritable() = decodeCollectionLoopWritable(CollectionKind.List, "list", "member") /** @@ -446,62 +457,66 @@ class CborParserGenerator( return structureParser(operationShape, symbolProvider.symbolForBuilder(inputShape), includedMembers) } - private fun deserializeMember(memberShape: MemberShape) = writable { - when (val target = model.expectShape(memberShape.target)) { - // Simple shapes: https://smithy.io/2.0/spec/simple-types.html - is BlobShape -> rust("decoder.blob()") - is BooleanShape -> rust("decoder.boolean()") + private fun deserializeMember(memberShape: MemberShape) = + writable { + when (val target = model.expectShape(memberShape.target)) { + // Simple shapes: https://smithy.io/2.0/spec/simple-types.html + is BlobShape -> rust("decoder.blob()") + is BooleanShape -> rust("decoder.boolean()") - is StringShape -> deserializeString(target).invoke(this) + is StringShape -> deserializeString(target).invoke(this) - is ByteShape -> rust("decoder.byte()") - is ShortShape -> rust("decoder.short()") - is IntegerShape -> rust("decoder.integer()") - is LongShape -> rust("decoder.long()") + is ByteShape -> rust("decoder.byte()") + is ShortShape -> rust("decoder.short()") + is IntegerShape -> rust("decoder.integer()") + is LongShape -> rust("decoder.long()") - is FloatShape -> rust("decoder.float()") - is DoubleShape -> rust("decoder.double()") + is FloatShape -> rust("decoder.float()") + is DoubleShape -> rust("decoder.double()") - is TimestampShape -> rust("decoder.timestamp()") + is TimestampShape -> rust("decoder.timestamp()") - // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html - is StructureShape -> deserializeStruct(target) - is CollectionShape -> deserializeCollection(target) - is MapShape -> deserializeMap(target) - is UnionShape -> deserializeUnion(target) + // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html + is StructureShape -> deserializeStruct(target) + is CollectionShape -> deserializeCollection(target) + is MapShape -> deserializeMap(target) + is UnionShape -> deserializeUnion(target) - // Note that no protocol using CBOR serialization supports `document` shapes. - else -> PANIC("unexpected shape: $target") + // Note that no protocol using CBOR serialization supports `document` shapes. + else -> PANIC("unexpected shape: $target") + } } - } - private fun deserializeString(target: StringShape) = writable { - when (target.hasTrait()) { - true -> { - if (this@CborParserGenerator.returnSymbolToParse(target).isUnconstrained) { - rust("decoder.string()") - } else { - rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) + private fun deserializeString(target: StringShape) = + writable { + when (target.hasTrait()) { + true -> { + if (this@CborParserGenerator.returnSymbolToParse(target).isUnconstrained) { + rust("decoder.string()") + } else { + rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) + } } + false -> rust("decoder.string()") } - false -> rust("decoder.string()") } - } private fun RustWriter.deserializeCollection(shape: CollectionShape) { val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) - val parser = protocolFunctions.deserializeFn(shape) { fnName -> - val initContainerWritable = writable { - withBlock("let mut list = ", ";") { - conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { - rustTemplate("#{Vec}::new()", *codegenScope) + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = + writable { + withBlock("let mut list = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{Vec}::new()", *codegenScope) + } + } } - } - } - rustTemplate( - """ + rustTemplate( + """ pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { #{ListMemberParserFn:W} @@ -512,18 +527,19 @@ class CborParserGenerator( Ok(list) } """, - "ReturnType" to returnSymbol, - "ListMemberParserFn" to listMemberParserFn( - returnSymbol, - isSparseList = shape.hasTrait(), - shape.member, - returnUnconstrainedType = returnUnconstrainedType, - ), - "InitContainerWritable" to initContainerWritable, - "DecodeListLoop" to decodeListLoopWritable(), - *codegenScope, - ) - } + "ReturnType" to returnSymbol, + "ListMemberParserFn" to + listMemberParserFn( + returnSymbol, + isSparseList = shape.hasTrait(), + shape.member, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeListLoop" to decodeListLoopWritable(), + *codegenScope, + ) + } rust("#T(decoder)", parser) } @@ -531,17 +547,19 @@ class CborParserGenerator( val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) - val parser = protocolFunctions.deserializeFn(shape) { fnName -> - val initContainerWritable = writable { - withBlock("let mut map = ", ";") { - conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { - rustTemplate("#{HashMap}::new()", *codegenScope) + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = + writable { + withBlock("let mut map = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{HashMap}::new()", *codegenScope) + } + } } - } - } - rustTemplate( - """ + rustTemplate( + """ pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { #{MapPairParserFn:W} @@ -552,63 +570,66 @@ class CborParserGenerator( Ok(map) } """, - "ReturnType" to returnSymbol, - "MapPairParserFn" to mapPairParserFnWritable( - keyTarget, - shape.value, - isSparseMap = shape.hasTrait(), - returnSymbol, - returnUnconstrainedType = returnUnconstrainedType, - ), - "InitContainerWritable" to initContainerWritable, - "DecodeMapLoop" to decodeMapLoopWritable(), - *codegenScope, - ) - } + "ReturnType" to returnSymbol, + "MapPairParserFn" to + mapPairParserFnWritable( + keyTarget, + shape.value, + isSparseMap = shape.hasTrait(), + returnSymbol, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeMapLoop" to decodeMapLoopWritable(), + *codegenScope, + ) + } rust("#T(decoder)", parser) } private fun RustWriter.deserializeStruct(shape: StructureShape) { val returnSymbolToParse = returnSymbolToParse(shape) - val parser = protocolFunctions.deserializeFn(shape) { fnName -> - rustBlockTemplate( - "pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}>", - "ReturnType" to returnSymbolToParse.symbol, - *codegenScope, - ) { - val builderSymbol = symbolProvider.symbolForBuilder(shape) - val includedMembers = shape.members() + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + "pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}>", + "ReturnType" to returnSymbolToParse.symbol, + *codegenScope, + ) { + val builderSymbol = symbolProvider.symbolForBuilder(shape) + val includedMembers = shape.members() - rustTemplate( - """ + rustTemplate( + """ #{StructurePairParserFn:W} let mut builder = #{Builder}::default(); #{DecodeStructureMapLoop:W} """, - *codegenScope, - "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), - "Builder" to builderSymbol, - "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), - ) - - // Only call `build()` if the builder is not fallible. Otherwise, return the builder. - if (returnSymbolToParse.isUnconstrained) { - rust("Ok(builder)") - } else { - rust("Ok(builder.build())") + *codegenScope, + "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), + "Builder" to builderSymbol, + "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), + ) + + // Only call `build()` if the builder is not fallible. Otherwise, return the builder. + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(builder)") + } else { + rust("Ok(builder.build())") + } } } - } rust("#T(decoder)", parser) } private fun RustWriter.deserializeUnion(shape: UnionShape) { val returnSymbolToParse = returnSymbolToParse(shape) - val parser = protocolFunctions.deserializeFn(shape) { fnName -> - rustTemplate( - """ + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustTemplate( + """ pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{UnionSymbol}, #{Error}> { #{UnionPairParserFnWritable} @@ -633,11 +654,11 @@ class CborParserGenerator( } } """, - "UnionSymbol" to returnSymbolToParse.symbol, - "UnionPairParserFnWritable" to unionPairParserFnWritable(shape), - *codegenScope, - ) - } + "UnionSymbol" to returnSymbolToParse.symbol, + "UnionPairParserFnWritable" to unionPairParserFnWritable(shape), + *codegenScope, + ) + } rust("#T(decoder)", parser) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index ca8bdcd223..f5d6403a46 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -95,7 +95,10 @@ class CborSerializerGenerator( val writeNulls: Boolean = false, ) { companion object { - fun collectionMember(context: Context, itemName: String): MemberContext = + fun collectionMember( + context: Context, + itemName: String, + ): MemberContext = MemberContext( "encoder", ValueExpression.Reference(itemName), @@ -103,7 +106,11 @@ class CborSerializerGenerator( writeNulls = true, ) - fun mapMember(context: Context, key: String, value: String): MemberContext = + fun mapMember( + context: Context, + key: String, + value: String, + ): MemberContext = MemberContext( "encoder.str($key)", ValueExpression.Reference(value), @@ -133,8 +140,7 @@ class CborSerializerGenerator( ) /** Returns an expression to encode a key member **/ - private fun encodeKeyExpression(name: String): String = - "encoder.str(${name.dq()})" + private fun encodeKeyExpression(name: String): String = "encoder.str(${name.dq()})" } } @@ -149,14 +155,16 @@ class CborSerializerGenerator( private val codegenTarget = codegenContext.target private val runtimeConfig = codegenContext.runtimeConfig private val protocolFunctions = ProtocolFunctions(codegenContext) + // TODO Cleanup - private val codegenScope = arrayOf( - "String" to RuntimeType.String, - "Error" to runtimeConfig.serializationError(), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - "Encoder" to RuntimeType.smithyCbor(runtimeConfig).resolve("Encoder"), - "ByteSlab" to RuntimeType.ByteSlab, - ) + private val codegenScope = + arrayOf( + "String" to RuntimeType.String, + "Error" to runtimeConfig.serializationError(), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "Encoder" to RuntimeType.smithyCbor(runtimeConfig).resolve("Encoder"), + "ByteSlab" to RuntimeType.ByteSlab, + ) private val serializerUtil = SerializerUtil(model, symbolProvider) /** @@ -169,10 +177,11 @@ class CborSerializerGenerator( includedMembers: List, error: Boolean, ): RuntimeType { - val suffix = when (error) { - true -> "error" - else -> "output" - } + val suffix = + when (error) { + true -> "error" + else -> "output" + } return protocolFunctions.serializeFn(structureShape, fnNameSuffix = suffix) { fnName -> rustBlockTemplate( "pub fn $fnName(value: &#{target}) -> Result, #{Error}>", @@ -247,38 +256,39 @@ class CborSerializerGenerator( """ encoder.begin_map(); encoder.end(); - """ + """, ) return } - val structureSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> - rustBlockTemplate( - "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> Result<(), #{Error}>", - "StructureSymbol" to symbolProvider.toSymbol(context.shape), - *codegenScope, - ) { - // TODO If all members are non-`Option`-al, we know AOT the map's size and can use `.map()` - // instead of `.begin_map()` for efficiency. Add test. - rust("encoder.begin_map();") - for (customization in customizations) { - customization.section( - CborSerializerSection.BeforeSerializingStructureMembers( - context.shape, - "encoder", - ), - )(this) - } - context.copy(localName = "input").also { inner -> - val members = includedMembers ?: inner.shape.members() - for (member in members) { - serializeMember(MemberContext.structMember(inner, member, symbolProvider)) + val structureSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> Result<(), #{Error}>", + "StructureSymbol" to symbolProvider.toSymbol(context.shape), + *codegenScope, + ) { + // TODO If all members are non-`Option`-al, we know AOT the map's size and can use `.map()` + // instead of `.begin_map()` for efficiency. Add test. + rust("encoder.begin_map();") + for (customization in customizations) { + customization.section( + CborSerializerSection.BeforeSerializingStructureMembers( + context.shape, + "encoder", + ), + )(this) } + context.copy(localName = "input").also { inner -> + val members = includedMembers ?: inner.shape.members() + for (member in members) { + serializeMember(MemberContext.structMember(inner, member, symbolProvider)) + } + } + rust("encoder.end();") + rust("Ok(())") } - rust("encoder.end();") - rust("Ok(())") } - } rust("#T(encoder, ${context.localName})?;", structureSerializer) } @@ -305,7 +315,10 @@ class CborSerializerGenerator( } } - private fun RustWriter.serializeMemberValue(context: MemberContext, target: Shape) { + private fun RustWriter.serializeMemberValue( + context: MemberContext, + target: Shape, + ) { val encoder = context.encoderBindingName val value = context.valueExpression val containerShape = model.expectShape(context.shape.container) @@ -359,7 +372,7 @@ class CborSerializerGenerator( encoder.array( (${context.valueExpression.asValue()}).len().try_into().expect("`usize` to `u64` conversion failed") ); - """ + """, ) val itemName = safeName("item") rustBlock("for $itemName in ${context.valueExpression.asRef()}") { @@ -378,7 +391,7 @@ class CborSerializerGenerator( encoder.map( (${context.valueExpression.asValue()}).len().try_into().expect("`usize` to `u64` conversion failed") ); - """ + """, ) rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { val keyExpression = "$keyName.as_str()" @@ -388,37 +401,39 @@ class CborSerializerGenerator( private fun RustWriter.serializeUnion(context: Context) { val unionSymbol = symbolProvider.toSymbol(context.shape) - val unionSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> - rustBlockTemplate( - "pub fn $fnName(encoder: &mut #{Encoder}, input: &#{UnionSymbol}) -> Result<(), #{Error}>", - "UnionSymbol" to unionSymbol, - *codegenScope, - ) { - // A union is serialized identically as a `structure` shape, but only a single member can be set to a - // non-null value. - rust("encoder.map(1);") - rustBlock("match input") { - for (member in context.shape.members()) { - val variantName = if (member.isTargetUnit()) { - symbolProvider.toMemberName(member) - } else { - "${symbolProvider.toMemberName(member)}(inner)" + val unionSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, input: &#{UnionSymbol}) -> Result<(), #{Error}>", + "UnionSymbol" to unionSymbol, + *codegenScope, + ) { + // A union is serialized identically as a `structure` shape, but only a single member can be set to a + // non-null value. + rust("encoder.map(1);") + rustBlock("match input") { + for (member in context.shape.members()) { + val variantName = + if (member.isTargetUnit()) { + symbolProvider.toMemberName(member) + } else { + "${symbolProvider.toMemberName(member)}(inner)" + } + rustBlock("#T::$variantName =>", unionSymbol) { + serializeMember(MemberContext.unionMember("inner", member)) + } } - rustBlock("#T::$variantName =>", unionSymbol) { - serializeMember(MemberContext.unionMember("inner", member)) + if (codegenTarget.renderUnknownVariant()) { + rustTemplate( + "#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", + "Union" to unionSymbol, + *codegenScope, + ) } } - if (codegenTarget.renderUnknownVariant()) { - rustTemplate( - "#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", - "Union" to unionSymbol, - *codegenScope, - ) - } + rust("Ok(())") } - rust("Ok(())") } - } rust("#T(encoder, ${context.valueExpression.asRef()})?;", unionSerializer) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 14cf21d6cb..6580968749 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -224,7 +224,7 @@ class JsonSerializerGenerator( """ object.finish(); Ok(out) - """ + """, ) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt index dc2c84bd52..89d4512007 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt @@ -51,7 +51,10 @@ object OperationNormalizer { * Returns `true` if the user had originally modeled an operation input shape on the given [operation]; * `false` if the transform added a synthetic one. */ - fun hadUserModeledOperationInput(operation: OperationShape, model: Model): Boolean { + fun hadUserModeledOperationInput( + operation: OperationShape, + model: Model, + ): Boolean { val syntheticInputTrait = operation.inputShape(model).expectTrait() return syntheticInputTrait.originalId != null } @@ -60,7 +63,10 @@ object OperationNormalizer { * Returns `true` if the user had originally modeled an operation output shape on the given [operation]; * `false` if the transform added a synthetic one. */ - fun hadUserModeledOperationOutput(operation: OperationShape, model: Model): Boolean { + fun hadUserModeledOperationOutput( + operation: OperationShape, + model: Model, + ): Boolean { val syntheticOutputTrait = operation.outputShape(model).expectTrait() return syntheticOutputTrait.originalId != null } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index cf5c81df54..7063d58617 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -610,7 +610,10 @@ open class ServerCodegenVisitor( /** * Generate protocol tests. This method can be overridden by other languages such as Python. */ - open fun protocolTestsForOperation(writer: RustWriter, shape: OperationShape) { + open fun protocolTestsForOperation( + writer: RustWriter, + shape: OperationShape, + ) { codegenDecorator.protocolTestGenerator( codegenContext, ServerProtocolTestGenerator( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt index ce2e590519..54cd279e8d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt @@ -13,7 +13,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection import software.amazon.smithy.rust.codegen.core.util.hasTrait -import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext /** * Smithy RPC v2 CBOR requires errors to be serialized in server responses with an additional `__type` field. @@ -40,21 +39,22 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext * there shouldn't™️ be any harm in always including it, which simplifies the code generator. */ class AddTypeFieldToServerErrorsCborCustomization : CborSerializerCustomization() { - override fun section(section: CborSerializerSection): Writable = when (section) { - is CborSerializerSection.BeforeSerializingStructureMembers -> - if (section.structureShape.hasTrait()) { - writable { - rust( - """ + override fun section(section: CborSerializerSection): Writable = + when (section) { + is CborSerializerSection.BeforeSerializingStructureMembers -> + if (section.structureShape.hasTrait()) { + writable { + rust( + """ ${section.encoderBindingName} .str("__type") .str("${escape(section.structureShape.id.toString())}"); - """ - ) + """, + ) + } + } else { + emptySection } - } else { - emptySection - } - else -> emptySection - } + else -> emptySection + } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt index 42eb1f6843..a01d0076e9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt @@ -21,19 +21,21 @@ import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstr * That value will be a `std::collections::HashMap` for map shapes, and a `std::vec::Vec` for collection shapes. */ class BeforeEncodingMapOrCollectionCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { - override fun section(section: CborSerializerSection): Writable = when (section) { - is CborSerializerSection.BeforeIteratingOverMapOrCollection -> writable { - check(section.shape is CollectionShape || section.shape is MapShape) - if (workingWithPublicConstrainedWrapperTupleType( - section.shape, - codegenContext.model, - codegenContext.settings.codegenConfig.publicConstrainedTypes, - ) - ) { - section.context.valueExpression = - ValueExpression.Reference("&${section.context.valueExpression.name}.0") - } + override fun section(section: CborSerializerSection): Writable = + when (section) { + is CborSerializerSection.BeforeIteratingOverMapOrCollection -> + writable { + check(section.shape is CollectionShape || section.shape is MapShape) + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + else -> emptySection } - else -> emptySection - } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index 8dc0c3f8cf..ba9b13f38a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -76,17 +76,17 @@ class ServerInstantiator( ignoreMissingMembers: Boolean = false, ) : Instantiator( - codegenContext.symbolProvider, - codegenContext.model, - codegenContext.runtimeConfig, - ServerBuilderKindBehavior(codegenContext), - defaultsForRequiredFields = true, - customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), - // Construct with direct pattern to more closely replicate actual server customer usage - constructPattern = InstantiatorConstructPattern.DIRECT, - customWritable = customWritable, - ignoreMissingMembers = ignoreMissingMembers, - ) + codegenContext.symbolProvider, + codegenContext.model, + codegenContext.runtimeConfig, + ServerBuilderKindBehavior(codegenContext), + defaultsForRequiredFields = true, + customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), + // Construct with direct pattern to more closely replicate actual server customer usage + constructPattern = InstantiatorConstructPattern.DIRECT, + customWritable = customWritable, + ignoreMissingMembers = ignoreMissingMembers, + ) class ServerBuilderInstantiator( private val symbolProvider: RustSymbolProvider, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index e7379e8a44..1397502867 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -24,12 +24,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.ReturnSymbolToParse -import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator -import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator @@ -175,9 +175,7 @@ class ServerAwsJsonProtocol( } // TODO This could technically be `&static str` right? - override fun serverRouterRequestSpecType( - requestSpecModule: RuntimeType, - ): RuntimeType = RuntimeType.String + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.String override fun serverRouterRuntimeConstructor() = when (version) { @@ -294,17 +292,19 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonPa */ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization(val codegenContext: ServerCodegenContext) : CborParserCustomization() { - override fun section(section: CborParserSection): Writable = when (section) { - is CborParserSection.BeforeBoxingDeserializedMember -> writable { - // We're only interested in _structure_ member shapes that can reach constrained shapes. - if ( - codegenContext.model.expectShape(section.shape.container) is StructureShape && - section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) - ) { - rust("let v = v.into();") - } + override fun section(section: CborParserSection): Writable = + when (section) { + is CborParserSection.BeforeBoxingDeserializedMember -> + writable { + // We're only interested in _structure_ member shapes that can reach constrained shapes. + if ( + codegenContext.model.expectShape(section.shape.container) is StructureShape && + section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) + ) { + rust("let v = v.into();") + } + } } - } } class ServerRpcV2CborProtocol( @@ -337,8 +337,9 @@ class ServerRpcV2CborProtocol( override fun markerStruct() = ServerRuntimeType.protocol("RpcV2", "rpc_v2", runtimeConfig) - override fun routerType() = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() - .resolve("protocol::rpc_v2::router::RpcV2Router") + override fun routerType() = + ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() + .resolve("protocol::rpc_v2::router::RpcV2Router") override fun serverRouterRequestSpec( operationShape: OperationShape, @@ -353,8 +354,7 @@ class ServerRpcV2CborProtocol( rust("$serviceName.$operationName".dq()) } - override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = - RuntimeType.StaticStr + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.StaticStr override fun serverRouterRuntimeConstructor() = "rpc_v2_router" diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 2ab9f8d2b9..4c5060eef7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -39,7 +39,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Servi import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_11 import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON_VALIDATION -import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR_EXTRAS import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors @@ -331,12 +330,15 @@ class ServerProtocolTestGenerator( * serialize said shape, the resulting HTTP response is of the form we expect, as defined in the test case. * [shape] is either an operation output shape or an error shape. */ - private fun RustWriter.renderHttpResponseTestCase(testCase: HttpResponseTestCase, shape: StructureShape) { + private fun RustWriter.renderHttpResponseTestCase( + testCase: HttpResponseTestCase, + shape: StructureShape, + ) { val operationErrorName = "crate::error::${operationSymbol.name}Error" if (!protocolSupport.responseSerialization || ( !protocolSupport.errorSerialization && shape.hasTrait() - ) + ) ) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -436,19 +438,20 @@ class ServerProtocolTestGenerator( // proxy for `bodyMediaType`. This works because `rpcv2Cbor` happens to be the only protocol where // the body is base64-encoded in the protocol test, but checking `bodyMediaType` should be a more // resilient check. - val encodedBody = if (protocol.toShapeId() == ShapeId.from("smithy.protocols#rpcv2Cbor")) { - """ + val encodedBody = + if (protocol.toShapeId() == ShapeId.from("smithy.protocols#rpcv2Cbor")) { + """ #{Bytes}::from( #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( "`body` field of Smithy protocol test is not correctly base64 encoded" ) ) """ - } else { - """ + } else { + """ #{Bytes}::from_static($sanitizedBody.as_bytes()) """ - } + } "#{SmithyHttpServer}::body::Body::from($encodedBody)" } else { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 31a55130f9..49489aeb65 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -59,7 +59,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.core.smithy.wrapOptional @@ -123,9 +122,9 @@ class ServerHttpBoundProtocolGenerator( customizations: List = listOf(), additionalHttpBindingCustomizations: List = listOf(), ) : ServerProtocolGenerator( - protocol, - ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), -) { + protocol, + ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), + ) { // TODO Delete, unused // Define suffixes for operation input / output / error wrappers companion object { @@ -575,14 +574,19 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } - private fun setResponseHeaderIfAbsent(writer: RustWriter, headerName: String, headerValue: String) { + private fun setResponseHeaderIfAbsent( + writer: RustWriter, + headerName: String, + headerValue: String, + ) { // We can be a tad more efficient if there's a `const` `HeaderName` in the `http` crate that matches. // https://docs.rs/http/latest/http/header/index.html#constants - val headerNameExpr = if (headerName == "content-type") { - "#{http}::header::CONTENT_TYPE" - } else { - "#{http}::header::HeaderName::from_static(\"$headerName\")" - } + val headerNameExpr = + if (headerName == "content-type") { + "#{http}::header::CONTENT_TYPE" + } else { + "#{http}::header::HeaderName::from_static(\"$headerName\")" + } writer.rustTemplate( """ @@ -596,7 +600,6 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } - /** * Sets HTTP response headers for the operation's output shape or the operation's error shape. * It will generate response headers for the operation's output shape, unless [errorShape] is non-null, in which diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt index a51119978b..244fc282db 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRpcV2CborFactory.kt @@ -12,20 +12,19 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRpcV2CborProtocol class ServerRpcV2CborFactory : ProtocolGeneratorFactory { - override fun protocol(codegenContext: ServerCodegenContext): Protocol = - ServerRpcV2CborProtocol(codegenContext) + override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRpcV2CborProtocol(codegenContext) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = ServerHttpBoundProtocolGenerator(codegenContext, ServerRpcV2CborProtocol(codegenContext)) override fun support(): ProtocolSupport { return ProtocolSupport( - /* Client support */ + // Client support requestSerialization = false, requestBodySerialization = false, responseDeserialization = false, errorDeserialization = false, - /* Server support */ + // Server support requestDeserialization = true, requestBodyDeserialization = true, responseSerialization = true, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt index a40673eb4a..8c0254904e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt @@ -9,7 +9,6 @@ import software.amazon.smithy.build.PluginContext import software.amazon.smithy.build.SmithyBuildPlugin import software.amazon.smithy.model.Model import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.codegenIntegrationTest import software.amazon.smithy.rust.codegen.server.smithy.RustServerCodegenPlugin diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt index 71187f1516..b6b4661bf3 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -41,7 +41,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Proto import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions -import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE @@ -85,21 +84,31 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { override fun memberMeta(memberShape: MemberShape): RustMetadata { val baseMetadata = base.toSymbol(memberShape).expectRustMetadata() return baseMetadata.copy( - additionalAttributes = baseMetadata.additionalAttributes + Attribute( - """serde(rename = "${memberShape.memberName}")""", - isDeriveHelper = true, - ), + additionalAttributes = + baseMetadata.additionalAttributes + + Attribute( + """serde(rename = "${memberShape.memberName}")""", + isDeriveHelper = true, + ), ) } override fun structureMeta(structureShape: StructureShape) = addDeriveSerdeSerializeDeserialize(structureShape) + override fun unionMeta(unionShape: UnionShape) = addDeriveSerdeSerializeDeserialize(unionShape) + override fun enumMeta(stringShape: StringShape) = addDeriveSerdeSerializeDeserialize(stringShape) override fun listMeta(listShape: ListShape): RustMetadata = addDeriveSerdeSerializeDeserialize(listShape) + override fun mapMeta(mapShape: MapShape): RustMetadata = addDeriveSerdeSerializeDeserialize(mapShape) - override fun stringMeta(stringShape: StringShape): RustMetadata = addDeriveSerdeSerializeDeserialize(stringShape) - override fun numberMeta(numberShape: NumberShape): RustMetadata = addDeriveSerdeSerializeDeserialize(numberShape) + + override fun stringMeta(stringShape: StringShape): RustMetadata = + addDeriveSerdeSerializeDeserialize(stringShape) + + override fun numberMeta(numberShape: NumberShape): RustMetadata = + addDeriveSerdeSerializeDeserialize(numberShape) + override fun blobMeta(blobShape: BlobShape): RustMetadata = addDeriveSerdeSerializeDeserialize(blobShape) } @@ -110,122 +119,129 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { // which we can't `#[derive(serde::Deserialize)]`. // Note we can't use `ModelTransformer.removeShapes` because it will leave the model in an inconsistent state // when removing list/set shape member shapes. - val removeTimestampAndBlobShapes: Predicate = Predicate { shape -> - when (shape) { - is MemberShape -> { - val targetShape = model.expectShape(shape.target) - targetShape is BlobShape || targetShape is TimestampShape - } - is BlobShape, is TimestampShape -> true - is CollectionShape -> { - val targetShape = model.expectShape(shape.member.target) - targetShape is BlobShape || targetShape is TimestampShape - } - else -> false - } - } - fun removeShapesByShapeId(shapeIds: Set): Predicate { - val predicate: Predicate = Predicate { shape -> + val removeTimestampAndBlobShapes: Predicate = + Predicate { shape -> when (shape) { is MemberShape -> { val targetShape = model.expectShape(shape.target) - shapeIds.contains(targetShape.id) + targetShape is BlobShape || targetShape is TimestampShape } - is CollectionShape -> { + is BlobShape, is TimestampShape -> true + is CollectionShape -> { val targetShape = model.expectShape(shape.member.target) - shapeIds.contains(targetShape.id) - } - else -> { - shapeIds.contains(shape.id) + targetShape is BlobShape || targetShape is TimestampShape } + else -> false } } + + fun removeShapesByShapeId(shapeIds: Set): Predicate { + val predicate: Predicate = + Predicate { shape -> + when (shape) { + is MemberShape -> { + val targetShape = model.expectShape(shape.target) + shapeIds.contains(targetShape.id) + } + is CollectionShape -> { + val targetShape = model.expectShape(shape.member.target) + shapeIds.contains(targetShape.id) + } + else -> { + shapeIds.contains(shape.id) + } + } + } return predicate } val modelTransformer = ModelTransformer.create() - model = modelTransformer.removeShapesIf( - modelTransformer.removeShapesIf(model, removeTimestampAndBlobShapes), - // These enums do not serialize their variants using the Rust members' names. - // We'd have to tack on `#[serde(rename = "name")]` using the proper name defined in the Smithy enum definition. - // But we have no way of injecting that attribute on Rust enum variants in the code generator. - // So we just remove these problematic shapes. - removeShapesByShapeId( - setOf( - ShapeId.from("smithy.protocoltests.shared#FooEnum"), - ShapeId.from("smithy.protocoltests.rpcv2Cbor#TestEnum"), + model = + modelTransformer.removeShapesIf( + modelTransformer.removeShapesIf(model, removeTimestampAndBlobShapes), + // These enums do not serialize their variants using the Rust members' names. + // We'd have to tack on `#[serde(rename = "name")]` using the proper name defined in the Smithy enum definition. + // But we have no way of injecting that attribute on Rust enum variants in the code generator. + // So we just remove these problematic shapes. + removeShapesByShapeId( + setOf( + ShapeId.from("smithy.protocoltests.shared#FooEnum"), + ShapeId.from("smithy.protocoltests.rpcv2Cbor#TestEnum"), + ), ), - ), - ) + ) return model } @Test fun `serde_cbor round trip`() { - val addDeriveSerdeSerializeDeserializeDecorator = object : ServerCodegenDecorator { - override val name: String = "Add `#[derive(serde::Serialize, serde::Deserialize)]`" - override val order: Byte = 0 + val addDeriveSerdeSerializeDeserializeDecorator = + object : ServerCodegenDecorator { + override val name: String = "Add `#[derive(serde::Serialize, serde::Deserialize)]`" + override val order: Byte = 0 - override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = - DeriveSerdeSerializeDeserializeSymbolMetadataProvider(base) - } + override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = + DeriveSerdeSerializeDeserializeSymbolMetadataProvider(base) + } // Don't generate protocol tests, because it'll attempt to pull out `params` for member shapes we'll remove // from the model. - val noProtocolTestsDecorator = object : ServerCodegenDecorator { - override val name: String = "Don't generate protocol tests" - override val order: Byte = 0 - - override fun protocolTestGenerator( - codegenContext: ServerCodegenContext, - baseGenerator: ProtocolTestGenerator - ): ProtocolTestGenerator { - val noOpProtocolTestsGenerator = object : ProtocolTestGenerator() { - override val codegenContext: CodegenContext - get() = baseGenerator.codegenContext - override val protocolSupport: ProtocolSupport - get() = baseGenerator.protocolSupport - override val operationShape: OperationShape - get() = baseGenerator.operationShape - override val appliesTo: AppliesTo - get() = baseGenerator.appliesTo - override val logger: Logger - get() = Logger.getLogger(javaClass.name) - override val expectFail: Set - get() = baseGenerator.expectFail - override val brokenTests: Set - get() = emptySet() - override val runOnly: Set - get() = baseGenerator.runOnly - override val disabledTests: Set - get() = baseGenerator.disabledTests - - override fun RustWriter.renderAllTestCases(allTests: List) { - // No-op. - } + val noProtocolTestsDecorator = + object : ServerCodegenDecorator { + override val name: String = "Don't generate protocol tests" + override val order: Byte = 0 + override fun protocolTestGenerator( + codegenContext: ServerCodegenContext, + baseGenerator: ProtocolTestGenerator, + ): ProtocolTestGenerator { + val noOpProtocolTestsGenerator = + object : ProtocolTestGenerator() { + override val codegenContext: CodegenContext + get() = baseGenerator.codegenContext + override val protocolSupport: ProtocolSupport + get() = baseGenerator.protocolSupport + override val operationShape: OperationShape + get() = baseGenerator.operationShape + override val appliesTo: AppliesTo + get() = baseGenerator.appliesTo + override val logger: Logger + get() = Logger.getLogger(javaClass.name) + override val expectFail: Set + get() = baseGenerator.expectFail + override val brokenTests: Set + get() = emptySet() + override val runOnly: Set + get() = baseGenerator.runOnly + override val disabledTests: Set + get() = baseGenerator.disabledTests + + override fun RustWriter.renderAllTestCases(allTests: List) { + // No-op. + } + } + return noOpProtocolTestsGenerator } - return noOpProtocolTestsGenerator } - } val model = prepareRpcV2CborModel() val serviceShape = model.expectShape(ShapeId.from(RPC_V2_CBOR)) serverIntegrationTest( model, additionalDecorators = listOf(addDeriveSerdeSerializeDeserializeDecorator, noProtocolTestsDecorator), - params = IntegrationTestParams(service = serviceShape.id.toString()) + params = IntegrationTestParams(service = serviceShape.id.toString()), ) { codegenContext, rustCrate -> // TODO(https://github.com/smithy-lang/smithy-rs/issues/1147): NaN != NaN. Ideally we when we address // this issue, we'd re-use the structure shape comparison code that both client and server protocol test // generators would use. val expectFail = setOf("RpcV2CborSupportsNaNFloatInputs", "RpcV2CborSupportsNaNFloatOutputs") - val codegenScope = arrayOf( - "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), - "SerdeCbor" to CargoDependency.SerdeCbor.toType(), - ) + val codegenScope = + arrayOf( + "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), + "SerdeCbor" to CargoDependency.SerdeCbor.toType(), + ) val instantiator = ServerInstantiator(codegenContext, ignoreMissingMembers = true) val rpcV2 = ServerRpcV2CborProtocol(codegenContext) @@ -247,11 +263,12 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { val targetShape = test.targetShape val params = test.testCase.params - val serializeFn = if (targetShape.hasTrait()) { - rpcV2.structuredDataSerializer().serverErrorSerializer(targetShape.id) - } else { - rpcV2.structuredDataSerializer().operationOutputSerializer(operationShape) - } + val serializeFn = + if (targetShape.hasTrait()) { + rpcV2.structuredDataSerializer().serverErrorSerializer(targetShape.id) + } else { + rpcV2.structuredDataSerializer().operationOutputSerializer(operationShape) + } if (serializeFn == null) { // Skip if there's nothing to serialize. @@ -305,9 +322,10 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { val targetShape = operationShape.inputShape(codegenContext.model) val params = test.testCase.params - val deserializeFn = rpcV2.structuredDataParser().serverInputParser(operationShape) - ?: // Skip if there's nothing to serialize. - continue + val deserializeFn = + rpcV2.structuredDataParser().serverInputParser(operationShape) + ?: // Skip if there's nothing to serialize. + continue if (expectFail.contains(test.id)) { writeWithNoFormatting("#[should_panic]") From ed17a10d4e7ebcf71b15fa1f2d15f64135768321 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 2 Jul 2024 18:10:01 +0200 Subject: [PATCH 32/77] Remove TODO I don't understand --- rust-runtime/aws-smithy-types/src/error/operation.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rust-runtime/aws-smithy-types/src/error/operation.rs b/rust-runtime/aws-smithy-types/src/error/operation.rs index ad5032fc4c..5b61f1acfe 100644 --- a/rust-runtime/aws-smithy-types/src/error/operation.rs +++ b/rust-runtime/aws-smithy-types/src/error/operation.rs @@ -21,7 +21,6 @@ pub struct SerializationError { kind: SerializationErrorKind, } -// TODO The docs in `main` are wrong. impl SerializationError { /// An error that occurs when serialization of an operation fails for an unknown reason. pub fn unknown_variant(union: &'static str) -> Self { From 9d2e56ad9f490410b2c174bd3bb9d0ac8d555910 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 3 Jul 2024 15:08:38 +0200 Subject: [PATCH 33/77] Adjust error message when bodies don't match in `aws-smithy-protocol-test` The left parameter is the expected one, see . --- rust-runtime/aws-smithy-protocol-test/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-protocol-test/src/lib.rs b/rust-runtime/aws-smithy-protocol-test/src/lib.rs index 63328ebba4..55e6d87b75 100644 --- a/rust-runtime/aws-smithy-protocol-test/src/lib.rs +++ b/rust-runtime/aws-smithy-protocol-test/src/lib.rs @@ -84,7 +84,7 @@ pub enum ProtocolTestFailure { #[error("Header `{forbidden}` was forbidden but found: `{found}`")] ForbiddenHeader { forbidden: String, found: String }, #[error( - "body did not match. left=actual, right=expected\n{comparison:?} \n == hint:\n{hint}." + "body did not match. left=expected, right=actual\n{comparison:?} \n == hint:\n{hint}." )] BodyDidNotMatch { // the comparison includes colorized escapes. PrettyString ensures that even during From 974f214d75cda919190034273b279d906d51930a Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 3 Jul 2024 16:12:25 +0200 Subject: [PATCH 34/77] Instantiate with dev-dependencies appropriately --- .../smithy/generators/ClientInstantiator.kt | 3 +- .../protocol/ClientProtocolTestGenerator.kt | 2 +- .../core/smithy/generators/Instantiator.kt | 60 ++++++++++++------- .../smithy/generators/ServerInstantiator.kt | 2 + .../protocol/ServerProtocolTestGenerator.kt | 2 +- 5 files changed, 45 insertions(+), 24 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt index ca39264a1c..992d85dd72 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt @@ -28,11 +28,12 @@ class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = true } -class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Instantiator( +class ClientInstantiator(private val codegenContext: ClientCodegenContext, withinTest: Boolean = false) : Instantiator( codegenContext.symbolProvider, codegenContext.model, codegenContext.runtimeConfig, ClientBuilderKindBehavior(codegenContext), + withinTest = false, ) { fun renderFluentCall( writer: RustWriter, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt index 9d9b4c5fac..b48c3d336a 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt @@ -94,7 +94,7 @@ class ClientProtocolTestGenerator( private val inputShape = operationShape.inputShape(codegenContext.model) private val outputShape = operationShape.outputShape(codegenContext.model) - private val instantiator = ClientInstantiator(codegenContext) + private val instantiator = ClientInstantiator(codegenContext, withinTest = true) private val codegenScope = arrayOf( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index 4e4b95897d..d6ac5ac109 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -98,10 +98,12 @@ open class Instantiator( private val constructPattern: InstantiatorConstructPattern = InstantiatorConstructPattern.BUILDER, private val customWritable: CustomWritable = NoCustomWritable(), /** - * The protocol test may provide data for missing members (because we transformed the model). + * A protocol test may provide data for missing members (because we transformed the model). * This flag makes it so that it is simply ignored, and code generation continues. **/ private val ignoreMissingMembers: Boolean = false, + /** Whether we're rendering within a test, in which case we should use dev-dependencies. */ + private val withinTest: Boolean = false, ) { data class Ctx( // The `http` crate requires that headers be lowercase, but Smithy protocol tests @@ -179,7 +181,7 @@ open class Instantiator( is MemberShape -> renderMember(writer, shape, data, ctx) is SimpleShape -> - PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate( + PrimitiveInstantiator(runtimeConfig, symbolProvider, withinTest).instantiate( shape, data, customWritable, @@ -487,7 +489,25 @@ open class Instantiator( } } -class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private val symbolProvider: SymbolProvider) { +class PrimitiveInstantiator( + private val runtimeConfig: RuntimeConfig, + private val symbolProvider: SymbolProvider, + withinTest: Boolean = false, +) { + val codegenScope = listOf( + "DateTime" to RuntimeType.dateTime(runtimeConfig), + "ByteStream" to RuntimeType.byteStream(runtimeConfig), + "Blob" to RuntimeType.blob(runtimeConfig), + "SmithyJson" to RuntimeType.smithyJson(runtimeConfig), + "SmithyTypes" to RuntimeType.smithyTypes(runtimeConfig), + ).map { it.first to + if (withinTest) { + it.second.toDevDependencyType() + } else { + it.second + } + }.toTypedArray() + fun instantiate( shape: SimpleShape, data: Node, @@ -501,9 +521,9 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va val num = BigDecimal(node.toString()) val wholePart = num.toInt() val fractionalPart = num.remainder(BigDecimal.ONE) - rust( - "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", - RuntimeType.dateTime(runtimeConfig).toDevDependencyType(), + rustTemplate( + "#{DateTime}::from_fractional_secs($wholePart, ${fractionalPart}_f64)", + *codegenScope, ) } @@ -514,14 +534,14 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va */ is BlobShape -> if (shape.hasTrait()) { - rust( - "#T::from_static(b${(data as StringNode).value.dq()})", - RuntimeType.byteStream(runtimeConfig).toDevDependencyType(), + rustTemplate( + "#{Bytestream}::from_static(b${(data as StringNode).value.dq()})", + *codegenScope, ) } else { - rust( - "#T::new(${(data as StringNode).value.dq()})", - RuntimeType.blob(runtimeConfig).toDevDependencyType(), + rustTemplate( + "#{Blob}::new(${(data as StringNode).value.dq()})", + *codegenScope, ) } @@ -531,11 +551,10 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va is StringNode -> { val numberSymbol = symbolProvider.toSymbol(shape) // support Smithy custom values, such as Infinity - rust( - """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", - numberSymbol, - RuntimeType.smithyTypes(runtimeConfig).toDevDependencyType() - .resolve("primitive::Parse"), + rustTemplate( + """<#{NumberSymbol} as #{SmithyTypes}::primitive::Parse>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", + "NumberSymbol" to numberSymbol, + *codegenScope, ) } @@ -554,11 +573,10 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va rustTemplate( """ let json_bytes = br##"${Node.prettyPrintJson(data)}"##; - let mut tokens = #{json_token_iter}(json_bytes).peekable(); - #{expect_document}(&mut tokens).expect("well formed json") + let mut tokens = #{SmithyJson}::deserialize::json_token_iter(json_bytes).peekable(); + #{SmithyJson}::deserialize::token::expect_document(&mut tokens).expect("well formed json") """, - "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), - "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), + *codegenScope, ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index ba9b13f38a..5b4860c26d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -74,6 +74,7 @@ class ServerInstantiator( codegenContext: CodegenContext, customWritable: CustomWritable = NoCustomWritable(), ignoreMissingMembers: Boolean = false, + withinTest: Boolean = false, ) : Instantiator( codegenContext.symbolProvider, @@ -86,6 +87,7 @@ class ServerInstantiator( constructPattern = InstantiatorConstructPattern.DIRECT, customWritable = customWritable, ignoreMissingMembers = ignoreMissingMembers, + withinTest = withinTest, ) class ServerBuilderInstantiator( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 4c5060eef7..52e572d29e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -260,7 +260,7 @@ class ServerProtocolTestGenerator( inputT to outputT } - private val instantiator = ServerInstantiator(codegenContext) + private val instantiator = ServerInstantiator(codegenContext, withinTest = true) private val codegenScope = arrayOf( From fecbb11c4e1d666fdf3ef48a564a543862d5cd27 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 3 Jul 2024 16:46:17 +0200 Subject: [PATCH 35/77] Specify service shape when loading model explicitly; revert this --- .../client/smithy/NamingObstacleCourseTest.kt | 31 +++++++++++++++---- .../smithy/ConstraintsMemberShapeTest.kt | 1 + .../RecursiveConstraintViolationsTest.kt | 6 +++- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/NamingObstacleCourseTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/NamingObstacleCourseTest.kt index addc4d8f1e..cff241e219 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/NamingObstacleCourseTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/NamingObstacleCourseTest.kt @@ -9,6 +9,7 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.NamingObstacleCourseTestModels.reusedInputOutputShapesModel import software.amazon.smithy.rust.codegen.core.testutil.NamingObstacleCourseTestModels.rustPreludeEnumVariantsModel import software.amazon.smithy.rust.codegen.core.testutil.NamingObstacleCourseTestModels.rustPreludeEnumsModel @@ -18,31 +19,49 @@ import software.amazon.smithy.rust.codegen.core.testutil.NamingObstacleCourseTes class NamingObstacleCourseTest { @Test fun `test Rust prelude operation names compile`() { - clientIntegrationTest(rustPreludeOperationsModel()) { _, _ -> } + clientIntegrationTest( + rustPreludeOperationsModel(), + params = IntegrationTestParams(service = "crate#Config"), + ) { _, _ -> } } @Test fun `test Rust prelude structure names compile`() { - clientIntegrationTest(rustPreludeStructsModel()) { _, _ -> } + clientIntegrationTest( + rustPreludeStructsModel(), + params = IntegrationTestParams(service = "crate#Config"), + ) { _, _ -> } } @Test fun `test Rust prelude enum names compile`() { - clientIntegrationTest(rustPreludeEnumsModel()) { _, _ -> } + clientIntegrationTest( + rustPreludeEnumsModel(), + params = IntegrationTestParams(service = "crate#Config"), + ) { _, _ -> } } @Test fun `test Rust prelude enum variant names compile`() { - clientIntegrationTest(rustPreludeEnumVariantsModel()) { _, _ -> } + clientIntegrationTest( + rustPreludeEnumVariantsModel(), + params = IntegrationTestParams(service = "crate#Config"), + ) { _, _ -> } } @Test fun `test reuse of input and output shapes json`() { - clientIntegrationTest(reusedInputOutputShapesModel(RestJson1Trait.builder().build())) + clientIntegrationTest( + reusedInputOutputShapesModel(RestJson1Trait.builder().build()), + params = IntegrationTestParams(service = "test#Service"), + ) } @Test fun `test reuse of input and output shapes xml`() { - clientIntegrationTest(reusedInputOutputShapesModel(RestXmlTrait.builder().build())) + clientIntegrationTest( + reusedInputOutputShapesModel(RestXmlTrait.builder().build()), + params = IntegrationTestParams(service = "test#Service"), + ) } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt index b97328eb8d..adf392e3a4 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt @@ -297,6 +297,7 @@ class ConstraintsMemberShapeTest { model, runtimeConfig = runtimeConfig, overrideTestDir = dirToUse, + service = "constrainedMemberShape#ConstrainedService", ) val codegenDecorator = CombinedServerCodegenDecorator.fromClasspath( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt index 432ad64fde..8c2231ff3f 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt @@ -11,6 +11,7 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource import software.amazon.smithy.model.Model +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest import java.util.stream.Stream @@ -198,6 +199,9 @@ internal class RecursiveConstraintViolationsTest { @ParameterizedTest @ArgumentsSource(RecursiveConstraintViolationsTestProvider::class) fun `recursive constraint violation code generation test`(testCase: TestCase) { - serverIntegrationTest(testCase.model) + serverIntegrationTest( + testCase.model, + params = IntegrationTestParams(service = "com.amazonaws.recursiveconstraintviolations#RecursiveConstraintViolations"), + ) } } From df55249760ab6e4c00ab020f56f48be126bca7ae Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 3 Jul 2024 16:46:18 +0200 Subject: [PATCH 36/77] Revert "Specify service shape when loading model explicitly; revert this" This reverts commit fecbb11c4e1d666fdf3ef48a564a543862d5cd27. --- .../client/smithy/NamingObstacleCourseTest.kt | 31 ++++--------------- .../smithy/ConstraintsMemberShapeTest.kt | 1 - .../RecursiveConstraintViolationsTest.kt | 6 +--- 3 files changed, 7 insertions(+), 31 deletions(-) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/NamingObstacleCourseTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/NamingObstacleCourseTest.kt index cff241e219..addc4d8f1e 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/NamingObstacleCourseTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/NamingObstacleCourseTest.kt @@ -9,7 +9,6 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest -import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.NamingObstacleCourseTestModels.reusedInputOutputShapesModel import software.amazon.smithy.rust.codegen.core.testutil.NamingObstacleCourseTestModels.rustPreludeEnumVariantsModel import software.amazon.smithy.rust.codegen.core.testutil.NamingObstacleCourseTestModels.rustPreludeEnumsModel @@ -19,49 +18,31 @@ import software.amazon.smithy.rust.codegen.core.testutil.NamingObstacleCourseTes class NamingObstacleCourseTest { @Test fun `test Rust prelude operation names compile`() { - clientIntegrationTest( - rustPreludeOperationsModel(), - params = IntegrationTestParams(service = "crate#Config"), - ) { _, _ -> } + clientIntegrationTest(rustPreludeOperationsModel()) { _, _ -> } } @Test fun `test Rust prelude structure names compile`() { - clientIntegrationTest( - rustPreludeStructsModel(), - params = IntegrationTestParams(service = "crate#Config"), - ) { _, _ -> } + clientIntegrationTest(rustPreludeStructsModel()) { _, _ -> } } @Test fun `test Rust prelude enum names compile`() { - clientIntegrationTest( - rustPreludeEnumsModel(), - params = IntegrationTestParams(service = "crate#Config"), - ) { _, _ -> } + clientIntegrationTest(rustPreludeEnumsModel()) { _, _ -> } } @Test fun `test Rust prelude enum variant names compile`() { - clientIntegrationTest( - rustPreludeEnumVariantsModel(), - params = IntegrationTestParams(service = "crate#Config"), - ) { _, _ -> } + clientIntegrationTest(rustPreludeEnumVariantsModel()) { _, _ -> } } @Test fun `test reuse of input and output shapes json`() { - clientIntegrationTest( - reusedInputOutputShapesModel(RestJson1Trait.builder().build()), - params = IntegrationTestParams(service = "test#Service"), - ) + clientIntegrationTest(reusedInputOutputShapesModel(RestJson1Trait.builder().build())) } @Test fun `test reuse of input and output shapes xml`() { - clientIntegrationTest( - reusedInputOutputShapesModel(RestXmlTrait.builder().build()), - params = IntegrationTestParams(service = "test#Service"), - ) + clientIntegrationTest(reusedInputOutputShapesModel(RestXmlTrait.builder().build())) } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt index adf392e3a4..b97328eb8d 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt @@ -297,7 +297,6 @@ class ConstraintsMemberShapeTest { model, runtimeConfig = runtimeConfig, overrideTestDir = dirToUse, - service = "constrainedMemberShape#ConstrainedService", ) val codegenDecorator = CombinedServerCodegenDecorator.fromClasspath( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt index 8c2231ff3f..432ad64fde 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt @@ -11,7 +11,6 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource import software.amazon.smithy.model.Model -import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest import java.util.stream.Stream @@ -199,9 +198,6 @@ internal class RecursiveConstraintViolationsTest { @ParameterizedTest @ArgumentsSource(RecursiveConstraintViolationsTestProvider::class) fun `recursive constraint violation code generation test`(testCase: TestCase) { - serverIntegrationTest( - testCase.model, - params = IntegrationTestParams(service = "com.amazonaws.recursiveconstraintviolations#RecursiveConstraintViolations"), - ) + serverIntegrationTest(testCase.model) } } From b0309adc60a5b68163b73a5391e7a50d04daf687 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 3 Jul 2024 16:46:51 +0200 Subject: [PATCH 37/77] Do not discover smithy-protocol-tests Smithy models when loading model with asSmithyModel --- .../rust/codegen/core/testutil/TestHelpers.kt | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index 13a59f3bae..10e051d49e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.testutil import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.loader.ModelDiscovery import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape @@ -151,7 +152,17 @@ fun String.asSmithyModel( disableValidation: Boolean = false, ): Model { val processed = letIf(!this.trimStart().startsWith("\$version")) { "\$version: ${smithyVersion.dq()}\n$it" } - val assembler = Model.assembler().discoverModels().addUnparsedModel(sourceLocation ?: "test.smithy", processed) + val denyModelsContaining = arrayOf( + // If Smithy protocol test models are in our classpath, don't load them, since they are fairly large and we + // almost never need them. + "smithy-protocol-tests" + ) + val urls = ModelDiscovery.findModels().filter { modelUrl -> denyModelsContaining.none { modelUrl.toString().contains(it) } } + val assembler = Model.assembler() + for (url in urls) { + assembler.addImport(url) + } + assembler.addUnparsedModel(sourceLocation ?: "test.smithy", processed) if (disableValidation) { assembler.disableValidation() } From b6ff29319f5c5a5bbedf3e2a5a59278d45653265 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 3 Jul 2024 16:57:12 +0200 Subject: [PATCH 38/77] Fix CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest, we're instantiating within test --- ...SerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt index b6b4661bf3..d025be05cc 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -243,7 +243,7 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { "SerdeCbor" to CargoDependency.SerdeCbor.toType(), ) - val instantiator = ServerInstantiator(codegenContext, ignoreMissingMembers = true) + val instantiator = ServerInstantiator(codegenContext, ignoreMissingMembers = true, withinTest = true) val rpcV2 = ServerRpcV2CborProtocol(codegenContext) for (operationShape in codegenContext.model.operationShapes) { From 2be4a16abc4851c32ded189c28af32c0a912279e Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 15:36:12 +0200 Subject: [PATCH 39/77] Misc changes in protocol test generation --- .../protocol/ClientProtocolTestGenerator.kt | 6 +++++- .../protocol/ProtocolTestGenerator.kt | 17 +++++++++++------ .../protocol/ServerProtocolTestGenerator.kt | 16 +++++++++++----- ...serGeneratorSerdeRoundTripIntegrationTest.kt | 4 ++-- 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt index b48c3d336a..f88f5b1ae2 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt @@ -80,7 +80,7 @@ class ClientProtocolTestGenerator( get() = AppliesTo.CLIENT override val expectFail: Set get() = ExpectFail - override val runOnly: Set + override val generateOnly: Set get() = emptySet() override val disabledTests: Set get() = emptySet() @@ -115,6 +115,8 @@ class ClientProtocolTestGenerator( } private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) { + logger.info("Generating request test: ${httpRequestTestCase.id}") + if (!protocolSupport.requestSerialization) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -200,6 +202,8 @@ class ClientProtocolTestGenerator( testCase: HttpResponseTestCase, expectedShape: StructureShape, ) { + logger.info("Generating response test: ${testCase.id}") + if (!protocolSupport.responseDeserialization || ( !protocolSupport.errorDeserialization && expectedShape.hasTrait( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index 0dd33c1e01..05fd3daf10 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -64,7 +64,7 @@ abstract class ProtocolTestGenerator { abstract val brokenTests: Set /** Only generate these tests; useful to temporarily set and shorten development cycles */ - abstract val runOnly: Set + abstract val generateOnly: Set /** * These tests are not even attempted to be generated, either because they will not compile @@ -89,6 +89,8 @@ abstract class ProtocolTestGenerator { allMatchingTestCases().flatMap { fixBrokenTestCase(it) } + // Filter afterward in case a fixed broken test is disabled. + .filterMatching() if (allTests.isEmpty()) { return } @@ -109,6 +111,8 @@ abstract class ProtocolTestGenerator { if (!it.isBroken()) { listOf(it) } else { + logger.info("Fixing ${it.kind} test case ${it.id}") + assert(it.expectFail()) val brokenTest = it.findInBroken()!! @@ -160,10 +164,11 @@ abstract class ProtocolTestGenerator { /** Filter out test cases that are disabled or don't match the service protocol. */ private fun List.filterMatching(): List = - if (runOnly.isEmpty()) { + if (generateOnly.isEmpty()) { this.filter { testCase -> testCase.protocol == codegenContext.protocol && !disabledTests.contains(testCase.id) } } else { - this.filter { testCase -> runOnly.contains(testCase.id) } + logger.warning("Generating only specified tests") + this.filter { testCase -> generateOnly.contains(testCase.id) } } private fun TestCase.toFailingTest(): FailingTest = @@ -190,7 +195,7 @@ abstract class ProtocolTestGenerator { val requestTests = operationShape.getTrait()?.getTestCasesFor(appliesTo).orEmpty() .map { TestCase.RequestTest(it) } - return requestTests.filterMatching() + return requestTests } fun responseTestCases(): List { @@ -208,7 +213,7 @@ abstract class ProtocolTestGenerator { ?.getTestCasesFor(appliesTo).orEmpty().map { TestCase.ResponseTest(it, error) } } - return (responseTestsOnOperations + responseTestsOnErrors).filterMatching() + return (responseTestsOnOperations + responseTestsOnErrors) } fun malformedRequestTestCases(): List { @@ -220,7 +225,7 @@ abstract class ProtocolTestGenerator { } else { emptyList() } - return malformedRequestTests.filterMatching() + return malformedRequestTests } /** diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 52e572d29e..88e65e8581 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -225,7 +225,7 @@ class ServerProtocolTestGenerator( get() = ExpectFail override val brokenTests: Set get() = BrokenTests - override val runOnly: Set + override val generateOnly: Set get() = emptySet() override val disabledTests: Set get() = DisabledTests @@ -291,6 +291,8 @@ class ServerProtocolTestGenerator( * an operation's input shape, the resulting shape is of the form we expect, as defined in the test case. */ private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) { + logger.info("Generating request test: ${httpRequestTestCase.id}") + if (!protocolSupport.requestDeserialization) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -334,6 +336,8 @@ class ServerProtocolTestGenerator( testCase: HttpResponseTestCase, shape: StructureShape, ) { + logger.info("Generating response test: ${testCase.id}") + val operationErrorName = "crate::error::${operationSymbol.name}Error" if (!protocolSupport.responseSerialization || ( @@ -366,6 +370,8 @@ class ServerProtocolTestGenerator( * with the given response. */ private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase) { + logger.info("Generating malformed request test: ${testCase.id}") + val (_, outputT) = operationInputOutputTypes[operationShape]!! val panicMessage = "request should have been rejected, but we accepted it; we parsed operation input `{:?}`" @@ -470,7 +476,7 @@ class ServerProtocolTestGenerator( } } - /** Returns the body of the request test. */ + /** Returns the body of the operation handler in a request test. */ private fun checkRequestHandler( operationShape: OperationShape, httpRequestTestCase: HttpRequestTestCase, @@ -478,7 +484,7 @@ class ServerProtocolTestGenerator( val inputShape = operationShape.inputShape(codegenContext.model) val outputShape = operationShape.outputShape(codegenContext.model) - // Construct expected request. + // Construct expected operation input. withBlock("let expected = ", ";") { instantiator.render(this, inputShape, httpRequestTestCase.params, httpRequestTestCase.headers) } @@ -491,9 +497,9 @@ class ServerProtocolTestGenerator( } if (operationShape.errors.isEmpty()) { - rust("response") + rust("output") } else { - rust("Ok(response)") + rust("Ok(output)") } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt index d025be05cc..46d4f92f79 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -212,8 +212,8 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { get() = baseGenerator.expectFail override val brokenTests: Set get() = emptySet() - override val runOnly: Set - get() = baseGenerator.runOnly + override val generateOnly: Set + get() = baseGenerator.generateOnly override val disabledTests: Set get() = baseGenerator.disabledTests From 909f8f1d8e54c5795b39cbdf7f81339487405be9 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 15:47:39 +0200 Subject: [PATCH 40/77] unfix json_rpc10 test --- .../rust/codegen/core/smithy/generators/Instantiator.kt | 4 +--- .../smithy/generators/protocol/ServerProtocolTestGenerator.kt | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index d6ac5ac109..fefd40e13e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -468,9 +468,7 @@ open class Instantiator( */ private fun fillDefaultValue(shape: Shape): Node = when (shape) { - is MemberShape -> - shape.getTrait()?.toNode() - ?: fillDefaultValue(model.expectShape(shape.target)) + is MemberShape -> fillDefaultValue(model.expectShape(shape.target)) // Aggregate shapes. is StructureShape -> Node.objectNode() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 88e65e8581..2a3aa5570b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -492,7 +492,7 @@ class ServerProtocolTestGenerator( checkRequestParams(inputShape, this) // Construct a dummy response. - withBlock("let response = ", ";") { + withBlock("let output = ", ";") { instantiator.render(this, outputShape, Node.objectNode()) } From 8dc346b694eed083c69a475ced9cac9652605d73 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 16:08:41 +0200 Subject: [PATCH 41/77] fixes --- .../smithy/rust/codegen/core/smithy/generators/Instantiator.kt | 2 +- rust-runtime/aws-smithy-http-server/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index fefd40e13e..a5b09e1192 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -494,7 +494,7 @@ class PrimitiveInstantiator( ) { val codegenScope = listOf( "DateTime" to RuntimeType.dateTime(runtimeConfig), - "ByteStream" to RuntimeType.byteStream(runtimeConfig), + "Bytestream" to RuntimeType.byteStream(runtimeConfig), "Blob" to RuntimeType.blob(runtimeConfig), "SmithyJson" to RuntimeType.smithyJson(runtimeConfig), "SmithyTypes" to RuntimeType.smithyTypes(runtimeConfig), diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 2a6a4b8efd..ace6d0b692 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-http-server" -version = "0.62.1" +version = "0.62.2" authors = ["Smithy Rust Server "] edition = "2021" license = "Apache-2.0" From 665c8b53b30a20c72103e69a69427400b1c6af98 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 16:24:57 +0200 Subject: [PATCH 42/77] address TODO --- .../serialize/CborSerializerGenerator.kt | 6 +--- rust-runtime/aws-smithy-cbor/src/encode.rs | 32 +++++++++++++++---- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index f5d6403a46..d91ed7de58 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -363,14 +363,10 @@ class CborSerializerGenerator( for (customization in customizations) { customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) } - // `.expect()` safety: `From for usize` is not in the standard library, but the conversion should be - // infallible (unless we ever have 128-bit machines I guess). - // See https://users.rust-lang.org/t/cant-convert-usize-to-u64/6243. - // TODO Point to a `static` to not inflate the binary. rust( """ encoder.array( - (${context.valueExpression.asValue()}).len().try_into().expect("`usize` to `u64` conversion failed") + (${context.valueExpression.asValue()}).len() ); """, ) diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs index 60663a9a7b..20ba397436 100644 --- a/rust-runtime/aws-smithy-cbor/src/encode.rs +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -31,7 +31,8 @@ pub struct Encoder { encoder: minicbor::Encoder>, } -// TODO docs +/// We always write to a `Vec`, which is infallible in `minicbor`. +/// https://docs.rs/minicbor/latest/minicbor/encode/write/trait.Write.html#impl-Write-for-Vec%3Cu8%3E const INFALLIBLE_WRITE: &str = "write failed"; impl Encoder { @@ -42,12 +43,6 @@ impl Encoder { } delegate_method! { - /// Writes a fixed length array of given length. - array => array(len: u64); - /// Used when we know the size in advance, i.e.: - /// - when a struct has all non-`Option`al members. - /// - when serializing `union` shapes (they can only have one member set). - map => map(len: u64); /// Used when it's not cheap to calculate the size, i.e. when the struct has one or more /// `Option`al members. begin_map => begin_map(); @@ -78,6 +73,29 @@ impl Encoder { self } + /// Writes a fixed length array of given length. + pub fn array(&mut self, len: usize) -> &mut Self { + self.encoder + // `.expect()` safety: `From for usize` is not in the standard library, + // but the conversion should be infallible (unless we ever have 128-bit machines I + // guess). . + .array(len.try_into().expect("`usize` to `u64` conversion failed")) + .expect(INFALLIBLE_WRITE); + self + } + + /// Writes a fixed length map of given length. + /// Used when we know the size in advance, i.e.: + /// - when a struct has all non-`Option`al members. + /// - when serializing `union` shapes (they can only have one member set). + /// - when serializing a `map` shape. + pub fn map(&mut self, len: usize) -> &mut Self { + self.encoder + .array(len.try_into().expect("`usize` to `u64` conversion failed")) + .expect(INFALLIBLE_WRITE); + self + } + pub fn timestamp(&mut self, x: &DateTime) -> &mut Self { self.encoder .tag(minicbor::data::Tag::Timestamp) From d75c9ec78acec673dd1b3cf1fb29b6b9c8e5c177 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 17:05:01 +0200 Subject: [PATCH 43/77] address Clippy warnings --- .../protocols/parse/CborParserGenerator.kt | 1 + .../serialize/CborSerializerGenerator.kt | 16 ++-------------- .../src/protocol/rpc_v2/router.rs | 10 ++++------ 3 files changed, 7 insertions(+), 20 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 707b6a502c..4a9db63091 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -277,6 +277,7 @@ class CborParserGenerator( val returnSymbolToParse = returnSymbolToParse(shape) rustBlockTemplate( """ + ##[allow(clippy::match_single_binding)] fn pair( decoder: &mut #{Decoder} ) -> Result<#{UnionSymbol}, #{Error}> diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index d91ed7de58..86ceece652 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -363,13 +363,7 @@ class CborSerializerGenerator( for (customization in customizations) { customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) } - rust( - """ - encoder.array( - (${context.valueExpression.asValue()}).len() - ); - """, - ) + rust("encoder.array((${context.valueExpression.asValue()}).len());") val itemName = safeName("item") rustBlock("for $itemName in ${context.valueExpression.asRef()}") { serializeMember(MemberContext.collectionMember(context, itemName)) @@ -382,13 +376,7 @@ class CborSerializerGenerator( for (customization in customizations) { customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) } - rust( - """ - encoder.map( - (${context.valueExpression.asValue()}).len().try_into().expect("`usize` to `u64` conversion failed") - ); - """, - ) + rust("encoder.map((${context.valueExpression.asValue()}).len());") rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { val keyExpression = "$keyName.as_str()" serializeMember(MemberContext.mapMember(context, keyExpression, valueName)) diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs index 3e73b8b9e0..60e3b7c286 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs @@ -55,11 +55,11 @@ pub struct RpcV2Router { /// Requests for the `rpcv2` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` /// header. An `rpcv2` request is malformed if it contains either of these headers. Server-side /// implementations MUST reject such requests for security reasons. -const FORBIDDEN_HEADERS: &'static [&'static str] = &["x-amz-target", "x-amzn-target"]; +const FORBIDDEN_HEADERS: &[&str] = &["x-amz-target", "x-amzn-target"]; /// Matches the `Identifier` ABNF rule in /// . -const IDENTIFIER_PATTERN: &'static str = r#"((_+([A-Za-z]|[0-9]))|[A-Za-z])[A-Za-z0-9_]*"#; +const IDENTIFIER_PATTERN: &str = r#"((_+([A-Za-z]|[0-9]))|[A-Za-z])[A-Za-z0-9_]*"#; impl RpcV2Router { // TODO Consider building a nom parser @@ -171,9 +171,7 @@ pub enum WireFormatError { /// by the protocol (see [`WireFormat`]). fn parse_wire_format_from_header(headers: &HeaderMap) -> Result { let header = headers.get("smithy-protocol").ok_or(WireFormatError::HeaderNotFound)?; - let header = header - .to_str() - .map_err(|e| WireFormatError::HeaderValueNotVisibleAscii(e))?; + let header = header.to_str().map_err(WireFormatError::HeaderValueNotVisibleAscii)?; let captures = RpcV2Router::<()>::wire_format_regex() .captures(header) .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; @@ -217,7 +215,7 @@ impl Router for RpcV2Router { // Some headers are not allowed. let request_has_forbidden_header = FORBIDDEN_HEADERS - .into_iter() + .iter() .any(|&forbidden_header| request.headers().contains_key(forbidden_header)); if request_has_forbidden_header { return Err(Error::ForbiddenHeaders); From 96c740b9c13a736241d6929b1ae18da972521e75 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 17:21:52 +0200 Subject: [PATCH 44/77] Set aws-smithy-cbor version to 0.60.0 --- rust-runtime/aws-smithy-cbor/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml index 49c176a8f2..913768e658 100644 --- a/rust-runtime/aws-smithy-cbor/Cargo.toml +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-cbor" -version = "0.0.0-smithy-rs-head" +version = "0.60.0" authors = [ "AWS Rust SDK Team ", "David Pérez ", From 83bc91ab9735bdd82a20e2fc25aa2f923748ee2e Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 17:35:14 +0200 Subject: [PATCH 45/77] fix bug encoding cbor maps --- rust-runtime/aws-smithy-cbor/src/encode.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs index 20ba397436..b01628eaf2 100644 --- a/rust-runtime/aws-smithy-cbor/src/encode.rs +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -91,7 +91,7 @@ impl Encoder { /// - when serializing a `map` shape. pub fn map(&mut self, len: usize) -> &mut Self { self.encoder - .array(len.try_into().expect("`usize` to `u64` conversion failed")) + .map(len.try_into().expect("`usize` to `u64` conversion failed")) .expect(INFALLIBLE_WRITE); self } From 65b841a5931f39b93fe54a3b43c8d3c0ae0f6171 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 18:17:16 +0200 Subject: [PATCH 46/77] Failing tests, broken hyperlink --- .../generators/protocol/ServerProtocolTestGenerator.kt | 5 +++++ rust-runtime/aws-smithy-cbor/src/encode.rs | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 2a3aa5570b..3050688fbd 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -39,6 +39,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Servi import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_11 import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON_VALIDATION +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR_EXTRAS import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors @@ -151,6 +152,10 @@ class ServerProtocolTestGenerator( ), // TODO(https://github.com/smithy-lang/smithy-rs/issues/3723): This affects all protocols FailingTest.MalformedRequestTest(RPC_V2_CBOR_EXTRAS, "AdditionalTokensEmptyStruct"), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3339) + FailingTest.ResponseTest(RPC_V2_CBOR, "RpcV2CborServerPopulatesDefaultsInResponseWhenMissingInParams"), + // TODO: We need to be able to configure instantiator so that it uses default _modeled_ values; `""` is not a valid enum value for `defaultEnum`. + FailingTest.ResponseTest(RPC_V2_CBOR, "RpcV2CborServerPopulatesDefaultsWhenMissingInRequestBody"), ) private val BrokenTests: diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs index b01628eaf2..3d7cfd5f6e 100644 --- a/rust-runtime/aws-smithy-cbor/src/encode.rs +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -32,7 +32,7 @@ pub struct Encoder { } /// We always write to a `Vec`, which is infallible in `minicbor`. -/// https://docs.rs/minicbor/latest/minicbor/encode/write/trait.Write.html#impl-Write-for-Vec%3Cu8%3E +/// const INFALLIBLE_WRITE: &str = "write failed"; impl Encoder { From 72daede77a530663ac9955da8285b303adfa1483 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 18:21:57 +0200 Subject: [PATCH 47/77] clippy, formatting --- .../protocols/parse/CborParserGenerator.kt | 118 +++++++++--------- 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 4a9db63091..9473245a60 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -200,6 +200,7 @@ class CborParserGenerator( ) = writable { rustBlockTemplate( """ + ##[allow(clippy::match_single_binding)] fn pair( mut builder: #{Builder}, decoder: &mut #{Decoder} @@ -277,11 +278,10 @@ class CborParserGenerator( val returnSymbolToParse = returnSymbolToParse(shape) rustBlockTemplate( """ - ##[allow(clippy::match_single_binding)] - fn pair( - decoder: &mut #{Decoder} - ) -> Result<#{UnionSymbol}, #{Error}> - """, + fn pair( + decoder: &mut #{Decoder} + ) -> Result<#{UnionSymbol}, #{Error}> + """, *codegenScope, "UnionSymbol" to returnSymbolToParse.symbol, ) { @@ -292,11 +292,11 @@ class CborParserGenerator( if (member.isTargetUnit()) { rust( """ - ${member.memberName.dq()} => { - decoder.skip()?; - #T::$variantName - } - """, + ${member.memberName.dq()} => { + decoder.skip()?; + #T::$variantName + } + """, returnSymbolToParse.symbol, ) } else { @@ -310,11 +310,11 @@ class CborParserGenerator( true -> rustTemplate( """ - _ => { - decoder.skip()?; - Some(#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME}) - } - """, + _ => { + decoder.skip()?; + Some(#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME}) + } + """, "Union" to returnSymbolToParse.symbol, *codegenScope, ) @@ -518,16 +518,16 @@ class CborParserGenerator( rustTemplate( """ - pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { - #{ListMemberParserFn:W} - - #{InitContainerWritable:W} + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { + #{ListMemberParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeListLoop:W} - #{DecodeListLoop:W} - - Ok(list) - } - """, + Ok(list) + } + """, "ReturnType" to returnSymbol, "ListMemberParserFn" to listMemberParserFn( @@ -561,16 +561,16 @@ class CborParserGenerator( rustTemplate( """ - pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { - #{MapPairParserFn:W} + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { + #{MapPairParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeMapLoop:W} - #{InitContainerWritable:W} - - #{DecodeMapLoop:W} - - Ok(map) - } - """, + Ok(map) + } + """, "ReturnType" to returnSymbol, "MapPairParserFn" to mapPairParserFnWritable( @@ -602,12 +602,12 @@ class CborParserGenerator( rustTemplate( """ - #{StructurePairParserFn:W} - - let mut builder = #{Builder}::default(); - - #{DecodeStructureMapLoop:W} - """, + #{StructurePairParserFn:W} + + let mut builder = #{Builder}::default(); + + #{DecodeStructureMapLoop:W} + """, *codegenScope, "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), "Builder" to builderSymbol, @@ -631,30 +631,30 @@ class CborParserGenerator( protocolFunctions.deserializeFn(shape) { fnName -> rustTemplate( """ - pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{UnionSymbol}, #{Error}> { - #{UnionPairParserFnWritable} - - match decoder.map()? { - None => { - let variant = pair(decoder)?; - match decoder.datatype()? { - #{SmithyCbor}::data::Type::Break => { - decoder.skip()?; - Ok(variant) - } - ty => Err( - #{Error}::unexpected_union_variant( - ty, - decoder.position(), + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{UnionSymbol}, #{Error}> { + #{UnionPairParserFnWritable} + + match decoder.map()? { + None => { + let variant = pair(decoder)?; + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + Ok(variant) + } + ty => Err( + #{Error}::unexpected_union_variant( + ty, + decoder.position(), + ), ), - ), + } } + Some(1) => pair(decoder), + Some(_) => Err(#{Error}::mixed_union_variants(decoder.position())) } - Some(1) => pair(decoder), - Some(_) => Err(#{Error}::mixed_union_variants(decoder.position())) } - } - """, + """, "UnionSymbol" to returnSymbolToParse.symbol, "UnionPairParserFnWritable" to unionPairParserFnWritable(shape), *codegenScope, From 6853b96928693cee10fbab8f387d6671c1012e01 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 18:23:18 +0200 Subject: [PATCH 48/77] ./gradlew ktlintFormat --- .../core/smithy/generators/Instantiator.kt | 30 +++++++++---------- .../rust/codegen/core/testutil/TestHelpers.kt | 18 +++++++---- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index a5b09e1192..c09bc545fc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -32,7 +32,6 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.DefaultTrait import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.HttpHeaderTrait import software.amazon.smithy.model.traits.HttpPayloadTrait @@ -62,7 +61,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectMember import software.amazon.smithy.rust.codegen.core.util.expectTrait -import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.letIf @@ -492,19 +490,21 @@ class PrimitiveInstantiator( private val symbolProvider: SymbolProvider, withinTest: Boolean = false, ) { - val codegenScope = listOf( - "DateTime" to RuntimeType.dateTime(runtimeConfig), - "Bytestream" to RuntimeType.byteStream(runtimeConfig), - "Blob" to RuntimeType.blob(runtimeConfig), - "SmithyJson" to RuntimeType.smithyJson(runtimeConfig), - "SmithyTypes" to RuntimeType.smithyTypes(runtimeConfig), - ).map { it.first to - if (withinTest) { - it.second.toDevDependencyType() - } else { - it.second - } - }.toTypedArray() + val codegenScope = + listOf( + "DateTime" to RuntimeType.dateTime(runtimeConfig), + "Bytestream" to RuntimeType.byteStream(runtimeConfig), + "Blob" to RuntimeType.blob(runtimeConfig), + "SmithyJson" to RuntimeType.smithyJson(runtimeConfig), + "SmithyTypes" to RuntimeType.smithyTypes(runtimeConfig), + ).map { + it.first to + if (withinTest) { + it.second.toDevDependencyType() + } else { + it.second + } + }.toTypedArray() fun instantiate( shape: SimpleShape, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index 10e051d49e..42997a8012 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -152,12 +152,18 @@ fun String.asSmithyModel( disableValidation: Boolean = false, ): Model { val processed = letIf(!this.trimStart().startsWith("\$version")) { "\$version: ${smithyVersion.dq()}\n$it" } - val denyModelsContaining = arrayOf( - // If Smithy protocol test models are in our classpath, don't load them, since they are fairly large and we - // almost never need them. - "smithy-protocol-tests" - ) - val urls = ModelDiscovery.findModels().filter { modelUrl -> denyModelsContaining.none { modelUrl.toString().contains(it) } } + val denyModelsContaining = + arrayOf( + // If Smithy protocol test models are in our classpath, don't load them, since they are fairly large and we + // almost never need them. + "smithy-protocol-tests", + ) + val urls = + ModelDiscovery.findModels().filter { modelUrl -> + denyModelsContaining.none { + modelUrl.toString().contains(it) + } + } val assembler = Model.assembler() for (url in urls) { assembler.addImport(url) From 8c65e7aa79a6fbc851067de42fcc901a6b70c1bd Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 18:35:20 +0200 Subject: [PATCH 49/77] AWS JSON 1.x server request specs can be `&'static str`s This is technically a breaking change because we stop implementing `FromIterator<(String, S)>`. --- .../server/smithy/generators/protocol/ServerProtocol.kt | 5 ++--- .../aws-smithy-http-server/src/protocol/aws_json/router.rs | 6 +++--- .../aws-smithy-http-server/src/protocol/rpc_v2/router.rs | 1 + 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 1397502867..bc12e70b18 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -171,11 +171,10 @@ class ServerAwsJsonProtocol( serviceName: String, requestSpecModule: RuntimeType, ) = writable { - rust("""String::from("$serviceName.$operationName")""") + rust(""""$serviceName.$operationName"""") } - // TODO This could technically be `&static str` right? - override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.String + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.StaticStr override fun serverRouterRuntimeConstructor() = when (version) { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs index 74fc44b8df..298d75dd75 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs @@ -47,7 +47,7 @@ pub(crate) const ROUTE_CUTOFF: usize = 15; /// [AWS JSON 1.1]: https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html #[derive(Debug, Clone)] pub struct AwsJsonRouter { - routes: TinyMap, + routes: TinyMap<&'static str, S, ROUTE_CUTOFF>, } impl AwsJsonRouter { @@ -107,9 +107,9 @@ where } } -impl FromIterator<(String, S)> for AwsJsonRouter { +impl FromIterator<(&'static str, S)> for AwsJsonRouter { #[inline] - fn from_iter>(iter: T) -> Self { + fn from_iter>(iter: T) -> Self { Self { routes: iter.into_iter().collect(), } diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs index 60e3b7c286..1d16634b57 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs @@ -243,6 +243,7 @@ impl Router for RpcV2Router { } impl FromIterator<(&'static str, S)> for RpcV2Router { + #[inline] fn from_iter>(iter: T) -> Self { Self { routes: iter.into_iter().collect(), From 61f927aff45e818e75ccc4ff52f7ac6568418ccc Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 4 Jul 2024 18:35:20 +0200 Subject: [PATCH 50/77] AWS JSON 1.x server request specs can be `&'static str`s This is technically a breaking change because we stop implementing `FromIterator<(String, S)>`. --- .../server/smithy/generators/protocol/ServerProtocol.kt | 4 ++-- .../aws-smithy-http-server/src/protocol/aws_json/router.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 2fb76bf879..6434c0290c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -163,10 +163,10 @@ class ServerAwsJsonProtocol( serviceName: String, requestSpecModule: RuntimeType, ) = writable { - rust("""String::from("$serviceName.$operationName")""") + rust(""""$serviceName.$operationName"""") } - override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.String + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.StaticStr override fun serverRouterRuntimeConstructor() = when (version) { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs index 295c670b27..1d09064303 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs @@ -47,7 +47,7 @@ const ROUTE_CUTOFF: usize = 15; /// [AWS JSON 1.1]: https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html #[derive(Debug, Clone)] pub struct AwsJsonRouter { - routes: TinyMap, + routes: TinyMap<&'static str, S, ROUTE_CUTOFF>, } impl AwsJsonRouter { @@ -106,9 +106,9 @@ where } } -impl FromIterator<(String, S)> for AwsJsonRouter { +impl FromIterator<(&'static str, S)> for AwsJsonRouter { #[inline] - fn from_iter>(iter: T) -> Self { + fn from_iter>(iter: T) -> Self { Self { routes: iter.into_iter().collect(), } From ec503322a216c2c7da6d01c03e6418a28d2c2f03 Mon Sep 17 00:00:00 2001 From: david-perez Date: Fri, 5 Jul 2024 12:09:57 +0200 Subject: [PATCH 51/77] Update cbor-diag to avoid minimal-versions issues --- rust-runtime/aws-smithy-protocol-test/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index e50222b200..c12c5aeca5 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -11,7 +11,7 @@ repository = "https://github.com/smithy-lang/smithy-rs" # Not perfect for our needs, but good for now assert-json-diff = "1.1" base64-simd = "0.8" -cbor-diag = "0.1" +cbor-diag = "0.1.12" serde_cbor = "0.11" http = "0.2.1" pretty_assertions = "1.3" From 46671d1a390088a0ccc91d0eb53532fe7afa76ea Mon Sep 17 00:00:00 2001 From: david-perez Date: Fri, 5 Jul 2024 12:41:35 +0200 Subject: [PATCH 52/77] Add aws-smithy-cbor to smithy runtime, start versioning at 0.60.6 --- buildSrc/src/main/kotlin/CrateSet.kt | 1 + rust-runtime/aws-smithy-cbor/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/buildSrc/src/main/kotlin/CrateSet.kt b/buildSrc/src/main/kotlin/CrateSet.kt index bc90115443..253bfa08ca 100644 --- a/buildSrc/src/main/kotlin/CrateSet.kt +++ b/buildSrc/src/main/kotlin/CrateSet.kt @@ -56,6 +56,7 @@ object CrateSet { val SMITHY_RUNTIME_COMMON = listOf( "aws-smithy-async", + "aws-smithy-cbor", "aws-smithy-checksums", "aws-smithy-compression", "aws-smithy-client", diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml index 913768e658..aa977912f7 100644 --- a/rust-runtime/aws-smithy-cbor/Cargo.toml +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-cbor" -version = "0.60.0" +version = "0.60.6" authors = [ "AWS Rust SDK Team ", "David Pérez ", From a1a823c25bc3ca07fb6a8ecc2890ded6ea8836f9 Mon Sep 17 00:00:00 2001 From: david-perez Date: Fri, 5 Jul 2024 12:42:24 +0200 Subject: [PATCH 53/77] Remove sample RPCv2 service --- .../examples/rpcv2-service/Cargo.toml | 14 -------- .../examples/rpcv2-service/src/main.rs | 32 ------------------- 2 files changed, 46 deletions(-) delete mode 100644 rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml delete mode 100644 rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs diff --git a/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml deleted file mode 100644 index e6ae4c9f6c..0000000000 --- a/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "rpcv2-service" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -# Local paths -aws-smithy-http-server = { path = "../../" } -hyper = "0.14.24" -tokio = "1.25.0" -rpcv2-server-sdk = { path = "../rpcv2-server-sdk/", package = "rpcv2" } -rpcv2-pokemon-client = { path = "../rpcv2-pokemon-client/" } diff --git a/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs deleted file mode 100644 index 96e2eb0221..0000000000 --- a/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -use std::net::SocketAddr; - -use aws_smithy_http_server::{body::Body, routing::Route}; -use rpcv2_server_sdk::{error, input, output, RpcV2Service}; - -async fn handler( - input: input::RpcV2OperationInput, -) -> Result { - println!("{input:#?}"); - - todo!() -} - -#[tokio::main] -async fn main() { - let service: RpcV2Service> = rpcv2_server_sdk::RpcV2Service::builder_without_plugins() - .rpc_v2_operation(handler) - .build() - .unwrap(); - - let server = service.into_make_service(); - let bind: SocketAddr = "127.0.0.1:6969" - .parse() - .expect("unable to parse the server bind address and port"); - - println!("Binding {bind}"); - hyper::Server::bind(&bind).serve(server).await.unwrap(); -} From bdb57473a9194fd3b6795a795540510158add7b3 Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 8 Jul 2024 15:54:09 +0200 Subject: [PATCH 54/77] fix --- .../amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index da5d742647..e508b29a48 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -272,6 +272,9 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val U64 = std.resolve("primitive::u64") val Vec = std.resolve("vec::Vec") + // primitve types + val StaticStr = RuntimeType("&'static str") + // external cargo dependency types val Bytes = CargoDependency.Bytes.toType().resolve("Bytes") val Http = CargoDependency.Http.toType() From 9f40fa51efd7ea680942aa56dab5c52d6481ea69 Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 8 Jul 2024 15:55:38 +0200 Subject: [PATCH 55/77] fix typo --- .../amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index e508b29a48..903bc85f86 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -272,7 +272,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val U64 = std.resolve("primitive::u64") val Vec = std.resolve("vec::Vec") - // primitve types + // primitive types val StaticStr = RuntimeType("&'static str") // external cargo dependency types From 2f89061fefe64eba014132bb3b6da5357b9f7a7e Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 8 Jul 2024 16:25:04 +0200 Subject: [PATCH 56/77] RpcV2CborServerPopulatesDefaultsWhenMissingInRequestBody expect fail --- .../generators/protocol/ServerProtocolTestGenerator.kt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 3050688fbd..6e1d6eb4e2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -154,8 +154,9 @@ class ServerProtocolTestGenerator( FailingTest.MalformedRequestTest(RPC_V2_CBOR_EXTRAS, "AdditionalTokensEmptyStruct"), // TODO(https://github.com/smithy-lang/smithy-rs/issues/3339) FailingTest.ResponseTest(RPC_V2_CBOR, "RpcV2CborServerPopulatesDefaultsInResponseWhenMissingInParams"), - // TODO: We need to be able to configure instantiator so that it uses default _modeled_ values; `""` is not a valid enum value for `defaultEnum`. - FailingTest.ResponseTest(RPC_V2_CBOR, "RpcV2CborServerPopulatesDefaultsWhenMissingInRequestBody"), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3743): We need to be able to configure + // instantiator so that it uses default _modeled_ values; `""` is not a valid enum value for `defaultEnum`. + FailingTest.RequestTest(RPC_V2_CBOR, "RpcV2CborServerPopulatesDefaultsWhenMissingInRequestBody"), ) private val BrokenTests: From bc65c638ff1db8127289fd7832021e213e394cea Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 8 Jul 2024 16:28:32 +0200 Subject: [PATCH 57/77] fix --- .../aws-smithy-http-server/src/protocol/aws_json/router.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs index 1d09064303..df304d823f 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs @@ -126,11 +126,7 @@ mod tests { #[tokio::test] async fn simple_routing() { let routes = vec![("Service.Operation")]; - let router: AwsJsonRouter<_> = routes - .clone() - .into_iter() - .map(|operation| (operation.to_string(), ())) - .collect(); + let router: AwsJsonRouter<_> = routes.clone().into_iter().map(|operation| (operation, ())).collect(); let mut headers = HeaderMap::new(); headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation")); From 792c21582156d122bc630584c97ff3410cceba90 Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 8 Jul 2024 16:31:03 +0200 Subject: [PATCH 58/77] bump aws-smithy-http-server --- rust-runtime/aws-smithy-http-server/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index c8b619dcfc..a63a125941 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-http-server" -version = "0.63.0" +version = "0.63.1" authors = ["Smithy Rust Server "] edition = "2021" license = "Apache-2.0" From f4ab98bc6d7992cc1e472884606778684c324e1d Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 8 Jul 2024 16:59:33 +0200 Subject: [PATCH 59/77] derive thiserror::Error for RPC v2 runtime errors --- .../src/protocol/rpc_v2/runtime_error.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs index 8375611639..9b177a71de 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs @@ -11,12 +11,24 @@ use http::StatusCode; use super::rejection::{RequestRejection, ResponseRejection}; -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum RuntimeError { + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::Serialization`] + #[error("request failed to deserialize or response failed to serialize: {0}")] Serialization(crate::Error), + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::InternalFailure`] + #[error("internal failure: {0}")] InternalFailure(crate::Error), + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::NotAcceptable`] + #[error("not acceptable request: request contains an `Accept` header with a MIME type, and the server cannot return a response body adhering to that MIME type")] NotAcceptable, + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::UnsupportedMediaType`] + #[error("unsupported media type: request does not contain the expected `Content-Type` header value")] UnsupportedMediaType, + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::Validation`] + #[error( + "validation failure: operation input contains data that does not adhere to the modeled constraints: {0:?}" + )] Validation(Vec), } From 6c6384e97e452ab5279d950f741999e6d98e2ce1 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 9 Jul 2024 14:11:42 +0200 Subject: [PATCH 60/77] Add some context to TODOs --- .../serialize/CborSerializerGenerator.kt | 4 ++-- .../src/protocol/aws_json_10/router.rs | 2 ++ .../src/protocol/aws_json_11/router.rs | 2 ++ .../src/protocol/rest_json_1/router.rs | 2 ++ .../src/protocol/rest_xml/router.rs | 3 ++- .../src/protocol/rpc_v2/router.rs | 18 +++++++----------- 6 files changed, 17 insertions(+), 14 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 86ceece652..9a72d8f564 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -268,8 +268,8 @@ class CborSerializerGenerator( "StructureSymbol" to symbolProvider.toSymbol(context.shape), *codegenScope, ) { - // TODO If all members are non-`Option`-al, we know AOT the map's size and can use `.map()` - // instead of `.begin_map()` for efficiency. Add test. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3745) If all members are non-`Option`-al, + // we know AOT the map's size and can use `.map()` instead of `.begin_map()` for efficiency. rust("encoder.begin_map();") for (customization in customizations) { customization.section( diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs index 30a28d6255..ac963ffe51 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs @@ -12,6 +12,8 @@ use super::AwsJson1_0; pub use crate::protocol::aws_json::router::*; +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs index 5ebd1002f2..2e3e16d8ad 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs @@ -12,6 +12,8 @@ use super::AwsJson1_1; pub use crate::protocol::aws_json::router::*; +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs index 023b43031c..939b1bb6ec 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs @@ -12,6 +12,8 @@ use super::RestJson1; pub use crate::protocol::rest::router::*; +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs index 529a3d19a2..e684ced4de 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs @@ -13,7 +13,8 @@ use super::RestXml; pub use crate::protocol::rest::router::*; -/// An AWS REST routing error. +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs index 1d16634b57..f2798f0b69 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs @@ -62,7 +62,7 @@ const FORBIDDEN_HEADERS: &[&str] = &["x-amz-target", "x-amzn-target"]; const IDENTIFIER_PATTERN: &str = r#"((_+([A-Za-z]|[0-9]))|[A-Za-z])[A-Za-z0-9_]*"#; impl RpcV2Router { - // TODO Consider building a nom parser + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3748) Consider building a nom parser. fn uri_path_regex() -> &'static Regex { // Every request for the `rpcv2Cbor` protocol MUST be sent to a URL with the // following form: `{prefix?}/service/{serviceName}/operation/{operationName}` @@ -124,24 +124,20 @@ impl RpcV2Router { } } -// TODO: Implement (current body copied from the rest xml impl) -// and document. -/// A Smithy RPC V2 routing error. +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { - Error::NotFound => http::Response::builder() + Error::MethodNotAllowed => method_disallowed(), + _ => http::Response::builder() .status(http::StatusCode::NOT_FOUND) - // TODO - .header(http::header::CONTENT_TYPE, "application/xml") + .header(http::header::CONTENT_TYPE, "application/cbor") .extension(RuntimeErrorExtension::new( UNKNOWN_OPERATION_EXCEPTION.to_string(), )) .body(empty()) - .expect("invalid HTTP response for REST XML routing error; please file a bug report under https://github.com/awslabs/smithy-rs/issues"), - Error::MethodNotAllowed => method_disallowed(), - // TODO - _ => todo!(), + .expect("invalid HTTP response for RPCv2 CBOR routing error; please file a bug report under https://github.com/awslabs/smithy-rs/issues"), } } } From 132f8b46d6e094e354be6666ae95d66a96a809b3 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 9 Jul 2024 19:42:39 +0200 Subject: [PATCH 61/77] Address more TODOs --- .../smithy/protocols/PythonServerProtocolLoader.kt | 4 ++-- .../server/smithy/protocols/ServerProtocolLoader.kt | 7 +++++-- .../server/smithy/protocols/ServerRpcV2CborFactory.kt | 10 ++++++++-- rust-runtime/aws-smithy-cbor/Cargo.toml | 2 +- rust-runtime/aws-smithy-cbor/README.md | 3 ++- rust-runtime/aws-smithy-cbor/src/decode.rs | 10 +++++++--- rust-runtime/aws-smithy-cbor/src/encode.rs | 4 +++- 7 files changed, 28 insertions(+), 12 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt index 4e41cc8ed3..f50e236337 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt @@ -56,10 +56,10 @@ class PythonServerAfterDeserializedMemberJsonParserCustomization(private val run } /** - * Customization class used to force casting a non primitive type into one overriden by a new symbol provider, + * Customization class used to force casting a non-primitive type into one overridden by a new symbol provider, * by explicitly calling `into()` on it. */ -class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() : +class PythonServerAfterDeserializedMemberServerHttpBoundCustomization : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index 446f805851..ba7c0104ea 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -80,8 +80,11 @@ class ServerProtocolLoader(supportedProtocols: ProtocolMap { +class ServerRpcV2CborFactory( + private val additionalServerHttpBoundProtocolCustomizations: List = emptyList(), +) : ProtocolGeneratorFactory { override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRpcV2CborProtocol(codegenContext) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = - ServerHttpBoundProtocolGenerator(codegenContext, ServerRpcV2CborProtocol(codegenContext)) + ServerHttpBoundProtocolGenerator( + codegenContext, + ServerRpcV2CborProtocol(codegenContext), + additionalServerHttpBoundProtocolCustomizations, + ) override fun support(): ProtocolSupport { return ProtocolSupport( diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml index aa977912f7..fe660a1338 100644 --- a/rust-runtime/aws-smithy-cbor/Cargo.toml +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -11,7 +11,7 @@ license = "Apache-2.0" repository = "https://github.com/awslabs/smithy-rs" [dependencies.minicbor] -version = "0.19.1" # TODO Update +version = "0.24.2" features = [ # To write to a `Vec`: https://docs.rs/minicbor/latest/minicbor/encode/write/trait.Write.html#impl-Write-for-Vec%3Cu8%3E "alloc", diff --git a/rust-runtime/aws-smithy-cbor/README.md b/rust-runtime/aws-smithy-cbor/README.md index 4d0ad07413..29b59ecd3f 100644 --- a/rust-runtime/aws-smithy-cbor/README.md +++ b/rust-runtime/aws-smithy-cbor/README.md @@ -1,6 +1,7 @@ # aws-smithy-cbor -TODO +JSON serialization and deserialization primitives for clients and servers +generated by [smithy-rs](https://github.com/smithy-lang/smithy-rs). This crate is part of the [AWS SDK for Rust](https://awslabs.github.io/aws-sdk-rust/) and the [smithy-rs](https://github.com/awslabs/smithy-rs) code generator. In most cases, it should not be used directly. diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs index 27ada4f7c9..f3123e229d 100644 --- a/rust-runtime/aws-smithy-cbor/src/decode.rs +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -50,8 +50,11 @@ impl DeserializeError { /// Unknown union variant was detected. Servers reject unknown union varaints. pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self { Self { - _inner: Error::message(format!("encountered unknown union variant {}", variant_name)) - .at(at), + _inner: Error::message(format!( + "encountered unknown union variant {}", + variant_name + )) + .at(at), } } @@ -206,8 +209,9 @@ impl<'b> Decoder<'b> { /// a `DeserializeError` error is returned. pub fn timestamp(&mut self) -> Result { let tag = self.decoder.tag().map_err(DeserializeError::new)?; + let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp); - if !matches!(tag, minicbor::data::Tag::Timestamp) { + if tag != timestamp_tag { Err(DeserializeError::new(Error::message( "expected timestamp tag", ))) diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs index 3d7cfd5f6e..f9817df984 100644 --- a/rust-runtime/aws-smithy-cbor/src/encode.rs +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -98,7 +98,9 @@ impl Encoder { pub fn timestamp(&mut self, x: &DateTime) -> &mut Self { self.encoder - .tag(minicbor::data::Tag::Timestamp) + .tag(minicbor::data::Tag::from( + minicbor::data::IanaTag::Timestamp, + )) .expect(INFALLIBLE_WRITE); self.encoder.f64(x.as_secs_f64()).expect(INFALLIBLE_WRITE); self From c5794c7152465077784f4ab043ebb92c2b5a72ae Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 12:11:40 +0200 Subject: [PATCH 62/77] Bump crate version numbers --- rust-runtime/aws-smithy-http-server/Cargo.toml | 2 +- rust-runtime/aws-smithy-protocol-test/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 922410592c..73c2ba42d1 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-http-server" -version = "0.63.1" +version = "0.63.2" authors = ["Smithy Rust Server "] edition = "2021" license = "Apache-2.0" diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index c12c5aeca5..64b99c4869 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-protocol-test" -version = "0.60.8" +version = "0.61.0" authors = ["AWS Rust SDK Team ", "Russell Cohen "] description = "A collection of library functions to validate HTTP requests against Smithy protocol tests." edition = "2021" From e33b6b98dd49d99f65f24dbf6a27d536f08e8efa Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 12:33:11 +0200 Subject: [PATCH 63/77] Update .cargo-deny-config.toml --- .cargo-deny-config.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.cargo-deny-config.toml b/.cargo-deny-config.toml index 3a9407f00c..abb8c8c067 100644 --- a/.cargo-deny-config.toml +++ b/.cargo-deny-config.toml @@ -25,6 +25,9 @@ exceptions = [ { allow = ["OpenSSL"], name = "ring", version = "*" }, { allow = ["OpenSSL"], name = "aws-lc-sys", version = "*" }, { allow = ["OpenSSL"], name = "aws-lc-fips-sys", version = "*" }, + { allow = ["BlueOak-1.0.0"], name = "minicbor", version = "<=0.24.2" }, + # Safe to bump as long as license does not change -------------^ + # See D105255799. ] [[licenses.clarify]] From 6781dca59888af5cf0c1f21b6635bab4cd7174b0 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 12:57:09 +0200 Subject: [PATCH 64/77] Add and modify router docs --- .../src/protocol/aws_json/router.rs | 3 +-- .../src/protocol/aws_json_10/mod.rs | 2 +- .../src/protocol/aws_json_11/mod.rs | 2 +- .../aws-smithy-http-server/src/protocol/rest/router.rs | 6 +++--- .../src/protocol/rest_json_1/mod.rs | 2 +- .../aws-smithy-http-server/src/protocol/rest_xml/mod.rs | 2 +- .../aws-smithy-http-server/src/protocol/rpc_v2/mod.rs | 7 +++---- .../aws-smithy-http-server/src/protocol/rpc_v2/router.rs | 8 +++++--- .../src/protocol/rpc_v2/runtime_error.rs | 8 ++++---- 9 files changed, 20 insertions(+), 20 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs index eff05c5d9c..38538fe1e9 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs @@ -41,7 +41,7 @@ pub enum Error { // https://github.com/smithy-lang/smithy-rs/pull/1429#issuecomment-1147516546 pub(crate) const ROUTE_CUTOFF: usize = 15; -/// A [`Router`] supporting [`AWS JSON 1.0`] and [`AWS JSON 1.1`] protocols. +/// A [`Router`] supporting [AWS JSON 1.0] and [AWS JSON 1.1] protocols. /// /// [AWS JSON 1.0]: https://smithy.io/2.0/aws/protocols/aws-json-1_0-protocol.html /// [AWS JSON 1.1]: https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html @@ -65,7 +65,6 @@ impl AwsJsonRouter { } } - // TODO This function is not used? Codegen should probably delegate to this function. /// Applies type erasure to the inner route using [`Route::new`]. pub fn boxed(self) -> AwsJsonRouter> where diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs index 8dfd48e466..a3e8f2c919 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs @@ -5,5 +5,5 @@ pub mod router; -/// [AWS JSON 1.0 Protocol](https://smithy.io/2.0/aws/protocols/aws-json-1_0-protocol.html). +/// [AWS JSON 1.0](https://smithy.io/2.0/aws/protocols/aws-json-1_0-protocol.html) protocol. pub struct AwsJson1_0; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs index 6fb09920a0..697aae52d3 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs @@ -5,5 +5,5 @@ pub mod router; -/// [AWS JSON 1.1 Protocol](https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html). +/// [AWS JSON 1.1](https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html) protocol. pub struct AwsJson1_1; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs index 19a644e7e4..94f99a98df 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs @@ -26,10 +26,10 @@ pub enum Error { MethodNotAllowed, } -/// A [`Router`] supporting [`AWS REST JSON 1.0`] and [`AWS REST XML`] protocols. +/// A [`Router`] supporting [AWS restJson1] and [AWS restXml] protocols. /// -/// [AWS REST JSON 1.0]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restjson1-protocol.html -/// [AWS REST XML]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restxml-protocol.html +/// [AWS restJson1]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restjson1-protocol.html +/// [AWS restXml]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restxml-protocol.html #[derive(Debug, Clone)] pub struct RestRouter { routes: Vec<(RequestSpec, S)>, diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs index f8384578d2..695d995ce1 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs @@ -7,5 +7,5 @@ pub mod rejection; pub mod router; pub mod runtime_error; -/// [AWS REST JSON 1.0 Protocol](https://smithy.io/2.0/aws/protocols/aws-restjson1-protocol.html). +/// [AWS restJson1](https://smithy.io/2.0/aws/protocols/aws-restjson1-protocol.html) protocol. pub struct RestJson1; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs index 0b16df11e3..e16570567e 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs @@ -7,5 +7,5 @@ pub mod rejection; pub mod router; pub mod runtime_error; -/// [AWS REST XML Protocol](https://smithy.io/2.0/aws/protocols/aws-restxml-protocol.html). +/// [AWS restXml](https://smithy.io/2.0/aws/protocols/aws-restxml-protocol.html) protocol. pub struct RestXml; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs index cec31e5aee..287a756446 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs @@ -7,7 +7,6 @@ pub mod rejection; pub mod router; pub mod runtime_error; -// TODO Rename to RpcV2Cbor -// TODO: Fill link -/// [Smithy RPC V2](). -pub struct RpcV2; +/// [Smithy RPC v2 CBOR](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) +/// protocol. +pub struct RpcV2Cbor; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs index f2798f0b69..f31f404efb 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs @@ -24,7 +24,7 @@ use crate::routing::Route; use crate::routing::Router; use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; -use super::RpcV2; +use super::RpcV2Cbor; pub use crate::protocol::rest::router::*; @@ -46,7 +46,9 @@ pub enum Error { NotFound, } -// TODO Docs +/// A [`Router`] supporting the [Smithy RPC v2 CBOR] protocol. +/// +/// [Smithy RPC v2 CBOR]: https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html #[derive(Debug, Clone)] pub struct RpcV2Router { routes: TinyMap<&'static str, S, ROUTE_CUTOFF>, @@ -126,7 +128,7 @@ impl RpcV2Router { // TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but // we have no tests to pin our implemenation against! -impl IntoResponse for Error { +impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { Error::MethodNotAllowed => method_disallowed(), diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs index 9b177a71de..c496731097 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs @@ -5,7 +5,7 @@ use crate::response::IntoResponse; use crate::runtime_error::{InternalFailureException, INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE}; -use crate::{extension::RuntimeErrorExtension, protocol::rpc_v2::RpcV2}; +use crate::{extension::RuntimeErrorExtension, protocol::rpc_v2::RpcV2Cbor}; use bytes::Bytes; use http::StatusCode; @@ -54,13 +54,13 @@ impl RuntimeError { } } -impl IntoResponse for InternalFailureException { +impl IntoResponse for InternalFailureException { fn into_response(self) -> http::Response { - IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) + IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) } } -impl IntoResponse for RuntimeError { +impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { let res = http::Response::builder() .status(self.status_code()) From 8f14505ed6b6b28212f2071468a2e8da1568566c Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 12:58:12 +0200 Subject: [PATCH 65/77] Rename RpcV2Router to RpcV2CborRouter --- .../src/protocol/rpc_v2/router.rs | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs index f31f404efb..7d3b1027b7 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs @@ -50,7 +50,7 @@ pub enum Error { /// /// [Smithy RPC v2 CBOR]: https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html #[derive(Debug, Clone)] -pub struct RpcV2Router { +pub struct RpcV2CborRouter { routes: TinyMap<&'static str, S, ROUTE_CUTOFF>, } @@ -63,7 +63,7 @@ const FORBIDDEN_HEADERS: &[&str] = &["x-amz-target", "x-amzn-target"]; /// . const IDENTIFIER_PATTERN: &str = r#"((_+([A-Za-z]|[0-9]))|[A-Za-z])[A-Za-z0-9_]*"#; -impl RpcV2Router { +impl RpcV2CborRouter { // TODO(https://github.com/smithy-lang/smithy-rs/issues/3748) Consider building a nom parser. fn uri_path_regex() -> &'static Regex { // Every request for the `rpcv2Cbor` protocol MUST be sent to a URL with the @@ -100,23 +100,23 @@ impl RpcV2Router { &SMITHY_PROTOCOL_REGEX } - pub fn boxed(self) -> RpcV2Router> + pub fn boxed(self) -> RpcV2CborRouter> where S: Service, Response = http::Response, Error = Infallible>, S: Send + Clone + 'static, S::Future: Send + 'static, { - RpcV2Router { + RpcV2CborRouter { routes: self.routes.into_iter().map(|(key, s)| (key, Route::new(s))).collect(), } } /// Applies a [`Layer`] uniformly to all routes. - pub fn layer(self, layer: L) -> RpcV2Router + pub fn layer(self, layer: L) -> RpcV2CborRouter where L: Layer, { - RpcV2Router { + RpcV2CborRouter { routes: self .routes .into_iter() @@ -170,7 +170,7 @@ pub enum WireFormatError { fn parse_wire_format_from_header(headers: &HeaderMap) -> Result { let header = headers.get("smithy-protocol").ok_or(WireFormatError::HeaderNotFound)?; let header = header.to_str().map_err(WireFormatError::HeaderValueNotVisibleAscii)?; - let captures = RpcV2Router::<()>::wire_format_regex() + let captures = RpcV2CborRouter::<()>::wire_format_regex() .captures(header) .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; @@ -200,7 +200,7 @@ impl FromStr for WireFormat { } } -impl Router for RpcV2Router { +impl Router for RpcV2CborRouter { type Service = S; type Error = Error; @@ -240,7 +240,7 @@ impl Router for RpcV2Router { } } -impl FromIterator<(&'static str, S)> for RpcV2Router { +impl FromIterator<(&'static str, S)> for RpcV2CborRouter { #[inline] fn from_iter>(iter: T) -> Self { Self { @@ -256,7 +256,7 @@ mod tests { use crate::protocol::test_helpers::req; - use super::{Error, Router, RpcV2Router}; + use super::{Error, Router, RpcV2CborRouter}; fn identifier_regex() -> Regex { Regex::new(&format!("^{}$", super::IDENTIFIER_PATTERN)).unwrap() @@ -290,7 +290,7 @@ mod tests { #[test] fn uri_regex_works_accepts() { - let regex = RpcV2Router::<()>::uri_path_regex(); + let regex = RpcV2CborRouter::<()>::uri_path_regex(); for uri in [ "/service/Service/operation/Operation", @@ -313,7 +313,7 @@ mod tests { #[test] fn uri_regex_works_rejects() { - let regex = RpcV2Router::<()>::uri_path_regex(); + let regex = RpcV2CborRouter::<()>::uri_path_regex(); for uri in [ "", @@ -333,7 +333,7 @@ mod tests { #[test] fn wire_format_regex_works() { - let regex = RpcV2Router::<()>::wire_format_regex(); + let regex = RpcV2CborRouter::<()>::wire_format_regex(); let captures = regex.captures("rpc-v2-something").unwrap(); assert_eq!("something", &captures["format"]); @@ -354,7 +354,7 @@ mod tests { #[test] fn simple_routing() { - let router: RpcV2Router<_> = ["Service.Operation"].into_iter().map(|op| (op, ())).collect(); + let router: RpcV2CborRouter<_> = ["Service.Operation"].into_iter().map(|op| (op, ())).collect(); let good_uri = "/prefix/service/Service/operation/Operation"; // The request should match. From a98524b4fbc5a32855810b3c58fb9f08ebdb210f Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 13:02:00 +0200 Subject: [PATCH 66/77] rpc_v2 module -> rpc_v2_cbor module --- rust-runtime/aws-smithy-http-server/src/protocol/mod.rs | 2 +- .../src/protocol/{rpc_v2 => rpc_v2_cbor}/mod.rs | 0 .../src/protocol/{rpc_v2 => rpc_v2_cbor}/rejection.rs | 0 .../src/protocol/{rpc_v2 => rpc_v2_cbor}/router.rs | 0 .../src/protocol/{rpc_v2 => rpc_v2_cbor}/runtime_error.rs | 2 +- 5 files changed, 2 insertions(+), 2 deletions(-) rename rust-runtime/aws-smithy-http-server/src/protocol/{rpc_v2 => rpc_v2_cbor}/mod.rs (100%) rename rust-runtime/aws-smithy-http-server/src/protocol/{rpc_v2 => rpc_v2_cbor}/rejection.rs (100%) rename rust-runtime/aws-smithy-http-server/src/protocol/{rpc_v2 => rpc_v2_cbor}/router.rs (100%) rename rust-runtime/aws-smithy-http-server/src/protocol/{rpc_v2 => rpc_v2_cbor}/runtime_error.rs (97%) diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs index cc5afab110..6d6bbf3b65 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs @@ -9,7 +9,7 @@ pub mod aws_json_11; pub mod rest; pub mod rest_json_1; pub mod rest_xml; -pub mod rpc_v2; +pub mod rpc_v2_cbor; use crate::rejection::MissingContentTypeReason; use aws_smithy_runtime_api::http::Headers as SmithyHeaders; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/mod.rs similarity index 100% rename from rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs rename to rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/mod.rs diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/rejection.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/rejection.rs similarity index 100% rename from rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/rejection.rs rename to rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/rejection.rs diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs similarity index 100% rename from rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs rename to rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs similarity index 97% rename from rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs rename to rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs index c496731097..b3f01da351 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs @@ -5,7 +5,7 @@ use crate::response::IntoResponse; use crate::runtime_error::{InternalFailureException, INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE}; -use crate::{extension::RuntimeErrorExtension, protocol::rpc_v2::RpcV2Cbor}; +use crate::{extension::RuntimeErrorExtension, protocol::rpc_v2_cbor::RpcV2Cbor}; use bytes::Bytes; use http::StatusCode; From e36704e1714fcad19799216e7ab863a320e4d905 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 13:10:02 +0200 Subject: [PATCH 67/77] More renaming to be rpcv2Cbor, not rpcv2 --- .../common-test-models/rpcv2Cbor-extras.smithy | 8 ++++---- codegen-server-test/build.gradle.kts | 2 +- .../smithy/generators/protocol/ServerProtocol.kt | 4 ++-- ...rAndParserGeneratorSerdeRoundTripIntegrationTest.kt | 8 ++++---- .../src/protocol/rpc_v2_cbor/router.rs | 10 +++++----- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy index 37ad8f21d1..c60b93736d 100644 --- a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -8,7 +8,7 @@ use smithy.test#httpResponseTests use smithy.test#httpMalformedRequestTests @rpcv2Cbor -service RpcV2Service { +service RpcV2CborService { operations: [ SimpleStructOperation ErrorSerializationOperation @@ -58,7 +58,7 @@ apply EmptyStructOperation @httpMalformedRequestTests([ protocol: rpcv2Cbor, request: { method: "POST", - uri: "/service/RpcV2Service/operation/EmptyStructOperation", + uri: "/service/RpcV2CborService/operation/EmptyStructOperation", headers: { "smithy-protocol": "rpc-v2-cbor", "Accept": "application/cbor", @@ -91,7 +91,7 @@ apply SingleMemberStructOperation @httpMalformedRequestTests([ protocol: rpcv2Cbor, request: { method: "POST", - uri: "/service/RpcV2Service/operation/SingleMemberStructOperation", + uri: "/service/RpcV2CborService/operation/SingleMemberStructOperation", headers: { "smithy-protocol": "rpc-v2-cbor", "Accept": "application/cbor", @@ -123,7 +123,7 @@ apply ErrorSerializationOperation @httpMalformedRequestTests([ protocol: rpcv2Cbor, request: { method: "POST", - uri: "/service/RpcV2Service/operation/ErrorSerializationOperation", + uri: "/service/RpcV2CborService/operation/ErrorSerializationOperation", headers: { "smithy-protocol": "rpc-v2-cbor", "Accept": "application/cbor", diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 27f09c5134..088a459a89 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -47,7 +47,7 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> // CodegenTest("aws.protocoltests.restxml#RestXml", "restXml"), CodegenTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"), CodegenTest( - "smithy.protocoltests.rpcv2Cbor#RpcV2Service", + "smithy.protocoltests.rpcv2Cbor#RpcV2CborService", "rpcv2Cbor_extras", imports = listOf("$commonModels/rpcv2Cbor-extras.smithy") ), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index bc12e70b18..42527862f6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -334,11 +334,11 @@ class ServerRpcV2CborProtocol( ) } - override fun markerStruct() = ServerRuntimeType.protocol("RpcV2", "rpc_v2", runtimeConfig) + override fun markerStruct() = ServerRuntimeType.protocol("RpcV2Cbor", "rpc_v2_cbor", runtimeConfig) override fun routerType() = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() - .resolve("protocol::rpc_v2::router::RpcV2Router") + .resolve("protocol::rpc_v2::router::RpcV2CborRouter") override fun serverRouterRequestSpec( operationShape: OperationShape, diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt index 46d4f92f79..2e92cde4e2 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -244,7 +244,7 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { ) val instantiator = ServerInstantiator(codegenContext, ignoreMissingMembers = true, withinTest = true) - val rpcV2 = ServerRpcV2CborProtocol(codegenContext) + val rpcv2Cbor = ServerRpcV2CborProtocol(codegenContext) for (operationShape in codegenContext.model.operationShapes) { val serverProtocolTestGenerator = @@ -265,9 +265,9 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { val serializeFn = if (targetShape.hasTrait()) { - rpcV2.structuredDataSerializer().serverErrorSerializer(targetShape.id) + rpcv2Cbor.structuredDataSerializer().serverErrorSerializer(targetShape.id) } else { - rpcV2.structuredDataSerializer().operationOutputSerializer(operationShape) + rpcv2Cbor.structuredDataSerializer().operationOutputSerializer(operationShape) } if (serializeFn == null) { @@ -323,7 +323,7 @@ internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { val params = test.testCase.params val deserializeFn = - rpcV2.structuredDataParser().serverInputParser(operationShape) + rpcv2Cbor.structuredDataParser().serverInputParser(operationShape) ?: // Skip if there's nothing to serialize. continue diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs index 7d3b1027b7..53d6e31483 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs @@ -28,13 +28,13 @@ use super::RpcV2Cbor; pub use crate::protocol::rest::router::*; -/// An RPC v2 routing error. +/// An RPC v2 CBOR routing error. #[derive(Debug, Error)] pub enum Error { /// Method was not `POST`. #[error("method not POST")] MethodNotAllowed, - /// Requests for the `rpcv2` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` + /// Requests for the `rpcv2Cbor` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` /// header. #[error("contains forbidden headers")] ForbiddenHeaders, @@ -54,8 +54,8 @@ pub struct RpcV2CborRouter { routes: TinyMap<&'static str, S, ROUTE_CUTOFF>, } -/// Requests for the `rpcv2` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` -/// header. An `rpcv2` request is malformed if it contains either of these headers. Server-side +/// Requests for the `rpcv2Cbor` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` +/// header. An `rpcv2Cbor` request is malformed if it contains either of these headers. Server-side /// implementations MUST reject such requests for security reasons. const FORBIDDEN_HEADERS: &[&str] = &["x-amz-target", "x-amzn-target"]; @@ -302,7 +302,7 @@ mod tests { // if only the name was specified. For example, if the `service`'s absolute shape ID is // `com.example#TheService`, a service should accept both `TheService` and // `com.example.TheService` as values for the `serviceName` segment. - "/service/aws.protocoltests.rpcv2.Service/operation/Operation", + "/service/aws.protocoltests.rpcv2Cbor.Service/operation/Operation", "/service/namespace.Service/operation/Operation", ] { let captures = regex.captures(uri).unwrap(); From c5440beba38541489c3846f5ecdcf07d3395146c Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 13:23:39 +0200 Subject: [PATCH 68/77] Address last TODO? --- .../smithy/protocols/parse/CborParserGenerator.kt | 4 ++-- .../protocols/serialize/CborSerializerGenerator.kt | 10 ++-------- rust-runtime/aws-smithy-protocol-test/src/lib.rs | 13 ++++++++----- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 9473245a60..04531cca1f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -69,7 +69,7 @@ class CborParserGenerator( private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape -> ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) }, - private val customizations: List = listOf(), + private val customizations: List = emptyList(), ) : StructuredDataParserGenerator { private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider @@ -445,7 +445,7 @@ class CborParserGenerator( errorShape, symbolProvider.symbolForBuilder(errorShape), errorShape.members().toList(), - fnNameSuffix = "json_err", + fnNameSuffix = "cbor_err", ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 9a72d8f564..d839e4ee7f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -48,10 +48,8 @@ import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.isUnit import software.amazon.smithy.rust.codegen.core.util.outputShape -// TODO Cleanup commented and unused code. - /** - * Class describing a JSON serializer section that can be used in a customization. + * Class describing a CBOR serializer section that can be used in a customization. */ sealed class CborSerializerSection(name: String) : Section(name) { /** @@ -81,7 +79,7 @@ class CborSerializerGenerator( data class Context( /** Expression representing the value to write to the encoder */ var valueExpression: ValueExpression, - /** Path in the JSON to get here, used for errors */ + /** Path in the CBOR to get here, used for errors */ val shape: T, ) @@ -156,14 +154,10 @@ class CborSerializerGenerator( private val runtimeConfig = codegenContext.runtimeConfig private val protocolFunctions = ProtocolFunctions(codegenContext) - // TODO Cleanup private val codegenScope = arrayOf( - "String" to RuntimeType.String, "Error" to runtimeConfig.serializationError(), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), "Encoder" to RuntimeType.smithyCbor(runtimeConfig).resolve("Encoder"), - "ByteSlab" to RuntimeType.ByteSlab, ) private val serializerUtil = SerializerUtil(model, symbolProvider) diff --git a/rust-runtime/aws-smithy-protocol-test/src/lib.rs b/rust-runtime/aws-smithy-protocol-test/src/lib.rs index 55e6d87b75..cfd841b189 100644 --- a/rust-runtime/aws-smithy-protocol-test/src/lib.rs +++ b/rust-runtime/aws-smithy-protocol-test/src/lib.rs @@ -331,7 +331,7 @@ impl> From for MediaType { } } -pub fn validate_body>( +pub fn validate_body + Debug>( actual_body: T, expected_body: &str, media_type: MediaType, @@ -414,7 +414,7 @@ fn try_json_eq(expected: &str, actual: &str) -> Result<(), ProtocolTestFailure> } } -fn try_cbor_eq>( +fn try_cbor_eq + Debug>( actual_body: T, expected_body: &str, ) -> Result<(), ProtocolTestFailure> { @@ -422,9 +422,12 @@ fn try_cbor_eq>( .decode_to_vec(expected_body) .expect("smithy protocol test `body` property is not properly base64 encoded"); let expected_cbor_value: serde_cbor::Value = - serde_cbor::from_slice(decoded.as_slice()).unwrap(); - let actual_cbor_value: serde_cbor::Value = - serde_cbor::from_slice(actual_body.as_ref()).unwrap(); // TODO Don't panic + serde_cbor::from_slice(decoded.as_slice()).expect("expected value must be valid CBOR"); + let actual_cbor_value: serde_cbor::Value = serde_cbor::from_slice(actual_body.as_ref()) + .map_err(|e| ProtocolTestFailure::InvalidBodyFormat { + expected: "cbor".to_owned(), + found: format!("{} {:?}", e.to_string(), actual_body), + })?; let actual_body_base64 = base64_simd::STANDARD.encode_to_string(&actual_body); if expected_cbor_value != actual_cbor_value { From 805c17787a30bc29b875917e7269b107a1f72849 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 16:11:57 +0200 Subject: [PATCH 69/77] fixes --- .../core/smithy/generators/protocol/ProtocolTestGenerator.kt | 2 +- .../server/smithy/generators/protocol/ServerProtocol.kt | 4 ++-- rust-runtime/aws-smithy-cbor/src/lib.rs | 4 ++++ rust-runtime/aws-smithy-protocol-test/Cargo.toml | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index 05fd3daf10..3c7950ef34 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -417,7 +417,7 @@ object ServiceShapeId { const val AWS_JSON_11 = "aws.protocoltests.json#JsonProtocol" const val REST_JSON = "aws.protocoltests.restjson#RestJson" const val RPC_V2_CBOR = "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol" - const val RPC_V2_CBOR_EXTRAS = "smithy.protocoltests.rpcv2Cbor#RpcV2Service" + const val RPC_V2_CBOR_EXTRAS = "smithy.protocoltests.rpcv2Cbor#RpcV2CborService" const val REST_XML = "aws.protocoltests.restxml#RestXml" const val AWS_QUERY = "aws.protocoltests.query#AwsQuery" const val EC2_QUERY = "aws.protocoltests.ec2#AwsEc2" diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 42527862f6..e43276a8a2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -311,7 +311,7 @@ class ServerRpcV2CborProtocol( ) : RpcV2Cbor(serverCodegenContext), ServerProtocol { val runtimeConfig = codegenContext.runtimeConfig - override val protocolModulePath = "rpc_v2" + override val protocolModulePath = "rpc_v2_cbor" override fun structuredDataParser(): StructuredDataParserGenerator = CborParserGenerator( @@ -338,7 +338,7 @@ class ServerRpcV2CborProtocol( override fun routerType() = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() - .resolve("protocol::rpc_v2::router::RpcV2CborRouter") + .resolve("protocol::rpc_v2_cbor::router::RpcV2CborRouter") override fun serverRouterRequestSpec( operationShape: OperationShape, diff --git a/rust-runtime/aws-smithy-cbor/src/lib.rs b/rust-runtime/aws-smithy-cbor/src/lib.rs index 1a547fedee..6db4813980 100644 --- a/rust-runtime/aws-smithy-cbor/src/lib.rs +++ b/rust-runtime/aws-smithy-cbor/src/lib.rs @@ -5,6 +5,10 @@ //! CBOR abstractions for Smithy. +/* Automatically managed default lints */ +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +/* End of automatically managed default lints */ + pub mod data; pub mod decode; pub mod encode; diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index 64b99c4869..2b816d96bf 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-protocol-test" -version = "0.61.0" +version = "0.61.1" authors = ["AWS Rust SDK Team ", "Russell Cohen "] description = "A collection of library functions to validate HTTP requests against Smithy protocol tests." edition = "2021" From ccde8601da9df6d942bfb4cd58c6eccaf6e57654 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 17:49:56 +0200 Subject: [PATCH 70/77] fixes --- aws/sdk/build.gradle.kts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index b8d4eab906..2f977bd039 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -453,7 +453,6 @@ tasks["assemble"].apply { outputs.upToDateWhen { false } } -project.registerCargoCommandsTasks(outputDir.asFile) tasks.register("copyCheckedInCargoLock") { description = "Copy the checked in Cargo.lock file back to the build directory" this.outputs.upToDateWhen { false } @@ -461,6 +460,7 @@ tasks.register("copyCheckedInCargoLock") { into(outputDir) } +project.registerCargoCommandsTasks(outputDir.asFile) project.registerGenerateCargoConfigTomlTask(outputDir.asFile) //The task name "test" is already registered by one of our plugins From c31356ca644c657b782fc636e180d2c8164a0f4b Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 18:05:41 +0200 Subject: [PATCH 71/77] fixes --- .../common-test-models/pokemon-awsjson.smithy | 1 - codegen-server-test/build.gradle.kts | 1 - .../generators/protocol/ServerProtocol.kt | 116 ++++++++++-------- 3 files changed, 62 insertions(+), 56 deletions(-) diff --git a/codegen-core/common-test-models/pokemon-awsjson.smithy b/codegen-core/common-test-models/pokemon-awsjson.smithy index 77e78a58d7..16eab7df90 100644 --- a/codegen-core/common-test-models/pokemon-awsjson.smithy +++ b/codegen-core/common-test-models/pokemon-awsjson.smithy @@ -27,7 +27,6 @@ service PokemonService { } /// Capture Pokémons via event streams. -@http(uri: "/simple-struct-operation", method: "POST") operation CapturePokemon { input := { events: AttemptCapturingPokemonEvent diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 088a459a89..808d476058 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -44,7 +44,6 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"), ), CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")), - // CodegenTest("aws.protocoltests.restxml#RestXml", "restXml"), CodegenTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"), CodegenTest( "smithy.protocoltests.rpcv2Cbor#RpcV2CborService", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index bbc49c83f1..d4984d65e9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -306,6 +306,68 @@ class ServerRestXmlProtocol( ) } +class ServerRpcV2CborProtocol( + private val serverCodegenContext: ServerCodegenContext, +) : RpcV2Cbor(serverCodegenContext), ServerProtocol { + val runtimeConfig = codegenContext.runtimeConfig + + override val protocolModulePath = "rpc_v2_cbor" + + override fun structuredDataParser(): StructuredDataParserGenerator = + CborParserGenerator( + serverCodegenContext, httpBindingResolver, returnSymbolToParseFn(serverCodegenContext), + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization( + serverCodegenContext, + ), + ), + ) + + override fun structuredDataSerializer(): StructuredDataSerializerGenerator { + return CborSerializerGenerator( + codegenContext, + httpBindingResolver, + listOf( + BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), + AddTypeFieldToServerErrorsCborCustomization(), + ), + ) + } + + override fun markerStruct() = ServerRuntimeType.protocol("RpcV2Cbor", "rpc_v2_cbor", runtimeConfig) + + override fun routerType() = + ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() + .resolve("protocol::rpc_v2_cbor::router::RpcV2CborRouter") + + override fun serverRouterRequestSpec( + operationShape: OperationShape, + operationName: String, + serviceName: String, + requestSpecModule: RuntimeType, + ) = writable { + // This is just the key used by the router's map to store and look up operations, it's completely arbitrary. + // We use the same key used by the awsJson1.x routers for simplicity. + // The router will extract the service name and the operation name from the URI, build this key, and lookup the + // operation stored there. + rust("$serviceName.$operationName".dq()) + } + + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.StaticStr + + override fun serverRouterRuntimeConstructor() = "rpc_v2_router" + + override fun serverContentTypeCheckNoModeledInput() = false + + override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType = + deserializePayloadErrorType( + codegenContext, + binding, + requestRejection(runtimeConfig), + RuntimeType.smithyCbor(codegenContext.runtimeConfig).resolve("decode::DeserializeError"), + ) +} + /** Just a common function to keep things DRY. **/ fun deserializePayloadErrorType( codegenContext: CodegenContext, @@ -367,57 +429,3 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborPa } } } - -class ServerRpcV2CborProtocol( - private val serverCodegenContext: ServerCodegenContext, -) : RpcV2Cbor(serverCodegenContext), ServerProtocol { - val runtimeConfig = codegenContext.runtimeConfig - - override val protocolModulePath = "rpc_v2_cbor" - - override fun structuredDataParser(): StructuredDataParserGenerator = - CborParserGenerator( - serverCodegenContext, httpBindingResolver, returnSymbolToParseFn(serverCodegenContext), - listOf( - ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization( - serverCodegenContext, - ), - ), - ) - - override fun structuredDataSerializer(): StructuredDataSerializerGenerator { - return CborSerializerGenerator( - codegenContext, - httpBindingResolver, - listOf( - BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), - AddTypeFieldToServerErrorsCborCustomization(), - ), - ) - } - - override fun markerStruct() = ServerRuntimeType.protocol("RpcV2Cbor", "rpc_v2_cbor", runtimeConfig) - - override fun routerType() = - ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() - .resolve("protocol::rpc_v2_cbor::router::RpcV2CborRouter") - - override fun serverRouterRequestSpec( - operationShape: OperationShape, - operationName: String, - serviceName: String, - requestSpecModule: RuntimeType, - ) = writable { - // This is just the key used by the router's map to store and lookup operations, it's completely arbitrary. - // We use the same key used by the awsJson1.x routers for simplicity. - // The router will extract the service name and the operation name from the URI, build this key, and lookup the - // operation stored there. - rust("$serviceName.$operationName".dq()) - } - - override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.StaticStr - - override fun serverRouterRuntimeConstructor() = "rpc_v2_router" - - override fun serverContentTypeCheckNoModeledInput() = false -} From 48def0563101da3c36b06dcac6cc68e9bec0a5af Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 18:16:18 +0200 Subject: [PATCH 72/77] copyright headers and other lint fixes --- .../smithy/protocols/parse/ReturnSymbolToParse.kt | 5 +++++ rust-runtime/aws-smithy-cbor/Cargo.toml | 13 +++++++------ rust-runtime/aws-smithy-cbor/README.md | 4 ++-- rust-runtime/aws-smithy-cbor/benches/blob.rs | 5 +++++ rust-runtime/aws-smithy-cbor/benches/string.rs | 5 +++++ rust-runtime/aws-smithy-cbor/src/data.rs | 5 +++++ rust-runtime/aws-smithy-cbor/src/decode.rs | 5 +++++ rust-runtime/aws-smithy-cbor/src/encode.rs | 5 +++++ rust-runtime/aws-smithy-protocol-test/src/lib.rs | 2 +- 9 files changed, 40 insertions(+), 9 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt index d78f2d98fd..b095aa8879 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt @@ -1,3 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import software.amazon.smithy.codegen.core.Symbol diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml index fe660a1338..b87366d6ef 100644 --- a/rust-runtime/aws-smithy-cbor/Cargo.toml +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -22,12 +22,6 @@ features = [ [dependencies] aws-smithy-types = { path = "../aws-smithy-types" } -[package.metadata.docs.rs] -all-features = true -targets = ["x86_64-unknown-linux-gnu"] -rustdoc-args = ["--cfg", "docsrs"] -# End of docs.rs metadata - [dev-dependencies] criterion = "0.5.1" @@ -38,3 +32,10 @@ harness = false [[bench]] name = "blob" harness = false + +[package.metadata.docs.rs] +all-features = true +targets = ["x86_64-unknown-linux-gnu"] +cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] +rustdoc-args = ["--cfg", "docsrs"] +# End of docs.rs metadata diff --git a/rust-runtime/aws-smithy-cbor/README.md b/rust-runtime/aws-smithy-cbor/README.md index 29b59ecd3f..367577b3e5 100644 --- a/rust-runtime/aws-smithy-cbor/README.md +++ b/rust-runtime/aws-smithy-cbor/README.md @@ -1,8 +1,8 @@ # aws-smithy-cbor -JSON serialization and deserialization primitives for clients and servers +CBOR serialization and deserialization primitives for clients and servers generated by [smithy-rs](https://github.com/smithy-lang/smithy-rs). -This crate is part of the [AWS SDK for Rust](https://awslabs.github.io/aws-sdk-rust/) and the [smithy-rs](https://github.com/awslabs/smithy-rs) code generator. In most cases, it should not be used directly. +This crate is part of the [AWS SDK for Rust](https://awslabs.github.io/aws-sdk-rust/) and the [smithy-rs](https://github.com/smithy-lang/smithy-rs) code generator. In most cases, it should not be used directly. diff --git a/rust-runtime/aws-smithy-cbor/benches/blob.rs b/rust-runtime/aws-smithy-cbor/benches/blob.rs index dd4960da0d..221940bb98 100644 --- a/rust-runtime/aws-smithy-cbor/benches/blob.rs +++ b/rust-runtime/aws-smithy-cbor/benches/blob.rs @@ -1,3 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + use aws_smithy_cbor::decode::Decoder; use criterion::{black_box, criterion_group, criterion_main, Criterion}; diff --git a/rust-runtime/aws-smithy-cbor/benches/string.rs b/rust-runtime/aws-smithy-cbor/benches/string.rs index 18d6b2ceee..f60ff353e0 100644 --- a/rust-runtime/aws-smithy-cbor/benches/string.rs +++ b/rust-runtime/aws-smithy-cbor/benches/string.rs @@ -1,3 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + use std::borrow::Cow; use aws_smithy_cbor::decode::Decoder; diff --git a/rust-runtime/aws-smithy-cbor/src/data.rs b/rust-runtime/aws-smithy-cbor/src/data.rs index a6eab6c549..e3bfdad2d9 100644 --- a/rust-runtime/aws-smithy-cbor/src/data.rs +++ b/rust-runtime/aws-smithy-cbor/src/data.rs @@ -1,3 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + #[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Debug, Hash)] pub enum Type { Bool, diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs index f3123e229d..57a844ab6d 100644 --- a/rust-runtime/aws-smithy-cbor/src/decode.rs +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -1,3 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + use std::borrow::Cow; use aws_smithy_types::{Blob, DateTime}; diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs index f9817df984..1651c37f9b 100644 --- a/rust-runtime/aws-smithy-cbor/src/encode.rs +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -1,3 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + use aws_smithy_types::{Blob, DateTime}; /// Macro for delegating method calls to the encoder. diff --git a/rust-runtime/aws-smithy-protocol-test/src/lib.rs b/rust-runtime/aws-smithy-protocol-test/src/lib.rs index cfd841b189..06cdbc2ff2 100644 --- a/rust-runtime/aws-smithy-protocol-test/src/lib.rs +++ b/rust-runtime/aws-smithy-protocol-test/src/lib.rs @@ -426,7 +426,7 @@ fn try_cbor_eq + Debug>( let actual_cbor_value: serde_cbor::Value = serde_cbor::from_slice(actual_body.as_ref()) .map_err(|e| ProtocolTestFailure::InvalidBodyFormat { expected: "cbor".to_owned(), - found: format!("{} {:?}", e.to_string(), actual_body), + found: format!("{} {:?}", e, actual_body), })?; let actual_body_base64 = base64_simd::STANDARD.encode_to_string(&actual_body); From 03ba0a88d81880980aef6722af1b4241e7ec5a52 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 10 Jul 2024 18:28:22 +0200 Subject: [PATCH 73/77] ./gradlew ktlintFormat --- .../server/smithy/protocols/ServerProtocolLoader.kt | 10 ++++++---- .../server/smithy/protocols/ServerRpcV2CborFactory.kt | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index ba7c0104ea..a121697eb7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -80,11 +80,13 @@ class ServerProtocolLoader(supportedProtocols: ProtocolMap = emptyList(), + private val additionalServerHttpBoundProtocolCustomizations: List = + emptyList(), ) : ProtocolGeneratorFactory { override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRpcV2CborProtocol(codegenContext) From ce7d83061b54e17590f32c7571023363275f7044 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 16 Jul 2024 16:50:16 +0200 Subject: [PATCH 74/77] Address comments --- buildSrc/src/main/kotlin/CodegenTestCommon.kt | 78 ++++++++++++++----- .../core/rustlang/RustReservedWords.kt | 2 +- .../core/smithy/protocols/RpcV2Cbor.kt | 16 ---- .../protocols/parse/CborParserGenerator.kt | 21 ++--- .../protocols/parse/ReturnSymbolToParse.kt | 3 +- .../serialize/CborSerializerGenerator.kt | 12 +-- .../smithy/rust/codegen/core/util/Smithy.kt | 78 +++++++------------ ...ypeFieldToServerErrorsCborCustomization.kt | 8 +- rust-runtime/Cargo.lock | 6 +- rust-runtime/aws-smithy-cbor/src/decode.rs | 2 + 10 files changed, 114 insertions(+), 112 deletions(-) diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index 202e6a0c83..8e0fd36447 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -26,29 +26,65 @@ fun generateImports(imports: List): String = if (imports.isEmpty()) { "" } else { - "\"imports\": [${imports.map { "\"$it\"" }.joinToString(", ")}]," + "\"imports\": [${imports.joinToString(", ") { "\"$it\"" }}]," } -fun toRustCrateName(input: String): String { - val rustKeywords = - setOf( - // Strict Keywords. - "as", "break", "const", "continue", "crate", "else", "enum", "extern", - "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", - "move", "mut", "pub", "ref", "return", "Self", "self", "static", "struct", - "super", "trait", "true", "type", "unsafe", "use", "where", "while", - // Weak Keywords. - "dyn", "async", "await", "try", - // Reserved for Future Use. - "abstract", "become", "box", "do", "final", "macro", "override", "priv", - "typeof", "unsized", "virtual", "yield", - // Primitive Types. - "bool", "char", "i8", "i16", "i32", "i64", "i128", "isize", - "u8", "u16", "u32", "u64", "u128", "usize", "f32", "f64", "str", - // Additional significant identifiers. - "proc_macro", - ) +val RustKeywords = + setOf( + "as", + "break", + "const", + "continue", + "crate", + "else", + "enum", + "extern", + "false", + "fn", + "for", + "if", + "impl", + "in", + "let", + "loop", + "match", + "mod", + "move", + "mut", + "pub", + "ref", + "return", + "self", + "Self", + "static", + "struct", + "super", + "trait", + "true", + "type", + "unsafe", + "use", + "where", + "while", + "async", + "await", + "dyn", + "abstract", + "become", + "box", + "do", + "final", + "macro", + "override", + "priv", + "typeof", + "unsized", + "virtual", + "yield", + "try", + ) +fun toRustCrateName(input: String): String { if (input.isBlank()) { throw IllegalArgumentException("Rust crate name cannot be empty") } @@ -62,7 +98,7 @@ fun toRustCrateName(input: String): String { when { trimmed.isEmpty() -> throw IllegalArgumentException("Rust crate name after sanitizing cannot be empty.") trimmed.matches(Regex("\\d+")) -> "n$trimmed" // Prepend 'n' if the name is purely numeric. - trimmed in rustKeywords -> "${trimmed}_" // Append an underscore if the name is reserved. + trimmed in RustKeywords -> "${trimmed}_" // Append an underscore if the name is reserved. else -> trimmed } return finalName diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt index cc37f891ed..3e1e8ab806 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt @@ -97,6 +97,7 @@ enum class EscapeFor { } object RustReservedWords : ReservedWords { + // This is the same list defined in `CodegenTestCommon` from the `buildSrc` Gradle subproject. private val RustKeywords = setOf( "as", @@ -151,7 +152,6 @@ object RustReservedWords : ReservedWords { "yield", "try", ) - // Some things can't be used as a raw identifier, so we can't use the normal escaping strategy // https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4 private val keywordEscapingMap = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index 2d910d5b39..d1af7ae72c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -11,8 +11,6 @@ import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.TimestampFormatTrait -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator @@ -89,20 +87,6 @@ class RpcV2CborHttpBindingResolver( } open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { - private val runtimeConfig = codegenContext.runtimeConfig - private val errorScope = - arrayOf( - "Bytes" to RuntimeType.Bytes, - "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), - "HeaderMap" to RuntimeType.Http.resolve("HeaderMap"), - "JsonError" to - CargoDependency.smithyJson(runtimeConfig).toType() - .resolve("deserialize::error::DeserializeError"), - "Response" to RuntimeType.Http.resolve("Response"), - "json_errors" to RuntimeType.jsonErrors(runtimeConfig), - ) - private val jsonDeserModule = RustModule.private("json_deser") - override val httpBindingResolver: HttpBindingResolver = RpcV2CborHttpBindingResolver( codegenContext.model, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 04531cca1f..99208b0b9a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -36,6 +36,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator @@ -83,7 +84,7 @@ class CborParserGenerator( "Decoder" to smithyCbor.resolve("Decoder"), "Error" to smithyCbor.resolve("decode::DeserializeError"), "HashMap" to RuntimeType.HashMap, - "Vec" to RuntimeType.Vec, + *preludeScope, ) private fun listMemberParserFn( @@ -97,7 +98,7 @@ class CborParserGenerator( fn member( mut list: #{ListSymbol}, decoder: &mut #{Decoder}, - ) -> Result<#{ListSymbol}, #{Error}> + ) -> #{Result}<#{ListSymbol}, #{Error}> """, *codegenScope, "ListSymbol" to listSymbol, @@ -148,7 +149,7 @@ class CborParserGenerator( fn pair( mut map: #{MapSymbol}, decoder: &mut #{Decoder}, - ) -> Result<#{MapSymbol}, #{Error}> + ) -> #{Result}<#{MapSymbol}, #{Error}> """, *codegenScope, "MapSymbol" to mapSymbol, @@ -204,7 +205,7 @@ class CborParserGenerator( fn pair( mut builder: #{Builder}, decoder: &mut #{Decoder} - ) -> Result<#{Builder}, #{Error}> + ) -> #{Result}<#{Builder}, #{Error}> """, *codegenScope, "Builder" to builderSymbol, @@ -280,7 +281,7 @@ class CborParserGenerator( """ fn pair( decoder: &mut #{Decoder} - ) -> Result<#{UnionSymbol}, #{Error}> + ) -> #{Result}<#{UnionSymbol}, #{Error}> """, *codegenScope, "UnionSymbol" to returnSymbolToParse.symbol, @@ -401,7 +402,7 @@ class CborParserGenerator( return protocolFunctions.deserializeFn(shape, fnNameSuffix) { fnName -> rustTemplate( """ - pub(crate) fn $fnName(value: &[u8], mut builder: #{Builder}) -> Result<#{Builder}, #{Error}> { + pub(crate) fn $fnName(value: &[u8], mut builder: #{Builder}) -> #{Result}<#{Builder}, #{Error}> { #{StructurePairParserFn:W} let decoder = &mut #{Decoder}::new(value); @@ -518,7 +519,7 @@ class CborParserGenerator( rustTemplate( """ - pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}> { #{ListMemberParserFn:W} #{InitContainerWritable:W} @@ -561,7 +562,7 @@ class CborParserGenerator( rustTemplate( """ - pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}> { #{MapPairParserFn:W} #{InitContainerWritable:W} @@ -593,7 +594,7 @@ class CborParserGenerator( val parser = protocolFunctions.deserializeFn(shape) { fnName -> rustBlockTemplate( - "pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}>", + "pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}>", "ReturnType" to returnSymbolToParse.symbol, *codegenScope, ) { @@ -631,7 +632,7 @@ class CborParserGenerator( protocolFunctions.deserializeFn(shape) { fnName -> rustTemplate( """ - pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{UnionSymbol}, #{Error}> { + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{UnionSymbol}, #{Error}> { #{UnionPairParserFnWritable} match decoder.map()? { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt index b095aa8879..4b69e87328 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import software.amazon.smithy.codegen.core.Symbol /** - * Given a shape, parsers need to know the symbol to parse and return, and whether it's unconstrained or not. + * Parsers need to know what symbol to parse and return, and whether it's unconstrained or not. + * This data class holds this information that the parsers fill out from a shape. */ data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index d839e4ee7f..be5af1a11e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section @@ -59,7 +60,7 @@ sealed class CborSerializerSection(name: String) : Section(name) { data class BeforeSerializingStructureMembers( val structureShape: StructureShape, val encoderBindingName: String, - ) : CborSerializerSection("ServerError") + ) : CborSerializerSection("BeforeSerializingStructureMembers") /** Manipulate the serializer context for a map prior to it being serialized. **/ data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context) : @@ -79,7 +80,7 @@ class CborSerializerGenerator( data class Context( /** Expression representing the value to write to the encoder */ var valueExpression: ValueExpression, - /** Path in the CBOR to get here, used for errors */ + /** Shape to serialize */ val shape: T, ) @@ -158,6 +159,7 @@ class CborSerializerGenerator( arrayOf( "Error" to runtimeConfig.serializationError(), "Encoder" to RuntimeType.smithyCbor(runtimeConfig).resolve("Encoder"), + *preludeScope, ) private val serializerUtil = SerializerUtil(model, symbolProvider) @@ -178,11 +180,11 @@ class CborSerializerGenerator( } return protocolFunctions.serializeFn(structureShape, fnNameSuffix = suffix) { fnName -> rustBlockTemplate( - "pub fn $fnName(value: &#{target}) -> Result, #{Error}>", + "pub fn $fnName(value: &#{target}) -> #{Result}<#{Vec}, #{Error}>", *codegenScope, "target" to symbolProvider.toSymbol(structureShape), ) { - rustTemplate("let mut encoder = #{Encoder}::new(Vec::new());", *codegenScope) + rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope) // Open a scope in which we can safely shadow the `encoder` variable to bind it to a mutable reference. rustBlock("") { rust("let encoder = &mut encoder;") @@ -382,7 +384,7 @@ class CborSerializerGenerator( val unionSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> rustBlockTemplate( - "pub fn $fnName(encoder: &mut #{Encoder}, input: &#{UnionSymbol}) -> Result<(), #{Error}>", + "pub fn $fnName(encoder: &mut #{Encoder}, input: &#{UnionSymbol}) -> #{Result}<(), #{Error}>", "UnionSymbol" to unionSymbol, *codegenScope, ) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt index cc036b7527..f6d6ddf84f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt @@ -24,19 +24,15 @@ import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait -inline fun Model.lookup(shapeId: String): T { - return this.expectShape(ShapeId.from(shapeId), T::class.java) -} +inline fun Model.lookup(shapeId: String): T = this.expectShape(ShapeId.from(shapeId), T::class.java) -fun OperationShape.inputShape(model: Model): StructureShape { +fun OperationShape.inputShape(model: Model): StructureShape = // The Rust Smithy generator adds an input to all shapes automatically - return model.expectShape(this.input.get(), StructureShape::class.java) -} + model.expectShape(this.input.get(), StructureShape::class.java) -fun OperationShape.outputShape(model: Model): StructureShape { +fun OperationShape.outputShape(model: Model): StructureShape = // The Rust Smithy generator adds an output to all shapes automatically - return model.expectShape(this.output.get(), StructureShape::class.java) -} + model.expectShape(this.output.get(), StructureShape::class.java) fun StructureShape.expectMember(member: String): MemberShape = this.getMember(member).orElseThrow { CodegenException("$member did not exist on $this") } @@ -55,47 +51,32 @@ fun UnionShape.hasStreamingMember(model: Model) = this.findMemberWithTrait() -} +fun MemberShape.isInputEventStream(model: Model): Boolean = + isEventStream(model) && model.expectShape(container).hasTrait() -fun MemberShape.isOutputEventStream(model: Model): Boolean { - return isEventStream(model) && model.expectShape(container).hasTrait() -} +fun MemberShape.isOutputEventStream(model: Model): Boolean = + isEventStream(model) && model.expectShape(container).hasTrait() private val unitShapeId = ShapeId.from("smithy.api#Unit") -fun Shape.isUnit(): Boolean { - return this.id == unitShapeId -} +fun Shape.isUnit(): Boolean = this.id == unitShapeId -fun MemberShape.isTargetUnit(): Boolean { - return this.target == unitShapeId -} +fun MemberShape.isTargetUnit(): Boolean = this.target == unitShapeId -fun Shape.hasEventStreamMember(model: Model): Boolean { - return members().any { it.isEventStream(model) } -} +fun Shape.hasEventStreamMember(model: Model): Boolean = members().any { it.isEventStream(model) } -fun OperationShape.isInputEventStream(model: Model): Boolean { - return input.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -} +fun OperationShape.isInputEventStream(model: Model): Boolean = + input.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -fun OperationShape.isOutputEventStream(model: Model): Boolean { - return output.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -} +fun OperationShape.isOutputEventStream(model: Model): Boolean = + output.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -fun OperationShape.isEventStream(model: Model): Boolean { - return isInputEventStream(model) || isOutputEventStream(model) -} +fun OperationShape.isEventStream(model: Model): Boolean = isInputEventStream(model) || isOutputEventStream(model) fun ServiceShape.hasEventStreamOperations(model: Model): Boolean = operations.any { id -> @@ -129,17 +110,13 @@ fun Shape.redactIfNecessary( * * A structure must have at most one streaming member. */ -fun StructureShape.findStreamingMember(model: Model): MemberShape? { - return this.findMemberWithTrait(model) -} +fun StructureShape.findStreamingMember(model: Model): MemberShape? = this.findMemberWithTrait(model) -inline fun StructureShape.findMemberWithTrait(model: Model): MemberShape? { - return this.members().find { it.getMemberTrait(model, T::class.java).isPresent } -} +inline fun StructureShape.findMemberWithTrait(model: Model): MemberShape? = + this.members().find { it.getMemberTrait(model, T::class.java).isPresent } -inline fun UnionShape.findMemberWithTrait(model: Model): MemberShape? { - return this.members().find { it.getMemberTrait(model, T::class.java).isPresent } -} +inline fun UnionShape.findMemberWithTrait(model: Model): MemberShape? = + this.members().find { it.getMemberTrait(model, T::class.java).isPresent } /** * If is member shape returns target, otherwise returns self. @@ -160,12 +137,11 @@ inline fun Shape.expectTrait(): T = expectTrait(T::class.jav /** Kotlin sugar for getTrait() check. e.g. shape.getTrait() instead of shape.getTrait(EnumTrait::class.java) */ inline fun Shape.getTrait(): T? = getTrait(T::class.java).orNull() -fun Shape.isPrimitive(): Boolean { - return when (this) { +fun Shape.isPrimitive(): Boolean = + when (this) { is NumberShape, is BooleanShape -> true else -> false } -} /** Convert a string to a ShapeId */ fun String.shapeId() = ShapeId.from(this) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt index 54cd279e8d..464a52dc46 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt @@ -46,10 +46,10 @@ class AddTypeFieldToServerErrorsCborCustomization : CborSerializerCustomization( writable { rust( """ - ${section.encoderBindingName} - .str("__type") - .str("${escape(section.structureShape.id.toString())}"); - """, + ${section.encoderBindingName} + .str("__type") + .str("${escape(section.structureShape.id.toString())}"); + """, ) } } else { diff --git a/rust-runtime/Cargo.lock b/rust-runtime/Cargo.lock index f44ad134ac..cf463b65a5 100644 --- a/rust-runtime/Cargo.lock +++ b/rust-runtime/Cargo.lock @@ -592,7 +592,7 @@ dependencies = [ [[package]] name = "aws-smithy-protocol-test" -version = "0.61.0" +version = "0.61.1" dependencies = [ "assert-json-diff", "aws-smithy-runtime-api 1.7.1", @@ -653,7 +653,7 @@ dependencies = [ "approx", "aws-smithy-async 1.2.1", "aws-smithy-http 0.60.9", - "aws-smithy-protocol-test 0.61.0", + "aws-smithy-protocol-test 0.61.1", "aws-smithy-runtime-api 1.7.1", "aws-smithy-types 1.2.0", "bytes", @@ -802,7 +802,7 @@ dependencies = [ name = "aws-smithy-xml" version = "0.60.8" dependencies = [ - "aws-smithy-protocol-test 0.61.0", + "aws-smithy-protocol-test 0.61.1", "base64 0.13.1", "proptest", "xmlparser", diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs index 57a844ab6d..3cfe070397 100644 --- a/rust-runtime/aws-smithy-cbor/src/decode.rs +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -64,6 +64,8 @@ impl DeserializeError { } /// More than one union variant was detected, but we never even got to parse the first one. + /// We immediately raise this error when detecting a union serialized as a fixed-length CBOR + /// map whose length (specified upfront) is a value different than 1. pub fn mixed_union_variants(at: usize) -> Self { Self { _inner: Error::message( From 7684ec1af56db126fb26d00d1b9260f24fd14332 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 17 Jul 2024 11:17:07 +0200 Subject: [PATCH 75/77] Address comments round 2 --- .../smithy/protocols/serialize/CborSerializerGenerator.kt | 8 ++++---- .../smithy/protocols/serialize/JsonSerializerGenerator.kt | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index be5af1a11e..f96a8b7cbc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -193,7 +193,7 @@ class CborSerializerGenerator( includedMembers, ) } - rust("Ok(encoder.into_writer())") + rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope) } } } @@ -260,7 +260,7 @@ class CborSerializerGenerator( val structureSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> rustBlockTemplate( - "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> Result<(), #{Error}>", + "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> #{Result}<(), #{Error}>", "StructureSymbol" to symbolProvider.toSymbol(context.shape), *codegenScope, ) { @@ -405,13 +405,13 @@ class CborSerializerGenerator( } if (codegenTarget.renderUnknownVariant()) { rustTemplate( - "#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", + "#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME} => return #{Err}(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", "Union" to unionSymbol, *codegenScope, ) } } - rust("Ok(())") + rustTemplate("#{Ok}(())", *codegenScope) } } rust("#T(encoder, ${context.valueExpression.asRef()})?;", unionSerializer) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 6580968749..69ec11fdd2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -213,7 +213,7 @@ class JsonSerializerGenerator( ) { rustTemplate( """ - let mut out = String::new(); + let mut out = #{String}::new(); let mut object = #{JsonObjectWriter}::new(&mut out); """, *codegenScope, From 1bf474d7a434ed68e559fabfab334b552c2588b3 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 17 Jul 2024 11:20:03 +0200 Subject: [PATCH 76/77] Bump aws-smithy-protocol-test to 0.62.0 --- rust-runtime/Cargo.lock | 8 ++++---- rust-runtime/aws-smithy-protocol-test/Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rust-runtime/Cargo.lock b/rust-runtime/Cargo.lock index cf463b65a5..602ebd4502 100644 --- a/rust-runtime/Cargo.lock +++ b/rust-runtime/Cargo.lock @@ -505,7 +505,7 @@ dependencies = [ [[package]] name = "aws-smithy-http-server-python" -version = "0.62.1" +version = "0.63.1" dependencies = [ "aws-smithy-http 0.60.9", "aws-smithy-http-server", @@ -592,7 +592,7 @@ dependencies = [ [[package]] name = "aws-smithy-protocol-test" -version = "0.61.1" +version = "0.62.0" dependencies = [ "assert-json-diff", "aws-smithy-runtime-api 1.7.1", @@ -653,7 +653,7 @@ dependencies = [ "approx", "aws-smithy-async 1.2.1", "aws-smithy-http 0.60.9", - "aws-smithy-protocol-test 0.61.1", + "aws-smithy-protocol-test 0.62.0", "aws-smithy-runtime-api 1.7.1", "aws-smithy-types 1.2.0", "bytes", @@ -802,7 +802,7 @@ dependencies = [ name = "aws-smithy-xml" version = "0.60.8" dependencies = [ - "aws-smithy-protocol-test 0.61.1", + "aws-smithy-protocol-test 0.62.0", "base64 0.13.1", "proptest", "xmlparser", diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index 2b816d96bf..9f5189079a 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-protocol-test" -version = "0.61.1" +version = "0.62.0" authors = ["AWS Rust SDK Team ", "Russell Cohen "] description = "A collection of library functions to validate HTTP requests against Smithy protocol tests." edition = "2021" From d485266adf957d508080927ad661470410088ef4 Mon Sep 17 00:00:00 2001 From: david-perez Date: Wed, 17 Jul 2024 11:22:25 +0200 Subject: [PATCH 77/77] ./gradlew ktlintFormat --- .../smithy/rust/codegen/core/rustlang/RustReservedWords.kt | 1 + 1 file changed, 1 insertion(+) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt index 3e1e8ab806..2c143d8009 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt @@ -152,6 +152,7 @@ object RustReservedWords : ReservedWords { "yield", "try", ) + // Some things can't be used as a raw identifier, so we can't use the normal escaping strategy // https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4 private val keywordEscapingMap =