Skip to content

Commit

Permalink
Merge branch 'harryb/move-protocol-to-server-protocol2' into harryb/m…
Browse files Browse the repository at this point in the history
…ove-protocol-to-server-protocol
  • Loading branch information
Harry Barber committed Sep 13, 2022
2 parents 765260e + 20974d3 commit 54cc60e
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class AwsJsonSerializerGenerator(

open class AwsJson(
val coreCodegenContext: CoreCodegenContext,
private val awsJsonVersion: AwsJsonVersion,
val awsJsonVersion: AwsJsonVersion,
) : Protocol {
private val runtimeConfig = coreCodegenContext.runtimeConfig
private val errorScope = arrayOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import software.amazon.smithy.rust.codegen.client.smithy.protocols.AwsJsonVersio
import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.client.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.client.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator

private fun allOperations(coreCodegenContext: CoreCodegenContext): List<OperationShape> {
val index = TopDownIndex.of(coreCodegenContext.model)
Expand Down Expand Up @@ -58,6 +60,13 @@ interface ServerProtocol : Protocol {
requestSpecModule: RuntimeType,
): Writable

/**
* In some protocols, such as restJson1,
* when there is no modeled body input, content type must not be set and the body must be empty.
* Returns a boolean indicating whether to perform this check.
*/
fun serverContentTypeCheckNoModeledInput(): Boolean = false

companion object {
/** Upgrades the core protocol to a `ServerProtocol`. */
fun fromCoreProtocol(protocol: Protocol): ServerProtocol = when (protocol) {
Expand All @@ -80,6 +89,9 @@ class ServerAwsJsonProtocol(
private val symbolProvider = coreCodegenContext.symbolProvider
private val service = coreCodegenContext.serviceShape

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
ServerAwsJsonSerializerGenerator(coreCodegenContext, httpBindingResolver, awsJsonVersion)

companion object {
fun fromCoreProtocol(awsJson: AwsJson): ServerAwsJsonProtocol = ServerAwsJsonProtocol(awsJson.coreCodegenContext, awsJson.version)
}
Expand Down Expand Up @@ -214,6 +226,8 @@ class ServerRestJsonProtocol(
): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape)

override fun serverRouterRuntimeConstructor() = "new_rest_json_router"

override fun serverContentTypeCheckNoModeledInput() = true
}

class ServerRestXmlProtocol(
Expand Down Expand Up @@ -241,4 +255,6 @@ class ServerRestXmlProtocol(
): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape)

override fun serverRouterRuntimeConstructor() = "new_rest_xml_router"

override fun serverContentTypeCheckNoModeledInput() = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,6 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case1", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case2", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithBodyExpectsApplicationJsonContentType", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedContentType", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithPayloadExpectsModeledContentType", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithoutBodyExpectsEmptyContentType", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonBodyMalformedListNullItem", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonBodyMalformedMapNullValue", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonMalformedSetDuplicateItems", TestType.MalformedRequest),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.protocols

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.client.rustlang.Writable
import software.amazon.smithy.rust.codegen.client.rustlang.escape
Expand All @@ -14,25 +13,25 @@ import software.amazon.smithy.rust.codegen.client.rustlang.writable
import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.client.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.client.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.client.smithy.protocols.awsJsonFieldName
import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.JsonCustomization
import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.JsonSection
import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.client.util.hasTrait
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerAwsJsonProtocol
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* AwsJson 1.0 and 1.1 server-side protocol factory. This factory creates the [ServerHttpBoundProtocolGenerator]
* with AwsJson specific configurations.
*/
class ServerAwsJsonFactory(private val version: AwsJsonVersion) :
ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator, ServerCodegenContext> {
override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerAwsJson(codegenContext, version)
override fun protocol(codegenContext: ServerCodegenContext): ServerProtocol = ServerAwsJsonProtocol(codegenContext, version)

override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator =
ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext))
Expand Down Expand Up @@ -93,11 +92,3 @@ class ServerAwsJsonSerializerGenerator(
customizations = listOf(ServerAwsJsonError(awsJsonVersion)),
),
) : StructuredDataSerializerGenerator by jsonSerializerGenerator

class ServerAwsJson(
coreCodegenContext: CoreCodegenContext,
private val awsJsonVersion: AwsJsonVersion,
) : AwsJson(coreCodegenContext, awsJsonVersion) {
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
ServerAwsJsonSerializerGenerator(coreCodegenContext, httpBindingResolver, awsJsonVersion)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
import software.amazon.smithy.model.traits.HttpPayloadTrait
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.rust.codegen.client.rustlang.Attribute
import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.client.rustlang.RustModule
import software.amazon.smithy.rust.codegen.client.rustlang.RustType
import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.client.rustlang.Writable
import software.amazon.smithy.rust.codegen.client.rustlang.asType
import software.amazon.smithy.rust.codegen.client.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.client.rustlang.render
import software.amazon.smithy.rust.codegen.client.rustlang.rust
import software.amazon.smithy.rust.codegen.client.rustlang.rustBlock
Expand All @@ -55,16 +58,17 @@ import software.amazon.smithy.rust.codegen.client.smithy.isOptional
import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.client.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.client.smithy.toOptional
import software.amazon.smithy.rust.codegen.client.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.client.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.client.smithy.wrapOptional
import software.amazon.smithy.rust.codegen.client.util.dq
import software.amazon.smithy.rust.codegen.client.util.expectTrait
import software.amazon.smithy.rust.codegen.client.util.findStreamingMember
import software.amazon.smithy.rust.codegen.client.util.getTrait
import software.amazon.smithy.rust.codegen.client.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.client.util.hasTrait
import software.amazon.smithy.rust.codegen.client.util.inputShape
import software.amazon.smithy.rust.codegen.client.util.isStreaming
import software.amazon.smithy.rust.codegen.client.util.outputShape
Expand All @@ -84,7 +88,7 @@ import java.util.logging.Logger
*/
class ServerHttpBoundProtocolGenerator(
codegenContext: ServerCodegenContext,
protocol: Protocol,
protocol: ServerProtocol,
) : ProtocolGenerator(
codegenContext,
protocol,
Expand All @@ -110,7 +114,7 @@ class ServerHttpBoundProtocolGenerator(
*/
private class ServerHttpBoundProtocolTraitImplGenerator(
private val codegenContext: ServerCodegenContext,
private val protocol: Protocol,
private val protocol: ServerProtocol,
) : ProtocolTraitImplGenerator {
private val logger = Logger.getLogger(javaClass.name)
private val symbolProvider = codegenContext.symbolProvider
Expand Down Expand Up @@ -168,7 +172,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
val operationName = symbolProvider.toSymbol(operationShape).name
val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"

val verifyResponseContentType = writable {
val verifyAcceptHeader = writable {
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
rustTemplate(
"""
Expand All @@ -183,6 +187,30 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
)
}
}
val verifyRequestContentTypeHeader = writable {
operationShape
.inputShape(model)
.members()
.find { it.hasTrait<HttpPayloadTrait>() }
?.let { payload ->
val target = model.expectShape(payload.target)
if (!target.isBlobShape || target.hasTrait<MediaTypeTrait>()) {
val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)
?.let { "Some(${it.dq()})" } ?: "None"
rustTemplate(
"""
if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() {
return Err(#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::UnsupportedMediaType,
})
}
""",
*codegenScope,
)
}
}
}

// Implement `from_request` trait for input types.
rustTemplate(
Expand All @@ -197,7 +225,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
B::Data: Send,
#{RequestRejection} : From<<B as #{SmithyHttpServer}::body::HttpBody>::Error>
{
#{verify_response_content_type:W}
#{verifyAcceptHeader:W}
#{verifyRequestContentTypeHeader:W}
#{parse_request}(req)
.await
.map($inputName)
Expand Down Expand Up @@ -235,7 +264,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
"I" to inputSymbol,
"Marker" to serverProtocol.markerStruct(),
"parse_request" to serverParseRequest(operationShape),
"verify_response_content_type" to verifyResponseContentType,
"verifyAcceptHeader" to verifyAcceptHeader,
"verifyRequestContentTypeHeader" to verifyRequestContentTypeHeader,
)

// Implement `into_response` for output types.
Expand Down Expand Up @@ -711,16 +741,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
Attribute.AllowUnusedMut.render(this)
rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider))
val parser = structuredDataParser.serverInputParser(operationShape)
val noInputs = model.expectShape(operationShape.inputShape).expectTrait<SyntheticInputTrait>().originalId == null
if (parser != null) {
val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)
rustTemplate(
"""
let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?;
let bytes = #{Hyper}::body::to_bytes(body).await?;
if !bytes.is_empty() {
static EXPECTED_CONTENT_TYPE: #{OnceCell}::sync::Lazy<#{Mime}::Mime> =
#{OnceCell}::sync::Lazy::new(|| "$expectedRequestContentType".parse::<#{Mime}::Mime>().unwrap());
#{SmithyHttpServer}::protocols::check_content_type(request, &EXPECTED_CONTENT_TYPE)?;
input = #{parser}(bytes.as_ref(), input)?;
}
""",
Expand All @@ -740,6 +767,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
serverRenderUriPathParser(this, operationShape)
serverRenderQueryStringParser(this, operationShape)

if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) {
conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) {
rustTemplate(
"""
#{SmithyHttpServer}::protocols::content_type_header_empty_body_no_modeled_input(request)?;
""",
*codegenScope,
)
}
}
val err = if (StructureGenerator.fallibleBuilder(inputShape, symbolProvider)) {
"?"
} else ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Pro
import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.client.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestJsonProtocol

/**
* RestJson1 server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator]
Expand All @@ -19,7 +20,7 @@ class ServerRestJsonFactory : ProtocolGeneratorFactory<ServerHttpBoundProtocolGe
override fun protocol(codegenContext: ServerCodegenContext): Protocol = RestJson(codegenContext)

override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator =
ServerHttpBoundProtocolGenerator(codegenContext, RestJson(codegenContext))
ServerHttpBoundProtocolGenerator(codegenContext, ServerRestJsonProtocol(codegenContext))

override fun support(): ProtocolSupport {
return ProtocolSupport(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Pro
import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.client.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestXmlProtocol

/*
* RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator]
Expand All @@ -19,7 +20,7 @@ class ServerRestXmlFactory : ProtocolGeneratorFactory<ServerHttpBoundProtocolGen
override fun protocol(codegenContext: ServerCodegenContext): Protocol = RestXml(codegenContext)

override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator =
ServerHttpBoundProtocolGenerator(codegenContext, RestXml(codegenContext))
ServerHttpBoundProtocolGenerator(codegenContext, ServerRestXmlProtocol(codegenContext))

override fun support(): ProtocolSupport {
return ProtocolSupport(
Expand Down
Loading

0 comments on commit 54cc60e

Please sign in to comment.