Skip to content

Commit

Permalink
Remove Protocol enum (#1829)
Browse files Browse the repository at this point in the history
* Remove Protocol enum

* Update Python implementation
  • Loading branch information
hlbarber authored Oct 10, 2022
1 parent 238cf8b commit 78022d6
Show file tree
Hide file tree
Showing 16 changed files with 226 additions and 260 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,15 @@ message = "Fix regression where `connect_timeout` and `read_timeout` fields are
references = ["smithy-rs#1822"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "kevinpark1217"

[[smithy-rs]]
message = "Remove `Protocol` enum, removing an obstruction to extending smithy to third-party protocols."
references = ["smithy-rs#1829"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" }
author = "hlbarber"

[[smithy-rs]]
message = "Convert the `protocol` argument on `PyMiddlewares::new` constructor to a type parameter."
references = ["smithy-rs#1829"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" }
author = "hlbarber"
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* Generates a Python compatible application and server that can be configured from Python.
Expand Down Expand Up @@ -62,13 +63,13 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
*/
class PythonApplicationGenerator(
codegenContext: CodegenContext,
private val protocol: ServerProtocol,
private val operations: List<OperationShape>,
) {
private val symbolProvider = codegenContext.symbolProvider
private val libName = "lib${codegenContext.settings.moduleName.toSnakeCase()}"
private val runtimeConfig = codegenContext.runtimeConfig
private val model = codegenContext.model
private val protocol = codegenContext.protocol
private val codegenScope =
arrayOf(
"SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(),
Expand All @@ -88,6 +89,7 @@ class PythonApplicationGenerator(
fun render(writer: RustWriter) {
renderPyApplicationRustDocs(writer)
renderAppStruct(writer)
renderAppDefault(writer)
renderAppClone(writer)
renderPyAppTrait(writer)
renderAppImpl(writer)
Expand All @@ -98,7 +100,7 @@ class PythonApplicationGenerator(
writer.rustTemplate(
"""
##[#{pyo3}::pyclass]
##[derive(Debug, Default)]
##[derive(Debug)]
pub struct App {
handlers: #{HashMap}<String, #{SmithyPython}::PyHandler>,
middlewares: #{SmithyPython}::PyMiddlewares,
Expand Down Expand Up @@ -128,6 +130,25 @@ class PythonApplicationGenerator(
)
}

private fun renderAppDefault(writer: RustWriter) {
writer.rustTemplate(
"""
impl Default for App {
fn default() -> Self {
Self {
handlers: Default::default(),
middlewares: #{SmithyPython}::PyMiddlewares::new::<#{Protocol}>(vec![]),
context: None,
workers: #{parking_lot}::Mutex::new(vec![]),
}
}
}
""",
"Protocol" to protocol.markerStruct(),
*codegenScope,
)
}

private fun renderAppImpl(writer: RustWriter) {
writer.rustBlockTemplate(
"""
Expand Down Expand Up @@ -165,28 +186,24 @@ class PythonApplicationGenerator(
rustTemplate(
"""
let middleware_locals = pyo3_asyncio::TaskLocals::new(event_loop);
use #{SmithyPython}::PyApp;
let service = #{tower}::ServiceBuilder::new().layer(
#{SmithyPython}::PyMiddlewareLayer::new(
self.middlewares.clone(),
self.protocol(),
middleware_locals
)?,
);
let service = #{tower}::ServiceBuilder::new()
.layer(
#{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(self.middlewares.clone(), middleware_locals),
);
let router: #{SmithyServer}::routing::Router = router
.build()
.expect("Unable to build operation registry")
.into();
Ok(router.layer(service))
""",
"Protocol" to protocol.markerStruct(),
*codegenScope,
)
}
}
}

private fun renderPyAppTrait(writer: RustWriter) {
val protocol = protocol.toString().replace("#", "##")
writer.rustTemplate(
"""
impl #{SmithyPython}::PyApp for App {
Expand All @@ -202,9 +219,6 @@ class PythonApplicationGenerator(
fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares {
&mut self.middlewares
}
fn protocol(&self) -> &'static str {
"$protocol"
}
}
""",
*codegenScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationHandlerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* The Rust code responsible to run the Python business logic on the Python interpreter
Expand All @@ -33,8 +34,9 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperat
*/
class PythonServerOperationHandlerGenerator(
codegenContext: CodegenContext,
protocol: ServerProtocol,
private val operations: List<OperationShape>,
) : ServerOperationHandlerGenerator(codegenContext, operations) {
) : ServerOperationHandlerGenerator(codegenContext, protocol, operations) {
private val symbolProvider = codegenContext.symbolProvider
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ class PythonServerServiceGenerator(
}

override fun renderOperationHandler(writer: RustWriter, operations: List<OperationShape>) {
PythonServerOperationHandlerGenerator(context, operations).render(writer)
PythonServerOperationHandlerGenerator(context, protocol, operations).render(writer)
}

override fun renderExtras(operations: List<OperationShape>) {
rustCrate.withModule(RustModule.public("python_server_application", "Python server and application implementation.")) { writer ->
PythonApplicationGenerator(context, operations)
PythonApplicationGenerator(context, protocol, operations)
.render(writer)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,21 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErr
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator

/**
* ServerOperationHandlerGenerator
*/
open class ServerOperationHandlerGenerator(
codegenContext: CodegenContext,
val protocol: ServerProtocol,
private val operations: List<OperationShape>,
) {
private val serverCrate = "aws_smithy_http_server"
private val service = codegenContext.serviceShape
private val model = codegenContext.model
private val protocol = codegenContext.protocol
private val symbolProvider = codegenContext.symbolProvider
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope = arrayOf(
Expand Down Expand Up @@ -83,11 +82,8 @@ open class ServerOperationHandlerGenerator(
Ok(v) => v,
Err(extension_not_found_rejection) => {
let extension = $serverCrate::extension::RuntimeErrorExtension::new(extension_not_found_rejection.to_string());
let runtime_error = $serverCrate::runtime_error::RuntimeError {
protocol: #{SmithyHttpServer}::protocols::Protocol::${protocol.name.toPascalCase()},
kind: extension_not_found_rejection.into(),
};
let mut response = runtime_error.into_response();
let runtime_error = $serverCrate::runtime_error::RuntimeError::from(extension_not_found_rejection);
let mut response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error);
response.extensions_mut().insert(extension);
return response.map($serverCrate::body::boxed);
}
Expand All @@ -109,7 +105,8 @@ open class ServerOperationHandlerGenerator(
let input_wrapper = match $inputWrapperName::from_request(&mut req).await {
Ok(v) => v,
Err(runtime_error) => {
return runtime_error.into_response().map($serverCrate::body::boxed);
let response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error);
return response.map($serverCrate::body::boxed);
}
};
$callImpl
Expand All @@ -120,6 +117,7 @@ open class ServerOperationHandlerGenerator(
response.map(#{SmithyHttpServer}::body::boxed)
}
""",
"Protocol" to protocol.markerStruct(),
*codegenScope,
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ open class ServerServiceGenerator(
private val rustCrate: RustCrate,
private val protocolGenerator: ServerProtocolGenerator,
private val protocolSupport: ProtocolSupport,
private val protocol: ServerProtocol,
val protocol: ServerProtocol,
private val codegenContext: CodegenContext,
) {
private val index = TopDownIndex.of(codegenContext.model)
Expand Down Expand Up @@ -107,7 +107,7 @@ open class ServerServiceGenerator(

// Render operations handler.
open fun renderOperationHandler(writer: RustWriter, operations: List<OperationShape>) {
ServerOperationHandlerGenerator(codegenContext, operations).render(writer)
ServerOperationHandlerGenerator(codegenContext, protocol, operations).render(writer)
}

// Render operations registry.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTraitImplGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol

open class ServerProtocolGenerator(
codegenContext: CodegenContext,
protocol: Protocol,
val protocol: ServerProtocol,
makeOperationGenerator: MakeOperationGenerator,
private val traitGenerator: ProtocolTraitImplGenerator,
) : ProtocolGenerator(codegenContext, protocol, makeOperationGenerator, traitGenerator) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,9 @@ class ServerProtocolTestGenerator(
"""
let mut http_request = #{SmithyHttpServer}::request::RequestParts::new(http_request);
let rejection = super::$operationName::from_request(&mut http_request).await.expect_err("request was accepted but we expected it to be rejected");
let http_response = rejection.into_response();
let http_response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(rejection);
""",
"Protocol" to protocolGenerator.protocol.markerStruct(),
*codegenScope,
)
checkResponse(this, testCase.response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
Expand Down Expand Up @@ -177,10 +176,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
rustTemplate(
"""
if ! #{SmithyHttpServer}::protocols::accept_header_classifier(req, ${contentType.dq()}) {
return Err(#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::NotAcceptable,
})
return Err(#{RuntimeError}::NotAcceptable)
}
""",
*codegenScope,
Expand All @@ -200,10 +196,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
rustTemplate(
"""
if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() {
return Err(#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::UnsupportedMediaType,
})
return Err(#{RuntimeError}::UnsupportedMediaType)
}
""",
*codegenScope,
Expand All @@ -230,12 +223,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
#{parse_request}(req)
.await
.map($inputName)
.map_err(
|err| #{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: err.into()
}
)
.map_err(Into::into)
}
}
Expand Down Expand Up @@ -282,12 +270,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
Self::Output(o) => {
match #{serialize_response}(o) {
Ok(response) => response,
Err(e) => {
#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: e.into()
}.into_response()
}
Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e))
}
},
Self::Error(err) => {
Expand All @@ -296,12 +279,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
response.extensions_mut().insert(#{SmithyHttpServer}::extension::ModeledErrorExtension::new(err.name()));
response
},
Err(e) => {
#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: e.into()
}.into_response()
}
Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e))
}
}
}
Expand Down Expand Up @@ -346,12 +324,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
"""
match #{serialize_response}(self.0) {
Ok(response) => response,
Err(e) => {
#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: e.into()
}.into_response()
}
Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e))
}
""".trimIndent()

Expand Down
Loading

0 comments on commit 78022d6

Please sign in to comment.