Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Render a Union member of type Unit to an enum variant without inner data #1989

Merged
merged 25 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bdf7858
Avoid explicitly emitting Unit type within Union
ysaito1001 Nov 15, 2022
2b148aa
Address test failures washed out in CI
ysaito1001 Nov 16, 2022
1f702df
Update codegen-core/src/test/kotlin/software/amazon/smithy/rust/codeg…
ysaito1001 Nov 16, 2022
11caf0c
Remove commented-out code
ysaito1001 Nov 16, 2022
a7002aa
Merge branch 'ysaito/remove-unit-from-generated-rust-enum' of https:/…
ysaito1001 Nov 16, 2022
fd60317
Add a helper for comparing against ShapeId for Unit
ysaito1001 Nov 17, 2022
2455eb6
Merge branch 'main' into ysaito/remove-unit-from-generated-rust-enum
ysaito1001 Nov 17, 2022
74041ca
Move Unit type bifurcation logic to jsonObjectWriter
ysaito1001 Nov 17, 2022
5802bd5
Make QuerySerializerGenerator in sync with the change
ysaito1001 Nov 17, 2022
d48b115
Update CHANGELOG.next.toml
ysaito1001 Nov 17, 2022
ccd2e63
Merge branch 'main' into ysaito/remove-unit-from-generated-rust-enum
ysaito1001 Nov 17, 2022
de1f5cf
Refactor ofTypeUnit -> isTargetUnit
ysaito1001 Nov 30, 2022
7c3a0a8
Merge branch 'main' into ysaito/remove-unit-from-generated-rust-enum
ysaito1001 Nov 30, 2022
5e95a55
Update codegen-core/src/main/kotlin/software/amazon/smithy/rust/codeg…
ysaito1001 Dec 3, 2022
4b9a355
Simplify if-else in jsonObjectWriter
ysaito1001 Dec 3, 2022
e4f6559
Avoid the union member's reference name being empty
ysaito1001 Dec 3, 2022
0e8b621
CHANGELOG.next.toml
ysaito1001 Dec 3, 2022
45ef80d
Ensure Union with Unit target can be serialized
ysaito1001 Dec 6, 2022
884155e
Ensure Union with Unit target can be parsed
ysaito1001 Dec 6, 2022
9962214
Merge branch 'main' into ysaito/remove-unit-from-generated-rust-enum
ysaito1001 Dec 6, 2022
2269361
Ensure match arm for Unit works in custom Debug impl
ysaito1001 Dec 6, 2022
5445282
Merge branch 'main' into ysaito/remove-unit-from-generated-rust-enum
ysaito1001 Dec 6, 2022
fd8e8fc
Fix unused variables warnings in CI
ysaito1001 Dec 7, 2022
2d1f01d
Fix E0658 on unused_variables
ysaito1001 Dec 7, 2022
d0c8b0b
Merge branch 'main' into ysaito/remove-unit-from-generated-rust-enum
ysaito1001 Dec 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,15 @@ x-amzn-errortype: com.example.service#InvalidRequestException
references = ["smithy-rs#1982"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" }
author = "david-perez"

[[aws-sdk-rust]]
message = "The Unit type for a Union member is no longer rendered."
references = ["smithy-rs#1989"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "ysaito1001"

[[smithy-rs]]
message = "The Unit type for a Union member is no longer rendered."
references = ["smithy-rs#1989"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" }
author = "ysaito1001"
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,10 @@ open class Instantiator(
val member = shape.expectMember(memberName)
writer.rust("#T::${symbolProvider.toMemberName(member)}", unionSymbol)
// Unions should specify exactly one member.
writer.withBlock("(", ")") {
renderMember(this, member, variant.second, ctx)
if (!member.ofTypeUnit()) {
writer.withBlock("(", ")") {
renderMember(this, member, variant.second, ctx)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
Expand All @@ -26,6 +28,12 @@ fun CodegenTarget.renderUnknownVariant() = when (this) {
CodegenTarget.CLIENT -> true
}

private val unitShapeId = ShapeId.from("smithy.api#Unit")

internal fun MemberShape.ofTypeUnit(): Boolean {
ysaito1001 marked this conversation as resolved.
Show resolved Hide resolved
return this.target == unitShapeId
}
ysaito1001 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Generate an `enum` for a Smithy Union Shape
*
Expand All @@ -49,16 +57,18 @@ class UnionGenerator(
private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }

fun render() {
renderUnion()
}

private fun renderUnion() {
writer.documentShape(shape, model)
writer.deprecatedShape(shape)

val unionSymbol = symbolProvider.toSymbol(shape)
val containerMeta = unionSymbol.expectRustMetadata()
containerMeta.render(writer)

renderUnion(unionSymbol)
renderImplBlock(unionSymbol)
}

private fun renderUnion(unionSymbol: Symbol) {
writer.rustBlock("enum ${unionSymbol.name}") {
sortedMembers.forEach { member ->
val memberSymbol = symbolProvider.toSymbol(member)
Expand All @@ -67,7 +77,7 @@ class UnionGenerator(
documentShape(member, model, note = note)
deprecatedShape(member)
memberSymbol.expectRustMetadata().renderAttributes(this)
write("${symbolProvider.toMemberName(member)}(#T),", symbolProvider.toSymbol(member))
writer.renderVariant(symbolProvider, member, memberSymbol)
}
if (renderUnknownVariant) {
docs("""The `Unknown` variant represents cases where new union variant was received. Consider upgrading the SDK to the latest available version.""")
Expand All @@ -82,6 +92,9 @@ class UnionGenerator(
rust("Unknown,")
}
}
}

private fun renderImplBlock(unionSymbol: Symbol) {
writer.rustBlock("impl ${unionSymbol.name}") {
sortedMembers.forEach { member ->
val memberSymbol = symbolProvider.toSymbol(member)
Expand All @@ -91,11 +104,7 @@ class UnionGenerator(
if (sortedMembers.size == 1) {
Attribute.Custom("allow(irrefutable_let_patterns)").render(this)
}
rust("/// Tries to convert the enum instance into [`$variantName`](#T::$variantName), extracting the inner #D.", unionSymbol, memberSymbol)
rust("/// Returns `Err(&Self)` if it can't be converted.")
rustBlock("pub fn as_$funcNamePart(&self) -> std::result::Result<&#T, &Self>", memberSymbol) {
rust("if let ${unionSymbol.name}::$variantName(val) = &self { Ok(val) } else { Err(self) }")
}
writer.renderAsVariant(member, variantName, funcNamePart, unionSymbol, memberSymbol)
rust("/// Returns true if this is a [`$variantName`](#T::$variantName).", unionSymbol)
rustBlock("pub fn is_$funcNamePart(&self) -> bool") {
rust("self.as_$funcNamePart().is_ok()")
Expand All @@ -114,7 +123,44 @@ class UnionGenerator(
const val UnknownVariantName = "Unknown"
}
}

fun unknownVariantError(union: String) =
"Cannot serialize `$union::${UnionGenerator.UnknownVariantName}` for the request. " +
"The `Unknown` variant is intended for responses only. " +
"It occurs when an outdated client is used after a new enum variant was added on the server side."

private fun RustWriter.renderVariant(symbolProvider: SymbolProvider, member: MemberShape, memberSymbol: Symbol) {
if (member.ofTypeUnit()) {
write("${symbolProvider.toMemberName(member)},")
} else {
write("${symbolProvider.toMemberName(member)}(#T),", memberSymbol)
}
}

private fun RustWriter.renderAsVariant(
member: MemberShape,
variantName: String,
funcNamePart: String,
unionSymbol: Symbol,
memberSymbol: Symbol,
) {
if (member.ofTypeUnit()) {
rust(
"/// Tries to convert the enum instance into [`$variantName`], extracting the inner `()`.",
)
rust("/// Returns `Err(&Self)` if it can't be converted.")
rustBlock("pub fn as_$funcNamePart(&self) -> std::result::Result<(), &Self>") {
ysaito1001 marked this conversation as resolved.
Show resolved Hide resolved
rust("if let ${unionSymbol.name}::$variantName = &self { Ok(()) } else { Err(self) }")
}
} else {
rust(
"/// Tries to convert the enum instance into [`$variantName`](#T::$variantName), extracting the inner #D.",
unionSymbol,
memberSymbol,
)
rust("/// Returns `Err(&Self)` if it can't be converted.")
rustBlock("pub fn as_$funcNamePart(&self) -> std::result::Result<&#T, &Self>", memberSymbol) {
rust("if let ${unionSymbol.name}::$variantName(val) = &self { Ok(val) } else { Err(self) }")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedSectionGen
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.ofTypeUnit
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
Expand Down Expand Up @@ -311,6 +312,7 @@ class JsonParserGenerator(
rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target))
}
}

else -> rust("u.into_owned()")
}
}
Expand Down Expand Up @@ -510,9 +512,19 @@ class JsonParserGenerator(
for (member in shape.members()) {
val variantName = symbolProvider.toMemberName(member)
rustBlock("${jsonName(member).dq()} =>") {
withBlock("Some(#T::$variantName(", "))", returnSymbolToParse.symbol) {
deserializeMember(member)
unwrapOrDefaultOrError(member)
if (member.ofTypeUnit()) {
rustTemplate(
"""
#{skip_value}(tokens)?;
Some(#{Union}::$variantName)
""",
"Union" to returnSymbolToParse.symbol, *codegenScope,
)
} else {
withBlock("Some(#T::$variantName(", "))", returnSymbolToParse.symbol) {
deserializeMember(member)
unwrapOrDefaultOrError(member)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedSectionGen
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.ofTypeUnit
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
Expand Down Expand Up @@ -415,9 +416,14 @@ class JsonSerializerGenerator(

private fun RustWriter.jsonObjectWriter(context: MemberContext, inner: RustWriter.(String) -> Unit) {
safeName("object").also { objectName ->
rust("let mut $objectName = ${context.writerExpression}.start_object();")
inner(objectName)
rust("$objectName.finish();")
if (context.shape.ofTypeUnit()) {
rust("let $objectName = ${context.writerExpression}.start_object();")
rust("$objectName.finish();")
} else {
rust("let mut $objectName = ${context.writerExpression}.start_object();")
inner(objectName)
rust("$objectName.finish();")
}
}
}

Expand Down Expand Up @@ -452,8 +458,15 @@ class JsonSerializerGenerator(
rustBlock("match input") {
for (member in context.shape.members()) {
val variantName = symbolProvider.toMemberName(member)
withBlock("#T::$variantName(inner) => {", "},", unionSymbol) {
serializeMember(MemberContext.unionMember(context, "inner", member, jsonName))

if (member.ofTypeUnit()) {
withBlock("#T::$variantName => {", "},", unionSymbol) {
serializeMember(MemberContext.unionMember(context, "", member, jsonName))
ysaito1001 marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
withBlock("#T::$variantName(inner) => {", "},", unionSymbol) {
serializeMember(MemberContext.unionMember(context, "inner", member, jsonName))
}
}
}
if (codegenTarget.renderUnknownVariant()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.ofTypeUnit
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
Expand Down Expand Up @@ -312,14 +313,20 @@ abstract class QuerySerializerGenerator(codegenContext: CodegenContext) : Struct
rustBlock("match input") {
for (member in context.shape.members()) {
val variantName = symbolProvider.toMemberName(member)
withBlock("#T::$variantName(inner) => {", "},", unionSymbol) {
serializeMember(
MemberContext.unionMember(
context.copy(writerExpression = "writer"),
"inner",
member,
),
)
if (member.ofTypeUnit()) {
withBlock("#T::$variantName => {", "},", unionSymbol) {
serializeMember(MemberContext.unionMember(context.copy(writerExpression = "writer"), "", member))
}
} else {
withBlock("#T::$variantName(inner) => {", "},", unionSymbol) {
serializeMember(
MemberContext.unionMember(
context.copy(writerExpression = "writer"),
"inner",
member,
),
)
}
}
}
if (target.renderUnknownVariant()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ class UnionGeneratorTest {
writer.compileAndTest()
}

@Test
fun `unit types should not appear in generated enum`() {
val writer = generateUnion("union MyUnion { a: Unit, b: String }", unknownVariant = true)
writer.compileAndTest(
"""
let a = MyUnion::A;
assert_eq!(Ok(()), a.as_a());
""",
)
}

private fun generateUnion(modelSmithy: String, unionName: String = "MyUnion", unknownVariant: Boolean = true): RustWriter {
val model = "namespace test\n$modelSmithy".asSmithyModel()
val provider: SymbolProvider = testSymbolProvider(model)
Expand Down