diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt index 616325ee3e..8b247cd917 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt @@ -232,7 +232,7 @@ class ProtocolParserGenerator( customizations: List, ) { val httpBindingGenerator = ResponseBindingGenerator(protocol, codegenContext, operationShape) - val structuredDataParser = protocol.structuredDataParser(operationShape) + val structuredDataParser = protocol.structuredDataParser() Attribute.AllowUnusedMut.render(this) rust("let mut output = #T::default();", symbolProvider.symbolForBuilder(outputShape)) if (outputShape.id == operationShape.output.get()) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 7812f917bb..0a53422552 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -143,15 +143,14 @@ open class AwsJson( override fun additionalRequestHeaders(operationShape: OperationShape): List> = listOf("x-amz-target" to "${codegenContext.serviceShape.id.name}.${operationShape.id.name}") - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - return JsonParserGenerator( + override fun structuredDataParser(): StructuredDataParserGenerator = + JsonParserGenerator( codegenContext, httpBindingResolver, ::awsJsonFieldName, ) - } - override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = AwsJsonSerializerGenerator(codegenContext, httpBindingResolver) override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt index 9d12cfbb62..934b703eba 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt @@ -51,10 +51,10 @@ class AwsQueryProtocol(private val codegenContext: CodegenContext) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = + override fun structuredDataParser(): StructuredDataParserGenerator = AwsQueryParserGenerator(codegenContext, awsQueryErrors) - override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = AwsQuerySerializerGenerator(codegenContext) override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt index 88fb69d015..489f1de582 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt @@ -64,11 +64,11 @@ class AwsQueryCompatible( override val defaultTimestampFormat = awsJson.defaultTimestampFormat - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - awsJson.structuredDataParser(operationShape) + override fun structuredDataParser(): StructuredDataParserGenerator = + awsJson.structuredDataParser() - override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = - awsJson.structuredDataSerializer(operationShape) + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = + awsJson.structuredDataSerializer() override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_http_error_metadata") { fnName -> diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt index 01a530d46a..215865f4cb 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt @@ -42,11 +42,10 @@ class Ec2QueryProtocol(private val codegenContext: CodegenContext) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - return Ec2QueryParserGenerator(codegenContext, ec2QueryErrors) - } + override fun structuredDataParser(): StructuredDataParserGenerator = + Ec2QueryParserGenerator(codegenContext, ec2QueryErrors) - override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = Ec2QuerySerializerGenerator(codegenContext) override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index bb55883a11..10d4b83cec 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -102,7 +102,7 @@ class HttpBoundProtocolPayloadGenerator( val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { - val serializerGenerator = protocol.structuredDataSerializer(operationShape) + val serializerGenerator = protocol.structuredDataSerializer() generateStructureSerializer(writer, self, serializerGenerator.operationInputSerializer(operationShape)) } else { generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) @@ -113,7 +113,7 @@ class HttpBoundProtocolPayloadGenerator( val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { - val serializerGenerator = protocol.structuredDataSerializer(operationShape) + val serializerGenerator = protocol.structuredDataSerializer() generateStructureSerializer(writer, self, serializerGenerator.operationOutputSerializer(operationShape)) } else { generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) @@ -126,7 +126,7 @@ class HttpBoundProtocolPayloadGenerator( operationShape: OperationShape, payloadMemberName: String, ) { - val serializerGenerator = protocol.structuredDataSerializer(operationShape) + val serializerGenerator = protocol.structuredDataSerializer() if (operationShape.isEventStream(model)) { if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt index ec9b944e3e..4a1339ca9a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt @@ -38,10 +38,10 @@ interface Protocol { fun additionalErrorResponseHeaders(errorShape: StructureShape): List> = emptyList() /** Returns a deserialization code generator for this protocol */ - fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator + fun structuredDataParser(): StructuredDataParserGenerator /** Returns a serialization code generator for this protocol */ - fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator + fun structuredDataSerializer(): StructuredDataSerializerGenerator /** * Generates a function signature like the following: diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index 675a0b2126..9a3d12a993 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -89,11 +89,10 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { override fun additionalErrorResponseHeaders(errorShape: StructureShape): List> = listOf("x-amzn-errortype" to errorShape.id.toString()) - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - return JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) - } + override fun structuredDataParser(): StructuredDataParserGenerator = + JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) - override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt index 3045cf0268..41df9fd52d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt @@ -40,13 +40,11 @@ open class RestXml(val codegenContext: CodegenContext) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - return RestXmlParserGenerator(codegenContext, restXmlErrors) - } + override fun structuredDataParser(): StructuredDataParserGenerator = + RestXmlParserGenerator(codegenContext, restXmlErrors) - override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator { - return XmlBindingTraitSerializerGenerator(codegenContext, httpBindingResolver) - } + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = + XmlBindingTraitSerializerGenerator(codegenContext, httpBindingResolver) override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_http_error_metadata") { fnName -> diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index 4fa98fa6ee..6e9826f054 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -294,7 +294,7 @@ class EventStreamUnmarshallerGenerator( private fun RustWriter.renderParseProtocolPayload(member: MemberShape) { val memberName = symbolProvider.toMemberName(member) - val parser = protocol.structuredDataParser(operationShape).payloadParser(member) + val parser = protocol.structuredDataParser().payloadParser(member) rustTemplate( """ #{parser}(&message.payload()[..]) @@ -341,7 +341,7 @@ class EventStreamUnmarshallerGenerator( when (codegenTarget) { CodegenTarget.CLIENT -> { val target = model.expectShape(member.target, StructureShape::class.java) - val parser = protocol.structuredDataParser(operationShape).errorParser(target) + val parser = protocol.structuredDataParser().errorParser(target) if (parser != null) { rust("let mut builder = #T::default();", symbolProvider.symbolForBuilder(target)) rustTemplate( @@ -363,7 +363,7 @@ class EventStreamUnmarshallerGenerator( CodegenTarget.SERVER -> { val target = model.expectShape(member.target, StructureShape::class.java) - val parser = protocol.structuredDataParser(operationShape).errorParser(target) + val parser = protocol.structuredDataParser().errorParser(target) val mut = if (parser != null) { " mut" } else { "" } rust("let$mut builder = #T::default();", symbolProvider.symbolForBuilder(target)) if (parser != null) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 9abe7db0c3..0b466f95c2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape @@ -16,6 +17,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml @@ -90,6 +92,31 @@ interface ServerProtocol : Protocol { .toType().resolve("proto::$protocolModulePath::runtime_error::RuntimeError") } +fun returnSymbolToParseFn(codegenContext: ServerCodegenContext): (Shape) -> ReturnSymbolToParse { + fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse = + if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) { + ReturnSymbolToParse(codegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true) + } else { + ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) + } + return ::returnSymbolToParse +} + +fun jsonParserGenerator( + codegenContext: ServerCodegenContext, + httpBindingResolver: HttpBindingResolver, + jsonName: (MemberShape) -> String, +): JsonParserGenerator = + JsonParserGenerator( + codegenContext, + httpBindingResolver, + jsonName, + returnSymbolToParseFn(codegenContext), + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(codegenContext), + ), + ) + class ServerAwsJsonProtocol( private val serverCodegenContext: ServerCodegenContext, awsJsonVersion: AwsJsonVersion, @@ -102,25 +129,10 @@ class ServerAwsJsonProtocol( is AwsJsonVersion.Json11 -> "aws_json_11" } - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse = - if (shape.canReachConstrainedShape(codegenContext.model, serverCodegenContext.symbolProvider)) { - ReturnSymbolToParse(serverCodegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true) - } else { - ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) - } - return JsonParserGenerator( - codegenContext, - httpBindingResolver, - ::awsJsonFieldName, - ::returnSymbolToParse, - listOf( - ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(serverCodegenContext), - ), - ) - } + override fun structuredDataParser(): StructuredDataParserGenerator = + jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::awsJsonFieldName) - override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = ServerAwsJsonSerializerGenerator(serverCodegenContext, httpBindingResolver, awsJsonVersion) override fun markerStruct(): RuntimeType { @@ -176,27 +188,10 @@ class ServerRestJsonProtocol( override val protocolModulePath: String = "rest_json_1" - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse = - if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) { - ReturnSymbolToParse(serverCodegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true) - } else { - ReturnSymbolToParse(serverCodegenContext.symbolProvider.toSymbol(shape), false) - } - return JsonParserGenerator( - codegenContext, - httpBindingResolver, - ::restJsonFieldName, - ::returnSymbolToParse, - listOf( - ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization( - serverCodegenContext, - ), - ), - ) - } + override fun structuredDataParser(): StructuredDataParserGenerator = + jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::restJsonFieldName) - override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver) override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", protocolModulePath, runtimeConfig) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 6bf8dba7ba..9a7e45f112 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -390,7 +390,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( errorSymbol: Symbol, ) { val operationName = symbolProvider.toSymbol(operationShape).name - val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) + val structuredDataSerializer = protocol.structuredDataSerializer() withBlock("match error {", "}") { val errors = operationShape.operationErrors(model) errors.forEach { @@ -612,7 +612,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( bindings: List, ) { val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) - val structuredDataParser = protocol.structuredDataParser(operationShape) + val structuredDataParser = protocol.structuredDataParser() Attribute.AllowUnusedMut.render(this) rust( "let mut input = #T::default();",