Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split RuntimeError and RequestRejection by protocol #2517

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ open class ServerCodegenVisitor(

val baseModel = baselineTransform(context.model)
val service = settings.getService(baseModel)
val (protocol, generator) =
val (protocolShape, protocolGeneratorFactory) =
ServerProtocolLoader(
codegenDecorator.protocols(
service.id,
ServerProtocolLoader.DefaultProtocols,
),
)
.protocolFor(context.model, service)
protocolGeneratorFactory = generator
this.protocolGeneratorFactory = protocolGeneratorFactory

model = codegenDecorator.transformModel(service, baseModel)

Expand All @@ -145,7 +145,7 @@ open class ServerCodegenVisitor(
serverSymbolProviders.symbolProvider,
null,
service,
protocol,
protocolShape,
settings,
serverSymbolProviders.unconstrainedShapeSymbolProvider,
serverSymbolProviders.constrainedShapeSymbolProvider,
Expand All @@ -169,7 +169,7 @@ open class ServerCodegenVisitor(
settings.codegenConfig,
codegenContext.expectModuleDocProvider(),
)
protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
protocolGenerator = this.protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
}

/**
Expand Down Expand Up @@ -315,7 +315,12 @@ open class ServerCodegenVisitor(
writer: RustWriter,
) {
if (codegenContext.settings.codegenConfig.publicConstrainedTypes || shape.isReachableFromOperationInput()) {
val serverBuilderGenerator = ServerBuilderGenerator(codegenContext, shape, validationExceptionConversionGenerator)
val serverBuilderGenerator = ServerBuilderGenerator(
codegenContext,
shape,
validationExceptionConversionGenerator,
protocolGenerator.protocol,
)
serverBuilderGenerator.render(rustCrate, writer)

if (codegenContext.settings.codegenConfig.publicConstrainedTypes) {
Expand All @@ -336,7 +341,12 @@ open class ServerCodegenVisitor(

if (!codegenContext.settings.codegenConfig.publicConstrainedTypes) {
val serverBuilderGeneratorWithoutPublicConstrainedTypes =
ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape, validationExceptionConversionGenerator)
ServerBuilderGeneratorWithoutPublicConstrainedTypes(
codegenContext,
shape,
validationExceptionConversionGenerator,
protocolGenerator.protocol,
)
serverBuilderGeneratorWithoutPublicConstrainedTypes.render(rustCrate, writer)

writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy

import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType

Expand All @@ -15,17 +14,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
* For a runtime type that is used in the client, or in both the client and the server, use [RuntimeType] directly.
*/
object ServerRuntimeType {
fun forInlineDependency(inlineDependency: InlineDependency) = RuntimeType("crate::${inlineDependency.name}", inlineDependency)
fun router(runtimeConfig: RuntimeConfig) =
ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

fun router(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

fun runtimeError(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("runtime_error::RuntimeError")

fun requestRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::RequestRejection")

fun responseRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::ResponseRejection")

fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")
fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) =
ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")

fun protocol(runtimeConfig: RuntimeConfig) = protocol("Protocol", "", runtimeConfig)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength
import software.amazon.smithy.rust.codegen.server.smithy.generators.CollectionTraitInfo
Expand All @@ -35,6 +34,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitI
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
Expand Down Expand Up @@ -67,11 +67,7 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
override val shapeId: ShapeId =
ShapeId.from(codegenContext.settings.codegenConfig.experimentalCustomValidationExceptionWithReasonPleaseDoNotUse)

override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
rustTemplate(
"""
impl #{From}<ConstraintViolation> for #{RequestRejection} {
Expand All @@ -89,7 +85,8 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
}
}
""",
*codegenScope,
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength
import software.amazon.smithy.rust.codegen.server.smithy.generators.CollectionTraitInfo
Expand All @@ -34,6 +33,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
Expand Down Expand Up @@ -66,11 +66,7 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
}
override val shapeId: ShapeId = SHAPE_ID

override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
rustTemplate(
"""
impl #{From}<ConstraintViolation> for #{RequestRejection} {
Expand All @@ -87,7 +83,8 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
}
}
""",
*codegenScope,
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTraitOrTargetHasConstraintTrait
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait
Expand Down Expand Up @@ -92,6 +92,7 @@ class ServerBuilderGenerator(
val codegenContext: ServerCodegenContext,
private val shape: StructureShape,
private val customValidationExceptionWithReasonConversionGenerator: ValidationExceptionConversionGenerator,
private val protocol: ServerProtocol,
) {
companion object {
/**
Expand Down Expand Up @@ -148,7 +149,7 @@ class ServerBuilderGenerator(
ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes, customValidationExceptionWithReasonConversionGenerator)

private val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(runtimeConfig),
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"Structure" to structureSymbol,
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
Expand Down Expand Up @@ -222,7 +223,8 @@ class ServerBuilderGenerator(
"""
#{Converter:W}
""",
"Converter" to customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(),
"Converter" to
customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(protocol),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.withInMemoryInlineModule

/**
Expand All @@ -49,6 +49,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes(
private val codegenContext: ServerCodegenContext,
shape: StructureShape,
validationExceptionConversionGenerator: ValidationExceptionConversionGenerator,
protocol: ServerProtocol,
) {
companion object {
/**
Expand Down Expand Up @@ -85,7 +86,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes(
ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false, validationExceptionConversionGenerator)

private val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"Structure" to structureSymbol,
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* Collection of methods that will be invoked by the respective generators to generate code to convert constraint
Expand All @@ -26,7 +27,7 @@ interface ValidationExceptionConversionGenerator {
* Convert from a top-level operation input's constraint violation into
* `aws_smithy_http_server::rejection::RequestRejection`.
*/
fun renderImplFromConstraintViolationForRequestRejection(): Writable
fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable

// Simple shapes.
fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection<StringTraitInfo>): Writable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol

import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
Expand Down Expand Up @@ -37,12 +36,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJson
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape

private fun allOperations(codegenContext: CodegenContext): List<OperationShape> {
val index = TopDownIndex.of(codegenContext.model)
return index.getContainedOperations(codegenContext.serviceShape).sortedBy { it.id }
}

interface ServerProtocol : Protocol {
/** The path such that `aws_smithy_http_server::proto::$path` points to the protocol's module. */
val protocolModulePath: String
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though this drys things up, I'm not sure how I feel about this. Might a constructor on RuntimeType, for requestRejection, responseRejection, and runtimeError, taking protocolModulePath: String be sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm the ServerProtocol should know where these types lie, we shouldn't invert the responsibility on the caller in ServerRuntimeType. protocolModulePath is just to make it so that most implementers can rely on the default implementations, but AWS JSON needs to switch based on version.

Copy link
Contributor

@hlbarber hlbarber Mar 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protocolModulePath is just to make it so that most implementers can rely on the default implementations

Right, in Rust this would be akin to adding a const DEFAULT: &'static str; to a trait in order to get default function implementations to work? In the case where functions don't follow the default implementation this field becomes useless?

Copy link
Contributor

@hlbarber hlbarber Mar 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm just misunderstanding Kotlin here, is the function implementation overridable?

fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType =
        ServerCargoDependency.smithyHttpServer(runtimeConfig)
            .toType().resolve("proto::$protocolModulePath::rejection::RequestRejection")

If not then I actually have no problem with this - we're tightening the ServerProtocol contract but there's no redundancy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is overridable. ServerAwsJsonProtocol overrides it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case where functions don't follow the default implementation this field becomes useless?

Not useless, the field can be used as implementations see fit. See e.g. the implementation of markerStruct in ServerRestJsonProtocol.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case where it's not being used by the default function impls then it can live as just a val in the implementing class?

This is a style nit, I won't block on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used by the default function impls.

Copy link
Contributor

@hlbarber hlbarber Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean in the case we are not using the default function impls, instead we're rolling our own*, then the val can live in the class implementing the interface.


/** Returns the Rust marker struct enjoying `OperationShape`. */
fun markerStruct(): RuntimeType

Expand Down Expand Up @@ -76,6 +73,21 @@ interface ServerProtocol : Protocol {
* Returns a boolean indicating whether to perform this check.
*/
fun serverContentTypeCheckNoModeledInput(): Boolean = false

/** The protocol-specific `RequestRejection` type. **/
fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::$protocolModulePath::rejection::RequestRejection")

/** The protocol-specific `ResponseRejection` type. **/
fun responseRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::$protocolModulePath::rejection::ResponseRejection")

/** The protocol-specific `RuntimeError` type. **/
fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::$protocolModulePath::runtime_error::RuntimeError")
}

class ServerAwsJsonProtocol(
Expand All @@ -84,6 +96,12 @@ class ServerAwsJsonProtocol(
) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol {
private val runtimeConfig = codegenContext.runtimeConfig

override val protocolModulePath: String
get() = when (version) {
is AwsJsonVersion.Json10 -> "aws_json_10"
is AwsJsonVersion.Json11 -> "aws_json_11"
}

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, serverCodegenContext.symbolProvider)) {
Expand All @@ -107,12 +125,8 @@ class ServerAwsJsonProtocol(

override fun markerStruct(): RuntimeType {
return when (version) {
is AwsJsonVersion.Json10 -> {
ServerRuntimeType.protocol("AwsJson1_0", "aws_json_10", runtimeConfig)
}
is AwsJsonVersion.Json11 -> {
ServerRuntimeType.protocol("AwsJson1_1", "aws_json_11", runtimeConfig)
}
is AwsJsonVersion.Json10 -> ServerRuntimeType.protocol("AwsJson1_0", protocolModulePath, runtimeConfig)
is AwsJsonVersion.Json11 -> ServerRuntimeType.protocol("AwsJson1_1", protocolModulePath, runtimeConfig)
}
}

Expand All @@ -139,6 +153,16 @@ class ServerAwsJsonProtocol(
AwsJsonVersion.Json10 -> "new_aws_json_10_router"
AwsJsonVersion.Json11 -> "new_aws_json_11_router"
}

override fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::rejection::RequestRejection")
override fun responseRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::rejection::ResponseRejection")
override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::runtime_error::RuntimeError")
}

private fun restRouterType(runtimeConfig: RuntimeConfig) =
Expand All @@ -150,6 +174,8 @@ class ServerRestJsonProtocol(
) : RestJson(serverCodegenContext), ServerProtocol {
val runtimeConfig = codegenContext.runtimeConfig

override val protocolModulePath: String = "rest_json_1"

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) {
Expand All @@ -173,7 +199,7 @@ class ServerRestJsonProtocol(
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver)

override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", "rest_json_1", runtimeConfig)
override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", protocolModulePath, runtimeConfig)

override fun routerType() = restRouterType(runtimeConfig)

Expand All @@ -196,8 +222,9 @@ class ServerRestXmlProtocol(
codegenContext: CodegenContext,
) : RestXml(codegenContext), ServerProtocol {
val runtimeConfig = codegenContext.runtimeConfig
override val protocolModulePath = "rest_xml"

override fun markerStruct() = ServerRuntimeType.protocol("RestXml", "rest_xml", runtimeConfig)
override fun markerStruct() = ServerRuntimeType.protocol("RestXml", protocolModulePath, runtimeConfig)

override fun routerType() = restRouterType(runtimeConfig)

Expand Down
Loading