Skip to content

Commit

Permalink
Add new service builder codegen (#1693)
Browse files Browse the repository at this point in the history
* Add `ServerProtocol` interface to allow for server side protocol specific methods.

* Make public the structs merged in #1679.

* Add `ServerOperationGenerator`, which generates a ZST and implements `OperationShape` on it.

* Add `ServerServiceGeneratorV2`, which generates the service newtype around a router and a service builder.

* Add `hidden` argument to `RustModule` which allows modules to be marked with `#[doc(hidden)]`.

* Add `BuildModifier` trait to provide a common interface for extending service builders.

* Add `Upgradable` trait to simplifying bounds when upgrading from an `Operation` to a HTTP service.

* Add `FromRequest`, `FromParts`, and `IntoResponse` implementations.

* Make `RoutingService` accept general body types `B` for the inner services `http::Response<B>`.

* Use new service builder in protocol tests.
  • Loading branch information
hlbarber authored Sep 12, 2022
1 parent 56f4be3 commit fd94858
Show file tree
Hide file tree
Showing 21 changed files with 979 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ fun Writable.isEmpty(): Boolean {
return writer.toString() == RustWriter.root().toString()
}

operator fun Writable.plus(other: Writable): Writable {
val first = this
return writable {
rustTemplate("#{First:W}#{Second:W}", "First" to first, "Second" to other)
}
}

/**
* Helper allowing a `Iterable<Writable>` to be joined together using a `String` separator.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class AwsJsonSerializerGenerator(
}

open class AwsJson(
private val coreCodegenContext: CoreCodegenContext,
val coreCodegenContext: CoreCodegenContext,
private val awsJsonVersion: AwsJsonVersion,
) : Protocol {
private val runtimeConfig = coreCodegenContext.runtimeConfig
Expand All @@ -143,6 +143,8 @@ open class AwsJson(
)
private val jsonDeserModule = RustModule.private("json_deser")

val version: AwsJsonVersion get() = awsJsonVersion

override val httpBindingResolver: HttpBindingResolver =
AwsJsonHttpBindingResolver(coreCodegenContext.model, awsJsonVersion)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class RestJsonHttpBindingResolver(
}
}

class RestJson(private val coreCodegenContext: CoreCodegenContext) : Protocol {
open class RestJson(val coreCodegenContext: CoreCodegenContext) : Protocol {
private val runtimeConfig = coreCodegenContext.runtimeConfig
private val errorScope = arrayOf(
"Bytes" to RuntimeType.Bytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class RestXmlFactory(
}
}

open class RestXml(private val coreCodegenContext: CoreCodegenContext) : Protocol {
open class RestXml(val coreCodegenContext: CoreCodegenContext) : Protocol {
private val restXml = coreCodegenContext.serviceShape.expectTrait<RestXmlTrait>()
private val runtimeConfig = coreCodegenContext.runtimeConfig
private val errorScope = arrayOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ object ServerRuntimeType {
fun ResponseRejection(runtimeConfig: RuntimeConfig) =
RuntimeType("ResponseRejection", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::rejection")

fun Protocol(runtimeConfig: RuntimeConfig) =
RuntimeType("Protocol", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::protocols")
fun Protocol(name: String, runtimeConfig: RuntimeConfig) =
RuntimeType(name, ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::protocols")

fun Protocol(runtimeConfig: RuntimeConfig) = Protocol("Protocol", runtimeConfig)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.client.rustlang.Writable
import software.amazon.smithy.rust.codegen.client.rustlang.asType
import software.amazon.smithy.rust.codegen.client.rustlang.documentShape
import software.amazon.smithy.rust.codegen.client.rustlang.rust
import software.amazon.smithy.rust.codegen.client.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.client.rustlang.writable
import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.client.util.toPascalCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency

class ServerOperationGenerator(
coreCodegenContext: CoreCodegenContext,
private val operation: OperationShape,
) {
private val runtimeConfig = coreCodegenContext.runtimeConfig
private val codegenScope =
arrayOf(
"SmithyHttpServer" to
ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(),
)
private val symbolProvider = coreCodegenContext.symbolProvider
private val model = coreCodegenContext.model

private val operationName = symbolProvider.toSymbol(operation).name.toPascalCase()
private val operationId = operation.id

/** Returns `std::convert::Infallible` if the model provides no errors. */
private fun operationError(): Writable = writable {
if (operation.errors.isEmpty()) {
rust("std::convert::Infallible")
} else {
rust("crate::error::${operationName}Error")
}
}

fun render(writer: RustWriter) {
writer.documentShape(operation, model)

writer.rustTemplate(
"""
pub struct $operationName;
impl #{SmithyHttpServer}::operation::OperationShape for $operationName {
const NAME: &'static str = "${operationId.toString().replace("#", "##")}";
type Input = crate::input::${operationName}Input;
type Output = crate::output::${operationName}Output;
type Error = #{Error:W};
}
""",
"Error" to operationError(),
*codegenScope,
)
// Adds newline to end of render
writer.rust("")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.client.rustlang.Attribute
import software.amazon.smithy.rust.codegen.client.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.client.rustlang.RustModule
import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.client.rustlang.Visibility
import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.DefaultPublicModules
import software.amazon.smithy.rust.codegen.client.smithy.RustCrate
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator

/**
Expand Down Expand Up @@ -63,6 +67,36 @@ open class ServerServiceGenerator(
) { writer ->
renderOperationRegistry(writer, operations)
}

// TODO(https://github.com/awslabs/smithy-rs/issues/1707): Remove, this is temporary.
rustCrate.withModule(
RustModule(
"operation_shape",
RustMetadata(
visibility = Visibility.PUBLIC,
additionalAttributes = listOf(
Attribute.DocHidden,
),
),
null,
),
) { writer ->
for (operation in operations) {
ServerOperationGenerator(coreCodegenContext, operation).render(writer)
}
}

// TODO(https://github.com/awslabs/smithy-rs/issues/1707): Remove, this is temporary.
rustCrate.withModule(
RustModule("service", RustMetadata(visibility = Visibility.PUBLIC, additionalAttributes = listOf(Attribute.DocHidden)), null),
) { writer ->
val serverProtocol = ServerProtocol.fromCoreProtocol(protocol)
ServerServiceGeneratorV2(
coreCodegenContext,
serverProtocol,
).render(writer)
}

renderExtras(operations)
}

Expand Down
Loading

0 comments on commit fd94858

Please sign in to comment.