Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Python unions #2427

Merged
merged 16 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ target/

# IDEs
.idea/
.project
.settings
.classpath

# tools
.tool-versions
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ class Attribute(val inner: Writable) {
val AllowClippyUnnecessaryWraps = Attribute(allow("clippy::unnecessary_wraps"))
val AllowClippyUselessConversion = Attribute(allow("clippy::useless_conversion"))
val AllowClippyUnnecessaryLazyEvaluations = Attribute(allow("clippy::unnecessary_lazy_evaluations"))
val AllowClippyTooManyArguments = Attribute(allow("clippy::too_many_arguments"))
val AllowDeadCode = Attribute(allow("dead_code"))
val AllowDeprecated = Attribute(allow("deprecated"))
val AllowIrrefutableLetPatterns = Attribute(allow("irrefutable_let_patterns"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ open class StructureGenerator(
writer.rustBlock("impl $name") {
// Render field accessor methods
forEachMember(accessorMembers) { member, memberName, memberSymbol ->
renderMemberDoc(member, memberSymbol)
writer.renderMemberDoc(member, memberSymbol)
writer.deprecatedShape(member)
val memberType = memberSymbol.rustType()
val returnType = when {
Expand All @@ -140,7 +140,7 @@ open class StructureGenerator(
memberType.isDeref() -> memberType.asDeref().asRef()
else -> memberType.asRef()
}
rustBlock("pub fn $memberName(&self) -> ${returnType.render()}") {
writer.rustBlock("pub fn $memberName(&self) -> ${returnType.render()}") {
when {
memberType.isCopy() -> rust("self.$memberName")
memberType is RustType.Option && memberType.member.isDeref() -> rust("self.$memberName.as_deref()")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fun CodegenTarget.renderUnknownVariant() = when (this) {
* Finally, if `[renderUnknownVariant]` is true (the default), it will render an `Unknown` variant. This is used by
* clients to allow response parsing to succeed, even if the server has added a new variant since the client was generated.
*/
class UnionGenerator(
open class UnionGenerator(
val model: Model,
private val symbolProvider: SymbolProvider,
private val writer: RustWriter,
Expand All @@ -60,7 +60,7 @@ class UnionGenerator(
private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
private val unionSymbol = symbolProvider.toSymbol(shape)

fun render() {
open fun render() {
writer.documentShape(shape, model)
writer.deprecatedShape(shape)

Expand Down
12 changes: 12 additions & 0 deletions codegen-server-test/python/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ val allCodegenTests = "../../codegen-core/common-test-models".let { commonModels
listOf(
CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")),
CodegenTest("com.aws.example.python#PokemonService", "pokemon-service-server-sdk"),
CodegenTest(
"com.amazonaws.ebs#Ebs", "ebs",
imports = listOf("$commonModels/ebs.json"),
extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
),
CodegenTest(
"aws.protocoltests.misc#MiscService",
"misc",
imports = listOf("$commonModels/misc.smithy"),
// TODO(https://github.com/awslabs/smithy-rs/issues/1401) `@uniqueItems` is used.
extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
),
)
}

Expand Down
16 changes: 9 additions & 7 deletions codegen-server-test/python/model/pokemon.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ use com.aws.example#Storage
use com.aws.example#GetServerStatistics
use com.aws.example#DoNothing
use com.aws.example#CheckHealth
use smithy.framework#ValidationException


/// The Pokémon Service allows you to retrieve information about Pokémon species.
@title("Pokémon Service")
@restJson1
service PokemonService {
version: "2021-12-01",
resources: [PokemonSpecies],
version: "2021-12-01"
resources: [PokemonSpecies]
operations: [
GetServerStatistics,
DoNothing,
CheckHealth,
GetServerStatistics
DoNothing
CheckHealth
StreamPokemonRadio
],
}
Expand All @@ -30,13 +32,13 @@ service PokemonService {
@readonly
@http(uri: "/radio", method: "GET")
operation StreamPokemonRadio {
output: StreamPokemonRadioOutput,
output: StreamPokemonRadioOutput
}

@output
structure StreamPokemonRadioOutput {
@httpPayload
data: StreamingBlob,
data: StreamingBlob
}

@streaming
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
package software.amazon.smithy.rust.codegen.server.python.smithy

import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.shapes.OperationShape
Expand All @@ -22,19 +21,27 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.isEventStream
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerOperationHandlerGenerator
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerServiceGenerator
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerUnionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenVisitor
import software.amazon.smithy.rust.codegen.server.smithy.ServerModuleDocProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerModuleProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import software.amazon.smithy.rust.codegen.server.smithy.ServerSymbolProviders
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.createInlineModuleCreator
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationErrorGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput

/**
* Entrypoint for Python server-side code generation. This class will walk the in-memory model and
Expand Down Expand Up @@ -82,7 +89,7 @@ class PythonServerCodegenVisitor(
publicConstrainedTypes: Boolean,
includeConstraintShapeProvider: Boolean,
codegenDecorator: ServerCodegenDecorator,
) = RustServerCodegenPythonPlugin.baseSymbolProvider(settings, model, serviceShape, rustSymbolProviderConfig, publicConstrainedTypes, codegenDecorator)
) = RustServerCodegenPythonPlugin.baseSymbolProvider(settings, model, serviceShape, rustSymbolProviderConfig, publicConstrainedTypes, includeConstraintShapeProvider, codegenDecorator)
unexge marked this conversation as resolved.
Show resolved Hide resolved

val serverSymbolProviders = ServerSymbolProviders.from(
settings,
Expand Down Expand Up @@ -178,7 +185,32 @@ class PythonServerCodegenVisitor(
* Note: this does not generate serializers
*/
override fun unionShape(shape: UnionShape) {
throw CodegenException("Union shapes are not supported in Python yet")
logger.info("[python-server-codegen] Generating an union shape $shape")
rustCrate.useShapeWriter(shape) {
PythonServerUnionGenerator(model, codegenContext.symbolProvider, this, shape, renderUnknownVariant = false).render()
}

if (shape.isReachableFromOperationInput() && shape.canReachConstrainedShape(
model,
codegenContext.symbolProvider,
)
) {
logger.info("[python-server-codegen] Generating an unconstrained type for union shape $shape")
rustCrate.withModule(ServerRustModule.UnconstrainedModule) modelsModuleWriter@{
UnconstrainedUnionGenerator(
codegenContext,
rustCrate.createInlineModuleCreator(),
this@modelsModuleWriter,
shape,
).render()
}
}

if (shape.isEventStream()) {
rustCrate.withModule(ServerRustModule.Error) {
ServerOperationErrorGenerator(model, codegenContext.symbolProvider, shape).render(this)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,35 @@ import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTra
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import java.util.logging.Logger

/* Returns the Python implementation of the ByteStream shape or the original symbol that is provided in input. */
private fun toPythonByteStreamSymbolOrOriginal(model: Model, config: RustSymbolProviderConfig, initial: Symbol, shape: Shape): Symbol {
if (shape !is MemberShape) {
return initial
}

val target = model.expectShape(shape.target)
val container = model.expectShape(shape.container)

if (!container.hasTrait<SyntheticOutputTrait>() && !container.hasTrait<SyntheticInputTrait>()) {
unexge marked this conversation as resolved.
Show resolved Hide resolved
return initial
}

// We are only targeting streaming blobs as the rest of the symbols do not change if streaming is enabled.
// For example a TimestampShape doesn't become a different symbol when streaming is involved, but BlobShape
// become a ByteStream.
return if (target is BlobShape && shape.isStreaming(model)) {
PythonServerRuntimeType.byteStream(config.runtimeConfig).toSymbol()
} else {
initial
}
}

/**
* Symbol visitor allowing that recursively replace symbols in nested shapes.
* Symbol provider that recursively replace symbols in nested shapes.
*
* Input / output / error structures can refer to complex types like the ones implemented inside
* `aws_smithy_types` (a good example is `aws_smithy_types::Blob`).
Expand All @@ -50,30 +75,13 @@ class PythonServerSymbolVisitor(
serviceShape: ServiceShape?,
config: RustSymbolProviderConfig,
) : SymbolVisitor(settings, model, serviceShape, config) {

private val runtimeConfig = config.runtimeConfig
private val logger = Logger.getLogger(javaClass.name)

override fun toSymbol(shape: Shape): Symbol {
val initial = shape.accept(this)

if (shape !is MemberShape) {
return initial
}
val target = model.expectShape(shape.target)
val container = model.expectShape(shape.container)

// We are only targeting non-synthetic inputs and outputs.
if (!container.hasTrait<SyntheticOutputTrait>() && !container.hasTrait<SyntheticInputTrait>()) {
return initial
}

// We are only targeting streaming blobs as the rest of the symbols do not change if streaming is enabled.
// For example a TimestampShape doesn't become a different symbol when streaming is involved, but BlobShape
// become a ByteStream.
return if (target is BlobShape && shape.isStreaming(model)) {
PythonServerRuntimeType.byteStream(config.runtimeConfig).toSymbol()
} else {
initial
}
return toPythonByteStreamSymbolOrOriginal(model, config, initial, shape)
}

override fun timestampShape(shape: TimestampShape?): Symbol {
Expand All @@ -89,6 +97,27 @@ class PythonServerSymbolVisitor(
}
}

/**
* Constrained symbol provider that recursively replace symbols in nested shapes.
*
* This symbol provider extends the `ConstrainedShapeSymbolProvider` to ensure constraints are
* applied properly and swaps out shapes that do not implement `pyo3::PyClass` with their
* wrappers.
*
* See `PythonServerSymbolVisitor` documentation for more info.
*/
class PythonConstrainedShapeSymbolProvider(
base: RustSymbolProvider,
serviceShape: ServiceShape,
publicConstrainedTypes: Boolean,
) : ConstrainedShapeSymbolProvider(base, serviceShape, publicConstrainedTypes) {

override fun toSymbol(shape: Shape): Symbol {
val initial = super.toSymbol(shape)
return toPythonByteStreamSymbolOrOriginal(model, config, initial, shape)
}
}

/**
* SymbolProvider to drop the PartialEq bounds in streaming shapes
*
Expand All @@ -98,6 +127,7 @@ class PythonServerSymbolVisitor(
* Note that since streaming members can only be used on the root shape, this can only impact input and output shapes.
*/
class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider) : SymbolMetadataProvider(base) {

override fun structureMeta(structureShape: StructureShape): RustMetadata {
val baseMetadata = base.toSymbol(structureShape).expectRustMetadata()
return if (structureShape.hasStreamingMember(model)) {
Expand All @@ -118,7 +148,6 @@ class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider)

override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata()
override fun enumMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata()

override fun listMeta(listShape: ListShape) = base.toSymbol(listShape).expectRustMetadata()
override fun mapMeta(mapShape: MapShape) = base.toSymbol(mapShape).expectRustMetadata()
override fun stringMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS
import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.DeriveEqAndHashSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerReservedWords
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
Expand Down Expand Up @@ -76,6 +75,7 @@ class RustServerCodegenPythonPlugin : SmithyBuildPlugin {
serviceShape: ServiceShape,
rustSymbolProviderConfig: RustSymbolProviderConfig,
constrainedTypes: Boolean = true,
includeConstrainedShapeProvider: Boolean = true,
crisidev marked this conversation as resolved.
Show resolved Hide resolved
codegenDecorator: ServerCodegenDecorator,
) =
// Rename a set of symbols that do not implement `PyClass` and have been wrapped in
Expand All @@ -85,7 +85,7 @@ class RustServerCodegenPythonPlugin : SmithyBuildPlugin {
// In the Python server project, this is only done to generate constrained types for simple shapes (e.g.
// a `string` shape with the `length` trait), but these always remain `pub(crate)`.
.let {
if (constrainedTypes) ConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it
if (includeConstrainedShapeProvider) PythonConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it
crisidev marked this conversation as resolved.
Show resolved Hide resolved
}
// Generate different types for EventStream shapes (e.g. transcribe streaming)
.let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.SERVER) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ResourceShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
Expand Down Expand Up @@ -67,12 +68,20 @@ class PythonServerModuleGenerator(
serviceShapes.forEach { shape ->
val moduleType = moduleType(shape)
if (moduleType != null) {
rustTemplate(
"""
$moduleType.add_class::<crate::$moduleType::${shape.id.name}>()?;
""",
*codegenScope,
)
when (shape) {
is UnionShape -> rustTemplate(
"""
$moduleType.add_class::<crate::$moduleType::PyUnionMarker${shape.id.name}>()?;
""",
*codegenScope,
)
else -> rustTemplate(
"""
$moduleType.add_class::<crate::$moduleType::${shape.id.name}>()?;
""",
*codegenScope,
)
}
}
}
rustTemplate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class PythonServerStructureGenerator(

private fun renderPyO3Methods() {
Attribute.AllowClippyNewWithoutDefault.render(writer)
Attribute.AllowClippyTooManyArguments.render(writer)
Attribute(pyO3.resolve("pymethods")).render(writer)
writer.rustTemplate(
"""
Expand Down
Loading