Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support server event streams #1479

Merged
merged 31 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7ec3c9c
Support server event streams
82marbag Jun 13, 2022
49e4cc5
fix function signature
82marbag Jul 14, 2022
cd3fa0d
revert tests for python
82marbag Jul 14, 2022
9c60845
update event stream marshalling tests
82marbag Jul 14, 2022
8f7e03f
remove todo on pokemon server
82marbag Jul 14, 2022
54ffd3f
document sign_empty None
82marbag Jul 15, 2022
a96ce2c
add documentation and rfc
82marbag Jul 15, 2022
dd1856b
update tests
82marbag Jul 18, 2022
dcc0dac
update tests
82marbag Jul 18, 2022
16f6fd9
address comments
82marbag Jul 18, 2022
09e9f31
update tests
82marbag Jul 18, 2022
2c700a5
render errors once
82marbag Jul 18, 2022
0ff3c5d
fix PythonCodegenServerPlugin
82marbag Jul 18, 2022
f1c932f
update tests
82marbag Jul 18, 2022
284e93d
update rfc
82marbag Jul 18, 2022
e1ff6e6
address comments
82marbag Jul 19, 2022
37f38e2
Merge branch 'main' into eventstreams
82marbag Jul 19, 2022
74fc001
address comments
82marbag Jul 19, 2022
011e02a
address comments
Jul 19, 2022
193ae52
render from errors once
82marbag Jul 19, 2022
6f77fb5
remove unused import in test
82marbag Jul 19, 2022
ecc60f4
refactor errors, generate on demand
82marbag Jul 21, 2022
6173c6e
Merge branch 'main' into eventstreams
82marbag Jul 21, 2022
5a10f87
Merge branch 'main' into eventstreams
82marbag Jul 21, 2022
0fce671
update CHANGELOG
82marbag Jul 21, 2022
4f51dd4
update TopLevelErrorGenerator
82marbag Jul 21, 2022
b446403
update CHANGELOG
82marbag Jul 22, 2022
5779a0d
Sort inline dependencies alphabetically by key
jdisanti Jul 21, 2022
f7fe78d
Revert "Sort inline dependencies alphabetically by key"
Jul 25, 2022
8f6ba8c
Merge branch 'main' into eventstreams
82marbag Jul 25, 2022
f389a45
rename rfc file
82marbag Jul 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class AwsInputPresignedMethod(
}

private fun RustWriter.writeInputPresignedMethod(section: OperationSection.InputImpl) {
val operationError = operationShape.errorSymbol(symbolProvider)
val operationError = operationShape.errorSymbol(coreCodegenContext.model, symbolProvider, coreCodegenContext.target)
val presignableOp = PRESIGNABLE_OPERATIONS.getValue(operationShape.id)

val makeOperationOp = if (presignableOp.hasModelTransforms()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@ import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerCombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.error.ServerCombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol

/**
* Generates a unified error enum for [operation]. It depends on [ServerCombinedErrorGenerator]
* to generate the errors from the model and adds the Rust implementation `From<pyo3::PyErr>`.
*/
class PythonServerCombinedErrorGenerator(
model: Model,
private val model: Model,
private val symbolProvider: RustSymbolProvider,
private val operation: OperationShape
) : ServerCombinedErrorGenerator(model, symbolProvider, operation) {
) : ServerCombinedErrorGenerator(model, symbolProvider, symbolProvider.toSymbol(operation), listOf()) {

private val operationIndex = OperationIndex.of(model)
private val errors = operationIndex.getErrors(operation)
Expand All @@ -53,7 +54,7 @@ class PythonServerCombinedErrorGenerator(

""",
"pyo3" to PythonServerCargoDependency.PyO3.asType(),
"Error" to operation.errorSymbol(symbolProvider),
"Error" to operation.errorSymbol(model, symbolProvider, CodegenTarget.SERVER),
"From" to RuntimeType.From,
"CastPyErrToRustError" to castPyErrToRustError()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
Expand Down Expand Up @@ -139,7 +140,7 @@ open class ServerOperationHandlerGenerator(
"Fun: FnOnce($inputName) -> Fut + Clone + Send + 'static,"
}
val outputType = if (operation.operationErrors(model).isNotEmpty()) {
"Result<${symbolProvider.toSymbol(operation.outputShape(model)).fullName}, ${operation.errorSymbol(symbolProvider).fullyQualifiedName()}>"
"Result<${symbolProvider.toSymbol(operation.outputShape(model)).fullName}, ${operation.errorSymbol(model, symbolProvider, CodegenTarget.SERVER).fullyQualifiedName()}>"
} else {
symbolProvider.toSymbol(operation.outputShape(model)).fullName
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import software.amazon.smithy.rust.codegen.smithy.Errors
import software.amazon.smithy.rust.codegen.smithy.Inputs
import software.amazon.smithy.rust.codegen.smithy.Outputs
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.util.getTrait
Expand Down Expand Up @@ -91,11 +92,11 @@ class ServerOperationRegistryGenerator(
}

writer.rustTemplate(
"""
"""
##[allow(clippy::tabs_in_doc_comments)]
/// The `$operationRegistryName` is the place where you can register
/// your service's operation implementations.
///
///
/// Use [`$operationRegistryBuilderName`] to construct the
/// `$operationRegistryName`. For each of the [operations] modeled in
/// your Smithy service, you need to provide an implementation in the
Expand All @@ -116,9 +117,9 @@ class ServerOperationRegistryGenerator(
/// type implementing [`tower::make::MakeService`], a _service
/// factory_. You can feed this value to a [Hyper server], and the
/// server will instantiate and [`serve`] your service.
///
///
/// Here's a full example to get you started:
///
///
/// ```rust
/// use std::net::SocketAddr;
$inputOutputErrorsImport
Expand Down Expand Up @@ -346,9 +347,9 @@ ${operationImplementationStubs(operations)}
} else ""
ret +
"""
/// ${it.signature()} {
/// todo!()
/// }
/// ${it.signature()} {
/// todo!()
/// }
""".trimIndent()
}

Expand All @@ -358,7 +359,7 @@ ${operationImplementationStubs(operations)}
private fun OperationShape.signature(): String {
val inputSymbol = symbolProvider.toSymbol(inputShape(model))
val outputSymbol = symbolProvider.toSymbol(outputShape(model))
val errorSymbol = errorSymbol(symbolProvider)
val errorSymbol = errorSymbol(model, symbolProvider, CodegenTarget.SERVER)

val inputT = "${Inputs.namespace}::${inputSymbol.name}"
val t = "${Outputs.namespace}::${outputSymbol.name}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ open class ServerServiceGenerator(

// Render combined errors.
open fun renderCombinedErrors(writer: RustWriter, operation: OperationShape) {
ServerCombinedErrorGenerator(coreCodegenContext.model, coreCodegenContext.symbolProvider, operation).render(writer)
/* Subclasses can override */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make it abstract?

}

// Render operations handler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerR
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.smithy.extractSymbolFromOption
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
Expand Down Expand Up @@ -215,7 +216,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
// Implement `into_response` for output types.

val outputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
val errorSymbol = operationShape.errorSymbol(symbolProvider)
val errorSymbol = operationShape.errorSymbol(model, symbolProvider, CodegenTarget.SERVER)

if (operationShape.operationErrors(model).isNotEmpty()) {
// The output of fallible operations is a `Result` which we convert into an
Expand Down Expand Up @@ -417,7 +418,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(

private fun serverSerializeError(operationShape: OperationShape): RuntimeType {
val fnName = "serialize_${operationShape.id.name.toSnakeCase()}_error"
val errorSymbol = operationShape.errorSymbol(symbolProvider)
val errorSymbol = operationShape.errorSymbol(model, symbolProvider, CodegenTarget.SERVER)
return RuntimeType.forInlineFun(fnName, operationSerModule) {
Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it)
it.rustBlockTemplate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.error.ServerCombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
Expand Down Expand Up @@ -52,7 +53,8 @@ class ServerCombinedErrorGeneratorTest {
listOf("FooException", "ComplexError", "InvalidGreeting").forEach {
model.lookup<StructureShape>("error#$it").renderWithModelBuilder(model, symbolProvider, writer, CodegenTarget.SERVER)
}
val generator = ServerCombinedErrorGenerator(model, symbolProvider, model.lookup("error#Greeting"))
val errors = listOf("FooException", "ComplexError", "InvalidGreeting").map { model.lookup<StructureShape>("error#$it") }
val generator = ServerCombinedErrorGenerator(model, symbolProvider, symbolProvider.toSymbol(model.lookup("error#Greeting")), errors)
generator.render(writer)

writer.unitTest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerCombinedErrorGenerator
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
Expand All @@ -27,6 +26,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.ServerCombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.smithy.transformers.EventStreamNormalizer
Expand All @@ -40,6 +40,7 @@ import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.lookup
import software.amazon.smithy.rust.codegen.util.outputShape
import java.util.stream.Stream
import kotlin.streams.toList

private fun fillInBaseModel(
protocolName: String,
Expand Down Expand Up @@ -342,10 +343,15 @@ object EventStreamTestTools {
CodegenTarget.SERVER -> serverTestSymbolProvider(model)
}
val project = TestWorkspace.testProject(symbolProvider)
val operationSymbol = symbolProvider.toSymbol(operationShape)
project.withModule(RustModule.public("error")) {
val errors = model.shapes()
.filter { shape -> shape.isStructureShape && shape.hasTrait<ErrorTrait>() }
.map { it.asStructureShape().get() }
.toList()
when (testCase.target) {
CodegenTarget.CLIENT -> CombinedErrorGenerator(model, symbolProvider, operationShape).render(it)
CodegenTarget.SERVER -> ServerCombinedErrorGenerator(model, symbolProvider, operationShape).render(it)
CodegenTarget.CLIENT -> CombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(it)
CodegenTarget.SERVER -> ServerCombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(it)
}
for (shape in model.shapes().filter { shape -> shape.isStructureShape && shape.hasTrait<ErrorTrait>() }) {
StructureGenerator(model, symbolProvider, it, shape as StructureShape).render(testCase.target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class EventStreamSymbolProvider(
val error = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) {
RuntimeType("MessageStreamError", smithyEventStream, "aws_smithy_http::event_stream").toSymbol()
} else {
unionShape.eventStreamErrorSymbol(this).toSymbol()
unionShape.eventStreamErrorSymbol(model, this, target).toSymbol()
}
val errorFmt = error.rustType().render(fullyQualified = true)
val innerFmt = initial.rustType().stripOuter<RustType.Option>().render(fullyQualified = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class PaginatorGenerator private constructor(

private val inputType = symbolProvider.toSymbol(operation.inputShape(model))
private val outputType = operation.outputShape(model)
private val errorType = operation.errorSymbol(symbolProvider)
private val errorType = operation.errorSymbol(model, symbolProvider, CodegenTarget.CLIENT)

private fun paginatorType(): RuntimeType = RuntimeType.forInlineFun(
paginatorName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import software.amazon.smithy.rust.codegen.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.smithy.RustCrate
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfigGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.TopLevelErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
Expand Down Expand Up @@ -55,10 +54,6 @@ class ServiceGenerator(
ProtocolTestGenerator(clientCodegenContext, protocolSupport, operation, operationWriter).render()
}
}
// Render a service-level error enum containing every error that the service can emit
rustCrate.withModule(RustModule.Error) { writer ->
CombinedErrorGenerator(clientCodegenContext.model, clientCodegenContext.symbolProvider, operation).render(writer)
}
}

TopLevelErrorGenerator(clientCodegenContext, operations).render(rustCrate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.customize.Section
import software.amazon.smithy.rust.codegen.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.smithy.generators.PaginatorGenerator
Expand Down Expand Up @@ -394,7 +395,7 @@ class FluentClientGenerator(

val output = operation.outputShape(model)
val operationOk = symbolProvider.toSymbol(output)
val operationErr = operation.errorSymbol(symbolProvider).toSymbol()
val operationErr = operation.errorSymbol(model, symbolProvider, CodegenTarget.CLIENT).toSymbol()

val inputFieldsBody = generateOperationShapeDocs(writer, symbolProvider, operation, model).joinToString("\n") {
"/// - $it"
Expand Down Expand Up @@ -486,7 +487,7 @@ class FluentClientGenerator(
) {
val inputType = symbolProvider.toSymbol(operation.inputShape(model))
val outputType = symbolProvider.toSymbol(operation.outputShape(model))
val errorType = operation.errorSymbol(symbolProvider)
val errorType = operation.errorSymbol(model, symbolProvider, CodegenTarget.CLIENT)
rustTemplate(
"""
/// Creates a new `${operationSymbol.name}`.
Expand Down Expand Up @@ -534,7 +535,7 @@ class FluentClientGenerator(
customizations,
FluentClientSection.FluentBuilderImpl(
operation,
operation.errorSymbol(symbolProvider)
operation.errorSymbol(model, symbolProvider, CodegenTarget.CLIENT)
)
)
input.members().forEach { member ->
Expand Down
Loading