Skip to content

Commit

Permalink
Move inputs, outputs, and op errors into operation modules (#2394)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdisanti authored Feb 17, 2023
1 parent afb1f16 commit 86bddca
Show file tree
Hide file tree
Showing 34 changed files with 208 additions and 162 deletions.
3 changes: 2 additions & 1 deletion aws/sdk/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ fun generateSmithyBuild(services: AwsServices): String {
"codegen": {
"includeFluentClient": false,
"renameErrors": false,
"eventStreamAllowList": [$eventStreamAllowListMembers]
"eventStreamAllowList": [$eventStreamAllowListMembers],
"enableNewCrateOrganizationScheme": false
},
"service": "${service.service}",
"module": "$moduleName",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ class ClientCodegenVisitor(
runtimeConfig = settings.runtimeConfig,
renameExceptions = settings.codegenConfig.renameExceptions,
nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
moduleProvider = ClientModuleProvider,
moduleProvider = when (settings.codegenConfig.enableNewCrateOrganizationScheme) {
true -> ClientModuleProvider
else -> OldModuleSchemeClientModuleProvider
},
)
val baseModel = baselineTransform(context.model)
val untransformedService = settings.getService(baseModel)
Expand Down Expand Up @@ -263,7 +266,7 @@ class ClientCodegenVisitor(
* Generate errors for operation shapes
*/
override fun operationShape(shape: OperationShape) {
rustCrate.withModule(ClientRustModule.Error) {
rustCrate.withModule(symbolProvider.moduleForOperationError(shape)) {
OperationErrorGenerator(
model,
symbolProvider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,31 @@

package software.amazon.smithy.rust.codegen.client.smithy

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.smithy.ModuleProvider
import software.amazon.smithy.rust.codegen.core.smithy.ModuleProviderContext
import software.amazon.smithy.rust.codegen.core.smithy.contextName
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

/**
* Modules for code generated client crates.
*/
object ClientRustModule {
/** crate::client */
val client = Client.self

object Client {
/** crate::client */
val self = RustModule.public("client", "Client and fluent builders for calling the service.")
Expand All @@ -40,20 +48,73 @@ object ClientRustModule {
}

object ClientModuleProvider : ModuleProvider {
override fun moduleForShape(shape: Shape): RustModule.LeafModule = when (shape) {
override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = when (shape) {
is OperationShape -> perOperationModule(context, shape)
is StructureShape -> when {
shape.hasTrait<ErrorTrait>() -> ClientRustModule.Error
shape.hasTrait<SyntheticInputTrait>() -> perOperationModule(context, shape)
shape.hasTrait<SyntheticOutputTrait>() -> perOperationModule(context, shape)
else -> ClientRustModule.Model
}

else -> ClientRustModule.Model
}

override fun moduleForOperationError(
context: ModuleProviderContext,
operation: OperationShape,
): RustModule.LeafModule = perOperationModule(context, operation)

override fun moduleForEventStreamError(
context: ModuleProviderContext,
eventStream: UnionShape,
): RustModule.LeafModule = ClientRustModule.Error

private fun Shape.findOperation(model: Model): OperationShape {
val inputTrait = getTrait<SyntheticInputTrait>()
val outputTrait = getTrait<SyntheticOutputTrait>()
return when {
this is OperationShape -> this
inputTrait != null -> model.expectShape(inputTrait.operation, OperationShape::class.java)
outputTrait != null -> model.expectShape(outputTrait.operation, OperationShape::class.java)
else -> UNREACHABLE("this is only called with compatible shapes")
}
}

private fun perOperationModule(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule {
val operationShape = shape.findOperation(context.model)
val contextName = operationShape.contextName(context.serviceShape)
val operationModuleName =
RustReservedWords.escapeIfNeeded(contextName.toSnakeCase())
return RustModule.public(
operationModuleName,
parent = ClientRustModule.Operation,
documentation = "Types for the `$contextName` operation.",
)
}
}

// TODO(CrateReorganization): Remove this provider
object OldModuleSchemeClientModuleProvider : ModuleProvider {
override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = when (shape) {
is OperationShape -> ClientRustModule.Operation
is StructureShape -> when {
shape.hasTrait<ErrorTrait>() -> ClientRustModule.Error
shape.hasTrait<SyntheticInputTrait>() -> ClientRustModule.Input
shape.hasTrait<SyntheticOutputTrait>() -> ClientRustModule.Output
else -> ClientRustModule.Model
}

else -> ClientRustModule.Model
}

override fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
ClientRustModule.Error
override fun moduleForOperationError(
context: ModuleProviderContext,
operation: OperationShape,
): RustModule.LeafModule = ClientRustModule.Error

override fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule =
ClientRustModule.Error
override fun moduleForEventStreamError(
context: ModuleProviderContext,
eventStream: UnionShape,
): RustModule.LeafModule = ClientRustModule.Error
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() {
fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, rustSymbolProviderConfig: RustSymbolProviderConfig) =
SymbolVisitor(model, serviceShape = serviceShape, config = rustSymbolProviderConfig)
// Generate different types for EventStream shapes (e.g. transcribe streaming)
.let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, model, CodegenTarget.CLIENT) }
.let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.CLIENT) }
// Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
.let { StreamingShapeSymbolProvider(it, model) }
.let { StreamingShapeSymbolProvider(it) }
// Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
.let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf(NonExhaustive)) }
.let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(NonExhaustive)) }
// Streaming shapes need different derives (e.g. they cannot derive `PartialEq`)
.let { StreamingShapeMetadataProvider(it, model) }
.let { StreamingShapeMetadataProvider(it) }
// Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot
// be the name of an operation input
.let { RustReservedWordSymbolProvider(it, model) }
.let { RustReservedWordSymbolProvider(it) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class PaginatorGenerator private constructor(
}

private val paginatorName = "${operation.id.name.toPascalCase()}Paginator"
private val runtimeConfig = symbolProvider.config().runtimeConfig
private val runtimeConfig = symbolProvider.config.runtimeConfig
private val idx = PaginatedIndex.of(model)
private val paginationInfo =
idx.getPaginationInfo(service, operation).orNull() ?: PANIC("failed to load pagination info")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ErrorGenerator(
private val error: ErrorTrait,
private val implCustomizations: List<ErrorImplCustomization>,
) {
private val runtimeConfig = symbolProvider.config().runtimeConfig
private val runtimeConfig = symbolProvider.config.runtimeConfig

fun render() {
val symbol = symbolProvider.toSymbol(shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ class OperationErrorGenerator(
private val operationOrEventStream: Shape,
private val customizations: List<ErrorCustomization>,
) {
private val runtimeConfig = symbolProvider.config().runtimeConfig
private val runtimeConfig = symbolProvider.config.runtimeConfig
private val symbol = symbolProvider.toSymbol(operationOrEventStream)
private val errorMetadata = errorMetadata(symbolProvider.config().runtimeConfig)
private val errorMetadata = errorMetadata(symbolProvider.config.runtimeConfig)
private val createUnhandledError =
RuntimeType.smithyHttp(runtimeConfig).resolve("result::CreateUnhandledError")

Expand Down Expand Up @@ -148,10 +148,10 @@ class OperationErrorGenerator(

writer.writeCustomizations(customizations, ErrorSection.OperationErrorAdditionalTraitImpls(errorSymbol, errors))

val retryErrorKindT = RuntimeType.retryErrorKind(symbolProvider.config().runtimeConfig)
val retryErrorKindT = RuntimeType.retryErrorKind(symbolProvider.config.runtimeConfig)
writer.rustBlock(
"impl #T for ${errorSymbol.name}",
RuntimeType.provideErrorKind(symbolProvider.config().runtimeConfig),
RuntimeType.provideErrorKind(symbolProvider.config.runtimeConfig),
) {
rustBlock("fn code(&self) -> Option<&str>") {
rust("#T::code(self)", RuntimeType.provideErrorMetadataTrait(runtimeConfig))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenConfig
import software.amazon.smithy.rust.codegen.client.smithy.ClientModuleProvider
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustSettings
import software.amazon.smithy.rust.codegen.client.smithy.OldModuleSchemeClientModuleProvider
import software.amazon.smithy.rust.codegen.client.smithy.RustClientCodegenPlugin
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
Expand Down Expand Up @@ -53,7 +53,7 @@ val ClientTestRustSymbolProviderConfig = RustSymbolProviderConfig(
runtimeConfig = TestRuntimeConfig,
renameExceptions = true,
nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
moduleProvider = ClientModuleProvider,
moduleProvider = OldModuleSchemeClientModuleProvider,
)

fun testSymbolProvider(model: Model, serviceShape: ServiceShape? = null): RustSymbolProvider =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class EventStreamSymbolProviderTest {
)

val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestRustSymbolProviderConfig), model, CodegenTarget.CLIENT)
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestRustSymbolProviderConfig), CodegenTarget.CLIENT)

// Look up the synthetic input/output rather than the original input/output
val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape
Expand Down Expand Up @@ -82,7 +82,7 @@ class EventStreamSymbolProviderTest {
)

val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestRustSymbolProviderConfig), model, CodegenTarget.CLIENT)
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestRustSymbolProviderConfig), CodegenTarget.CLIENT)

// Look up the synthetic input/output rather than the original input/output
val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.core.rustlang
import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider
import software.amazon.smithy.codegen.core.ReservedWords
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
Expand All @@ -23,8 +22,7 @@ import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, private val model: Model) :
WrappingSymbolProvider(base) {
class RustReservedWordSymbolProvider(private val base: RustSymbolProvider) : WrappingSymbolProvider(base) {
private val internal =
ReservedWordSymbolProvider.builder().symbolProvider(base).memberReservedWords(RustReservedWords).build()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
package software.amazon.smithy.rust.codegen.core.smithy

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
Expand All @@ -24,18 +21,12 @@ import software.amazon.smithy.rust.codegen.core.util.isEventStream
import software.amazon.smithy.rust.codegen.core.util.isInputEventStream
import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream

fun UnionShape.eventStreamErrorSymbol(symbolProvider: RustSymbolProvider): RuntimeType {
val unionSymbol = symbolProvider.toSymbol(this)
return RustModule.Error.toType().resolve("${unionSymbol.name}Error")
}

/**
* Wrapping symbol provider to wrap modeled types with the aws-smithy-http Event Stream send/receive types.
*/
class EventStreamSymbolProvider(
private val runtimeConfig: RuntimeConfig,
base: RustSymbolProvider,
private val model: Model,
private val target: CodegenTarget,
) : WrappingSymbolProvider(base) {
override fun toSymbol(shape: Shape): Symbol {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ package software.amazon.smithy.rust.codegen.core.smithy

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
Expand All @@ -18,15 +20,19 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
/**
* SymbolProvider interface that carries both the inner configuration and a function to produce an enum variant name.
*/
interface RustSymbolProvider : SymbolProvider, ModuleProvider {
fun config(): RustSymbolProviderConfig
interface RustSymbolProvider : SymbolProvider {
val model: Model
val moduleProviderContext: ModuleProviderContext
val config: RustSymbolProviderConfig

fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed?

override fun moduleForShape(shape: Shape): RustModule.LeafModule = config().moduleProvider.moduleForShape(shape)
override fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
config().moduleProvider.moduleForOperationError(operation)
override fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule =
config().moduleProvider.moduleForEventStreamError(eventStream)
fun moduleForShape(shape: Shape): RustModule.LeafModule =
config.moduleProvider.moduleForShape(moduleProviderContext, shape)
fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
config.moduleProvider.moduleForOperationError(moduleProviderContext, operation)
fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule =
config.moduleProvider.moduleForEventStreamError(moduleProviderContext, eventStream)

/** Returns the symbol for an operation error */
fun symbolForOperationError(operation: OperationShape): Symbol
Expand All @@ -35,18 +41,29 @@ interface RustSymbolProvider : SymbolProvider, ModuleProvider {
fun symbolForEventStreamError(eventStream: UnionShape): Symbol
}

/**
* Module providers can't use the full CodegenContext since they're invoked from
* inside the SymbolVisitor, which is created before CodegenContext is created.
*/
data class ModuleProviderContext(
val model: Model,
val serviceShape: ServiceShape?,
)

fun CodegenContext.toModuleProviderContext(): ModuleProviderContext = ModuleProviderContext(model, serviceShape)

/**
* Provider for RustModules so that the symbol provider knows where to organize things.
*/
interface ModuleProvider {
/** Returns the module for a shape */
fun moduleForShape(shape: Shape): RustModule.LeafModule
fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule

/** Returns the module for an operation error */
fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule
fun moduleForOperationError(context: ModuleProviderContext, operation: OperationShape): RustModule.LeafModule

/** Returns the module for an event stream error */
fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule
fun moduleForEventStreamError(context: ModuleProviderContext, eventStream: UnionShape): RustModule.LeafModule
}

/**
Expand All @@ -63,7 +80,10 @@ data class RustSymbolProviderConfig(
* Default delegator to enable easily decorating another symbol provider.
*/
open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSymbolProvider {
override fun config(): RustSymbolProviderConfig = base.config()
override val model: Model get() = base.model
override val moduleProviderContext: ModuleProviderContext get() = base.moduleProviderContext
override val config: RustSymbolProviderConfig get() = base.config

override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? = base.toEnumVariantName(definition)
override fun toSymbol(shape: Shape): Symbol = base.toSymbol(shape)
override fun toMemberName(shape: MemberShape): String = base.toMemberName(shape)
Expand Down
Loading

0 comments on commit 86bddca

Please sign in to comment.