Skip to content

Commit

Permalink
Add client-support for RPC v2 CBOR (#3767)
Browse files Browse the repository at this point in the history
## Motivation and Context
Follow-up on #2544 to add
client-side support for the protocol

## Description
The client implementation mainly focuses on a sub-section
[Requests](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#requests)
in the spec. To that end, this PR addresses `TODO` for the client to
fill in the blanks and includes additional adjustments/refactoring to
pass client protocol tests.

## Testing
- Existing tests in CI
- Upstream protocol test `rpcv2Cbor`
- Our handwritten protocol test `rpcv2Cbor-extras.smithy`

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
ysaito1001 authored Jul 31, 2024
1 parent 50148e6 commit 36b50b3
Show file tree
Hide file tree
Showing 22 changed files with 409 additions and 121 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ message = "Fix incorrect redaction of `@sensitive` types in maps and lists."
references = ["smithy-rs#3765", "smithy-rs#3757"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "landonxjames"

[[smithy-rs]]
message = "Fix client error correction to properly parse structure members that target a `Union` containing that structure recursively."
references = ["smithy-rs#3767"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "ysaito1001"
7 changes: 7 additions & 0 deletions codegen-client-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ val workingDirUnderBuildDir = "smithyprojections/codegen-client-test/"
dependencies {
implementation(project(":codegen-client"))
implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
}
Expand Down Expand Up @@ -72,6 +73,12 @@ val allCodegenTests = listOf(
ClientTest("aws.protocoltests.restxml#RestXml", "rest_xml", addMessageToErrors = false),
ClientTest("aws.protocoltests.query#AwsQuery", "aws_query", addMessageToErrors = false),
ClientTest("aws.protocoltests.ec2#AwsEc2", "ec2_query", addMessageToErrors = false),
ClientTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"),
ClientTest(
"smithy.protocoltests.rpcv2Cbor#RpcV2CborService",
"rpcv2Cbor_extras",
dependsOn = listOf("rpcv2Cbor-extras.smithy")
),
ClientTest(
"aws.protocoltests.restxml.xmlns#RestXmlWithNamespace",
"rest_xml_namespace",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,18 @@ private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Wri

target is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this)
target is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this)
target is UnionShape -> rust("Some(#T::Unknown)", targetSymbol)
target is UnionShape ->
rustTemplate(
"Some(#{unknown})", *preludeScope,
"unknown" to
writable {
if (memberSymbol.isRustBoxed()) {
rust("Box::new(#T::Unknown)", targetSymbol)
} else {
rust("#T::Unknown", targetSymbol)
}
},
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Proto
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_10
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
Expand Down Expand Up @@ -78,6 +79,8 @@ class ClientProtocolTestGenerator(
FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput"),
FailingTest.RequestTest(REST_JSON, "RestJsonClientPopulatesDefaultValuesInInput"),
FailingTest.RequestTest(REST_JSON, "RestJsonClientUsesExplicitlyProvidedMemberValuesOverDefaults"),
FailingTest.RequestTest(RPC_V2_CBOR, "RpcV2CborClientPopulatesDefaultValuesInInput"),
FailingTest.RequestTest(RPC_V2_CBOR, "RpcV2CborClientUsesExplicitlyProvidedMemberValuesOverDefaults"),
)

private val BrokenTests:
Expand Down Expand Up @@ -268,6 +271,7 @@ class ClientProtocolTestGenerator(
""",
RT.sdkBody(runtimeConfig = rc),
)
val mediaType = testCase.bodyMediaType.orNull()
rustTemplate(
"""
use #{DeserializeResponse};
Expand All @@ -280,19 +284,19 @@ class ClientProtocolTestGenerator(
let parsed = de.deserialize_streaming(&mut http_response);
let parsed = parsed.unwrap_or_else(|| {
let http_response = http_response.map(|body| {
#{SdkBody}::from(#{copy_from_slice}(body.bytes().unwrap()))
#{SdkBody}::from(#{copy_from_slice}(&#{decode_body_data}(body.bytes().unwrap(), #{MediaType}::from(${(mediaType ?: "unknown").dq()}))))
});
de.deserialize_nonstreaming(&http_response)
});
""",
"copy_from_slice" to RT.Bytes.resolve("copy_from_slice"),
"SharedResponseDeserializer" to
RT.smithyRuntimeApiClient(rc)
.resolve("client::ser_de::SharedResponseDeserializer"),
"Operation" to codegenContext.symbolProvider.toSymbol(operationShape),
"decode_body_data" to RT.protocolTest(rc, "decode_body_data"),
"DeserializeResponse" to RT.smithyRuntimeApiClient(rc).resolve("client::ser_de::DeserializeResponse"),
"MediaType" to RT.protocolTest(rc, "MediaType"),
"Operation" to codegenContext.symbolProvider.toSymbol(operationShape),
"RuntimePlugin" to RT.runtimePlugin(rc),
"SdkBody" to RT.sdkBody(rc),
"SharedResponseDeserializer" to RT.smithyRuntimeApiClient(rc).resolve("client::ser_de::SharedResponseDeserializer"),
)
if (expectedShape.hasTrait<ErrorTrait>()) {
val errorSymbol = codegenContext.symbolProvider.symbolForOperationError(operationShape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.findStreamingMember
Expand Down Expand Up @@ -125,10 +124,8 @@ class RequestSerializerGenerator(
)
}

private fun needsContentLength(operationShape: OperationShape): Boolean {
return protocol.httpBindingResolver.requestBindings(operationShape)
.any { it.location == HttpLocation.DOCUMENT || it.location == HttpLocation.PAYLOAD }
}
private fun needsContentLength(operationShape: OperationShape): Boolean =
protocol.needsRequestContentLength(operationShape)

private fun createHttpRequest(operationShape: OperationShape): Writable =
writable {
Expand Down
21 changes: 17 additions & 4 deletions codegen-core/common-test-models/rpcv2Cbor-extras.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ apply ErrorSerializationOperation @httpMalformedRequestTests([
"Content-Type": "application/cbor"
}
// An empty CBOR map. We're missing a lot of `@required` members!
body: "oA=="
body: "oA==",
bodyMediaType: "application/cbor"
},
response: {
code: 400,
Expand All @@ -149,9 +150,9 @@ apply ErrorSerializationOperation @httpResponseTests([
id: "OperationOutputSerializationQuestionablyIncludesTypeField",
documentation: """
Despite the operation output being a structure shape with the `@error` trait,
`__type` field should, in a strict interpretation of the spec, not be included,
because we're not serializing a server error response. However, we do, because
there shouldn't™️ be any harm in doing so, and it greatly simplifies the
`__type` field should, in a strict interpretation of the spec, not be included,
because we're not serializing a server error response. However, we do, because
there shouldn't™️ be any harm in doing so, and it greatly simplifies the
code generator. This test just pins this behavior in case we ever modify it.""",
protocol: rpcv2Cbor,
code: 200,
Expand All @@ -170,6 +171,12 @@ apply SimpleStructOperation @httpResponseTests([
id: "SimpleStruct",
protocol: rpcv2Cbor,
code: 200, // Not used.
body: "v2RibG9iS2Jsb2JieSBibG9iZ2Jvb2xlYW70ZnN0cmluZ3hwVGhlcmUgYXJlIHRocmVlIHRoaW5ncyBhbGwgd2lzZSBtZW4gZmVhcjogdGhlIHNlYSBpbiBzdG9ybSwgYSBuaWdodCB3aXRoIG5vIG1vb24sIGFuZCB0aGUgYW5nZXIgb2YgYSBnZW50bGUgbWFuLmRieXRlGEVlc2hvcnQYRmdpbnRlZ2VyGEdkbG9uZxhIZWZsb2F0+j8wo9dmZG91Ymxl+z/mTQE6kqMFaXRpbWVzdGFtcMH7QdcKq2AAAABkZW51bWdESUFNT05EbHJlcXVpcmVkQmxvYktibG9iYnkgYmxvYm9yZXF1aXJlZEJvb2xlYW70bnJlcXVpcmVkU3RyaW5neHBUaGVyZSBhcmUgdGhyZWUgdGhpbmdzIGFsbCB3aXNlIG1lbiBmZWFyOiB0aGUgc2VhIGluIHN0b3JtLCBhIG5pZ2h0IHdpdGggbm8gbW9vbiwgYW5kIHRoZSBhbmdlciBvZiBhIGdlbnRsZSBtYW4ubHJlcXVpcmVkQnl0ZRhFbXJlcXVpcmVkU2hvcnQYRm9yZXF1aXJlZEludGVnZXIYR2xyZXF1aXJlZExvbmcYSG1yZXF1aXJlZEZsb2F0+j8wo9ducmVxdWlyZWREb3VibGX7P+ZNATqSowVxcmVxdWlyZWRUaW1lc3RhbXDB+0HXCqtgAAAAbHJlcXVpcmVkRW51bWdESUFNT05E/w==",
bodyMediaType: "application/cbor",
headers: {
"smithy-protocol": "rpc-v2-cbor",
"Content-Type": "application/cbor"
},
params: {
blob: "blobby blob",
boolean: false,
Expand Down Expand Up @@ -211,6 +218,12 @@ apply SimpleStructOperation @httpResponseTests([
id: "SimpleStructWithOptionsSetToNone",
protocol: rpcv2Cbor,
code: 200, // Not used.
body: "v2xyZXF1aXJlZEJsb2JLYmxvYmJ5IGJsb2JvcmVxdWlyZWRCb29sZWFu9G5yZXF1aXJlZFN0cmluZ3hwVGhlcmUgYXJlIHRocmVlIHRoaW5ncyBhbGwgd2lzZSBtZW4gZmVhcjogdGhlIHNlYSBpbiBzdG9ybSwgYSBuaWdodCB3aXRoIG5vIG1vb24sIGFuZCB0aGUgYW5nZXIgb2YgYSBnZW50bGUgbWFuLmxyZXF1aXJlZEJ5dGUYRW1yZXF1aXJlZFNob3J0GEZvcmVxdWlyZWRJbnRlZ2VyGEdscmVxdWlyZWRMb25nGEhtcmVxdWlyZWRGbG9hdPo/MKPXbnJlcXVpcmVkRG91Ymxl+z/mTQE6kqMFcXJlcXVpcmVkVGltZXN0YW1wwftB1wqrYAAAAGxyZXF1aXJlZEVudW1nRElBTU9ORP8=",
bodyMediaType: "application/cbor",
headers: {
"smithy-protocol": "rpc-v2-cbor",
"Content-Type": "application/cbor"
},
params: {
requiredBlob: "blobby blob",
requiredBoolean: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ class InlineDependency(
CargoDependency.smithyTypes(runtimeConfig),
)

fun cborErrors(runtimeConfig: RuntimeConfig): InlineDependency =
forInlineableRustFile(
"cbor_errors",
CargoDependency.smithyCbor(runtimeConfig),
CargoDependency.smithyRuntimeApi(runtimeConfig),
CargoDependency.smithyTypes(runtimeConfig),
)

fun ec2QueryErrors(runtimeConfig: RuntimeConfig): InlineDependency =
forInlineableRustFile("ec2_query_errors", CargoDependency.smithyXml(runtimeConfig))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
)

// inlinable types
fun cborErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.cborErrors(runtimeConfig))

fun ec2QueryErrors(runtimeConfig: RuntimeConfig) =
forInlineDependency(InlineDependency.ec2QueryErrors(runtimeConfig))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ interface Protocol {
* there are no response headers or statuses available to further inform the error parsing.
*/
fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType

/**
* Determines whether the `Content-Length` header should be set in an HTTP request.
*/
fun needsRequestContentLength(operationShape: OperationShape): Boolean =
httpBindingResolver.requestBindings(operationShape)
.any { it.location == HttpLocation.DOCUMENT || it.location == HttpLocation.PAYLOAD }
}

typealias ProtocolMap<T, C> = Map<ShapeId, ProtocolGeneratorFactory<T, C>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,28 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.pattern.UriPattern
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.core.util.outputShape

class RpcV2CborHttpBindingResolver(
private val model: Model,
private val contentTypes: ProtocolContentTypes,
private val serviceShape: ServiceShape,
) : HttpBindingResolver {
private fun bindings(shape: ToShapeId): List<HttpBindingDescriptor> {
val members = shape.let { model.expectShape(it.toShapeId()) }.members()
Expand All @@ -47,10 +50,12 @@ class RpcV2CborHttpBindingResolver(
.toList()
}

// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
// In the server, this is only used when the protocol actually supports the `@http` trait.
// However, we will have to do this for client support. Perhaps this method deserves a rename.
override fun httpTrait(operationShape: OperationShape) = PANIC("RPC v2 does not support the `@http` trait")
override fun httpTrait(operationShape: OperationShape): HttpTrait =
HttpTrait.builder()
.code(200)
.method("POST")
.uri(UriPattern.parse("/service/${serviceShape.id.name}/operation/${operationShape.id.name}"))
.build()

override fun requestBindings(operationShape: OperationShape) = bindings(operationShape.inputShape)

Expand Down Expand Up @@ -87,6 +92,8 @@ class RpcV2CborHttpBindingResolver(
}

open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
private val runtimeConfig = codegenContext.runtimeConfig

override val httpBindingResolver: HttpBindingResolver =
RpcV2CborHttpBindingResolver(
codegenContext.model,
Expand All @@ -96,26 +103,50 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
eventStreamContentType = "application/vnd.amazon.eventstream",
eventStreamMessageContentType = "application/cbor",
),
codegenContext.serviceShape,
)

// Note that [CborParserGenerator] and [CborSerializerGenerator] automatically (de)serialize timestamps
// using floating point seconds from the epoch.
override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

override fun additionalRequestHeaders(operationShape: OperationShape): List<Pair<String, String>> =
listOf("smithy-protocol" to "rpc-v2-cbor")

override fun additionalResponseHeaders(operationShape: OperationShape): List<Pair<String, String>> =
listOf("smithy-protocol" to "rpc-v2-cbor")

override fun structuredDataParser(): StructuredDataParserGenerator =
CborParserGenerator(codegenContext, httpBindingResolver)
CborParserGenerator(
codegenContext, httpBindingResolver,
handleNullForNonSparseCollection = { collectionName: String ->
writable {
// The client should drop a null value in a dense collection, see
// https://github.com/smithy-lang/smithy/blob/6466fe77c65b8a17b219f0b0a60c767915205f95/smithy-protocol-tests/model/rpcv2Cbor/cbor-maps.smithy#L158
rustTemplate(
"""
decoder.null()?;
return #{Ok}($collectionName)
""",
*RuntimeType.preludeScope,
)
}
},
)

override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
CborSerializerGenerator(codegenContext, httpBindingResolver)

// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
TODO("rpcv2Cbor client support has not yet been implemented")
RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata")

// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType =
TODO("rpcv2Cbor event streams have not yet been implemented")

// Unlike other protocols, the `rpcv2Cbor` protocol requires that `Content-Length` is always set
// unless there is no input or if the operation is an event stream, see
// https://github.com/smithy-lang/smithy/blob/6466fe77c65b8a17b219f0b0a60c767915205f95/smithy-protocol-tests/model/rpcv2Cbor/empty-input-output.smithy#L106
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3772): Do not set `Content-Length` for event stream operations
override fun needsRequestContentLength(operationShape: OperationShape) = operationShape.input.isPresent
}
Loading

0 comments on commit 36b50b3

Please sign in to comment.