diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt index 8c5a56b577..d76cfb1353 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt @@ -57,6 +57,7 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr is StringShape -> if (shape.hasTrait()) { enumMeta(shape) } else null + else -> null } return baseSymbol.toBuilder().meta(meta).build() @@ -100,11 +101,13 @@ class BaseSymbolMetadataProvider( ) } } + container.isUnionShape || container.isListShape || container.isSetShape || container.isMapShape -> RustMetadata(visibility = Visibility.PUBLIC) + else -> TODO("Unrecognized container type: $container") } } @@ -120,9 +123,10 @@ class BaseSymbolMetadataProvider( override fun enumMeta(stringShape: StringShape): RustMetadata { return containerDefault.withDerives( RuntimeType.std.member("hash::Hash"), - ).withDerives( // enums can be eq because they can only contain strings + ).withDerives( + // enums can be eq because the inner data also implements Eq RuntimeType.std.member("cmp::Eq"), - // enums can be Ord because they can only contain strings + // enums can be Ord because the inner data also implements Ord RuntimeType.std.member("cmp::PartialOrd"), RuntimeType.std.member("cmp::Ord"), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index e735ec5e11..4db0b8ca7b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -99,6 +99,9 @@ open class EnumGenerator( /** Name of the generated unknown enum member name for enums with named members. */ const val UnknownVariant = "Unknown" + /** Name of the opaque struct that is inner data for the generated [UnknownVariant]. */ + const val UnknownVariantValue = "UnknownVariantValue" + /** Name of the function on the enum impl to get a vec of value names */ const val Values = "values" } @@ -108,6 +111,10 @@ open class EnumGenerator( // pub enum Blah { V1, V2, .. } renderEnum() writer.insertTrailingNewline() + if (target == CodegenTarget.CLIENT) { + renderUnknownVariantValue() + } + writer.insertTrailingNewline() // impl From for Blah { ... } renderFromForStr() // impl FromStr for Blah { ... } @@ -168,8 +175,8 @@ open class EnumGenerator( writer.rustBlock("enum $enumName") { sortedMembers.forEach { member -> member.render(writer) } if (target == CodegenTarget.CLIENT) { - docs("$UnknownVariant contains new variants that have been added since this code was generated.") - rust("$UnknownVariant(String)") + docs("`$UnknownVariant` contains new variants that have been added since this code was generated.") + rust("$UnknownVariant($UnknownVariantValue)") } } } @@ -183,7 +190,7 @@ open class EnumGenerator( rust("""$enumName::${member.derivedName()} => ${member.value.dq()},""") } if (target == CodegenTarget.CLIENT) { - rust("$enumName::$UnknownVariant(s) => s.as_ref()") + rust("$enumName::$UnknownVariant(value) => value.as_str()") } } } @@ -198,6 +205,17 @@ open class EnumGenerator( } } + private fun renderUnknownVariantValue() { + meta.render(writer) + writer.write("struct $UnknownVariantValue(String);") + writer.rustBlock("impl $UnknownVariantValue") { + // The generated as_str is not pub as we need to prevent users from calling it on this opaque struct. + rustBlock("fn as_str(&self) -> &str") { + rust("&self.0") + } + } + } + protected open fun renderFromForStr() { writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.From) { rustBlock("fn from(s: &str) -> Self") { @@ -205,7 +223,7 @@ open class EnumGenerator( sortedMembers.forEach { member -> rust("""${member.value.dq()} => $enumName::${member.derivedName()},""") } - rust("other => $enumName::$UnknownVariant(other.to_owned())") + rust("other => $enumName::$UnknownVariant($UnknownVariantValue(other.to_owned()))") } } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt index 90679f9773..d5858fcd52 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt @@ -117,7 +117,7 @@ class EnumGeneratorTest { let instance = InstanceType::T2Micro; assert_eq!(instance.as_str(), "t2.micro"); assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano); - assert_eq!(InstanceType::from("other"), InstanceType::Unknown("other".to_owned())); + assert_eq!(InstanceType::from("other"), InstanceType::Unknown(UnknownVariantValue("other".to_owned()))); // round trip unknown variants: assert_eq!(InstanceType::from("other").as_str(), "other"); """, @@ -250,7 +250,7 @@ class EnumGeneratorTest { """ assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue); assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown("SomethingNew".into())); + assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(UnknownVariantValue("SomethingNew".to_owned()))); """, ) } @@ -271,7 +271,9 @@ class EnumGeneratorTest { val shape: StringShape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val rendered = RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }.toString() + val rendered = + RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() } + .toString() rendered shouldContain """ @@ -297,7 +299,9 @@ class EnumGeneratorTest { val shape: StringShape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val rendered = RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }.toString() + val rendered = + RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() } + .toString() rendered shouldContain """ @@ -326,7 +330,7 @@ class EnumGeneratorTest { writer.compileAndTest( """ assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown("SomethingNew".into())); + assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(UnknownVariantValue("SomethingNew".to_owned()))); """, ) }