Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix constraint-related errors in Rpcv2CBOR server implementation #3794

Merged
merged 22 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d026ad1
Use constraints.smithy with CBor
Aug 19, 2024
fd8c25c
Use CBOR encoded string for marhsalling tests
Aug 30, 2024
6afc8bf
Remove streaming trait from blob
Sep 4, 2024
3fede97
Use constraints.smithy with CBor
Aug 19, 2024
5bb92b1
Use CBOR encoded string for marhsalling tests
Aug 30, 2024
f2c11b8
Merge remote-tracking branch 'cbor-fixe/fahadzub/cbor-constraint' int…
Sep 4, 2024
763d598
Fix formatting and comments
Sep 4, 2024
b961f18
Merge branch 'main' into fahadzub/cbor-constraint
Sep 4, 2024
91d80a4
Implement `parseEventStreamErrorMetadata`, and change client test cas…
Sep 8, 2024
d51cc78
Add copyright
Sep 9, 2024
050d9c5
Add changelog and fix lint issues
Sep 9, 2024
4e351cb
Merge branch 'main' into fahadzub/cbor-constraint
drganjoo Sep 9, 2024
3e74cc8
Update .changelog/2155171.md
drganjoo Sep 17, 2024
8bf71f1
Update .changelog/2155171.md
drganjoo Sep 17, 2024
13c0c78
Update codegen-core/src/main/kotlin/software/amazon/smithy/rust/codeg…
drganjoo Sep 17, 2024
c36dfb0
Update codegen-core/src/main/kotlin/software/amazon/smithy/rust/codeg…
drganjoo Sep 17, 2024
5bbfdc2
Add comments to clarify that the ServerProtocolBasedTransformationFac…
Sep 17, 2024
432e0f9
Merge branch 'main' into fahadzub/cbor-constraint
drganjoo Sep 17, 2024
9977516
Merge branch 'main' into fahadzub/cbor-constraint
drganjoo Sep 18, 2024
fcf670d
Merge branch 'main' into fahadzub/cbor-constraint
drganjoo Sep 30, 2024
35a9099
Merge branch 'main' into fahadzub/cbor-constraint
drganjoo Sep 30, 2024
27ca7f1
Merge branch 'main' into fahadzub/cbor-constraint
aws-sdk-rust-ci Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .changelog/2155171.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
applies_to: ["server","client"]
authors: ["drganjoo"]
references: [smithy-rs#3573]
breaking: false
new_feature: true
bug_fix: false
---
Support for the [rpcv2Cbor](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) protocol has been added, allowing services to serialize RPC payloads as CBOR (Concise Binary Object Representation), improving performance and efficiency in data transmission.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.generateRustPayloadInitializer
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.testModule
Expand Down Expand Up @@ -46,7 +47,7 @@ class ClientEventStreamUnmarshallerGeneratorTest {
"exception",
"UnmodeledError",
"${testCase.responseContentType}",
br#"${testCase.validUnmodeledError}"#
${testCase.generateRustPayloadInitializer(testCase.validUnmodeledError)}
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
Expand Down
1 change: 1 addition & 0 deletions codegen-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies {
implementation("org.jsoup:jsoup:1.16.2")
api("software.amazon.smithy:smithy-codegen-core:$smithyVersion")
api("com.moandjiezana.toml:toml4j:0.7.2")
implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0")
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
Expand Down Expand Up @@ -140,9 +141,23 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata")

// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType =
TODO("rpcv2Cbor event streams have not yet been implemented")
ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName ->
rustTemplate(
"""
pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{DeserializeError}> {
#{cbor_errors}::parse_error_metadata(0, &#{Headers}::new(), payload)
}
""",
"cbor_errors" to RuntimeType.cborErrors(runtimeConfig),
"Bytes" to RuntimeType.Bytes,
"ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig),
"DeserializeError" to
CargoDependency.smithyCbor(runtimeConfig).toType()
.resolve("decode::DeserializeError"),
"Headers" to RuntimeType.headers(runtimeConfig),
)
}

// Unlike other protocols, the `rpcv2Cbor` protocol requires that `Content-Length` is always set
// unless there is no input or if the operation is an event stream, see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
Expand Down Expand Up @@ -447,7 +446,24 @@ class CborParserGenerator(
}

override fun payloadParser(member: MemberShape): RuntimeType {
UNREACHABLE("No protocol using CBOR serialization supports payload binding")
val shape = model.expectShape(member.target)
val returnSymbol = returnSymbolToParse(shape)
check(shape is UnionShape || shape is StructureShape) {
"Payload parser should only be used on structure and union shapes."
}
return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName ->
rustTemplate(
"""
pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> {
let decoder = &mut #{Decoder}::new(value);
#{DeserializeMember}
}
""",
"ReturnType" to returnSymbol.symbol,
"DeserializeMember" to deserializeMember(member),
*codegenScope,
)
}
}

override fun operationParser(operationShape: OperationShape): RuntimeType? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) {
/** Manipulate the serializer context for a map prior to it being serialized. **/
data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context<Shape>) :
CborSerializerSection("BeforeIteratingOverMapOrCollection")

/** Manipulate the serializer context for a non-null member prior to it being serialized. **/
data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) :
CborSerializerSection("BeforeSerializingNonNullMember")
}

/**
Expand Down Expand Up @@ -200,9 +204,26 @@ class CborSerializerGenerator(
}
}

// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
override fun payloadSerializer(member: MemberShape): RuntimeType {
TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573")
val target = model.expectShape(member.target)
return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName ->
rustBlockTemplate(
"pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}<u8>, #{Error}>",
*codegenScope,
"target" to symbolProvider.toSymbol(target),
) {
rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope)
rustBlock("") {
rust("let encoder = &mut encoder;")
when (target) {
is StructureShape -> serializeStructure(StructContext("input", target))
is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target))
else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions")
}
}
rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope)
}
}
}

override fun unsetStructure(structure: StructureShape): RuntimeType =
Expand Down Expand Up @@ -311,6 +332,7 @@ class CborSerializerGenerator(
safeName().also { local ->
rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") {
context.valueExpression = ValueExpression.Reference(local)
resolveValueExpressionForConstrainedType(targetShape, context)
serializeMemberValue(context, targetShape)
}
if (context.writeNulls) {
Expand All @@ -320,6 +342,7 @@ class CborSerializerGenerator(
}
}
} else {
resolveValueExpressionForConstrainedType(targetShape, context)
with(serializerUtil) {
ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) {
serializeMemberValue(context, targetShape)
Expand All @@ -328,6 +351,20 @@ class CborSerializerGenerator(
}
}

private fun RustWriter.resolveValueExpressionForConstrainedType(
targetShape: Shape,
context: MemberContext,
) {
for (customization in customizations) {
customization.section(
CborSerializerSection.BeforeSerializingNonNullMember(
targetShape,
context,
),
)(this)
}
}

private fun RustWriter.serializeMemberValue(
context: MemberContext,
target: Shape,
Expand Down Expand Up @@ -362,7 +399,7 @@ class CborSerializerGenerator(
rust("$encoder;") // Encode the member key.
}
when (target) {
is StructureShape -> serializeStructure(StructContext(value.name, target))
is StructureShape -> serializeStructure(StructContext(value.asRef(), target))
is CollectionShape -> serializeCollection(Context(value, target))
is MapShape -> serializeMap(Context(value, target))
is UnionShape -> serializeUnion(Context(value, target))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@

package software.amazon.smithy.rust.codegen.core.testutil

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.cbor.CBORFactory
import software.amazon.smithy.model.Model
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor
import java.util.Base64

private fun fillInBaseModel(
protocolName: String,
namespacedProtocolName: String,
extraServiceAnnotations: String = "",
): String =
"""
namespace test

use smithy.framework#ValidationException
use aws.protocols#$protocolName
use $namespacedProtocolName

union TestUnion {
Foo: String,
Expand Down Expand Up @@ -86,22 +90,24 @@ private fun fillInBaseModel(
}

$extraServiceAnnotations
@$protocolName
@${namespacedProtocolName.substringAfter("#")}
service TestService { version: "123", operations: [TestStreamOp] }
"""

object EventStreamTestModels {
private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel()

private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel()

private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel()

private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel()

private fun awsQuery(): Model =
fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()

private fun ec2Query(): Model =
fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()

data class TestCase(
val protocolShapeId: String,
Expand All @@ -120,39 +126,67 @@ object EventStreamTestModels {
override fun toString(): String = protocolShapeId
}

private fun base64Encode(input: ByteArray): String {
val encodedBytes = Base64.getEncoder().encode(input)
return String(encodedBytes)
}

private fun createCborFromJson(jsonString: String): ByteArray {
val jsonMapper = ObjectMapper()
val cborMapper = ObjectMapper(CBORFactory())
// Parse JSON string to a generic type.
val jsonData = jsonMapper.readValue(jsonString, Any::class.java)
// Convert the parsed data to CBOR.
return cborMapper.writeValueAsBytes(jsonData)
}

private val restJsonTestCase =
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) }

val TEST_CASES =
listOf(
//
// restJson1
//
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) },
restJsonTestCase,
//
// rpcV2Cbor
//
restJsonTestCase.copy(
protocolShapeId = "smithy.protocols#rpcv2Cbor",
model = rpcv2Cbor(),
mediaType = "application/cbor",
responseContentType = "application/cbor",
eventStreamMessageContentType = "application/cbor",
validTestStruct = base64Encode(createCborFromJson(restJsonTestCase.validTestStruct)),
validMessageWithNoHeaderPayloadTraits = base64Encode(createCborFromJson(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)),
validTestUnion = base64Encode(createCborFromJson(restJsonTestCase.validTestUnion)),
validSomeError = base64Encode(createCborFromJson(restJsonTestCase.validSomeError)),
validUnmodeledError = base64Encode(createCborFromJson(restJsonTestCase.validUnmodeledError)),
protocolBuilder = { RpcV2Cbor(it) },
),
//
// awsJson1_1
//
TestCase(
restJsonTestCase.copy(
protocolShapeId = "aws.protocols#awsJson1_1",
model = awsJson11(),
mediaType = "application/x-amz-json-1.1",
requestContentType = "application/x-amz-json-1.1",
responseContentType = "application/x-amz-json-1.1",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { AwsJson(it, AwsJsonVersion.Json11) },
//
// restXml
Expand Down
Loading