diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt index 05bcc1937e..3a5949c4d4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt @@ -64,6 +64,7 @@ class ServerServiceGeneratorV2( /** A `Writable` block containing all the `Handler` and `Operation` setters for the builder. */ private fun builderSetters(): Writable = writable { + val pluginType = listOf("Pl") for ((index, pair) in builderFieldNames.zip(operationStructNames).withIndex()) { val (fieldName, structName) = pair @@ -128,16 +129,17 @@ class ServerServiceGeneratorV2( /// [`$structName`](crate::operation_shape::$structName) using either /// [`OperationShape::from_handler`](#{SmithyHttpServer}::operation::OperationShapeExt::from_handler) or /// [`OperationShape::from_service`](#{SmithyHttpServer}::operation::OperationShapeExt::from_service). - pub fn ${fieldName}_operation(self, value: NewOp) -> $builderName<${(replacedOpGenerics + replacedExtGenerics).joinToString(", ")}> + pub fn ${fieldName}_operation(self, value: NewOp) -> $builderName<${(replacedOpGenerics + replacedExtGenerics + pluginType).joinToString(", ")}> { $builderName { ${switchedFields.joinToString(", ")}, - _exts: std::marker::PhantomData + _exts: std::marker::PhantomData, + plugin: self.plugin, } } """, "Protocol" to protocol.markerStruct(), - "HandlerSetterGenerics" to (replacedOpServiceGenerics + (replacedExtGenerics.map { writable(it) })).join(", "), + "HandlerSetterGenerics" to (replacedOpServiceGenerics + ((replacedExtGenerics + pluginType).map { writable(it) })).join(", "), *codegenScope, ) @@ -159,6 +161,7 @@ class ServerServiceGeneratorV2( crate::operation_shape::${symbolProvider.toSymbol(operation).name.toPascalCase()}, $exts, B, + Pl, >, $type::Service: Clone + Send + 'static, <$type::Service as #{Tower}::Service<#{Http}::Request>>::Future: Send + 'static, @@ -174,18 +177,26 @@ class ServerServiceGeneratorV2( /** Returns a `Writable` containing the builder struct definition and its implementations. */ private fun builder(): Writable = writable { val extensionTypesDefault = extensionTypes.map { "$it = ()" } - val structGenerics = (builderOps + extensionTypesDefault).joinToString(", ") - val builderGenerics = (builderOps + extensionTypes).joinToString(", ") + val pluginName = "Pl" + val pluginTypeList = listOf(pluginName) + val newPluginType = "New$pluginName" + val pluginTypeDefault = listOf("$pluginName = #{SmithyHttpServer}::plugin::IdentityPlugin") + val structGenerics = (builderOps + extensionTypesDefault + pluginTypeDefault).joinToString(", ") + val builderGenerics = (builderOps + extensionTypes + pluginTypeList).joinToString(", ") + val builderGenericsNoPlugin = (builderOps + extensionTypes).joinToString(", ") // Generate router construction block. val router = protocol .routerConstruction( builderFieldNames .map { - writable { rustTemplate("self.$it.upgrade()") } + writable { rustTemplate("self.$it.upgrade(&self.plugin)") } } .asIterable(), ) + val setterFields = builderFieldNames.map { item -> + "$item: self.$item" + }.joinToString(", ") rustTemplate( """ /// The service builder for [`$serviceName`]. @@ -194,7 +205,8 @@ class ServerServiceGeneratorV2( pub struct $builderName<$structGenerics> { ${builderFields.joinToString(", ")}, ##[allow(unused_parens)] - _exts: std::marker::PhantomData<(${extensionTypes.joinToString(", ")})> + _exts: std::marker::PhantomData<(${extensionTypes.joinToString(", ")})>, + plugin: $pluginName, } impl<$builderGenerics> $builderName<$builderGenerics> { @@ -213,6 +225,17 @@ class ServerServiceGeneratorV2( } } } + + impl<$builderGenerics, $newPluginType> #{SmithyHttpServer}::plugin::Pluggable<$newPluginType> for $builderName<$builderGenerics> { + type Output = $builderName<$builderGenericsNoPlugin, #{SmithyHttpServer}::plugin::PluginStack<$pluginName, $newPluginType>>; + fn apply(self, plugin: $newPluginType) -> Self::Output { + $builderName { + $setterFields, + _exts: self._exts, + plugin: #{SmithyHttpServer}::plugin::PluginStack::new(self.plugin, plugin), + } + } + } """, "Setters" to builderSetters(), "BuildConstraints" to buildConstraints.join(", "), @@ -265,7 +288,8 @@ class ServerServiceGeneratorV2( pub fn builder() -> $builderName<#{NotSetGenerics:W}> { $builderName { #{NotSetFields:W}, - _exts: std::marker::PhantomData + _exts: std::marker::PhantomData, + plugin: #{SmithyHttpServer}::plugin::IdentityPlugin } } @@ -276,7 +300,8 @@ class ServerServiceGeneratorV2( pub fn unchecked_builder() -> $builderName<#{InternalFailureGenerics:W}> { $builderName { #{InternalFailureFields:W}, - _exts: std::marker::PhantomData + _exts: std::marker::PhantomData, + plugin: #{SmithyHttpServer}::plugin::IdentityPlugin } } } diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs index 0089f1557c..4340a91cac 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs @@ -19,6 +19,9 @@ use pokemon_service_server_sdk::{error, input, model, model::CapturingPayload, o use rand::Rng; use tracing_subscriber::{prelude::*, EnvFilter}; +#[doc(hidden)] +pub mod plugin; + const PIKACHU_ENGLISH_FLAVOR_TEXT: &str = "When several of these Pokémon gather, their electricity could build and cause lightning storms."; const PIKACHU_SPANISH_FLAVOR_TEXT: &str = diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/plugin.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/plugin.rs new file mode 100644 index 0000000000..6c76df5613 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/plugin.rs @@ -0,0 +1,80 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_http_server::plugin::Plugin; + +/// A [`Service`](tower::Service) that adds a print log. +#[derive(Clone, Debug)] +pub struct PrintService { + inner: S, + name: &'static str, +} + +impl tower::Service for PrintService +where + S: tower::Service, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: R) -> Self::Future { + println!("Hi {}", self.name); + self.inner.call(req) + } +} + +/// A [`Layer`](tower::Layer) which constructs the [`PrintService`]. +#[derive(Debug)] +pub struct PrintLayer { + name: &'static str, +} +impl tower::Layer for PrintLayer { + type Service = PrintService; + + fn layer(&self, service: S) -> Self::Service { + PrintService { + inner: service, + name: self.name, + } + } +} + +/// A [`Plugin`]() for a service builder to add a [`PrintLayer`] over operations. +#[derive(Debug)] +pub struct PrintPlugin; +impl Plugin for PrintPlugin +where + Op: aws_smithy_http_server::operation::OperationShape, +{ + type Service = S; + type Layer = tower::layer::util::Stack; + + fn map( + &self, + input: aws_smithy_http_server::operation::Operation, + ) -> aws_smithy_http_server::operation::Operation { + input.layer(PrintLayer { name: Op::NAME }) + } +} + +/// An extension to service builders to add the `print()` function. +pub trait PrintExt: aws_smithy_http_server::plugin::Pluggable { + /// Causes all operations to print the operation name when called. + /// + /// This works by applying the [`PrintPlugin`]. + fn print(self) -> Self::Output + where + Self: Sized, + { + self.apply(PrintPlugin) + } +} + +impl PrintExt for Builder where Builder: aws_smithy_http_server::plugin::Pluggable {} diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index b032060bbc..842b5cfdfd 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -18,6 +18,8 @@ pub mod logging; #[doc(hidden)] pub mod operation; #[doc(hidden)] +pub mod plugin; +#[doc(hidden)] pub mod protocols; #[doc(hidden)] pub mod rejection; diff --git a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs index 3d4a617ccf..5028231e21 100644 --- a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs +++ b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs @@ -18,6 +18,7 @@ use tracing::error; use crate::{ body::BoxBody, + plugin::Plugin, request::{FromParts, FromRequest}, response::IntoResponse, runtime_error::InternalFailureException, @@ -220,14 +221,14 @@ where /// Provides an interface to convert a representation of an operation to a HTTP [`Service`](tower::Service) with /// canonical associated types. -pub trait Upgradable { +pub trait Upgradable { type Service: Service, Response = http::Response>; /// Performs an upgrade from a representation of an operation to a HTTP [`Service`](tower::Service). - fn upgrade(self) -> Self::Service; + fn upgrade(self, plugin: &Plugin) -> Self::Service; } -impl Upgradable for Operation +impl Upgradable for Operation where // `Op` is used to specify the operation shape Op: OperationShape, @@ -245,21 +246,26 @@ where // The signature of the inner service is correct S: Service<(Op::Input, Exts), Response = Op::Output, Error = OperationError> + Clone, - // Layer applies correctly to `Upgrade` - L: Layer>, + // The plugin takes this operation as input + Pl: Plugin, + + // The modified Layer applies correctly to `Upgrade` + Pl::Layer: Layer>, // The signature of the output is correct - L::Service: Service, Response = http::Response>, + >>::Service: + Service, Response = http::Response>, { - type Service = L::Service; + type Service = >>::Service; - /// Takes the [`Operation`](Operation), applies [`UpgradeLayer`] to + /// Takes the [`Operation`](Operation), applies [`Plugin`], then applies [`UpgradeLayer`] to /// the modified `S`, then finally applies the modified `L`. /// /// The composition is made explicit in the method constraints and return type. - fn upgrade(self) -> Self::Service { - let layer = Stack::new(UpgradeLayer::new(), self.layer); - layer.layer(self.inner) + fn upgrade(self, plugin: &Pl) -> Self::Service { + let mapped = plugin.map(self); + let layer = Stack::new(UpgradeLayer::new(), mapped.layer); + layer.layer(mapped.inner) } } @@ -273,13 +279,13 @@ pub struct MissingOperation; /// This _does_ implement [`Upgradable`] but produces a [`Service`] which always returns an internal failure message. pub struct FailOnMissingOperation; -impl Upgradable for FailOnMissingOperation +impl Upgradable for FailOnMissingOperation where InternalFailureException: IntoResponse

, { type Service = MissingFailure

; - fn upgrade(self) -> Self::Service { + fn upgrade(self, _plugin: &Pl) -> Self::Service { MissingFailure { _protocol: PhantomData } } } diff --git a/rust-runtime/aws-smithy-http-server/src/plugin.rs b/rust-runtime/aws-smithy-http-server/src/plugin.rs new file mode 100644 index 0000000000..8ba1f5747a --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/plugin.rs @@ -0,0 +1,79 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::operation::Operation; + +/// Provides a standard interface for applying [`Plugin`]s to a service builder. This is implemented automatically for all builders. +/// As [`Plugin`]s modify the way in which [`Operation`]s are [`upgraded`](crate::operation::Upgradable) we can use [`Pluggable`] as a foundation +/// to write extension traits for all builders. +/// +/// # Example +/// +/// ``` +/// # struct PrintPlugin; +/// # use aws_smithy_http_server::plugin::Pluggable; +/// trait PrintExt: Pluggable { +/// fn print(self) -> Self::Output where Self: Sized { +/// self.apply(PrintPlugin) +/// } +/// } +/// impl PrintExt for Builder where Builder: Pluggable {} +/// ``` +pub trait Pluggable { + type Output; + + /// A service builder applies this `plugin`. + fn apply(self, plugin: NewPlugin) -> Self::Output; +} + +/// Maps one [`Operation`] to another, +/// parameterised by the protocol `P` and operation shape `Op` to allow for plugin behaviour to be specialised accordingly. +/// +/// This is passed to [`Pluggable::apply`] to modify the behaviour of the builder. +pub trait Plugin { + type Service; + type Layer; + + /// Map an [`Operation`] to another. + fn map(&self, input: Operation) -> Operation; +} + +/// An [`Plugin`] that maps an `input` [`Operation`] to itself. +pub struct IdentityPlugin; +impl Plugin for IdentityPlugin { + type Service = S; + type Layer = L; + + fn map(&self, input: Operation) -> Operation { + input + } +} + +/// A wrapper struct which composes an `Inner` and an `Outer` [`Plugin`]. +pub struct PluginStack { + inner: Inner, + outer: Outer, +} + +impl PluginStack { + /// Creates a new [`PluginStack`]. + pub fn new(inner: Inner, outer: Outer) -> Self { + PluginStack { inner, outer } + } +} + +impl Plugin for PluginStack +where + Inner: Plugin, + Outer: Plugin, +{ + type Service = Outer::Service; + type Layer = Outer::Layer; + + fn map(&self, input: Operation) -> Operation { + let inner = self.inner.map(input); + self.outer.map(inner) + } +}