Skip to content

Commit

Permalink
merge fix
Browse files Browse the repository at this point in the history
  • Loading branch information
david-perez committed Jun 4, 2024
1 parent 5f8330f commit 9531447
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
Expand Down Expand Up @@ -173,8 +175,14 @@ class DefaultProtocolTestGenerator(
}
testModuleWriter.write("Test ID: ${testCase.id}")
testModuleWriter.newlinePrefix = ""

Attribute.TokioTest.render(testModuleWriter)
Attribute.TracedTest.render(testModuleWriter)
// The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in
// case the code rendered by the test does not make use of `tracing` at all.
val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) }
testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency)

val action = when (testCase) {
is HttpResponseTestCase -> Action.Response
is HttpRequestTestCase -> Action.Request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.protocols.traits.Rpcv2Trait
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
Expand Down Expand Up @@ -42,7 +42,7 @@ class ClientProtocolLoader(supportedProtocols: ProtocolMap<OperationGenerator, C
Ec2QueryTrait.ID to ClientEc2QueryFactory(),
RestJson1Trait.ID to ClientRestJsonFactory(),
RestXmlTrait.ID to ClientRestXmlFactory(),
Rpcv2Trait.ID to ClientRpcV2Factory(),
Rpcv2CborTrait.ID to ClientRpcV2CborFactory(),
)
val Default = ClientProtocolLoader(DefaultProtocols)
}
Expand Down Expand Up @@ -120,12 +120,11 @@ class ClientRestXmlFactory(
override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT
}

// TODO(rpcv2): Implement `ClientRpcV2Factory`
class ClientRpcV2Factory() : ProtocolGeneratorFactory<HttpBoundProtocolGenerator, ClientCodegenContext> {
class ClientRpcV2CborFactory : ProtocolGeneratorFactory<OperationGenerator, ClientCodegenContext> {
override fun protocol(codegenContext: ClientCodegenContext): Protocol = RpcV2(codegenContext)

override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): HttpBoundProtocolGenerator =
HttpBoundProtocolGenerator(codegenContext, protocol(codegenContext))
override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator =
OperationGenerator(codegenContext, protocol(codegenContext))

override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ open class AwsJson(
rustTemplate(
"""
pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> {
#{json_errors}::parse_error_metadata(payload, &#{HeaderMap}::new())
#{json_errors}::parse_error_metadata(payload, &#{Headers}::new())
}
""",
*errorScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol {
rustTemplate(
"""
pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> {
#{json_errors}::parse_error_metadata(payload, &#{HeaderMap}::new())
#{json_errors}::parse_error_metadata(payload, &#{Headers}::new())
}
""",
*errorScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
Expand All @@ -42,6 +43,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors
import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
Expand Down Expand Up @@ -280,6 +282,10 @@ class ServerProtocolTestGenerator(

Attribute.TokioTest.render(testModuleWriter)
Attribute.TracedTest.render(testModuleWriter)
// The `#[traced_test]` macro desugars to using `tracing`, so we need to depend on the latter explicitly in
// case the code rendered by the test does not make use of `tracing` at all.
val tracingDevDependency = testDependenciesOnly { addDependency(CargoDependency.Tracing.toDevDependency()) }
testModuleWriter.rustTemplate("#{TracingDevDependency:W}", "TracingDevDependency" to tracingDevDependency)

if (expectFail(testCase)) {
testModuleWriter.writeWithNoFormatting("#[should_panic]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator

class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() {
class StreamPayloadSerializerCustomization : ServerHttpBoundProtocolCustomization() {
override fun section(section: ServerHttpBoundProtocolSection): Writable =
when (section) {
is ServerHttpBoundProtocolSection.WrapStreamPayload ->
Expand Down Expand Up @@ -81,7 +81,7 @@ class ServerProtocolLoader(supportedProtocols: ProtocolMap<ServerProtocolGenerat
additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()),
),
// TODO `StreamPayloadSerializerCustomization`
Rpcv2CborTrait.ID to ServerRpcV2Factory(),
Rpcv2CborTrait.ID to ServerRpcV2CborFactory(),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGenerat
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRpcV2Protocol

class ServerRpcV2Factory : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator, ServerCodegenContext> {
class ServerRpcV2CborFactory : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator, ServerCodegenContext> {
override fun protocol(codegenContext: ServerCodegenContext): Protocol =
ServerRpcV2Protocol(codegenContext)

override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator =
ServerHttpBoundProtocolGenerator(codegenContext, ServerRpcV2Protocol(codegenContext))

override fun support(): ProtocolSupport {
// TODO(): Implement `ServerRpcV2Factory.support`
return ProtocolSupport(
/* Client support */
requestSerialization = false,
Expand Down

0 comments on commit 9531447

Please sign in to comment.