diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 46a2722a42..0d67d890a8 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -22,3 +22,9 @@ message = "Fix incorrect redaction of `@sensitive` types in maps and lists." references = ["smithy-rs#3765", "smithy-rs#3757"] meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" } author = "landonxjames" + +[[smithy-rs]] +message = "Fix client error correction to properly parse structure members that target a `Union` containing that structure recursively." +references = ["smithy-rs#3767"] +meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" } +author = "ysaito1001" diff --git a/codegen-client-test/build.gradle.kts b/codegen-client-test/build.gradle.kts index c69a4ae719..d963166895 100644 --- a/codegen-client-test/build.gradle.kts +++ b/codegen-client-test/build.gradle.kts @@ -24,6 +24,7 @@ val workingDirUnderBuildDir = "smithyprojections/codegen-client-test/" dependencies { implementation(project(":codegen-client")) implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") } @@ -72,6 +73,12 @@ val allCodegenTests = listOf( ClientTest("aws.protocoltests.restxml#RestXml", "rest_xml", addMessageToErrors = false), ClientTest("aws.protocoltests.query#AwsQuery", "aws_query", addMessageToErrors = false), ClientTest("aws.protocoltests.ec2#AwsEc2", "ec2_query", addMessageToErrors = false), + ClientTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"), + ClientTest( + "smithy.protocoltests.rpcv2Cbor#RpcV2CborService", + "rpcv2Cbor_extras", + dependsOn = listOf("rpcv2Cbor-extras.smithy") + ), ClientTest( "aws.protocoltests.restxml.xmlns#RestXmlWithNamespace", "rest_xml_namespace", diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt index 07617d9301..e75b0ed8ef 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt @@ -87,7 +87,18 @@ private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Wri target is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this) target is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this) - target is UnionShape -> rust("Some(#T::Unknown)", targetSymbol) + target is UnionShape -> + rustTemplate( + "Some(#{unknown})", *preludeScope, + "unknown" to + writable { + if (memberSymbol.isRustBoxed()) { + rust("Box::new(#T::Unknown)", targetSymbol) + } else { + rust("#T::Unknown", targetSymbol) + } + }, + ) } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt index 85d32c2bf3..0860dbe33b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Proto import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_10 import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.dq @@ -78,6 +79,8 @@ class ClientProtocolTestGenerator( FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput"), FailingTest.RequestTest(REST_JSON, "RestJsonClientPopulatesDefaultValuesInInput"), FailingTest.RequestTest(REST_JSON, "RestJsonClientUsesExplicitlyProvidedMemberValuesOverDefaults"), + FailingTest.RequestTest(RPC_V2_CBOR, "RpcV2CborClientPopulatesDefaultValuesInInput"), + FailingTest.RequestTest(RPC_V2_CBOR, "RpcV2CborClientUsesExplicitlyProvidedMemberValuesOverDefaults"), ) private val BrokenTests: @@ -268,6 +271,7 @@ class ClientProtocolTestGenerator( """, RT.sdkBody(runtimeConfig = rc), ) + val mediaType = testCase.bodyMediaType.orNull() rustTemplate( """ use #{DeserializeResponse}; @@ -280,19 +284,19 @@ class ClientProtocolTestGenerator( let parsed = de.deserialize_streaming(&mut http_response); let parsed = parsed.unwrap_or_else(|| { let http_response = http_response.map(|body| { - #{SdkBody}::from(#{copy_from_slice}(body.bytes().unwrap())) + #{SdkBody}::from(#{copy_from_slice}(&#{decode_body_data}(body.bytes().unwrap(), #{MediaType}::from(${(mediaType ?: "unknown").dq()})))) }); de.deserialize_nonstreaming(&http_response) }); """, "copy_from_slice" to RT.Bytes.resolve("copy_from_slice"), - "SharedResponseDeserializer" to - RT.smithyRuntimeApiClient(rc) - .resolve("client::ser_de::SharedResponseDeserializer"), - "Operation" to codegenContext.symbolProvider.toSymbol(operationShape), + "decode_body_data" to RT.protocolTest(rc, "decode_body_data"), "DeserializeResponse" to RT.smithyRuntimeApiClient(rc).resolve("client::ser_de::DeserializeResponse"), + "MediaType" to RT.protocolTest(rc, "MediaType"), + "Operation" to codegenContext.symbolProvider.toSymbol(operationShape), "RuntimePlugin" to RT.runtimePlugin(rc), "SdkBody" to RT.sdkBody(rc), + "SharedResponseDeserializer" to RT.smithyRuntimeApiClient(rc).resolve("client::ser_de::SharedResponseDeserializer"), ) if (expectedShape.hasTrait()) { val errorSymbol = codegenContext.symbolProvider.symbolForOperationError(operationShape) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt index d2c1f8fc2f..11959087cf 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt @@ -19,7 +19,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator -import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.findStreamingMember @@ -125,10 +124,8 @@ class RequestSerializerGenerator( ) } - private fun needsContentLength(operationShape: OperationShape): Boolean { - return protocol.httpBindingResolver.requestBindings(operationShape) - .any { it.location == HttpLocation.DOCUMENT || it.location == HttpLocation.PAYLOAD } - } + private fun needsContentLength(operationShape: OperationShape): Boolean = + protocol.needsRequestContentLength(operationShape) private fun createHttpRequest(operationShape: OperationShape): Writable = writable { diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy index c60b93736d..61a04466c1 100644 --- a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -130,7 +130,8 @@ apply ErrorSerializationOperation @httpMalformedRequestTests([ "Content-Type": "application/cbor" } // An empty CBOR map. We're missing a lot of `@required` members! - body: "oA==" + body: "oA==", + bodyMediaType: "application/cbor" }, response: { code: 400, @@ -149,9 +150,9 @@ apply ErrorSerializationOperation @httpResponseTests([ id: "OperationOutputSerializationQuestionablyIncludesTypeField", documentation: """ Despite the operation output being a structure shape with the `@error` trait, - `__type` field should, in a strict interpretation of the spec, not be included, - because we're not serializing a server error response. However, we do, because - there shouldn't™️ be any harm in doing so, and it greatly simplifies the + `__type` field should, in a strict interpretation of the spec, not be included, + because we're not serializing a server error response. However, we do, because + there shouldn't™️ be any harm in doing so, and it greatly simplifies the code generator. This test just pins this behavior in case we ever modify it.""", protocol: rpcv2Cbor, code: 200, @@ -170,6 +171,12 @@ apply SimpleStructOperation @httpResponseTests([ id: "SimpleStruct", protocol: rpcv2Cbor, code: 200, // Not used. + body: "v2RibG9iS2Jsb2JieSBibG9iZ2Jvb2xlYW70ZnN0cmluZ3hwVGhlcmUgYXJlIHRocmVlIHRoaW5ncyBhbGwgd2lzZSBtZW4gZmVhcjogdGhlIHNlYSBpbiBzdG9ybSwgYSBuaWdodCB3aXRoIG5vIG1vb24sIGFuZCB0aGUgYW5nZXIgb2YgYSBnZW50bGUgbWFuLmRieXRlGEVlc2hvcnQYRmdpbnRlZ2VyGEdkbG9uZxhIZWZsb2F0+j8wo9dmZG91Ymxl+z/mTQE6kqMFaXRpbWVzdGFtcMH7QdcKq2AAAABkZW51bWdESUFNT05EbHJlcXVpcmVkQmxvYktibG9iYnkgYmxvYm9yZXF1aXJlZEJvb2xlYW70bnJlcXVpcmVkU3RyaW5neHBUaGVyZSBhcmUgdGhyZWUgdGhpbmdzIGFsbCB3aXNlIG1lbiBmZWFyOiB0aGUgc2VhIGluIHN0b3JtLCBhIG5pZ2h0IHdpdGggbm8gbW9vbiwgYW5kIHRoZSBhbmdlciBvZiBhIGdlbnRsZSBtYW4ubHJlcXVpcmVkQnl0ZRhFbXJlcXVpcmVkU2hvcnQYRm9yZXF1aXJlZEludGVnZXIYR2xyZXF1aXJlZExvbmcYSG1yZXF1aXJlZEZsb2F0+j8wo9ducmVxdWlyZWREb3VibGX7P+ZNATqSowVxcmVxdWlyZWRUaW1lc3RhbXDB+0HXCqtgAAAAbHJlcXVpcmVkRW51bWdESUFNT05E/w==", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + }, params: { blob: "blobby blob", boolean: false, @@ -211,6 +218,12 @@ apply SimpleStructOperation @httpResponseTests([ id: "SimpleStructWithOptionsSetToNone", protocol: rpcv2Cbor, code: 200, // Not used. + body: "v2xyZXF1aXJlZEJsb2JLYmxvYmJ5IGJsb2JvcmVxdWlyZWRCb29sZWFu9G5yZXF1aXJlZFN0cmluZ3hwVGhlcmUgYXJlIHRocmVlIHRoaW5ncyBhbGwgd2lzZSBtZW4gZmVhcjogdGhlIHNlYSBpbiBzdG9ybSwgYSBuaWdodCB3aXRoIG5vIG1vb24sIGFuZCB0aGUgYW5nZXIgb2YgYSBnZW50bGUgbWFuLmxyZXF1aXJlZEJ5dGUYRW1yZXF1aXJlZFNob3J0GEZvcmVxdWlyZWRJbnRlZ2VyGEdscmVxdWlyZWRMb25nGEhtcmVxdWlyZWRGbG9hdPo/MKPXbnJlcXVpcmVkRG91Ymxl+z/mTQE6kqMFcXJlcXVpcmVkVGltZXN0YW1wwftB1wqrYAAAAGxyZXF1aXJlZEVudW1nRElBTU9ORP8=", + bodyMediaType: "application/cbor", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + }, params: { requiredBlob: "blobby blob", requiredBoolean: false, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index f34921dfff..24c7e54268 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -147,6 +147,14 @@ class InlineDependency( CargoDependency.smithyTypes(runtimeConfig), ) + fun cborErrors(runtimeConfig: RuntimeConfig): InlineDependency = + forInlineableRustFile( + "cbor_errors", + CargoDependency.smithyCbor(runtimeConfig), + CargoDependency.smithyRuntimeApi(runtimeConfig), + CargoDependency.smithyTypes(runtimeConfig), + ) + fun ec2QueryErrors(runtimeConfig: RuntimeConfig): InlineDependency = forInlineableRustFile("ec2_query_errors", CargoDependency.smithyXml(runtimeConfig)) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index cd20c22502..61771ad61d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -519,6 +519,8 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) ) // inlinable types + fun cborErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.cborErrors(runtimeConfig)) + fun ec2QueryErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.ec2QueryErrors(runtimeConfig)) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt index c7b139bfd5..7768087a25 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt @@ -78,6 +78,13 @@ interface Protocol { * there are no response headers or statuses available to further inform the error parsing. */ fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType + + /** + * Determines whether the `Content-Length` header should be set in an HTTP request. + */ + fun needsRequestContentLength(operationShape: OperationShape): Boolean = + httpBindingResolver.requestBindings(operationShape) + .any { it.location == HttpLocation.DOCUMENT || it.location == HttpLocation.PAYLOAD } } typealias ProtocolMap = Map> diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index d1af7ae72c..f67638edba 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -7,10 +7,15 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model +import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ToShapeId +import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator @@ -18,14 +23,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.Structure import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer -import software.amazon.smithy.rust.codegen.core.util.PANIC -import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isStreaming -import software.amazon.smithy.rust.codegen.core.util.outputShape class RpcV2CborHttpBindingResolver( private val model: Model, private val contentTypes: ProtocolContentTypes, + private val serviceShape: ServiceShape, ) : HttpBindingResolver { private fun bindings(shape: ToShapeId): List { val members = shape.let { model.expectShape(it.toShapeId()) }.members() @@ -47,10 +50,12 @@ class RpcV2CborHttpBindingResolver( .toList() } - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) - // In the server, this is only used when the protocol actually supports the `@http` trait. - // However, we will have to do this for client support. Perhaps this method deserves a rename. - override fun httpTrait(operationShape: OperationShape) = PANIC("RPC v2 does not support the `@http` trait") + override fun httpTrait(operationShape: OperationShape): HttpTrait = + HttpTrait.builder() + .code(200) + .method("POST") + .uri(UriPattern.parse("/service/${serviceShape.id.name}/operation/${operationShape.id.name}")) + .build() override fun requestBindings(operationShape: OperationShape) = bindings(operationShape.inputShape) @@ -87,6 +92,8 @@ class RpcV2CborHttpBindingResolver( } open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { + private val runtimeConfig = codegenContext.runtimeConfig + override val httpBindingResolver: HttpBindingResolver = RpcV2CborHttpBindingResolver( codegenContext.model, @@ -96,26 +103,50 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { eventStreamContentType = "application/vnd.amazon.eventstream", eventStreamMessageContentType = "application/cbor", ), + codegenContext.serviceShape, ) // Note that [CborParserGenerator] and [CborSerializerGenerator] automatically (de)serialize timestamps // using floating point seconds from the epoch. override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS + override fun additionalRequestHeaders(operationShape: OperationShape): List> = + listOf("smithy-protocol" to "rpc-v2-cbor") + override fun additionalResponseHeaders(operationShape: OperationShape): List> = listOf("smithy-protocol" to "rpc-v2-cbor") override fun structuredDataParser(): StructuredDataParserGenerator = - CborParserGenerator(codegenContext, httpBindingResolver) + CborParserGenerator( + codegenContext, httpBindingResolver, + handleNullForNonSparseCollection = { collectionName: String -> + writable { + // The client should drop a null value in a dense collection, see + // https://github.com/smithy-lang/smithy/blob/6466fe77c65b8a17b219f0b0a60c767915205f95/smithy-protocol-tests/model/rpcv2Cbor/cbor-maps.smithy#L158 + rustTemplate( + """ + decoder.null()?; + return #{Ok}($collectionName) + """, + *RuntimeType.preludeScope, + ) + } + }, + ) override fun structuredDataSerializer(): StructuredDataSerializerGenerator = CborSerializerGenerator(codegenContext, httpBindingResolver) - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = - TODO("rpcv2Cbor client support has not yet been implemented") + RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata") // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = TODO("rpcv2Cbor event streams have not yet been implemented") + + // Unlike other protocols, the `rpcv2Cbor` protocol requires that `Content-Length` is always set + // unless there is no input or if the operation is an event stream, see + // https://github.com/smithy-lang/smithy/blob/6466fe77c65b8a17b219f0b0a60c767915205f95/smithy-protocol-tests/model/rpcv2Cbor/empty-input-output.smithy#L106 + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3772): Do not set `Content-Length` for event stream operations + override fun needsRequestContentLength(operationShape: OperationShape) = operationShape.input.isPresent } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 99208b0b9a..0cc16c101f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -27,6 +27,7 @@ import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock @@ -42,7 +43,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName -import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation @@ -70,6 +70,10 @@ class CborParserGenerator( private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape -> ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) }, + /** Lambda that controls what to do when seeing a NULL value while parsing an element for a non-sparse collection */ + private val handleNullForNonSparseCollection: (String) -> Writable, + /** Lambda that determines whether the input to a builder setter needs to be wrapped in `Some` */ + private val shouldWrapBuilderMemberSetterInputWithOption: (MemberShape) -> Boolean = { _ -> true }, private val customizations: List = emptyList(), ) : StructuredDataParserGenerator { private val model = codegenContext.model @@ -78,8 +82,10 @@ class CborParserGenerator( private val codegenTarget = codegenContext.target private val smithyCbor = CargoDependency.smithyCbor(runtimeConfig).toType() private val protocolFunctions = ProtocolFunctions(codegenContext) + private val builderInstantiator = codegenContext.builderInstantiator() private val codegenScope = arrayOf( + *preludeScope, "SmithyCbor" to smithyCbor, "Decoder" to smithyCbor.resolve("Decoder"), "Error" to smithyCbor.resolve("decode::DeserializeError"), @@ -87,6 +93,29 @@ class CborParserGenerator( *preludeScope, ) + private fun handleNullForCollection( + collectionName: String, + isSparse: Boolean, + ) = writable { + if (isSparse) { + rustTemplate( + """ + decoder.null()?; + #{None} + """, + *codegenScope, + ) + } else { + rustTemplate( + "#{handle_null_for_non_sparse_collection:W}", + "handle_null_for_non_sparse_collection" to + handleNullForNonSparseCollection( + collectionName, + ), + ) + } + } + private fun listMemberParserFn( listSymbol: Symbol, isSparseList: Boolean, @@ -103,29 +132,26 @@ class CborParserGenerator( *codegenScope, "ListSymbol" to listSymbol, ) { - val deserializeMemberWritable = deserializeMember(memberShape) - if (isSparseList) { - rustTemplate( - """ - let value = match decoder.datatype()? { - #{SmithyCbor}::data::Type::Null => { - decoder.null()?; - None + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + #{handleNullForCollection:W} + } + _ => #{DeserializeMember:W} + }; + """, + *codegenScope, + "handleNullForCollection" to handleNullForCollection(CollectionKind.List.decoderMethodName(), isSparseList), + "DeserializeMember" to + writable { + conditionalBlock( + "Some(", ")", isSparseList, + ) { + rust("#T?", deserializeMember(memberShape)) } - _ => Some(#{DeserializeMember:W}?), - }; - """, - *codegenScope, - "DeserializeMember" to deserializeMemberWritable, - ) - } else { - rustTemplate( - """ - let value = #{DeserializeMember:W}?; - """, - "DeserializeMember" to deserializeMemberWritable, - ) - } + }, + ) if (returnUnconstrainedType) { rust("list.0.push(value);") @@ -161,30 +187,26 @@ class CborParserGenerator( """, "DeserializeKey" to deserializeKeyWritable, ) - val deserializeValueWritable = deserializeMember(valueShape) - if (isSparseMap) { - rustTemplate( - """ - let value = match decoder.datatype()? { - #{SmithyCbor}::data::Type::Null => { - decoder.null()?; - None + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + #{handleNullForCollection:W} + } + _ => #{DeserializeMember:W} + }; + """, + *codegenScope, + "handleNullForCollection" to handleNullForCollection(CollectionKind.Map.decoderMethodName(), isSparseMap), + "DeserializeMember" to + writable { + conditionalBlock( + "Some(", ")", isSparseMap, + ) { + rust("#T?", deserializeMember(valueShape)) } - _ => Some(#{DeserializeValue:W}?), - }; - """, - *codegenScope, - "DeserializeValue" to deserializeValueWritable, - ) - } else { - rustTemplate( - """ - let value = #{DeserializeValue:W}?; - """, - "DeserializeValue" to deserializeValueWritable, - ) - } - + }, + ) if (returnUnconstrainedType) { rust("map.0.insert(key, value);") } else { @@ -216,7 +238,7 @@ class CborParserGenerator( val callBuilderSetMemberFieldWritable = writable { withBlock("builder.${member.setterName()}(", ")") { - conditionalBlock("Some(", ")", symbolProvider.toSymbol(member).isOptional()) { + conditionalBlock("Some(", ")", shouldWrapBuilderMemberSetterInputWithOption(member)) { val symbol = symbolProvider.toSymbol(member) if (symbol.isRustBoxed()) { rustBlock("") { @@ -263,7 +285,7 @@ class CborParserGenerator( rust( """ - _ => { + _ => { decoder.skip()?; builder } @@ -293,9 +315,9 @@ class CborParserGenerator( if (member.isTargetUnit()) { rust( """ - ${member.memberName.dq()} => { + ${member.memberName.dq()} => { decoder.skip()?; - #T::$variantName + #T::$variantName } """, returnSymbolToParse.symbol, @@ -313,7 +335,7 @@ class CborParserGenerator( """ _ => { decoder.skip()?; - Some(#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME}) + #{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME} } """, "Union" to returnSymbolToParse.symbol, @@ -404,9 +426,9 @@ class CborParserGenerator( """ pub(crate) fn $fnName(value: &[u8], mut builder: #{Builder}) -> #{Result}<#{Builder}, #{Error}> { #{StructurePairParserFn:W} - + let decoder = &mut #{Decoder}::new(value); - + #{DecodeStructureMapLoop:W} if decoder.position() != value.len() { @@ -496,7 +518,7 @@ class CborParserGenerator( if (this@CborParserGenerator.returnSymbolToParse(target).isUnconstrained) { rust("decoder.string()") } else { - rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) + rust("decoder.string().map(|s| #T::from(s.as_ref()))", symbolProvider.toSymbol(target)) } } false -> rust("decoder.string()") @@ -521,11 +543,11 @@ class CborParserGenerator( """ pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}> { #{ListMemberParserFn:W} - + #{InitContainerWritable:W} - + #{DecodeListLoop:W} - + Ok(list) } """, @@ -564,11 +586,11 @@ class CborParserGenerator( """ pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}> { #{MapPairParserFn:W} - + #{InitContainerWritable:W} - + #{DecodeMapLoop:W} - + Ok(map) } """, @@ -604,9 +626,9 @@ class CborParserGenerator( rustTemplate( """ #{StructurePairParserFn:W} - + let mut builder = #{Builder}::default(); - + #{DecodeStructureMapLoop:W} """, *codegenScope, @@ -619,7 +641,18 @@ class CborParserGenerator( if (returnSymbolToParse.isUnconstrained) { rust("Ok(builder)") } else { - rust("Ok(builder.build())") + val builder = + builderInstantiator.finalizeBuilder( + "builder", shape, + ) { + rustTemplate( + """|err| #{Error}::custom(err.to_string(), decoder.position())""", *codegenScope, + ) + } + rust("##[allow(clippy::needless_question_mark)]") + rustBlock("") { + rust("return Ok(#T);", builder) + } } } } @@ -634,7 +667,7 @@ class CborParserGenerator( """ pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{UnionSymbol}, #{Error}> { #{UnionPairParserFnWritable} - + match decoder.map()? { None => { let variant = pair(decoder)?; diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index f96a8b7cbc..1eaf1cd4da 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -45,6 +45,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctio import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.isUnit import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -157,9 +158,10 @@ class CborSerializerGenerator( private val codegenScope = arrayOf( + *preludeScope, "Error" to runtimeConfig.serializationError(), "Encoder" to RuntimeType.smithyCbor(runtimeConfig).resolve("Encoder"), - *preludeScope, + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), ) private val serializerUtil = SerializerUtil(model, symbolProvider) @@ -210,14 +212,29 @@ class CborSerializerGenerator( UNREACHABLE("Only clients use this method when serializing an `@httpPayload`. No protocol using CBOR supports this trait, so we don't need to implement this") override fun operationInputSerializer(operationShape: OperationShape): RuntimeType? { - // Don't generate an operation CBOR serializer if there is no CBOR body. - val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) - if (httpDocumentMembers.isEmpty()) { + // Don't generate an operation CBOR serializer if there was no operation input shape in the + // original (untransformed) model. + if (!OperationNormalizer.hadUserModeledOperationInput(operationShape, model)) { return null } - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) - TODO("Client implementation should fill this out") + val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + val inputShape = operationShape.inputShape(model) + return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> + rustBlockTemplate( + "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>", + *codegenScope, "target" to symbolProvider.toSymbol(inputShape), + ) { + rustTemplate("let mut encoder = #{Encoder}::new(Vec::new());", *codegenScope) + // Open a scope in which we can safely shadow the `encoder` variable to bind it to a mutable reference + // which doesn't require us to pass `&mut encoder` where requested. + rustBlock("") { + rust("let encoder = &mut encoder;") + serializeStructure(StructContext("input", inputShape), httpDocumentMembers) + } + rustTemplate("Ok(#{SdkBody}::from(encoder.into_writer()))", *codegenScope) + } + } } override fun documentSerializer(): RuntimeType = diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index d4984d65e9..59f0f5e5f0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -10,12 +10,15 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor @@ -316,6 +319,22 @@ class ServerRpcV2CborProtocol( override fun structuredDataParser(): StructuredDataParserGenerator = CborParserGenerator( serverCodegenContext, httpBindingResolver, returnSymbolToParseFn(serverCodegenContext), + handleNullForNonSparseCollection = { collectionName: String -> + writable { + rustTemplate( + """ + return #{Err}(#{Error}::custom("dense $collectionName cannot contain null values", decoder.position())) + """, + *RuntimeType.preludeScope, + "Error" to + CargoDependency.smithyCbor(runtimeConfig).toType() + .resolve("decode::DeserializeError"), + ) + } + }, + shouldWrapBuilderMemberSetterInputWithOption = { member: MemberShape -> + codegenContext.symbolProvider.toSymbol(member).isOptional() + }, listOf( ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization( serverCodegenContext, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 09e6b635de..c05f515a2f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -11,7 +11,6 @@ import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.DoubleShape import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.protocoltests.traits.AppliesTo @@ -272,13 +271,15 @@ class ServerProtocolTestGenerator( private val codegenScope = arrayOf( + "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), "Base64SimdDev" to ServerCargoDependency.Base64SimdDev.toType(), "Bytes" to RuntimeType.Bytes, "Hyper" to RuntimeType.Hyper, + "MediaType" to RuntimeType.protocolTest(codegenContext.runtimeConfig, "MediaType"), "Tokio" to ServerCargoDependency.TokioDev.toType(), "Tower" to RuntimeType.Tower, "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType(), - "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), + "decode_body_data" to RuntimeType.protocolTest(codegenContext.runtimeConfig, "decode_body_data"), ) override fun RustWriter.renderAllTestCases(allTests: List) { @@ -313,7 +314,6 @@ class ServerProtocolTestGenerator( headers, body.orNull(), bodyMediaType.orNull(), - protocol, queryParams, host.orNull(), ) @@ -387,15 +387,12 @@ class ServerProtocolTestGenerator( rustBlock("") { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. - // TODO(https://github.com/smithy-lang/smithy/issues/1932): we send `null` for `bodyMediaType` for now but - // the Smithy protocol test should give it to us. renderHttpRequest( uri.get(), method, headers, body.orNull(), - bodyMediaType = null, - testCase.protocol, + bodyMediaType.orNull(), queryParams, host.orNull(), ) @@ -417,7 +414,6 @@ class ServerProtocolTestGenerator( headers: Map, body: String?, bodyMediaType: String?, - protocol: ShapeId, queryParams: List, host: String?, ) { @@ -448,24 +444,12 @@ class ServerProtocolTestGenerator( // We also escape to avoid interactions with templating in the case where the body contains `#`. val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() - // TODO(https://github.com/smithy-lang/smithy/issues/1932): We're using the `protocol` field as a - // proxy for `bodyMediaType`. This works because `rpcv2Cbor` happens to be the only protocol where - // the body is base64-encoded in the protocol test, but checking `bodyMediaType` should be a more - // resilient check. val encodedBody = - if (protocol.toShapeId() == ShapeId.from("smithy.protocols#rpcv2Cbor")) { - """ - #{Bytes}::from( - #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( - "`body` field of Smithy protocol test is not correctly base64 encoded" - ) - ) """ - } else { - """ - #{Bytes}::from_static($sanitizedBody.as_bytes()) + #{Bytes}::copy_from_slice( + &#{decode_body_data}($sanitizedBody.as_bytes(), #{MediaType}::from(${(bodyMediaType ?: "unknown").dq()})) + ) """ - } "#{SmithyHttpServer}::body::Body::from($encodedBody)" } else { diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml index b87366d6ef..123e506b5a 100644 --- a/rust-runtime/aws-smithy-cbor/Cargo.toml +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-cbor" -version = "0.60.6" +version = "0.60.7" authors = [ "AWS Rust SDK Team ", "David Pérez ", diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs index 3cfe070397..ca71cdd250 100644 --- a/rust-runtime/aws-smithy-cbor/src/decode.rs +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -82,6 +82,13 @@ impl DeserializeError { } } + /// Returns a custom error with an offset. + pub fn custom(message: impl Into>, at: usize) -> Self { + Self { + _inner: Error::message(message.into()).at(at), + } + } + /// An unexpected type was encountered. // We handle this one when decoding sparse collections: we have to expect either a `null` or an // item, so we try decoding both. @@ -223,8 +230,19 @@ impl<'b> Decoder<'b> { "expected timestamp tag", ))) } else { + // Values that are more granular than millisecond precision SHOULD be truncated to fit + // millisecond precision for epoch-seconds: + // https://smithy.io/2.0/spec/protocol-traits.html#timestamp-formats + // + // Without truncation, the `RpcV2CborDateTimeWithFractionalSeconds` protocol test would + // fail since the upstream test expect `123000000` in subsec but the decoded actual + // subsec would be `123000025`. + // https://github.com/smithy-lang/smithy/blob/6466fe77c65b8a17b219f0b0a60c767915205f95/smithy-protocol-tests/model/rpcv2Cbor/fractional-seconds.smithy#L17 let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?; - Ok(DateTime::from_secs_f64(epoch_seconds)) + let mut result = DateTime::from_secs_f64(epoch_seconds); + let subsec_nanos = result.subsec_nanos(); + result.set_subsec_nanos((subsec_nanos / 1_000_000) * 1_000_000); + Ok(result) } } } @@ -279,6 +297,7 @@ where #[cfg(test)] mod tests { use crate::Decoder; + use aws_smithy_types::date_time::Format; #[test] fn test_definite_str_is_cow_borrowed() { @@ -338,4 +357,23 @@ mod tests { aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes()) ); } + + #[test] + fn test_timestamp_should_be_truncated_to_fit_millisecond_precision() { + // Input bytes are derived from the `RpcV2CborDateTimeWithFractionalSeconds` protocol test, + // extracting portion representing a timestamp value. + let bytes = [ + 0xc1, 0xfb, 0x41, 0xcc, 0x37, 0xdb, 0x38, 0x0f, 0xbe, 0x77, 0xff, + ]; + let mut decoder = Decoder::new(&bytes); + let timestamp = decoder.timestamp().expect("should decode timestamp"); + assert_eq!( + timestamp, + aws_smithy_types::date_time::DateTime::from_str( + "2000-01-02T20:34:56.123Z", + Format::DateTime + ) + .unwrap() + ); + } } diff --git a/rust-runtime/aws-smithy-protocol-test/src/lib.rs b/rust-runtime/aws-smithy-protocol-test/src/lib.rs index 06cdbc2ff2..f6bef0d78c 100644 --- a/rust-runtime/aws-smithy-protocol-test/src/lib.rs +++ b/rust-runtime/aws-smithy-protocol-test/src/lib.rs @@ -23,6 +23,7 @@ use aws_smithy_runtime_api::client::orchestrator::HttpRequest; use aws_smithy_runtime_api::http::Headers; use http::{HeaderMap, Uri}; use pretty_assertions::Comparison; +use std::borrow::Cow; use std::collections::HashSet; use std::fmt::{self, Debug}; use thiserror::Error; @@ -474,6 +475,17 @@ actual body in base64 (useful to update the protocol test): } } +pub fn decode_body_data(body: &[u8], media_type: MediaType) -> Cow<'_, [u8]> { + match media_type { + MediaType::Cbor => Cow::Owned( + base64_simd::STANDARD + .decode_to_vec(body) + .expect("smithy protocol test `body` property is not properly base64 encoded"), + ), + _ => Cow::Borrowed(body), + } +} + #[cfg(test)] mod tests { use crate::{ diff --git a/rust-runtime/inlineable/Cargo.toml b/rust-runtime/inlineable/Cargo.toml index 8e7c0ebd79..00a6fca122 100644 --- a/rust-runtime/inlineable/Cargo.toml +++ b/rust-runtime/inlineable/Cargo.toml @@ -17,6 +17,7 @@ gated-tests = [] default = ["gated-tests"] [dependencies] +aws-smithy-cbor = { path = "../aws-smithy-cbor" } aws-smithy-compression = { path = "../aws-smithy-compression", features = ["http-body-0-4-x"] } aws-smithy-http = { path = "../aws-smithy-http", features = ["event-stream"] } aws-smithy-json = { path = "../aws-smithy-json" } diff --git a/rust-runtime/inlineable/src/cbor_errors.rs b/rust-runtime/inlineable/src/cbor_errors.rs new file mode 100644 index 0000000000..d96c5233aa --- /dev/null +++ b/rust-runtime/inlineable/src/cbor_errors.rs @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_cbor::decode::DeserializeError; +use aws_smithy_cbor::Decoder; +use aws_smithy_runtime_api::http::Headers; +use aws_smithy_types::error::metadata::{Builder as ErrorMetadataBuilder, ErrorMetadata}; + +// This function is a copy-paste from `json_errors::sanitize_error_code`, therefore the functional +// tests can be viewed in the unit tests there. +// Since this is in the `inlineable` crate, there aren't good modules for housing common utilities +// unless we move this to a Smithy runtime crate. +fn sanitize_error_code(error_code: &str) -> &str { + // Trim a trailing URL from the error code, which is done by removing the longest suffix + // beginning with a `:` + let error_code = match error_code.find(':') { + Some(idx) => &error_code[..idx], + None => error_code, + }; + + // Trim a prefixing namespace from the error code, beginning with a `#` + match error_code.find('#') { + Some(idx) => &error_code[idx + 1..], + None => error_code, + } +} + +pub fn parse_error_metadata( + _response_status: u16, + _response_headers: &Headers, + response_body: &[u8], +) -> Result { + fn error_code( + mut builder: ErrorMetadataBuilder, + decoder: &mut Decoder, + ) -> Result { + builder = match decoder.str()?.as_ref() { + "__type" => { + let code = decoder.str()?; + builder.code(sanitize_error_code(&code)) + } + _ => { + decoder.skip()?; + builder + } + }; + Ok(builder) + } + + let decoder = &mut Decoder::new(response_body); + let mut builder = ErrorMetadata::builder(); + + match decoder.map()? { + None => loop { + match decoder.datatype()? { + ::aws_smithy_cbor::data::Type::Break => { + decoder.skip()?; + break; + } + _ => { + builder = error_code(builder, decoder)?; + } + }; + }, + Some(n) => { + for _ in 0..n { + builder = error_code(builder, decoder)?; + } + } + }; + + Ok(builder) +} diff --git a/rust-runtime/inlineable/src/json_errors.rs b/rust-runtime/inlineable/src/json_errors.rs index b785bbe63b..05da1c217c 100644 --- a/rust-runtime/inlineable/src/json_errors.rs +++ b/rust-runtime/inlineable/src/json_errors.rs @@ -16,7 +16,8 @@ pub fn is_error(response: &http::Response) -> bool { } fn sanitize_error_code(error_code: &str) -> &str { - // Trim a trailing URL from the error code, beginning with a `:` + // Trim a trailing URL from the error code, which is done by removing the longest suffix + // beginning with a `:` let error_code = match error_code.find(':') { Some(idx) => &error_code[..idx], None => error_code, diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index cf0d8705d1..0e2a815e05 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -9,6 +9,8 @@ #[allow(dead_code)] mod aws_query_compatible_errors; #[allow(unused)] +mod cbor_errors; +#[allow(unused)] mod client_http_checksum_required; #[allow(dead_code)] mod client_idempotency_token; diff --git a/tools/ci-cdk/canary-runner/Cargo.lock b/tools/ci-cdk/canary-runner/Cargo.lock index a26f76eaf7..e1a2834a4d 100644 --- a/tools/ci-cdk/canary-runner/Cargo.lock +++ b/tools/ci-cdk/canary-runner/Cargo.lock @@ -2558,6 +2558,19 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap 2.2.6", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "sha1" version = "0.10.6" @@ -2649,6 +2662,7 @@ dependencies = [ "semver", "serde", "serde_json", + "serde_yaml", "thiserror", "tokio", "toml 0.5.11", @@ -3082,6 +3096,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + [[package]] name = "untrusted" version = "0.7.1"