Skip to content

Commit

Permalink
Remove parameter from Protocols structuredDataParser, `structured…
Browse files Browse the repository at this point in the history
…DataSerializer` (#2536)

No implementation of the `Protocol` interface makes use of the
`OperationShape` parameter in the `structuredDataParser` and
`structuredDataSerializer` methods.
  • Loading branch information
david-perez authored Apr 4, 2023
1 parent 8bc93fc commit 7e6f2c9
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class ProtocolParserGenerator(
customizations: List<OperationCustomization>,
) {
val httpBindingGenerator = ResponseBindingGenerator(protocol, codegenContext, operationShape)
val structuredDataParser = protocol.structuredDataParser(operationShape)
val structuredDataParser = protocol.structuredDataParser()
Attribute.AllowUnusedMut.render(this)
rust("let mut output = #T::default();", symbolProvider.symbolForBuilder(outputShape))
if (outputShape.id == operationShape.output.get()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,14 @@ open class AwsJson(
override fun additionalRequestHeaders(operationShape: OperationShape): List<Pair<String, String>> =
listOf("x-amz-target" to "${codegenContext.serviceShape.id.name}.${operationShape.id.name}")

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
return JsonParserGenerator(
override fun structuredDataParser(): StructuredDataParserGenerator =
JsonParserGenerator(
codegenContext,
httpBindingResolver,
::awsJsonFieldName,
)
}

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
AwsJsonSerializerGenerator(codegenContext, httpBindingResolver)

override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ class AwsQueryProtocol(private val codegenContext: CodegenContext) : Protocol {

override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator =
override fun structuredDataParser(): StructuredDataParserGenerator =
AwsQueryParserGenerator(codegenContext, awsQueryErrors)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
AwsQuerySerializerGenerator(codegenContext)

override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ class AwsQueryCompatible(

override val defaultTimestampFormat = awsJson.defaultTimestampFormat

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator =
awsJson.structuredDataParser(operationShape)
override fun structuredDataParser(): StructuredDataParserGenerator =
awsJson.structuredDataParser()

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
awsJson.structuredDataSerializer(operationShape)
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
awsJson.structuredDataSerializer()

override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
ProtocolFunctions.crossOperationFn("parse_http_error_metadata") { fnName ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ class Ec2QueryProtocol(private val codegenContext: CodegenContext) : Protocol {

override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
return Ec2QueryParserGenerator(codegenContext, ec2QueryErrors)
}
override fun structuredDataParser(): StructuredDataParserGenerator =
Ec2QueryParserGenerator(codegenContext, ec2QueryErrors)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
Ec2QuerySerializerGenerator(codegenContext)

override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class HttpBoundProtocolPayloadGenerator(
val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName

if (payloadMemberName == null) {
val serializerGenerator = protocol.structuredDataSerializer(operationShape)
val serializerGenerator = protocol.structuredDataSerializer()
generateStructureSerializer(writer, self, serializerGenerator.operationInputSerializer(operationShape))
} else {
generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName)
Expand All @@ -113,7 +113,7 @@ class HttpBoundProtocolPayloadGenerator(
val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName

if (payloadMemberName == null) {
val serializerGenerator = protocol.structuredDataSerializer(operationShape)
val serializerGenerator = protocol.structuredDataSerializer()
generateStructureSerializer(writer, self, serializerGenerator.operationOutputSerializer(operationShape))
} else {
generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName)
Expand All @@ -126,7 +126,7 @@ class HttpBoundProtocolPayloadGenerator(
operationShape: OperationShape,
payloadMemberName: String,
) {
val serializerGenerator = protocol.structuredDataSerializer(operationShape)
val serializerGenerator = protocol.structuredDataSerializer()

if (operationShape.isEventStream(model)) {
if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ interface Protocol {
fun additionalErrorResponseHeaders(errorShape: StructureShape): List<Pair<String, String>> = emptyList()

/** Returns a deserialization code generator for this protocol */
fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator
fun structuredDataParser(): StructuredDataParserGenerator

/** Returns a serialization code generator for this protocol */
fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator
fun structuredDataSerializer(): StructuredDataSerializerGenerator

/**
* Generates a function signature like the following:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol {
override fun additionalErrorResponseHeaders(errorShape: StructureShape): List<Pair<String, String>> =
listOf("x-amzn-errortype" to errorShape.id.toString())

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
return JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
}
override fun structuredDataParser(): StructuredDataParserGenerator =
JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)

override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ open class RestXml(val codegenContext: CodegenContext) : Protocol {
override val defaultTimestampFormat: TimestampFormatTrait.Format =
TimestampFormatTrait.Format.DATE_TIME

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
return RestXmlParserGenerator(codegenContext, restXmlErrors)
}
override fun structuredDataParser(): StructuredDataParserGenerator =
RestXmlParserGenerator(codegenContext, restXmlErrors)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator {
return XmlBindingTraitSerializerGenerator(codegenContext, httpBindingResolver)
}
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
XmlBindingTraitSerializerGenerator(codegenContext, httpBindingResolver)

override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
ProtocolFunctions.crossOperationFn("parse_http_error_metadata") { fnName ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class EventStreamUnmarshallerGenerator(

private fun RustWriter.renderParseProtocolPayload(member: MemberShape) {
val memberName = symbolProvider.toMemberName(member)
val parser = protocol.structuredDataParser(operationShape).payloadParser(member)
val parser = protocol.structuredDataParser().payloadParser(member)
rustTemplate(
"""
#{parser}(&message.payload()[..])
Expand Down Expand Up @@ -341,7 +341,7 @@ class EventStreamUnmarshallerGenerator(
when (codegenTarget) {
CodegenTarget.CLIENT -> {
val target = model.expectShape(member.target, StructureShape::class.java)
val parser = protocol.structuredDataParser(operationShape).errorParser(target)
val parser = protocol.structuredDataParser().errorParser(target)
if (parser != null) {
rust("let mut builder = #T::default();", symbolProvider.symbolForBuilder(target))
rustTemplate(
Expand All @@ -363,7 +363,7 @@ class EventStreamUnmarshallerGenerator(

CodegenTarget.SERVER -> {
val target = model.expectShape(member.target, StructureShape::class.java)
val parser = protocol.structuredDataParser(operationShape).errorParser(target)
val parser = protocol.structuredDataParser().errorParser(target)
val mut = if (parser != null) { " mut" } else { "" }
rust("let$mut builder = #T::default();", symbolProvider.symbolForBuilder(target))
if (parser != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
Expand All @@ -16,6 +17,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
Expand Down Expand Up @@ -90,6 +92,31 @@ interface ServerProtocol : Protocol {
.toType().resolve("proto::$protocolModulePath::runtime_error::RuntimeError")
}

fun returnSymbolToParseFn(codegenContext: ServerCodegenContext): (Shape) -> ReturnSymbolToParse {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) {
ReturnSymbolToParse(codegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true)
} else {
ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false)
}
return ::returnSymbolToParse
}

fun jsonParserGenerator(
codegenContext: ServerCodegenContext,
httpBindingResolver: HttpBindingResolver,
jsonName: (MemberShape) -> String,
): JsonParserGenerator =
JsonParserGenerator(
codegenContext,
httpBindingResolver,
jsonName,
returnSymbolToParseFn(codegenContext),
listOf(
ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(codegenContext),
),
)

class ServerAwsJsonProtocol(
private val serverCodegenContext: ServerCodegenContext,
awsJsonVersion: AwsJsonVersion,
Expand All @@ -102,25 +129,10 @@ class ServerAwsJsonProtocol(
is AwsJsonVersion.Json11 -> "aws_json_11"
}

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, serverCodegenContext.symbolProvider)) {
ReturnSymbolToParse(serverCodegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true)
} else {
ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false)
}
return JsonParserGenerator(
codegenContext,
httpBindingResolver,
::awsJsonFieldName,
::returnSymbolToParse,
listOf(
ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(serverCodegenContext),
),
)
}
override fun structuredDataParser(): StructuredDataParserGenerator =
jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::awsJsonFieldName)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
ServerAwsJsonSerializerGenerator(serverCodegenContext, httpBindingResolver, awsJsonVersion)

override fun markerStruct(): RuntimeType {
Expand Down Expand Up @@ -176,27 +188,10 @@ class ServerRestJsonProtocol(

override val protocolModulePath: String = "rest_json_1"

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) {
ReturnSymbolToParse(serverCodegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true)
} else {
ReturnSymbolToParse(serverCodegenContext.symbolProvider.toSymbol(shape), false)
}
return JsonParserGenerator(
codegenContext,
httpBindingResolver,
::restJsonFieldName,
::returnSymbolToParse,
listOf(
ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(
serverCodegenContext,
),
),
)
}
override fun structuredDataParser(): StructuredDataParserGenerator =
jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::restJsonFieldName)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver)

override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", protocolModulePath, runtimeConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
errorSymbol: Symbol,
) {
val operationName = symbolProvider.toSymbol(operationShape).name
val structuredDataSerializer = protocol.structuredDataSerializer(operationShape)
val structuredDataSerializer = protocol.structuredDataSerializer()
withBlock("match error {", "}") {
val errors = operationShape.operationErrors(model)
errors.forEach {
Expand Down Expand Up @@ -612,7 +612,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
bindings: List<HttpBindingDescriptor>,
) {
val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape)
val structuredDataParser = protocol.structuredDataParser(operationShape)
val structuredDataParser = protocol.structuredDataParser()
Attribute.AllowUnusedMut.render(this)
rust(
"let mut input = #T::default();",
Expand Down

0 comments on commit 7e6f2c9

Please sign in to comment.