diff --git a/aws/sdk-adhoc-test/build.gradle.kts b/aws/sdk-adhoc-test/build.gradle.kts index 1f2b7fde5a..3bdc663dad 100644 --- a/aws/sdk-adhoc-test/build.gradle.kts +++ b/aws/sdk-adhoc-test/build.gradle.kts @@ -20,7 +20,6 @@ java { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val pluginName = "rust-client-codegen" @@ -78,7 +77,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile, defaultRustDocFlags) +project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index 6e2165eb12..95b27bad1e 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -31,7 +31,6 @@ configure { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val crateHasherToolPath = rootProject.projectDir.resolve("tools/ci-build/crate-hasher") @@ -442,7 +441,7 @@ tasks["assemble"].apply { outputs.upToDateWhen { false } } -project.registerCargoCommandsTasks(outputDir.asFile, defaultRustDocFlags) +project.registerCargoCommandsTasks(outputDir.asFile) project.registerGenerateCargoConfigTomlTask(outputDir.asFile) //The task name "test" is already registered by one of our plugins diff --git a/build.gradle.kts b/build.gradle.kts index 20f2d9e400..5e11e0ab02 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ allprojects { val allowLocalDeps: String by project repositories { if (allowLocalDeps.toBoolean()) { - mavenLocal() + mavenLocal() } mavenCentral() google() diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index b7681a5b56..1977be9b11 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -29,6 +29,49 @@ fun generateImports(imports: List): String = "\"imports\": [${imports.map { "\"$it\"" }.joinToString(", ")}]," } +fun toRustCrateName(input: String): String { + val rustKeywords = setOf( + // Strict Keywords. + "as", "break", "const", "continue", "crate", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", + "move", "mut", "pub", "ref", "return", "Self", "self", "static", "struct", + "super", "trait", "true", "type", "unsafe", "use", "where", "while", + + // Weak Keywords. + "dyn", "async", "await", "try", + + // Reserved for Future Use. + "abstract", "become", "box", "do", "final", "macro", "override", "priv", + "typeof", "unsized", "virtual", "yield", + + // Primitive Types. + "bool", "char", "i8", "i16", "i32", "i64", "i128", "isize", + "u8", "u16", "u32", "u64", "u128", "usize", "f32", "f64", "str", + + // Additional significant identifiers. + "proc_macro" + ) + + // Then within your function, you could include a check against this set + if (input.isBlank()) { + throw IllegalArgumentException("Rust crate name cannot be empty") + } + val lowerCased = input.lowercase() + // Replace any sequence of characters that are not lowercase letters, numbers, or underscores with a single underscore. + val sanitized = lowerCased.replace(Regex("[^a-z0-9_]+"), "_") + // Trim leading or trailing underscores. + val trimmed = sanitized.trim('_') + // Check if the resulting string is empty, purely numeric, or a reserved name + val finalName = when { + trimmed.isEmpty() -> throw IllegalArgumentException("Rust crate name after sanitizing cannot be empty.") + trimmed.matches(Regex("\\d+")) -> "n$trimmed" // Prepend 'n' if the name is purely numeric. + trimmed in rustKeywords -> "${trimmed}_" // Append an underscore if the name is reserved. + else -> trimmed + } + return finalName +} + + private fun generateSmithyBuild( projectDir: String, pluginName: String, @@ -48,7 +91,7 @@ private fun generateSmithyBuild( ${it.extraCodegenConfig ?: ""} }, "service": "${it.service}", - "module": "${it.module}", + "module": "${toRustCrateName(it.module)}", "moduleVersion": "0.0.1", "moduleDescription": "test", "moduleAuthors": ["protocoltest@example.com"] @@ -205,13 +248,15 @@ fun Project.registerGenerateCargoWorkspaceTask( fun Project.registerGenerateCargoConfigTomlTask(outputDir: File) { this.tasks.register("generateCargoConfigToml") { description = "generate `.cargo/config.toml`" + // TODO(https://github.com/smithy-lang/smithy-rs/issues/1068): Once doc normalization + // is completed, warnings can be prohibited in rustdoc by setting `rustdocflags` to `-D warnings`. doFirst { outputDir.resolve(".cargo").mkdirs() outputDir.resolve(".cargo/config.toml") .writeText( """ [build] - rustflags = ["--deny", "warnings"] + rustflags = ["--deny", "warnings", "--cfg", "aws_sdk_unstable"] """.trimIndent(), ) } @@ -255,10 +300,7 @@ fun Project.registerModifyMtimeTask() { } } -fun Project.registerCargoCommandsTasks( - outputDir: File, - defaultRustDocFlags: String, -) { +fun Project.registerCargoCommandsTasks(outputDir: File) { val dependentTasks = listOfNotNull( "assemble", @@ -269,29 +311,24 @@ fun Project.registerCargoCommandsTasks( this.tasks.register(Cargo.CHECK.toString) { dependsOn(dependentTasks) workingDir(outputDir) - environment("RUSTFLAGS", "--cfg aws_sdk_unstable") commandLine("cargo", "check", "--lib", "--tests", "--benches", "--all-features") } this.tasks.register(Cargo.TEST.toString) { dependsOn(dependentTasks) workingDir(outputDir) - environment("RUSTFLAGS", "--cfg aws_sdk_unstable") commandLine("cargo", "test", "--all-features", "--no-fail-fast") } this.tasks.register(Cargo.DOCS.toString) { dependsOn(dependentTasks) workingDir(outputDir) - environment("RUSTDOCFLAGS", defaultRustDocFlags) - environment("RUSTFLAGS", "--cfg aws_sdk_unstable") commandLine("cargo", "doc", "--no-deps", "--document-private-items") } this.tasks.register(Cargo.CLIPPY.toString) { dependsOn(dependentTasks) workingDir(outputDir) - environment("RUSTFLAGS", "--cfg aws_sdk_unstable") commandLine("cargo", "clippy") } } diff --git a/codegen-client-test/build.gradle.kts b/codegen-client-test/build.gradle.kts index de4b54d5b3..a1c6de5c23 100644 --- a/codegen-client-test/build.gradle.kts +++ b/codegen-client-test/build.gradle.kts @@ -15,7 +15,6 @@ plugins { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) fun getSmithyRuntimeMode(): String = properties.get("smithy.runtime.mode") ?: "orchestrator" @@ -112,6 +111,11 @@ val allCodegenTests = listOf( "pokemon-service-awsjson-client", dependsOn = listOf("pokemon-awsjson.smithy", "pokemon-common.smithy"), ), + ClientTest( + "com.amazonaws.simple#RpcV2Service", + "rpcv2-pokemon-client", + dependsOn = listOf("rpcv2.smithy") + ), ClientTest("aws.protocoltests.misc#QueryCompatService", "query-compat-test", dependsOn = listOf("aws-json-query-compat.smithy")), ).map(ClientTest::toCodegenTest) @@ -125,7 +129,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile, defaultRustDocFlags) +project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/codegen-client/build.gradle.kts b/codegen-client/build.gradle.kts index 485a656d7b..3e1f1ec580 100644 --- a/codegen-client/build.gradle.kts +++ b/codegen-client/build.gradle.kts @@ -27,9 +27,10 @@ dependencies { implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") // `smithy.framework#ValidationException` is defined here, which is used in event stream -// marshalling/unmarshalling tests. + // marshalling/unmarshalling tests. testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt index cb961cbd19..47ea9f5626 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -33,7 +33,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -173,13 +175,19 @@ class DefaultProtocolTestGenerator( } testModuleWriter.write("Test ID: ${testCase.id}") testModuleWriter.newlinePrefix = "" + Attribute.TokioTest.render(testModuleWriter) - val action = - when (testCase) { - is HttpResponseTestCase -> Action.Response - is HttpRequestTestCase -> Action.Request - else -> throw CodegenException("unknown test case type") - } + Attribute.TracedTest.render(testModuleWriter) + // The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in + // case the code rendered by the test does not make use of `tracing` at all. + val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) } + testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) + + val action = when (testCase) { + is HttpResponseTestCase -> Action.Response + is HttpRequestTestCase -> Action.Request + else -> throw CodegenException("unknown test case type") + } if (expectFail(testCase)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") } @@ -415,8 +423,8 @@ class DefaultProtocolTestGenerator( if (body == "") { rustWriter.rustTemplate( """ - // No body - #{AssertEq}(::std::str::from_utf8(body).unwrap(), ""); + // No body. + #{AssertEq}(&body, &bytes::Bytes::new()); """, *codegenScope, ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index 5c0f7e6e1a..2af64bbad8 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext @@ -28,20 +29,21 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2 import software.amazon.smithy.rust.codegen.core.util.hasTrait class ClientProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { companion object { - val DefaultProtocols = - mapOf( - AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10), - AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11), - AwsQueryTrait.ID to ClientAwsQueryFactory(), - Ec2QueryTrait.ID to ClientEc2QueryFactory(), - RestJson1Trait.ID to ClientRestJsonFactory(), - RestXmlTrait.ID to ClientRestXmlFactory(), - ) + val DefaultProtocols = mapOf( + AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10), + AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11), + AwsQueryTrait.ID to ClientAwsQueryFactory(), + Ec2QueryTrait.ID to ClientEc2QueryFactory(), + RestJson1Trait.ID to ClientRestJsonFactory(), + RestXmlTrait.ID to ClientRestXmlFactory(), + Rpcv2CborTrait.ID to ClientRpcV2CborFactory(), + ) val Default = ClientProtocolLoader(DefaultProtocols) } } @@ -117,3 +119,12 @@ class ClientRestXmlFactory( override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT } + +class ClientRpcV2CborFactory : ProtocolGeneratorFactory { + override fun protocol(codegenContext: ClientCodegenContext): Protocol = RpcV2(codegenContext) + + override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator = + OperationGenerator(codegenContext, protocol(codegenContext)) + + override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT +} diff --git a/codegen-core/build.gradle.kts b/codegen-core/build.gradle.kts index 2fdd74abfb..eff612be35 100644 --- a/codegen-core/build.gradle.kts +++ b/codegen-core/build.gradle.kts @@ -28,6 +28,7 @@ dependencies { implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") } fun gitCommitHash(): String { diff --git a/codegen-core/common-test-models/adwait-cbor-structs.smithy b/codegen-core/common-test-models/adwait-cbor-structs.smithy new file mode 100644 index 0000000000..b88d930967 --- /dev/null +++ b/codegen-core/common-test-models/adwait-cbor-structs.smithy @@ -0,0 +1,142 @@ +$version: "2.0" + +namespace aws.protocoltests.rpcv2 + +use aws.protocoltests.shared#StringList +use smithy.protocols#rpcv2 +use smithy.test#httpRequestTests +use smithy.test#httpResponseTests + + +@httpRequestTests([ + { + id: "RpcV2CborSimpleScalarProperties", + protocol: rpcv2, + documentation: "Serializes simple scalar properties", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + method: "POST", + bodyMediaType: "application/cbor", + uri: "/service/RpcV2Protocol/operation/SimpleScalarProperties", + body: "v2lieXRlVmFsdWUFa2RvdWJsZVZhbHVl+z/+OVgQYk3TcWZhbHNlQm9vbGVhblZhbHVl9GpmbG9hdFZhbHVl+kDz989saW50ZWdlclZhbHVlGQEAaWxvbmdWYWx1ZRkmkWpzaG9ydFZhbHVlGSaqa3N0cmluZ1ZhbHVlZnNpbXBsZXB0cnVlQm9vbGVhblZhbHVl9f8=" + params: { + trueBooleanValue: true, + falseBooleanValue: false, + byteValue: 5, + doubleValue: 1.889, + floatValue: 7.624, + integerValue: 256, + shortValue: 9898, + longValue: 9873 + stringValue: "simple" + } + }, + { + id: "RpcV2CborClientDoesntSerializeNullStructureValues", + documentation: "RpcV2 Cbor should not serialize null structure values", + protocol: rpcv2, + method: "POST", + uri: "/service/RpcV2Protocol/operation/SimpleScalarProperties", + body: "v/8=", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + params: { + stringValue: null + }, + appliesTo: "client" + }, + { + id: "RpcV2CborServerDoesntDeSerializeNullStructureValues", + documentation: "RpcV2 Cbor should not deserialize null structure values", + protocol: rpcv2, + method: "POST", + uri: "/service/RpcV2Protocol/operation/SimpleScalarProperties", + body: "v2tzdHJpbmdWYWx1Zfb/", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + params: {}, + appliesTo: "server" + }, +]) +@httpResponseTests([ + { + id: "simple_scalar_structure", + protocol: rpcv2, + documentation: "Serializes simple scalar properties", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + bodyMediaType: "application/cbor", + body: "v2lieXRlVmFsdWUFa2RvdWJsZVZhbHVl+z/+OVgQYk3TcWZhbHNlQm9vbGVhblZhbHVl9GpmbG9hdFZhbHVl+kDz989saW50ZWdlclZhbHVlGQEAaWxvbmdWYWx1ZRkmkWpzaG9ydFZhbHVlGSaqa3N0cmluZ1ZhbHVlZnNpbXBsZXB0cnVlQm9vbGVhblZhbHVl9f8=", + code: 200, + params: { + trueBooleanValue: true, + falseBooleanValue: false, + byteValue: 5, + doubleValue: 1.889, + floatValue: 7.624, + integerValue: 256, + shortValue: 9898, + stringValue: "simple" + } + }, + { + id: "RpcV2CborClientDoesntDeSerializeNullStructureValues", + documentation: "RpcV2 Cbor should deserialize null structure values", + protocol: rpcv2, + body: "v2tzdHJpbmdWYWx1Zfb/", + code: 200, + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + params: {} + appliesTo: "client" + }, + { + id: "RpcV2CborServerDoesntSerializeNullStructureValues", + documentation: "RpcV2 Cbor should not serialize null structure values", + protocol: rpcv2, + body: "v/8=", + code: 200, + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + params: { + stringValue: null + }, + appliesTo: "server" + }, +]) +operation SimpleScalarProperties { + input: SimpleScalarStructure, + output: SimpleScalarStructure +} + + +structure SimpleScalarStructure { + trueBooleanValue: Boolean, + falseBooleanValue: Boolean, + byteValue: Byte, + doubleValue: Double, + floatValue: Float, + integerValue: Integer, + longValue: Long, + shortValue: Short, + stringValue: String, +} diff --git a/codegen-core/common-test-models/adwait-empty-input-output.smithy b/codegen-core/common-test-models/adwait-empty-input-output.smithy new file mode 100644 index 0000000000..186cbaaa30 --- /dev/null +++ b/codegen-core/common-test-models/adwait-empty-input-output.smithy @@ -0,0 +1,174 @@ +$version: "2.0" + +namespace aws.protocoltests.rpcv2 + +use smithy.protocols#rpcv2 +use smithy.test#httpRequestTests +use smithy.test#httpResponseTests + + +@httpRequestTests([ + { + id: "no_input", + protocol: rpcv2, + documentation: "Body is empty and no Content-Type header if no input", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + }, + forbidHeaders: [ + "Content-Type", + "X-Amz-Target" + ] + method: "POST", + uri: "/service/RpcV2Protocol/operation/NoInputOutput", + body: "" + }, + { + id: "no_input_server_allows_accept", + protocol: rpcv2, + documentation: "Servers should allow the Accept header to be set to the default content-type.", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + method: "POST", + uri: "/service/RpcV2Protocol/operation/NoInputOutput", + body: "", + appliesTo: "server" + }, + { + id: "no_input_server_allows_empty_cbor", + protocol: rpcv2, + documentation: "Servers should accept CBOR empty struct if no input.", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + method: "POST", + uri: "/service/RpcV2Protocol/operation/NoInputOutput", + body: "v/8=", + appliesTo: "server" + } +]) +@httpResponseTests([ + { + id: "no_output", + protocol: rpcv2, + documentation: "Body is empty and no Content-Type header if no response", + body: "", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + }, + forbidHeaders: [ + "Content-Type" + ] + code: 200, + }, + { + id: "no_output_client_allows_accept", + protocol: rpcv2, + documentation: "Servers should allow the accept header to be set to the default content-type.", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + body: "", + code: 200, + appliesTo: "client", + }, + { + id: "no_input_client_allows_empty_cbor", + protocol: rpcv2, + documentation: "Client should accept CBOR empty struct if no output", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + body: "v/8=", + code: 200, + appliesTo: "client", + } +]) +operation NoInputOutput {} + + +@httpRequestTests([ + { + id: "empty_input", + protocol: rpcv2, + documentation: "When Input structure is empty we write CBOR equivalent of {}", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + }, + forbidHeaders: [ + "X-Amz-Target" + ] + method: "POST", + uri: "/service/RpcV2Protocol/operation/EmptyInputOutput", + body: "v/8=", + }, +]) +@httpResponseTests([ + { + id: "empty_output", + protocol: rpcv2, + documentation: "When output structure is empty we write CBOR equivalent of {}", + body: "v/8=", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + code: 200, + }, +]) +operation EmptyInputOutput { + input: EmptyStructure, + output: EmptyStructure +} + +@httpRequestTests([ + { + id: "optional_input", + protocol: rpcv2, + documentation: "When input is empty we write CBOR equivalent of {}", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + }, + forbidHeaders: [ + "X-Amz-Target" + ] + method: "POST", + uri: "/service/RpcV2Protocol/operation/OptionalInputOutput", + body: "v/8=", + bodyMediaType: "application/cbor", + }, +]) +@httpResponseTests([ + { + id: "optional_output", + protocol: rpcv2, + documentation: "When output is empty we write CBOR equivalent of {}", + body: "v/8=", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + } + code: 200, + }, +]) +operation OptionalInputOutput { + input: SimpleStructure, + output: SimpleStructure +} diff --git a/codegen-core/common-test-models/adwait-main.smithy b/codegen-core/common-test-models/adwait-main.smithy new file mode 100644 index 0000000000..c827b1dd79 --- /dev/null +++ b/codegen-core/common-test-models/adwait-main.smithy @@ -0,0 +1,30 @@ +$version: "2.0" + +namespace aws.protocoltests.rpcv2 +use aws.api#service +use smithy.protocols#rpcv2 +use smithy.test#httpRequestTests +use smithy.test#httpResponseTests + +@service(sdkId: "Sample RpcV2 Protocol") +@rpcv2(format: ["cbor"]) +@title("RpcV2 Protocol Service") +service RpcV2Protocol { + version: "2020-07-14", + operations: [ + //Basic input/output tests + NoInputOutput, + EmptyInputOutput, + OptionalInputOutput, + + SimpleScalarProperties, + ] +} + +structure EmptyStructure { + +} + +structure SimpleStructure { + value: String, +} \ No newline at end of file diff --git a/codegen-core/common-test-models/rpcv2.smithy b/codegen-core/common-test-models/rpcv2.smithy new file mode 100644 index 0000000000..48a4fb694b --- /dev/null +++ b/codegen-core/common-test-models/rpcv2.smithy @@ -0,0 +1,206 @@ +$version: "2.0" + +// TODO Update namespace +namespace com.amazonaws.simple + +use smithy.framework#ValidationException +use smithy.protocols#rpcv2 +use smithy.test#httpResponseTests + +@rpcv2(format: ["cbor"]) +service RpcV2Service { + operations: [ + SimpleStructOperation, + ComplexStructOperation + ] +} + +// TODO RpcV2 should not use the `@http` trait. +@http(uri: "/simple-struct-operation", method: "POST") +operation SimpleStructOperation { + input: SimpleStruct + output: SimpleStruct + errors: [ValidationException] +} + +// TODO RpcV2 should not use the `@http` trait. +@http(uri: "/complex-struct-operation", method: "POST") +operation ComplexStructOperation { + input: ComplexStruct + output: ComplexStruct + errors: [ValidationException] +} + +apply SimpleStructOperation @httpResponseTests([ + { + id: "SimpleStruct", + protocol: "smithy.protocols#rpcv2", + code: 200, // Not used. + params: { + blob: "blobby blob", + boolean: false, + + string: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + byte: 69, + short: 70, + integer: 71, + long: 72, + + float: 0.69, + double: 0.6969, + + timestamp: 1546300800, + // document: { + // documentInteger: 69 + // } + enum: "DIAMOND" + + // With `@required`. + + requiredBlob: "blobby blob", + requiredBoolean: false, + + requiredString: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + requiredByte: 69, + requiredShort: 70, + requiredInteger: 71, + requiredLong: 72, + + requiredFloat: 0.69, + requiredDouble: 0.6969, + + requiredTimestamp: 1546300800, + // document: { + // documentInteger: 69 + // } + requiredEnum: "DIAMOND" + } + }, + // Same test, but leave optional types empty + { + id: "SimpleStructWithOptionsSetToNone", + protocol: "smithy.protocols#rpcv2", + code: 200, // Not used. + params: { + requiredBlob: "blobby blob", + requiredBoolean: false, + + requiredString: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + requiredByte: 69, + requiredShort: 70, + requiredInteger: 71, + requiredLong: 72, + + requiredFloat: 0.69, + requiredDouble: 0.6969, + + requiredTimestamp: 1546300800, + // document: { + // documentInteger: 69 + // } + requiredEnum: "DIAMOND" + } + } +]) + +structure SimpleStruct { + blob: Blob + boolean: Boolean + + string: String + + byte: Byte + short: Short + integer: Integer + long: Long + + float: Float + double: Double + + timestamp: Timestamp + // document: MyDocument + enum: Suit + + // With `@required`. + + @required requiredBlob: Blob + @required requiredBoolean: Boolean + + @required requiredString: String + + @required requiredByte: Byte + @required requiredShort: Short + @required requiredInteger: Integer + @required requiredLong: Long + + @required requiredFloat: Float + @required requiredDouble: Double + + @required requiredTimestamp: Timestamp + // @required requiredDocument: MyDocument + @required requiredEnum: Suit +} + +structure ComplexStruct { + structure: SimpleStruct + list: SimpleList + map: SimpleMap + union: SimpleUnion + + structureList: StructureList + + // `@required` for good measure here. + @required complexList: ComplexList + @required complexMap: ComplexMap + @required complexUnion: ComplexUnion +} + +list StructureList { + member: SimpleStruct +} + +list SimpleList { + member: String +} + +map SimpleMap { + key: String + value: Integer +} + +union SimpleUnion { + blob: Blob + boolean: Boolean + string: String +} + +list ComplexList { + member: ComplexMap +} + +map ComplexMap { + key: String + value: ComplexUnion +} + +union ComplexUnion { + // Recursive path here. + complexStruct: ComplexStruct + + structure: SimpleStruct + list: SimpleList + map: SimpleMap + union: SimpleUnion +} + +enum Suit { + DIAMOND + CLUB + HEART + SPADE +} + +// document MyDocument diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 8fc0f9b6a7..b4207eeb99 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -294,6 +294,7 @@ data class CargoDependency( val Hound: CargoDependency = CargoDependency("hound", CratesIo("3.4.0"), DependencyScope.Dev) val PrettyAssertions: CargoDependency = CargoDependency("pretty_assertions", CratesIo("1.3.0"), DependencyScope.Dev) + val SerdeCbor: CargoDependency = CargoDependency("serde_cbor", CratesIo("0.11"), DependencyScope.Dev) val SerdeJson: CargoDependency = CargoDependency("serde_json", CratesIo("1.0.0"), DependencyScope.Dev) val Smol: CargoDependency = CargoDependency("smol", CratesIo("1.2.0"), DependencyScope.Dev) val TempFile: CargoDependency = CargoDependency("tempfile", CratesIo("3.2.0"), DependencyScope.Dev) @@ -327,6 +328,8 @@ data class CargoDependency( fun smithyAsync(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-async") + fun smithyCbor(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-cbor") + fun smithyChecksums(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-checksums") fun smithyCompression(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-compression") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt index 10c8def399..6a98294e1b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt @@ -570,6 +570,7 @@ class Attribute(val inner: Writable, val isDeriveHelper: Boolean = false) { val Test = Attribute("test") val TokioTest = Attribute(RuntimeType.Tokio.resolve("test").writable) + val TracedTest = Attribute(RuntimeType.TracingTest.resolve("traced_test").writable) val AwsSdkUnstableAttribute = Attribute(cfg("aws_sdk_unstable")) /** diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index da5d742647..c909d411f6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -272,6 +272,9 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val U64 = std.resolve("primitive::u64") val Vec = std.resolve("vec::Vec") + // primitive types + val StaticStr = RuntimeType("&'static str") + // external cargo dependency types val Bytes = CargoDependency.Bytes.toType().resolve("Bytes") val Http = CargoDependency.Http.toType() @@ -288,6 +291,9 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val PercentEncoding = CargoDependency.PercentEncoding.toType() val PrettyAssertions = CargoDependency.PrettyAssertions.toType() val Regex = CargoDependency.Regex.toType() + val Serde= CargoDependency.Serde.toType() + val SerdeDeserialize = Serde.resolve("Deserialize") + val SerdeSerialize = Serde.resolve("Serialize") val RegexLite = CargoDependency.RegexLite.toType() val Tokio = CargoDependency.Tokio.toType() val TokioStream = CargoDependency.TokioStream.toType() @@ -299,14 +305,11 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val ConstrainedTrait = RuntimeType("crate::constrained::Constrained", InlineDependency.constrained()) val MaybeConstrained = RuntimeType("crate::constrained::MaybeConstrained", InlineDependency.constrained()) - // serde types. Gated behind `CfgUnstable`. - val Serde = CargoDependency.Serde.toType() - val SerdeSerialize = Serde.resolve("Serialize") - val SerdeDeserialize = Serde.resolve("Deserialize") - // smithy runtime types fun smithyAsync(runtimeConfig: RuntimeConfig) = CargoDependency.smithyAsync(runtimeConfig).toType() + fun smithyCbor(runtimeConfig: RuntimeConfig) = CargoDependency.smithyCbor(runtimeConfig).toType() + fun smithyChecksums(runtimeConfig: RuntimeConfig) = CargoDependency.smithyChecksums(runtimeConfig).toType() fun smithyCompression(runtimeConfig: RuntimeConfig) = CargoDependency.smithyCompression(runtimeConfig).toType() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index de73ab760d..db51633fef 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -38,6 +38,7 @@ import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait import software.amazon.smithy.model.traits.StreamingTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -65,6 +66,8 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.letIf import java.math.BigDecimal +import software.amazon.smithy.model.traits.DefaultTrait +import software.amazon.smithy.rust.codegen.core.util.getTrait /** * Class describing an instantiator section that can be used in a customization. @@ -452,7 +455,7 @@ open class Instantiator( */ private fun fillDefaultValue(shape: Shape): Node = when (shape) { - is MemberShape -> fillDefaultValue(model.expectShape(shape.target)) + is MemberShape -> shape.getTrait()?.toNode() ?: fillDefaultValue(model.expectShape(shape.target)) // Aggregate shapes. is StructureShape -> Node.objectNode() @@ -487,7 +490,9 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va val fractionalPart = num.remainder(BigDecimal.ONE) rust( "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", - RuntimeType.dateTime(runtimeConfig), +// RuntimeType.dateTime(runtimeConfig), + // TODO + runtimeConfig.smithyRuntimeCrate("smithy-types", scope = DependencyScope.Dev).toType().resolve("DateTime"), ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 486b443a6a..b44bdfb84d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -174,10 +174,10 @@ open class AwsJson( override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> + // `HeaderMap::new()` doesn't allocate. rustTemplate( """ pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - // Note: HeaderMap::new() doesn't allocate #{json_errors}::parse_error_metadata(payload, &#{Headers}::new()) } """, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt index 236c297db9..c7b139bfd5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt @@ -27,11 +27,28 @@ interface Protocol { /** The timestamp format that should be used if no override is specified in the model */ val defaultTimestampFormat: TimestampFormatTrait.Format - /** Returns additional HTTP headers that should be included in HTTP requests for the given operation for this protocol. */ + /** + * Returns additional HTTP headers that should be included in HTTP requests for the given operation for this protocol. + * + * These MUST all be lowercase, or the application will panic, as per + * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static + */ fun additionalRequestHeaders(operationShape: OperationShape): List> = emptyList() + /** + * Returns additional HTTP headers that should be included in HTTP responses for the given operation for this protocol. + * + * These MUST all be lowercase, or the application will panic, as per + * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static + */ + fun additionalResponseHeaders(operationShape: OperationShape): List> = emptyList() + /** * Returns additional HTTP headers that should be included in HTTP responses for the given error shape. + * These headers are added to responses _in addition_ to those returned by `additionalResponseHeaders`; if a header + * added by this function has the same header name as one added by `additionalResponseHeaders`, the one added by + * `additionalResponseHeaders` takes precedence. + * * These MUST all be lowercase, or the application will panic, as per * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static */ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt index e40046f1d8..cb9b766771 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt @@ -39,7 +39,7 @@ class ProtocolFunctions( private val codegenContext: CodegenContext, ) { companion object { - private val serDeModule = RustModule.pubCrate("protocol_serde") + val serDeModule = RustModule.pubCrate("protocol_serde") fun crossOperationFn( fnName: String, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index 641548fc11..c4e3980668 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -56,6 +56,12 @@ class RestJsonHttpBindingResolver( } } } + + // The spec does not mention whether we should set the `Content-Type` header when there is no modeled output. + // The protocol tests indicate it's optional: + // + // + // In our implementation, we opt to always set it to `application/json`. return super.responseContentType(operationShape) ?: "application/json" } } @@ -124,10 +130,10 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> + // `HeaderMap::new()` doesn't allocate. rustTemplate( """ pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - // Note: HeaderMap::new() doesn't allocate #{json_errors}::parse_error_metadata(payload, &#{Headers}::new()) } """, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt new file mode 100644 index 0000000000..75b275fc8f --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2.kt @@ -0,0 +1,141 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ToShapeId +import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.isStreaming +import software.amazon.smithy.rust.codegen.core.util.outputShape + +// TODO Rename these to RpcV2Cbor +class RpcV2HttpBindingResolver( + private val model: Model, +) : HttpBindingResolver { + private fun bindings(shape: ToShapeId): List { + val members = shape.let { model.expectShape(it.toShapeId()) }.members() + // TODO(https://github.com/awslabs/smithy-rs/issues/2237): support non-streaming members too + if (members.size > 1 && members.any { it.isStreaming(model) }) { + throw CodegenException( + "We only support one payload member if that payload contains a streaming member." + + "Tracking issue to relax this constraint: https://github.com/awslabs/smithy-rs/issues/2237", + ) + } + + return members.map { + if (it.isStreaming(model)) { + HttpBindingDescriptor(it, HttpLocation.PAYLOAD, "document") + } else { + HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") + } + } + .toList() + } + + // TODO + // In the server, this is only used when the protocol actually supports the `@http` trait. + // However, we will have to do this for client support. Perhaps this method deserves a rename. + override fun httpTrait(operationShape: OperationShape) = PANIC("RPC v2 does not support the `@http` trait") + + override fun requestBindings(operationShape: OperationShape) = bindings(operationShape.inputShape) + override fun responseBindings(operationShape: OperationShape) = bindings(operationShape.outputShape) + override fun errorResponseBindings(errorShape: ToShapeId) = bindings(errorShape) + + // TODO This should return null when operationShape has no members, and we should not rely on our janky + // `serverContentTypeCheckNoModeledInput`. Same goes for restJson1 protocol. + override fun requestContentType(operationShape: OperationShape): String = "application/cbor" + + /** + * > Responses for operations with no defined output type MUST NOT contain bodies in their HTTP responses. + * > The `Content-Type` for the serialization format MUST NOT be set. + */ + override fun responseContentType(operationShape: OperationShape): String? { + // When `syntheticOutputTrait.originalId == null` it implies that the operation had no output defined + // in the Smithy model. + val syntheticOutputTrait = operationShape.outputShape(model).expectTrait() + if (syntheticOutputTrait.originalId == null) { + return null + } + return requestContentType(operationShape) + } + + override fun eventStreamMessageContentType(memberShape: MemberShape): String? = + ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/cbor") +} + +/** + * TODO: Docs. + */ +open class RpcV2(val codegenContext: CodegenContext) : Protocol { + private val runtimeConfig = codegenContext.runtimeConfig + private val errorScope = arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "HeaderMap" to RuntimeType.Http.resolve("HeaderMap"), + "JsonError" to CargoDependency.smithyJson(runtimeConfig).toType() + .resolve("deserialize::error::DeserializeError"), + "Response" to RuntimeType.Http.resolve("Response"), + "json_errors" to RuntimeType.jsonErrors(runtimeConfig), + ) + private val jsonDeserModule = RustModule.private("json_deser") + + override val httpBindingResolver: HttpBindingResolver = RpcV2HttpBindingResolver(codegenContext.model) + + // Note that [CborParserGenerator] and [CborSerializerGenerator] automatically (de)serialize timestamps + // using floating point seconds from the epoch. + override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS + + override fun additionalResponseHeaders(operationShape: OperationShape): List> = + listOf("smithy-protocol" to "rpc-v2-cbor") + + override fun structuredDataParser(): StructuredDataParserGenerator = + CborParserGenerator(codegenContext, httpBindingResolver) + + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = + CborSerializerGenerator(codegenContext, httpBindingResolver) + + // TODO: Implement `RpcV2.parseHttpErrorMetadata` + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_http_error_metadata", jsonDeserModule) { + rustTemplate( + """ + pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { + #{json_errors}::parse_error_metadata(response.body(), response.headers()) + } + """, + *errorScope, + ) + } + + // TODO: Implement `RpcV2.parseEventStreamErrorMetadata` + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_event_stream_error_metadata", jsonDeserModule) { + // `HeaderMap::new()` doesn't allocate. + rustTemplate( + """ + pub fn parse_event_stream_error_metadata(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { + #{json_errors}::parse_error_metadata(payload, &#{HeaderMap}::new()) + } + """, + *errorScope, + ) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt new file mode 100644 index 0000000000..012c695d16 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -0,0 +1,657 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.SparseTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.outputShape + +/** + * Class describing a CBOR parser section that can be used in a customization. + */ +sealed class CborParserSection(name: String) : Section(name) { + data class BeforeBoxingDeserializedMember(val shape: MemberShape) : CborParserSection("BeforeBoxingDeserializedMember") +} + +/** + * Customization for the CBOR parser. + */ +typealias CborParserCustomization = NamedCustomization + +// TODO Add a `CborParserGeneratorTest` a la `CborSerializerGeneratorTest`. +class CborParserGenerator( + private val codegenContext: CodegenContext, + private val httpBindingResolver: HttpBindingResolver, + /** See docs for this parameter in [JsonParserGenerator]. */ + private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape -> + ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) + }, + private val customizations: List = listOf(), +) : StructuredDataParserGenerator { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig + // TODO Use? + private val codegenTarget = codegenContext.target + private val smithyCbor = CargoDependency.smithyCbor(runtimeConfig).toType() + private val protocolFunctions = ProtocolFunctions(codegenContext) + private val codegenScope = arrayOf( + "SmithyCbor" to smithyCbor, + "Decoder" to smithyCbor.resolve("Decoder"), + "Error" to smithyCbor.resolve("decode::DeserializeError"), + "HashMap" to RuntimeType.HashMap, + "Vec" to RuntimeType.Vec, + ) + + private fun listMemberParserFn( + listSymbol: Symbol, + isSparseList: Boolean, + memberShape: MemberShape, + returnUnconstrainedType: Boolean, + ) = writable { + rustBlockTemplate( + """ + fn member( + mut list: #{ListSymbol}, + decoder: &mut #{Decoder}, + ) -> Result<#{ListSymbol}, #{Error}> + """, + *codegenScope, + "ListSymbol" to listSymbol, + ) { + val deserializeMemberWritable = deserializeMember(memberShape) + if (isSparseList) { + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + decoder.null()?; + None + } + _ => Some(#{DeserializeMember:W}?), + }; + """, + *codegenScope, + "DeserializeMember" to deserializeMemberWritable, + ) + } else { + rustTemplate( + """ + let value = #{DeserializeMember:W}?; + """, + "DeserializeMember" to deserializeMemberWritable, + ) + } + + if (returnUnconstrainedType) { + rust("list.0.push(value);") + } else { + rust("list.push(value);") + } + + rust("Ok(list)") + } + } + + private fun mapPairParserFnWritable( + keyTarget: StringShape, + valueShape: MemberShape, + isSparseMap: Boolean, + mapSymbol: Symbol, + returnUnconstrainedType: Boolean, + ) = writable { + rustBlockTemplate( + """ + fn pair( + mut map: #{MapSymbol}, + decoder: &mut #{Decoder}, + ) -> Result<#{MapSymbol}, #{Error}> + """, + *codegenScope, + "MapSymbol" to mapSymbol, + ) { + val deserializeKeyWritable = deserializeString(keyTarget) + rustTemplate( + """ + let key = #{DeserializeKey:W}?; + """, + "DeserializeKey" to deserializeKeyWritable, + ) + val deserializeValueWritable = deserializeMember(valueShape) + if (isSparseMap) { + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + decoder.null()?; + None + } + _ => Some(#{DeserializeValue:W}?), + }; + """, + *codegenScope, + "DeserializeValue" to deserializeValueWritable, + ) + } else { + rustTemplate( + """ + let value = #{DeserializeValue:W}?; + """, + "DeserializeValue" to deserializeValueWritable, + ) + } + + if (returnUnconstrainedType) { + rust("map.0.insert(key, value);") + } else { + rust("map.insert(key, value);") + } + + rust("Ok(map)") + } + } + + private fun structurePairParserFnWritable(builderSymbol: Symbol, includedMembers: Collection) = writable { + rustBlockTemplate( + """ + fn pair( + mut builder: #{Builder}, + decoder: &mut #{Decoder} + ) -> Result<#{Builder}, #{Error}> + """, + *codegenScope, + "Builder" to builderSymbol, + ) { + withBlock("builder = match decoder.str()?.as_ref() {", "};") { + for (member in includedMembers) { + rustBlock("${member.memberName.dq()} =>") { + val callBuilderSetMemberFieldWritable = writable { + withBlock("builder.${member.setterName()}(", ")") { + conditionalBlock("Some(", ")", symbolProvider.toSymbol(member).isOptional()) { + val symbol = symbolProvider.toSymbol(member) + if (symbol.isRustBoxed()) { + rustBlock("") { + rustTemplate("let v = #{DeserializeMember:W}?;", + "DeserializeMember" to deserializeMember(member)) + + for (customization in customizations) { + customization.section( + CborParserSection.BeforeBoxingDeserializedMember( + member + ) + )(this) + } + rust("Box::new(v)") + } + } else { + rustTemplate("#{DeserializeMember:W}?", + "DeserializeMember" to deserializeMember(member)) + } + } + } + } + + if (member.isOptional) { + // Call `builder.set_member()` only if the value for the field on the wire is not null. + rustTemplate( + """ + ::aws_smithy_cbor::decode::set_optional(builder, decoder, |builder, decoder| { + Ok(#{MemberSettingWritable:W}) + })? + """, + "MemberSettingWritable" to callBuilderSetMemberFieldWritable + ) + } + else { + callBuilderSetMemberFieldWritable.invoke(this) + } + } + } + + rust( + """ + _ => { + decoder.skip()?; + builder + } + """) + } + rust("Ok(builder)") + } + } + + private fun unionPairParserFnWritable(shape: UnionShape) = writable { + val returnSymbolToParse = returnSymbolToParse(shape) + // TODO Test with unit variants + // TODO Test with all unit variants + rustBlockTemplate( + """ + fn pair( + decoder: &mut #{Decoder} + ) -> Result<#{UnionSymbol}, #{Error}> + """, + *codegenScope, + "UnionSymbol" to returnSymbolToParse.symbol, + ) { + withBlock("Ok(match decoder.str()? {", "})") { + for (member in shape.members()) { + val variantName = symbolProvider.toMemberName(member) + + withBlock("${member.memberName.dq()} => #T::$variantName(", "?),", returnSymbolToParse.symbol) { + deserializeMember(member).invoke(this) + } + } + // TODO Test client mode (parse unknown variant) and server mode (reject unknown variant). + // In client mode, resolve an unknown union variant to the unknown variant. + // In server mode, use strict parsing. + // Consultation: https://github.com/awslabs/smithy/issues/1222 + rust("_ => { todo!() }") + } + } + } + + private fun decodeStructureMapLoopWritable() = writable { + rustTemplate( + """ + match decoder.map()? { + None => loop { + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + break; + } + _ => { + builder = pair(builder, decoder)?; + } + }; + }, + Some(n) => { + for _ in 0..n { + builder = pair(builder, decoder)?; + } + } + }; + """, + *codegenScope, + ) + } + + // TODO This should be DRYed up with `decodeStructureMapLoopWritable`. + private fun decodeMapLoopWritable() = writable { + rustTemplate( + """ + match decoder.map()? { + None => loop { + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + break; + } + _ => { + map = pair(map, decoder)?; + } + }; + }, + Some(n) => { + for _ in 0..n { + map = pair(map, decoder)?; + } + } + }; + """, + *codegenScope, + ) + } + + // TODO This should be DRYed up with `decodeStructureMapLoopWritable`. + private fun decodeListLoop() = writable { + rustTemplate( + """ + match decoder.list()? { + None => loop { + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + break; + } + _ => { + list = member(list, decoder)?; + } + }; + }, + Some(n) => { + for _ in 0..n { + list = member(list, decoder)?; + } + } + }; + """, + *codegenScope, + ) + } + + /** + * Reusable structure parser implementation that can be used to generate parsing code for + * operation, error and structure shapes. + * We still generate the parser symbol even if there are no included members because the server + * generation requires parsers for all input structures. + */ + private fun structureParser( + shape: Shape, + builderSymbol: Symbol, + includedMembers: List, + fnNameSuffix: String? = null, + ): RuntimeType { + return protocolFunctions.deserializeFn(shape, fnNameSuffix) { fnName -> + // TODO Test no members. +// val unusedMut = if (includedMembers.isEmpty()) "##[allow(unused_mut)] " else "" + // TODO Assert token stream ended. + rustTemplate( + """ + pub(crate) fn $fnName(value: &[u8], mut builder: #{Builder}) -> Result<#{Builder}, #{Error}> { + #{StructurePairParserFn:W} + + let decoder = &mut #{Decoder}::new(value); + + #{DecodeStructureMapLoop:W} + + if decoder.position() != value.len() { + todo!() + } + + Ok(builder) + } + """, + "Builder" to builderSymbol, + "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), + "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), + *codegenScope, + ) + } + } + + override fun payloadParser(member: MemberShape): RuntimeType { + UNREACHABLE("No protocol using CBOR serialization supports payload binding") + } + + override fun operationParser(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR deserializer if there is nothing bound to the HTTP body. + val httpDocumentMembers = httpBindingResolver.responseMembers(operationShape, HttpLocation.DOCUMENT) + if (httpDocumentMembers.isEmpty()) { + return null + } + val outputShape = operationShape.outputShape(model) + return structureParser(operationShape, symbolProvider.symbolForBuilder(outputShape), httpDocumentMembers) + } + + override fun errorParser(errorShape: StructureShape): RuntimeType? { + if (errorShape.members().isEmpty()) { + return null + } + return structureParser( + errorShape, + symbolProvider.symbolForBuilder(errorShape), + errorShape.members().toList(), + fnNameSuffix = "json_err", + ) + } + + override fun serverInputParser(operationShape: OperationShape): RuntimeType? { + val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + if (includedMembers.isEmpty()) { + return null + } + val inputShape = operationShape.inputShape(model) + return structureParser(operationShape, symbolProvider.symbolForBuilder(inputShape), includedMembers) + } + + private fun RustWriter.deserializeMember(memberShape: MemberShape) = writable { + when (val target = model.expectShape(memberShape.target)) { + // Simple shapes: https://smithy.io/2.0/spec/simple-types.html + is BlobShape -> rust("decoder.blob()") + is BooleanShape -> rust("decoder.boolean()") + + is StringShape -> deserializeString(target).invoke(this) + + is ByteShape -> rust("decoder.byte()") + is ShortShape -> rust("decoder.short()") + is IntegerShape -> rust("decoder.integer()") + is LongShape -> rust("decoder.long()") + + is FloatShape -> rust("decoder.float()") + is DoubleShape -> rust("decoder.double()") + + is TimestampShape -> rust("decoder.timestamp()") + + // TODO Document shapes have not been specced out yet. + // is DocumentShape -> rustTemplate("Some(#{expect_document}(tokens)?)", *codegenScope) + + // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html + is StructureShape -> deserializeStruct(target) + is CollectionShape -> deserializeCollection(target) + is MapShape -> deserializeMap(target) + is UnionShape -> deserializeUnion(target) + else -> PANIC("unexpected shape: $target") + } + // TODO Boxing +// val symbol = symbolProvider.toSymbol(memberShape) +// if (symbol.isRustBoxed()) { +// for (customization in customizations) { +// customization.section(JsonParserSection.BeforeBoxingDeserializedMember(memberShape))(this) +// } +// rust(".map(Box::new)") +// } + } + + private fun RustWriter.deserializeString(target: StringShape, bubbleUp: Boolean = true) = writable { + // TODO Handle enum shapes + rust("decoder.string()") + } + + private fun RustWriter.deserializeCollection(shape: CollectionShape) { + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) + + // TODO Test `@sparse` and non-@sparse lists. + // - Clients should insert only non-null values in non-`@sparse` list. + // - Servers should reject upon encountering first null value in non-`@sparse` list. + // - Both clients and servers should insert null values in `@sparse` list. + + val parser = protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = writable { + withBlock("let mut list = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{Vec}::new()", *codegenScope) + } + } + } + + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { + #{ListMemberParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeListLoop:W} + + Ok(list) + } + """, + "ReturnType" to returnSymbol, + "ListMemberParserFn" to listMemberParserFn( + returnSymbol, + isSparseList = shape.hasTrait(), + shape.member, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeListLoop" to decodeListLoop(), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeMap(shape: MapShape) { + val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) + + // TODO Test `@sparse` and non-@sparse maps. + // - Clients should insert only non-null values in non-`@sparse` map. + // - Servers should reject upon encountering first null value in non-`@sparse` map. + // - Both clients and servers should insert null values in `@sparse` map. + + val parser = protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = writable { + withBlock("let mut map = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{HashMap}::new()", *codegenScope) + } + } + } + + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}> { + #{MapPairParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeMapLoop:W} + + Ok(map) + } + """, + "ReturnType" to returnSymbol, + "MapPairParserFn" to mapPairParserFnWritable( + keyTarget, + shape.value, + isSparseMap = shape.hasTrait(), + returnSymbol, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeMapLoop" to decodeMapLoopWritable(), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeStruct(shape: StructureShape) { + val returnSymbolToParse = returnSymbolToParse(shape) + val parser = protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + "pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{ReturnType}, #{Error}>", + "ReturnType" to returnSymbolToParse.symbol, + *codegenScope, + ) { + val builderSymbol = symbolProvider.symbolForBuilder(shape) + val includedMembers = shape.members() + + rustTemplate( + """ + #{StructurePairParserFn:W} + + let mut builder = #{Builder}::default(); + + #{DecodeStructureMapLoop:W} + """, + *codegenScope, + "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), + "Builder" to builderSymbol, + "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), + ) + + // Only call `build()` if the builder is not fallible. Otherwise, return the builder. + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(builder)") + } else { + rust("Ok(builder.build())") + } + } + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeUnion(shape: UnionShape) { + val returnSymbolToParse = returnSymbolToParse(shape) + val parser = protocolFunctions.deserializeFn(shape) { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> Result<#{UnionSymbol}, #{Error}> { + #{UnionPairParserFnWritable} + + match decoder.map()? { + None => { + let variant = pair(decoder)?; + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + Ok(variant) + } + ty => Err( + #{Error}::unexpected_union_variant( + ty, + decoder.position(), + ), + ), + } + } + Some(1) => pair(decoder), + Some(_) => Err(#{Error}::mixed_union_variants(decoder.position())) + } + } + """, + "UnionSymbol" to returnSymbolToParse.symbol, + "UnionPairParserFnWritable" to unionPairParserFnWritable(shape), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 4833f808fb..5f248aad33 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -77,8 +77,6 @@ sealed class JsonParserSection(name: String) : Section(name) { */ typealias JsonParserCustomization = NamedCustomization -data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) - class JsonParserGenerator( private val codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, @@ -447,7 +445,7 @@ class JsonParserGenerator( } private fun RustWriter.deserializeMap(shape: MapShape) { - val keyTarget = model.expectShape(shape.key.target) as StringShape + val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) val isSparse = shape.hasTrait() val returnSymbolToParse = returnSymbolToParse(shape) val parser = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt new file mode 100644 index 0000000000..d78f2d98fd --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt @@ -0,0 +1,8 @@ +package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse + +import software.amazon.smithy.codegen.core.Symbol + +/** + * Given a shape, parsers need to know the symbol to parse and return, and whether it's unconstrained or not. + */ +data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt index f8b053d80f..fd7e6fcb28 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt @@ -12,8 +12,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType interface StructuredDataParserGenerator { /** - * Generate a parse function for a given targeted as a payload. + * Generate a parse function for a given shape targeted with `@httpPayload`. * Entry point for payload-based parsing. + * * Roughly: * ```rust * fn parse_my_struct(input: &[u8]) -> Result { @@ -49,6 +50,7 @@ interface StructuredDataParserGenerator { /** * Generate a parser for a server operation input structure + * * ```rust * fn deser_operation_crate_operation_my_operation_input( * value: &[u8], builder: my_operation_input::Builder diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt new file mode 100644 index 0000000000..68533ef7d1 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -0,0 +1,469 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize + +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.isTargetUnit +import software.amazon.smithy.rust.codegen.core.util.outputShape + +// TODO Cleanup commented and unused code. + +/** + * Class describing a JSON serializer section that can be used in a customization. + */ +sealed class CborSerializerSection(name: String) : Section(name) { + /** + * Mutate the serializer prior to serializing any structure members. Eg: this can be used to inject `__type` + * to record the error type in the case of an error structure. + */ + data class BeforeSerializingStructureMembers(val structureShape: StructureShape, val encoderBindingName: String) : + CborSerializerSection("ServerError") + + /** Manipulate the serializer context for a map prior to it being serialized. **/ + data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context) : + CborSerializerSection("BeforeIteratingOverMapOrCollection") +} + +/** + * Customization for the CBOR serializer. + */ +typealias CborSerializerCustomization = NamedCustomization + +class CborSerializerGenerator( + codegenContext: CodegenContext, + private val httpBindingResolver: HttpBindingResolver, + private val customizations: List = listOf(), +) : StructuredDataSerializerGenerator { + data class Context( + /** Expression representing the value to write to the encoder */ + var valueExpression: ValueExpression, + /** Path in the JSON to get here, used for errors */ + val shape: T, + ) + + data class MemberContext( + /** Name for the variable bound to the encoder object **/ + val encoderBindingName: String, + /** Expression representing the value to write to the `Encoder` */ + var valueExpression: ValueExpression, + val shape: MemberShape, + /** Whether to serialize null values if the type is optional */ + val writeNulls: Boolean = false, + ) { + companion object { + fun collectionMember(context: Context, itemName: String): MemberContext = + MemberContext( + "encoder", + ValueExpression.Reference(itemName), + context.shape.member, + writeNulls = true, + ) + + fun mapMember(context: Context, key: String, value: String): MemberContext = + MemberContext( + "encoder.str($key)", + ValueExpression.Reference(value), + context.shape.value, + writeNulls = true, + ) + + fun structMember( + context: StructContext, + member: MemberShape, + symProvider: RustSymbolProvider, + ): MemberContext = + MemberContext( + encodeKeyExpression(member.memberName), + ValueExpression.Value("${context.localName}.${symProvider.toMemberName(member)}"), + member, + ) + + fun unionMember( + variantReference: String, + member: MemberShape, + ): MemberContext = + MemberContext( + encodeKeyExpression(member.memberName), + ValueExpression.Reference(variantReference), + member, + ) + + /** Returns an expression to encode a key member **/ + private fun encodeKeyExpression(name: String): String = + "encoder.str(${name.dq()})" + } + } + + // Specialized since it holds a JsonObjectWriter expression rather than a JsonValueWriter + data class StructContext( + /** Name of the variable that holds the struct */ + val localName: String, + val shape: StructureShape, + ) + + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val codegenTarget = codegenContext.target + private val runtimeConfig = codegenContext.runtimeConfig + private val protocolFunctions = ProtocolFunctions(codegenContext) + // TODO Cleanup + private val codegenScope = arrayOf( + "String" to RuntimeType.String, + "Error" to runtimeConfig.serializationError(), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "Encoder" to RuntimeType.smithyCbor(runtimeConfig).resolve("Encoder"), + "ByteSlab" to RuntimeType.ByteSlab, + ) + private val serializerUtil = SerializerUtil(model, symbolProvider) + + /** + * Reusable structure serializer implementation that can be used to generate serializing code for + * operation outputs or errors. + * This function is only used by the server, the client uses directly [serializeStructure]. + */ + private fun serverSerializer( + structureShape: StructureShape, + includedMembers: List, + error: Boolean, + ): RuntimeType { + val suffix = when (error) { + true -> "error" + else -> "output" + } + return protocolFunctions.serializeFn(structureShape, fnNameSuffix = suffix) { fnName -> + rustBlockTemplate( + "pub fn $fnName(value: &#{target}) -> Result, #{Error}>", + *codegenScope, + "target" to symbolProvider.toSymbol(structureShape), + ) { + rustTemplate("let mut encoder = #{Encoder}::new(Vec::new());", *codegenScope) + // Open a scope in which we can safely shadow the `encoder` variable to bind it to a mutable reference. + rustBlock("") { + rust("let encoder = &mut encoder;") + serializeStructure(StructContext("value", structureShape), includedMembers) + } + rust("Ok(encoder.into_writer())") + } + } + } + + // TODO + override fun payloadSerializer(member: MemberShape): RuntimeType { + val target = model.expectShape(member.target) + return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName -> + rustBlockTemplate( + "pub fn $fnName(input: &#{target}) -> std::result::Result<#{ByteSlab}, #{Error}>", + *codegenScope, + "target" to symbolProvider.toSymbol(target), + ) { + rust("let mut out = String::new();") + rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) + when (target) { + is StructureShape -> serializeStructure(StructContext("input", target)) + is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target)) + else -> throw IllegalStateException("json payloadSerializer only supports structs and unions") + } + rust("object.finish();") + rustTemplate("Ok(out.into_bytes())", *codegenScope) + } + } + } + + // TODO Unclear whether we'll need this. + override fun unsetStructure(structure: StructureShape): RuntimeType = + ProtocolFunctions.crossOperationFn("rest_json_unsetpayload") { fnName -> + rustTemplate( + """ + pub fn $fnName() -> #{ByteSlab} { + b"{}"[..].into() + } + """, + *codegenScope, + ) + } + + override fun unsetUnion(union: UnionShape): RuntimeType { + // TODO + TODO("Not yet implemented") + } + + override fun operationInputSerializer(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR serializer if there is no CBOR body. + val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + if (httpDocumentMembers.isEmpty()) { + return null + } + + val inputShape = operationShape.inputShape(model) + return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> + rustBlockTemplate( + "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>", + *codegenScope, "target" to symbolProvider.toSymbol(inputShape), + ) { + rust("let mut out = String::new();") + rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) + serializeStructure(StructContext("input", inputShape), httpDocumentMembers) + rust("object.finish();") + rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope) + } + } + } + + override fun documentSerializer(): RuntimeType { + return ProtocolFunctions.crossOperationFn("serialize_document") { fnName -> + rustTemplate( + """ + pub fn $fnName(input: &#{Document}) -> #{ByteSlab} { + let mut out = String::new(); + #{JsonValueWriter}::new(&mut out).document(input); + out.into_bytes() + } + """, + "Document" to RuntimeType.document(runtimeConfig), *codegenScope, + ) + } + } + + override fun operationOutputSerializer(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation JSON serializer if there was no operation output shape in the + // original (untransformed) model. + val syntheticOutputTrait = operationShape.outputShape(model).expectTrait() + if (syntheticOutputTrait.originalId == null) { + return null + } + + // TODO + // Note that, unlike the client, we serialize an empty JSON document `"{}"` if the operation output shape is + // empty (has no members). + // The client instead serializes an empty payload `""` in _both_ these scenarios: + // 1. there is no operation input shape; and + // 2. the operation input shape is empty (has no members). + // The first case gets reduced to the second, because all operations get a synthetic input shape with + // the [OperationNormalizer] transformation. + val httpDocumentMembers = httpBindingResolver.responseMembers(operationShape, HttpLocation.DOCUMENT) + + val outputShape = operationShape.outputShape(model) + return serverSerializer(outputShape, httpDocumentMembers, error = false) + } + + override fun serverErrorSerializer(shape: ShapeId): RuntimeType { + val errorShape = model.expectShape(shape, StructureShape::class.java) + val includedMembers = + httpBindingResolver.errorResponseBindings(shape).filter { it.location == HttpLocation.DOCUMENT } + .map { it.member } + return serverSerializer(errorShape, includedMembers, error = true) + } + + private fun RustWriter.serializeStructure( + context: StructContext, + includedMembers: List? = null, + ) { + // TODO Need to inject `__type` when serializing errors. + val structureSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> Result<(), #{Error}>", + "StructureSymbol" to symbolProvider.toSymbol(context.shape), + *codegenScope, + ) { + // TODO If all members are non-`Option`-al, we know AOT the map's size and can use `.map()` + // instead of `.begin_map()` for efficiency. Add test. + rust("encoder.begin_map();") + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeSerializingStructureMembers(context.shape, "encoder"))(this) + } + context.copy(localName = "input").also { inner -> + val members = includedMembers ?: inner.shape.members() + for (member in members) { + serializeMember(MemberContext.structMember(inner, member, symbolProvider)) + } + } + rust("encoder.end();") + rust("Ok(())") + } + } + rust("#T(encoder, ${context.localName})?;", structureSerializer) + } + + private fun RustWriter.serializeMember(context: MemberContext) { + val targetShape = model.expectShape(context.shape.target) + if (symbolProvider.toSymbol(context.shape).isOptional()) { + safeName().also { local -> + rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { + context.valueExpression = ValueExpression.Reference(local) + serializeMemberValue(context, targetShape) + } + if (context.writeNulls) { + rustBlock("else") { + rust("${context.encoderBindingName}.null();") + } + } + } + } else { + with(serializerUtil) { + ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) { + serializeMemberValue(context, targetShape) + } + } + } + } + + private fun RustWriter.serializeMemberValue(context: MemberContext, target: Shape) { + val encoder = context.encoderBindingName + val value = context.valueExpression + val containerShape = model.expectShape(context.shape.container) + + when (target) { + // Simple shapes: https://smithy.io/2.0/spec/simple-types.html + is BlobShape -> rust("$encoder.blob(${value.asRef()});") + is BooleanShape -> rust("$encoder.boolean(${value.asValue()});") + + is StringShape -> rust("$encoder.str(${value.name}.as_str());") + + is ByteShape -> rust("$encoder.byte(${value.asValue()});") + is ShortShape -> rust("$encoder.short(${value.asValue()});") + is IntegerShape -> rust("$encoder.integer(${value.asValue()});") + is LongShape -> rust("$encoder.long(${value.asValue()});") + + is FloatShape -> rust("$encoder.float(${value.asValue()});") + is DoubleShape -> rust("$encoder.double(${value.asValue()});") + + is TimestampShape -> rust("$encoder.timestamp(${value.asRef()});") + + // TODO Document shapes have not been specced out yet. + // is DocumentShape -> rust("$encoder.document(${value.asRef()});") + + // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html + else -> { + // This condition is equivalent to `containerShape !is CollectionShape`. + if (containerShape is StructureShape || containerShape is UnionShape || containerShape is MapShape) { + rust("$encoder;") // Encode the member key. + } + when (target) { + is StructureShape -> serializeStructure(StructContext(value.name, target)) + is CollectionShape -> serializeCollection(Context(value, target)) + is MapShape -> serializeMap(Context(value, target)) + is UnionShape -> serializeUnion(Context(value, target)) + else -> UNREACHABLE("Smithy added a new aggregate shape: $target") + } + } + } + } + + private fun RustWriter.serializeCollection(context: Context) { + // `.expect()` safety: `From for usize` is not in the standard library, but the conversion should be + // infallible (unless we ever have 128-bit machines I guess). + // See https://users.rust-lang.org/t/cant-convert-usize-to-u64/6243. + // TODO Point to a `static` to not inflate the binary. + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) + } + rust( + """ + encoder.array( + (${context.valueExpression.asValue()}).len().try_into().expect("`usize` to `u64` conversion failed") + ); + """ + ) + val itemName = safeName("item") + rustBlock("for $itemName in ${context.valueExpression.asRef()}") { + serializeMember(MemberContext.collectionMember(context, itemName)) + } + } + + private fun RustWriter.serializeMap(context: Context) { + val keyName = safeName("key") + val valueName = safeName("value") + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) + } + rust( + """ + encoder.map( + (${context.valueExpression.asValue()}).len().try_into().expect("`usize` to `u64` conversion failed") + ); + """ + ) + rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { + val keyExpression = "$keyName.as_str()" + serializeMember(MemberContext.mapMember(context, keyExpression, valueName)) + } + } + + private fun RustWriter.serializeUnion(context: Context) { + val unionSymbol = symbolProvider.toSymbol(context.shape) + val unionSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, input: &#{UnionSymbol}) -> Result<(), #{Error}>", + "UnionSymbol" to unionSymbol, + *codegenScope, + ) { + // A union is serialized identically as a `structure` shape, but only a single member can be set to a + // non-null value. + rust("encoder.map(1);") + rustBlock("match input") { + for (member in context.shape.members()) { + val variantName = if (member.isTargetUnit()) { + symbolProvider.toMemberName(member) + } else { + "${symbolProvider.toMemberName(member)}(inner)" + } + rustBlock("#T::$variantName =>", unionSymbol) { + serializeMember(MemberContext.unionMember("inner", member)) + } + } + if (codegenTarget.renderUnknownVariant()) { + rustTemplate( + "#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", + "Union" to unionSymbol, + *codegenScope, + ) + } + } + rust("Ok(())") + } + } + rust("#T(encoder, ${context.valueExpression.asRef()})?;", unionSerializer) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 46341bb09c..8b18dfa5c8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -212,12 +212,21 @@ class JsonSerializerGenerator( *codegenScope, "target" to symbolProvider.toSymbol(structureShape), ) { - rust("let mut out = String::new();") - rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) + rustTemplate( + """ + let mut out = String::new(); + let mut object = #{JsonObjectWriter}::new(&mut out); + """, + *codegenScope, + ) serializeStructure(StructContext("object", "value", structureShape), includedMembers) customizations.forEach { it.section(makeSection(structureShape, "object"))(this) } - rust("object.finish();") - rustTemplate("Ok(out)", *codegenScope) + rust( + """ + object.finish(); + Ok(out) + """ + ) } } } diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index cab36dbbcc..1c9c9d4c7d 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -16,7 +16,6 @@ plugins { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val pluginName = "rust-server-codegen" @@ -25,6 +24,7 @@ val workingDirUnderBuildDir = "smithyprojections/codegen-server-test/" dependencies { implementation(project(":codegen-server")) implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") @@ -44,6 +44,19 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"), ), CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")), + // CodegenTest("aws.protocoltests.restxml#RestXml", "restXml"), + CodegenTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"), + // Todo: change this to rpcv2extra + CodegenTest("com.amazonaws.simple#RpcV2Service", "rpcv2Extra", imports = listOf("$commonModels/rpcv2.smithy")), + CodegenTest( + "aws.protocoltests.rpcv2#RpcV2Protocol", + "adwait-main", + imports = listOf( + "$commonModels/adwait-main.smithy", + "$commonModels/adwait-cbor-structs.smithy", + "$commonModels/adwait-empty-input-output.smithy", + ) + ), CodegenTest( "com.amazonaws.constraints#ConstraintsService", "constraints_without_public_constrained_types", @@ -103,7 +116,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace", "generateCargoConfigToml") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile, defaultRustDocFlags) +project.registerCargoCommandsTasks(layout.buildDirectory.dir(workingDirUnderBuildDir).get().asFile) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/codegen-server-test/python/build.gradle.kts b/codegen-server-test/python/build.gradle.kts index f9129a2fde..b6ee22147a 100644 --- a/codegen-server-test/python/build.gradle.kts +++ b/codegen-server-test/python/build.gradle.kts @@ -16,7 +16,6 @@ plugins { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val buildDir = layout.buildDirectory.get().asFile @@ -120,7 +119,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(buildDir.resolve(workingDirUnderBuildDir), defaultRustDocFlags) +project.registerCargoCommandsTasks(buildDir.resolve(workingDirUnderBuildDir)) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/codegen-server-test/typescript/build.gradle.kts b/codegen-server-test/typescript/build.gradle.kts index 428da17df6..22c27b7d90 100644 --- a/codegen-server-test/typescript/build.gradle.kts +++ b/codegen-server-test/typescript/build.gradle.kts @@ -16,7 +16,6 @@ plugins { } val smithyVersion: String by project -val defaultRustDocFlags: String by project val properties = PropertyRetriever(rootProject, project) val buildDir = layout.buildDirectory.get().asFile @@ -49,7 +48,7 @@ tasks["smithyBuild"].dependsOn("generateSmithyBuild") tasks["assemble"].finalizedBy("generateCargoWorkspace") project.registerModifyMtimeTask() -project.registerCargoCommandsTasks(buildDir.resolve(workingDirUnderBuildDir), defaultRustDocFlags) +project.registerCargoCommandsTasks(buildDir.resolve(workingDirUnderBuildDir)) tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) diff --git a/codegen-server/build.gradle.kts b/codegen-server/build.gradle.kts index 0ba262225d..2dbc33df05 100644 --- a/codegen-server/build.gradle.kts +++ b/codegen-server/build.gradle.kts @@ -26,6 +26,7 @@ dependencies { implementation(project(":codegen-core")) implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") // `smithy.framework#ValidationException` is defined here, which is used in `constraints.smithy`, which is used // in `CustomValidationExceptionWithReasonDecoratorTest`. diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 01df0c0a93..a9c488503d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig */ object ServerCargoDependency { val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1.74")) + val Base64SimdDev: CargoDependency = CargoDependency("base64-simd", CratesIo("0.8"), scope = DependencyScope.Dev) val FormUrlEncoded: CargoDependency = CargoDependency("form_urlencoded", CratesIo("1")) val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3")) val Mime: CargoDependency = CargoDependency("mime", CratesIo("0.3")) @@ -26,7 +27,7 @@ object ServerCargoDependency { val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev) val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) - val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), DependencyScope.Dev) + val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), scope = DependencyScope.Dev) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt new file mode 100644 index 0000000000..59ead97478 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.escape +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext + +/** + * Smithy RPC v2 CBOR requires errors to be serialized in server responses with an additional `__type` field. + */ +class AddTypeFieldToServerErrorsCborCustomization : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = when (section) { + is CborSerializerSection.BeforeSerializingStructureMembers -> + if (section.structureShape.hasTrait()) { + writable { + rust( + """ + ${section.encoderBindingName} + .str("__type") + .str("${escape(section.structureShape.id.toString())}"); + """ + ) + } + } else { + emptySection + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt new file mode 100644 index 0000000000..42eb1f6843 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +/** + * A customization to, just before we encode over a _constrained_ map or collection shape in a CBOR serializer, + * unwrap the wrapper newtype and take a shared reference to the actual value within it. + * That value will be a `std::collections::HashMap` for map shapes, and a `std::vec::Vec` for collection shapes. + */ +class BeforeEncodingMapOrCollectionCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = when (section) { + is CborSerializerSection.BeforeIteratingOverMapOrCollection -> writable { + check(section.shape is CollectionShape || section.shape is MapShape) + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 2fb76bf879..d4281b4ddf 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -21,18 +21,27 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2 import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.ReturnSymbolToParse +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator @@ -70,8 +79,8 @@ interface ServerProtocol : Protocol { fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType /** - * In some protocols, such as restJson1, - * when there is no modeled body input, content type must not be set and the body must be empty. + * In some protocols, such as restJson1 and rpcv2, + * when there is no modeled body input, `content-type` must not be set and the body must be empty. * Returns a boolean indicating whether to perform this check. */ fun serverContentTypeCheckNoModeledInput(): Boolean = false @@ -166,7 +175,10 @@ class ServerAwsJsonProtocol( rust("""String::from("$serviceName.$operationName")""") } - override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.String + // TODO This could technically be `&static str` right? + override fun serverRouterRequestSpecType( + requestSpecModule: RuntimeType, + ): RuntimeType = RuntimeType.String override fun serverRouterRuntimeConstructor() = when (version) { @@ -255,8 +267,8 @@ class ServerRestXmlProtocol( } /** - * A customization to, just before we box a recursive member that we've deserialized into `Option`, convert it into - * `MaybeConstrained` if the target shape can reach a constrained shape. + * A customization to, just before we box a recursive member that we've deserialized from JSON into `Option`, convert + * it into `MaybeConstrained` if the target shape can reach a constrained shape. */ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(val codegenContext: ServerCodegenContext) : JsonParserCustomization() { @@ -276,3 +288,74 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonPa else -> emptySection } } + +/** + * A customization to, just before we box a recursive member that we've deserialized from CBOR into `T` held in a + * variable binding `v`, convert it into `MaybeConstrained` if the target shape can reach a constrained shape. + */ +class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization(val codegenContext: ServerCodegenContext) : + CborParserCustomization() { + override fun section(section: CborParserSection): Writable = when (section) { + is CborParserSection.BeforeBoxingDeserializedMember -> writable { + // We're only interested in _structure_ member shapes that can reach constrained shapes. + if ( + codegenContext.model.expectShape(section.shape.container) is StructureShape && + section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) + ) { + rust("let v = v.into();") + } + } + } +} + +class ServerRpcV2Protocol( + private val serverCodegenContext: ServerCodegenContext, +) : RpcV2(serverCodegenContext), ServerProtocol { + val runtimeConfig = codegenContext.runtimeConfig + + override val protocolModulePath = "rpc_v2" + + override fun structuredDataParser(): StructuredDataParserGenerator = + CborParserGenerator( + serverCodegenContext, httpBindingResolver, returnSymbolToParseFn(serverCodegenContext), + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization( + serverCodegenContext, + ), + ), + ) + + override fun structuredDataSerializer(): StructuredDataSerializerGenerator { + return CborSerializerGenerator( + codegenContext, + httpBindingResolver, + listOf( + BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), + AddTypeFieldToServerErrorsCborCustomization(), + ), + ) + } + + override fun markerStruct() = ServerRuntimeType.protocol("RpcV2", "rpc_v2", runtimeConfig) + + override fun routerType() = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() + .resolve("protocol::rpc_v2::router::RpcV2Router") + + override fun serverRouterRequestSpec( + operationShape: OperationShape, + operationName: String, + serviceName: String, + requestSpecModule: RuntimeType, + ) = writable { + // This is just the key used by the router's map to store and lookup operations, it's completely arbitrary. + // We use the same key used by the awsJson1.x routers for simplicity. + rust("$serviceName.$operationName".dq()) + } + + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = + RuntimeType.StaticStr + + override fun serverRouterRuntimeConstructor() = "rpc_v2_router" + + override fun serverContentTypeCheckNoModeledInput() = false +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index c35518319c..c9ffa06cc4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -26,6 +26,7 @@ import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords @@ -42,6 +43,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors +import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember @@ -100,6 +102,7 @@ class ServerProtocolTestGenerator( private val codegenScope = arrayOf( + "Base64SimdDev" to ServerCargoDependency.Base64SimdDev.toType(), "Bytes" to RuntimeType.Bytes, "SmithyHttp" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), "Http" to RuntimeType.Http, @@ -278,6 +281,11 @@ class ServerProtocolTestGenerator( testModuleWriter.newlinePrefix = "" Attribute.TokioTest.render(testModuleWriter) + Attribute.TracedTest.render(testModuleWriter) + // The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in + // case the code rendered by the test does not make use of `tracing` at all. + val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) } + testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency) if (expectFail(testCase)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") @@ -309,7 +317,7 @@ class ServerProtocolTestGenerator( } with(httpRequestTestCase) { - renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) + renderHttpRequest(uri, method, headers, body.orNull(), bodyMediaType.orNull(), queryParams, host.orNull()) } if (protocolSupport.requestBodyDeserialization) { makeRequest(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase)) @@ -387,7 +395,9 @@ class ServerProtocolTestGenerator( rustBlock("") { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. - renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) + // TODO(https://github.com/smithy-lang/smithy/issues/1932): we send `null` for `bodyMediaType` for now but + // the Smithy protocol test should give it to us. + renderHttpRequest(uri.get(), method, headers, body.orNull(), bodyMediaType = null, queryParams, host.orNull()) } makeRequest( @@ -405,6 +415,7 @@ class ServerProtocolTestGenerator( method: String, headers: Map, body: String?, + bodyMediaType: String?, queryParams: List, host: String?, ) { @@ -434,8 +445,17 @@ class ServerProtocolTestGenerator( // // We also escape to avoid interactions with templating in the case where the body contains `#`. val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() + + val encodedBody = + """ + #{Bytes}::from( + #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( + "`body` field of Smithy protocol test is not correctly base64 encoded" + ) + ) + """ - "#{SmithyHttpServer}::body::Body::from(#{Bytes}::from_static($sanitizedBody.as_bytes()))" + "#{SmithyHttpServer}::body::Body::from($encodedBody)" } else { "#{SmithyHttpServer}::body::Body::empty()" } @@ -660,13 +680,13 @@ class ServerProtocolTestGenerator( rustWriter.rustTemplate( """ // No body. - #{AssertEq}(std::str::from_utf8(&body).unwrap(), ""); + #{AssertEq}(&body, &bytes::Bytes::new()); """, *codegenScope, ) } else { assertOk(rustWriter) { - rustWriter.rust( + rust( "#T(&body, ${ rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 910acf4a70..ac966853e6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -114,7 +114,7 @@ typealias ServerHttpBoundProtocolCustomization = NamedCustomization = listOf(), additionalHttpBindingCustomizations: List = listOf(), ) : ServerProtocolGenerator( - protocol, - ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), - ) { + protocol, + ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), +) { + // TODO Delete, unused // Define suffixes for operation input / output / error wrappers companion object { const val OPERATION_INPUT_WRAPPER_SUFFIX = "OperationInputWrapper" @@ -604,13 +605,35 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } + private fun setResponseHeaderIfAbsent(writer: RustWriter, headerName: String, headerValue: String) { + // We can be a tad more efficient if there's a `const` `HeaderName` in the `http` crate that matches. + // https://docs.rs/http/latest/http/header/index.html#constants + val headerNameExpr = if (headerName == "content-type") { + "#{http}::header::CONTENT_TYPE" + } else { + "#{http}::header::HeaderName::from_static(\"$headerName\")" + } + + writer.rustTemplate( + """ + builder = #{header_util}::set_response_header_if_absent( + builder, + $headerNameExpr, + "${writer.escape(headerValue)}", + ); + """, + *codegenScope, + ) + } + + /** * Sets HTTP response headers for the operation's output shape or the operation's error shape. * It will generate response headers for the operation's output shape, unless [errorShape] is non-null, in which * case it will generate response headers for the given error shape. * * It sets three groups of headers in order. Headers from one group take precedence over headers in a later group. - * 1. Headers bound by the `httpHeader` and `httpPrefixHeader` traits. = null + * 1. Headers bound by the `httpHeader` and `httpPrefixHeader` traits. * 2. The protocol-specific `Content-Type` header for the operation. * 3. Additional protocol-specific headers for errors, if [errorShape] is non-null. */ @@ -626,7 +649,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( rust( """ builder = #{T}($outputOwnedOrBorrowed, builder)?; - """.trimIndent(), + """, addHeadersFn, ) } @@ -635,32 +658,17 @@ class ServerHttpBoundProtocolTraitImplGenerator( // to allow operations that bind a member to `Content-Type` (which we set earlier) to take precedence (this is // because we always use `set_response_header_if_absent`, so the _first_ header value we set for a given // header name is the one that takes precedence). - val contentType = httpBindingResolver.responseContentType(operationShape) - if (contentType != null) { - rustTemplate( - """ - builder = #{header_util}::set_response_header_if_absent( - builder, - #{http}::header::CONTENT_TYPE, - "$contentType" - ); - """, - *codegenScope, - ) + httpBindingResolver.responseContentType(operationShape)?.let { contentTypeValue -> + setResponseHeaderIfAbsent(this, "content-type", contentTypeValue) + } + + for ((headerName, headerValue) in protocol.additionalResponseHeaders(operationShape)) { + setResponseHeaderIfAbsent(this, headerName, headerValue) } if (errorShape != null) { for ((headerName, headerValue) in protocol.additionalErrorResponseHeaders(errorShape)) { - rustTemplate( - """ - builder = #{header_util}::set_response_header_if_absent( - builder, - http::header::HeaderName::from_static("$headerName"), - "${escape(headerValue)}" - ); - """, - *codegenScope, - ) + setResponseHeaderIfAbsent(this, headerName, headerValue) } } } @@ -747,13 +755,14 @@ class ServerHttpBoundProtocolTraitImplGenerator( "RequestParts" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("http::RequestParts"), ) val parser = structuredDataParser.serverInputParser(operationShape) - val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { // `null` is only returned by Smithy when there are no members, but we know there's at least one, since // there's something to parse (i.e. `parser != null`), so `!!` is safe here. val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!! rustTemplate("let bytes = #{Hyper}::body::to_bytes(body).await?;", *codegenScope) + // TODO Isn't this VERY wrong? If there's modeled operation input, we must reject if there's no payload! + // We currently accept and silently build empty input! rustBlock("if !bytes.is_empty()") { rustTemplate( """ @@ -793,6 +802,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( serverRenderUriPathParser(this, operationShape) serverRenderQueryStringParser(this, operationShape) + val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) { conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) { rustTemplate( @@ -1300,6 +1310,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( * Returns the error type of the function that deserializes a non-streaming HTTP payload (a byte slab) into the * shape targeted by the `httpPayload` trait. */ + // TODO This should not live here. Plus, only some protocols support `@httpPayload`. private fun getDeserializePayloadErrorSymbol(binding: HttpBindingDescriptor): Symbol { check(binding.location == HttpLocation.PAYLOAD) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index ae87ec5723..446f805851 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -20,7 +21,7 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator -class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { +class StreamPayloadSerializerCustomization : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { is ServerHttpBoundProtocolSection.WrapStreamPayload -> @@ -79,6 +80,8 @@ class ServerProtocolLoader(supportedProtocols: ProtocolMap { + override fun protocol(codegenContext: ServerCodegenContext): Protocol = + ServerRpcV2Protocol(codegenContext) + + override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = + ServerHttpBoundProtocolGenerator(codegenContext, ServerRpcV2Protocol(codegenContext)) + + override fun support(): ProtocolSupport { + return ProtocolSupport( + /* Client support */ + requestSerialization = false, + requestBodySerialization = false, + responseDeserialization = false, + errorDeserialization = false, + /* Server support */ + requestDeserialization = true, + requestBodyDeserialization = true, + responseSerialization = true, + errorSerialization = true, + ) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt new file mode 100644 index 0000000000..3b69814ffd --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RpcV2Test.kt @@ -0,0 +1,42 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest + +// TODO This won't be needed since we'll cover it with a proper integration test. +internal class RpcV2Test { + val model = """ + ${"\$"}version: "2.0" + + namespace com.amazonaws.simple + + use smithy.protocols#rpcv2 + + @rpcv2(format: ["cbor"]) + service RpcV2Service { + version: "SomeVersion", + operations: [RpcV2Operation], + } + + @http(uri: "/operation", method: "POST") + operation RpcV2Operation { + input: OperationInputOutput + output: OperationInputOutput + } + + structure OperationInputOutput { + message: String + } + """.asSmithyModel() + + @Test + fun `generate a rpc v2 service that compiles`() { + serverIntegrationTest(model) { _, _ -> } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt new file mode 100644 index 0000000000..ed418ad71f --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerGeneratorTest.kt @@ -0,0 +1,126 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.protocoltests.traits.AppliesTo +import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2 +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.outputShape +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.io.File + +internal class CborSerializerGeneratorTest { + class DeriveSerdeDeserializeSymbolMetadataProvider( + private val base: RustSymbolProvider, + ) : SymbolMetadataProvider(base) { + private fun addDeriveSerdeDeserialize(shape: Shape): RustMetadata { + check(shape !is MemberShape) + + val baseMetadata = base.toSymbol(shape).expectRustMetadata() + return baseMetadata.withDerives(RuntimeType.SerdeDeserialize) + } + + override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + + override fun structureMeta(structureShape: StructureShape) = addDeriveSerdeDeserialize(structureShape) + override fun unionMeta(unionShape: UnionShape) = addDeriveSerdeDeserialize(unionShape) + override fun enumMeta(stringShape: StringShape) = addDeriveSerdeDeserialize(stringShape) + + override fun listMeta(listShape: ListShape): RustMetadata = addDeriveSerdeDeserialize(listShape) + override fun mapMeta(mapShape: MapShape): RustMetadata = addDeriveSerdeDeserialize(mapShape) + override fun stringMeta(stringShape: StringShape): RustMetadata = addDeriveSerdeDeserialize(stringShape) + override fun numberMeta(numberShape: NumberShape): RustMetadata = addDeriveSerdeDeserialize(numberShape) + override fun blobMeta(blobShape: BlobShape): RustMetadata = addDeriveSerdeDeserialize(blobShape) + } + + @Test + fun `we serialize and serde_cbor deserializes round trip`() { + val model = File("../codegen-core/common-test-models/rpcv2.smithy").readText().asSmithyModel() + + val addDeriveSerdeSerializeDecorator = object : ServerCodegenDecorator { + override val name: String = "Add `#[derive(serde::Deserialize)]`" + override val order: Byte = 0 + + override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = + DeriveSerdeDeserializeSymbolMetadataProvider(base) + } + + serverIntegrationTest( + model, + additionalDecorators = listOf(addDeriveSerdeSerializeDecorator), + ) { codegenContext, rustCrate -> + val codegenScope = arrayOf( + "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), + "SerdeCbor" to CargoDependency.SerdeCbor.toType(), + ) + + val instantiator = ServerInstantiator(codegenContext) + val rpcV2 = RpcV2(codegenContext) + + for (operationShape in codegenContext.model.operationShapes) { + val outputShape = operationShape.outputShape(codegenContext.model) + // TODO Use `httpRequestTests` and error tests too. + val tests = operationShape.getTrait() + ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { + ServerProtocolTestGenerator.TestCase.ResponseTest( + it, + outputShape, + ) + } + val serializeFn = rpcV2 + .structuredDataSerializer() + .operationOutputSerializer(operationShape) + ?: continue // Skip if there's nothing to serialize. + + // TODO Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on + // which we can't `#[derive(Deserialize)]`. + rustCrate.withModule(ProtocolFunctions.serDeModule) { + for ((idx, test) in tests.withIndex()) { + unitTest("TODO_$idx") { + rustTemplate( + """ + let expected = #{InstantiateShape:W}; + let bytes = #{SerializeFn}(&expected) + .expect("generated CBOR serializer failed"); + let actual = #{SerdeCbor}::from_slice(&bytes) + .expect("serde_cbor failed deserializing from bytes"); + #{AssertEq}(expected, actual); + """, + "InstantiateShape" to instantiator.generate(test.targetShape, test.testCase.params), + "SerializeFn" to serializeFn, + *codegenScope, + ) + } + } + } + } + } + } +} diff --git a/examples/Cargo.toml b/examples/Cargo.toml index d92869a661..f65cd8c4ec 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -8,8 +8,11 @@ members = [ "pokemon-service-lambda", "pokemon-service-server-sdk", "pokemon-service-client", +<<<<<<< HEAD +======= "pokemon-service-client-usage", +>>>>>>> release-2023-11-01 ] [profile.release] diff --git a/examples/Makefile b/examples/Makefile index 9c2bf9061d..791248c0a0 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -3,16 +3,23 @@ CUR_DIR := $(shell pwd) GRADLE := $(SRC_DIR)/gradlew SERVER_SDK_DST := $(CUR_DIR)/pokemon-service-server-sdk CLIENT_SDK_DST := $(CUR_DIR)/pokemon-service-client +RPCV2_SDK_DST := $(CUR_DIR)/rpcv2-server-sdk +RPCV2_CLIENT_DST := $(CUR_DIR)/rpcv2-pokemon-client + SERVER_SDK_SRC := $(SRC_DIR)/codegen-server-test/build/smithyprojections/codegen-server-test/pokemon-service-server-sdk/rust-server-codegen CLIENT_SDK_SRC := $(SRC_DIR)/codegen-client-test/build/smithyprojections/codegen-client-test/pokemon-service-client/rust-client-codegen +RPCV2_SDK_SRC := $(SRC_DIR)/codegen-server-test/build/smithyprojections/codegen-server-test/rpcv2/rust-server-codegen +RPCV2_CLIENT_SRC := $(SRC_DIR)/codegen-client-test/build/smithyprojections/codegen-client-test/rpcv2-pokemon-client/rust-client-codegen all: codegen codegen: - $(GRADLE) --project-dir $(SRC_DIR) -P modules='pokemon-service-server-sdk,pokemon-service-client' :codegen-client-test:assemble :codegen-server-test:assemble - mkdir -p $(SERVER_SDK_DST) $(CLIENT_SDK_DST) + $(GRADLE) --project-dir $(SRC_DIR) -P modules='pokemon-service-server-sdk,pokemon-service-client,rpcv2,rpcv2-pokemon-client' :codegen-client-test:clean :codegen-server-test:clean :codegen-client-test:assemble :codegen-server-test:assemble + mkdir -p $(SERVER_SDK_DST) $(CLIENT_SDK_DST) $(RPCV2_SDK_DST) $(RPCV2_CLIENT_DST) cp -av $(SERVER_SDK_SRC)/* $(SERVER_SDK_DST)/ cp -av $(CLIENT_SDK_SRC)/* $(CLIENT_SDK_DST)/ + cp -av $(RPCV2_SDK_SRC)/* $(RPCV2_SDK_DST)/ + cp -av $(RPCV2_CLIENT_SRC)/* $(RPCV2_CLIENT_DST)/ build: codegen cargo build @@ -39,6 +46,6 @@ lambda_invoke: cargo lambda invoke pokemon-service-lambda --data-file pokemon-service/tests/fixtures/example-apigw-request.json distclean: clean - rm -rf $(SERVER_SDK_DST) $(CLIENT_SDK_DST) Cargo.lock + rm -rf $(SERVER_SDK_DST) $(CLIENT_SDK_DST) $(RPCV2_SDK_DST) $(RPCV2_CLIENT_DST) Cargo.lock .PHONY: all diff --git a/gradle.properties b/gradle.properties index 6d06ec08dc..c9e5ab7774 100644 --- a/gradle.properties +++ b/gradle.properties @@ -31,10 +31,4 @@ kotlinVersion=1.9.20 ktlintVersion=1.0.1 kotestVersion=5.8.0 # Avoid registering dependencies/plugins/tasks that are only used for testing purposes -isTestingEnabled=true - -# TODO(https://github.com/smithy-lang/smithy-rs/issues/1068): Once doc normalization -# is completed, warnings can be prohibited in rustdoc. -# -# defaultRustDocFlags=-D warnings -defaultRustDocFlags= +isTestingEnabled=true \ No newline at end of file diff --git a/rust-runtime/Cargo.toml b/rust-runtime/Cargo.toml index 3a618e6bbe..d7853b0099 100644 --- a/rust-runtime/Cargo.toml +++ b/rust-runtime/Cargo.toml @@ -3,6 +3,7 @@ resolver = "2" members = [ "inlineable", "aws-smithy-async", + "aws-smithy-cbor", "aws-smithy-checksums", "aws-smithy-compression", "aws-smithy-client", diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml new file mode 100644 index 0000000000..49c176a8f2 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "aws-smithy-cbor" +version = "0.0.0-smithy-rs-head" +authors = [ + "AWS Rust SDK Team ", + "David Pérez ", +] +description = "CBOR utilities for smithy-rs." +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/awslabs/smithy-rs" + +[dependencies.minicbor] +version = "0.19.1" # TODO Update +features = [ + # To write to a `Vec`: https://docs.rs/minicbor/latest/minicbor/encode/write/trait.Write.html#impl-Write-for-Vec%3Cu8%3E + "alloc", + # To support reading `f16` to accomodate fewer bytes transmitted that fit the value. + "half", +] + +[dependencies] +aws-smithy-types = { path = "../aws-smithy-types" } + +[package.metadata.docs.rs] +all-features = true +targets = ["x86_64-unknown-linux-gnu"] +rustdoc-args = ["--cfg", "docsrs"] +# End of docs.rs metadata + +[dev-dependencies] +criterion = "0.5.1" + +[[bench]] +name = "string" +harness = false + +[[bench]] +name = "blob" +harness = false diff --git a/rust-runtime/aws-smithy-cbor/LICENSE b/rust-runtime/aws-smithy-cbor/LICENSE new file mode 100644 index 0000000000..67db858821 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. diff --git a/rust-runtime/aws-smithy-cbor/README.md b/rust-runtime/aws-smithy-cbor/README.md new file mode 100644 index 0000000000..4d0ad07413 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/README.md @@ -0,0 +1,7 @@ +# aws-smithy-cbor + +TODO + + +This crate is part of the [AWS SDK for Rust](https://awslabs.github.io/aws-sdk-rust/) and the [smithy-rs](https://github.com/awslabs/smithy-rs) code generator. In most cases, it should not be used directly. + diff --git a/rust-runtime/aws-smithy-cbor/benches/blob.rs b/rust-runtime/aws-smithy-cbor/benches/blob.rs new file mode 100644 index 0000000000..dd4960da0d --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/benches/blob.rs @@ -0,0 +1,21 @@ +use aws_smithy_cbor::decode::Decoder; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn blob_benchmark(c: &mut Criterion) { + // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`. + let blob_indefinite_bytes = [ + 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62, 0x79, + 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c, 0x4e, 0x20, + 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0xff, + ]; + + c.bench_function("blob", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&blob_indefinite_bytes); + let _ = black_box(decoder.blob()); + }) + }); +} + +criterion_group!(benches, blob_benchmark); +criterion_main!(benches); diff --git a/rust-runtime/aws-smithy-cbor/benches/string.rs b/rust-runtime/aws-smithy-cbor/benches/string.rs new file mode 100644 index 0000000000..18d6b2ceee --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/benches/string.rs @@ -0,0 +1,131 @@ +use std::borrow::Cow; + +use aws_smithy_cbor::decode::Decoder; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn str_benchmark(c: &mut Criterion) { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, 0x79, + 0xff, + ]; + + c.bench_function("definite str()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&definite_bytes); + let x = black_box(decoder.str()); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("definite str_alt", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let x = black_box(str_alt(&mut decoder)); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("indefinite str()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&indefinite_bytes); + let x = black_box(decoder.str()); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("indefinite str_alt", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let x = black_box(str_alt(&mut decoder)); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); +} + +// The following seems to be a bit slower than the implementation that we have +// kept in the `aws_smithy_cbor::Decoder`. +pub fn string_alt<'b>( + decoder: &'b mut minicbor::Decoder<'b>, +) -> Result { + decoder.str_iter()?.collect() +} + +// The following seems to be a bit slower than the implementation that we have +// kept in the `aws_smithy_cbor::Decoder`. +fn str_alt<'b>( + decoder: &'b mut minicbor::Decoder<'b>, +) -> Result, minicbor::decode::Error> { + // This implementation uses `next` twice to see if there is + // another str chunk. If there is, it returns a owned `String`. + let mut chunks_iter = decoder.str_iter()?; + let head = match chunks_iter.next() { + Some(Ok(head)) => head, + None => return Ok(Cow::Borrowed("")), + Some(Err(e)) => return Err(e), + }; + + match chunks_iter.next() { + None => Ok(Cow::Borrowed(head)), + Some(Err(e)) => Err(e), + Some(Ok(next)) => { + let mut concatenated_string = String::from(head); + concatenated_string.push_str(next); + for chunk in chunks_iter { + concatenated_string.push_str(chunk?); + } + Ok(Cow::Owned(concatenated_string)) + } + } +} + +// We have two `string` implementations. One uses `collect` the other +// uses `String::new` followed by `string::push`. +pub fn string_benchmark(c: &mut Criterion) { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, 0x79, + 0xff, + ]; + + c.bench_function("definite string()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&definite_bytes); + let _ = black_box(decoder.string()); + }) + }); + + c.bench_function("definite string_alt()", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let _ = black_box(string_alt(&mut decoder)); + }) + }); + + c.bench_function("indefinite string()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&indefinite_bytes); + let _ = black_box(decoder.string()); + }) + }); + + c.bench_function("indefinite string_alt()", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let _ = black_box(string_alt(&mut decoder)); + }) + }); +} + +criterion_group!(benches, string_benchmark, str_benchmark,); +criterion_main!(benches); diff --git a/rust-runtime/aws-smithy-cbor/src/data.rs b/rust-runtime/aws-smithy-cbor/src/data.rs new file mode 100644 index 0000000000..a6eab6c549 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/data.rs @@ -0,0 +1,97 @@ +#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Debug, Hash)] +pub enum Type { + Bool, + Null, + Undefined, + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, + Int, + F16, + F32, + F64, + Simple, + Bytes, + BytesIndef, + String, + StringIndef, + Array, + ArrayIndef, + Map, + MapIndef, + Tag, + Break, + Unknown(u8), +} + +impl Type { + pub(crate) fn new(ty: minicbor::data::Type) -> Self { + match ty { + minicbor::data::Type::Bool => Self::Bool, + minicbor::data::Type::Null => Self::Null, + minicbor::data::Type::Undefined => Self::Undefined, + minicbor::data::Type::U8 => Self::U8, + minicbor::data::Type::U16 => Self::U16, + minicbor::data::Type::U32 => Self::U32, + minicbor::data::Type::U64 => Self::U64, + minicbor::data::Type::I8 => Self::I8, + minicbor::data::Type::I16 => Self::I16, + minicbor::data::Type::I32 => Self::I32, + minicbor::data::Type::I64 => Self::I64, + minicbor::data::Type::Int => Self::Int, + minicbor::data::Type::F16 => Self::F16, + minicbor::data::Type::F32 => Self::F32, + minicbor::data::Type::F64 => Self::F64, + minicbor::data::Type::Simple => Self::Simple, + minicbor::data::Type::Bytes => Self::Bytes, + minicbor::data::Type::BytesIndef => Self::BytesIndef, + minicbor::data::Type::String => Self::String, + minicbor::data::Type::StringIndef => Self::StringIndef, + minicbor::data::Type::Array => Self::Array, + minicbor::data::Type::ArrayIndef => Self::ArrayIndef, + minicbor::data::Type::Map => Self::Map, + minicbor::data::Type::MapIndef => Self::MapIndef, + minicbor::data::Type::Tag => Self::Tag, + minicbor::data::Type::Break => Self::Break, + minicbor::data::Type::Unknown(byte) => Self::Unknown(byte), + } + } + + // This is just the reverse mapping of `new`. + pub(crate) fn into_minicbor_type(self) -> minicbor::data::Type { + match self { + Type::Bool => minicbor::data::Type::Bool, + Type::Null => minicbor::data::Type::Null, + Type::Undefined => minicbor::data::Type::Undefined, + Type::U8 => minicbor::data::Type::U8, + Type::U16 => minicbor::data::Type::U16, + Type::U32 => minicbor::data::Type::U32, + Type::U64 => minicbor::data::Type::U64, + Type::I8 => minicbor::data::Type::I8, + Type::I16 => minicbor::data::Type::I16, + Type::I32 => minicbor::data::Type::I32, + Type::I64 => minicbor::data::Type::I64, + Type::Int => minicbor::data::Type::Int, + Type::F16 => minicbor::data::Type::F16, + Type::F32 => minicbor::data::Type::F32, + Type::F64 => minicbor::data::Type::F64, + Type::Simple => minicbor::data::Type::Simple, + Type::Bytes => minicbor::data::Type::Bytes, + Type::BytesIndef => minicbor::data::Type::BytesIndef, + Type::String => minicbor::data::Type::String, + Type::StringIndef => minicbor::data::Type::StringIndef, + Type::Array => minicbor::data::Type::Array, + Type::ArrayIndef => minicbor::data::Type::ArrayIndef, + Type::Map => minicbor::data::Type::Map, + Type::MapIndef => minicbor::data::Type::MapIndef, + Type::Tag => minicbor::data::Type::Tag, + Type::Break => minicbor::data::Type::Break, + Type::Unknown(byte) => minicbor::data::Type::Unknown(byte), + } + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs new file mode 100644 index 0000000000..bbff32bfd5 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -0,0 +1,314 @@ +use std::borrow::Cow; + +use aws_smithy_types::{Blob, DateTime}; +use minicbor::decode::Error; + +use crate::data::Type; + +/// Provides functions for decoding a CBOR object with a known schema. +/// +/// Although CBOR is a self-describing format, this decoder is tailored for cases where the schema +/// is known in advance. Therefore, the caller can determine which object key exists at the current +/// position by calling `str` method, and call the relevant function based on the predetermined schema +/// for that key. If an unexpected key is encountered, the caller can use the `skip` method to skip +/// over the element. +#[derive(Debug, Clone)] +pub struct Decoder<'b> { + decoder: minicbor::Decoder<'b>, +} + +/// When any of the decode methods are called they look for that particular data type at the current +/// position. If the CBOR data tag does not match the type, a `DeserializeError` is returned. +#[derive(Debug)] +pub struct DeserializeError { + #[allow(dead_code)] + _inner: Error, +} + +impl std::fmt::Display for DeserializeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self._inner.fmt(f) + } +} + +impl std::error::Error for DeserializeError {} + +impl DeserializeError { + pub(crate) fn new(inner: Error) -> Self { + Self { _inner: inner } + } + + /// More than one union variant was detected: `unexpected_type` was unexpected. + pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self { + Self { + _inner: Error::type_mismatch(unexpected_type.into_minicbor_type()) + .with_message("encountered unexpected union variant; expected end of union") + .at(at), + } + } + + /// More than one union variant was detected, but we never even got to parse the first one. + pub fn mixed_union_variants(at: usize) -> Self { + Self { + _inner: Error::message("encountered mixed variants in union; expected end of union") + .at(at), + } + } + + /// An unexpected type was encountered. + // We handle this one when decoding sparse collections: we have to expect either a `null` or an + // item, so we try decoding both. + pub fn is_type_mismatch(&self) -> bool { + self._inner.is_type_mismatch() + } +} + + +/// Macro for delegating method calls to the decoder. +/// +/// This macro generates wrapper methods for calling specific encoder methods on the decoder +/// and returning the result with error handling. +/// +/// # Example +/// +/// ``` +/// delegate_method! { +/// /// Wrapper method for encoding method `encode_str` on the decoder. +/// encode_str_wrapper => encode_str(String); +/// /// Wrapper method for encoding method `encode_int` on the decoder. +/// encode_int_wrapper => encode_int(i32); +/// } +/// ``` +macro_rules! delegate_method { + ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => { + $( + pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> { + self.decoder.$encoder_name().map_err(DeserializeError::new) + } + )+ + }; +} + +impl<'b> Decoder<'b> { + pub fn new(bytes: &'b [u8]) -> Self { + Self { + decoder: minicbor::Decoder::new(bytes), + } + } + + pub fn datatype(&self) -> Result { + self.decoder + .datatype() + .map(Type::new) + .map_err(DeserializeError::new) + } + + delegate_method! { + /// Skips the current CBOR element. + skip => skip(()); + /// Reads a boolean at the current position. + boolean => bool(bool); + /// Reads a byte at the current position. + byte => i8(i8); + /// Reads a short at the current position. + short => i16(i16); + /// Reads a integer at the current position. + integer => i32(i32); + /// Reads a long at the current position. + long => i64(i64); + /// Reads a float at the current position. + float => f32(f32); + /// Reads a double at the current position. + double => f64(f64); + /// Reads a null CBOR element at the current position. + null => null(()); + /// Returns the number of elements in a definite list. For indefinite lists it returns a `None`. + list => array(Option); + /// Returns the number of elements in a definite map. For indefinite map it returns a `None`. + map => map(Option); + } + + /// Returns the current position of the buffer, which will be decoded when any of the methods is called. + pub fn position(&self) -> usize { + self.decoder.position() + } + + /// Returns a `cow::Borrowed(&str)` if the element at the current position in the buffer is a definite + /// length string. Otherwise, it returns a `cow::Owned(String)` if the element at the current position is an + /// indefinite-length string. An error is returned if the element is neither a definite length nor an + /// indefinite-length string. + pub fn str(&mut self) -> Result, DeserializeError> { + let bookmark = self.decoder.position(); + match self.decoder.str() { + Ok(str_value) => Ok(Cow::Borrowed(str_value)), + Err(e) if e.is_type_mismatch() => { + // Move the position back to the start of the CBOR element and then try + // decoding it as a indefinite length string. + self.decoder.set_position(bookmark); + Ok(Cow::Owned(self.string()?)) + } + Err(e) => Err(DeserializeError::new(e)), + } + } + + /// Allocates and returns a `String` if the element at the current position in the buffer is either a + /// definite-length or an indefinite-length string. Otherwise, an error is returned if the element is not a string type. + pub fn string(&mut self) -> Result { + let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?; + let head = iter.next(); + + let decoded_string = match head { + None => String::new(), + Some(head) => { + let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?); + for chunk in iter { + combined_chunks.push_str(chunk.map_err(DeserializeError::new)?); + } + combined_chunks + } + }; + + Ok(decoded_string) + } + + /// Returns a `blob` if the element at the current position in the buffer is a byte string. Otherwise, + /// a `DeserializeError` error is returned. + pub fn blob(&mut self) -> Result { + let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?; + let parts: Vec<&[u8]> = iter + .collect::>() + .map_err(DeserializeError::new)?; + + Ok(if parts.len() == 1 { + Blob::new(parts[0]) // Directly convert &[u8] to Blob if there's only one part. + } else { + Blob::new(parts.concat()) // Concatenate all parts into a single Blob. + }) + } + + /// Returns a `DateTime` if the element at the current position in the buffer is a `timestamp`. Otherwise, + /// a `DeserializeError` error is returned. + pub fn timestamp(&mut self) -> Result { + let tag = self.decoder.tag().map_err(DeserializeError::new)?; + + if !matches!(tag, minicbor::data::Tag::Timestamp) { + Err(DeserializeError::new(Error::message( + "expected timestamp tag", + ))) + } else { + let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?; + Ok(DateTime::from_secs_f64(epoch_seconds)) + } + } +} + +#[derive(Debug)] +pub struct ArrayIter<'a, 'b, T> { + inner: minicbor::decode::ArrayIter<'a, 'b, T>, +} + +impl<'a, 'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'a, 'b, T> { + type Item = Result; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|opt| opt.map_err(DeserializeError::new)) + } +} + +#[derive(Debug)] +pub struct MapIter<'a, 'b, K, V> { + inner: minicbor::decode::MapIter<'a, 'b, K, V>, +} + +impl<'a, 'b, K, V> Iterator for MapIter<'a, 'b, K, V> +where + K: minicbor::Decode<'b, ()>, + V: minicbor::Decode<'b, ()>, +{ + type Item = Result<(K, V), DeserializeError>; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|opt| opt.map_err(DeserializeError::new)) + } +} + +pub fn set_optional(builder: B, decoder: &mut Decoder, f: F) -> Result +where + F: Fn(B, &mut Decoder) -> Result, +{ + match decoder.datatype()? { + crate::data::Type::Null => { + decoder.null()?; + Ok(builder) + } + _ => f(builder, decoder), + } +} + +#[cfg(test)] +mod tests { + use crate::Decoder; + + #[test] + fn test_definite_str_is_cow_borrowed() { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + let mut decoder = Decoder::new(&definite_bytes); + let member = decoder.str().expect("could not decode str"); + assert_eq!(member, "thisIsAKey"); + assert!(matches!(member, std::borrow::Cow::Borrowed(_))); + } + + #[test] + fn test_indefinite_str_is_cow_owned() { + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, + 0x79, 0xff, + ]; + let mut decoder = Decoder::new(&indefinite_bytes); + let member = decoder.str().expect("could not decode str"); + assert_eq!(member, "thisIsAKey"); + assert!(matches!(member, std::borrow::Cow::Owned(_))); + } + + #[test] + fn test_empty_str_works() { + let bytes = [0x60]; + let mut decoder = Decoder::new(&bytes); + let member = decoder.str().expect("could not decode empty str"); + assert_eq!(member, ""); + } + + #[test] + fn test_empty_blob_works() { + let bytes = [0x40]; + let mut decoder = Decoder::new(&bytes); + let member = decoder.blob().expect("could not decode an empty blob"); + assert_eq!(member, aws_smithy_types::Blob::new(&[])); + } + + #[test] + fn test_indefinite_length_blob() { + // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`. + // https://cbor.nemo157.com/#type=hex&value=bf69626c6f6256616c75655f50696e646566696e6974652d627974652c49206368756e6b65642c4e206f6e206561636820636f6d6d61ffff + let indefinite_bytes = [ + 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62, + 0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c, + 0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, + 0x61, 0xff, + ]; + let mut decoder = Decoder::new(&indefinite_bytes); + let member = decoder.blob().expect("could not decode blob"); + assert_eq!( + member, + aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes()) + ); + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs new file mode 100644 index 0000000000..475071ffcc --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -0,0 +1,92 @@ +use aws_smithy_types::{Blob, DateTime}; + +/// Macro for delegating method calls to the encoder. +/// +/// This macro generates wrapper methods for calling specific encoder methods on the encoder +/// and returning a mutable reference to self for method chaining. +/// +/// # Example +/// +/// ``` +/// delegate_method! { +/// /// Wrapper method for encoding method `encode_str` on the encoder. +/// encode_str_wrapper => encode_str(data: &str); +/// /// Wrapper method for encoding method `encode_int` on the encoder. +/// encode_int_wrapper => encode_int(value: i32); +/// } +/// ``` +macro_rules! delegate_method { + ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($($param_name:ident : $param_type:ty),*);)+) => { + $( + pub fn $wrapper_name(&mut self, $($param_name: $param_type),*) -> &mut Self { + self.encoder.$encoder_name($($param_name)*).expect(INFALLIBLE_WRITE); + self + } + )+ + }; +} + +#[derive(Debug, Clone)] +pub struct Encoder { + encoder: minicbor::Encoder>, +} + +// TODO docs +const INFALLIBLE_WRITE: &str = "write failed"; + +impl Encoder { + pub fn new(writer: Vec) -> Self { + Self { + encoder: minicbor::Encoder::new(writer), + } + } + + delegate_method! { + /// Writes a fixed length array of given length. + array => array(len: u64); + /// Used when we know the size in advance, i.e.: + /// - when a struct has all non-`Option`al members. + /// - when serializing `union` shapes (they can only have one member set). + map => map(len: u64); + /// Used when it's not cheap to calculate the size, i.e. when the struct has one or more + /// `Option`al members. + begin_map => begin_map(); + /// Writes a definite length string. + str => str(x: &str); + /// Writes a boolean value. + boolean => bool(x: bool); + /// Writes a byte value. + byte => i8(x: i8); + /// Writes a short value. + short => i16(x: i16); + /// Writes an integer value. + integer => i32(x: i32); + /// Writes an long value. + long => i64(x: i64); + /// Writes an float value. + float => f32(x: f32); + /// Writes an double value. + double => f64(x: f64); + /// Writes a null tag. + null => null(); + /// Writes an end tag. + end => end(); + } + + pub fn blob(&mut self, x: &Blob) -> &mut Self { + self.encoder.bytes(x.as_ref()).expect(INFALLIBLE_WRITE); + self + } + + pub fn timestamp(&mut self, x: &DateTime) -> &mut Self { + self.encoder + .tag(minicbor::data::Tag::Timestamp) + .expect(INFALLIBLE_WRITE); + self.encoder.f64(x.as_secs_f64()).expect(INFALLIBLE_WRITE); + self + } + + pub fn into_writer(self) -> Vec { + self.encoder.into_writer() + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/lib.rs b/rust-runtime/aws-smithy-cbor/src/lib.rs new file mode 100644 index 0000000000..1a547fedee --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/lib.rs @@ -0,0 +1,13 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! CBOR abstractions for Smithy. + +pub mod data; +pub mod decode; +pub mod encode; + +pub use decode::Decoder; +pub use encode::Encoder; diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 18b12aaa01..0009e5c0d5 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-http-server" -version = "0.61.2" +version = "0.61.3" authors = ["Smithy Rust Server "] edition = "2021" license = "Apache-2.0" @@ -23,6 +23,7 @@ aws-smithy-json = { path = "../aws-smithy-json" } aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["http-02x"] } aws-smithy-types = { path = "../aws-smithy-types", features = ["http-body-0-4-x", "hyper-0-14-x"] } aws-smithy-xml = { path = "../aws-smithy-xml" } +aws-smithy-cbor = { path = "../aws-smithy-cbor" } bytes = "1.1" futures-util = { version = "0.3.29", default-features = false } http = "0.2" diff --git a/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml new file mode 100644 index 0000000000..e6ae4c9f6c --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "rpcv2-service" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +# Local paths +aws-smithy-http-server = { path = "../../" } +hyper = "0.14.24" +tokio = "1.25.0" +rpcv2-server-sdk = { path = "../rpcv2-server-sdk/", package = "rpcv2" } +rpcv2-pokemon-client = { path = "../rpcv2-pokemon-client/" } diff --git a/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs new file mode 100644 index 0000000000..96e2eb0221 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/examples/rpcv2-service/src/main.rs @@ -0,0 +1,32 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +use std::net::SocketAddr; + +use aws_smithy_http_server::{body::Body, routing::Route}; +use rpcv2_server_sdk::{error, input, output, RpcV2Service}; + +async fn handler( + input: input::RpcV2OperationInput, +) -> Result { + println!("{input:#?}"); + + todo!() +} + +#[tokio::main] +async fn main() { + let service: RpcV2Service> = rpcv2_server_sdk::RpcV2Service::builder_without_plugins() + .rpc_v2_operation(handler) + .build() + .unwrap(); + + let server = service.into_make_service(); + let bind: SocketAddr = "127.0.0.1:6969" + .parse() + .expect("unable to parse the server bind address and port"); + + println!("Binding {bind}"); + hyper::Server::bind(&bind).serve(server).await.unwrap(); +} diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs index 295c670b27..74fc44b8df 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs @@ -39,7 +39,7 @@ pub enum Error { // This constant determines when the `TinyMap` implementation switches from being a `Vec` to a // `HashMap`. This is chosen to be 15 as a result of the discussion around // https://github.com/smithy-lang/smithy-rs/pull/1429#issuecomment-1147516546 -const ROUTE_CUTOFF: usize = 15; +pub(crate) const ROUTE_CUTOFF: usize = 15; /// A [`Router`] supporting [`AWS JSON 1.0`] and [`AWS JSON 1.1`] protocols. /// @@ -65,6 +65,7 @@ impl AwsJsonRouter { } } + // TODO This function is not used? Codegen should probably delegate to this function. /// Applies type erasure to the inner route using [`Route::new`]. pub fn boxed(self) -> AwsJsonRouter> where diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs index 0b3faf0136..4686e94e07 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs @@ -9,6 +9,7 @@ pub mod aws_json_11; pub mod rest; pub mod rest_json_1; pub mod rest_xml; +pub mod rpc_v2; use crate::rejection::MissingContentTypeReason; use aws_smithy_runtime_api::http::Headers as SmithyHeaders; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs new file mode 100644 index 0000000000..716befd77a --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +pub mod rejection; +pub mod router; +pub mod runtime_error; + +// TODO: Fill link +/// [Smithy RPC V2](). +pub struct RpcV2; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/rejection.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/rejection.rs new file mode 100644 index 0000000000..2ec8b957af --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/rejection.rs @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::num::TryFromIntError; + +use crate::rejection::MissingContentTypeReason; +use aws_smithy_runtime_api::http::HttpError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ResponseRejection { + #[error("invalid bound HTTP status code; status codes must be inside the 100-999 range: {0}")] + InvalidHttpStatusCode(TryFromIntError), + #[error("error serializing CBOR-encoded body: {0}")] + Serialization(#[from] aws_smithy_types::error::operation::SerializationError), + #[error("error building HTTP response: {0}")] + HttpBuild(#[from] http::Error), +} + +#[derive(Debug, Error)] +pub enum RequestRejection { + #[error("error converting non-streaming body to bytes: {0}")] + BufferHttpBodyBytes(crate::Error), + #[error("request contains invalid value for `Accept` header")] + NotAcceptable, + #[error("expected `Content-Type` header not found: {0}")] + MissingContentType(#[from] MissingContentTypeReason), + #[error("error deserializing request HTTP body as CBOR: {0}")] + CborDeserialize(#[from] aws_smithy_cbor::decode::DeserializeError), + // Unlike the other protocols, RPC v2 uses CBOR, a binary serialization format, so we take in a + // `Vec` here instead of `String`. + #[error("request does not adhere to modeled constraints")] + ConstraintViolation(Vec), + + /// Typically happens when the request has headers that are not valid UTF-8. + #[error("failed to convert request: {0}")] + HttpConversion(#[from] HttpError), +} + +impl From for RequestRejection { + fn from(_err: std::convert::Infallible) -> Self { + match _err {} + } +} + +convert_to_request_rejection!(hyper::Error, BufferHttpBodyBytes); +convert_to_request_rejection!(Box, BufferHttpBodyBytes); diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs new file mode 100644 index 0000000000..3e73b8b9e0 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/router.rs @@ -0,0 +1,409 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::convert::Infallible; +use std::str::FromStr; + +use http::header::ToStrError; +use http::HeaderMap; +use once_cell::sync::Lazy; +use regex::Regex; +use thiserror::Error; +use tower::Layer; +use tower::Service; + +use crate::body::empty; +use crate::body::BoxBody; +use crate::extension::RuntimeErrorExtension; +use crate::protocol::aws_json_11::router::ROUTE_CUTOFF; +use crate::response::IntoResponse; +use crate::routing::tiny_map::TinyMap; +use crate::routing::Route; +use crate::routing::Router; +use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; + +use super::RpcV2; + +pub use crate::protocol::rest::router::*; + +/// An RPC v2 routing error. +#[derive(Debug, Error)] +pub enum Error { + /// Method was not `POST`. + #[error("method not POST")] + MethodNotAllowed, + /// Requests for the `rpcv2` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` + /// header. + #[error("contains forbidden headers")] + ForbiddenHeaders, + /// Unable to parse `smithy-protocol` header into a valid wire format value. + #[error("failed to parse `smithy-protocol` header into a valid wire format value")] + InvalidWireFormatHeader(#[from] WireFormatError), + /// Operation not found. + #[error("operation not found")] + NotFound, +} + +// TODO Docs +#[derive(Debug, Clone)] +pub struct RpcV2Router { + routes: TinyMap<&'static str, S, ROUTE_CUTOFF>, +} + +/// Requests for the `rpcv2` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` +/// header. An `rpcv2` request is malformed if it contains either of these headers. Server-side +/// implementations MUST reject such requests for security reasons. +const FORBIDDEN_HEADERS: &'static [&'static str] = &["x-amz-target", "x-amzn-target"]; + +/// Matches the `Identifier` ABNF rule in +/// . +const IDENTIFIER_PATTERN: &'static str = r#"((_+([A-Za-z]|[0-9]))|[A-Za-z])[A-Za-z0-9_]*"#; + +impl RpcV2Router { + // TODO Consider building a nom parser + fn uri_path_regex() -> &'static Regex { + // Every request for the `rpcv2Cbor` protocol MUST be sent to a URL with the + // following form: `{prefix?}/service/{serviceName}/operation/{operationName}` + // + // * The optional `prefix` segment may span multiple path segments and is not + // utilized by the Smithy RPC v2 CBOR protocol. For example, a service could + // use a `v1` prefix for the following URL path: `v1/service/FooService/operation/BarOperation` + // * The `serviceName` segment MUST be replaced by the [`shape + // name`](https://smithy.io/2.0/spec/model.html#grammar-token-smithy-Identifier) + // of the service's [Shape ID](https://smithy.io/2.0/spec/model.html#shape-id) + // in the Smithy model. The `serviceName` produced by client implementations + // MUST NOT contain the namespace of the `service` shape. Service + // implementations SHOULD accept an absolute shape ID as the content of this + // segment with the `#` character replaced with a `.` character, routing it + // the same as if only the name was specified. For example, if the `service`'s + // absolute shape ID is `com.example#TheService`, a service should accept both + // `TheService` and `com.example.TheService` as values for the `serviceName` + // segment. + static PATH_REGEX: Lazy = Lazy::new(|| { + Regex::new(&format!( + r#"/service/({}\.)*(?P{})/operation/(?P{})$"#, + IDENTIFIER_PATTERN, IDENTIFIER_PATTERN, IDENTIFIER_PATTERN, + )) + .unwrap() + }); + + &PATH_REGEX + } + + pub fn wire_format_regex() -> &'static Regex { + static SMITHY_PROTOCOL_REGEX: Lazy = Lazy::new(|| Regex::new(r#"^rpc-v2-(?P\w+)$"#).unwrap()); + + &SMITHY_PROTOCOL_REGEX + } + + pub fn boxed(self) -> RpcV2Router> + where + S: Service, Response = http::Response, Error = Infallible>, + S: Send + Clone + 'static, + S::Future: Send + 'static, + { + RpcV2Router { + routes: self.routes.into_iter().map(|(key, s)| (key, Route::new(s))).collect(), + } + } + + /// Applies a [`Layer`] uniformly to all routes. + pub fn layer(self, layer: L) -> RpcV2Router + where + L: Layer, + { + RpcV2Router { + routes: self + .routes + .into_iter() + .map(|(key, route)| (key, layer.layer(route))) + .collect(), + } + } +} + +// TODO: Implement (current body copied from the rest xml impl) +// and document. +/// A Smithy RPC V2 routing error. +impl IntoResponse for Error { + fn into_response(self) -> http::Response { + match self { + Error::NotFound => http::Response::builder() + .status(http::StatusCode::NOT_FOUND) + // TODO + .header(http::header::CONTENT_TYPE, "application/xml") + .extension(RuntimeErrorExtension::new( + UNKNOWN_OPERATION_EXCEPTION.to_string(), + )) + .body(empty()) + .expect("invalid HTTP response for REST XML routing error; please file a bug report under https://github.com/awslabs/smithy-rs/issues"), + Error::MethodNotAllowed => method_disallowed(), + // TODO + _ => todo!(), + } + } +} + +/// Errors that can happen when parsing the wire format from the `smithy-protocol` header. +#[derive(Debug, Error)] +pub enum WireFormatError { + /// Header not found. + #[error("`smithy-protocol` header not found")] + HeaderNotFound, + /// Header value is not visible ASCII. + #[error("`smithy-protocol` header not visible ASCII")] + HeaderValueNotVisibleAscii(ToStrError), + /// Header value does not match the `rpc-v2-{format}` pattern. The actual parsed header value + /// is stored in the tuple struct. + // https://doc.rust-lang.org/std/fmt/index.html#escaping + #[error("`smithy-protocol` header does not match the `rpc-v2-{{format}}` pattern: `{0}`")] + HeaderValueNotValid(String), + /// Header value matches the `rpc-v2-{format}` pattern, but the `format` is not supported. The + /// actual parsed header value is stored in the tuple struct. + #[error("found unsupported `smithy-protocol` wire format: `{0}`")] + WireFormatNotSupported(String), +} + +/// Smithy RPC V2 requests have a `smithy-protocol` header with the value +/// `"rpc-v2-{format}"`, where `format` is one of the supported wire formats +/// by the protocol (see [`WireFormat`]). +fn parse_wire_format_from_header(headers: &HeaderMap) -> Result { + let header = headers.get("smithy-protocol").ok_or(WireFormatError::HeaderNotFound)?; + let header = header + .to_str() + .map_err(|e| WireFormatError::HeaderValueNotVisibleAscii(e))?; + let captures = RpcV2Router::<()>::wire_format_regex() + .captures(header) + .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; + + let format = captures + .name("format") + .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; + + let wire_format_parse_res: Result = format.as_str().parse(); + wire_format_parse_res.map_err(|_| WireFormatError::WireFormatNotSupported(header.to_owned())) +} + +/// Supported wire formats by RPC V2. +enum WireFormat { + Cbor, +} + +struct WireFormatFromStrError; + +impl FromStr for WireFormat { + type Err = WireFormatFromStrError; + + fn from_str(format: &str) -> Result { + match format { + "cbor" => Ok(Self::Cbor), + _ => Err(WireFormatFromStrError), + } + } +} + +impl Router for RpcV2Router { + type Service = S; + + type Error = Error; + + fn match_route(&self, request: &http::Request) -> Result { + // Only `Method::POST` is allowed. + if request.method() != http::Method::POST { + return Err(Error::MethodNotAllowed); + } + + // Some headers are not allowed. + let request_has_forbidden_header = FORBIDDEN_HEADERS + .into_iter() + .any(|&forbidden_header| request.headers().contains_key(forbidden_header)); + if request_has_forbidden_header { + return Err(Error::ForbiddenHeaders); + } + + // Wire format has to be specified and supported. + let _wire_format = parse_wire_format_from_header(request.headers())?; + + // Extract the service name and the operation name from the request URI. + let request_path = request.uri().path(); + let regex = Self::uri_path_regex(); + + tracing::trace!(%request_path, "capturing service and operation from URI"); + let captures = regex.captures(request_path).ok_or(Error::NotFound)?; + let (service, operation) = (&captures["service"], &captures["operation"]); + tracing::trace!(%service, %operation, "captured service and operation from URI"); + + // Lookup in the `TinyMap` for a route for the target. + let route = self + .routes + .get((format!("{service}.{operation}")).as_str()) + .ok_or(Error::NotFound)?; + Ok(route.clone()) + } +} + +impl FromIterator<(&'static str, S)> for RpcV2Router { + fn from_iter>(iter: T) -> Self { + Self { + routes: iter.into_iter().collect(), + } + } +} + +#[cfg(test)] +mod tests { + use http::{HeaderMap, HeaderValue, Method}; + use regex::Regex; + + use crate::protocol::test_helpers::req; + + use super::{Error, Router, RpcV2Router}; + + fn identifier_regex() -> Regex { + Regex::new(&format!("^{}$", super::IDENTIFIER_PATTERN)).unwrap() + } + + #[test] + fn valid_identifiers() { + let valid_identifiers = vec!["a", "_a", "_0", "__0", "variable123", "_underscored_variable"]; + + for id in &valid_identifiers { + assert!(identifier_regex().is_match(id), "'{}' is incorrectly rejected", id); + } + } + + #[test] + fn invalid_identifiers() { + let invalid_identifiers = vec![ + "0", + "123starts_with_digit", + "@invalid_start_character", + " space_in_identifier", + "invalid-character", + "invalid@character", + "no#hashes", + ]; + + for id in &invalid_identifiers { + assert!(!identifier_regex().is_match(id), "'{}' is incorrectly accepted", id); + } + } + + #[test] + fn uri_regex_works_accepts() { + let regex = RpcV2Router::<()>::uri_path_regex(); + + for uri in [ + "/service/Service/operation/Operation", + "prefix/69/service/Service/operation/Operation", + // Here the prefix is up to the last occurrence of the string `/service`. + "prefix/69/service/Service/operation/Operation/service/Service/operation/Operation", + // Service implementations SHOULD accept an absolute shape ID as the content of this + // segment with the `#` character replaced with a `.` character, routing it the same as + // if only the name was specified. For example, if the `service`'s absolute shape ID is + // `com.example#TheService`, a service should accept both `TheService` and + // `com.example.TheService` as values for the `serviceName` segment. + "/service/aws.protocoltests.rpcv2.Service/operation/Operation", + "/service/namespace.Service/operation/Operation", + ] { + let captures = regex.captures(uri).unwrap(); + assert_eq!("Service", &captures["service"], "uri: {}", uri); + assert_eq!("Operation", &captures["operation"], "uri: {}", uri); + } + } + + #[test] + fn uri_regex_works_rejects() { + let regex = RpcV2Router::<()>::uri_path_regex(); + + for uri in [ + "", + "foo", + "/servicee/Service/operation/Operation", + "/service/Service", + "/service/Service/operation/", + "/service/Service/operation/Operation/", + "/service/Service/operation/Operation/invalid-suffix", + "/service/namespace.foo#Service/operation/Operation", + "/service/namespace-Service/operation/Operation", + "/service/.Service/operation/Operation", + ] { + assert!(regex.captures(uri).is_none(), "uri: {}", uri); + } + } + + #[test] + fn wire_format_regex_works() { + let regex = RpcV2Router::<()>::wire_format_regex(); + + let captures = regex.captures("rpc-v2-something").unwrap(); + assert_eq!("something", &captures["format"]); + + let captures = regex.captures("rpc-v2-SomethingElse").unwrap(); + assert_eq!("SomethingElse", &captures["format"]); + + let invalid = regex.captures("rpc-v1-something"); + assert!(invalid.is_none()); + } + + /// Helper function returning the only strictly required header. + fn headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert("smithy-protocol", HeaderValue::from_static("rpc-v2-cbor")); + headers + } + + #[test] + fn simple_routing() { + let router: RpcV2Router<_> = ["Service.Operation"].into_iter().map(|op| (op, ())).collect(); + let good_uri = "/prefix/service/Service/operation/Operation"; + + // The request should match. + let routing_result = router.match_route(&req(&Method::POST, good_uri, Some(headers()))); + assert!(routing_result.is_ok()); + + // The request would be valid if it used `Method::POST`. + let invalid_request = req(&Method::GET, good_uri, Some(headers())); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::MethodNotAllowed) + )); + + // The request would be valid if it did not have forbidden headers. + for forbidden_header_name in ["x-amz-target", "x-amzn-target"] { + let mut headers = headers(); + headers.insert(forbidden_header_name, HeaderValue::from_static("Service.Operation")); + let invalid_request = req(&Method::POST, good_uri, Some(headers)); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::ForbiddenHeaders) + )); + } + + for bad_uri in [ + // These requests would be valid if they used correct URIs. + "/prefix/Service/Service/operation/Operation", + "/prefix/service/Service/operation/Operation/suffix", + // These requests would be valid if their URI matched an existing operation. + "/prefix/service/ThisServiceDoesNotExist/operation/Operation", + "/prefix/service/Service/operation/ThisOperationDoesNotExist", + ] { + let invalid_request = &req(&Method::POST, bad_uri, Some(headers())); + assert!(matches!(router.match_route(&invalid_request), Err(Error::NotFound))); + } + + // The request would be valid if it specified a supported wire format in the + // `smithy-protocol` header. + for header_name in ["bad-header", "rpc-v2-json", "foo-rpc-v2-cbor", "rpc-v2-cbor-foo"] { + let mut headers = HeaderMap::new(); + headers.insert("smithy-protocol", HeaderValue::from_static(header_name)); + let invalid_request = &req(&Method::POST, good_uri, Some(headers)); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::InvalidWireFormatHeader(_)) + )); + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs new file mode 100644 index 0000000000..d8ee60796a --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2/runtime_error.rs @@ -0,0 +1,82 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::response::IntoResponse; +use crate::runtime_error::{InternalFailureException, INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE}; +use crate::{extension::RuntimeErrorExtension, protocol::rpc_v2::RpcV2}; +use http::StatusCode; + +use super::rejection::{RequestRejection, ResponseRejection}; + +#[derive(Debug)] +pub enum RuntimeError { + Serialization(crate::Error), + InternalFailure(crate::Error), + NotAcceptable, + UnsupportedMediaType, + Validation(Vec), +} + +impl RuntimeError { + pub fn name(&self) -> &'static str { + match self { + Self::Serialization(_) => "SerializationException", + Self::InternalFailure(_) => "InternalFailureException", + Self::NotAcceptable => "NotAcceptableException", + Self::UnsupportedMediaType => "UnsupportedMediaTypeException", + Self::Validation(_) => "ValidationException", + } + } + + pub fn status_code(&self) -> StatusCode { + match self { + Self::Serialization(_) => StatusCode::BAD_REQUEST, + Self::InternalFailure(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::NotAcceptable => StatusCode::NOT_ACCEPTABLE, + Self::UnsupportedMediaType => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::Validation(_) => StatusCode::BAD_REQUEST, + } + } +} + +impl IntoResponse for InternalFailureException { + fn into_response(self) -> http::Response { + IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) + } +} + +impl IntoResponse for RuntimeError { + fn into_response(self) -> http::Response { + let res = http::Response::builder() + .status(self.status_code()) + .header("Content-Type", "application/cbor") + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + // TODO + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + // See https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_0-protocol.html#empty-body-serialization + _ => crate::body::to_boxed("{}"), + }; + + res.body(body) + .expect(INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE) + } +} + +impl From for RuntimeError { + fn from(err: ResponseRejection) -> Self { + Self::Serialization(crate::Error::new(err)) + } +} + +impl From for RuntimeError { + fn from(err: RequestRejection) -> Self { + match err { + RequestRejection::ConstraintViolation(reason) => Self::Validation(reason), + _ => Self::Serialization(crate::Error::new(err)), + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 14f124f687..ede1f5117b 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -37,7 +37,6 @@ use futures_util::{ use http::Response; use http_body::Body as HttpBody; use tower::{util::Oneshot, Service, ServiceExt}; -use tracing::debug; use crate::{ body::{boxed, BoxBody}, @@ -191,12 +190,13 @@ where } fn call(&mut self, req: http::Request) -> Self::Future { + tracing::debug!("inside routing service call"); match self.router.match_route(&req) { // Successfully routed, use the routes `Service::call`. Ok(ok) => RoutingFuture::from_oneshot(ok.oneshot(req)), // Failed to route, use the `R::Error`s `IntoResponse

`. Err(error) => { - debug!(%error, "failed to route"); + tracing::debug!(%error, "failed to route"); RoutingFuture::from_response(error.into_response()) } } diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index c056bbc020..e50222b200 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-protocol-test" -version = "0.60.7" +version = "0.60.8" authors = ["AWS Rust SDK Team ", "Russell Cohen "] description = "A collection of library functions to validate HTTP requests against Smithy protocol tests." edition = "2021" @@ -10,6 +10,9 @@ repository = "https://github.com/smithy-lang/smithy-rs" [dependencies] # Not perfect for our needs, but good for now assert-json-diff = "1.1" +base64-simd = "0.8" +cbor-diag = "0.1" +serde_cbor = "0.11" http = "0.2.1" pretty_assertions = "1.3" regex-lite = "0.1.5" @@ -18,7 +21,6 @@ serde_json = "1" thiserror = "1.0.40" aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["client"] } - [package.metadata.docs.rs] all-features = true targets = ["x86_64-unknown-linux-gnu"] diff --git a/rust-runtime/aws-smithy-protocol-test/src/lib.rs b/rust-runtime/aws-smithy-protocol-test/src/lib.rs index 07839e2a38..63328ebba4 100644 --- a/rust-runtime/aws-smithy-protocol-test/src/lib.rs +++ b/rust-runtime/aws-smithy-protocol-test/src/lib.rs @@ -306,10 +306,12 @@ pub fn require_headers( #[derive(Clone)] pub enum MediaType { - /// Json media types are deserialized and compared + /// JSON media types are deserialized and compared Json, /// XML media types are normalized and compared Xml, + /// CBOR media types are decoded from base64 to binary and compared + Cbor, /// For x-www-form-urlencoded, do some map order comparison shenanigans UrlEncodedForm, /// Other media types are compared literally @@ -322,6 +324,7 @@ impl> From for MediaType { "application/json" => MediaType::Json, "application/x-amz-json-1.1" => MediaType::Json, "application/xml" => MediaType::Xml, + "application/cbor" => MediaType::Cbor, "application/x-www-form-urlencoded" => MediaType::UrlEncodedForm, other => MediaType::Other(other.to_string()), } @@ -336,11 +339,11 @@ pub fn validate_body>( let body_str = std::str::from_utf8(actual_body.as_ref()); match (media_type, body_str) { (MediaType::Json, Ok(actual_body)) => try_json_eq(expected_body, actual_body), - (MediaType::Xml, Ok(actual_body)) => try_xml_equivalent(expected_body, actual_body), (MediaType::Json, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "json".to_owned(), found: "input was not valid UTF-8".to_owned(), }), + (MediaType::Xml, Ok(actual_body)) => try_xml_equivalent(actual_body, expected_body), (MediaType::Xml, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "XML".to_owned(), found: "input was not valid UTF-8".to_owned(), @@ -352,6 +355,7 @@ pub fn validate_body>( expected: "x-www-form-urlencoded".to_owned(), found: "input was not valid UTF-8".to_owned(), }), + (MediaType::Cbor, _) => try_cbor_eq(actual_body, expected_body), (MediaType::Other(media_type), Ok(actual_body)) => { if actual_body != expected_body { Err(ProtocolTestFailure::BodyDidNotMatch { @@ -410,6 +414,63 @@ fn try_json_eq(expected: &str, actual: &str) -> Result<(), ProtocolTestFailure> } } +fn try_cbor_eq>( + actual_body: T, + expected_body: &str, +) -> Result<(), ProtocolTestFailure> { + let decoded = base64_simd::STANDARD + .decode_to_vec(expected_body) + .expect("smithy protocol test `body` property is not properly base64 encoded"); + let expected_cbor_value: serde_cbor::Value = + serde_cbor::from_slice(decoded.as_slice()).unwrap(); + let actual_cbor_value: serde_cbor::Value = + serde_cbor::from_slice(actual_body.as_ref()).unwrap(); // TODO Don't panic + let actual_body_base64 = base64_simd::STANDARD.encode_to_string(&actual_body); + + if expected_cbor_value != actual_cbor_value { + let expected_body_annotated_hex: String = cbor_diag::parse_bytes(&decoded) + .expect("smithy protocol test `body` property is not valid CBOR") + .to_hex(); + let expected_body_diag: String = cbor_diag::parse_bytes(&decoded) + .expect("smithy protocol test `body` property is not valid CBOR") + .to_diag_pretty(); + let actual_body_annotated_hex: String = cbor_diag::parse_bytes(&actual_body) + .expect("actual body is not valid CBOR") + .to_hex(); + let actual_body_diag: String = cbor_diag::parse_bytes(&actual_body) + .expect("actual body is not valid CBOR") + .to_diag_pretty(); + + Err(ProtocolTestFailure::BodyDidNotMatch { + comparison: PrettyString(format!( + "{}", + Comparison::new(&expected_cbor_value, &actual_cbor_value) + )), + // The last newline is important because the panic message ends with a `.` + hint: format!( + "expected body in diagnostic format: +{} +actual body in diagnostic format: +{} +expected body in annotated hex: +{} +actual body in annotated hex: +{} +actual body in base64 (useful to update the protocol test): +{} +", + expected_body_diag, + actual_body_diag, + expected_body_annotated_hex, + actual_body_annotated_hex, + actual_body_base64, + ), + }) + } else { + Ok(()) + } +} + #[cfg(test)] mod tests { use crate::{ diff --git a/rust-runtime/aws-smithy-types/src/error/operation.rs b/rust-runtime/aws-smithy-types/src/error/operation.rs index 5b61f1acfe..ad5032fc4c 100644 --- a/rust-runtime/aws-smithy-types/src/error/operation.rs +++ b/rust-runtime/aws-smithy-types/src/error/operation.rs @@ -21,6 +21,7 @@ pub struct SerializationError { kind: SerializationErrorKind, } +// TODO The docs in `main` are wrong. impl SerializationError { /// An error that occurs when serialization of an operation fails for an unknown reason. pub fn unknown_variant(union: &'static str) -> Self {