Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow injecting methods with generic type parameters in the config object #3274

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
@@ -13,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTypeParameters
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
@@ -33,7 +35,7 @@ data class ConfigMethod(
val docs: String,
/** The parameters of the method. **/
val params: List<Binding>,
/** In case the method is fallible, the error type it returns. **/
/** In case the method is fallible, the concrete error type it returns. **/
val errorType: RuntimeType?,
/** The code block inside the method. **/
val initializer: Initializer,
@@ -104,15 +106,45 @@ data class Initializer(
* }
*
* has two variable bindings. The `bar` name is bound to a `String` variable and the `baz` name is bound to a
* `u64` variable.
* `u64` variable. Both are bindings that use concrete types. Types can also be generic:
*
* ```rust
* fn foo<T>(bar: T) { }
* ```
*/
data class Binding(
/** The name of the variable. */
val name: String,
/** The type of the variable. */
val ty: RuntimeType,
)
sealed class Binding {
data class Generic(
/** The name of the variable. The name of the type parameter will be the PascalCased variable name. */
val name: String,
/** The type of the variable. */
val ty: RuntimeType,
/**
* The generic type parameters contained in `ty`. For example, if `ty` renders to `Vec<T>` with `T` being a
* generic type parameter, then `genericTys` should be a singleton set containing `"T"`.
* You can't use `L`, `H`, or `M` as the names to refer to any generic types.
* */
val genericTys: Set<String>,
) : Binding()

data class Concrete(
/** The name of the variable. */
val name: String,
/** The type of the variable. */
val ty: RuntimeType,
) : Binding()

fun name() =
when (this) {
is Concrete -> this.name
is Generic -> this.name
}

fun ty() =
when (this) {
is Concrete -> this.ty
is Generic -> this.ty
}
}

class ServiceConfigGenerator(
codegenContext: ServerCodegenContext,
@@ -271,10 +303,10 @@ class ServiceConfigGenerator(
writable {
rustTemplate(
"""
if !self.${it.requiredBuilderFlagName()} {
return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()});
}
""",
if !self.${it.requiredBuilderFlagName()} {
return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()});
}
""",
*codegenScope,
)
}
@@ -303,19 +335,19 @@ class ServiceConfigGenerator(
writable {
rust(
"""
##[error("service is not fully configured; invoke `${it.name}` on the config builder")]
${it.requiredErrorVariant()},
""",
##[error("service is not fully configured; invoke `${it.name}` on the config builder")]
${it.requiredErrorVariant()},
""",
)
}
}
rustTemplate(
"""
##[derive(Debug, #{ThisError}::Error)]
pub enum ${serviceName}ConfigError {
#{Variants:W}
}
""",
##[derive(Debug, #{ThisError}::Error)]
pub enum ${serviceName}ConfigError {
#{Variants:W}
}
""",
"ThisError" to ServerCargoDependency.ThisError.toType(),
"Variants" to variants.join("\n"),
)
@@ -327,8 +359,20 @@ class ServiceConfigGenerator(
writable {
val paramBindings =
it.params.map { binding ->
writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) }
writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) }
}.join("\n")
val genericBindings = it.params.filterIsInstance<Binding.Generic>()
val lhmBindings =
genericBindings.filter {
it.genericTys.contains("L") || it.genericTys.contains("H") || it.genericTys.contains("M")
}
if (lhmBindings.isNotEmpty()) {
throw CodegenException(
"Injected config method `${it.name}` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings: $lhmBindings",
)
}
val paramBindingsGenericTys = genericBindings.flatMap { it.genericTys }.toSet()
val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray())

// This produces a nested type like: "S<B, S<A, T>>", where
// - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack
@@ -345,7 +389,7 @@ class ServiceConfigGenerator(
rustTemplate(
"#{StackType}<#{Ty}, #{Acc:W}>",
"StackType" to stackType,
"Ty" to next.ty,
"Ty" to next.ty(),
"Acc" to acc,
)
}
@@ -362,12 +406,12 @@ class ServiceConfigGenerator(
writable {
rustTemplate(
"""
${serviceName}ConfigBuilder<
#{LayersReturnTy:W},
#{HttpPluginsReturnTy:W},
#{ModelPluginsReturnTy:W},
>
""",
${serviceName}ConfigBuilder<
#{LayersReturnTy:W},
#{HttpPluginsReturnTy:W},
#{ModelPluginsReturnTy:W},
>
""",
"LayersReturnTy" to layersReturnTy,
"HttpPluginsReturnTy" to httpPluginsReturnTy,
"ModelPluginsReturnTy" to modelPluginsReturnTy,
@@ -391,14 +435,15 @@ class ServiceConfigGenerator(
docs(it.docs)
rustBlockTemplate(
"""
pub fn ${it.name}(
##[allow(unused_mut)]
mut self,
#{ParamBindings:W}
) -> #{ReturnTy:W}
""",
pub fn ${it.name}#{ParamBindingsGenericsWritable}(
##[allow(unused_mut)]
mut self,
#{ParamBindings:W}
) -> #{ReturnTy:W}
""",
"ReturnTy" to returnTy,
"ParamBindings" to paramBindings,
"ParamBindingsGenericsWritable" to paramBindingsGenericsWritable,
) {
rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code)

@@ -412,9 +457,9 @@ class ServiceConfigGenerator(
conditionalBlock("Ok(", ")", conditional = it.errorType != null) {
val registrations =
(
it.initializer.layerBindings.map { ".layer(${it.name})" } +
it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } +
it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" }
it.initializer.layerBindings.map { ".layer(${it.name()})" } +
it.initializer.httpPluginBindings.map { ".http_plugin(${it.name()})" } +
it.initializer.modelPluginBindings.map { ".model_plugin(${it.name()})" }
).joinToString("")
rust("self$registrations")
}
@@ -437,9 +482,9 @@ class ServiceConfigGenerator(
writable {
rustBlockTemplate(
"""
/// Build the configuration.
pub fn build(self) -> #{BuilderBuildReturnTy:W}
""",
/// Build the configuration.
pub fn build(self) -> #{BuilderBuildReturnTy:W}
""",
"BuilderBuildReturnTy" to builderBuildReturnType(),
) {
rustTemplate(
@@ -450,12 +495,12 @@ class ServiceConfigGenerator(
conditionalBlock("Ok(", ")", isBuilderFallible) {
rust(
"""
super::${serviceName}Config {
layers: self.layers,
http_plugins: self.http_plugins,
model_plugins: self.model_plugins,
}
""",
super::${serviceName}Config {
layers: self.layers,
http_plugins: self.http_plugins,
model_plugins: self.model_plugins,
}
""",
)
}
}
Original file line number Diff line number Diff line change
@@ -5,7 +5,10 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators

import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.string.shouldContain
import org.junit.jupiter.api.Test
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
@@ -43,8 +46,9 @@ internal class ServiceConfigGeneratorTest {
docs = "Docs",
params =
listOf(
Binding("auth_spec", RuntimeType.String),
Binding("authorizer", RuntimeType.U64),
Binding.Concrete("auth_spec", RuntimeType.String),
Binding.Concrete("authorizer", RuntimeType.U64),
Binding.Generic("generic_list", RuntimeType("::std::vec::Vec<T>"), setOf("T")),
),
errorType = RuntimeType.std.resolve("io::Error"),
initializer =
@@ -53,30 +57,30 @@ internal class ServiceConfigGeneratorTest {
writable {
rustTemplate(
"""
if authorizer != 69 {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1"));
}
if auth_spec.len() != 69 {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2"));
}
let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin;
let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin;
""",
if authorizer != 69 {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1"));
}
if auth_spec.len() != 69 && generic_list.len() != 69 {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2"));
}
let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin;
let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin;
""",
*codegenScope,
)
},
layerBindings = emptyList(),
httpPluginBindings =
listOf(
Binding(
Binding.Concrete(
"authn_plugin",
smithyHttpServer.resolve("plugin::IdentityPlugin"),
),
),
modelPluginBindings =
listOf(
Binding(
Binding.Concrete(
"authz_plugin",
smithyHttpServer.resolve("plugin::IdentityPlugin"),
),
@@ -108,7 +112,7 @@ internal class ServiceConfigGeneratorTest {
// One model plugin has been applied.
PluginStack<IdentityPlugin, IdentityPlugin>,
> = SimpleServiceConfig::builder()
.aws_auth("a".repeat(69).to_owned(), 69)
.aws_auth("a".repeat(69).to_owned(), 69, vec![69])
.expect("failed to configure aws_auth")
.build()
.unwrap();
@@ -120,7 +124,7 @@ internal class ServiceConfigGeneratorTest {
rust(
"""
let actual_err = SimpleServiceConfig::builder()
.aws_auth("a".to_owned(), 69)
.aws_auth("a".to_owned(), 69, vec![69])
.unwrap_err();
let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 2").to_string();
assert_eq!(actual_err.to_string(), expected);
@@ -132,7 +136,7 @@ internal class ServiceConfigGeneratorTest {
rust(
"""
let actual_err = SimpleServiceConfig::builder()
.aws_auth("a".repeat(69).to_owned(), 6969)
.aws_auth("a".repeat(69).to_owned(), 6969, vec!["69"])
.unwrap_err();
let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 1").to_string();
assert_eq!(actual_err.to_string(), expected);
@@ -154,7 +158,7 @@ internal class ServiceConfigGeneratorTest {
}

@Test
fun `it should inject an method that applies three non-required layers`() {
fun `it should inject a method that applies three non-required layers`() {
val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel()

val decorator =
@@ -182,18 +186,18 @@ internal class ServiceConfigGeneratorTest {
writable {
rustTemplate(
"""
let layer1 = #{Identity}::new();
let layer2 = #{Identity}::new();
let layer3 = #{Identity}::new();
""",
let layer1 = #{Identity}::new();
let layer2 = #{Identity}::new();
let layer3 = #{Identity}::new();
""",
*codegenScope,
)
},
layerBindings =
listOf(
Binding("layer1", identityLayer),
Binding("layer2", identityLayer),
Binding("layer3", identityLayer),
Binding.Concrete("layer1", identityLayer),
Binding.Concrete("layer2", identityLayer),
Binding.Concrete("layer3", identityLayer),
),
httpPluginBindings = emptyList(),
modelPluginBindings = emptyList(),
@@ -240,4 +244,50 @@ internal class ServiceConfigGeneratorTest {
}
}
}

@Test
fun `it should throw an exception if a generic binding using L, H, or M is used`() {
val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel()

val decorator =
object : ServerCodegenDecorator {
override val name: String
get() = "InvalidGenericBindingsDecorator"
override val order: Byte
get() = 69

override fun configMethods(codegenContext: ServerCodegenContext): List<ConfigMethod> {
val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity")
return listOf(
ConfigMethod(
name = "invalid_generic_bindings",
docs = "Docs",
params =
listOf(
Binding.Generic("param1_bad", identityLayer, setOf("L")),
Binding.Generic("param2_bad", identityLayer, setOf("H")),
Binding.Generic("param3_bad", identityLayer, setOf("M")),
Binding.Generic("param4_ok", identityLayer, setOf("N")),
),
errorType = null,
initializer =
Initializer(
code = writable {},
layerBindings = emptyList(),
httpPluginBindings = emptyList(),
modelPluginBindings = emptyList(),
),
isRequired = false,
),
)
}
}

val codegenException =
shouldThrow<CodegenException> {
serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, _ -> }
}

codegenException.message.shouldContain("Injected config method `invalid_generic_bindings` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings:")
}
}