Skip to content

Commit

Permalink
Refactor determining server error type when deserializing an `@httpPa…
Browse files Browse the repository at this point in the history
…yload` (#3752)

Determining the error type when deserializing an `@httpPayload` is a
protocol-specific concern, and as such should not live in
`ServerHttpBoundProtocolGenerator`, which should remain
protocol-agnostic. This commits makes that determination part of the
`ServerProtocol` interface.

As a drive-by improvement, the companion object in
`ServerHttpBoundProtocolGenerator` has also been removed, since its
members have been unused for a long time.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
david-perez authored Jul 10, 2024
1 parent dc66ae4 commit 4c30f00
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators.http

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
Expand All @@ -20,12 +19,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindi
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape

class ServerRequestBindingGenerator(
protocol: Protocol,
val protocol: ServerProtocol,
codegenContext: ServerCodegenContext,
operationShape: OperationShape,
additionalHttpBindingCustomizations: List<HttpBindingCustomization> = listOf(),
Expand All @@ -50,12 +49,11 @@ class ServerRequestBindingGenerator(

fun generateDeserializePayloadFn(
binding: HttpBindingDescriptor,
errorSymbol: Symbol,
structuredHandler: RustWriter.(String) -> Unit,
): RuntimeType =
httpBindingGenerator.generateDeserializePayloadFn(
binding,
errorSymbol,
protocol.deserializePayloadErrorType(binding).toSymbol(),
structuredHandler,
HttpMessageType.REQUEST,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
Expand All @@ -17,7 +18,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
Expand Down Expand Up @@ -70,8 +73,8 @@ interface ServerProtocol : Protocol {
fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType

/**
* In some protocols, such as restJson1,
* when there is no modeled body input, content type must not be set and the body must be empty.
* In some protocols, such as `restJson1` and `rpcv2Cbor`,
* when there is no modeled body input, `content-type` must not be set and the body must be empty.
* Returns a boolean indicating whether to perform this check.
*/
fun serverContentTypeCheckNoModeledInput(): Boolean = false
Expand All @@ -90,6 +93,19 @@ interface ServerProtocol : Protocol {
fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("protocol::$protocolModulePath::runtime_error::RuntimeError")

/**
* The function that deserializes a payload-bound shape takes as input a byte slab and returns a `Result` holding
* the deserialized shape if successful. What error type should we use in case of failure?
*
* The shape could be payload-bound either because of the `@httpPayload` trait, or because it's part of an event
* stream.
*
* Note that despite the trait (https://smithy.io/2.0/spec/http-bindings.html#httppayload-trait) being able to
* target any structure member shape, AWS Protocols only support binding the following shape types to the payload
* (and Smithy does indeed enforce this at model build-time): string, blob, structure, union, and document
*/
fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType
}

fun returnSymbolToParseFn(codegenContext: ServerCodegenContext): (Shape) -> ReturnSymbolToParse {
Expand Down Expand Up @@ -185,6 +201,18 @@ class ServerAwsJsonProtocol(
override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("protocol::aws_json::runtime_error::RuntimeError")

/*
* Note that despite the AWS JSON 1.x protocols not supporting the `@httpPayload` trait, event streams are bound
* to the payload.
*/
override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
deserializePayloadErrorType(
codegenContext,
binding,
requestRejection(runtimeConfig),
RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"),
)
}

private fun restRouterType(runtimeConfig: RuntimeConfig) =
Expand Down Expand Up @@ -227,6 +255,14 @@ class ServerRestJsonProtocol(
override fun serverRouterRuntimeConstructor() = "new_rest_json_router"

override fun serverContentTypeCheckNoModeledInput() = true

override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
deserializePayloadErrorType(
codegenContext,
binding,
requestRejection(runtimeConfig),
RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"),
)
}

class ServerRestXmlProtocol(
Expand All @@ -252,6 +288,32 @@ class ServerRestXmlProtocol(
override fun serverRouterRuntimeConstructor() = "new_rest_xml_router"

override fun serverContentTypeCheckNoModeledInput() = true

override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
deserializePayloadErrorType(
codegenContext,
binding,
requestRejection(runtimeConfig),
RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"),
)
}

/** Just a common function to keep things DRY. **/
fun deserializePayloadErrorType(
codegenContext: CodegenContext,
binding: HttpBindingDescriptor,
requestRejection: RuntimeType,
protocolSerializationFormatError: RuntimeType,
): RuntimeType {
check(binding.location == HttpLocation.PAYLOAD)

if (codegenContext.model.expectShape(binding.member.target) is StringShape) {
// The only way deserializing a string can fail is if the HTTP body does not contain valid UTF-8.
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3750): we're returning an incorrect `RequestRejection` variant here.
return requestRejection
}

return protocolSerializationFormatError
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.protocols

import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.node.ExpectationNotMetException
Expand All @@ -20,7 +16,6 @@ import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
Expand Down Expand Up @@ -124,13 +119,7 @@ class ServerHttpBoundProtocolGenerator(
) : ServerProtocolGenerator(
protocol,
ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations),
) {
// Define suffixes for operation input / output / error wrappers
companion object {
const val OPERATION_INPUT_WRAPPER_SUFFIX = "OperationInputWrapper"
const val OPERATION_OUTPUT_WRAPPER_SUFFIX = "OperationOutputWrapper"
}
}
)

class ServerHttpBoundProtocolPayloadGenerator(
codegenContext: CodegenContext,
Expand Down Expand Up @@ -697,8 +686,6 @@ class ServerHttpBoundProtocolTraitImplGenerator(
inputShape: StructureShape,
bindings: List<HttpBindingDescriptor>,
) {
val httpBindingGenerator =
ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
val structuredDataParser = protocol.structuredDataParser()
Attribute.AllowUnusedMut.render(this)
rust(
Expand Down Expand Up @@ -740,7 +727,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
for (binding in bindings) {
val member = binding.member
val parsedValue =
serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser)
serverRenderBindingParser(binding, operationShape, httpBindingGenerator(operationShape), structuredDataParser)
val valueToSet =
if (symbolProvider.toSymbol(binding.member).isOptional()) {
"Some(value)"
Expand Down Expand Up @@ -801,13 +788,8 @@ class ServerHttpBoundProtocolTraitImplGenerator(
val structureShapeHandler: RustWriter.(String) -> Unit = { body ->
rust("#T($body)", structuredDataParser.payloadParser(binding.member))
}
val errorSymbol = getDeserializePayloadErrorSymbol(binding)
val deserializer =
httpBindingGenerator.generateDeserializePayloadFn(
binding,
errorSymbol,
structuredHandler = structureShapeHandler,
)
httpBindingGenerator.generateDeserializePayloadFn(binding, structuredHandler = structureShapeHandler)
return writable {
if (binding.member.isStreaming(model)) {
rustTemplate(
Expand Down Expand Up @@ -1196,9 +1178,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
binding: HttpBindingDescriptor,
operationShape: OperationShape,
) {
val httpBindingGenerator =
ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding)
val deserializer = httpBindingGenerator(operationShape).generateDeserializeHeaderFn(binding)
writer.rustTemplate(
"""
#{deserializer}(&headers)?
Expand All @@ -1215,8 +1195,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
) {
check(binding.location == HttpLocation.PREFIX_HEADERS)

val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape)
val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding)
val deserializer = httpBindingGenerator(operationShape).generateDeserializePrefixHeadersFn(binding)
writer.rustTemplate(
"""
#{deserializer}(&headers)?
Expand Down Expand Up @@ -1300,33 +1279,13 @@ class ServerHttpBoundProtocolTraitImplGenerator(
}
}

/**
* Returns the error type of the function that deserializes a non-streaming HTTP payload (a byte slab) into the
* shape targeted by the `httpPayload` trait.
*/
private fun getDeserializePayloadErrorSymbol(binding: HttpBindingDescriptor): Symbol {
check(binding.location == HttpLocation.PAYLOAD)

if (model.expectShape(binding.member.target) is StringShape) {
return protocol.requestRejection(runtimeConfig).toSymbol()
}
return when (codegenContext.protocol) {
RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> {
RuntimeType.smithyJson(runtimeConfig).resolve("deserialize::error::DeserializeError").toSymbol()
}
RestXmlTrait.ID -> {
RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError").toSymbol()
}
else -> {
TODO("Protocol ${codegenContext.protocol} not supported yet")
}
}
}

private fun streamingBodyTraitBounds(operationShape: OperationShape) =
if (operationShape.inputShape(model).hasStreamingMember(model)) {
"\n B: Into<#{SmithyTypes}::byte_stream::ByteStream>,"
} else {
""
}

private fun httpBindingGenerator(operationShape: OperationShape) =
ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
}

0 comments on commit 4c30f00

Please sign in to comment.