Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Daniele Ahmed <[email protected]>
  • Loading branch information
82marbag authored and Daniele Ahmed committed Jul 8, 2022
1 parent 75cfd62 commit 2c4db07
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 24 deletions.
2 changes: 1 addition & 1 deletion codegen-server-test/model/pokemon.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ resource PokemonSpecies {
operation CapturePokemonOperation {
input: CapturePokemonOperationEventsInput,
output: CapturePokemonOperationEventsOutput,
errors: [UnsupportedRegionError]
errors: [UnsupportedRegionError, MasterBallUnsuccessful]
}

@input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.error.eventStreamErrorSymbol
import software.amazon.smithy.rust.codegen.smithy.transformers.getStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.util.isEventStream
import software.amazon.smithy.rust.codegen.util.toSnakeCase

Expand All @@ -36,7 +36,7 @@ open class ServerCombinedErrorGenerator(
private val operation: OperationShape
) {
fun render(writer: RustWriter) {
val errors = operation.nonEventStreamErrors(model)
val errors = operation.operationErrors(model)
val symbol = operation.errorSymbol(symbolProvider)
if (errors.isNotEmpty()) {
renderErrors(writer, errors.map { it.asStructureShape().get() }, symbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBou
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
Expand Down Expand Up @@ -140,7 +140,7 @@ open class ServerOperationHandlerGenerator(
} else {
"Fun: FnOnce($inputName) -> Fut + Clone + Send + 'static,"
}
val outputType = if (operation.nonEventStreamErrors(model).isNotEmpty()) {
val outputType = if (operation.operationErrors(model).isNotEmpty()) {
"Result<${symbolProvider.toSymbol(operation.outputShape(model)).fullName}, ${operation.errorSymbol(symbolProvider).fullyQualifiedName()}>"
} else {
symbolProvider.toSymbol(operation.outputShape(model)).fullName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.toOptional
import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.smithy.wrapOptional
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
Expand Down Expand Up @@ -218,7 +218,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
val outputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
val errorSymbol = operationShape.errorSymbol(symbolProvider)

if (operationShape.nonEventStreamErrors(model).isNotEmpty()) {
if (operationShape.operationErrors(model).isNotEmpty()) {
// The output of fallible operations is a `Result` which we convert into an
// isomorphic `enum` type we control that can in turn be converted into a response.
val intoResponseImpl =
Expand Down Expand Up @@ -304,7 +304,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
}

// Implement conversion function to "wrap" from the model operation output types.
if (operationShape.nonEventStreamErrors(model).isNotEmpty()) {
if (operationShape.operationErrors(model).isNotEmpty()) {
rustTemplate(
"""
impl #{From}<Result<#{O}, #{E}>> for $outputName {
Expand Down Expand Up @@ -443,7 +443,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
val operationName = symbolProvider.toSymbol(operationShape).name
val structuredDataSerializer = protocol.structuredDataSerializer(operationShape)
withBlock("match error {", "}") {
val errors = operationShape.nonEventStreamErrors(model)
val errors = operationShape.operationErrors(model)
errors.forEach {
val variantShape = model.expectShape(it.id, StructureShape::class.java)
val errorTrait = variantShape.expectTrait<ErrorTrait>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.customize.Section
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.transformers.getStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.isEventStream
import software.amazon.smithy.rust.codegen.util.toSnakeCase
Expand Down Expand Up @@ -68,7 +68,7 @@ class CombinedErrorGenerator(

fun render(writer: RustWriter) {
val symbol = operation.errorSymbol(symbolProvider)
renderErrors(writer, operation.nonEventStreamErrors(model).map { it.asStructureShape().get() }.toMutableList(), symbol)
renderErrors(writer, operation.operationErrors(model).map { it.asStructureShape().get() }.toMutableList(), symbol)

if (operation.isEventStream(model)) {
val clientErrors = operation.getStreamErrors(model, CodegenTarget.CLIENT).map { it.asStructureShape().get() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import software.amazon.smithy.rust.codegen.smithy.RustCrate
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.transformers.allErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.getStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors

/**
* Each service defines its own "top-level" error combining all possible errors that a service can emit.
Expand Down Expand Up @@ -78,7 +78,7 @@ class TopLevelErrorGenerator(codegenContext: CodegenContext, private val operati
}

private fun RustWriter.renderImplFrom(operationShape: OperationShape) {
val nonEventStreamErrors = operationShape.nonEventStreamErrors(model).map { it.id }
val nonEventStreamErrors = operationShape.operationErrors(model).map { it.id }
val allErrors: List<Pair<RuntimeType, List<ShapeId>>> = listOf(
Pair(operationShape.errorSymbol(symbolProvider), nonEventStreamErrors),
Pair(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolTr
import software.amazon.smithy.rust.codegen.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember
import software.amazon.smithy.rust.codegen.smithy.transformers.nonEventStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
Expand Down Expand Up @@ -168,7 +168,7 @@ class HttpBoundProtocolTraitImplGenerator(
protocol.parseHttpGenericError(operationShape),
errorSymbol
)
if (operationShape.nonEventStreamErrors(model).isNotEmpty()) {
if (operationShape.operationErrors(model).isNotEmpty()) {
rustTemplate(
"""
let error_code = match generic.code() {
Expand All @@ -181,7 +181,7 @@ class HttpBoundProtocolTraitImplGenerator(
"error_symbol" to errorSymbol,
)
withBlock("Err(match error_code {", "})") {
val errors = operationShape.nonEventStreamErrors(model)
val errors = operationShape.operationErrors(model)
errors.forEach { error ->
val errorShape = model.expectShape(error.id, StructureShape::class.java)
val variantName = symbolProvider.toSymbol(model.expectShape(error.id)).name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ package software.amazon.smithy.rust.codegen.smithy.transformers

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.AnnotationTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
Expand All @@ -26,10 +29,18 @@ import software.amazon.smithy.rust.codegen.util.isEventStream
*/
object EventStreamNormalizer {
fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape ->
if (shape is OperationShape && shape.isEventStream(model)) {
addStreamErrorsToOperationErrors(model, shape)
if (shape is OperationShape) {
val newErrors = shape.errors
.map { model.expectShape(it, StructureShape::class.java) }
.map { it.toBuilder().addTrait(SyntheticOperationErrorTrait()).build() }
val newShape = shape.toBuilder().errors(newErrors.map { it.id }).build()
if (newShape.isEventStream(model)) {
addStreamErrorsToOperationErrors(model, newShape)
} else {
newShape
}
} else if (shape is UnionShape && shape.isEventStream()) {
syntheticEquivalent(model, shape)
syntheticEquivalentEventStreamUnion(model, shape)
} else {
shape
}
Expand All @@ -54,7 +65,7 @@ object EventStreamNormalizer {
.build()
}

private fun syntheticEquivalent(model: Model, union: UnionShape): UnionShape {
private fun syntheticEquivalentEventStreamUnion(model: Model, union: UnionShape): UnionShape {
val (errorMembers, eventMembers) = union.members().partition { member ->
model.expectShape(member.target).hasTrait<ErrorTrait>()
}
Expand Down Expand Up @@ -82,14 +93,20 @@ fun OperationShape.getStreamErrors(model: Model, target: CodegenTarget): List<Sh
CodegenTarget.SERVER -> inputOutput.filter { it.expectTrait<ErrorTrait>().isClientError }
}
}
fun OperationShape.nonEventStreamErrors(model: Model): List<Shape> {
fun OperationShape.operationErrors(model: Model): List<Shape> {
val operationIndex = OperationIndex.of(model)
val eventStreamErrors = this.eventStreamErrors(model)
return operationIndex.getErrors(this).filter { !eventStreamErrors.contains(it) }
return operationIndex.getErrors(this)
.filter { it.hasTrait<SyntheticOperationErrorTrait>() }
}
fun OperationShape.eventStreamErrors(model: Model): List<Shape> {
return this.getStreamErrors(model, CodegenTarget.CLIENT) + this.getStreamErrors(model, CodegenTarget.SERVER)
}
fun OperationShape.allErrors(model: Model): List<Shape> {
return this.eventStreamErrors(model) + this.nonEventStreamErrors(model)
return this.eventStreamErrors(model) + this.operationErrors(model)
}

class SyntheticOperationErrorTrait : AnnotationTrait(ID, Node.objectNode()) {
companion object {
val ID = ShapeId.from("smithy.api.internal#syntheticOperationErrorTrait")
}
}

0 comments on commit 2c4db07

Please sign in to comment.