diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt index 12ecad9b4b..a784b9a9ac 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.http -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter @@ -20,12 +19,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindi import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape class ServerRequestBindingGenerator( - protocol: Protocol, + val protocol: ServerProtocol, codegenContext: ServerCodegenContext, operationShape: OperationShape, additionalHttpBindingCustomizations: List = listOf(), @@ -50,12 +49,11 @@ class ServerRequestBindingGenerator( fun generateDeserializePayloadFn( binding: HttpBindingDescriptor, - errorSymbol: Symbol, structuredHandler: RustWriter.(String) -> Unit, ): RuntimeType = httpBindingGenerator.generateDeserializePayloadFn( binding, - errorSymbol, + protocol.deserializePayloadErrorType(binding).toSymbol(), structuredHandler, HttpMessageType.REQUEST, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 6434c0290c..f31f6d92da 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -17,7 +18,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml @@ -70,8 +73,8 @@ interface ServerProtocol : Protocol { fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType /** - * In some protocols, such as restJson1, - * when there is no modeled body input, content type must not be set and the body must be empty. + * In some protocols, such as `restJson1` and `rpcv2Cbor`, + * when there is no modeled body input, `content-type` must not be set and the body must be empty. * Returns a boolean indicating whether to perform this check. */ fun serverContentTypeCheckNoModeledInput(): Boolean = false @@ -90,6 +93,19 @@ interface ServerProtocol : Protocol { fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType = ServerCargoDependency.smithyHttpServer(runtimeConfig) .toType().resolve("protocol::$protocolModulePath::runtime_error::RuntimeError") + + /** + * The function that deserializes a payload-bound shape takes as input a byte slab and returns a `Result` holding + * the deserialized shape if successful. What error type should we use in case of failure? + * + * The shape could be payload-bound either because of the `@httpPayload` trait, or because it's part of an event + * stream. + * + * Note that despite the trait (https://smithy.io/2.0/spec/http-bindings.html#httppayload-trait) being able to + * target any structure member shape, AWS Protocols only support binding the following shape types to the payload + * (and Smithy does indeed enforce this at model build-time): string, blob, structure, union, and document + */ + fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType } fun returnSymbolToParseFn(codegenContext: ServerCodegenContext): (Shape) -> ReturnSymbolToParse { @@ -185,6 +201,18 @@ class ServerAwsJsonProtocol( override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType = ServerCargoDependency.smithyHttpServer(runtimeConfig) .toType().resolve("protocol::aws_json::runtime_error::RuntimeError") + + /* + * Note that despite the AWS JSON 1.x protocols not supporting the `@httpPayload` trait, event streams are bound + * to the payload. + */ + override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType = + deserializePayloadErrorType( + codegenContext, + binding, + requestRejection(runtimeConfig), + RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"), + ) } private fun restRouterType(runtimeConfig: RuntimeConfig) = @@ -227,6 +255,14 @@ class ServerRestJsonProtocol( override fun serverRouterRuntimeConstructor() = "new_rest_json_router" override fun serverContentTypeCheckNoModeledInput() = true + + override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType = + deserializePayloadErrorType( + codegenContext, + binding, + requestRejection(runtimeConfig), + RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"), + ) } class ServerRestXmlProtocol( @@ -252,6 +288,32 @@ class ServerRestXmlProtocol( override fun serverRouterRuntimeConstructor() = "new_rest_xml_router" override fun serverContentTypeCheckNoModeledInput() = true + + override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType = + deserializePayloadErrorType( + codegenContext, + binding, + requestRejection(runtimeConfig), + RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), + ) +} + +/** Just a common function to keep things DRY. **/ +fun deserializePayloadErrorType( + codegenContext: CodegenContext, + binding: HttpBindingDescriptor, + requestRejection: RuntimeType, + protocolSerializationFormatError: RuntimeType, +): RuntimeType { + check(binding.location == HttpLocation.PAYLOAD) + + if (codegenContext.model.expectShape(binding.member.target) is StringShape) { + // The only way deserializing a string can fail is if the HTTP body does not contain valid UTF-8. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3750): we're returning an incorrect `RequestRejection` variant here. + return requestRejection + } + + return protocolSerializationFormatError } /** diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index f6ba45ab27..3d94bb8821 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -5,10 +5,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols -import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait -import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait -import software.amazon.smithy.aws.traits.protocols.RestJson1Trait -import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.node.ExpectationNotMetException @@ -20,7 +16,6 @@ import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.HttpErrorTrait @@ -124,13 +119,7 @@ class ServerHttpBoundProtocolGenerator( ) : ServerProtocolGenerator( protocol, ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), - ) { - // Define suffixes for operation input / output / error wrappers - companion object { - const val OPERATION_INPUT_WRAPPER_SUFFIX = "OperationInputWrapper" - const val OPERATION_OUTPUT_WRAPPER_SUFFIX = "OperationOutputWrapper" - } -} + ) class ServerHttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, @@ -697,8 +686,6 @@ class ServerHttpBoundProtocolTraitImplGenerator( inputShape: StructureShape, bindings: List, ) { - val httpBindingGenerator = - ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) val structuredDataParser = protocol.structuredDataParser() Attribute.AllowUnusedMut.render(this) rust( @@ -740,7 +727,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( for (binding in bindings) { val member = binding.member val parsedValue = - serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) + serverRenderBindingParser(binding, operationShape, httpBindingGenerator(operationShape), structuredDataParser) val valueToSet = if (symbolProvider.toSymbol(binding.member).isOptional()) { "Some(value)" @@ -801,13 +788,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( val structureShapeHandler: RustWriter.(String) -> Unit = { body -> rust("#T($body)", structuredDataParser.payloadParser(binding.member)) } - val errorSymbol = getDeserializePayloadErrorSymbol(binding) val deserializer = - httpBindingGenerator.generateDeserializePayloadFn( - binding, - errorSymbol, - structuredHandler = structureShapeHandler, - ) + httpBindingGenerator.generateDeserializePayloadFn(binding, structuredHandler = structureShapeHandler) return writable { if (binding.member.isStreaming(model)) { rustTemplate( @@ -1196,9 +1178,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( binding: HttpBindingDescriptor, operationShape: OperationShape, ) { - val httpBindingGenerator = - ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) - val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding) + val deserializer = httpBindingGenerator(operationShape).generateDeserializeHeaderFn(binding) writer.rustTemplate( """ #{deserializer}(&headers)? @@ -1215,8 +1195,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) { check(binding.location == HttpLocation.PREFIX_HEADERS) - val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) - val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding) + val deserializer = httpBindingGenerator(operationShape).generateDeserializePrefixHeadersFn(binding) writer.rustTemplate( """ #{deserializer}(&headers)? @@ -1300,33 +1279,13 @@ class ServerHttpBoundProtocolTraitImplGenerator( } } - /** - * Returns the error type of the function that deserializes a non-streaming HTTP payload (a byte slab) into the - * shape targeted by the `httpPayload` trait. - */ - private fun getDeserializePayloadErrorSymbol(binding: HttpBindingDescriptor): Symbol { - check(binding.location == HttpLocation.PAYLOAD) - - if (model.expectShape(binding.member.target) is StringShape) { - return protocol.requestRejection(runtimeConfig).toSymbol() - } - return when (codegenContext.protocol) { - RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> { - RuntimeType.smithyJson(runtimeConfig).resolve("deserialize::error::DeserializeError").toSymbol() - } - RestXmlTrait.ID -> { - RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError").toSymbol() - } - else -> { - TODO("Protocol ${codegenContext.protocol} not supported yet") - } - } - } - private fun streamingBodyTraitBounds(operationShape: OperationShape) = if (operationShape.inputShape(model).hasStreamingMember(model)) { "\n B: Into<#{SmithyTypes}::byte_stream::ByteStream>," } else { "" } + + private fun httpBindingGenerator(operationShape: OperationShape) = + ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) }