Skip to content

Commit

Permalink
Split RuntimeError and RequestRejection by protocol
Browse files Browse the repository at this point in the history
As outlined in the [Protocol Specific Errors] of the [Service Builder
Improvements RFC], `RuntimeError` should be split up into smaller,
protocol specific, errors which accurately model the failure cases of
each protocol.

The same goes for `RequestRejection`.

Closes #1703.

[Protocol Specific Errors]: https://github.com/awslabs/smithy-rs/blob/main/design/src/rfcs/rfc0020_service_builder.md#protocol-specific-errors
[Service Builder Improvements RFC]: https://github.com/awslabs/smithy-rs/blob/main/design/src/rfcs/rfc0020_service_builder.md
  • Loading branch information
david-perez committed Mar 30, 2023
1 parent 92316f7 commit 946df69
Show file tree
Hide file tree
Showing 26 changed files with 802 additions and 466 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ open class ServerCodegenVisitor(

val baseModel = baselineTransform(context.model)
val service = settings.getService(baseModel)
val (protocol, generator) =
val (protocolShape, protocolGeneratorFactory) =
ServerProtocolLoader(
codegenDecorator.protocols(
service.id,
ServerProtocolLoader.DefaultProtocols,
),
)
.protocolFor(context.model, service)
protocolGeneratorFactory = generator
this.protocolGeneratorFactory = protocolGeneratorFactory

model = codegenDecorator.transformModel(service, baseModel)

Expand All @@ -145,7 +145,7 @@ open class ServerCodegenVisitor(
serverSymbolProviders.symbolProvider,
null,
service,
protocol,
protocolShape,
settings,
serverSymbolProviders.unconstrainedShapeSymbolProvider,
serverSymbolProviders.constrainedShapeSymbolProvider,
Expand All @@ -169,7 +169,7 @@ open class ServerCodegenVisitor(
settings.codegenConfig,
codegenContext.expectModuleDocProvider(),
)
protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
protocolGenerator = this.protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
}

/**
Expand Down Expand Up @@ -315,7 +315,12 @@ open class ServerCodegenVisitor(
writer: RustWriter,
) {
if (codegenContext.settings.codegenConfig.publicConstrainedTypes || shape.isReachableFromOperationInput()) {
val serverBuilderGenerator = ServerBuilderGenerator(codegenContext, shape, validationExceptionConversionGenerator)
val serverBuilderGenerator = ServerBuilderGenerator(
codegenContext,
shape,
validationExceptionConversionGenerator,
protocolGenerator.protocol,
)
serverBuilderGenerator.render(rustCrate, writer)

if (codegenContext.settings.codegenConfig.publicConstrainedTypes) {
Expand All @@ -336,7 +341,12 @@ open class ServerCodegenVisitor(

if (!codegenContext.settings.codegenConfig.publicConstrainedTypes) {
val serverBuilderGeneratorWithoutPublicConstrainedTypes =
ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape, validationExceptionConversionGenerator)
ServerBuilderGeneratorWithoutPublicConstrainedTypes(
codegenContext,
shape,
validationExceptionConversionGenerator,
protocolGenerator.protocol,
)
serverBuilderGeneratorWithoutPublicConstrainedTypes.render(rustCrate, writer)

writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,19 @@ package software.amazon.smithy.rust.codegen.server.smithy
import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* Object used *exclusively* in the runtime of the server, for separation concerns.
* Analogous to the companion object in [RuntimeType]; see its documentation for details.
* For a runtime type that is used in the client, or in both the client and the server, use [RuntimeType] directly.
*/
object ServerRuntimeType {
fun forInlineDependency(inlineDependency: InlineDependency) = RuntimeType("crate::${inlineDependency.name}", inlineDependency)
fun router(runtimeConfig: RuntimeConfig) =
ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

fun router(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

fun runtimeError(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("runtime_error::RuntimeError")

fun requestRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::RequestRejection")

fun responseRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::ResponseRejection")

fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")
fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) =
ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")

fun protocol(runtimeConfig: RuntimeConfig) = protocol("Protocol", "", runtimeConfig)
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitI
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
Expand Down Expand Up @@ -67,11 +68,7 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
override val shapeId: ShapeId =
ShapeId.from(codegenContext.settings.codegenConfig.experimentalCustomValidationExceptionWithReasonPleaseDoNotUse)

override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
rustTemplate(
"""
impl #{From}<ConstraintViolation> for #{RequestRejection} {
Expand All @@ -89,7 +86,8 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
}
}
""",
*codegenScope,
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
Expand Down Expand Up @@ -66,11 +67,7 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
}
override val shapeId: ShapeId = SHAPE_ID

override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
rustTemplate(
"""
impl #{From}<ConstraintViolation> for #{RequestRejection} {
Expand All @@ -87,7 +84,8 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
}
}
""",
*codegenScope,
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTraitOrTargetHasConstraintTrait
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait
Expand Down Expand Up @@ -92,6 +93,7 @@ class ServerBuilderGenerator(
val codegenContext: ServerCodegenContext,
private val shape: StructureShape,
private val customValidationExceptionWithReasonConversionGenerator: ValidationExceptionConversionGenerator,
private val protocol: ServerProtocol,
) {
companion object {
/**
Expand Down Expand Up @@ -148,7 +150,7 @@ class ServerBuilderGenerator(
ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes, customValidationExceptionWithReasonConversionGenerator)

private val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(runtimeConfig),
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"Structure" to structureSymbol,
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
Expand Down Expand Up @@ -222,7 +224,8 @@ class ServerBuilderGenerator(
"""
#{Converter:W}
""",
"Converter" to customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(),
"Converter" to
customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(protocol),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.withInMemoryInlineModule

/**
Expand All @@ -49,6 +50,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes(
private val codegenContext: ServerCodegenContext,
shape: StructureShape,
validationExceptionConversionGenerator: ValidationExceptionConversionGenerator,
protocol: ServerProtocol,
) {
companion object {
/**
Expand Down Expand Up @@ -85,7 +87,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes(
ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false, validationExceptionConversionGenerator)

private val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"Structure" to structureSymbol,
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* Collection of methods that will be invoked by the respective generators to generate code to convert constraint
Expand All @@ -26,7 +27,7 @@ interface ValidationExceptionConversionGenerator {
* Convert from a top-level operation input's constraint violation into
* `aws_smithy_http_server::rejection::RequestRejection`.
*/
fun renderImplFromConstraintViolationForRequestRejection(): Writable
fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable

// Simple shapes.
fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection<StringTraitInfo>): Writable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJson
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape

private fun allOperations(codegenContext: CodegenContext): List<OperationShape> {
val index = TopDownIndex.of(codegenContext.model)
return index.getContainedOperations(codegenContext.serviceShape).sortedBy { it.id }
}

interface ServerProtocol : Protocol {
/** The path such that `aws_smithy_http_server::proto::$path` points to the protocol's module. */
val protocolModulePath: String;

/** Returns the Rust marker struct enjoying `OperationShape`. */
fun markerStruct(): RuntimeType

Expand Down Expand Up @@ -76,6 +74,21 @@ interface ServerProtocol : Protocol {
* Returns a boolean indicating whether to perform this check.
*/
fun serverContentTypeCheckNoModeledInput(): Boolean = false

/** The protocol-specific `RequestRejection` type. **/
fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::${protocolModulePath}::rejection::RequestRejection")

/** The protocol-specific `ResponseRejection` type. **/
fun responseRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::${protocolModulePath}::rejection::ResponseRejection")

/** The protocol-specific `RuntimeError` type. **/
fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::${protocolModulePath}::runtime_error::RuntimeError")
}

class ServerAwsJsonProtocol(
Expand All @@ -84,6 +97,12 @@ class ServerAwsJsonProtocol(
) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol {
private val runtimeConfig = codegenContext.runtimeConfig

override val protocolModulePath: String
get() = when (version) {
is AwsJsonVersion.Json10 -> "aws_json_10"
is AwsJsonVersion.Json11 -> "aws_json_11"
}

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, serverCodegenContext.symbolProvider)) {
Expand All @@ -107,12 +126,8 @@ class ServerAwsJsonProtocol(

override fun markerStruct(): RuntimeType {
return when (version) {
is AwsJsonVersion.Json10 -> {
ServerRuntimeType.protocol("AwsJson1_0", "aws_json_10", runtimeConfig)
}
is AwsJsonVersion.Json11 -> {
ServerRuntimeType.protocol("AwsJson1_1", "aws_json_11", runtimeConfig)
}
is AwsJsonVersion.Json10 -> ServerRuntimeType.protocol("AwsJson1_0", protocolModulePath, runtimeConfig)
is AwsJsonVersion.Json11 -> ServerRuntimeType.protocol("AwsJson1_1", protocolModulePath, runtimeConfig)
}
}

Expand All @@ -139,6 +154,16 @@ class ServerAwsJsonProtocol(
AwsJsonVersion.Json10 -> "new_aws_json_10_router"
AwsJsonVersion.Json11 -> "new_aws_json_11_router"
}

override fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::rejection::RequestRejection")
override fun responseRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::rejection::ResponseRejection")
override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::runtime_error::RuntimeError")
}

private fun restRouterType(runtimeConfig: RuntimeConfig) =
Expand All @@ -150,6 +175,8 @@ class ServerRestJsonProtocol(
) : RestJson(serverCodegenContext), ServerProtocol {
val runtimeConfig = codegenContext.runtimeConfig

override val protocolModulePath: String = "rest_json_1"

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) {
Expand All @@ -173,7 +200,8 @@ class ServerRestJsonProtocol(
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver)

override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", "rest_json_1", runtimeConfig)

override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", protocolModulePath, runtimeConfig)

override fun routerType() = restRouterType(runtimeConfig)

Expand All @@ -196,8 +224,9 @@ class ServerRestXmlProtocol(
codegenContext: CodegenContext,
) : RestXml(codegenContext), ServerProtocol {
val runtimeConfig = codegenContext.runtimeConfig
override val protocolModulePath = "rest_xml"

override fun markerStruct() = ServerRuntimeType.protocol("RestXml", "rest_xml", runtimeConfig)
override fun markerStruct() = ServerRuntimeType.protocol("RestXml", protocolModulePath, runtimeConfig)

override fun routerType() = restRouterType(runtimeConfig)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ class ServerHttpBoundProtocolTraitImplGenerator(
"Regex" to RuntimeType.Regex,
"SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig),
"SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(),
"RuntimeError" to ServerRuntimeType.runtimeError(runtimeConfig),
"RequestRejection" to ServerRuntimeType.requestRejection(runtimeConfig),
"ResponseRejection" to ServerRuntimeType.responseRejection(runtimeConfig),
"RuntimeError" to protocol.runtimeError(runtimeConfig),
"RequestRejection" to protocol.requestRejection(runtimeConfig),
"ResponseRejection" to protocol.responseRejection(runtimeConfig),
"PinProjectLite" to ServerCargoDependency.PinProjectLite.toType(),
"http" to RuntimeType.Http,
)
Expand All @@ -159,12 +159,11 @@ class ServerHttpBoundProtocolTraitImplGenerator(
outputSymbol: Symbol,
operationShape: OperationShape,
) {
val operationName = symbolProvider.toSymbol(operationShape).name
val verifyAcceptHeader = writable {
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
rustTemplate(
"""
if ! #{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), ${contentType.dq()}) {
if !#{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), ${contentType.dq()}) {
return Err(#{RuntimeError}::NotAcceptable)
}
""",
Expand Down Expand Up @@ -1149,7 +1148,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
check(binding.location == HttpLocation.PAYLOAD)

if (model.expectShape(binding.member.target) is StringShape) {
return ServerRuntimeType.requestRejection(runtimeConfig).toSymbol()
return protocol.requestRejection(runtimeConfig).toSymbol()
}
return when (codegenContext.protocol) {
RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> {
Expand Down
Loading

0 comments on commit 946df69

Please sign in to comment.