diff --git a/codegen-server-test/model/pokemon.smithy b/codegen-server-test/model/pokemon.smithy index 6d4ead7eee9..72469895afd 100644 --- a/codegen-server-test/model/pokemon.smithy +++ b/codegen-server-test/model/pokemon.smithy @@ -27,7 +27,7 @@ resource PokemonSpecies { operation CapturePokemonOperation { input: CapturePokemonOperationEventsInput, output: CapturePokemonOperationEventsOutput, - errors: [UnsupportedRegionError] + errors: [UnsupportedRegionError, MasterBallUnsuccessful] } @input diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt index 7af59162e68..58dbfa481b1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt @@ -22,7 +22,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.smithy.generators.error.eventStreamErrorSymbol import software.amazon.smithy.rust.codegen.smithy.transformers.getStreamErrors -import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.util.isEventStream import software.amazon.smithy.rust.codegen.util.toSnakeCase @@ -36,7 +36,7 @@ open class ServerCombinedErrorGenerator( private val operation: OperationShape ) { fun render(writer: RustWriter) { - val errors = operation.nonEventStreamErrors(model) + val errors = operation.operationErrors(model) val symbol = operation.errorSymbol(symbolProvider) if (errors.isNotEmpty()) { renderErrors(writer, errors.map { it.asStructureShape().get() }, symbol) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt index 320f8a41630..89d49df9cb3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt @@ -17,7 +17,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBou import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol -import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape @@ -140,7 +140,7 @@ open class ServerOperationHandlerGenerator( } else { "Fun: FnOnce($inputName) -> Fut + Clone + Send + 'static," } - val outputType = if (operation.nonEventStreamErrors(model).isNotEmpty()) { + val outputType = if (operation.operationErrors(model).isNotEmpty()) { "Result<${symbolProvider.toSymbol(operation.outputShape(model)).fullName}, ${operation.errorSymbol(symbolProvider).fullyQualifiedName()}>" } else { symbolProvider.toSymbol(operation.outputShape(model)).fullName diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 8499eab2245..514d0485171 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -59,7 +59,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.toOptional -import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.smithy.wrapOptional import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait @@ -218,7 +218,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val outputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" val errorSymbol = operationShape.errorSymbol(symbolProvider) - if (operationShape.nonEventStreamErrors(model).isNotEmpty()) { + if (operationShape.operationErrors(model).isNotEmpty()) { // The output of fallible operations is a `Result` which we convert into an // isomorphic `enum` type we control that can in turn be converted into a response. val intoResponseImpl = @@ -304,7 +304,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( } // Implement conversion function to "wrap" from the model operation output types. - if (operationShape.nonEventStreamErrors(model).isNotEmpty()) { + if (operationShape.operationErrors(model).isNotEmpty()) { rustTemplate( """ impl #{From}> for $outputName { @@ -443,7 +443,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) withBlock("match error {", "}") { - val errors = operationShape.nonEventStreamErrors(model) + val errors = operationShape.operationErrors(model) errors.forEach { val variantShape = model.expectShape(it.id, StructureShape::class.java) val errorTrait = variantShape.expectTrait() diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/CombinedErrorGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/CombinedErrorGenerator.kt index 5b3c78e60f6..d9ac691d22b 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/CombinedErrorGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/CombinedErrorGenerator.kt @@ -27,7 +27,7 @@ import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.customize.Section import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.transformers.getStreamErrors -import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.isEventStream import software.amazon.smithy.rust.codegen.util.toSnakeCase @@ -68,7 +68,7 @@ class CombinedErrorGenerator( fun render(writer: RustWriter) { val symbol = operation.errorSymbol(symbolProvider) - renderErrors(writer, operation.nonEventStreamErrors(model).map { it.asStructureShape().get() }.toMutableList(), symbol) + renderErrors(writer, operation.operationErrors(model).map { it.asStructureShape().get() }.toMutableList(), symbol) if (operation.isEventStream(model)) { val clientErrors = operation.getStreamErrors(model, CodegenTarget.CLIENT).map { it.asStructureShape().get() } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt index abaf650bc05..0ae3cce3e66 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt @@ -25,7 +25,7 @@ import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.smithy.transformers.getStreamErrors -import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors /** * Each service defines its own "top-level" error combining all possible errors that a service can emit. @@ -78,7 +78,7 @@ class TopLevelErrorGenerator(codegenContext: CodegenContext, private val operati } private fun RustWriter.renderImplFrom(operationShape: OperationShape) { - val nonEventStreamErrors = operationShape.nonEventStreamErrors(model).map { it.id } + val nonEventStreamErrors = operationShape.operationErrors(model).map { it.id } val allErrors: List>> = listOf( Pair(operationShape.errorSymbol(symbolProvider), nonEventStreamErrors), Pair( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt index 858ecdf0d95..6fb2e960e00 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -32,7 +32,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolTr import software.amazon.smithy.rust.codegen.smithy.generators.setterName import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember -import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors +import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.util.UNREACHABLE import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.hasStreamingMember @@ -168,7 +168,7 @@ class HttpBoundProtocolTraitImplGenerator( protocol.parseHttpGenericError(operationShape), errorSymbol ) - if (operationShape.nonEventStreamErrors(model).isNotEmpty()) { + if (operationShape.operationErrors(model).isNotEmpty()) { rustTemplate( """ let error_code = match generic.code() { @@ -181,7 +181,7 @@ class HttpBoundProtocolTraitImplGenerator( "error_symbol" to errorSymbol, ) withBlock("Err(match error_code {", "})") { - val errors = operationShape.nonEventStreamErrors(model) + val errors = operationShape.operationErrors(model) errors.forEach { error -> val errorShape = model.expectShape(error.id, StructureShape::class.java) val variantName = symbolProvider.toSymbol(model.expectShape(error.id)).name diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/EventStreamNormalizer.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/EventStreamNormalizer.kt index eb22d8fe7d7..6432ca82101 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/EventStreamNormalizer.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/EventStreamNormalizer.kt @@ -7,10 +7,13 @@ package software.amazon.smithy.rust.codegen.smithy.transformers import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.OperationIndex +import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.AnnotationTrait import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget @@ -26,10 +29,18 @@ import software.amazon.smithy.rust.codegen.util.isEventStream */ object EventStreamNormalizer { fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape -> - if (shape is OperationShape && shape.isEventStream(model)) { - addStreamErrorsToOperationErrors(model, shape) + if (shape is OperationShape) { + val newErrors = shape.errors + .map { model.expectShape(it, StructureShape::class.java) } + .map { it.toBuilder().addTrait(SyntheticOperationErrorTrait()).build() } + val newShape = shape.toBuilder().errors(newErrors.map { it.id }).build() + if (newShape.isEventStream(model)) { + addStreamErrorsToOperationErrors(model, newShape) + } else { + newShape + } } else if (shape is UnionShape && shape.isEventStream()) { - syntheticEquivalent(model, shape) + syntheticEquivalentEventStreamUnion(model, shape) } else { shape } @@ -54,7 +65,7 @@ object EventStreamNormalizer { .build() } - private fun syntheticEquivalent(model: Model, union: UnionShape): UnionShape { + private fun syntheticEquivalentEventStreamUnion(model: Model, union: UnionShape): UnionShape { val (errorMembers, eventMembers) = union.members().partition { member -> model.expectShape(member.target).hasTrait() } @@ -82,14 +93,20 @@ fun OperationShape.getStreamErrors(model: Model, target: CodegenTarget): List inputOutput.filter { it.expectTrait().isClientError } } } -fun OperationShape.nonEventStreamErrors(model: Model): List { +fun OperationShape.operationErrors(model: Model): List { val operationIndex = OperationIndex.of(model) - val eventStreamErrors = this.eventStreamErrors(model) - return operationIndex.getErrors(this).filter { !eventStreamErrors.contains(it) } + return operationIndex.getErrors(this) + .filter { it.hasTrait() } } fun OperationShape.eventStreamErrors(model: Model): List { return this.getStreamErrors(model, CodegenTarget.CLIENT) + this.getStreamErrors(model, CodegenTarget.SERVER) } fun OperationShape.allErrors(model: Model): List { - return this.eventStreamErrors(model) + this.nonEventStreamErrors(model) + return this.eventStreamErrors(model) + this.operationErrors(model) +} + +class SyntheticOperationErrorTrait : AnnotationTrait(ID, Node.objectNode()) { + companion object { + val ID = ShapeId.from("smithy.api.internal#syntheticOperationErrorTrait") + } }