From efba260686238cb44eeea2cd27aae919ded9c6e2 Mon Sep 17 00:00:00 2001 From: david-perez Date: Thu, 30 Nov 2023 13:12:06 +0100 Subject: [PATCH 1/3] Allow injecting methods with generic type parameters in the config object This is a follow-up to #3111. Currently, the injected methods are limited to taking in concrete types. This PR allows for these methods to take in generic type parameters as well. ```rust impl SimpleServiceConfigBuilder { pub fn aws_auth(config: C) { ... } } ``` --- .../generators/ServiceConfigGenerator.kt | 59 ++++++++++++++----- .../generators/ServiceConfigGeneratorTest.kt | 27 +++++---- 2 files changed, 59 insertions(+), 27 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt index 525487968f..93da475202 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTypeParameters import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope @@ -33,7 +34,7 @@ data class ConfigMethod( val docs: String, /** The parameters of the method. **/ val params: List, - /** In case the method is fallible, the error type it returns. **/ + /** In case the method is fallible, the concrete error type it returns. **/ val errorType: RuntimeType?, /** The code block inside the method. **/ val initializer: Initializer, @@ -104,15 +105,42 @@ data class Initializer( * } * * has two variable bindings. The `bar` name is bound to a `String` variable and the `baz` name is bound to a - * `u64` variable. + * `u64` variable. Both are bindings that use concrete types. Types can also be generic: + * + * ```rust + * fn foo(bar: T) { } * ``` */ -data class Binding( - /** The name of the variable. */ - val name: String, - /** The type of the variable. */ - val ty: RuntimeType, -) +sealed class Binding { + data class Generic( + /** The name of the variable. The name of the type parameter will be the PascalCased variable name. */ + val name: String, + /** The type of the variable. */ + val ty: RuntimeType, + /** + * The generic type parameters contained in `ty`. For example, if `ty` renders to `Vec` with `T` being a + * generic type parameter, then `genericTys` should be a singleton list containing `"T"`. + * */ + val genericTys: List + ): Binding() + + data class Concrete( + /** The name of the variable. */ + val name: String, + /** The type of the variable. */ + val ty: RuntimeType, + ): Binding() + + fun name() = when (this) { + is Concrete -> this.name + is Generic -> this.name + } + + fun ty() = when (this) { + is Concrete -> this.ty + is Generic -> this.ty + } +} class ServiceConfigGenerator( codegenContext: ServerCodegenContext, @@ -317,8 +345,10 @@ class ServiceConfigGenerator( private fun injectedMethods() = configMethods.map { writable { val paramBindings = it.params.map { binding -> - writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) } + writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) } }.join("\n") + val paramBindingsGenericTys = it.params.filterIsInstance().flatMap { it.genericTys } + val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray()) // This produces a nested type like: "S>", where // - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack @@ -332,7 +362,7 @@ class ServiceConfigGenerator( rustTemplate( "#{StackType}<#{Ty}, #{Acc:W}>", "StackType" to stackType, - "Ty" to next.ty, + "Ty" to next.ty(), "Acc" to acc, ) } @@ -376,7 +406,7 @@ class ServiceConfigGenerator( docs(it.docs) rustBlockTemplate( """ - pub fn ${it.name}( + pub fn ${it.name}#{ParamBindingsGenericsWritable}( ##[allow(unused_mut)] mut self, #{ParamBindings:W} @@ -384,6 +414,7 @@ class ServiceConfigGenerator( """, "ReturnTy" to returnTy, "ParamBindings" to paramBindings, + "ParamBindingsGenericsWritable" to paramBindingsGenericsWritable, ) { rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code) @@ -396,9 +427,9 @@ class ServiceConfigGenerator( } conditionalBlock("Ok(", ")", conditional = it.errorType != null) { val registrations = ( - it.initializer.layerBindings.map { ".layer(${it.name})" } + - it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } + - it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" } + it.initializer.layerBindings.map { ".layer(${it.name()})" } + + it.initializer.httpPluginBindings.map { ".http_plugin(${it.name()})" } + + it.initializer.modelPluginBindings.map { ".model_plugin(${it.name()})" } ).joinToString("") rust("self$registrations") } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt index c2c568b291..f1264b672f 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt @@ -40,8 +40,9 @@ internal class ServiceConfigGeneratorTest { name = "aws_auth", docs = "Docs", params = listOf( - Binding("auth_spec", RuntimeType.String), - Binding("authorizer", RuntimeType.U64), + Binding.Concrete("auth_spec", RuntimeType.String), + Binding.Concrete("authorizer", RuntimeType.U64), + Binding.Generic("generic_list", RuntimeType("::std::vec::Vec"), listOf("T")), ), errorType = RuntimeType.std.resolve("io::Error"), initializer = Initializer( @@ -51,8 +52,8 @@ internal class ServiceConfigGeneratorTest { if authorizer != 69 { return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); } - - if auth_spec.len() != 69 { + + if auth_spec.len() != 69 && generic_list.len() != 69 { return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2")); } let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; @@ -63,13 +64,13 @@ internal class ServiceConfigGeneratorTest { }, layerBindings = emptyList(), httpPluginBindings = listOf( - Binding( + Binding.Concrete( "authn_plugin", smithyHttpServer.resolve("plugin::IdentityPlugin"), ), ), modelPluginBindings = listOf( - Binding( + Binding.Concrete( "authz_plugin", smithyHttpServer.resolve("plugin::IdentityPlugin"), ), @@ -101,7 +102,7 @@ internal class ServiceConfigGeneratorTest { // One model plugin has been applied. PluginStack, > = SimpleServiceConfig::builder() - .aws_auth("a".repeat(69).to_owned(), 69) + .aws_auth("a".repeat(69).to_owned(), 69, vec![69]) .expect("failed to configure aws_auth") .build() .unwrap(); @@ -113,7 +114,7 @@ internal class ServiceConfigGeneratorTest { rust( """ let actual_err = SimpleServiceConfig::builder() - .aws_auth("a".to_owned(), 69) + .aws_auth("a".to_owned(), 69, vec![69]) .unwrap_err(); let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 2").to_string(); assert_eq!(actual_err.to_string(), expected); @@ -125,7 +126,7 @@ internal class ServiceConfigGeneratorTest { rust( """ let actual_err = SimpleServiceConfig::builder() - .aws_auth("a".repeat(69).to_owned(), 6969) + .aws_auth("a".repeat(69).to_owned(), 6969, vec!["69"]) .unwrap_err(); let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 1").to_string(); assert_eq!(actual_err.to_string(), expected); @@ -147,7 +148,7 @@ internal class ServiceConfigGeneratorTest { } @Test - fun `it should inject an method that applies three non-required layers`() { + fun `it should inject a method that applies three non-required layers`() { val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() val decorator = object : ServerCodegenDecorator { @@ -179,9 +180,9 @@ internal class ServiceConfigGeneratorTest { ) }, layerBindings = listOf( - Binding("layer1", identityLayer), - Binding("layer2", identityLayer), - Binding("layer3", identityLayer), + Binding.Concrete("layer1", identityLayer), + Binding.Concrete("layer2", identityLayer), + Binding.Concrete("layer3", identityLayer), ), httpPluginBindings = emptyList(), modelPluginBindings = emptyList(), From 213c32bda809dde1cdeceb22a0c9835dc2cc6b6d Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 19 Dec 2023 15:00:01 +0100 Subject: [PATCH 2/3] Address comments --- .../generators/ServiceConfigGenerator.kt | 15 ++++-- .../generators/ServiceConfigGeneratorTest.kt | 48 ++++++++++++++++++- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt index 93da475202..f1917789ef 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock @@ -119,9 +120,10 @@ sealed class Binding { val ty: RuntimeType, /** * The generic type parameters contained in `ty`. For example, if `ty` renders to `Vec` with `T` being a - * generic type parameter, then `genericTys` should be a singleton list containing `"T"`. + * generic type parameter, then `genericTys` should be a singleton set containing `"T"`. + * You can't use `L`, `H`, or `M` as the names to refer to any generic types. * */ - val genericTys: List + val genericTys: Set ): Binding() data class Concrete( @@ -347,7 +349,14 @@ class ServiceConfigGenerator( val paramBindings = it.params.map { binding -> writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) } }.join("\n") - val paramBindingsGenericTys = it.params.filterIsInstance().flatMap { it.genericTys } + val genericBindings = it.params.filterIsInstance() + val lhmBindings = genericBindings.filter { it.genericTys.contains("L") || it.genericTys.contains("H") || it.genericTys.contains("M") } + if (lhmBindings.isNotEmpty()) { + throw CodegenException( + "Injected config method `${it.name}` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings: $lhmBindings" + ) + } + val paramBindingsGenericTys = genericBindings.flatMap { it.genericTys }.toSet() val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray()) // This produces a nested type like: "S>", where diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt index f1264b672f..e05895d92f 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt @@ -5,7 +5,11 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain import org.junit.jupiter.api.Test +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -42,7 +46,7 @@ internal class ServiceConfigGeneratorTest { params = listOf( Binding.Concrete("auth_spec", RuntimeType.String), Binding.Concrete("authorizer", RuntimeType.U64), - Binding.Generic("generic_list", RuntimeType("::std::vec::Vec"), listOf("T")), + Binding.Generic("generic_list", RuntimeType("::std::vec::Vec"), setOf("T")), ), errorType = RuntimeType.std.resolve("io::Error"), initializer = Initializer( @@ -229,4 +233,46 @@ internal class ServiceConfigGeneratorTest { } } } + + @Test + fun `it should throw an exception if a generic binding using L, H, or M is used`() { + val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() + + val decorator = object : ServerCodegenDecorator { + override val name: String + get() = "InvalidGenericBindingsDecorator" + override val order: Byte + get() = 69 + + override fun configMethods(codegenContext: ServerCodegenContext): List { + val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") + return listOf( + ConfigMethod( + name = "invalid_generic_bindings", + docs = "Docs", + params = listOf( + Binding.Generic("param1_bad", identityLayer, setOf("L")), + Binding.Generic("param2_bad", identityLayer, setOf("H")), + Binding.Generic("param3_bad", identityLayer, setOf("M")), + Binding.Generic("param4_ok", identityLayer, setOf("N")), + ), + errorType = null, + initializer = Initializer( + code = writable {}, + layerBindings = emptyList(), + httpPluginBindings = emptyList(), + modelPluginBindings = emptyList(), + ), + isRequired = false, + ), + ) + } + } + + val codegenException = shouldThrow { + serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, _ -> } + } + + codegenException.message.shouldContain("Injected config method `invalid_generic_bindings` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings:") + } } From 85330f6b81ec104a3efb59dd90228db7b3e34646 Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 19 Dec 2023 15:17:55 +0100 Subject: [PATCH 3/3] ./gradlew ktlintFormat --- .../generators/ServiceConfigGenerator.kt | 307 ++++++++++-------- .../generators/ServiceConfigGeneratorTest.kt | 204 ++++++------ 2 files changed, 273 insertions(+), 238 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt index 8541ec965b..56a662348f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt @@ -123,25 +123,27 @@ sealed class Binding { * generic type parameter, then `genericTys` should be a singleton set containing `"T"`. * You can't use `L`, `H`, or `M` as the names to refer to any generic types. * */ - val genericTys: Set - ): Binding() + val genericTys: Set, + ) : Binding() data class Concrete( /** The name of the variable. */ val name: String, /** The type of the variable. */ val ty: RuntimeType, - ): Binding() + ) : Binding() - fun name() = when (this) { - is Concrete -> this.name - is Generic -> this.name - } + fun name() = + when (this) { + is Concrete -> this.name + is Generic -> this.name + } - fun ty() = when (this) { - is Concrete -> this.ty - is Generic -> this.ty - } + fun ty() = + when (this) { + is Concrete -> this.ty + is Generic -> this.ty + } } class ServiceConfigGenerator( @@ -296,184 +298,203 @@ class ServiceConfigGenerator( private val isBuilderFallible = configMethods.isBuilderFallible() - private fun builderBuildRequiredMethodChecks() = configMethods.filter { it.isRequired }.map { - writable { - rustTemplate( - """ + private fun builderBuildRequiredMethodChecks() = + configMethods.filter { it.isRequired }.map { + writable { + rustTemplate( + """ if !self.${it.requiredBuilderFlagName()} { return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()}); } """, - *codegenScope, - ) - } - }.join("\n") + *codegenScope, + ) + } + }.join("\n") - private fun builderRequiredMethodFlagsDefinitions() = configMethods.filter { it.isRequired }.map { - writable { rust("pub(crate) ${it.requiredBuilderFlagName()}: bool,") } - }.join("\n") + private fun builderRequiredMethodFlagsDefinitions() = + configMethods.filter { it.isRequired }.map { + writable { rust("pub(crate) ${it.requiredBuilderFlagName()}: bool,") } + }.join("\n") - private fun builderRequiredMethodFlagsInit() = configMethods.filter { it.isRequired }.map { - writable { rust("${it.requiredBuilderFlagName()}: false,") } - }.join("\n") + private fun builderRequiredMethodFlagsInit() = + configMethods.filter { it.isRequired }.map { + writable { rust("${it.requiredBuilderFlagName()}: false,") } + }.join("\n") - private fun builderRequiredMethodFlagsMove() = configMethods.filter { it.isRequired }.map { - writable { rust("${it.requiredBuilderFlagName()}: self.${it.requiredBuilderFlagName()},") } - }.join("\n") + private fun builderRequiredMethodFlagsMove() = + configMethods.filter { it.isRequired }.map { + writable { rust("${it.requiredBuilderFlagName()}: self.${it.requiredBuilderFlagName()},") } + }.join("\n") - private fun builderRequiredMethodError() = writable { - if (isBuilderFallible) { - val variants = configMethods.filter { it.isRequired }.map { - writable { - rust( - """ + private fun builderRequiredMethodError() = + writable { + if (isBuilderFallible) { + val variants = + configMethods.filter { it.isRequired }.map { + writable { + rust( + """ ##[error("service is not fully configured; invoke `${it.name}` on the config builder")] ${it.requiredErrorVariant()}, """, - ) - } - } - rustTemplate( - """ + ) + } + } + rustTemplate( + """ ##[derive(Debug, #{ThisError}::Error)] pub enum ${serviceName}ConfigError { #{Variants:W} } """, - "ThisError" to ServerCargoDependency.ThisError.toType(), - "Variants" to variants.join("\n"), - ) - } - } - - private fun injectedMethods() = configMethods.map { - writable { - val paramBindings = it.params.map { binding -> - writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) } - }.join("\n") - val genericBindings = it.params.filterIsInstance() - val lhmBindings = genericBindings.filter { it.genericTys.contains("L") || it.genericTys.contains("H") || it.genericTys.contains("M") } - if (lhmBindings.isNotEmpty()) { - throw CodegenException( - "Injected config method `${it.name}` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings: $lhmBindings" + "ThisError" to ServerCargoDependency.ThisError.toType(), + "Variants" to variants.join("\n"), ) } - val paramBindingsGenericTys = genericBindings.flatMap { it.genericTys }.toSet() - val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray()) - - // This produces a nested type like: "S>", where - // - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack - // and the second is the "outer" part of the stack. The outer part gets executed first. For an example, - // see `aws_smithy_http_server::plugin::PluginStack`. - // - "A", "B" are the types of the "things" that are added. - // - "T" is the generic type variable name used in the enclosing impl block. - fun List.stackReturnType(genericTypeVarName: String, stackType: RuntimeType): Writable = - this.fold(writable { rust(genericTypeVarName) }) { acc, next -> - writable { - rustTemplate( - "#{StackType}<#{Ty}, #{Acc:W}>", - "StackType" to stackType, - "Ty" to next.ty(), - "Acc" to acc, - ) + } + + private fun injectedMethods() = + configMethods.map { + writable { + val paramBindings = + it.params.map { binding -> + writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) } + }.join("\n") + val genericBindings = it.params.filterIsInstance() + val lhmBindings = + genericBindings.filter { + it.genericTys.contains("L") || it.genericTys.contains("H") || it.genericTys.contains("M") } + if (lhmBindings.isNotEmpty()) { + throw CodegenException( + "Injected config method `${it.name}` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings: $lhmBindings", + ) } + val paramBindingsGenericTys = genericBindings.flatMap { it.genericTys }.toSet() + val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray()) + + // This produces a nested type like: "S>", where + // - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack + // and the second is the "outer" part of the stack. The outer part gets executed first. For an example, + // see `aws_smithy_http_server::plugin::PluginStack`. + // - "A", "B" are the types of the "things" that are added. + // - "T" is the generic type variable name used in the enclosing impl block. + fun List.stackReturnType( + genericTypeVarName: String, + stackType: RuntimeType, + ): Writable = + this.fold(writable { rust(genericTypeVarName) }) { acc, next -> + writable { + rustTemplate( + "#{StackType}<#{Ty}, #{Acc:W}>", + "StackType" to stackType, + "Ty" to next.ty(), + "Acc" to acc, + ) + } + } - val layersReturnTy = - it.initializer.layerBindings.stackReturnType("L", RuntimeType.Tower.resolve("layer::util::Stack")) - val httpPluginsReturnTy = - it.initializer.httpPluginBindings.stackReturnType("H", smithyHttpServer.resolve("plugin::PluginStack")) - val modelPluginsReturnTy = - it.initializer.modelPluginBindings.stackReturnType("M", smithyHttpServer.resolve("plugin::PluginStack")) + val layersReturnTy = + it.initializer.layerBindings.stackReturnType("L", RuntimeType.Tower.resolve("layer::util::Stack")) + val httpPluginsReturnTy = + it.initializer.httpPluginBindings.stackReturnType("H", smithyHttpServer.resolve("plugin::PluginStack")) + val modelPluginsReturnTy = + it.initializer.modelPluginBindings.stackReturnType("M", smithyHttpServer.resolve("plugin::PluginStack")) - val configBuilderReturnTy = writable { - rustTemplate( - """ + val configBuilderReturnTy = + writable { + rustTemplate( + """ ${serviceName}ConfigBuilder< #{LayersReturnTy:W}, #{HttpPluginsReturnTy:W}, #{ModelPluginsReturnTy:W}, > """, - "LayersReturnTy" to layersReturnTy, - "HttpPluginsReturnTy" to httpPluginsReturnTy, - "ModelPluginsReturnTy" to modelPluginsReturnTy, - ) - } + "LayersReturnTy" to layersReturnTy, + "HttpPluginsReturnTy" to httpPluginsReturnTy, + "ModelPluginsReturnTy" to modelPluginsReturnTy, + ) + } - val returnTy = if (it.errorType != null) { - writable { - rustTemplate( - "#{Result}<#{T:W}, #{E}>", - "T" to configBuilderReturnTy, - "E" to it.errorType, - *codegenScope, - ) - } - } else { - configBuilderReturnTy - } + val returnTy = + if (it.errorType != null) { + writable { + rustTemplate( + "#{Result}<#{T:W}, #{E}>", + "T" to configBuilderReturnTy, + "E" to it.errorType, + *codegenScope, + ) + } + } else { + configBuilderReturnTy + } - docs(it.docs) - rustBlockTemplate( - """ + docs(it.docs) + rustBlockTemplate( + """ pub fn ${it.name}#{ParamBindingsGenericsWritable}( ##[allow(unused_mut)] mut self, #{ParamBindings:W} ) -> #{ReturnTy:W} """, - "ReturnTy" to returnTy, - "ParamBindings" to paramBindings, - "ParamBindingsGenericsWritable" to paramBindingsGenericsWritable, - ) { - rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code) - - check(it.initializer.layerBindings.size + it.initializer.httpPluginBindings.size + it.initializer.modelPluginBindings.size > 0) { - "This method's initializer does not register any layers, HTTP plugins, or model plugins. It must register at least something!" - } + "ReturnTy" to returnTy, + "ParamBindings" to paramBindings, + "ParamBindingsGenericsWritable" to paramBindingsGenericsWritable, + ) { + rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code) + + check(it.initializer.layerBindings.size + it.initializer.httpPluginBindings.size + it.initializer.modelPluginBindings.size > 0) { + "This method's initializer does not register any layers, HTTP plugins, or model plugins. It must register at least something!" + } - if (it.isRequired) { - rust("self.${it.requiredBuilderFlagName()} = true;") - } - conditionalBlock("Ok(", ")", conditional = it.errorType != null) { - val registrations = ( - it.initializer.layerBindings.map { ".layer(${it.name()})" } + - it.initializer.httpPluginBindings.map { ".http_plugin(${it.name()})" } + - it.initializer.modelPluginBindings.map { ".model_plugin(${it.name()})" } - ).joinToString("") - rust("self$registrations") + if (it.isRequired) { + rust("self.${it.requiredBuilderFlagName()} = true;") + } + conditionalBlock("Ok(", ")", conditional = it.errorType != null) { + val registrations = + ( + it.initializer.layerBindings.map { ".layer(${it.name()})" } + + it.initializer.httpPluginBindings.map { ".http_plugin(${it.name()})" } + + it.initializer.modelPluginBindings.map { ".model_plugin(${it.name()})" } + ).joinToString("") + rust("self$registrations") + } } } - } - }.join("\n\n") + }.join("\n\n") - private fun builderBuildReturnType() = writable { - val t = "super::${serviceName}Config" + private fun builderBuildReturnType() = + writable { + val t = "super::${serviceName}Config" - if (isBuilderFallible) { - rustTemplate("#{Result}<$t, ${serviceName}ConfigError>", *codegenScope) - } else { - rust(t) + if (isBuilderFallible) { + rustTemplate("#{Result}<$t, ${serviceName}ConfigError>", *codegenScope) + } else { + rust(t) + } } - } - private fun builderBuildMethod() = writable { - rustBlockTemplate( - """ + private fun builderBuildMethod() = + writable { + rustBlockTemplate( + """ /// Build the configuration. pub fn build(self) -> #{BuilderBuildReturnTy:W} """, - "BuilderBuildReturnTy" to builderBuildReturnType(), - ) { - rustTemplate( - "#{BuilderBuildRequiredMethodChecks:W}", - "BuilderBuildRequiredMethodChecks" to builderBuildRequiredMethodChecks(), - ) - - conditionalBlock("Ok(", ")", isBuilderFallible) { - rust( - """ + "BuilderBuildReturnTy" to builderBuildReturnType(), + ) { + rustTemplate( + "#{BuilderBuildRequiredMethodChecks:W}", + "BuilderBuildRequiredMethodChecks" to builderBuildRequiredMethodChecks(), + ) + + conditionalBlock("Ok(", ")", isBuilderFallible) { + rust( + """ super::${serviceName}Config { layers: self.layers, http_plugins: self.http_plugins, diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt index 46edac4377..a744e53791 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt @@ -34,25 +34,29 @@ internal class ServiceConfigGeneratorTest { override val order: Byte get() = -69 - override fun configMethods(codegenContext: ServerCodegenContext): List { - val smithyHttpServer = ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType() - val codegenScope = arrayOf( - "SmithyHttpServer" to smithyHttpServer, - ) - return listOf( - ConfigMethod( - name = "aws_auth", - docs = "Docs", - params = listOf( - Binding.Concrete("auth_spec", RuntimeType.String), - Binding.Concrete("authorizer", RuntimeType.U64), - Binding.Generic("generic_list", RuntimeType("::std::vec::Vec"), setOf("T")), - ), - errorType = RuntimeType.std.resolve("io::Error"), - initializer = Initializer( - code = writable { - rustTemplate( - """ + override fun configMethods(codegenContext: ServerCodegenContext): List { + val smithyHttpServer = ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType() + val codegenScope = + arrayOf( + "SmithyHttpServer" to smithyHttpServer, + ) + return listOf( + ConfigMethod( + name = "aws_auth", + docs = "Docs", + params = + listOf( + Binding.Concrete("auth_spec", RuntimeType.String), + Binding.Concrete("authorizer", RuntimeType.U64), + Binding.Generic("generic_list", RuntimeType("::std::vec::Vec"), setOf("T")), + ), + errorType = RuntimeType.std.resolve("io::Error"), + initializer = + Initializer( + code = + writable { + rustTemplate( + """ if authorizer != 69 { return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); } @@ -63,28 +67,30 @@ internal class ServiceConfigGeneratorTest { let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; """, - *codegenScope, - ) - }, - layerBindings = emptyList(), - httpPluginBindings = listOf( - Binding.Concrete( - "authn_plugin", - smithyHttpServer.resolve("plugin::IdentityPlugin"), - ), - ), - modelPluginBindings = listOf( - Binding.Concrete( - "authz_plugin", - smithyHttpServer.resolve("plugin::IdentityPlugin"), + *codegenScope, + ) + }, + layerBindings = emptyList(), + httpPluginBindings = + listOf( + Binding.Concrete( + "authn_plugin", + smithyHttpServer.resolve("plugin::IdentityPlugin"), + ), + ), + modelPluginBindings = + listOf( + Binding.Concrete( + "authz_plugin", + smithyHttpServer.resolve("plugin::IdentityPlugin"), + ), + ), ), - ), + isRequired = true, ), - isRequired = true ) - ) + } } - } serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, rustCrate -> rustCrate.testModule { @@ -162,41 +168,45 @@ internal class ServiceConfigGeneratorTest { override val order: Byte get() = 69 - override fun configMethods(codegenContext: ServerCodegenContext): List { - val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") - val codegenScope = arrayOf( - "Identity" to identityLayer, - ) - return listOf( - ConfigMethod( - name = "three_non_required_layers", - docs = "Docs", - params = emptyList(), - errorType = null, - initializer = Initializer( - code = writable { - rustTemplate( - """ + override fun configMethods(codegenContext: ServerCodegenContext): List { + val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") + val codegenScope = + arrayOf( + "Identity" to identityLayer, + ) + return listOf( + ConfigMethod( + name = "three_non_required_layers", + docs = "Docs", + params = emptyList(), + errorType = null, + initializer = + Initializer( + code = + writable { + rustTemplate( + """ let layer1 = #{Identity}::new(); let layer2 = #{Identity}::new(); let layer3 = #{Identity}::new(); """, - *codegenScope, - ) - }, - layerBindings = listOf( - Binding.Concrete("layer1", identityLayer), - Binding.Concrete("layer2", identityLayer), - Binding.Concrete("layer3", identityLayer), - ), - httpPluginBindings = emptyList(), - modelPluginBindings = emptyList(), + *codegenScope, + ) + }, + layerBindings = + listOf( + Binding.Concrete("layer1", identityLayer), + Binding.Concrete("layer2", identityLayer), + Binding.Concrete("layer3", identityLayer), + ), + httpPluginBindings = emptyList(), + modelPluginBindings = emptyList(), + ), + isRequired = false, ), - isRequired = false, ) - ) + } } - } serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, rustCrate -> rustCrate.testModule { @@ -239,40 +249,44 @@ internal class ServiceConfigGeneratorTest { fun `it should throw an exception if a generic binding using L, H, or M is used`() { val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() - val decorator = object : ServerCodegenDecorator { - override val name: String - get() = "InvalidGenericBindingsDecorator" - override val order: Byte - get() = 69 + val decorator = + object : ServerCodegenDecorator { + override val name: String + get() = "InvalidGenericBindingsDecorator" + override val order: Byte + get() = 69 - override fun configMethods(codegenContext: ServerCodegenContext): List { - val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") - return listOf( - ConfigMethod( - name = "invalid_generic_bindings", - docs = "Docs", - params = listOf( - Binding.Generic("param1_bad", identityLayer, setOf("L")), - Binding.Generic("param2_bad", identityLayer, setOf("H")), - Binding.Generic("param3_bad", identityLayer, setOf("M")), - Binding.Generic("param4_ok", identityLayer, setOf("N")), - ), - errorType = null, - initializer = Initializer( - code = writable {}, - layerBindings = emptyList(), - httpPluginBindings = emptyList(), - modelPluginBindings = emptyList(), + override fun configMethods(codegenContext: ServerCodegenContext): List { + val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") + return listOf( + ConfigMethod( + name = "invalid_generic_bindings", + docs = "Docs", + params = + listOf( + Binding.Generic("param1_bad", identityLayer, setOf("L")), + Binding.Generic("param2_bad", identityLayer, setOf("H")), + Binding.Generic("param3_bad", identityLayer, setOf("M")), + Binding.Generic("param4_ok", identityLayer, setOf("N")), + ), + errorType = null, + initializer = + Initializer( + code = writable {}, + layerBindings = emptyList(), + httpPluginBindings = emptyList(), + modelPluginBindings = emptyList(), + ), + isRequired = false, ), - isRequired = false, - ), - ) + ) + } } - } - val codegenException = shouldThrow { - serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, _ -> } - } + val codegenException = + shouldThrow { + serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, _ -> } + } codegenException.message.shouldContain("Injected config method `invalid_generic_bindings` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings:") }