Skip to content

Commit

Permalink
Implement Python unions (#2427)
Browse files Browse the repository at this point in the history
* Add initial implementation of unions with very broken symbol provider

* Add support for creating new unions in Python

* Generate getters and static methods for unions

* Allow to compile misc model

Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com>

* Doesn't work

* Now it works

* Simplify code a little

* Remove leftover from the many tries I did

* Finally fixed model generation with unions

* Fix wrong import

* Update to reflect changes in decorators

* Remove debugging output

* Simplify symbol provider

* Follow PR suggestions

* Remove union operation from python example

* Return `PyUnionMarker` for wrapped type in `IntoPy` impl

---------

Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com>
Co-authored-by: Burak Varlı <burakvar@amazon.co.uk>
crisidev and unexge authored Mar 17, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d8a7d99 commit 48eda40
Showing 15 changed files with 338 additions and 48 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -50,3 +50,9 @@ target/

# IDEs
.idea/
.project
.settings
.classpath

# tools
.tool-versions
Original file line number Diff line number Diff line change
@@ -466,6 +466,7 @@ class Attribute(val inner: Writable) {
val AllowClippyUnnecessaryWraps = Attribute(allow("clippy::unnecessary_wraps"))
val AllowClippyUselessConversion = Attribute(allow("clippy::useless_conversion"))
val AllowClippyUnnecessaryLazyEvaluations = Attribute(allow("clippy::unnecessary_lazy_evaluations"))
val AllowClippyTooManyArguments = Attribute(allow("clippy::too_many_arguments"))
val AllowDeadCode = Attribute(allow("dead_code"))
val AllowDeprecated = Attribute(allow("deprecated"))
val AllowIrrefutableLetPatterns = Attribute(allow("irrefutable_let_patterns"))
Original file line number Diff line number Diff line change
@@ -131,7 +131,7 @@ open class StructureGenerator(
writer.rustBlock("impl $name") {
// Render field accessor methods
forEachMember(accessorMembers) { member, memberName, memberSymbol ->
renderMemberDoc(member, memberSymbol)
writer.renderMemberDoc(member, memberSymbol)
writer.deprecatedShape(member)
val memberType = memberSymbol.rustType()
val returnType = when {
@@ -140,7 +140,7 @@ open class StructureGenerator(
memberType.isDeref() -> memberType.asDeref().asRef()
else -> memberType.asRef()
}
rustBlock("pub fn $memberName(&self) -> ${returnType.render()}") {
writer.rustBlock("pub fn $memberName(&self) -> ${returnType.render()}") {
when {
memberType.isCopy() -> rust("self.$memberName")
memberType is RustType.Option && memberType.member.isDeref() -> rust("self.$memberName.as_deref()")
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ fun CodegenTarget.renderUnknownVariant() = when (this) {
* Finally, if `[renderUnknownVariant]` is true (the default), it will render an `Unknown` variant. This is used by
* clients to allow response parsing to succeed, even if the server has added a new variant since the client was generated.
*/
class UnionGenerator(
open class UnionGenerator(
val model: Model,
private val symbolProvider: SymbolProvider,
private val writer: RustWriter,
@@ -60,7 +60,7 @@ class UnionGenerator(
private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
private val unionSymbol = symbolProvider.toSymbol(shape)

fun render() {
open fun render() {
writer.documentShape(shape, model)
writer.deprecatedShape(shape)

12 changes: 12 additions & 0 deletions codegen-server-test/python/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -42,6 +42,18 @@ val allCodegenTests = "../../codegen-core/common-test-models".let { commonModels
listOf(
CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")),
CodegenTest("com.aws.example.python#PokemonService", "pokemon-service-server-sdk"),
CodegenTest(
"com.amazonaws.ebs#Ebs", "ebs",
imports = listOf("$commonModels/ebs.json"),
extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
),
CodegenTest(
"aws.protocoltests.misc#MiscService",
"misc",
imports = listOf("$commonModels/misc.smithy"),
// TODO(https://github.com/awslabs/smithy-rs/issues/1401) `@uniqueItems` is used.
extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
),
)
}

16 changes: 9 additions & 7 deletions codegen-server-test/python/model/pokemon.smithy
Original file line number Diff line number Diff line change
@@ -11,17 +11,19 @@ use com.aws.example#Storage
use com.aws.example#GetServerStatistics
use com.aws.example#DoNothing
use com.aws.example#CheckHealth
use smithy.framework#ValidationException


/// The Pokémon Service allows you to retrieve information about Pokémon species.
@title("Pokémon Service")
@restJson1
service PokemonService {
version: "2021-12-01",
resources: [PokemonSpecies],
version: "2021-12-01"
resources: [PokemonSpecies]
operations: [
GetServerStatistics,
DoNothing,
CheckHealth,
GetServerStatistics
DoNothing
CheckHealth
StreamPokemonRadio
],
}
@@ -30,13 +32,13 @@ service PokemonService {
@readonly
@http(uri: "/radio", method: "GET")
operation StreamPokemonRadio {
output: StreamPokemonRadioOutput,
output: StreamPokemonRadioOutput
}

@output
structure StreamPokemonRadioOutput {
@httpPayload
data: StreamingBlob,
data: StreamingBlob
}

@streaming
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@
package software.amazon.smithy.rust.codegen.server.python.smithy

import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.shapes.OperationShape
@@ -22,19 +21,27 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.isEventStream
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerOperationHandlerGenerator
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerServiceGenerator
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerUnionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenVisitor
import software.amazon.smithy.rust.codegen.server.smithy.ServerModuleDocProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerModuleProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import software.amazon.smithy.rust.codegen.server.smithy.ServerSymbolProviders
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.createInlineModuleCreator
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationErrorGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput

/**
* Entrypoint for Python server-side code generation. This class will walk the in-memory model and
@@ -82,7 +89,7 @@ class PythonServerCodegenVisitor(
publicConstrainedTypes: Boolean,
includeConstraintShapeProvider: Boolean,
codegenDecorator: ServerCodegenDecorator,
) = RustServerCodegenPythonPlugin.baseSymbolProvider(settings, model, serviceShape, rustSymbolProviderConfig, publicConstrainedTypes, codegenDecorator)
) = RustServerCodegenPythonPlugin.baseSymbolProvider(settings, model, serviceShape, rustSymbolProviderConfig, publicConstrainedTypes, includeConstraintShapeProvider, codegenDecorator)

val serverSymbolProviders = ServerSymbolProviders.from(
settings,
@@ -178,7 +185,32 @@ class PythonServerCodegenVisitor(
* Note: this does not generate serializers
*/
override fun unionShape(shape: UnionShape) {
throw CodegenException("Union shapes are not supported in Python yet")
logger.info("[python-server-codegen] Generating an union shape $shape")
rustCrate.useShapeWriter(shape) {
PythonServerUnionGenerator(model, codegenContext.symbolProvider, this, shape, renderUnknownVariant = false).render()
}

if (shape.isReachableFromOperationInput() && shape.canReachConstrainedShape(
model,
codegenContext.symbolProvider,
)
) {
logger.info("[python-server-codegen] Generating an unconstrained type for union shape $shape")
rustCrate.withModule(ServerRustModule.UnconstrainedModule) modelsModuleWriter@{
UnconstrainedUnionGenerator(
codegenContext,
rustCrate.createInlineModuleCreator(),
this@modelsModuleWriter,
shape,
).render()
}
}

if (shape.isEventStream()) {
rustCrate.withModule(ServerRustModule.Error) {
ServerOperationErrorGenerator(model, codegenContext.symbolProvider, shape).render(this)
}
}
}

/**
Original file line number Diff line number Diff line change
@@ -31,10 +31,35 @@ import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTra
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import java.util.logging.Logger

/* Returns the Python implementation of the ByteStream shape or the original symbol that is provided in input. */
private fun toPythonByteStreamSymbolOrOriginal(model: Model, config: RustSymbolProviderConfig, initial: Symbol, shape: Shape): Symbol {
if (shape !is MemberShape) {
return initial
}

val target = model.expectShape(shape.target)
val container = model.expectShape(shape.container)

if (!container.hasTrait<SyntheticOutputTrait>() && !container.hasTrait<SyntheticInputTrait>()) {
return initial
}

// We are only targeting streaming blobs as the rest of the symbols do not change if streaming is enabled.
// For example a TimestampShape doesn't become a different symbol when streaming is involved, but BlobShape
// become a ByteStream.
return if (target is BlobShape && shape.isStreaming(model)) {
PythonServerRuntimeType.byteStream(config.runtimeConfig).toSymbol()
} else {
initial
}
}

/**
* Symbol visitor allowing that recursively replace symbols in nested shapes.
* Symbol provider that recursively replace symbols in nested shapes.
*
* Input / output / error structures can refer to complex types like the ones implemented inside
* `aws_smithy_types` (a good example is `aws_smithy_types::Blob`).
@@ -50,30 +75,13 @@ class PythonServerSymbolVisitor(
serviceShape: ServiceShape?,
config: RustSymbolProviderConfig,
) : SymbolVisitor(settings, model, serviceShape, config) {

private val runtimeConfig = config.runtimeConfig
private val logger = Logger.getLogger(javaClass.name)

override fun toSymbol(shape: Shape): Symbol {
val initial = shape.accept(this)

if (shape !is MemberShape) {
return initial
}
val target = model.expectShape(shape.target)
val container = model.expectShape(shape.container)

// We are only targeting non-synthetic inputs and outputs.
if (!container.hasTrait<SyntheticOutputTrait>() && !container.hasTrait<SyntheticInputTrait>()) {
return initial
}

// We are only targeting streaming blobs as the rest of the symbols do not change if streaming is enabled.
// For example a TimestampShape doesn't become a different symbol when streaming is involved, but BlobShape
// become a ByteStream.
return if (target is BlobShape && shape.isStreaming(model)) {
PythonServerRuntimeType.byteStream(config.runtimeConfig).toSymbol()
} else {
initial
}
return toPythonByteStreamSymbolOrOriginal(model, config, initial, shape)
}

override fun timestampShape(shape: TimestampShape?): Symbol {
@@ -89,6 +97,27 @@ class PythonServerSymbolVisitor(
}
}

/**
* Constrained symbol provider that recursively replace symbols in nested shapes.
*
* This symbol provider extends the `ConstrainedShapeSymbolProvider` to ensure constraints are
* applied properly and swaps out shapes that do not implement `pyo3::PyClass` with their
* wrappers.
*
* See `PythonServerSymbolVisitor` documentation for more info.
*/
class PythonConstrainedShapeSymbolProvider(
base: RustSymbolProvider,
serviceShape: ServiceShape,
publicConstrainedTypes: Boolean,
) : ConstrainedShapeSymbolProvider(base, serviceShape, publicConstrainedTypes) {

override fun toSymbol(shape: Shape): Symbol {
val initial = super.toSymbol(shape)
return toPythonByteStreamSymbolOrOriginal(model, config, initial, shape)
}
}

/**
* SymbolProvider to drop the PartialEq bounds in streaming shapes
*
@@ -98,6 +127,7 @@ class PythonServerSymbolVisitor(
* Note that since streaming members can only be used on the root shape, this can only impact input and output shapes.
*/
class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider) : SymbolMetadataProvider(base) {

override fun structureMeta(structureShape: StructureShape): RustMetadata {
val baseMetadata = base.toSymbol(structureShape).expectRustMetadata()
return if (structureShape.hasStreamingMember(model)) {
@@ -118,7 +148,6 @@ class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider)

override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata()
override fun enumMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata()

override fun listMeta(listShape: ListShape) = base.toSymbol(listShape).expectRustMetadata()
override fun mapMeta(mapShape: MapShape) = base.toSymbol(mapShape).expectRustMetadata()
override fun stringMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata()
Original file line number Diff line number Diff line change
@@ -17,7 +17,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS
import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.DeriveEqAndHashSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerReservedWords
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
@@ -76,6 +75,7 @@ class RustServerCodegenPythonPlugin : SmithyBuildPlugin {
serviceShape: ServiceShape,
rustSymbolProviderConfig: RustSymbolProviderConfig,
constrainedTypes: Boolean = true,
includeConstrainedShapeProvider: Boolean = true,
codegenDecorator: ServerCodegenDecorator,
) =
// Rename a set of symbols that do not implement `PyClass` and have been wrapped in
@@ -85,7 +85,7 @@ class RustServerCodegenPythonPlugin : SmithyBuildPlugin {
// In the Python server project, this is only done to generate constrained types for simple shapes (e.g.
// a `string` shape with the `length` trait), but these always remain `pub(crate)`.
.let {
if (constrainedTypes) ConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it
if (includeConstrainedShapeProvider) PythonConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it
}
// Generate different types for EventStream shapes (e.g. transcribe streaming)
.let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.SERVER) }
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ResourceShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
@@ -67,12 +68,20 @@ class PythonServerModuleGenerator(
serviceShapes.forEach { shape ->
val moduleType = moduleType(shape)
if (moduleType != null) {
rustTemplate(
"""
$moduleType.add_class::<crate::$moduleType::${shape.id.name}>()?;
""",
*codegenScope,
)
when (shape) {
is UnionShape -> rustTemplate(
"""
$moduleType.add_class::<crate::$moduleType::PyUnionMarker${shape.id.name}>()?;
""",
*codegenScope,
)
else -> rustTemplate(
"""
$moduleType.add_class::<crate::$moduleType::${shape.id.name}>()?;
""",
*codegenScope,
)
}
}
}
rustTemplate(
Original file line number Diff line number Diff line change
@@ -75,6 +75,7 @@ class PythonServerStructureGenerator(

private fun renderPyO3Methods() {
Attribute.AllowClippyNewWithoutDefault.render(writer)
Attribute.AllowClippyTooManyArguments.render(writer)
Attribute(pyO3.resolve("pymethods")).render(writer)
writer.rustTemplate(
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.server.python.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.isCopy
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.isTargetUnit
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.python.smithy.pythonType
import software.amazon.smithy.rust.codegen.server.python.smithy.renderAsDocstring

/*
* Generate unions that are compatible with Python by wrapping the Rust implementation into
* a new structure and implementing `IntoPy` and `FromPyObject` to ensure the ability to extract
* the union inside the Python context.
*/
class PythonServerUnionGenerator(
model: Model,
private val symbolProvider: SymbolProvider,
private val writer: RustWriter,
shape: UnionShape,
private val renderUnknownVariant: Boolean = true,
) : UnionGenerator(model, symbolProvider, writer, shape, renderUnknownVariant) {
private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
private val unionSymbol = symbolProvider.toSymbol(shape)

private val pyo3 = PythonServerCargoDependency.PyO3.toType()

override fun render() {
super.render()
renderPyUnionStruct()
renderPyUnionImpl()
renderPyObjectConverters()
}

private fun renderPyUnionStruct() {
writer.rust("""##[pyo3::pyclass(name = "${unionSymbol.name}")]""")
val containerMeta = unionSymbol.expectRustMetadata()
containerMeta.render(writer)
writer.rust("struct PyUnionMarker${unionSymbol.name}(pub ${unionSymbol.name});")
}

private fun renderPyUnionImpl() {
Attribute(pyo3.resolve("pymethods")).render(writer)
writer.rustBlock("impl PyUnionMarker${unionSymbol.name}") {
sortedMembers.forEach { member ->
val funcNamePart = member.memberName.toSnakeCase()
val variantName = symbolProvider.toMemberName(member)

if (sortedMembers.size == 1) {
Attribute.AllowIrrefutableLetPatterns.render(this)
}
renderNewVariant(writer, model, symbolProvider, member, variantName, funcNamePart, unionSymbol)
renderAsVariant(writer, model, symbolProvider, member, variantName, funcNamePart, unionSymbol)
rust("/// Returns true if this is a [`$variantName`](#T::$variantName).", unionSymbol)
rust("/// :rtype bool:")
rustBlock("pub fn is_$funcNamePart(&self) -> bool") {
rust("self.0.is_$funcNamePart()")
}
}
if (renderUnknownVariant) {
rust("/// Returns true if the union instance is the `Unknown` variant.")
rust("/// :rtype bool:")
rustBlock("pub fn is_unknown(&self) -> bool") {
rust("self.0.is_unknown()")
}
}
}
}

private fun renderPyObjectConverters() {
writer.rustBlockTemplate("impl #{pyo3}::IntoPy<#{pyo3}::PyObject> for ${unionSymbol.name}", "pyo3" to pyo3) {
rustBlockTemplate("fn into_py(self, py: #{pyo3}::Python<'_>) -> #{pyo3}::PyObject", "pyo3" to pyo3) {
rust("PyUnionMarker${unionSymbol.name}(self).into_py(py)")
}
}
writer.rustBlockTemplate("impl<'source> #{pyo3}::FromPyObject<'source> for ${unionSymbol.name}", "pyo3" to pyo3) {
rustBlockTemplate("fn extract(obj: &'source #{pyo3}::PyAny) -> #{pyo3}::PyResult<Self>", "pyo3" to pyo3) {
rust(
"""
let data: PyUnionMarker${unionSymbol.name} = obj.extract()?;
Ok(data.0)
""",
)
}
}
}

private fun renderNewVariant(
writer: RustWriter,
model: Model,
symbolProvider: SymbolProvider,
member: MemberShape,
variantName: String,
funcNamePart: String,
unionSymbol: Symbol,
) {
if (member.isTargetUnit()) {
Attribute("staticmethod").render(writer)
writer.rust(
"/// Creates a new union instance of [`$variantName`](#T::$variantName)",
unionSymbol,
)
writer.rust("/// :rtype ${unionSymbol.name}:")
writer.rustBlock("pub fn $funcNamePart() -> Self") {
rust("Self(${unionSymbol.name}::$variantName")
}
} else {
val memberSymbol = symbolProvider.toSymbol(member)
val pythonType = memberSymbol.rustType().pythonType()
val targetType = memberSymbol.rustType()
Attribute("staticmethod").render(writer)
writer.rust(
"/// Creates a new union instance of [`$variantName`](#T::$variantName)",
unionSymbol,
)
writer.rust("/// :param data ${pythonType.renderAsDocstring()}:")
writer.rust("/// :rtype ${unionSymbol.name}:")
writer.rustBlock("pub fn $funcNamePart(data: ${targetType.render()}) -> Self") {
rust("Self(${unionSymbol.name}::$variantName(data))")
}
}
}

private fun renderAsVariant(
writer: RustWriter,
model: Model,
symbolProvider: SymbolProvider,
member: MemberShape,
variantName: String,
funcNamePart: String,
unionSymbol: Symbol,
) {
if (member.isTargetUnit()) {
writer.rust(
"/// Tries to convert the union instance into [`$variantName`].",
)
writer.rust("/// :rtype None:")
writer.rustBlockTemplate("pub fn as_$funcNamePart(&self) -> #{pyo3}::PyResult<()>", "pyo3" to pyo3) {
rustTemplate(
"""
self.0.as_$funcNamePart().map_err(#{pyo3}::exceptions::PyValueError::new_err(
"${unionSymbol.name} variant is not None"
))
""",
"pyo3" to pyo3,
)
}
} else {
val memberSymbol = symbolProvider.toSymbol(member)
val pythonType = memberSymbol.rustType().pythonType()
val targetSymbol = symbolProvider.toSymbol(model.expectShape(member.target))
val rustType = memberSymbol.rustType()
writer.rust(
"/// Tries to convert the enum instance into [`$variantName`](#T::$variantName), extracting the inner #D.",
unionSymbol,
targetSymbol,
)
writer.rust("/// :rtype ${pythonType.renderAsDocstring()}:")
writer.rustBlockTemplate("pub fn as_$funcNamePart(&self) -> #{pyo3}::PyResult<${rustType.render()}>", "pyo3" to pyo3) {
val variantType = if (rustType.isCopy()) {
"*variant"
} else {
"variant.clone()"
}
rustTemplate(
"""
match self.0.as_$funcNamePart() {
Ok(variant) => Ok($variantType),
Err(_) => Err(#{pyo3}::exceptions::PyValueError::new_err(
"${unionSymbol.name} variant is not of type ${memberSymbol.rustType().pythonType().renderAsDocstring()}"
)),
}
""",
"pyo3" to pyo3,
)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructu
* whose associated types are `pub(crate)` and thus not exposed to the end
* user.
*/
class ConstrainedShapeSymbolProvider(
open class ConstrainedShapeSymbolProvider(
private val base: RustSymbolProvider,
private val serviceShape: ServiceShape,
private val publicConstrainedTypes: Boolean,
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pokemon-service-client/
pokemon-service-server-sdk/
wheels/
__pycache__
5 changes: 2 additions & 3 deletions rust-runtime/aws-smithy-http-server-python/src/types.rs
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ use pyo3::{
iter::IterNextOutput,
prelude::*,
};
use tokio::sync::Mutex;
use tokio::{runtime::Handle, sync::Mutex};
use tokio_stream::StreamExt;

use crate::PyError;
@@ -386,7 +386,6 @@ impl Default for ByteStream {
}
}

/// ByteStream Abstractions.
#[pymethods]
impl ByteStream {
/// Create a new [ByteStream](aws_smithy_http::byte_stream::ByteStream) from a slice of bytes.
@@ -408,7 +407,7 @@ impl ByteStream {
/// :rtype ByteStream:
#[staticmethod]
pub fn from_path_blocking(py: Python, path: String) -> PyResult<Py<PyAny>> {
let byte_stream = futures::executor::block_on(async {
let byte_stream = Handle::current().block_on(async {
aws_smithy_http::byte_stream::ByteStream::from_path(path)
.await
.map_err(|e| PyRuntimeError::new_err(e.to_string()))

0 comments on commit 48eda40

Please sign in to comment.