Skip to content

Commit

Permalink
Unknown enum variants removed from server (#1398)
Browse files Browse the repository at this point in the history
The server must have the most up to date variants and the unknown enum
variant should not be used. Clients are generated with it because they
might not have the most recent model and the server might return
an unknown variant to them.

Closes #1187

Signed-off-by: Daniele Ahmed <[email protected]>

Co-authored-by: Daniele Ahmed <[email protected]>
Co-authored-by: david-perez <[email protected]>
Co-authored-by: Matteo Bigoi <[email protected]>
  • Loading branch information
4 people authored May 23, 2022
1 parent d6e2944 commit 35989d2
Show file tree
Hide file tree
Showing 21 changed files with 327 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class RustCodegenServerPlugin : SmithyBuildPlugin {
override fun execute(context: PluginContext) {
// Suppress extremely noisy logs about reserved words
Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF
// Discover [RustCodegenDecorators] on the classpath. [RustCodegenDectorator] return different types of
// Discover [RustCodegenDecorators] on the classpath. [RustCodegenDecorator] return different types of
// customization. A customization is a function of:
// - location (e.g. the mutate section of an operation)
// - context (e.g. the of the operation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerServiceGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
Expand All @@ -28,7 +29,6 @@ import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
Expand Down Expand Up @@ -184,7 +184,7 @@ class ServerCodegenVisitor(context: PluginContext, private val codegenDecorator:
logger.info("[rust-server-codegen] Generating an enum $shape")
shape.getTrait<EnumTrait>()?.also { enum ->
rustCrate.useShapeWriter(shape) { writer ->
EnumGenerator(model, symbolProvider, writer, shape, enum).render()
ServerEnumGenerator(model, symbolProvider, writer, shape, enum, codegenContext.runtimeConfig).render()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ServerCombinedErrorGenerator(
}
}

writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.stdfmt.member("Display")) {
writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.Display) {
rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") {
delegateToVariants {
rust("_inner.fmt(f)")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.smithy.CodegenMode
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.util.dq

class ServerEnumGenerator(
model: Model,
symbolProvider: RustSymbolProvider,
private val writer: RustWriter,
shape: StringShape,
enumTrait: EnumTrait,
private val runtimeConfig: RuntimeConfig,
) : EnumGenerator(model, symbolProvider, writer, shape, enumTrait) {
override var mode: CodegenMode = CodegenMode.Server
private val errorStruct = "${enumName}UnknownVariantError"

override fun renderFromForStr() {
writer.rust(
"""
##[derive(Debug, PartialEq, Eq, Hash)]
pub struct $errorStruct(String);
"""
)
writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.TryFrom) {
write("type Error = $errorStruct;")
writer.rustBlock("fn try_from(s: &str) -> Result<Self, <$enumName as #T<&str>>::Error>", RuntimeType.TryFrom) {
writer.rustBlock("match s") {
sortedMembers.forEach { member ->
write("${member.value.dq()} => Ok($enumName::${member.derivedName()}),")
}
write("_ => Err($errorStruct(s.to_owned()))")
}
}
}
writer.rustTemplate(
"""
impl #{From}<$errorStruct> for #{RequestRejection} {
fn from(e: $errorStruct) -> Self {
Self::EnumVariantNotFound(Box::new(e))
}
}
impl #{From}<$errorStruct> for #{JsonDeserialize} {
fn from(e: $errorStruct) -> Self {
Self::custom(format!("unknown variant {}", e))
}
}
impl #{StdError} for $errorStruct { }
impl #{Display} for $errorStruct {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
""",
"Display" to RuntimeType.Display,
"From" to RuntimeType.From,
"StdError" to RuntimeType.StdError,
"RequestRejection" to ServerRuntimeType.RequestRejection(runtimeConfig),
"JsonDeserialize" to RuntimeType.jsonDeserialize(runtimeConfig),
)
}

override fun renderFromStr() {
writer.rust(
"""
impl std::str::FromStr for $enumName {
type Err = $errorStruct;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
$enumName::try_from(s)
}
}
"""
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class ServerOperationRegistryGenerator(
pub enum ${operationRegistryBuilderName}Error {
UninitializedField(&'static str)
}
impl std::fmt::Display for ${operationRegistryBuilderName}Error {
impl #{Display} for ${operationRegistryBuilderName}Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UninitializedField(v) => write!(f, "{}", v),
Expand All @@ -128,7 +128,8 @@ class ServerOperationRegistryGenerator(
}
impl #{StdError} for ${operationRegistryBuilderName}Error {}
""".trimIndent(),
*codegenScope
*codegenScope,
"Display" to RuntimeType.Display,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.CodegenMode
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
Expand Down Expand Up @@ -71,7 +72,7 @@ class ServerProtocolTestGenerator(
private val operationErrorName = "crate::error::${operationSymbol.name}Error"

private val instantiator = with(codegenContext) {
Instantiator(symbolProvider, model, runtimeConfig)
Instantiator(symbolProvider, model, runtimeConfig, CodegenMode.Server)
}

private val codegenScope = arrayOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
Expand All @@ -40,6 +41,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerR
import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerResponseBindingGenerator
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.extractSymbolFromOption
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
Expand All @@ -53,13 +55,15 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolPay
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.smithy.toOptional
import software.amazon.smithy.rust.codegen.smithy.wrapOptional
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.findStreamingMember
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.outputShape
Expand Down Expand Up @@ -855,15 +859,25 @@ private class ServerHttpBoundProtocolTraitImplGenerator(

when {
memberShape.isStringShape -> {
// `<_>::from()` is necessary to convert the `&str` into:
// `<_>::from()/try_from()` is necessary to convert the `&str` into:
// * the Rust enum in case the `string` shape has the `enum` trait; or
// * `String` in case it doesn't.
rustTemplate(
"""
let v = <_>::from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref());
""".trimIndent(),
*codegenScope
)
if (memberShape.hasTrait<EnumTrait>()) {
rustTemplate(
"""
let v = <#{memberShape}>::try_from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref())?;
""",
*codegenScope,
"memberShape" to symbolProvider.toSymbol(memberShape),
)
} else {
rustTemplate(
"""
let v = <_>::from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref());
""".trimIndent(),
*codegenScope
)
}
}
memberShape.isTimestampShape -> {
val index = HttpBindingIndex.of(model)
Expand Down Expand Up @@ -984,6 +998,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
private fun generateParsePercentEncodedStrAsStringFn(binding: HttpBindingDescriptor): RuntimeType {
val output = symbolProvider.toSymbol(binding.member)
val fnName = generateParseStrFnName(binding)
val symbol = output.extractSymbolFromOption()
return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer ->
writer.rustBlockTemplate(
"pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{RequestRejection}>",
Expand All @@ -993,12 +1008,30 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
// `<_>::from()` is necessary to convert the `&str` into:
// * the Rust enum in case the `string` shape has the `enum` trait; or
// * `String` in case it doesn't.
rustTemplate(
when (symbol.rustType()) {
RustType.String ->
rustTemplate(
"""
let value = <#{T}>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref());
""",
*codegenScope,
"T" to symbol,
)
else -> { // RustType.Opaque, the Enum
check(symbol.rustType() is RustType.Opaque)
rustTemplate(
"""
let value = <#{T}>::try_from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref())?;
""",
*codegenScope,
"T" to symbol,
)
}
}
writer.write(
"""
let value = <_>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref());
Ok(${symbolProvider.wrapOptional(binding.member, "value")})
""".trimIndent(),
*codegenScope,
"""
)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.server.smithy.generators

import io.kotest.matchers.string.shouldNotContain
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.lookup

class ServerEnumGeneratorTest {
private val model = """
namespace test
@enum([
{
value: "t2.nano",
name: "T2_NANO",
documentation: "T2 instances are Burstable Performance Instances.",
tags: ["ebsOnly"]
},
{
value: "t2.micro",
name: "T2_MICRO",
documentation: "T2 instances are Burstable Performance Instances.",
tags: ["ebsOnly"]
},
])
string InstanceType
""".asSmithyModel()

@Test
fun `it generates TryFrom, FromStr and errors for enums`() {
val provider = serverTestSymbolProvider(model)
val writer = RustWriter.forModule("model")
val shape = model.lookup<StringShape>("test#InstanceType")
val generator = ServerEnumGenerator(model, provider, writer, shape, shape.expectTrait(), TestRuntimeConfig)
generator.render()
writer.compileAndTest(
"""
use std::str::FromStr;
assert_eq!(InstanceType::try_from("t2.nano").unwrap(), InstanceType::T2Nano);
assert_eq!(InstanceType::from_str("t2.nano").unwrap(), InstanceType::T2Nano);
assert_eq!(InstanceType::try_from("unknown").unwrap_err(), InstanceTypeUnknownVariantError("unknown".to_string()));
"""
)
}

@Test
fun `it generates enums without the unknown variant`() {
val provider = serverTestSymbolProvider(model)
val writer = RustWriter.forModule("model")
val shape = model.lookup<StringShape>("test#InstanceType")
val generator = ServerEnumGenerator(model, provider, writer, shape, shape.expectTrait(), TestRuntimeConfig)
generator.render()
writer.compileAndTest(
"""
// check no unknown
let instance = InstanceType::T2Micro;
match instance {
InstanceType::T2Micro => (),
InstanceType::T2Nano => (),
}
"""
)
}

@Test
fun `it generates enums without non_exhaustive`() {
val provider = serverTestSymbolProvider(model)
val writer = RustWriter.forModule("model")
val shape = model.lookup<StringShape>("test#InstanceType")
val generator = ServerEnumGenerator(model, provider, writer, shape, shape.expectTrait(), TestRuntimeConfig)
generator.render()
writer.toString() shouldNotContain "#[non_exhaustive]"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
val Clone = std.member("clone::Clone")
val Debug = stdfmt.member("Debug")
val Default: RuntimeType = RuntimeType("Default", dependency = null, namespace = "std::default")
val Display = stdfmt.member("Display")
val From = RuntimeType("From", dependency = null, namespace = "std::convert")
val TryFrom = RuntimeType("TryFrom", dependency = null, namespace = "std::convert")
val Infallible = RuntimeType("Infallible", dependency = null, namespace = "std::convert")
val PartialEq = std.member("cmp::PartialEq")
val StdError = RuntimeType("Error", dependency = null, namespace = "std::error")
Expand Down Expand Up @@ -290,6 +292,12 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
namespace = "aws_smithy_http::response"
)

fun jsonDeserialize(runtimeConfig: RuntimeConfig) = RuntimeType(
name = "Error",
dependency = CargoDependency.smithyJson(runtimeConfig),
namespace = "aws_smithy_json::deserialize"
)

fun ec2QueryErrors(runtimeConfig: RuntimeConfig) =
forInlineDependency(InlineDependency.ec2QueryErrors(runtimeConfig))

Expand Down
Loading

0 comments on commit 35989d2

Please sign in to comment.