diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt index c320d19d74..1bd31fbc45 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt @@ -12,11 +12,14 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.orNull @@ -107,28 +110,34 @@ class EnumGeneratorTest { @deprecated(since: "1.2.3") string InstanceType """.asSmithyModel() - val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - writer.rust("##![allow(deprecated)]") + val shape = model.lookup("test#InstanceType") - val generator = EnumGenerator(model, provider, writer, shape, shape.expectTrait()) - generator.render() - writer.compileAndTest( - """ - let instance = InstanceType::T2Micro; - assert_eq!(instance.as_str(), "t2.micro"); - assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano); - assert_eq!(InstanceType::from("other"), InstanceType::Unknown(UnknownVariantValue("other".to_owned()))); - // round trip unknown variants: - assert_eq!(InstanceType::from("other").as_str(), "other"); - """, - ) - val output = writer.toString() - output shouldContain "#[non_exhaustive]" - // on enum variant `T2Micro` - output shouldContain "#[deprecated]" - // on enum itself - output shouldContain "#[deprecated(since = \"1.2.3\")]" + val trait = shape.expectTrait() + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + rust("##![allow(deprecated)]") + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "it_generates_named_enums", + """ + let instance = InstanceType::T2Micro; + assert_eq!(instance.as_str(), "t2.micro"); + assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano); + assert_eq!(InstanceType::from("other"), InstanceType::Unknown(UnknownVariantValue("other".to_owned()))); + // round trip unknown variants: + assert_eq!(InstanceType::from("other").as_str(), "other"); + """, + ) + val output = toString() + output.shouldContain("#[non_exhaustive]") + // on enum variant `T2Micro` + output.shouldContain("#[deprecated]") + // on enum itself + output.shouldContain("#[deprecated(since = \"1.2.3\")]") + } + project.compileAndTest() } @Test @@ -146,19 +155,25 @@ class EnumGeneratorTest { }]) string FooEnum """.asSmithyModel() + val shape = model.lookup("test#FooEnum") val trait = shape.expectTrait() - val writer = RustWriter.forModule("model") - val generator = EnumGenerator(model, testSymbolProvider(model), writer, shape, trait) - generator.render() - writer.compileAndTest( - """ - assert_eq!(FooEnum::Foo, FooEnum::Foo); - assert_ne!(FooEnum::Bar, FooEnum::Foo); - let mut hash_of_enums = std::collections::HashSet::new(); - hash_of_enums.insert(FooEnum::Foo); - """.trimIndent(), - ) + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "named_enums_implement_eq_and_hash", + """ + assert_eq!(FooEnum::Foo, FooEnum::Foo); + assert_ne!(FooEnum::Bar, FooEnum::Foo); + let mut hash_of_enums = std::collections::HashSet::new(); + hash_of_enums.insert(FooEnum::Foo); + """.trimIndent(), + ) + } + project.compileAndTest() } @Test @@ -175,20 +190,26 @@ class EnumGeneratorTest { @deprecated string FooEnum """.asSmithyModel() + val shape = model.lookup("test#FooEnum") val trait = shape.expectTrait() - val writer = RustWriter.forModule("model") - writer.rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, testSymbolProvider(model), writer, shape, trait) - generator.render() - writer.compileAndTest( - """ - assert_eq!(FooEnum::from("Foo"), FooEnum::from("Foo")); - assert_ne!(FooEnum::from("Bar"), FooEnum::from("Foo")); - let mut hash_of_enums = std::collections::HashSet::new(); - hash_of_enums.insert(FooEnum::from("Foo")); - """, - ) + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + rust("##![allow(deprecated)]") + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "unnamed_enums_implement_eq_and_hash", + """ + assert_eq!(FooEnum::from("Foo"), FooEnum::from("Foo")); + assert_ne!(FooEnum::from("Bar"), FooEnum::from("Foo")); + let mut hash_of_enums = std::collections::HashSet::new(); + hash_of_enums.insert(FooEnum::from("Foo")); + """.trimIndent(), + ) + } + project.compileAndTest() } @Test @@ -214,19 +235,24 @@ class EnumGeneratorTest { ]) string FooEnum """.asSmithyModel() + val shape = model.lookup("test#FooEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - writer.rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, provider, writer, shape, trait) - generator.render() - writer.compileAndTest( - """ - // Values should be sorted - assert_eq!(FooEnum::${EnumGenerator.Values}(), ["0", "1", "Bar", "Baz", "Foo"]); - """, - ) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + rust("##![allow(deprecated)]") + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "it_generates_unnamed_enums", + """ + // Values should be sorted + assert_eq!(FooEnum::${EnumGenerator.Values}(), ["0", "1", "Bar", "Baz", "Foo"]); + """.trimIndent(), + ) + } + project.compileAndTest() } @Test @@ -241,19 +267,23 @@ class EnumGeneratorTest { string SomeEnum """.asSmithyModel() - val shape: StringShape = model.lookup("test#SomeEnum") + val shape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - EnumGenerator(model, provider, writer, shape, trait).render() - - writer.compileAndTest( - """ - assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue); - assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(UnknownVariantValue("SomethingNew".to_owned()))); - """, - ) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "it_escapes_the_unknown_variant_if_the_enum_has_an_unknown_value_in_the_model", + """ + assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue); + assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_); + assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(UnknownVariantValue("SomethingNew".to_owned()))); + """.trimIndent(), + ) + } + project.compileAndTest() } @Test @@ -269,19 +299,22 @@ class EnumGeneratorTest { string SomeEnum """.asSmithyModel() - val shape: StringShape = model.lookup("test#SomeEnum") + val shape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val rendered = - RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() } - .toString() - - rendered shouldContain - """ - /// Some top-level documentation. - /// - /// _Note: `SomeEnum::Unknown` has been renamed to `::UnknownValue`._ - """.trimIndent() + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + val rendered = toString() + rendered shouldContain + """ + /// Some top-level documentation. + /// + /// _Note: `SomeEnum::Unknown` has been renamed to `::UnknownValue`._ + """.trimIndent() + } + project.compileAndTest() } @Test @@ -297,17 +330,20 @@ class EnumGeneratorTest { string SomeEnum """.asSmithyModel() - val shape: StringShape = model.lookup("test#SomeEnum") + val shape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val rendered = - RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() } - .toString() - - rendered shouldContain - """ - /// Some top-level documentation. - """.trimIndent() + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + val rendered = toString() + rendered shouldContain + """ + /// Some top-level documentation. + """.trimIndent() + } + project.compileAndTest() } } @@ -322,38 +358,47 @@ class EnumGeneratorTest { string SomeEnum """.asSmithyModel() - val shape: StringShape = model.lookup("test#SomeEnum") + val shape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - EnumGenerator(model, provider, writer, shape, trait).render() - - writer.compileAndTest( - """ - assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(UnknownVariantValue("SomethingNew".to_owned()))); - """, - ) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "it_handles_variants_that_clash_with_rust_reserved_words", + """ + assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue); + assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(UnknownVariantValue("SomethingNew".to_owned()))); + """.trimIndent(), + ) + } + project.compileAndTest() } @Test fun `matching on enum should be forward-compatible`() { fun expectMatchExpressionCompiles(model: Model, shapeId: String, enumToMatchOn: String) { - val shape: StringShape = model.lookup(shapeId) + val shape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - EnumGenerator(model, provider, writer, shape, trait).render() - - val matchExpressionUnderTest = """ - match $enumToMatchOn { - SomeEnum::Variant1 => assert!(false, "expected `Variant3` but got `Variant1`"), - SomeEnum::Variant2 => assert!(false, "expected `Variant3` but got `Variant2`"), - other @ _ if other.as_str() == "Variant3" => assert!(true), - _ => assert!(false, "expected `Variant3` but got `_`"), - } - """ - writer.compileAndTest(matchExpressionUnderTest) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "matching_on_enum_should_be_forward_compatible", + """ + match $enumToMatchOn { + SomeEnum::Variant1 => assert!(false, "expected `Variant3` but got `Variant1`"), + SomeEnum::Variant2 => assert!(false, "expected `Variant3` but got `Variant2`"), + other @ _ if other.as_str() == "Variant3" => assert!(true), + _ => assert!(false, "expected `Variant3` but got `_`"), + } + """.trimIndent(), + ) + } + project.compileAndTest() } val modelV1 = """