From 7ec3c9c87049cbba0dda749e7e8e38042ed433f9 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Mon, 13 Jun 2022 13:23:14 +0200 Subject: [PATCH] Support server event streams * Server event streams * Rename EventStreamInput to EventStreamSender * Make event stream errors optional * Pokemon service model updated * Pokemon server event handler * Pokemon client to test event streams * EventStreamDecorator to make optional using SigV4 signing Closes: #1157 Signed-off-by: Daniele Ahmed --- CHANGELOG.next.toml | 12 ++ .../smithy/rustsdk/SigV4SigningDecorator.kt | 51 ++---- codegen-server-test/model/pokemon.smithy | 78 +++++++- .../smithy/PythonCodegenServerPlugin.kt | 2 +- .../server/smithy/RustCodegenServerPlugin.kt | 3 +- .../AdditionalErrorsDecorator.kt | 3 +- .../ServerCombinedErrorGenerator.kt | 76 +++++--- .../ServerOperationHandlerGenerator.kt | 3 +- .../protocol/ServerProtocolTestGenerator.kt | 3 +- .../ServerHttpBoundProtocolGenerator.kt | 62 ++++--- .../smithy/EventStreamSymbolProvider.kt | 24 ++- .../rust/codegen/smithy/RustCodegenPlugin.kt | 7 +- .../NoOpEventStreamSigningDecorator.kt | 63 +++++++ .../config/EventStreamSigningConfig.kt | 60 +++++++ .../error/CombinedErrorGenerator.kt | 87 +++++---- .../error/TopLevelErrorGenerator.kt | 55 +++--- .../protocols/HttpBoundProtocolGenerator.kt | 10 +- .../HttpBoundProtocolPayloadGenerator.kt | 104 ++++++++--- .../rust/codegen/smithy/protocols/RestJson.kt | 2 +- .../rust/codegen/smithy/protocols/RestXml.kt | 2 +- .../parse/EventStreamUnmarshallerGenerator.kt | 129 ++++++++++---- .../EventStreamErrorMarshallerGenerator.kt | 168 ++++++++++++++++++ .../EventStreamMarshallerGenerator.kt | 22 ++- .../transformers/EventStreamNormalizer.kt | 66 ++++++- .../RemoveEventStreamOperations.kt | 3 +- .../amazon/smithy/rust/codegen/util/Smithy.kt | 3 +- .../HttpVersionListGeneratorTest.kt | 4 +- .../smithy/EventStreamSymbolProviderTest.kt | 7 +- .../aws-smithy-eventstream/src/frame.rs | 14 +- .../examples/Cargo.toml | 2 +- .../examples/pokemon_service/Cargo.toml | 3 + .../examples/pokemon_service/src/lib.rs | 72 +++++++- .../examples/pokemon_service/src/main.rs | 5 +- .../tests/simple_integration_test.rs | 151 +++++++++++++++- .../aws-smithy-http/src/event_stream.rs | 8 +- .../event_stream/{output.rs => receiver.rs} | 30 +++- .../src/event_stream/{input.rs => sender.rs} | 165 +++++++++++------ 37 files changed, 1253 insertions(+), 306 deletions(-) create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/NoOpEventStreamSigningDecorator.kt create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/config/EventStreamSigningConfig.kt create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt rename rust-runtime/aws-smithy-http/src/event_stream/{output.rs => receiver.rs} (95%) rename rust-runtime/aws-smithy-http/src/event_stream/{input.rs => sender.rs} (56%) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 7e66a96517..0510f2d129 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -31,6 +31,18 @@ references = ["smithy-rs#1263"] meta = { "breaking" = false, "tada" = false, "bug" = false } author = "Velfi" +[[smithy-rs]] +message = "Rename EventStreamInput to EventStreamSender" +references = ["smithy-rs#1157"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "82marbag" + +[[aws-sdk-rust]] +message = "Rename EventStreamInput to EventStreamSender" +references = ["smithy-rs#1157"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "82marbag" + [[aws-sdk-rust]] message = "Re-export aws_types::SdkConfig in aws_config" references = ["smithy-rs#1457"] diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt index 77fa9dd95b..afd039fbc0 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt @@ -13,7 +13,6 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.OptionalAuthTrait -import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust @@ -27,7 +26,7 @@ import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomizati import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization -import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig +import software.amazon.smithy.rust.codegen.smithy.generators.config.EventStreamSigningConfig import software.amazon.smithy.rust.codegen.smithy.letIf import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait @@ -82,51 +81,37 @@ class SigV4SigningConfig( runtimeConfig: RuntimeConfig, private val serviceHasEventStream: Boolean, private val sigV4Trait: SigV4Trait -) : ConfigCustomization() { +) : EventStreamSigningConfig(runtimeConfig) { private val codegenScope = arrayOf( "SigV4Signer" to RuntimeType( "SigV4Signer", runtimeConfig.awsRuntimeDependency("aws-sig-auth", setOf("sign-eventstream")), "aws_sig_auth::event_stream" ), - "SharedPropertyBag" to RuntimeType( - "SharedPropertyBag", - CargoDependency.SmithyHttp(runtimeConfig), - "aws_smithy_http::property_bag" - ) ) - override fun section(section: ServiceConfig): Writable { - return when (section) { - is ServiceConfig.ConfigImpl -> writable { + override fun configImplSection() = renderEventStreamSignerFn { propertiesName -> + writable { + rustTemplate( + """ + /// The signature version 4 service signing name to use in the credential scope when signing requests. + /// + /// The signing service may be overridden by the `Endpoint`, or by specifying a custom + /// [`SigningService`](aws_types::SigningService) during operation construction + pub fn signing_service(&self) -> &'static str { + ${sigV4Trait.name.dq()} + } + """, + *codegenScope + ) + if (serviceHasEventStream) { rustTemplate( """ - /// The signature version 4 service signing name to use in the credential scope when signing requests. - /// - /// The signing service may be overridden by the `Endpoint`, or by specifying a custom - /// [`SigningService`](aws_types::SigningService) during operation construction - pub fn signing_service(&self) -> &'static str { - ${sigV4Trait.name.dq()} - } + #{SigV4Signer}::new($propertiesName) """, *codegenScope ) - if (serviceHasEventStream) { - rustTemplate( - """ - /// Creates a new Event Stream `SignMessage` implementor. - pub fn new_event_stream_signer( - &self, - properties: #{SharedPropertyBag} - ) -> #{SigV4Signer} { - #{SigV4Signer}::new(properties) - } - """, - *codegenScope - ) - } } - else -> emptySection } } } diff --git a/codegen-server-test/model/pokemon.smithy b/codegen-server-test/model/pokemon.smithy index 2f8bd5a84d..0bb3b58003 100644 --- a/codegen-server-test/model/pokemon.smithy +++ b/codegen-server-test/model/pokemon.smithy @@ -10,7 +10,7 @@ use aws.protocols#restJson1 service PokemonService { version: "2021-12-01", resources: [PokemonSpecies], - operations: [GetServerStatistics, EmptyOperation], + operations: [GetServerStatistics, EmptyOperation, CapturePokemonOperation], } /// A Pokémon species forms the basis for at least one Pokémon. @@ -22,6 +22,82 @@ resource PokemonSpecies { read: GetPokemonSpecies, } +/// Capture Pokémons via event streams +@http(uri: "/capture-pokemon-event/{region}", method: "POST") +operation CapturePokemonOperation { + input: CapturePokemonOperationEventsInput, + output: CapturePokemonOperationEventsOutput, + errors: [UnsupportedRegionError, ThrottlingError] +} + +@input +structure CapturePokemonOperationEventsInput { + @httpPayload + events: AttemptCapturingPokemonEvent, + + @httpLabel + @required + region: String, +} + +@output +structure CapturePokemonOperationEventsOutput { + @httpPayload + events: CapturePokemonEvents, +} + +@streaming +union AttemptCapturingPokemonEvent { + event: CapturingEvent, + masterball_unsuccessful: MasterBallUnsuccessful, +} + +structure CapturingEvent { + @eventPayload + payload: CapturingPayload, +} + +structure CapturingPayload { + name: String, + pokeball: String, +} + +@streaming +union CapturePokemonEvents { + event: CaptureEvent, + invalid_pokeball: InvalidPokeballError, + throttlingError: ThrottlingError, +} + +structure CaptureEvent { + @eventHeader + name: String, + @eventHeader + captured: Boolean, + @eventHeader + shiny: Boolean, + @eventPayload + pokedex_update: Blob, +} + +@error("server") +structure UnsupportedRegionError { + @required + region: String, +} +@error("client") +structure InvalidPokeballError { + @required + pokeball: String, +} +@error("server") +structure MasterBallUnsuccessful { + @required + message: String, +} +@error("client") +structure ThrottlingError {} + /// Retrieve information about a Pokémon species. @readonly @http(uri: "/pokemon-species/{name}", method: "GET") diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt index 42400cc0f7..e87cbda4ec 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt @@ -38,7 +38,7 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin { override fun execute(context: PluginContext) { // Suppress extremely noisy logs about reserved words Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF - // Discover [RustCodegenDecorators] on the classpath. [RustCodegenDectorator] return different types of + // Discover [RustCodegenDecorators] on the classpath. [RustCodegenDecorator] return different types of // customization. A customization is a function of: // - location (e.g. the mutate section of an operation) // - context (e.g. the of the operation) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt index 8395ceb44e..43c1b342dc 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt @@ -20,6 +20,7 @@ import software.amazon.smithy.rust.codegen.smithy.StreamingShapeSymbolProvider import software.amazon.smithy.rust.codegen.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecorator +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import java.util.logging.Level import java.util.logging.Logger @@ -64,7 +65,7 @@ class RustCodegenServerPlugin : SmithyBuildPlugin { SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) // Generate different types for EventStream shapes (e.g. transcribe streaming) .let { - EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model) + EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) } // Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) .let { StreamingShapeSymbolProvider(it, model) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt index 23e59d5001..9bbc1bab71 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt @@ -15,6 +15,7 @@ import software.amazon.smithy.model.traits.RequiredTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator +import software.amazon.smithy.rust.codegen.smithy.transformers.allErrors /** * Add at least one error to all operations in the model. @@ -35,7 +36,7 @@ class AddInternalServerErrorToInfallibleOperationsDecorator : RustCodegenDecorat override val order: Byte = 0 override fun transformModel(service: ServiceShape, model: Model): Model = - addErrorShapeToModelOperations(service, model) { shape -> shape.errors.isEmpty() } + addErrorShapeToModelOperations(service, model) { shape -> shape.allErrors(model).isEmpty() } } /** diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt index 64683f4448..15aa2a1f83 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt @@ -5,9 +5,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model -import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.rustlang.RustWriter @@ -19,6 +20,10 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.error.eventStreamErrorSymbol +import software.amazon.smithy.rust.codegen.smithy.transformers.eventStreamErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors +import software.amazon.smithy.rust.codegen.util.isEventStream import software.amazon.smithy.rust.codegen.util.toSnakeCase /** @@ -30,12 +35,35 @@ open class ServerCombinedErrorGenerator( private val symbolProvider: RustSymbolProvider, private val operation: OperationShape ) { - private val operationIndex = OperationIndex.of(model) - - open fun render(writer: RustWriter) { - val errors = operationIndex.getErrors(operation) - val operationSymbol = symbolProvider.toSymbol(operation) + fun render(writer: RustWriter) { + val errors = operation.operationErrors(model) val symbol = operation.errorSymbol(symbolProvider) + val operationSymbol = symbolProvider.toSymbol(operation) + if (errors.isNotEmpty()) { + renderErrors(writer, errors.map { it.asStructureShape().get() }, symbol, operationSymbol) + } + + if (operation.isEventStream(model)) { + operation.eventStreamErrors(model) + .forEach { (unionShape, unionErrors) -> + if (unionErrors.isNotEmpty()) { + renderErrors( + writer, + unionErrors, + unionShape.eventStreamErrorSymbol(symbolProvider), + symbolProvider.toSymbol(unionShape) + ) + } + } + } + } + + private fun renderErrors( + writer: RustWriter, + errors: List, + errorSymbol: RuntimeType, + operationSymbol: Symbol + ) { val meta = RustMetadata( derives = Attribute.Derives(setOf(RuntimeType.Debug)), visibility = Visibility.PUBLIC @@ -44,7 +72,7 @@ open class ServerCombinedErrorGenerator( writer.rust("/// Error type for the `${operationSymbol.name}` operation.") writer.rust("/// Each variant represents an error that can occur for the `${operationSymbol.name}` operation.") meta.render(writer) - writer.rustBlock("enum ${symbol.name}") { + writer.rustBlock("enum ${errorSymbol.name}") { errors.forEach { errorVariant -> documentShape(errorVariant, model) val errorVariantSymbol = symbolProvider.toSymbol(errorVariant) @@ -52,44 +80,44 @@ open class ServerCombinedErrorGenerator( } } - writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.Display) { + writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.Display) { rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { - delegateToVariants { + delegateToVariants(errors, errorSymbol) { rust("_inner.fmt(f)") } } } - writer.rustBlock("impl ${symbol.name}") { + writer.rustBlock("impl ${errorSymbol.name}") { errors.forEach { error -> - val errorSymbol = symbolProvider.toSymbol(error) - val fnName = errorSymbol.name.toSnakeCase() - writer.rust("/// Returns `true` if the error kind is `${symbol.name}::${errorSymbol.name}`.") + val errorVariantSymbol = symbolProvider.toSymbol(error) + val fnName = errorVariantSymbol.name.toSnakeCase() + writer.rust("/// Returns `true` if the error kind is `${errorSymbol.name}::${errorVariantSymbol.name}`.") writer.rustBlock("pub fn is_$fnName(&self) -> bool") { - rust("matches!(&self, ${symbol.name}::${errorSymbol.name}(_))") + rust("matches!(&self, ${errorSymbol.name}::${errorVariantSymbol.name}(_))") } } writer.rust("/// Returns the error name string by matching the correct variant.") writer.rustBlock("pub fn name(&self) -> &'static str") { - delegateToVariants { + delegateToVariants(errors, errorSymbol) { rust("_inner.name()") } } } - writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) { + writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.StdError) { rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) { - delegateToVariants { + delegateToVariants(errors, errorSymbol) { rust("Some(_inner)") } } } for (error in errors) { - val errorSymbol = symbolProvider.toSymbol(error) - writer.rustBlock("impl #T<#T> for #T", RuntimeType.From, errorSymbol, symbol) { - rustBlock("fn from(variant: #T) -> #T", errorSymbol, symbol) { - rust("Self::${errorSymbol.name}(variant)") + val errorVariantSymbol = symbolProvider.toSymbol(error) + writer.rustBlock("impl #T<#T> for #T", RuntimeType.From, errorVariantSymbol, errorSymbol) { + rustBlock("fn from(variant: #T) -> #T", errorVariantSymbol, errorSymbol) { + rust("Self::${errorVariantSymbol.name}(variant)") } } } @@ -112,10 +140,10 @@ open class ServerCombinedErrorGenerator( * The field will always be bound as `_inner`. */ private fun RustWriter.delegateToVariants( - writable: Writable + errors: List, + symbol: RuntimeType, + writable: Writable, ) { - val errors = operationIndex.getErrors(operation) - val symbol = operation.errorSymbol(symbolProvider) rustBlock("match &self") { errors.forEach { val errorSymbol = symbolProvider.toSymbol(it) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt index eb09b3e8f1..23fa52561a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt @@ -17,6 +17,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBou import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape @@ -137,7 +138,7 @@ open class ServerOperationHandlerGenerator( } else { "Fun: FnOnce($inputName) -> Fut + Clone + Send + 'static," } - val outputType = if (operation.errors.isNotEmpty()) { + val outputType = if (operation.operationErrors(model).isNotEmpty()) { "Result<${symbolProvider.toSymbol(operation.outputShape(model)).fullName}, ${operation.errorSymbol(symbolProvider).fullyQualifiedName()}>" } else { symbolProvider.toSymbol(operation.outputShape(model)).fullName diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index e694e2cc4e..092fee62da 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -39,6 +39,7 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.generators.Instantiator import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.testutil.TokioTest import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.getTrait @@ -282,7 +283,7 @@ class ServerProtocolTestGenerator( writeInline("let output =") instantiator.render(this, shape, testCase.params) write(";") - val operationImpl = if (operationShape.errors.isNotEmpty()) { + val operationImpl = if (operationShape.allErrors(model).isNotEmpty()) { if (shape.hasTrait()) { val variant = symbolProvider.toSymbol(shape).name "$operationImplementationName::Error($operationErrorName::$variant(output))" diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 8d197d6625..71a206ffd1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -59,6 +59,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.toOptional +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.smithy.wrapOptional import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait @@ -216,7 +217,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val outputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" val errorSymbol = operationShape.errorSymbol(symbolProvider) - if (operationShape.errors.isNotEmpty()) { + if (operationShape.operationErrors(model).isNotEmpty()) { // The output of fallible operations is a `Result` which we convert into an // isomorphic `enum` type we control that can in turn be converted into a response. val intoResponseImpl = @@ -302,7 +303,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( } // Implement conversion function to "wrap" from the model operation output types. - if (operationShape.errors.isNotEmpty()) { + if (operationShape.operationErrors(model).isNotEmpty()) { rustTemplate( """ impl #{From}> for $outputName { @@ -441,11 +442,12 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) withBlock("match error {", "}") { - operationShape.errors.forEach { - val variantShape = model.expectShape(it, StructureShape::class.java) + val errors = operationShape.operationErrors(model) + errors.forEach { + val variantShape = model.expectShape(it.id, StructureShape::class.java) val errorTrait = variantShape.expectTrait() val variantSymbol = symbolProvider.toSymbol(variantShape) - val serializerSymbol = structuredDataSerializer.serverErrorSerializer(it) + val serializerSymbol = structuredDataSerializer.serverErrorSerializer(it.id) rustBlock("#T::${variantSymbol.name}(output) =>", errorSymbol) { rust( @@ -507,13 +509,10 @@ private class ServerHttpBoundProtocolTraitImplGenerator( // Fallback to the default code of `http::response::Builder`, 200. operationShape.outputShape(model).findStreamingMember(model)?.let { - val memberName = symbolProvider.toMemberName(it) - rustTemplate( - """ - let payload = #{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName); - """, - *codegenScope, - ) + val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE) + withBlockTemplate("let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", *codegenScope) { + payloadGenerator.generatePayload(this, "output", operationShape) + } } ?: run { val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE) withBlockTemplate("let payload = ", ";") { @@ -521,11 +520,17 @@ private class ServerHttpBoundProtocolTraitImplGenerator( } serverRenderContentLengthHeader() + + rustTemplate( + """ + let body = #{SmithyHttpServer}::body::to_boxed(payload); + """, + *codegenScope, + ) } rustTemplate( """ - let body = #{SmithyHttpServer}::body::to_boxed(payload); builder.body(body)? """, *codegenScope, @@ -703,29 +708,28 @@ private class ServerHttpBoundProtocolTraitImplGenerator( HttpLocation.HEADER -> writable { serverRenderHeaderParser(this, binding, operationShape) } HttpLocation.PREFIX_HEADERS -> writable { serverRenderPrefixHeadersParser(this, binding, operationShape) } HttpLocation.PAYLOAD -> { - return if (binding.member.isStreaming(model)) { - writable { + val structureShapeHandler: RustWriter.(String) -> Unit = { body -> + rust("#T($body)", structuredDataParser.payloadParser(binding.member)) + } + val errorSymbol = getDeserializePayloadErrorSymbol(binding) + val deserializer = httpBindingGenerator.generateDeserializePayloadFn( + binding, + errorSymbol, + structuredHandler = structureShapeHandler + ) + return writable { + if (binding.member.isStreaming(model)) { rustTemplate( """ { let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; - Some(body.into()) + Some(#{Deserializer}(&mut body.into().into_inner())?) } - """.trimIndent(), + """, + "Deserializer" to deserializer, *codegenScope ) - } - } else { - val structureShapeHandler: RustWriter.(String) -> Unit = { body -> - rust("#T($body)", structuredDataParser.payloadParser(binding.member)) - } - val errorSymbol = getDeserializePayloadErrorSymbol(binding) - val deserializer = httpBindingGenerator.generateDeserializePayloadFn( - binding, - errorSymbol, - structuredHandler = structureShapeHandler - ) - writable { + } else { rustTemplate( """ { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProvider.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProvider.kt index bb92cebf6c..10cb47ab64 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProvider.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProvider.kt @@ -14,12 +14,15 @@ import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.render import software.amazon.smithy.rust.codegen.rustlang.stripOuter -import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget +import software.amazon.smithy.rust.codegen.smithy.generators.error.eventStreamErrorSymbol import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.smithy.transformers.eventStreamErrors import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.rust.codegen.util.isEventStream import software.amazon.smithy.rust.codegen.util.isInputEventStream +import software.amazon.smithy.rust.codegen.util.isOutputEventStream /** * Wrapping symbol provider to wrap modeled types with the aws-smithy-http Event Stream send/receive types. @@ -27,8 +30,10 @@ import software.amazon.smithy.rust.codegen.util.isInputEventStream class EventStreamSymbolProvider( private val runtimeConfig: RuntimeConfig, base: RustSymbolProvider, - private val model: Model + private val model: Model, + private val target: CodegenTarget, ) : WrappingSymbolProvider(base) { + private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) override fun toSymbol(shape: Shape): Symbol { val initial = super.toSymbol(shape) @@ -42,20 +47,27 @@ class EventStreamSymbolProvider( } // If we find an operation shape, then we can wrap the type if (operationShape != null) { - val error = operationShape.errorSymbol(this).toSymbol() + val unionShape = model.expectShape(shape.target).asUnionShape().get() + val error = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { + RuntimeType("MessageStreamError", smithyEventStream, "aws_smithy_http::event_stream").toSymbol() + } else { + unionShape.eventStreamErrorSymbol(this).toSymbol() + } val errorFmt = error.rustType().render(fullyQualified = true) val innerFmt = initial.rustType().stripOuter().render(fullyQualified = true) - val outer = when (shape.isInputEventStream(model)) { - true -> "EventStreamInput<$innerFmt>" + val isSender = (shape.isInputEventStream(model) && target == CodegenTarget.CLIENT) || + (shape.isOutputEventStream(model) && target == CodegenTarget.SERVER) + val outer = when (isSender) { + true -> "EventStreamSender<$innerFmt, $errorFmt>" else -> "Receiver<$innerFmt, $errorFmt>" } val rustType = RustType.Opaque(outer, "aws_smithy_http::event_stream") return initial.toBuilder() .name(rustType.name) .rustType(rustType) - .addReference(error) .addReference(initial) .addDependency(CargoDependency.SmithyHttp(runtimeConfig).withFeature("event-stream")) + .addReference(error) .build() } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustCodegenPlugin.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustCodegenPlugin.kt index 47eb606d1f..9e02e6c67a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustCodegenPlugin.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustCodegenPlugin.kt @@ -14,7 +14,9 @@ import software.amazon.smithy.rust.codegen.rustlang.Attribute.Companion.NonExhau import software.amazon.smithy.rust.codegen.rustlang.RustReservedWordSymbolProvider import software.amazon.smithy.rust.codegen.smithy.customizations.ClientCustomizations import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecorator +import software.amazon.smithy.rust.codegen.smithy.customize.NoOpEventStreamSigningDecorator import software.amazon.smithy.rust.codegen.smithy.customize.RequiredCustomizations +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.generators.client.FluentClientDecorator import java.util.logging.Level import java.util.logging.Logger @@ -42,7 +44,8 @@ class RustCodegenPlugin : SmithyBuildPlugin { context, ClientCustomizations(), RequiredCustomizations(), - FluentClientDecorator() + FluentClientDecorator(), + NoOpEventStreamSigningDecorator() ) // CodegenVisitor is the main driver of code generation that traverses the model and generates code @@ -59,7 +62,7 @@ class RustCodegenPlugin : SmithyBuildPlugin { fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig) = SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model) } + .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.CLIENT) } // Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) .let { StreamingShapeSymbolProvider(it, model) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/NoOpEventStreamSigningDecorator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/NoOpEventStreamSigningDecorator.kt new file mode 100644 index 0000000000..88774ea6ef --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/NoOpEventStreamSigningDecorator.kt @@ -0,0 +1,63 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.smithy.customize + +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext +import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization +import software.amazon.smithy.rust.codegen.smithy.generators.config.EventStreamSigningConfig +import software.amazon.smithy.rust.codegen.util.hasEventStreamOperations + +/** + * The NoOpEventStreamSigningDecorator: + * - adds a `new_event_stream_signer()` method to `config` to create an Event Stream NoOp signer + */ +open class NoOpEventStreamSigningDecorator : RustCodegenDecorator { + override val name: String = "NoOpEventStreamSigning" + override val order: Byte = 0 + + private fun applies(codegenContext: CoreCodegenContext): Boolean = + codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model) + + override fun configCustomizations( + codegenContext: C, + baseCustomizations: List + ): List { + if (!applies(codegenContext)) + return baseCustomizations + return baseCustomizations + NoOpEventStreamSigningConfig( + codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model), + codegenContext.runtimeConfig, + ) + } +} + +class NoOpEventStreamSigningConfig( + private val serviceHasEventStream: Boolean, + runtimeConfig: RuntimeConfig, +) : EventStreamSigningConfig(runtimeConfig) { + private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) + private val codegenScope = arrayOf( + "NoOpSigner" to RuntimeType("NoOpSigner", smithyEventStream, "aws_smithy_eventstream::frame"), + ) + + override fun configImplSection() = renderEventStreamSignerFn { + writable { + if (serviceHasEventStream) { + rustTemplate( + """ + #{NoOpSigner}{} + """, + *codegenScope + ) + } + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/config/EventStreamSigningConfig.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/config/EventStreamSigningConfig.kt new file mode 100644 index 0000000000..37e12e31c2 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/config/EventStreamSigningConfig.kt @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.smithy.generators.config + +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.Writable +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.smithy.RuntimeType + +open class EventStreamSigningConfig( + runtimeConfig: RuntimeConfig, +) : ConfigCustomization() { + private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) + private val codegenScope = arrayOf( + "NoOpSigner" to RuntimeType("NoOpSigner", smithyEventStream, "aws_smithy_eventstream::frame"), + "SharedPropertyBag" to RuntimeType( + "SharedPropertyBag", + CargoDependency.SmithyHttp(runtimeConfig), + "aws_smithy_http::property_bag" + ), + "SignMessage" to RuntimeType( + "SignMessage", + CargoDependency.SmithyEventStream(runtimeConfig), + "aws_smithy_eventstream::frame" + ), + ) + open fun signer(): Writable = emptySection + + override fun section(section: ServiceConfig): Writable { + return when (section) { + is ServiceConfig.ConfigImpl -> configImplSection() + else -> emptySection + } + } + + open fun configImplSection(): Writable = emptySection + + fun renderEventStreamSignerFn(signerInstantiator: (String) -> Writable): Writable { + return writable { + rustTemplate( + """ + /// Creates a new Event Stream `SignMessage` implementor. + pub fn new_event_stream_signer( + &self, + _properties: #{SharedPropertyBag} + ) -> impl #{SignMessage} { + #{signer:W} + } + """, + *codegenScope, + "signer" to signerInstantiator("_properties"), + ) + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/CombinedErrorGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/CombinedErrorGenerator.kt index 88049b89e9..738c639c46 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/CombinedErrorGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/CombinedErrorGenerator.kt @@ -8,9 +8,10 @@ package software.amazon.smithy.rust.codegen.smithy.generators.error import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model -import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.RetryableTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.RustMetadata @@ -25,7 +26,10 @@ import software.amazon.smithy.rust.codegen.rustlang.writable import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.customize.Section +import software.amazon.smithy.rust.codegen.smithy.transformers.eventStreamErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.util.hasTrait +import software.amazon.smithy.rust.codegen.util.isEventStream import software.amazon.smithy.rust.codegen.util.toSnakeCase /** @@ -43,6 +47,10 @@ fun OperationShape.errorSymbol(symbolProvider: SymbolProvider): RuntimeType { return RuntimeType("${symbol.name}Error", null, "crate::error") } +fun UnionShape.eventStreamErrorSymbol(symbolProvider: SymbolProvider): RuntimeType { + val symbol = symbolProvider.toSymbol(this) + return RuntimeType("${symbol.name}Error", null, "crate::error") +} /** * Generates a unified error enum for [operation]. [ErrorGenerator] handles generating the individual variants, * but we must still combine those variants into an enum covering all possible errors for a given operation. @@ -52,14 +60,33 @@ class CombinedErrorGenerator( private val symbolProvider: RustSymbolProvider, private val operation: OperationShape ) { - private val operationIndex = OperationIndex.of(model) private val runtimeConfig = symbolProvider.config().runtimeConfig private val genericError = RuntimeType.GenericError(symbolProvider.config().runtimeConfig) fun render(writer: RustWriter) { - val errors = operationIndex.getErrors(operation) - val operationSymbol = symbolProvider.toSymbol(operation) val symbol = operation.errorSymbol(symbolProvider) + val operationSymbol = symbolProvider.toSymbol(operation) + renderErrors(writer, operation.operationErrors(model).map { it.asStructureShape().get() }.toMutableList(), symbol, operationSymbol) + + if (operation.isEventStream(model)) { + operation.eventStreamErrors(model) + .forEach { (unionShape, unionErrors) -> + renderErrors( + writer, + unionErrors, + unionShape.eventStreamErrorSymbol(symbolProvider), + symbolProvider.toSymbol(unionShape) + ) + } + } + } + + private fun renderErrors( + writer: RustWriter, + errors: List, + errorSymbol: RuntimeType, + operationSymbol: Symbol + ) { val meta = RustMetadata( derives = Attribute.Derives(setOf(RuntimeType.Debug)), additionalAttributes = listOf(Attribute.NonExhaustive), @@ -68,11 +95,11 @@ class CombinedErrorGenerator( writer.rust("/// Error type for the `${operationSymbol.name}` operation.") meta.render(writer) - writer.rustBlock("struct ${symbol.name}") { + writer.rustBlock("struct ${errorSymbol.name}") { rust( """ /// Kind of error that occurred. - pub kind: ${symbol.name}Kind, + pub kind: ${errorSymbol.name}Kind, /// Additional metadata about the error, including error code, message, and request ID. pub (crate) meta: #T """, @@ -81,11 +108,11 @@ class CombinedErrorGenerator( } writer.rust("/// Types of errors that can occur for the `${operationSymbol.name}` operation.") meta.render(writer) - writer.rustBlock("enum ${symbol.name}Kind") { + writer.rustBlock("enum ${errorSymbol.name}Kind") { errors.forEach { errorVariant -> documentShape(errorVariant, model) - val errorSymbol = symbolProvider.toSymbol(errorVariant) - write("${errorSymbol.name}(#T),", errorSymbol) + val errorVariantSymbol = symbolProvider.toSymbol(errorVariant) + write("${errorVariantSymbol.name}(#T),", errorVariantSymbol) } rust( """ @@ -95,9 +122,9 @@ class CombinedErrorGenerator( RuntimeType.StdError ) } - writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.Display) { + writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.Display) { rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { - delegateToVariants { + delegateToVariants(errors, errorSymbol) { writable { rust("_inner.fmt(f)") } } } @@ -105,11 +132,11 @@ class CombinedErrorGenerator( val errorKindT = RuntimeType.errorKind(symbolProvider.config().runtimeConfig) writer.rustBlock( - "impl #T for ${symbol.name}", + "impl #T for ${errorSymbol.name}", RuntimeType.provideErrorKind(symbolProvider.config().runtimeConfig) ) { rustBlock("fn code(&self) -> Option<&str>") { - rust("${symbol.name}::code(self)") + rust("${errorSymbol.name}::code(self)") } rustBlock("fn retryable_error_kind(&self) -> Option<#T>", errorKindT) { @@ -119,8 +146,8 @@ class CombinedErrorGenerator( } else { rustBlock("match &self.kind") { retryableVariants.forEach { - val errorSymbol = symbolProvider.toSymbol(it) - rust("${symbol.name}Kind::${errorSymbol.name}(inner) => Some(inner.retryable_error_kind()),") + val errorVariantSymbol = symbolProvider.toSymbol(it) + rust("${errorSymbol.name}Kind::${errorVariantSymbol.name}(inner) => Some(inner.retryable_error_kind()),") } rust("_ => None") } @@ -128,27 +155,27 @@ class CombinedErrorGenerator( } } - writer.rustBlock("impl ${symbol.name}") { + writer.rustBlock("impl ${errorSymbol.name}") { writer.rustTemplate( """ - /// Creates a new `${symbol.name}`. - pub fn new(kind: ${symbol.name}Kind, meta: #{generic_error}) -> Self { + /// Creates a new `${errorSymbol.name}`. + pub fn new(kind: ${errorSymbol.name}Kind, meta: #{generic_error}) -> Self { Self { kind, meta } } - /// Creates the `${symbol.name}::Unhandled` variant from any error type. + /// Creates the `${errorSymbol.name}::Unhandled` variant from any error type. pub fn unhandled(err: impl Into>) -> Self { Self { - kind: ${symbol.name}Kind::Unhandled(err.into()), + kind: ${errorSymbol.name}Kind::Unhandled(err.into()), meta: Default::default() } } - /// Creates the `${symbol.name}::Unhandled` variant from a `#{generic_error}`. + /// Creates the `${errorSymbol.name}::Unhandled` variant from a `#{generic_error}`. pub fn generic(err: #{generic_error}) -> Self { Self { meta: err.clone(), - kind: ${symbol.name}Kind::Unhandled(err.into()), + kind: ${errorSymbol.name}Kind::Unhandled(err.into()), } } @@ -176,18 +203,18 @@ class CombinedErrorGenerator( "generic_error" to genericError, "std_error" to RuntimeType.StdError ) errors.forEach { error -> - val errorSymbol = symbolProvider.toSymbol(error) - val fnName = errorSymbol.name.toSnakeCase() - writer.rust("/// Returns `true` if the error kind is `${symbol.name}Kind::${errorSymbol.name}`.") + val errorVariantSymbol = symbolProvider.toSymbol(error) + val fnName = errorVariantSymbol.name.toSnakeCase() + writer.rust("/// Returns `true` if the error kind is `${errorSymbol.name}Kind::${errorVariantSymbol.name}`.") writer.rustBlock("pub fn is_$fnName(&self) -> bool") { - rust("matches!(&self.kind, ${symbol.name}Kind::${errorSymbol.name}(_))") + rust("matches!(&self.kind, ${errorSymbol.name}Kind::${errorVariantSymbol.name}(_))") } } } - writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) { + writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.StdError) { rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) { - delegateToVariants { + delegateToVariants(errors, errorSymbol) { writable { when (it) { is VariantMatch.Unhandled -> rust("Some(_inner.as_ref())") @@ -222,10 +249,10 @@ class CombinedErrorGenerator( * The field will always be bound as `_inner`. */ private fun RustWriter.delegateToVariants( + errors: List, + symbol: RuntimeType, handler: (VariantMatch) -> Writable ) { - val errors = operationIndex.getErrors(operation) - val symbol = operation.errorSymbol(symbolProvider) rustBlock("match &self.kind") { errors.forEach { val errorSymbol = symbolProvider.toSymbol(it) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt index b63dbd8b12..303b566499 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.smithy.generators.error import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.CargoDependency @@ -21,6 +22,9 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustCrate +import software.amazon.smithy.rust.codegen.smithy.transformers.allErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.eventStreamErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors /** * Each service defines its own "top-level" error combining all possible errors that a service can emit. @@ -41,7 +45,7 @@ class TopLevelErrorGenerator(coreCodegenContext: CoreCodegenContext, private val private val symbolProvider = coreCodegenContext.symbolProvider private val model = coreCodegenContext.model - private val allErrors = operations.flatMap { it.errors }.distinctBy { it.getName(coreCodegenContext.serviceShape) } + private val allErrors = operations.flatMap { it.allErrors(model) }.map { it.id }.distinctBy { it.getName(coreCodegenContext.serviceShape) } .map { coreCodegenContext.model.expectShape(it, StructureShape::class.java) } .sortedBy { it.id.getName(coreCodegenContext.serviceShape) } @@ -73,27 +77,38 @@ class TopLevelErrorGenerator(coreCodegenContext: CoreCodegenContext, private val } private fun RustWriter.renderImplFrom(operationShape: OperationShape) { - val operationError = operationShape.errorSymbol(symbolProvider) - rustBlock( - "impl From<#T<#T, R>> for Error where R: Send + Sync + std::fmt::Debug + 'static", - sdkError, - operationError - ) { - rustBlockTemplate( - "fn from(err: #{SdkError}<#{OpError}, R>) -> Self", - "SdkError" to sdkError, - "OpError" to operationError - ) { - rustBlock("match err") { - val operationErrors = operationShape.errors.map { model.expectShape(it) } - rustBlock("#T::ServiceError { err, ..} => match err.kind", sdkError) { - operationErrors.forEach { errorShape -> - val errSymbol = symbolProvider.toSymbol(errorShape) - rust("#TKind::${errSymbol.name}(inner) => Error::${errSymbol.name}(inner),", operationError) + val nonEventStreamErrors = operationShape.operationErrors(model).map { it.id } + val allErrors: List>> = listOf( + Pair(operationShape.errorSymbol(symbolProvider), nonEventStreamErrors), + ) + operationShape.eventStreamErrors(model) + .map { Pair(it.key.eventStreamErrorSymbol(symbolProvider), it.value.map { it.id }) } + allErrors.forEach { (symbol, errors) -> + if (errors.isNotEmpty()) { + rustBlock( + "impl From<#T<#T, R>> for Error where R: Send + Sync + std::fmt::Debug + 'static", + sdkError, + symbol + ) { + rustBlockTemplate( + "fn from(err: #{SdkError}<#{OpError}, R>) -> Self", + "SdkError" to sdkError, + "OpError" to symbol + ) { + rustBlock("match err") { + val operationErrors = errors.map { model.expectShape(it) } + rustBlock("#T::ServiceError { err, ..} => match err.kind", sdkError) { + operationErrors.forEach { errorShape -> + val errSymbol = symbolProvider.toSymbol(errorShape) + rust( + "#TKind::${errSymbol.name}(inner) => Error::${errSymbol.name}(inner),", + symbol + ) + } + rust("#TKind::Unhandled(inner) => Error::Unhandled(inner),", symbol) + } + rust("_ => Error::Unhandled(err.into()),") } - rust("#TKind::Unhandled(inner) => Error::Unhandled(inner),", operationError) } - rust("_ => Error::Unhandled(err.into()),") } } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt index 65d61f66cf..d5db77ef36 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -32,6 +32,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolTr import software.amazon.smithy.rust.codegen.smithy.generators.setterName import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.util.UNREACHABLE import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.hasStreamingMember @@ -167,7 +168,7 @@ class HttpBoundProtocolTraitImplGenerator( protocol.parseHttpGenericError(operationShape), errorSymbol ) - if (operationShape.errors.isNotEmpty()) { + if (operationShape.operationErrors(model).isNotEmpty()) { rustTemplate( """ let error_code = match generic.code() { @@ -180,9 +181,10 @@ class HttpBoundProtocolTraitImplGenerator( "error_symbol" to errorSymbol, ) withBlock("Err(match error_code {", "})") { - operationShape.errors.forEach { error -> - val errorShape = model.expectShape(error, StructureShape::class.java) - val variantName = symbolProvider.toSymbol(model.expectShape(error)).name + val errors = operationShape.operationErrors(model) + errors.forEach { error -> + val errorShape = model.expectShape(error.id, StructureShape::class.java) + val variantName = symbolProvider.toSymbol(model.expectShape(error.id)).name val errorCode = httpBindingResolver.errorCode(errorShape).dq() withBlock( "$errorCode => #1T { meta: generic, kind: #1TKind::$variantName({", diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 5efac1f935..53e709cbd7 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -25,11 +25,12 @@ import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.EventStreamErrorMarshallerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.EventStreamMarshallerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.util.PANIC @@ -39,6 +40,7 @@ import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.isEventStream import software.amazon.smithy.rust.codegen.util.isInputEventStream +import software.amazon.smithy.rust.codegen.util.isOutputEventStream import software.amazon.smithy.rust.codegen.util.isStreaming import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase @@ -56,13 +58,15 @@ class HttpBoundProtocolPayloadGenerator( private val operationSerModule = RustModule.private("operation_ser") + private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) private val codegenScope = arrayOf( "hyper" to CargoDependency.HyperWithStream.asType(), "ByteStream" to RuntimeType.ByteStream(runtimeConfig), "ByteSlab" to RuntimeType.ByteSlab, "SdkBody" to RuntimeType.sdkBody(runtimeConfig), "BuildError" to runtimeConfig.operationBuildError(), - "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType() + "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(), + "NoOpSigner" to RuntimeType("NoOpSigner", smithyEventStream, "aws_smithy_eventstream::frame"), ) override fun payloadMetadata(operationShape: OperationShape): ProtocolPayloadGenerator.PayloadMetadata { @@ -108,8 +112,7 @@ class HttpBoundProtocolPayloadGenerator( val serializerGenerator = protocol.structuredDataSerializer(operationShape) generateStructureSerializer(writer, self, serializerGenerator.operationInputSerializer(operationShape)) } else { - val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) - generatePayloadMemberSerializer(writer, self, operationShape, payloadMember) + generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) } } @@ -120,8 +123,7 @@ class HttpBoundProtocolPayloadGenerator( val serializerGenerator = protocol.structuredDataSerializer(operationShape) generateStructureSerializer(writer, self, serializerGenerator.operationOutputSerializer(operationShape)) } else { - val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) - generatePayloadMemberSerializer(writer, self, operationShape, payloadMember) + generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) } } @@ -129,15 +131,26 @@ class HttpBoundProtocolPayloadGenerator( writer: RustWriter, self: String, operationShape: OperationShape, - payloadMember: MemberShape + payloadMemberName: String, ) { val serializerGenerator = protocol.structuredDataSerializer(operationShape) - // TODO(https://github.com/awslabs/smithy-rs/issues/1157) Add support for server event streams. - if (operationShape.isInputEventStream(model)) { - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator) + if (operationShape.isEventStream(model)) { + if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) { + val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) + writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "self") + } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) { + val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) + writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output") + } else { + throw CodegenException("Payload serializer for event streams with an invalid configuration") + } } else { val bodyMetadata = payloadMetadata(operationShape) + val payloadMember = when (httpMessageType) { + HttpMessageType.RESPONSE -> operationShape.outputShape(model).expectMember(payloadMemberName) + HttpMessageType.REQUEST -> operationShape.inputShape(model).expectMember(payloadMemberName) + } writer.serializeViaPayload(bodyMetadata, self, payloadMember, serializerGenerator) } } @@ -156,11 +169,26 @@ class HttpBoundProtocolPayloadGenerator( private fun RustWriter.serializeViaEventStream( operationShape: OperationShape, memberShape: MemberShape, - serializerGenerator: StructuredDataSerializerGenerator + serializerGenerator: StructuredDataSerializerGenerator, + outerName: String, ) { val memberName = symbolProvider.toMemberName(memberShape) val unionShape = model.expectShape(memberShape.target, UnionShape::class.java) + val contentType = when (target) { + CodegenTarget.CLIENT -> httpBindingResolver.requestContentType(operationShape) + CodegenTarget.SERVER -> httpBindingResolver.responseContentType(operationShape) + } + val errorMarshallerConstructorFn = EventStreamErrorMarshallerGenerator( + model, + target, + runtimeConfig, + symbolProvider, + unionShape, + operationShape, + serializerGenerator, + contentType ?: throw CodegenException("event streams must set a content type"), + ).render() val marshallerConstructorFn = EventStreamMarshallerGenerator( model, target, @@ -168,27 +196,47 @@ class HttpBoundProtocolPayloadGenerator( symbolProvider, unionShape, serializerGenerator, - httpBindingResolver.requestContentType(operationShape) - ?: throw CodegenException("event streams must set a content type"), + contentType ?: throw CodegenException("event streams must set a content type"), ).render() // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the - // parameters that are not `@eventHeader` or `@eventPayload`. - rustTemplate( - """ - { - let marshaller = #{marshallerConstructorFn}(); - let signer = _config.new_event_stream_signer(properties.clone()); - let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, #{OperationError}> = - self.$memberName.into_body_stream(marshaller, signer); - let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); - body + // parameters that are not `@eventHeader` or `@eventPayload`. + when (target) { + CodegenTarget.CLIENT -> + rustTemplate( + """ + { + let error_marshaller = #{errorMarshallerConstructorFn}(); + let marshaller = #{marshallerConstructorFn}(); + let signer = _config.new_event_stream_signer(properties.clone()); + let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> = + $outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer); + let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); + body + } + """, + *codegenScope, + "marshallerConstructorFn" to marshallerConstructorFn, + "errorMarshallerConstructorFn" to errorMarshallerConstructorFn, + ) + CodegenTarget.SERVER -> { + rustTemplate( + """ + { + let error_marshaller = #{errorMarshallerConstructorFn}(); + let marshaller = #{marshallerConstructorFn}(); + let signer = #{NoOpSigner}{}; + let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> = + $outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer); + adapter + } + """, + *codegenScope, + "marshallerConstructorFn" to marshallerConstructorFn, + "errorMarshallerConstructorFn" to errorMarshallerConstructorFn, + ) } - """, - *codegenScope, - "marshallerConstructorFn" to marshallerConstructorFn, - "OperationError" to operationShape.errorSymbol(symbolProvider) - ) + } } private fun RustWriter.serializeViaPayload( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestJson.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestJson.kt index a70de5d3b4..973e61578f 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestJson.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestJson.kt @@ -100,7 +100,7 @@ class RestJson(private val coreCodegenContext: CoreCodegenContext) : Protocol { private val jsonDeserModule = RustModule.private("json_deser") override val httpBindingResolver: HttpBindingResolver = - RestJsonHttpBindingResolver(coreCodegenContext.model, ProtocolContentTypes.consistent("application/json")) + RestJsonHttpBindingResolver(coreCodegenContext.model, ProtocolContentTypes("application/json", "application/json", "application/vnd.amazon.eventstream")) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestXml.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestXml.kt index 0866ec832d..65d783ef36 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestXml.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestXml.kt @@ -68,7 +68,7 @@ open class RestXml(private val coreCodegenContext: CoreCodegenContext) : Protoco } override val httpBindingResolver: HttpBindingResolver = - HttpTraitHttpBindingResolver(coreCodegenContext.model, ProtocolContentTypes.consistent("application/xml")) + HttpTraitHttpBindingResolver(coreCodegenContext.model, ProtocolContentTypes("application/xml", "application/xml", "application/vnd.amazon.eventstream")) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index 1bbfc6482e..910fe3d5ce 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -36,10 +36,11 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.error.eventStreamErrorSymbol import software.amazon.smithy.rust.codegen.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticEventStreamUnionTrait +import software.amazon.smithy.rust.codegen.smithy.transformers.eventStreamErrors import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait import software.amazon.smithy.rust.codegen.util.hasTrait @@ -55,8 +56,12 @@ class EventStreamUnmarshallerGenerator( private val target: CodegenTarget, ) { private val unionSymbol = symbolProvider.toSymbol(unionShape) - private val operationErrorSymbol = operationShape.errorSymbol(symbolProvider) private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) + private val errorSymbol = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { + RuntimeType("MessageStreamError", smithyEventStream, "aws_smithy_http::event_stream").toSymbol() + } else { + unionShape.eventStreamErrorSymbol(symbolProvider).toSymbol() + } private val eventStreamSerdeModule = RustModule.private("event_stream_serde") private val codegenScope = arrayOf( "Blob" to RuntimeType("Blob", CargoDependency.SmithyTypes(runtimeConfig), "aws_smithy_types"), @@ -65,7 +70,7 @@ class EventStreamUnmarshallerGenerator( "Header" to RuntimeType("Header", smithyEventStream, "aws_smithy_eventstream::frame"), "HeaderValue" to RuntimeType("HeaderValue", smithyEventStream, "aws_smithy_eventstream::frame"), "Message" to RuntimeType("Message", smithyEventStream, "aws_smithy_eventstream::frame"), - "OpError" to operationErrorSymbol, + "OpError" to errorSymbol, "SmithyError" to RuntimeType("Error", CargoDependency.SmithyTypes(runtimeConfig), "aws_smithy_types"), "tracing" to CargoDependency.Tracing.asType(), "UnmarshalledMessage" to RuntimeType("UnmarshalledMessage", smithyEventStream, "aws_smithy_eventstream::frame"), @@ -99,7 +104,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope ) { rust("type Output = #T;", unionSymbol) - rust("type Error = #T;", operationErrorSymbol) + rust("type Error = #T;", errorSymbol) rustBlockTemplate( """ @@ -292,51 +297,111 @@ class EventStreamUnmarshallerGenerator( } private fun RustWriter.renderUnmarshallError() { - rustTemplate( - """ - let generic = match #{parse_generic_error}(message.payload()) { - Ok(generic) => generic, - Err(err) => return Ok(#{UnmarshalledMessage}::Error(#{OpError}::unhandled(err))), - }; - """, - "parse_generic_error" to protocol.parseEventStreamGenericError(operationShape), - *codegenScope - ) + when (target) { + CodegenTarget.CLIENT -> { + rustTemplate( + """ + let generic = match #{parse_generic_error}(message.payload()) { + Ok(generic) => generic, + Err(err) => return Ok(#{UnmarshalledMessage}::Error(#{OpError}::unhandled(err))), + }; + """, + "parse_generic_error" to protocol.parseEventStreamGenericError(operationShape), + *codegenScope + ) + } + CodegenTarget.SERVER -> {} + } val syntheticUnion = unionShape.expectTrait() if (syntheticUnion.errorMembers.isNotEmpty()) { - rustBlock("match response_headers.smithy_type.as_str()") { - for (member in syntheticUnion.errorMembers) { - val target = model.expectShape(member.target, StructureShape::class.java) - rustBlock("${member.memberName.dq()} => ") { - val parser = protocol.structuredDataParser(operationShape).errorParser(target) - if (parser != null) { - rust("let mut builder = #T::builder();", symbolProvider.toSymbol(target)) - // TODO(EventStream): Errors on the operation can be disjoint with errors in the union, - // so we need to generate a new top-level Error type for each event stream union. + // clippy::single-match implied, using if when there's only one error + val (header, matchOperator) = if (syntheticUnion.errorMembers.size > 1) { + listOf("match response_headers.smithy_type.as_str() {", "=>") + } else { + listOf("if response_headers.smithy_type.as_str() == ", "") + } + rust(header) + for (member in syntheticUnion.errorMembers) { + rustBlock("${member.memberName.dq()} $matchOperator ") { + // TODO(EventStream): Errors on the operation can be disjoint with errors in the union, + // so we need to generate a new top-level Error type for each event stream union. + when (target) { + CodegenTarget.CLIENT -> { + val target = model.expectShape(member.target, StructureShape::class.java) + val parser = protocol.structuredDataParser(operationShape).errorParser(target) + if (parser != null) { + rust("let mut builder = #T::builder();", symbolProvider.toSymbol(target)) + rustTemplate( + """ + builder = #{parser}(&message.payload()[..], builder) + .map_err(|err| { + #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) + })?; + return Ok(#{UnmarshalledMessage}::Error( + #{OpError}::new( + #{OpError}Kind::${member.target.name}(builder.build()), + generic, + ) + )) + """, + "parser" to parser, + *codegenScope + ) + } + } + CodegenTarget.SERVER -> { + val target = model.expectShape(member.target, StructureShape::class.java) + val parser = protocol.structuredDataParser(operationShape).errorParser(target) + val mut = if (parser != null) { " mut" } else { "" } + rust("let$mut builder = #T::builder();", symbolProvider.toSymbol(target)) + if (parser != null) { + rustTemplate( + """ + builder = #{parser}(&message.payload()[..], builder) + .map_err(|err| { + #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) + })?; + """, + "parser" to parser, + *codegenScope + ) + } rustTemplate( """ - builder = #{parser}(&message.payload()[..], builder) - .map_err(|err| { - #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) - })?; return Ok(#{UnmarshalledMessage}::Error( - #{OpError}::new( - #{OpError}Kind::${symbolProvider.toMemberName(member)}(builder.build()), - generic, + #{OpError}::${member.target.name}( + builder.build() ) )) """, - "parser" to parser, *codegenScope ) } } } + } + if (syntheticUnion.errorMembers.size > 1) { + // it's: match ... { rust("_ => {}") + rust("}") + } + } + when (target) { + CodegenTarget.CLIENT -> { + rustTemplate("Ok(#{UnmarshalledMessage}::Error(#{OpError}::generic(generic)))", *codegenScope) + } + CodegenTarget.SERVER -> { + rustTemplate( + """ + return Err(aws_smithy_eventstream::error::Error::Unmarshalling( + format!("unrecognized exception: {}", response_headers.smithy_type.as_str()), + )); + """, + *codegenScope + ) } } - rustTemplate("Ok(#{UnmarshalledMessage}::Error(#{OpError}::generic(generic)))", *codegenScope) } private fun UnionShape.eventStreamUnmarshallerType(): RuntimeType { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt new file mode 100644 index 0000000000..005a36dacf --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt @@ -0,0 +1,168 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.smithy.protocols.serialize + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EventHeaderTrait +import software.amazon.smithy.model.traits.EventPayloadTrait +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.RustModule +import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.render +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget +import software.amazon.smithy.rust.codegen.smithy.generators.error.eventStreamErrorSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.smithy.generators.unknownVariantError +import software.amazon.smithy.rust.codegen.smithy.rustType +import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticEventStreamUnionTrait +import software.amazon.smithy.rust.codegen.smithy.transformers.eventStreamErrors +import software.amazon.smithy.rust.codegen.util.dq +import software.amazon.smithy.rust.codegen.util.expectTrait +import software.amazon.smithy.rust.codegen.util.hasTrait +import software.amazon.smithy.rust.codegen.util.toPascalCase + +class EventStreamErrorMarshallerGenerator( + private val model: Model, + private val target: CodegenTarget, + runtimeConfig: RuntimeConfig, + private val symbolProvider: RustSymbolProvider, + private val unionShape: UnionShape, + private val operationShape: OperationShape, + private val serializerGenerator: StructuredDataSerializerGenerator, + private val payloadContentType: String, +) : EventStreamMarshallerGenerator(model, target, runtimeConfig, symbolProvider, unionShape, serializerGenerator, payloadContentType) { + private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) + private val operationErrorSymbol = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { + RuntimeType("MessageStreamError", smithyEventStream, "aws_smithy_http::event_stream").toSymbol() + } else { + unionShape.eventStreamErrorSymbol(symbolProvider).toSymbol() + } + private val eventStreamSerdeModule = RustModule.private("event_stream_serde") + val errorsShape = unionShape.expectTrait() + private val codegenScope = arrayOf( + "MarshallMessage" to RuntimeType("MarshallMessage", smithyEventStream, "aws_smithy_eventstream::frame"), + "Message" to RuntimeType("Message", smithyEventStream, "aws_smithy_eventstream::frame"), + "Header" to RuntimeType("Header", smithyEventStream, "aws_smithy_eventstream::frame"), + "HeaderValue" to RuntimeType("HeaderValue", smithyEventStream, "aws_smithy_eventstream::frame"), + "Error" to RuntimeType("Error", smithyEventStream, "aws_smithy_eventstream::error"), + ) + + override fun render(): RuntimeType { + val marshallerType = operationShape.eventStreamMarshallerType() + val unionSymbol = symbolProvider.toSymbol(unionShape) + + return RuntimeType.forInlineFun("${marshallerType.name}::new", eventStreamSerdeModule) { inlineWriter -> + inlineWriter.renderMarshaller(marshallerType, unionSymbol) + } + } + + private fun RustWriter.renderMarshaller(marshallerType: RuntimeType, unionSymbol: Symbol) { + rust( + """ + ##[non_exhaustive] + ##[derive(Debug)] + pub struct ${marshallerType.name}; + + impl ${marshallerType.name} { + pub fn new() -> Self { + ${marshallerType.name} + } + } + """ + ) + + rustBlockTemplate( + "impl #{MarshallMessage} for ${marshallerType.name}", + *codegenScope + ) { + rust("type Input = ${operationErrorSymbol.rustType().render(fullyQualified = true)};") + + rustBlockTemplate( + "fn marshall(&self, input: Self::Input) -> std::result::Result<#{Message}, #{Error}>", + *codegenScope + ) { + rust("let mut headers = Vec::new();") + addStringHeader(":message-type", """"exception".into()""") + val kind = when (target) { + CodegenTarget.CLIENT -> ".kind" + CodegenTarget.SERVER -> "" + } + if (errorsShape.errorMembers.isEmpty()) { + rustBlock("let payload = match input$kind") { + rust("_ => Vec::new()") + } + } else { + rustBlock("let payload = match input$kind") { + val symbol = operationErrorSymbol + val errorName = when (target) { + CodegenTarget.CLIENT -> "${symbol}Kind" + CodegenTarget.SERVER -> "$symbol" + } + + errorsShape.errorMembers.forEach { error -> + val errorSymbol = symbolProvider.toSymbol(error) + val errorString = error.memberName + val target = model.expectShape(error.target, StructureShape::class.java) + rustBlock("$errorName::${errorSymbol.name}(inner) => ") { + addStringHeader(":exception-type", "${errorString.dq()}.into()") + renderMarshallEvent(error, target) + } + } + if (target.renderUnknownVariant()) { + rustTemplate( + """ + $errorName::Unhandled(_inner) => return Err( + #{Error}::Marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) + ), + """, + *codegenScope + ) + } + } + } + rustTemplate("; Ok(#{Message}::new_from_parts(headers, payload))", *codegenScope) + } + } + } + + fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) { + val headerMembers = eventStruct.members().filter { it.hasTrait() } + val payloadMember = eventStruct.members().firstOrNull { it.hasTrait() } + for (member in headerMembers) { + val memberName = symbolProvider.toMemberName(member) + val target = model.expectShape(member.target) + renderMarshallEventHeader(memberName, member, target) + } + if (payloadMember != null) { + val memberName = symbolProvider.toMemberName(payloadMember) + val target = model.expectShape(payloadMember.target) + val serializerFn = serializerGenerator.serverErrorSerializer(payloadMember.toShapeId()) + renderMarshallEventPayload("inner.$memberName", payloadMember, target, serializerFn) + } else if (headerMembers.isEmpty()) { + val serializerFn = serializerGenerator.serverErrorSerializer(unionMember.target.toShapeId()) + renderMarshallEventPayload("inner", unionMember, eventStruct, serializerFn) + } else { + rust("Vec::new()") + } + } + + private fun OperationShape.eventStreamMarshallerType(): RuntimeType { + val symbol = symbolProvider.toSymbol(this) + return RuntimeType("${symbol.name.toPascalCase()}ErrorMarshaller", null, "crate::event_stream_serde") + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt index a12a2be9a8..70ce60909f 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt @@ -43,7 +43,7 @@ import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.toPascalCase -class EventStreamMarshallerGenerator( +open class EventStreamMarshallerGenerator( private val model: Model, private val target: CodegenTarget, runtimeConfig: RuntimeConfig, @@ -62,7 +62,7 @@ class EventStreamMarshallerGenerator( "Error" to RuntimeType("Error", smithyEventStream, "aws_smithy_eventstream::error"), ) - fun render(): RuntimeType { + open fun render(): RuntimeType { val marshallerType = unionShape.eventStreamMarshallerType() val unionSymbol = symbolProvider.toSymbol(unionShape) @@ -134,15 +134,17 @@ class EventStreamMarshallerGenerator( if (payloadMember != null) { val memberName = symbolProvider.toMemberName(payloadMember) val target = model.expectShape(payloadMember.target) - renderMarshallEventPayload("inner.$memberName", payloadMember, target) + val serializerFn = serializerGenerator.payloadSerializer(payloadMember) + renderMarshallEventPayload("inner.$memberName", payloadMember, target, serializerFn) } else if (headerMembers.isEmpty()) { - renderMarshallEventPayload("inner", unionMember, eventStruct) + val serializerFn = serializerGenerator.payloadSerializer(unionMember) + renderMarshallEventPayload("inner", unionMember, eventStruct, serializerFn) } else { rust("Vec::new()") } } - private fun RustWriter.renderMarshallEventHeader(memberName: String, member: MemberShape, target: Shape) { + protected fun RustWriter.renderMarshallEventHeader(memberName: String, member: MemberShape, target: Shape) { val headerName = member.memberName handleOptional( symbolProvider.toSymbol(member).isOptional(), @@ -175,7 +177,12 @@ class EventStreamMarshallerGenerator( else -> throw IllegalStateException("unsupported event stream header shape type: $target") } - private fun RustWriter.renderMarshallEventPayload(inputExpr: String, member: MemberShape, target: Shape) { + protected fun RustWriter.renderMarshallEventPayload( + inputExpr: String, + member: Shape, + target: Shape, + serializerFn: RuntimeType + ) { val optional = symbolProvider.toSymbol(member).isOptional() if (target is BlobShape || target is StringShape) { data class PayloadContext(val conversionFn: String, val contentType: String) @@ -196,7 +203,6 @@ class EventStreamMarshallerGenerator( } else { addStringHeader(":content-type", "${payloadContentType.dq()}.into()") - val serializerFn = serializerGenerator.payloadSerializer(member) handleOptional( optional, inputExpr, @@ -237,7 +243,7 @@ class EventStreamMarshallerGenerator( } } - private fun RustWriter.addStringHeader(name: String, valueExpr: String) { + protected fun RustWriter.addStringHeader(name: String, valueExpr: String) { rustTemplate("headers.push(#{Header}::new(${name.dq()}, #{HeaderValue}::String($valueExpr)));", *codegenScope) } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/EventStreamNormalizer.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/EventStreamNormalizer.kt index 7ea38e6708..b71f3228a9 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/EventStreamNormalizer.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/EventStreamNormalizer.kt @@ -6,12 +6,21 @@ package software.amazon.smithy.rust.codegen.smithy.transformers import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.OperationIndex +import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticEventStreamUnionTrait +import software.amazon.smithy.rust.codegen.util.expectTrait import software.amazon.smithy.rust.codegen.util.hasTrait +import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.isEventStream +import software.amazon.smithy.rust.codegen.util.outputShape /** * Generates synthetic unions to replace the modeled unions for Event Stream types. @@ -20,14 +29,35 @@ import software.amazon.smithy.rust.codegen.util.isEventStream */ object EventStreamNormalizer { fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape -> - if (shape is UnionShape && shape.isEventStream()) { - syntheticEquivalent(model, shape) + if (shape is OperationShape && shape.isEventStream(model)) { + addStreamErrorsToOperationErrors(model, shape) + } else if (shape is UnionShape && shape.isEventStream()) { + syntheticEquivalentEventStreamUnion(model, shape) } else { shape } } - private fun syntheticEquivalent(model: Model, union: UnionShape): UnionShape { + private fun addStreamErrorsToOperationErrors(model: Model, operation: OperationShape): OperationShape { + if (!operation.isEventStream(model)) { + return operation + } + val getStreamErrors = { shape: ShapeId -> + model.expectShape(shape).members() + .filter { it.isEventStream(model) } + .map { model.expectShape(it.target) }.flatMap { it.members() } + .filter { model.expectShape(it.target).hasTrait() } + .map { model.expectShape(it.target).id } + } + val inputs = operation.input.map { getStreamErrors(it) } + val outputs = operation.output.map { getStreamErrors(it) } + val streamErrors = inputs.orElse(listOf()) + outputs.orElse(listOf()) + return operation.toBuilder() + .addErrors(streamErrors.map { it.toShapeId() }) + .build() + } + + private fun syntheticEquivalentEventStreamUnion(model: Model, union: UnionShape): UnionShape { val (errorMembers, eventMembers) = union.members().partition { member -> model.expectShape(member.target).hasTrait() } @@ -37,3 +67,33 @@ object EventStreamNormalizer { .build() } } + +fun OperationShape.operationErrors(model: Model): List { + val operationIndex = OperationIndex.of(model) + return operationIndex.getErrors(this) +} + +fun eventStreamErrors(model: Model, shape: Shape): Map> { + return Walker(model).walkShapes(shape) + .filter { it is UnionShape && it.isEventStream() } + .map { it.asUnionShape().get() } + .associateWith { unionShape -> + unionShape.expectTrait().errorMembers + .map { model.expectShape(it.target).asStructureShape().get() } + } +} + +fun UnionShape.eventStreamErrors(): List { + if (!this.isEventStream()) { + return listOf() + } + return this.expectTrait().errorMembers +} + +fun OperationShape.eventStreamErrors(model: Model): Map> { + return eventStreamErrors(model, inputShape(model)) + eventStreamErrors(model, outputShape(model)) +} + +fun OperationShape.allErrors(model: Model): List { + return this.eventStreamErrors(model).values.flatten() + this.operationErrors(model) +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RemoveEventStreamOperations.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RemoveEventStreamOperations.kt index f3c16c9417..8eef2c644a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RemoveEventStreamOperations.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RemoveEventStreamOperations.kt @@ -21,7 +21,8 @@ object RemoveEventStreamOperations { fun transform(model: Model, settings: CoreRustSettings): Model { // If Event Stream is allowed in build config, then don't remove the operations - if (settings.codegenConfig.eventStreamAllowList.contains(settings.moduleName)) { + val allowList = settings.codegenConfig.eventStreamAllowList + if (allowList.isEmpty() || allowList.contains(settings.moduleName)) { return model } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt index 5c9b2ed202..a9b9149ca2 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.StreamingTrait import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait inline fun Model.lookup(shapeId: String): T { return this.expectShape(ShapeId.from(shapeId), T::class.java) @@ -57,7 +58,7 @@ fun MemberShape.isInputEventStream(model: Model): Boolean { } fun MemberShape.isOutputEventStream(model: Model): Boolean { - return isEventStream(model) && model.expectShape(container).hasTrait() + return isEventStream(model) && model.expectShape(container).hasTrait() } fun Shape.hasEventStreamMember(model: Model): Boolean { diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/customizations/HttpVersionListGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/customizations/HttpVersionListGeneratorTest.kt index 93d91e4501..450b130cf1 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/customizations/HttpVersionListGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/customizations/HttpVersionListGeneratorTest.kt @@ -299,8 +299,8 @@ class FakeSigningConfig( Ok(message) } - fn sign_empty(&mut self) -> Result<#{Message}, #{SignMessageError}> { - Ok(#{Message}::new(Vec::new())) + fn sign_empty(&mut self) -> Option> { + Some(Ok(#{Message}::new(Vec::new()))) } } """, diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProviderTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProviderTest.kt index f05fcce5ad..7deacbdd47 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProviderTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProviderTest.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.rustlang.RustType +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.testutil.TestSymbolVisitorConfig @@ -42,7 +43,7 @@ class EventStreamSymbolProviderTest { ) val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, TestSymbolVisitorConfig), model) + val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, TestSymbolVisitorConfig), model, CodegenTarget.CLIENT) // Look up the synthetic input/output rather than the original input/output val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape @@ -51,7 +52,7 @@ class EventStreamSymbolProviderTest { val inputType = provider.toSymbol(inputStream).rustType() val outputType = provider.toSymbol(outputStream).rustType() - inputType shouldBe RustType.Opaque("EventStreamInput", "aws_smithy_http::event_stream") + inputType shouldBe RustType.Opaque("EventStreamSender", "aws_smithy_http::event_stream") outputType shouldBe RustType.Opaque("Receiver", "aws_smithy_http::event_stream") } @@ -78,7 +79,7 @@ class EventStreamSymbolProviderTest { ) val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, TestSymbolVisitorConfig), model) + val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, TestSymbolVisitorConfig), model, CodegenTarget.CLIENT) // Look up the synthetic input/output rather than the original input/output val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape diff --git a/rust-runtime/aws-smithy-eventstream/src/frame.rs b/rust-runtime/aws-smithy-eventstream/src/frame.rs index 4cdb95099f..6732ae72a5 100644 --- a/rust-runtime/aws-smithy-eventstream/src/frame.rs +++ b/rust-runtime/aws-smithy-eventstream/src/frame.rs @@ -27,7 +27,19 @@ pub type SignMessageError = Box; pub trait SignMessage: fmt::Debug { fn sign(&mut self, message: Message) -> Result; - fn sign_empty(&mut self) -> Result; + fn sign_empty(&mut self) -> Option>; +} + +#[derive(Debug)] +pub struct NoOpSigner {} +impl SignMessage for NoOpSigner { + fn sign(&mut self, message: Message) -> Result { + Ok(message) + } + + fn sign_empty(&mut self) -> Option> { + None + } } /// Converts a Smithy modeled Event Stream type into a [`Message`](Message). diff --git a/rust-runtime/aws-smithy-http-server/examples/Cargo.toml b/rust-runtime/aws-smithy-http-server/examples/Cargo.toml index 2102e16a70..7dc0f10a1a 100644 --- a/rust-runtime/aws-smithy-http-server/examples/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/examples/Cargo.toml @@ -3,7 +3,7 @@ members = [ "pokemon_service", "pokemon_service_sdk", - "pokemon_service_client" + "pokemon_service_client", ] [profile.release] diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/Cargo.toml b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/Cargo.toml index 5a2e295e42..0c788ab91e 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/Cargo.toml @@ -7,8 +7,10 @@ authors = ["Smithy-rs Server Team "] description = "A smithy Rust service to retrieve information about Pokémon." [dependencies] +async-stream = "0.3" clap = { version = "~3.2.1", features = ["derive"] } hyper = {version = "0.14", features = ["server"] } +rand = "0.8" tokio = "1" tower = "0.4" tower-http = { version = "0.3", features = ["trace"] } @@ -22,6 +24,7 @@ pokemon_service_sdk = { path = "../pokemon_service_sdk/" } [dev-dependencies] assert_cmd = "2.0" home = "0.5" +serial_test = "0.7.0" wrk-api-bench = "0.0.7" # Local paths diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/lib.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/lib.rs index d23a7b6ada..555e3c42ad 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/lib.rs @@ -13,8 +13,10 @@ use std::{ sync::{atomic::AtomicU64, Arc}, }; +use async_stream::stream; use aws_smithy_http_server::Extension; -use pokemon_service_sdk::{error, input, model, output}; +use pokemon_service_sdk::{error, input, model, model::CapturingPayload, output, types::Blob}; +use rand::Rng; use tracing_subscriber::{prelude::*, EnvFilter}; const PIKACHU_ENGLISH_FLAVOR_TEXT: &str = @@ -191,6 +193,74 @@ pub async fn get_server_statistics( output::GetServerStatisticsOutput { calls_count } } +/// Attempts to capture a Pokémon +pub async fn capture_pokemon( + mut input: input::CapturePokemonOperationInput, +) -> Result { + if input.region != "Kanto" { + return Err(error::CapturePokemonOperationError::UnsupportedRegionError( + error::UnsupportedRegionError::builder().build(), + )); + } + let output_stream = stream! { + loop { + use std::time::Duration; + match input.events.recv().await { + Ok(maybe_event) => match maybe_event { + Some(event) => { + let capturing_event = event.as_event(); + // TODO: verify the events from the Pokémon trainer + if let Ok(attempt) = capturing_event { + let payload = attempt.payload.clone().unwrap_or(CapturingPayload::builder().build()); + let pokeball = payload.pokeball.as_ref().map(|ball| ball.as_str()).unwrap_or(""); + if ! matches!(pokeball, "Master Ball" | "Great Ball" | "Fast Ball") { + yield Err( + crate::error::CapturePokemonEventsError::InvalidPokeballError( + crate::error::InvalidPokeballError::builder().pokeball(pokeball).build() + ) + ); + } else { + let captured = match pokeball { + "Master Ball" => true, + "Great Ball" => rand::thread_rng().gen_range(0..100) > 33, + "Fast Ball" => rand::thread_rng().gen_range(0..100) > 66, + _ => unreachable!("invalid pokeball"), + }; + // Only support Kanto + tokio::time::sleep(Duration::from_millis(1000)).await; + // Will it capture the Pokémon? + if captured { + let shiny = rand::thread_rng().gen_range(0..4096) == 0; + let pokemon = payload + .name + .as_ref() + .map(|name| name.as_str()) + .unwrap_or("") + .to_string(); + let pokedex: Vec = (0..255).collect(); + yield Ok(crate::model::CapturePokemonEvents::Event( + crate::model::CaptureEvent::builder() + .name(pokemon) + .shiny(shiny) + .pokedex_update(Blob::new(pokedex)) + .build(), + )); + } + } + } + } + None => break, + }, + Err(e) => println!("{:?}", e), + } + } + }; + Ok(output::CapturePokemonOperationOutput::builder() + .events(output_stream.into()) + .build() + .unwrap()) +} + /// Empty operation used to benchmark the service. pub async fn empty_operation(_input: input::EmptyOperationInput) -> output::EmptyOperationOutput { output::EmptyOperationOutput {} diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/main.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/main.rs index e84559daa3..a2a6f19a7c 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/main.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/main.rs @@ -8,7 +8,9 @@ use std::{net::SocketAddr, sync::Arc}; use aws_smithy_http_server::{AddExtensionLayer, Router}; use clap::Parser; -use pokemon_service::{empty_operation, get_pokemon_species, get_server_statistics, setup_tracing, State}; +use pokemon_service::{ + capture_pokemon, empty_operation, get_pokemon_species, get_server_statistics, setup_tracing, State, +}; use pokemon_service_sdk::operation_registry::OperationRegistryBuilder; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; @@ -34,6 +36,7 @@ pub async fn main() { // return the operation's output. .get_pokemon_species(get_pokemon_species) .get_server_statistics(get_server_statistics) + .capture_pokemon_operation(capture_pokemon) .empty_operation(empty_operation) .build() .expect("Unable to build operation registry") diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/tests/simple_integration_test.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/tests/simple_integration_test.rs index 67c5831bfa..a82e6e0bb5 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/tests/simple_integration_test.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/tests/simple_integration_test.rs @@ -10,15 +10,42 @@ use std::time::Duration; use crate::helpers::{client, PokemonService}; +use async_stream::stream; +use pokemon_service_client::{ + error::AttemptCapturingPokemonEventError, error::AttemptCapturingPokemonEventErrorKind, + error::MasterBallUnsuccessful, model::AttemptCapturingPokemonEvent, model::CapturingEvent, model::CapturingPayload, +}; +use rand::Rng; +use serial_test::serial; use tokio::time; mod helpers; +fn get_pokeball() -> String { + let random = rand::thread_rng().gen_range(0..100); + let pokeball = if random < 5 { + "Master Ball" + } else if random < 30 { + "Great Ball" + } else if random < 80 { + "Fast Ball" + } else { + "Smithy Ball" + }; + pokeball.to_string() +} + +fn get_pokemon_to_capture() -> String { + let pokemons = vec!["Charizard", "Pikachu", "Regieleki"]; + pokemons[rand::thread_rng().gen_range(0..pokemons.len())].to_string() +} + #[tokio::test] +#[serial] async fn simple_integration_test() { let _program = PokemonService::run(); - // Give PokemonSérvice some time to start up. - time::sleep(Duration::from_millis(50)).await; + // Give PokémonService some time to start up. + time::sleep(Duration::from_millis(500)).await; let service_statistics_out = client().get_server_statistics().send().await.unwrap(); assert_eq!(0, service_statistics_out.calls_count.unwrap()); @@ -43,3 +70,123 @@ async fn simple_integration_test() { let service_statistics_out = client().get_server_statistics().send().await.unwrap(); assert_eq!(2, service_statistics_out.calls_count.unwrap()); } + +#[tokio::test] +#[serial] +async fn event_stream_test() { + let _program = PokemonService::run(); + // Give PokémonService some time to start up. + time::sleep(Duration::from_millis(500)).await; + + let mut team = vec![]; + let input_stream = stream! { + // Always Pikachu + yield Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload(CapturingPayload::builder() + .name("Pikachu") + .pokeball("Master Ball") + .build()) + .build() + )); + yield Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload(CapturingPayload::builder() + .name("Regieleki") + .pokeball("Fast Ball") + .build()) + .build() + )); + yield Err(AttemptCapturingPokemonEventError::new( + AttemptCapturingPokemonEventErrorKind::MasterBallUnsuccessful(MasterBallUnsuccessful::builder().build()), + Default::default() + )); + // The next event should not happen + yield Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload(CapturingPayload::builder() + .name("Charizard") + .pokeball("Great Ball") + .build()) + .build() + )); + }; + + // Throw many! + let mut output = client() + .capture_pokemon_operation() + .region("Kanto") + .events(input_stream.into()) + .send() + .await + .unwrap(); + loop { + match output.events.recv().await { + Ok(Some(capture)) => { + let pokemon = capture.as_event().unwrap().name.as_ref().unwrap().clone(); + let pokedex = capture.as_event().unwrap().pokedex_update.as_ref().unwrap().clone(); + let shiny = if *capture.as_event().unwrap().shiny.as_ref().unwrap() { + "" + } else { + "not " + }; + let expected_pokedex: Vec = (0..255).collect(); + println!("captured {} ({}shiny)", pokemon, shiny); + if expected_pokedex == pokedex.into_inner() { + println!("pokedex updated") + } + team.push(pokemon); + } + Err(e) => { + println!("error from the server: {:?}", e); + break; + } + Ok(None) => break, + } + } + + while team.len() < 6 { + let pokeball = get_pokeball(); + let pokemon = get_pokemon_to_capture(); + let input_stream = stream! { + yield Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload(CapturingPayload::builder() + .name(pokemon) + .pokeball(pokeball) + .build()) + .build() + )) + }; + let mut output = client() + .capture_pokemon_operation() + .region("Kanto") + .events(input_stream.into()) + .send() + .await + .unwrap(); + match output.events.recv().await { + Ok(Some(capture)) => { + let pokemon = capture.as_event().unwrap().name.as_ref().unwrap().clone(); + let pokedex = capture.as_event().unwrap().pokedex_update.as_ref().unwrap().clone(); + let shiny = if *capture.as_event().unwrap().shiny.as_ref().unwrap() { + "" + } else { + "not " + }; + let expected_pokedex: Vec = (0..255).collect(); + println!("captured {} ({}shiny)", pokemon, shiny); + if expected_pokedex == pokedex.into_inner() { + println!("pokedex updated") + } + team.push(pokemon); + } + Err(e) => { + println!("error from the server: {:?}", e); + break; + } + Ok(None) => {} + } + } + println!("Team: {:?}", team); +} diff --git a/rust-runtime/aws-smithy-http/src/event_stream.rs b/rust-runtime/aws-smithy-http/src/event_stream.rs index bc0f97ccd6..ae6176cbd2 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream.rs @@ -7,13 +7,13 @@ use std::error::Error as StdError; -mod input; -mod output; +mod receiver; +mod sender; pub type BoxError = Box; #[doc(inline)] -pub use input::{EventStreamInput, MessageStreamAdapter}; +pub use sender::{EventStreamSender, MessageStreamAdapter, MessageStreamError}; #[doc(inline)] -pub use output::{Error, RawMessage, Receiver}; +pub use receiver::{Error, RawMessage, Receiver}; diff --git a/rust-runtime/aws-smithy-http/src/event_stream/output.rs b/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs similarity index 95% rename from rust-runtime/aws-smithy-http/src/event_stream/output.rs rename to rust-runtime/aws-smithy-http/src/event_stream/receiver.rs index be25474c20..ee728d4b14 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/output.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs @@ -23,24 +23,26 @@ enum RecvBuf { /// Nothing has been buffered yet. Empty, /// Some data has been buffered. - /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary + /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary. Partial(SegmentedBuf), /// The end of the stream has been reached, but there may still be some buffered data. EosPartial(SegmentedBuf), + /// An exception terminated this stream. + Terminated, } impl RecvBuf { /// Returns true if there's more buffered data. fn has_data(&self) -> bool { match self { - RecvBuf::Empty => false, + RecvBuf::Empty | RecvBuf::Terminated => false, RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0, } } /// Returns true if the stream has ended. fn is_eos(&self) -> bool { - matches!(self, RecvBuf::EosPartial(_)) + matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated) } /// Returns a mutable reference to the underlying buffered data. @@ -49,6 +51,7 @@ impl RecvBuf { RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"), RecvBuf::Partial(segmented) => segmented, RecvBuf::EosPartial(segmented) => segmented, + RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"), } } @@ -65,8 +68,8 @@ impl RecvBuf { segmented.push(partial); RecvBuf::Partial(segmented) } - RecvBuf::EosPartial(_) => { - panic!("cannot buffer more data after the stream has ended; this is a bug") + RecvBuf::EosPartial(_) | RecvBuf::Terminated => { + panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug") } } } @@ -77,6 +80,7 @@ impl RecvBuf { RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()), RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented), RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"), + RecvBuf::Terminated => panic!("stream terminated; this is a bug"), } } } @@ -239,10 +243,22 @@ impl Receiver { /// messages. pub async fn recv(&mut self) -> Result, SdkError> { if let Some(buffered) = self.buffered_message.take() { - return self.unmarshall(buffered); + return match self.unmarshall(buffered) { + Ok(message) => Ok(message), + Err(error) => { + self.buffer = RecvBuf::Terminated; + Err(error) + } + }; } if let Some(message) = self.next_message().await? { - self.unmarshall(message) + match self.unmarshall(message) { + Ok(message) => Ok(message), + Err(error) => { + self.buffer = RecvBuf::Terminated; + Err(error) + } + } } else { Ok(None) } diff --git a/rust-runtime/aws-smithy-http/src/event_stream/input.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs similarity index 56% rename from rust-runtime/aws-smithy-http/src/event_stream/input.rs rename to rust-runtime/aws-smithy-http/src/event_stream/sender.rs index afbc320cdd..5143162bb7 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/input.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -3,76 +3,121 @@ * SPDX-License-Identifier: Apache-2.0 */ -use super::BoxError; use crate::result::SdkError; use aws_smithy_eventstream::frame::{MarshallMessage, SignMessage}; use bytes::Bytes; use futures_core::Stream; use std::error::Error as StdError; use std::fmt; +use std::fmt::Debug; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; /// Input type for Event Streams. -pub struct EventStreamInput { - input_stream: Pin> + Send>>, +pub struct EventStreamSender { + input_stream: Pin> + Send>>, } -impl fmt::Debug for EventStreamInput { +impl Debug for EventStreamSender { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "EventStreamInput(Box)") + write!(f, "EventStreamSender(Box)") } } -impl EventStreamInput { +impl EventStreamSender { #[doc(hidden)] - pub fn into_body_stream( + pub fn into_body_stream( self, marshaller: impl MarshallMessage + Send + Sync + 'static, + error_marshaller: impl MarshallMessage + Send + Sync + 'static, signer: impl SignMessage + Send + Sync + 'static, ) -> MessageStreamAdapter { - MessageStreamAdapter::new(marshaller, signer, self.input_stream) + MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream) } } -impl From for EventStreamInput +impl From for EventStreamSender where - S: Stream> + Send + 'static, + S: Stream> + Send + 'static, { fn from(stream: S) -> Self { - EventStreamInput { + EventStreamSender { input_stream: Box::pin(stream), } } } +#[derive(Debug)] +pub struct MessageStreamError { + kind: MessageStreamErrorKind, + pub(crate) meta: aws_smithy_types::Error, +} + +#[derive(Debug)] +pub enum MessageStreamErrorKind { + Unhandled(Box), +} + +impl MessageStreamError { + /// Creates the `MessageStreamError::Unhandled` variant from any error type. + pub fn unhandled(err: impl Into>) -> Self { + Self { + meta: Default::default(), + kind: MessageStreamErrorKind::Unhandled(err.into()), + } + } + + /// Creates the `MessageStreamError::Unhandled` variant from a `aws_smithy_types::Error`. + pub fn generic(err: aws_smithy_types::Error) -> Self { + Self { + meta: err.clone(), + kind: MessageStreamErrorKind::Unhandled(err.into()), + } + } + + /// Returns error metadata, which includes the error code, message, + /// request ID, and potentially additional information. + pub fn meta(&self) -> &aws_smithy_types::Error { + &self.meta + } +} + +impl StdError for MessageStreamError {} +impl fmt::Display for MessageStreamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + MessageStreamErrorKind::Unhandled(inner) => std::fmt::Display::fmt(inner, f), + } + } +} + /// Adapts a `Stream` to a signed `Stream` by using the provided /// message marshaller and signer implementations. /// /// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be /// marshalled into an Event Stream frame, (e.g., if the message payload was too large). -pub struct MessageStreamAdapter { +pub struct MessageStreamAdapter { marshaller: Box + Send + Sync>, + error_marshaller: Box + Send + Sync>, signer: Box, - stream: Pin> + Send>>, + stream: Pin> + Send>>, end_signal_sent: bool, _phantom: PhantomData, } -impl Unpin for MessageStreamAdapter {} +impl Unpin for MessageStreamAdapter {} -impl MessageStreamAdapter -where - E: StdError + Send + Sync + 'static, -{ +impl MessageStreamAdapter { pub fn new( marshaller: impl MarshallMessage + Send + Sync + 'static, + error_marshaller: impl MarshallMessage + Send + Sync + 'static, signer: impl SignMessage + Send + Sync + 'static, - stream: Pin> + Send>>, + stream: Pin> + Send>>, ) -> Self { MessageStreamAdapter { marshaller: Box::new(marshaller), + error_marshaller: Box::new(error_marshaller), signer: Box::new(signer), stream, end_signal_sent: false, @@ -81,22 +126,23 @@ where } } -impl Stream for MessageStreamAdapter -where - E: StdError + Send + Sync + 'static, -{ +impl Stream for MessageStreamAdapter { type Item = Result>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.stream.as_mut().poll_next(cx) { Poll::Ready(message_option) => { if let Some(message_result) = message_option { - let message_result = - message_result.map_err(|err| SdkError::ConstructionFailure(err)); - let message = self - .marshaller - .marshall(message_result?) - .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; + let message = match message_result { + Ok(message) => self + .marshaller + .marshall(message) + .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?, + Err(message) => self + .error_marshaller + .marshall(message) + .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?, + }; let message = self .signer .sign(message) @@ -109,12 +155,15 @@ where } else if !self.end_signal_sent { self.end_signal_sent = true; let mut buffer = Vec::new(); - self.signer - .sign_empty() - .map_err(|err| SdkError::ConstructionFailure(err))? - .write_to(&mut buffer) - .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; - Poll::Ready(Some(Ok(Bytes::from(buffer)))) + match self.signer.sign_empty() { + Some(sign) => { + sign.map_err(|err| SdkError::ConstructionFailure(err))? + .write_to(&mut buffer) + .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; + Poll::Ready(Some(Ok(Bytes::from(buffer)))) + } + None => Poll::Ready(None), + } } else { Poll::Ready(None) } @@ -127,7 +176,7 @@ where #[cfg(test)] mod tests { use super::MarshallMessage; - use crate::event_stream::{EventStreamInput, MessageStreamAdapter}; + use crate::event_stream::{EventStreamSender, MessageStreamAdapter}; use crate::result::SdkError; use async_stream::stream; use aws_smithy_eventstream::error::Error as EventStreamError; @@ -179,8 +228,10 @@ mod tests { Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true)))) } - fn sign_empty(&mut self) -> Result { - Ok(Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))) + fn sign_empty(&mut self) -> Option> { + Some(Ok( + Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true))) + )) } } @@ -188,7 +239,7 @@ mod tests { where S: Stream> + Send + 'static, O: Into + 'static, - E: Into> + 'static, + E: Into> + 'static, { stream } @@ -198,13 +249,15 @@ mod tests { let stream = stream! { yield Ok(TestMessage("test".into())); }; - let mut adapter = - check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< - TestMessage, - TestServiceError, - >::new( - Marshaller, TestSigner, Box::pin(stream) - )); + let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< + TestMessage, + TestServiceError, + >::new( + Marshaller, + Some(Marshaller), + TestSigner, + Box::pin(stream), + )); let mut sent_bytes = adapter.next().await.unwrap().unwrap(); let sent = Message::read_from(&mut sent_bytes).unwrap(); @@ -225,13 +278,15 @@ mod tests { let stream = stream! { yield Err(EventStreamError::InvalidMessageLength.into()); }; - let mut adapter = - check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< - TestMessage, - TestServiceError, - >::new( - Marshaller, TestSigner, Box::pin(stream) - )); + let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< + TestMessage, + TestServiceError, + >::new( + Marshaller, + Some(Marshaller), + TestSigner, + Box::pin(stream), + )); let result = adapter.next().await.unwrap(); assert!(result.is_err()); @@ -244,8 +299,8 @@ mod tests { // Verify the developer experience for this compiles #[allow(unused)] fn event_stream_input_ergonomics() { - fn check(input: impl Into>) { - let _: EventStreamInput = input.into(); + fn check(input: impl Into>) { + let _: EventStreamSender = input.into(); } check(stream! { yield Ok(TestMessage("test".into()));