From d026ad1f32974b18a720bd3c5187adc21ddf189b Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Mon, 19 Aug 2024 15:45:14 +0100 Subject: [PATCH 01/14] Use constraints.smithy with CBor --- .../serialize/CborSerializerGenerator.kt | 22 ++++- .../server/smithy/ServerCodegenVisitor.kt | 4 +- ...eforeSerializingMemberCborCustomization.kt | 36 ++++++++ .../generators/protocol/ServerProtocol.kt | 2 + ...erverProtocolBasedTransformationFactory.kt | 82 +++++++++++++++++++ .../codegen/server/smithy/ConstraintsTest.kt | 78 ++++++++++++++++++ .../CborConstraintsIntegrationTest.kt | 32 ++++++++ 7 files changed, 254 insertions(+), 2 deletions(-) create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 1eaf1cd4da..625a669b5d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) { /** Manipulate the serializer context for a map prior to it being serialized. **/ data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context) : CborSerializerSection("BeforeIteratingOverMapOrCollection") + + /** Manipulate the serializer context for a non-null member prior to it being serialized. **/ + data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) : + CborSerializerSection("BeforeSerializingNonNullMember") } /** @@ -311,6 +315,7 @@ class CborSerializerGenerator( safeName().also { local -> rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { context.valueExpression = ValueExpression.Reference(local) + resolveValueExpressionForConstrainedType(targetShape, context) serializeMemberValue(context, targetShape) } if (context.writeNulls) { @@ -320,6 +325,7 @@ class CborSerializerGenerator( } } } else { + resolveValueExpressionForConstrainedType(targetShape, context) with(serializerUtil) { ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) { serializeMemberValue(context, targetShape) @@ -328,6 +334,20 @@ class CborSerializerGenerator( } } + private fun RustWriter.resolveValueExpressionForConstrainedType( + targetShape: Shape, + context: MemberContext, + ) { + for (customization in customizations) { + customization.section( + CborSerializerSection.BeforeSerializingNonNullMember( + targetShape, + context, + ), + )(this) + } + } + private fun RustWriter.serializeMemberValue( context: MemberContext, target: Shape, @@ -362,7 +382,7 @@ class CborSerializerGenerator( rust("$encoder;") // Encode the member key. } when (target) { - is StructureShape -> serializeStructure(StructContext(value.name, target)) + is StructureShape -> serializeStructure(StructContext(value.asRef(), target)) is CollectionShape -> serializeCollection(Context(value, target)) is MapShape -> serializeMap(Context(value, target)) is UnionShape -> serializeUnion(Context(value, target)) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 5be7d27254..769cb240de 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -90,6 +90,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachVali import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException +import software.amazon.smithy.rust.codegen.server.smithy.transformers.ServerProtocolBasedTransformationFactory import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.logging.Logger @@ -133,7 +134,8 @@ open class ServerCodegenVisitor( .protocolFor(context.model, service) this.protocolGeneratorFactory = protocolGeneratorFactory - model = codegenDecorator.transformModel(service, baseModel, settings) + val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service) + model = codegenDecorator.transformModel(service, protocolTransformedModel, settings) val serverSymbolProviders = ServerSymbolProviders.from( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt new file mode 100644 index 0000000000..71ced4d9df --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt @@ -0,0 +1,36 @@ +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +class BeforeSerializingMemberCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = + when (section) { + is CborSerializerSection.BeforeSerializingNonNullMember -> + writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + if (section.shape is IntegerShape || section.shape is ShortShape || section.shape is LongShape || section.shape is ByteShape || section.shape is BlobShape) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + } + + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 59f0f5e5f0..ac506aebd4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -47,6 +47,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator @@ -349,6 +350,7 @@ class ServerRpcV2CborProtocol( listOf( BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), AddTypeFieldToServerErrorsCborCustomization(), + BeforeSerializingMemberCborCustomization(serverCodegenContext), ), ) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt new file mode 100644 index 0000000000..32bf59a48d --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt @@ -0,0 +1,82 @@ +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.AbstractShapeBuilder +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.HttpLabelTrait +import software.amazon.smithy.model.traits.HttpPayloadTrait +import software.amazon.smithy.model.traits.HttpTrait +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait +import software.amazon.smithy.utils.SmithyBuilder +import software.amazon.smithy.utils.ToSmithyBuilder + +/** + * Each protocol may not support all of the features that Smithy allows. For instance, most + * RPC protocols do not support HTTP bindings. `ServerProtocolBasedTransformationFactory` is a factory + * object that transforms the model and removes specific traits based on the protocol being instantiated. + */ +object ServerProtocolBasedTransformationFactory { + fun createTransformer(protocolShapeId: ShapeId): Transformer = + when (protocolShapeId) { + Rpcv2CborTrait.ID -> Rpcv2Transformer() + else -> IdentityTransformer() + } + + interface Transformer { + fun transform( + model: Model, + service: ServiceShape, + ): Model + } + + fun T.removeTraitIfPresent( + traitId: ShapeId, + ): T + where T : ToSmithyBuilder, + B : AbstractShapeBuilder, + B : SmithyBuilder { + return if (this.hasTrait(traitId)) { + @Suppress("UNCHECKED_CAST") + (this.toBuilder() as B).removeTrait(traitId).build() + } else { + this + } + } + + class Rpcv2Transformer() : Transformer { + override fun transform( + model: Model, + service: ServiceShape, + ): Model { + val transformedModel = + ModelTransformer.create().mapShapes(model) { shape -> + when (shape) { + is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID) + is MemberShape -> { + shape + .removeTraitIfPresent(HttpLabelTrait.ID) + .removeTraitIfPresent(HttpPayloadTrait.ID) + } + + else -> shape + } + } + + return transformedModel + } + } + + class IdentityTransformer() : Transformer { + override fun transform( + model: Model, + service: ServiceShape, + ): Model { + return model + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index 30e5e64813..5be3930452 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -6,18 +6,96 @@ package software.amazon.smithy.rust.codegen.server.smithy import io.kotest.inspectors.forAll +import io.kotest.matchers.ints.shouldBeGreaterThan import io.kotest.matchers.shouldBe +import java.io.File import org.junit.jupiter.api.Test +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.traits.AbstractTrait +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +enum class ModelProtocol(val trait: AbstractTrait) { + AwsJson10(AwsJson1_0Trait.builder().build()), + AwsJson11(AwsJson1_1Trait.builder().build()), + RestJson(RestJson1Trait.builder().build()), + RestXml(RestXmlTrait.builder().build()), + Rpcv2Cbor(Rpcv2CborTrait.builder().build()), +} + +fun loadSmithyConstraintsModelForProtocol(modelProtocol: ModelProtocol): Pair { + val (serviceShapeId, model) = loadSmithyConstraintsModel() + return Pair(serviceShapeId, model.replaceProtocolTrait(serviceShapeId, modelProtocol)) +} + +fun loadSmithyConstraintsModel(): Pair { + val filePath = "../codegen-core/common-test-models/constraints.smithy" + val serviceShapeId = ShapeId.from("com.amazonaws.constraints#ConstraintsService") + val model = + File(filePath).readText().asSmithyModel() + return Pair(serviceShapeId, model) +} + +/** + * Removes all existing protocol traits annotated on the given service, + * then sets the provided `protocol` as the sole protocol trait for the service. + */ +fun Model.replaceProtocolTrait( + serviceShapeId: ShapeId, + modelProtocol: ModelProtocol, +): Model { + val serviceBuilder = + this.expectShape(serviceShapeId, ServiceShape::class.java).toBuilder() + for (p in ModelProtocol.values()) { + serviceBuilder.removeTrait(p.trait.toShapeId()) + } + val service = serviceBuilder.addTrait(modelProtocol.trait).build() + return ModelTransformer.create().replaceShapes(this, listOf(service)) +} + +fun List.containsAnyShapeId(ids: Collection): Boolean { + return ids.any { id -> this.any { shape -> shape == id } } +} + +/** + * Removes the given operations from the model. + */ +fun Model.removeOperations( + serviceShapeId: ShapeId, + operationsToRemove: List, +): Model { + val service = this.expectShape(serviceShapeId, ServiceShape::class.java) + val serviceBuilder = service.toBuilder() + // The operation must exist in the service. + service.operations.map { it.toShapeId() }.containsAll(operationsToRemove) shouldBe true + // Remove all operations. + for (opToRemove in operationsToRemove) { + serviceBuilder.removeOperation(opToRemove) + } + val changedModel = ModelTransformer.create().replaceShapes(this, listOf(serviceBuilder.build())) + // The operation must not exist in the updated service. + val changedService = changedModel.expectShape(serviceShapeId, ServiceShape::class.java) + changedService.operations.size shouldBeGreaterThan 0 + changedService.operations.map { it.toShapeId() }.containsAnyShapeId(operationsToRemove) shouldBe false + + return changedModel +} + class ConstraintsTest { private val model = """ diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt new file mode 100644 index 0000000000..da808553a7 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt @@ -0,0 +1,32 @@ +package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol +import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol +import software.amazon.smithy.rust.codegen.server.smithy.removeOperations +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest + +class CborConstraintsIntegrationTest { + @Test + fun `ensure CBOR implementation works for all constraint types`() { + val (serviceShape, constraintModel) = loadSmithyConstraintsModelForProtocol(ModelProtocol.Rpcv2Cbor) + // Event streaming operations are not supported by `Rpcv2Cbor` implementation. + // https://github.com/smithy-lang/smithy-rs/issues/3573 + val nonSupportedOperations = + listOf("EventStreamsOperation", "StreamingBlobOperation") + .map { ShapeId.from("${serviceShape.namespace}#$it") } + val model = + constraintModel + .removeOperations(serviceShape, nonSupportedOperations) + // The test should compile; no further testing is required. + serverIntegrationTest( + model, + IntegrationTestParams( + service = serviceShape.toString(), + ), + ) { _, _ -> + } + } +} From fd8c25cc37a33e96bba40328b48012a034f27807 Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Fri, 30 Aug 2024 12:04:49 +0100 Subject: [PATCH 02/14] Use CBOR encoded string for marhsalling tests --- codegen-core/build.gradle.kts | 1 + .../protocols/parse/CborParserGenerator.kt | 20 ++++- .../serialize/CborSerializerGenerator.kt | 22 ++++- .../core/testutil/EventStreamTestModels.kt | 88 +++++++++++++------ .../server/smithy/ServerCodegenVisitor.kt | 5 +- ...erverProtocolBasedTransformationFactory.kt | 63 +++++-------- .../codegen/server/smithy/ConstraintsTest.kt | 2 +- .../CborConstraintsIntegrationTest.kt | 4 +- 8 files changed, 128 insertions(+), 77 deletions(-) diff --git a/codegen-core/build.gradle.kts b/codegen-core/build.gradle.kts index eff612be35..0baf35b6a7 100644 --- a/codegen-core/build.gradle.kts +++ b/codegen-core/build.gradle.kts @@ -25,6 +25,7 @@ dependencies { implementation("org.jsoup:jsoup:1.16.2") api("software.amazon.smithy:smithy-codegen-core:$smithyVersion") api("com.moandjiezana.toml:toml4j:0.7.2") + implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 0cc16c101f..df6d802b17 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.util.PANIC -import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -447,7 +446,24 @@ class CborParserGenerator( } override fun payloadParser(member: MemberShape): RuntimeType { - UNREACHABLE("No protocol using CBOR serialization supports payload binding") + val shape = model.expectShape(member.target) + val returnSymbol = returnSymbolToParse(shape) + check(shape is UnionShape || shape is StructureShape) { + "Payload parser should only be used on structure and union shapes." + } + return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> { + let decoder = &mut #{Decoder}::new(value); + #{DeserializeMember} + } + """, + "ReturnType" to returnSymbol.symbol, + "DeserializeMember" to deserializeMember(member), + *codegenScope, + ) + } } override fun operationParser(operationShape: OperationShape): RuntimeType? { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 625a669b5d..426c1c354b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -204,9 +204,26 @@ class CborSerializerGenerator( } } - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun payloadSerializer(member: MemberShape): RuntimeType { - TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573") + val target = model.expectShape(member.target) + return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName -> + rustBlockTemplate( + "pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}, #{Error}>", + *codegenScope, + "target" to symbolProvider.toSymbol(target), + ) { + rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope) + rustBlock("") { + rust("let encoder = &mut encoder;") + when (target) { + is StructureShape -> serializeStructure(StructContext("input", target)) + is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target)) + else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions") + } + } + rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope) + } + } } override fun unsetStructure(structure: StructureShape): RuntimeType = @@ -223,6 +240,7 @@ class CborSerializerGenerator( } val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + val inputShape = operationShape.inputShape(model) return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> rustBlockTemplate( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index dc37caf714..c09d8b3382 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.testutil +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.dataformat.cbor.CBORFactory import software.amazon.smithy.model.Model import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson @@ -12,16 +14,18 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor +import java.util.Base64 private fun fillInBaseModel( - protocolName: String, + namespacedProtocolName: String, extraServiceAnnotations: String = "", ): String = """ namespace test use smithy.framework#ValidationException - use aws.protocols#$protocolName + use $namespacedProtocolName union TestUnion { Foo: String, @@ -86,22 +90,24 @@ private fun fillInBaseModel( } $extraServiceAnnotations - @$protocolName + @${namespacedProtocolName.substringAfter("#")} service TestService { version: "123", operations: [TestStreamOp] } """ object EventStreamTestModels { - private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel() + private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel() - private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel() + private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel() - private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel() + private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel() + + private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel() private fun awsQuery(): Model = - fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() + fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() private fun ec2Query(): Model = - fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() + fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() data class TestCase( val protocolShapeId: String, @@ -120,39 +126,67 @@ object EventStreamTestModels { override fun toString(): String = protocolShapeId } + private fun base64Encode(input: ByteArray): String { + val encodedBytes = Base64.getEncoder().encode(input) + return String(encodedBytes) + } + + private fun createCBORFromJSON(jsonString: String): ByteArray { + val jsonMapper = ObjectMapper() + val cborMapper = ObjectMapper(CBORFactory()) + // Parse JSON string to a generic type. + val jsonData = jsonMapper.readValue(jsonString, Any::class.java) + // Convert the parsed data to CBOR. + return cborMapper.writeValueAsBytes(jsonData) + } + + private val restJsonTestCase = + TestCase( + protocolShapeId = "aws.protocols#restJson1", + model = restJson1(), + mediaType = "application/json", + requestContentType = "application/vnd.amazon.eventstream", + responseContentType = "application/json", + eventStreamMessageContentType = "application/json", + validTestStruct = """{"someString":"hello","someInt":5}""", + validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", + validTestUnion = """{"Foo":"hello"}""", + validSomeError = """{"Message":"some error"}""", + validUnmodeledError = """{"Message":"unmodeled error"}""", + ) { RestJson(it) } + val TEST_CASES = listOf( // // restJson1 // - TestCase( - protocolShapeId = "aws.protocols#restJson1", - model = restJson1(), - mediaType = "application/json", - requestContentType = "application/vnd.amazon.eventstream", - responseContentType = "application/json", - eventStreamMessageContentType = "application/json", - validTestStruct = """{"someString":"hello","someInt":5}""", - validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", - validTestUnion = """{"Foo":"hello"}""", - validSomeError = """{"Message":"some error"}""", - validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { RestJson(it) }, + restJsonTestCase, + // + // rpcV2Cbor + // + restJsonTestCase.copy( + protocolShapeId = "smithy.protocols#rpcv2Cbor", + model = rpcv2Cbor(), + mediaType = "application/cbor", + responseContentType = "application/cbor", + eventStreamMessageContentType = "application/cbor", + validTestStruct = base64Encode(createCBORFromJSON(restJsonTestCase.validTestStruct)), + validMessageWithNoHeaderPayloadTraits = base64Encode(createCBORFromJSON(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)), + validTestUnion = base64Encode(createCBORFromJSON(restJsonTestCase.validTestUnion)), + validSomeError = base64Encode(createCBORFromJSON(restJsonTestCase.validSomeError)), + validUnmodeledError = base64Encode(createCBORFromJSON(restJsonTestCase.validUnmodeledError)), + protocolBuilder = { RpcV2Cbor(it) }, + ), // // awsJson1_1 // - TestCase( + restJsonTestCase.copy( protocolShapeId = "aws.protocols#awsJson1_1", model = awsJson11(), mediaType = "application/x-amz-json-1.1", requestContentType = "application/x-amz-json-1.1", responseContentType = "application/x-amz-json-1.1", eventStreamMessageContentType = "application/json", - validTestStruct = """{"someString":"hello","someInt":5}""", - validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", - validTestUnion = """{"Foo":"hello"}""", - validSomeError = """{"Message":"some error"}""", - validUnmodeledError = """{"Message":"unmodeled error"}""", ) { AwsJson(it, AwsJsonVersion.Json11) }, // // restXml diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 769cb240de..229943b8ac 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -134,8 +134,7 @@ open class ServerCodegenVisitor( .protocolFor(context.model, service) this.protocolGeneratorFactory = protocolGeneratorFactory - val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service) - model = codegenDecorator.transformModel(service, protocolTransformedModel, settings) + model = codegenDecorator.transformModel(service, baseModel, settings) val serverSymbolProviders = ServerSymbolProviders.from( @@ -210,6 +209,8 @@ open class ServerCodegenVisitor( .let { AttachValidationExceptionToConstrainedOperationInputs.transform(it, settings) } // Tag aggregate shapes reachable from operation input .let(ShapesReachableFromOperationInputTagger::transform) + // Remove traits that are not supported by the chosen protocol + .let { ServerProtocolBasedTransformationFactory.transform(it, settings) } // Normalize event stream operations .let(EventStreamNormalizer::transform) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt index 32bf59a48d..6c3357bf87 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt @@ -4,7 +4,6 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.AbstractShapeBuilder import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.HttpLabelTrait @@ -12,6 +11,8 @@ import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.protocol.traits.Rpcv2CborTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings import software.amazon.smithy.utils.SmithyBuilder import software.amazon.smithy.utils.ToSmithyBuilder @@ -21,17 +22,27 @@ import software.amazon.smithy.utils.ToSmithyBuilder * object that transforms the model and removes specific traits based on the protocol being instantiated. */ object ServerProtocolBasedTransformationFactory { - fun createTransformer(protocolShapeId: ShapeId): Transformer = - when (protocolShapeId) { - Rpcv2CborTrait.ID -> Rpcv2Transformer() - else -> IdentityTransformer() + fun transform( + model: Model, + settings: ServerRustSettings, + ): Model { + val service = settings.getService(model) + if (!service.hasTrait()) { + return model } - interface Transformer { - fun transform( - model: Model, - service: ServiceShape, - ): Model + return ModelTransformer.create().mapShapes(model) { shape -> + when (shape) { + is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID) + is MemberShape -> { + shape + .removeTraitIfPresent(HttpLabelTrait.ID) + .removeTraitIfPresent(HttpPayloadTrait.ID) + } + + else -> shape + } + } } fun T.removeTraitIfPresent( @@ -47,36 +58,4 @@ object ServerProtocolBasedTransformationFactory { this } } - - class Rpcv2Transformer() : Transformer { - override fun transform( - model: Model, - service: ServiceShape, - ): Model { - val transformedModel = - ModelTransformer.create().mapShapes(model) { shape -> - when (shape) { - is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID) - is MemberShape -> { - shape - .removeTraitIfPresent(HttpLabelTrait.ID) - .removeTraitIfPresent(HttpPayloadTrait.ID) - } - - else -> shape - } - } - - return transformedModel - } - } - - class IdentityTransformer() : Transformer { - override fun transform( - model: Model, - service: ServiceShape, - ): Model { - return model - } - } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index 5be3930452..ef7d75dd45 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy import io.kotest.inspectors.forAll import io.kotest.matchers.ints.shouldBeGreaterThan import io.kotest.matchers.shouldBe -import java.io.File import org.junit.jupiter.api.Test import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait @@ -29,6 +28,7 @@ import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import java.io.File enum class ModelProtocol(val trait: AbstractTrait) { AwsJson10(AwsJson1_0Trait.builder().build()), diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt index da808553a7..60bf75d127 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt @@ -3,6 +3,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol import software.amazon.smithy.rust.codegen.server.smithy.removeOperations @@ -15,7 +16,7 @@ class CborConstraintsIntegrationTest { // Event streaming operations are not supported by `Rpcv2Cbor` implementation. // https://github.com/smithy-lang/smithy-rs/issues/3573 val nonSupportedOperations = - listOf("EventStreamsOperation", "StreamingBlobOperation") + listOf("StreamingBlobOperation") .map { ShapeId.from("${serviceShape.namespace}#$it") } val model = constraintModel @@ -25,6 +26,7 @@ class CborConstraintsIntegrationTest { model, IntegrationTestParams( service = serviceShape.toString(), + additionalSettings = ServerAdditionalSettings.builder().generateCodegenComments().toObjectNode(), ), ) { _, _ -> } From 6afc8bf266b0d96cf18ab33c14821728af0c6ec6 Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Wed, 4 Sep 2024 10:10:47 +0100 Subject: [PATCH 03/14] Remove streaming trait from blob --- .../ServerProtocolBasedTransformationFactory.kt | 13 +++++++++++-- .../serialize/CborConstraintsIntegrationTest.kt | 12 +----------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt index 6c3357bf87..c38d8b623d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt @@ -2,6 +2,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.transformers import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.AbstractShapeBuilder +import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape @@ -9,6 +10,7 @@ import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.HttpLabelTrait import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpTrait +import software.amazon.smithy.model.traits.StreamingTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -17,8 +19,8 @@ import software.amazon.smithy.utils.SmithyBuilder import software.amazon.smithy.utils.ToSmithyBuilder /** - * Each protocol may not support all of the features that Smithy allows. For instance, most - * RPC protocols do not support HTTP bindings. `ServerProtocolBasedTransformationFactory` is a factory + * Each protocol may not support all of the features that Smithy allows. For instance, `rpcv2Cbor` + * does not support HTTP bindings. `ServerProtocolBasedTransformationFactory` is a factory * object that transforms the model and removes specific traits based on the protocol being instantiated. */ object ServerProtocolBasedTransformationFactory { @@ -31,6 +33,10 @@ object ServerProtocolBasedTransformationFactory { return model } + // `rpcv2Cbor` does not support: + // 1. `@httpPayload` trait. + // 2. `@httpLabel` trait. + // 3. `@streaming` trait applied to a `Blob` (data streaming). return ModelTransformer.create().mapShapes(model) { shape -> when (shape) { is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID) @@ -39,6 +45,9 @@ object ServerProtocolBasedTransformationFactory { .removeTraitIfPresent(HttpLabelTrait.ID) .removeTraitIfPresent(HttpPayloadTrait.ID) } + is BlobShape -> { + shape.removeTraitIfPresent(StreamingTrait.ID) + } else -> shape } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt index 60bf75d127..242bcb398c 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt @@ -1,26 +1,16 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize import org.junit.jupiter.api.Test -import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol -import software.amazon.smithy.rust.codegen.server.smithy.removeOperations import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest class CborConstraintsIntegrationTest { @Test fun `ensure CBOR implementation works for all constraint types`() { - val (serviceShape, constraintModel) = loadSmithyConstraintsModelForProtocol(ModelProtocol.Rpcv2Cbor) - // Event streaming operations are not supported by `Rpcv2Cbor` implementation. - // https://github.com/smithy-lang/smithy-rs/issues/3573 - val nonSupportedOperations = - listOf("StreamingBlobOperation") - .map { ShapeId.from("${serviceShape.namespace}#$it") } - val model = - constraintModel - .removeOperations(serviceShape, nonSupportedOperations) + val (serviceShape, model) = loadSmithyConstraintsModelForProtocol(ModelProtocol.Rpcv2Cbor) // The test should compile; no further testing is required. serverIntegrationTest( model, From 3fede97219ce7073c385c82444165d4fea762a3d Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Mon, 19 Aug 2024 15:45:14 +0100 Subject: [PATCH 04/14] Use constraints.smithy with CBor --- .../serialize/CborSerializerGenerator.kt | 22 ++++- .../server/smithy/ServerCodegenVisitor.kt | 4 +- ...eforeSerializingMemberCborCustomization.kt | 36 ++++++++ .../generators/protocol/ServerProtocol.kt | 2 + ...erverProtocolBasedTransformationFactory.kt | 82 +++++++++++++++++++ .../codegen/server/smithy/ConstraintsTest.kt | 78 ++++++++++++++++++ .../CborConstraintsIntegrationTest.kt | 32 ++++++++ 7 files changed, 254 insertions(+), 2 deletions(-) create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 1eaf1cd4da..625a669b5d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) { /** Manipulate the serializer context for a map prior to it being serialized. **/ data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context) : CborSerializerSection("BeforeIteratingOverMapOrCollection") + + /** Manipulate the serializer context for a non-null member prior to it being serialized. **/ + data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) : + CborSerializerSection("BeforeSerializingNonNullMember") } /** @@ -311,6 +315,7 @@ class CborSerializerGenerator( safeName().also { local -> rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { context.valueExpression = ValueExpression.Reference(local) + resolveValueExpressionForConstrainedType(targetShape, context) serializeMemberValue(context, targetShape) } if (context.writeNulls) { @@ -320,6 +325,7 @@ class CborSerializerGenerator( } } } else { + resolveValueExpressionForConstrainedType(targetShape, context) with(serializerUtil) { ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) { serializeMemberValue(context, targetShape) @@ -328,6 +334,20 @@ class CborSerializerGenerator( } } + private fun RustWriter.resolveValueExpressionForConstrainedType( + targetShape: Shape, + context: MemberContext, + ) { + for (customization in customizations) { + customization.section( + CborSerializerSection.BeforeSerializingNonNullMember( + targetShape, + context, + ), + )(this) + } + } + private fun RustWriter.serializeMemberValue( context: MemberContext, target: Shape, @@ -362,7 +382,7 @@ class CborSerializerGenerator( rust("$encoder;") // Encode the member key. } when (target) { - is StructureShape -> serializeStructure(StructContext(value.name, target)) + is StructureShape -> serializeStructure(StructContext(value.asRef(), target)) is CollectionShape -> serializeCollection(Context(value, target)) is MapShape -> serializeMap(Context(value, target)) is UnionShape -> serializeUnion(Context(value, target)) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 5be7d27254..769cb240de 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -90,6 +90,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachVali import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException +import software.amazon.smithy.rust.codegen.server.smithy.transformers.ServerProtocolBasedTransformationFactory import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.logging.Logger @@ -133,7 +134,8 @@ open class ServerCodegenVisitor( .protocolFor(context.model, service) this.protocolGeneratorFactory = protocolGeneratorFactory - model = codegenDecorator.transformModel(service, baseModel, settings) + val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service) + model = codegenDecorator.transformModel(service, protocolTransformedModel, settings) val serverSymbolProviders = ServerSymbolProviders.from( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt new file mode 100644 index 0000000000..71ced4d9df --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt @@ -0,0 +1,36 @@ +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +class BeforeSerializingMemberCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = + when (section) { + is CborSerializerSection.BeforeSerializingNonNullMember -> + writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + if (section.shape is IntegerShape || section.shape is ShortShape || section.shape is LongShape || section.shape is ByteShape || section.shape is BlobShape) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + } + + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 59f0f5e5f0..ac506aebd4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -47,6 +47,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator @@ -349,6 +350,7 @@ class ServerRpcV2CborProtocol( listOf( BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), AddTypeFieldToServerErrorsCborCustomization(), + BeforeSerializingMemberCborCustomization(serverCodegenContext), ), ) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt new file mode 100644 index 0000000000..32bf59a48d --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt @@ -0,0 +1,82 @@ +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.AbstractShapeBuilder +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.HttpLabelTrait +import software.amazon.smithy.model.traits.HttpPayloadTrait +import software.amazon.smithy.model.traits.HttpTrait +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait +import software.amazon.smithy.utils.SmithyBuilder +import software.amazon.smithy.utils.ToSmithyBuilder + +/** + * Each protocol may not support all of the features that Smithy allows. For instance, most + * RPC protocols do not support HTTP bindings. `ServerProtocolBasedTransformationFactory` is a factory + * object that transforms the model and removes specific traits based on the protocol being instantiated. + */ +object ServerProtocolBasedTransformationFactory { + fun createTransformer(protocolShapeId: ShapeId): Transformer = + when (protocolShapeId) { + Rpcv2CborTrait.ID -> Rpcv2Transformer() + else -> IdentityTransformer() + } + + interface Transformer { + fun transform( + model: Model, + service: ServiceShape, + ): Model + } + + fun T.removeTraitIfPresent( + traitId: ShapeId, + ): T + where T : ToSmithyBuilder, + B : AbstractShapeBuilder, + B : SmithyBuilder { + return if (this.hasTrait(traitId)) { + @Suppress("UNCHECKED_CAST") + (this.toBuilder() as B).removeTrait(traitId).build() + } else { + this + } + } + + class Rpcv2Transformer() : Transformer { + override fun transform( + model: Model, + service: ServiceShape, + ): Model { + val transformedModel = + ModelTransformer.create().mapShapes(model) { shape -> + when (shape) { + is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID) + is MemberShape -> { + shape + .removeTraitIfPresent(HttpLabelTrait.ID) + .removeTraitIfPresent(HttpPayloadTrait.ID) + } + + else -> shape + } + } + + return transformedModel + } + } + + class IdentityTransformer() : Transformer { + override fun transform( + model: Model, + service: ServiceShape, + ): Model { + return model + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index 30e5e64813..5be3930452 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -6,18 +6,96 @@ package software.amazon.smithy.rust.codegen.server.smithy import io.kotest.inspectors.forAll +import io.kotest.matchers.ints.shouldBeGreaterThan import io.kotest.matchers.shouldBe +import java.io.File import org.junit.jupiter.api.Test +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.traits.AbstractTrait +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +enum class ModelProtocol(val trait: AbstractTrait) { + AwsJson10(AwsJson1_0Trait.builder().build()), + AwsJson11(AwsJson1_1Trait.builder().build()), + RestJson(RestJson1Trait.builder().build()), + RestXml(RestXmlTrait.builder().build()), + Rpcv2Cbor(Rpcv2CborTrait.builder().build()), +} + +fun loadSmithyConstraintsModelForProtocol(modelProtocol: ModelProtocol): Pair { + val (serviceShapeId, model) = loadSmithyConstraintsModel() + return Pair(serviceShapeId, model.replaceProtocolTrait(serviceShapeId, modelProtocol)) +} + +fun loadSmithyConstraintsModel(): Pair { + val filePath = "../codegen-core/common-test-models/constraints.smithy" + val serviceShapeId = ShapeId.from("com.amazonaws.constraints#ConstraintsService") + val model = + File(filePath).readText().asSmithyModel() + return Pair(serviceShapeId, model) +} + +/** + * Removes all existing protocol traits annotated on the given service, + * then sets the provided `protocol` as the sole protocol trait for the service. + */ +fun Model.replaceProtocolTrait( + serviceShapeId: ShapeId, + modelProtocol: ModelProtocol, +): Model { + val serviceBuilder = + this.expectShape(serviceShapeId, ServiceShape::class.java).toBuilder() + for (p in ModelProtocol.values()) { + serviceBuilder.removeTrait(p.trait.toShapeId()) + } + val service = serviceBuilder.addTrait(modelProtocol.trait).build() + return ModelTransformer.create().replaceShapes(this, listOf(service)) +} + +fun List.containsAnyShapeId(ids: Collection): Boolean { + return ids.any { id -> this.any { shape -> shape == id } } +} + +/** + * Removes the given operations from the model. + */ +fun Model.removeOperations( + serviceShapeId: ShapeId, + operationsToRemove: List, +): Model { + val service = this.expectShape(serviceShapeId, ServiceShape::class.java) + val serviceBuilder = service.toBuilder() + // The operation must exist in the service. + service.operations.map { it.toShapeId() }.containsAll(operationsToRemove) shouldBe true + // Remove all operations. + for (opToRemove in operationsToRemove) { + serviceBuilder.removeOperation(opToRemove) + } + val changedModel = ModelTransformer.create().replaceShapes(this, listOf(serviceBuilder.build())) + // The operation must not exist in the updated service. + val changedService = changedModel.expectShape(serviceShapeId, ServiceShape::class.java) + changedService.operations.size shouldBeGreaterThan 0 + changedService.operations.map { it.toShapeId() }.containsAnyShapeId(operationsToRemove) shouldBe false + + return changedModel +} + class ConstraintsTest { private val model = """ diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt new file mode 100644 index 0000000000..da808553a7 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt @@ -0,0 +1,32 @@ +package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol +import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol +import software.amazon.smithy.rust.codegen.server.smithy.removeOperations +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest + +class CborConstraintsIntegrationTest { + @Test + fun `ensure CBOR implementation works for all constraint types`() { + val (serviceShape, constraintModel) = loadSmithyConstraintsModelForProtocol(ModelProtocol.Rpcv2Cbor) + // Event streaming operations are not supported by `Rpcv2Cbor` implementation. + // https://github.com/smithy-lang/smithy-rs/issues/3573 + val nonSupportedOperations = + listOf("EventStreamsOperation", "StreamingBlobOperation") + .map { ShapeId.from("${serviceShape.namespace}#$it") } + val model = + constraintModel + .removeOperations(serviceShape, nonSupportedOperations) + // The test should compile; no further testing is required. + serverIntegrationTest( + model, + IntegrationTestParams( + service = serviceShape.toString(), + ), + ) { _, _ -> + } + } +} From 5bb92b13879790315a21444c9bacb256cb399b3b Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Fri, 30 Aug 2024 12:04:49 +0100 Subject: [PATCH 05/14] Use CBOR encoded string for marhsalling tests --- codegen-core/build.gradle.kts | 1 + .../protocols/parse/CborParserGenerator.kt | 20 ++++- .../serialize/CborSerializerGenerator.kt | 22 ++++- .../core/testutil/EventStreamTestModels.kt | 88 +++++++++++++------ .../server/smithy/ServerCodegenVisitor.kt | 5 +- ...erverProtocolBasedTransformationFactory.kt | 63 +++++-------- .../codegen/server/smithy/ConstraintsTest.kt | 2 +- .../CborConstraintsIntegrationTest.kt | 4 +- 8 files changed, 128 insertions(+), 77 deletions(-) diff --git a/codegen-core/build.gradle.kts b/codegen-core/build.gradle.kts index eff612be35..0baf35b6a7 100644 --- a/codegen-core/build.gradle.kts +++ b/codegen-core/build.gradle.kts @@ -25,6 +25,7 @@ dependencies { implementation("org.jsoup:jsoup:1.16.2") api("software.amazon.smithy:smithy-codegen-core:$smithyVersion") api("com.moandjiezana.toml:toml4j:0.7.2") + implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 0cc16c101f..df6d802b17 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.util.PANIC -import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -447,7 +446,24 @@ class CborParserGenerator( } override fun payloadParser(member: MemberShape): RuntimeType { - UNREACHABLE("No protocol using CBOR serialization supports payload binding") + val shape = model.expectShape(member.target) + val returnSymbol = returnSymbolToParse(shape) + check(shape is UnionShape || shape is StructureShape) { + "Payload parser should only be used on structure and union shapes." + } + return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> { + let decoder = &mut #{Decoder}::new(value); + #{DeserializeMember} + } + """, + "ReturnType" to returnSymbol.symbol, + "DeserializeMember" to deserializeMember(member), + *codegenScope, + ) + } } override fun operationParser(operationShape: OperationShape): RuntimeType? { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 625a669b5d..426c1c354b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -204,9 +204,26 @@ class CborSerializerGenerator( } } - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun payloadSerializer(member: MemberShape): RuntimeType { - TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573") + val target = model.expectShape(member.target) + return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName -> + rustBlockTemplate( + "pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}, #{Error}>", + *codegenScope, + "target" to symbolProvider.toSymbol(target), + ) { + rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope) + rustBlock("") { + rust("let encoder = &mut encoder;") + when (target) { + is StructureShape -> serializeStructure(StructContext("input", target)) + is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target)) + else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions") + } + } + rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope) + } + } } override fun unsetStructure(structure: StructureShape): RuntimeType = @@ -223,6 +240,7 @@ class CborSerializerGenerator( } val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + val inputShape = operationShape.inputShape(model) return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> rustBlockTemplate( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index dc37caf714..c09d8b3382 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.testutil +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.dataformat.cbor.CBORFactory import software.amazon.smithy.model.Model import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson @@ -12,16 +14,18 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor +import java.util.Base64 private fun fillInBaseModel( - protocolName: String, + namespacedProtocolName: String, extraServiceAnnotations: String = "", ): String = """ namespace test use smithy.framework#ValidationException - use aws.protocols#$protocolName + use $namespacedProtocolName union TestUnion { Foo: String, @@ -86,22 +90,24 @@ private fun fillInBaseModel( } $extraServiceAnnotations - @$protocolName + @${namespacedProtocolName.substringAfter("#")} service TestService { version: "123", operations: [TestStreamOp] } """ object EventStreamTestModels { - private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel() + private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel() - private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel() + private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel() - private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel() + private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel() + + private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel() private fun awsQuery(): Model = - fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() + fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() private fun ec2Query(): Model = - fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() + fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() data class TestCase( val protocolShapeId: String, @@ -120,39 +126,67 @@ object EventStreamTestModels { override fun toString(): String = protocolShapeId } + private fun base64Encode(input: ByteArray): String { + val encodedBytes = Base64.getEncoder().encode(input) + return String(encodedBytes) + } + + private fun createCBORFromJSON(jsonString: String): ByteArray { + val jsonMapper = ObjectMapper() + val cborMapper = ObjectMapper(CBORFactory()) + // Parse JSON string to a generic type. + val jsonData = jsonMapper.readValue(jsonString, Any::class.java) + // Convert the parsed data to CBOR. + return cborMapper.writeValueAsBytes(jsonData) + } + + private val restJsonTestCase = + TestCase( + protocolShapeId = "aws.protocols#restJson1", + model = restJson1(), + mediaType = "application/json", + requestContentType = "application/vnd.amazon.eventstream", + responseContentType = "application/json", + eventStreamMessageContentType = "application/json", + validTestStruct = """{"someString":"hello","someInt":5}""", + validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", + validTestUnion = """{"Foo":"hello"}""", + validSomeError = """{"Message":"some error"}""", + validUnmodeledError = """{"Message":"unmodeled error"}""", + ) { RestJson(it) } + val TEST_CASES = listOf( // // restJson1 // - TestCase( - protocolShapeId = "aws.protocols#restJson1", - model = restJson1(), - mediaType = "application/json", - requestContentType = "application/vnd.amazon.eventstream", - responseContentType = "application/json", - eventStreamMessageContentType = "application/json", - validTestStruct = """{"someString":"hello","someInt":5}""", - validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", - validTestUnion = """{"Foo":"hello"}""", - validSomeError = """{"Message":"some error"}""", - validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { RestJson(it) }, + restJsonTestCase, + // + // rpcV2Cbor + // + restJsonTestCase.copy( + protocolShapeId = "smithy.protocols#rpcv2Cbor", + model = rpcv2Cbor(), + mediaType = "application/cbor", + responseContentType = "application/cbor", + eventStreamMessageContentType = "application/cbor", + validTestStruct = base64Encode(createCBORFromJSON(restJsonTestCase.validTestStruct)), + validMessageWithNoHeaderPayloadTraits = base64Encode(createCBORFromJSON(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)), + validTestUnion = base64Encode(createCBORFromJSON(restJsonTestCase.validTestUnion)), + validSomeError = base64Encode(createCBORFromJSON(restJsonTestCase.validSomeError)), + validUnmodeledError = base64Encode(createCBORFromJSON(restJsonTestCase.validUnmodeledError)), + protocolBuilder = { RpcV2Cbor(it) }, + ), // // awsJson1_1 // - TestCase( + restJsonTestCase.copy( protocolShapeId = "aws.protocols#awsJson1_1", model = awsJson11(), mediaType = "application/x-amz-json-1.1", requestContentType = "application/x-amz-json-1.1", responseContentType = "application/x-amz-json-1.1", eventStreamMessageContentType = "application/json", - validTestStruct = """{"someString":"hello","someInt":5}""", - validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", - validTestUnion = """{"Foo":"hello"}""", - validSomeError = """{"Message":"some error"}""", - validUnmodeledError = """{"Message":"unmodeled error"}""", ) { AwsJson(it, AwsJsonVersion.Json11) }, // // restXml diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 769cb240de..229943b8ac 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -134,8 +134,7 @@ open class ServerCodegenVisitor( .protocolFor(context.model, service) this.protocolGeneratorFactory = protocolGeneratorFactory - val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service) - model = codegenDecorator.transformModel(service, protocolTransformedModel, settings) + model = codegenDecorator.transformModel(service, baseModel, settings) val serverSymbolProviders = ServerSymbolProviders.from( @@ -210,6 +209,8 @@ open class ServerCodegenVisitor( .let { AttachValidationExceptionToConstrainedOperationInputs.transform(it, settings) } // Tag aggregate shapes reachable from operation input .let(ShapesReachableFromOperationInputTagger::transform) + // Remove traits that are not supported by the chosen protocol + .let { ServerProtocolBasedTransformationFactory.transform(it, settings) } // Normalize event stream operations .let(EventStreamNormalizer::transform) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt index 32bf59a48d..6c3357bf87 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt @@ -4,7 +4,6 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.AbstractShapeBuilder import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.HttpLabelTrait @@ -12,6 +11,8 @@ import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.protocol.traits.Rpcv2CborTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings import software.amazon.smithy.utils.SmithyBuilder import software.amazon.smithy.utils.ToSmithyBuilder @@ -21,17 +22,27 @@ import software.amazon.smithy.utils.ToSmithyBuilder * object that transforms the model and removes specific traits based on the protocol being instantiated. */ object ServerProtocolBasedTransformationFactory { - fun createTransformer(protocolShapeId: ShapeId): Transformer = - when (protocolShapeId) { - Rpcv2CborTrait.ID -> Rpcv2Transformer() - else -> IdentityTransformer() + fun transform( + model: Model, + settings: ServerRustSettings, + ): Model { + val service = settings.getService(model) + if (!service.hasTrait()) { + return model } - interface Transformer { - fun transform( - model: Model, - service: ServiceShape, - ): Model + return ModelTransformer.create().mapShapes(model) { shape -> + when (shape) { + is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID) + is MemberShape -> { + shape + .removeTraitIfPresent(HttpLabelTrait.ID) + .removeTraitIfPresent(HttpPayloadTrait.ID) + } + + else -> shape + } + } } fun T.removeTraitIfPresent( @@ -47,36 +58,4 @@ object ServerProtocolBasedTransformationFactory { this } } - - class Rpcv2Transformer() : Transformer { - override fun transform( - model: Model, - service: ServiceShape, - ): Model { - val transformedModel = - ModelTransformer.create().mapShapes(model) { shape -> - when (shape) { - is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID) - is MemberShape -> { - shape - .removeTraitIfPresent(HttpLabelTrait.ID) - .removeTraitIfPresent(HttpPayloadTrait.ID) - } - - else -> shape - } - } - - return transformedModel - } - } - - class IdentityTransformer() : Transformer { - override fun transform( - model: Model, - service: ServiceShape, - ): Model { - return model - } - } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index 5be3930452..ef7d75dd45 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy import io.kotest.inspectors.forAll import io.kotest.matchers.ints.shouldBeGreaterThan import io.kotest.matchers.shouldBe -import java.io.File import org.junit.jupiter.api.Test import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait @@ -29,6 +28,7 @@ import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import java.io.File enum class ModelProtocol(val trait: AbstractTrait) { AwsJson10(AwsJson1_0Trait.builder().build()), diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt index da808553a7..60bf75d127 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt @@ -3,6 +3,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol import software.amazon.smithy.rust.codegen.server.smithy.removeOperations @@ -15,7 +16,7 @@ class CborConstraintsIntegrationTest { // Event streaming operations are not supported by `Rpcv2Cbor` implementation. // https://github.com/smithy-lang/smithy-rs/issues/3573 val nonSupportedOperations = - listOf("EventStreamsOperation", "StreamingBlobOperation") + listOf("StreamingBlobOperation") .map { ShapeId.from("${serviceShape.namespace}#$it") } val model = constraintModel @@ -25,6 +26,7 @@ class CborConstraintsIntegrationTest { model, IntegrationTestParams( service = serviceShape.toString(), + additionalSettings = ServerAdditionalSettings.builder().generateCodegenComments().toObjectNode(), ), ) { _, _ -> } From 763d598793be8c37899c76ccc05d6ec7f76380d4 Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Wed, 4 Sep 2024 12:03:45 +0100 Subject: [PATCH 06/14] Fix formatting and comments --- .../protocols/parse/CborParserGenerator.kt | 8 ++++---- .../serialize/CborSerializerGenerator.kt | 1 - .../server/smithy/ServerCodegenVisitor.kt | 2 +- ...BeforeSerializingMemberCborCustomization.kt | 7 +++++++ .../codegen/server/smithy/ConstraintsTest.kt | 18 +++++++++++++----- .../CborConstraintsIntegrationTest.kt | 4 +--- 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index df6d802b17..6830d7a1f7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -454,10 +454,10 @@ class CborParserGenerator( return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName -> rustTemplate( """ - pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> { - let decoder = &mut #{Decoder}::new(value); - #{DeserializeMember} - } + pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> { + let decoder = &mut #{Decoder}::new(value); + #{DeserializeMember} + } """, "ReturnType" to returnSymbol.symbol, "DeserializeMember" to deserializeMember(member), diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 426c1c354b..1830fb962f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -240,7 +240,6 @@ class CborSerializerGenerator( } val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) - val inputShape = operationShape.inputShape(model) return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> rustBlockTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 229943b8ac..b1fc07a31a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -209,7 +209,7 @@ open class ServerCodegenVisitor( .let { AttachValidationExceptionToConstrainedOperationInputs.transform(it, settings) } // Tag aggregate shapes reachable from operation input .let(ShapesReachableFromOperationInputTagger::transform) - // Remove traits that are not supported by the chosen protocol + // Remove traits that are not supported by the chosen protocol. .let { ServerProtocolBasedTransformationFactory.transform(it, settings) } // Normalize event stream operations .let(EventStreamNormalizer::transform) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt index 71ced4d9df..728021e8f7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt @@ -13,6 +13,13 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.Value import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType +/** + * Constrained shapes are wrapped in a Rust tuple struct that implements all necessary checks. However, + * for serialization purposes, the inner type of the constrained shape is used for serialization. + * + * The `BeforeSerializingMemberCborCustomization` class generates a reference to the inner type when the shape being + * code-generated is constrained and the `publicConstrainedTypes` codegen flag is set. + */ class BeforeSerializingMemberCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { override fun section(section: CborSerializerSection): Writable = when (section) { diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index ef7d75dd45..1c2868966e 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -38,17 +38,25 @@ enum class ModelProtocol(val trait: AbstractTrait) { Rpcv2Cbor(Rpcv2CborTrait.builder().build()), } -fun loadSmithyConstraintsModelForProtocol(modelProtocol: ModelProtocol): Pair { - val (serviceShapeId, model) = loadSmithyConstraintsModel() - return Pair(serviceShapeId, model.replaceProtocolTrait(serviceShapeId, modelProtocol)) +/** + * Returns the Smithy constraints model from the common repository, with the specified protocol + * applied to the service. + */ +fun loadSmithyConstraintsModelForProtocol(modelProtocol: ModelProtocol): Pair { + val (model, serviceShapeId) = loadSmithyConstraintsModel() + return Pair(model.replaceProtocolTrait(serviceShapeId, modelProtocol), serviceShapeId) } -fun loadSmithyConstraintsModel(): Pair { +/** + * Loads the Smithy constraints model defined in the common repository and returns the model along with + * the service shape defined in it. + */ +fun loadSmithyConstraintsModel(): Pair { val filePath = "../codegen-core/common-test-models/constraints.smithy" val serviceShapeId = ShapeId.from("com.amazonaws.constraints#ConstraintsService") val model = File(filePath).readText().asSmithyModel() - return Pair(serviceShapeId, model) + return Pair(model, serviceShapeId) } /** diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt index eea9fbba37..10aea6a044 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt @@ -1,18 +1,16 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize import org.junit.jupiter.api.Test -import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol -import software.amazon.smithy.rust.codegen.server.smithy.removeOperations import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest class CborConstraintsIntegrationTest { @Test fun `ensure CBOR implementation works for all constraint types`() { - val (serviceShape, model) = loadSmithyConstraintsModelForProtocol(ModelProtocol.Rpcv2Cbor) + val (model, serviceShape) = loadSmithyConstraintsModelForProtocol(ModelProtocol.Rpcv2Cbor) // The test should compile; no further testing is required. serverIntegrationTest( model, From 91d80a4264ee62920f5a8a45cadd85ac03f27e70 Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Sun, 8 Sep 2024 20:23:24 +0100 Subject: [PATCH 07/14] Implement `parseEventStreamErrorMetadata`, and change client test case for event stream --- ...entEventStreamUnmarshallerGeneratorTest.kt | 3 +- .../core/smithy/protocols/RpcV2Cbor.kt | 20 +++++++++- .../EventStreamUnmarshallTestCases.kt | 40 ++++++++++++++++--- rust-runtime/inlineable/src/cbor_errors.rs | 14 +++++-- 4 files changed, 66 insertions(+), 11 deletions(-) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt index 92d0b3663c..0ae50bc013 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt @@ -10,6 +10,7 @@ import org.junit.jupiter.params.provider.ArgumentsSource import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.generateRustPayloadInitializer import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.testModule @@ -46,7 +47,7 @@ class ClientEventStreamUnmarshallerGeneratorTest { "exception", "UnmodeledError", "${testCase.responseContentType}", - br#"${testCase.validUnmodeledError}"# + ${testCase.generateRustPayloadInitializer(testCase.validUnmodeledError)} ); let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index f67638edba..10c7ba01c9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext @@ -140,9 +141,24 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata") - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = - TODO("rpcv2Cbor event streams have not yet been implemented") + ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> + // `HeaderMap::new()` doesn't allocate. + rustTemplate( + """ + pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{DeserializeError}> { + #{cbor_errors}::parse_error_metadata(0, &#{Headers}::new(), payload) + } + """, + "cbor_errors" to RuntimeType.cborErrors(runtimeConfig), + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "DeserializeError" to + CargoDependency.smithyCbor(runtimeConfig).toType() + .resolve("decode::DeserializeError"), + "Headers" to RuntimeType.headers(runtimeConfig), + ) + } // Unlike other protocols, the `rpcv2Cbor` protocol requires that `Content-Length` is always set // unless there is no input or if the operation is an event stream, see diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt index 4a94d0af3a..99d95dd3cf 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt @@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.util.lookup +import java.util.Base64 object EventStreamUnmarshallTestCases { fun RustWriter.writeUnmarshallTestCases( @@ -109,7 +110,7 @@ object EventStreamUnmarshallTestCases { "event", "MessageWithStruct", "${testCase.responseContentType}", - br##"${testCase.validTestStruct}"## + ${testCase.generateRustPayloadInitializer(testCase.validTestStruct)} ); let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); @@ -140,7 +141,7 @@ object EventStreamUnmarshallTestCases { "event", "MessageWithUnion", "${testCase.responseContentType}", - br##"${testCase.validTestUnion}"## + ${testCase.generateRustPayloadInitializer(testCase.validTestUnion)} ); let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); @@ -221,7 +222,7 @@ object EventStreamUnmarshallTestCases { "event", "MessageWithNoHeaderPayloadTraits", "${testCase.responseContentType}", - br##"${testCase.validMessageWithNoHeaderPayloadTraits}"## + ${testCase.generateRustPayloadInitializer(testCase.validMessageWithNoHeaderPayloadTraits)} ); let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); @@ -246,7 +247,7 @@ object EventStreamUnmarshallTestCases { "exception", "SomeError", "${testCase.responseContentType}", - br##"${testCase.validSomeError}"## + ${testCase.generateRustPayloadInitializer(testCase.validSomeError)} ); let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); @@ -267,7 +268,7 @@ object EventStreamUnmarshallTestCases { "event", "MessageWithBlob", "wrong-content-type", - br#"${testCase.validTestStruct}"# + ${testCase.generateRustPayloadInitializer(testCase.validTestStruct)} ); let result = $generator::new().unmarshall(&message); assert!(result.is_err(), "expected error, got: {:?}", result); @@ -275,6 +276,35 @@ object EventStreamUnmarshallTestCases { """, ) } + + /** + * Generates a Rust-compatible initializer string for a given payload. + * + * This function handles two different scenarios based on the event stream message content type: + * + * 1. For CBOR payloads (content type "application/cbor"): + * - The input payload is expected to be a base64 encoded CBOR value. + * - It decodes the base64 string and generates a Rust byte array initializer. + * - The output format is: &[0xFFu8, 0xFFu8, ...] where FF are hexadecimal values. + * + * 2. For all other content types: + * - It returns a Rust raw string literal initializer. + * - The output format is: br##"original_payload"## + */ + fun EventStreamTestModels.TestCase.generateRustPayloadInitializer(payload: String): String { + return if (this.eventStreamMessageContentType == "application/cbor") { + Base64.getDecoder().decode(payload) + .joinToString( + prefix = "&[", + postfix = "]", + transform = { "0x${it.toUByte().toString(16).padStart(2, '0')}u8" }, + ) + } else { + """ + br##"$payload"## + """ + } + } } internal fun conditionalBuilderInput( diff --git a/rust-runtime/inlineable/src/cbor_errors.rs b/rust-runtime/inlineable/src/cbor_errors.rs index d96c5233aa..44a611e4fb 100644 --- a/rust-runtime/inlineable/src/cbor_errors.rs +++ b/rust-runtime/inlineable/src/cbor_errors.rs @@ -32,7 +32,7 @@ pub fn parse_error_metadata( _response_headers: &Headers, response_body: &[u8], ) -> Result { - fn error_code( + fn error_code_and_message( mut builder: ErrorMetadataBuilder, decoder: &mut Decoder, ) -> Result { @@ -41,6 +41,14 @@ pub fn parse_error_metadata( let code = decoder.str()?; builder.code(sanitize_error_code(&code)) } + "message" | "Message" | "errorMessage" => { + // Silently skip if `message` is not a string. This allows for custom error + // structures that might use different types for the message field. + match decoder.str() { + Ok(message) => builder.message(message), + Err(_) => builder + } + } _ => { decoder.skip()?; builder @@ -60,13 +68,13 @@ pub fn parse_error_metadata( break; } _ => { - builder = error_code(builder, decoder)?; + builder = error_code_and_message(builder, decoder)?; } }; }, Some(n) => { for _ in 0..n { - builder = error_code(builder, decoder)?; + builder = error_code_and_message(builder, decoder)?; } } }; From d51cc7813b4f178c68f8fec5294bed3bff0e988c Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Mon, 9 Sep 2024 10:05:30 +0100 Subject: [PATCH 08/14] Add copyright --- .../BeforeSerializingMemberCborCustomization.kt | 4 ++++ .../transformers/ServerProtocolBasedTransformationFactory.kt | 4 ++++ .../protocols/serialize/CborConstraintsIntegrationTest.kt | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt index 728021e8f7..ee39022525 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt @@ -1,3 +1,7 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ package software.amazon.smithy.rust.codegen.server.smithy.customizations import software.amazon.smithy.model.shapes.BlobShape diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt index 87017dd3bd..b18212b8b1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt @@ -1,3 +1,7 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ package software.amazon.smithy.rust.codegen.server.smithy.transformers import software.amazon.smithy.model.Model diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt index 10aea6a044..0a80a125d6 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt @@ -1,3 +1,7 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize import org.junit.jupiter.api.Test From 050d9c5bb8f077b61ace73805bed2422bba4844c Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Mon, 9 Sep 2024 10:48:32 +0100 Subject: [PATCH 09/14] Add changelog and fix lint issues --- .changelog/2155171.md | 9 +++++++++ rust-runtime/inlineable/src/cbor_errors.rs | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 .changelog/2155171.md diff --git a/.changelog/2155171.md b/.changelog/2155171.md new file mode 100644 index 0000000000..f38d61bd37 --- /dev/null +++ b/.changelog/2155171.md @@ -0,0 +1,9 @@ +--- +applies_to: ["server","client"] +authors: ["drganjoo"] +references: [] +breaking: false +new_feature: true +bug_fix: false +--- +Support for [rpcv2Cbor](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) has been added, allowing services to serialize RPC payloads as CBOR (Concise Binary Object Representation), improving performance and efficiency in data transmission. diff --git a/rust-runtime/inlineable/src/cbor_errors.rs b/rust-runtime/inlineable/src/cbor_errors.rs index 44a611e4fb..f72ea882f4 100644 --- a/rust-runtime/inlineable/src/cbor_errors.rs +++ b/rust-runtime/inlineable/src/cbor_errors.rs @@ -46,7 +46,7 @@ pub fn parse_error_metadata( // structures that might use different types for the message field. match decoder.str() { Ok(message) => builder.message(message), - Err(_) => builder + Err(_) => builder, } } _ => { From 3e74cc83b4a906db47d7f9232d0ca88ea83551a0 Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Tue, 17 Sep 2024 13:20:12 +0100 Subject: [PATCH 10/14] Update .changelog/2155171.md Co-authored-by: david-perez --- .changelog/2155171.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.changelog/2155171.md b/.changelog/2155171.md index f38d61bd37..50b671fb49 100644 --- a/.changelog/2155171.md +++ b/.changelog/2155171.md @@ -6,4 +6,4 @@ breaking: false new_feature: true bug_fix: false --- -Support for [rpcv2Cbor](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) has been added, allowing services to serialize RPC payloads as CBOR (Concise Binary Object Representation), improving performance and efficiency in data transmission. +Support for the [rpcv2Cbor](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) protocol has been added, allowing services to serialize RPC payloads as CBOR (Concise Binary Object Representation), improving performance and efficiency in data transmission. From 8bf71f1766f53304e8322b57e29a4abef5765354 Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Tue, 17 Sep 2024 13:20:25 +0100 Subject: [PATCH 11/14] Update .changelog/2155171.md Co-authored-by: david-perez --- .changelog/2155171.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.changelog/2155171.md b/.changelog/2155171.md index 50b671fb49..a737eefc68 100644 --- a/.changelog/2155171.md +++ b/.changelog/2155171.md @@ -1,7 +1,7 @@ --- applies_to: ["server","client"] authors: ["drganjoo"] -references: [] +references: [smithy-rs#3573] breaking: false new_feature: true bug_fix: false From 13c0c783ca073619bf16595d7bb7d2b854015777 Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Tue, 17 Sep 2024 13:20:39 +0100 Subject: [PATCH 12/14] Update codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt Co-authored-by: david-perez --- .../smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt | 1 - 1 file changed, 1 deletion(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index 10c7ba01c9..d3d9800b1a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -143,7 +143,6 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> - // `HeaderMap::new()` doesn't allocate. rustTemplate( """ pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{DeserializeError}> { From c36dfb0ff0b317767a89f9acd17b054a871165cc Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Tue, 17 Sep 2024 13:20:48 +0100 Subject: [PATCH 13/14] Update codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt Co-authored-by: david-perez --- .../smithy/rust/codegen/core/testutil/EventStreamTestModels.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index c09d8b3382..2faafbbf2a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -131,7 +131,7 @@ object EventStreamTestModels { return String(encodedBytes) } - private fun createCBORFromJSON(jsonString: String): ByteArray { + private fun createCborFromJson(jsonString: String): ByteArray { val jsonMapper = ObjectMapper() val cborMapper = ObjectMapper(CBORFactory()) // Parse JSON string to a generic type. From 5bbfdc2eec7075237faa8dbce5a41bd2eb3b2947 Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Tue, 17 Sep 2024 13:38:22 +0100 Subject: [PATCH 14/14] Add comments to clarify that the ServerProtocolBasedTransformationFactory class will be removed later on --- .../codegen/core/testutil/EventStreamTestModels.kt | 10 +++++----- .../ServerProtocolBasedTransformationFactory.kt | 6 +++++- .../rust/codegen/server/smithy/ConstraintsTest.kt | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index 2faafbbf2a..82544953c6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -170,11 +170,11 @@ object EventStreamTestModels { mediaType = "application/cbor", responseContentType = "application/cbor", eventStreamMessageContentType = "application/cbor", - validTestStruct = base64Encode(createCBORFromJSON(restJsonTestCase.validTestStruct)), - validMessageWithNoHeaderPayloadTraits = base64Encode(createCBORFromJSON(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)), - validTestUnion = base64Encode(createCBORFromJSON(restJsonTestCase.validTestUnion)), - validSomeError = base64Encode(createCBORFromJSON(restJsonTestCase.validSomeError)), - validUnmodeledError = base64Encode(createCBORFromJSON(restJsonTestCase.validUnmodeledError)), + validTestStruct = base64Encode(createCborFromJson(restJsonTestCase.validTestStruct)), + validMessageWithNoHeaderPayloadTraits = base64Encode(createCborFromJson(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)), + validTestUnion = base64Encode(createCborFromJson(restJsonTestCase.validTestUnion)), + validSomeError = base64Encode(createCborFromJson(restJsonTestCase.validSomeError)), + validUnmodeledError = base64Encode(createCborFromJson(restJsonTestCase.validUnmodeledError)), protocolBuilder = { RpcV2Cbor(it) }, ), // diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt index b18212b8b1..ea80d44a7d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt @@ -24,8 +24,12 @@ import software.amazon.smithy.utils.ToSmithyBuilder /** * Each protocol may not support all of the features that Smithy allows. For instance, `rpcv2Cbor` - * does not support HTTP bindings. `ServerProtocolBasedTransformationFactory` is a factory + * does not support HTTP bindings other than `@httpError`. `ServerProtocolBasedTransformationFactory` is a factory * object that transforms the model and removes specific traits based on the protocol being instantiated. + * + * In the long term, this class will be removed, and each protocol should be resilient enough to ignore extra + * traits that the model is annotated with. This will be addressed when we fix issue + * [#2979](https://github.com/smithy-lang/smithy-rs/issues/2979). */ object ServerProtocolBasedTransformationFactory { fun transform( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index 1c2868966e..31123ef811 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -53,9 +53,9 @@ fun loadSmithyConstraintsModelForProtocol(modelProtocol: ModelProtocol): Pair { val filePath = "../codegen-core/common-test-models/constraints.smithy" - val serviceShapeId = ShapeId.from("com.amazonaws.constraints#ConstraintsService") val model = File(filePath).readText().asSmithyModel() + val serviceShapeId = model.shapes().filter { it.isServiceShape }.findFirst().orElseThrow().id return Pair(model, serviceShapeId) }