From a81109b213b9d1615e719986fd6b356885f8451c Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Fri, 15 Sep 2023 16:27:27 -0400 Subject: [PATCH] Extract builderInstantiator interface to prepare for nullability changes --- .../client/smithy/ClientCodegenContext.kt | 5 ++ .../generators/ClientBuilderInstantiator.kt | 41 +++++++++++++ .../smithy/protocols/ClientProtocolLoader.kt | 4 +- .../codegen/core/smithy/CodegenContext.kt | 5 +- .../smithy/generators/BuilderInstantiator.kt | 29 ++++++++++ .../codegen/core/smithy/protocols/AwsJson.kt | 2 + .../parse/EventStreamUnmarshallerGenerator.kt | 14 ++++- .../protocols/parse/JsonParserGenerator.kt | 23 +++++--- .../parse/XmlBindingTraitParserGenerator.kt | 58 +++++++++++++------ .../core/testutil/EventStreamTestModels.kt | 2 +- .../rust/codegen/core/testutil/TestHelpers.kt | 16 ++++- .../generators/TestBuilderInstantiator.kt | 6 ++ .../server/smithy/ServerCodegenContext.kt | 9 ++- .../smithy/generators/ServerInstantiator.kt | 40 +++++++++++++ .../generators/protocol/ServerProtocol.kt | 19 +++++- 15 files changed, 235 insertions(+), 38 deletions(-) create mode 100644 codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderInstantiator.kt create mode 100644 codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestBuilderInstantiator.kt diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt index 0e12986ff3f..101d9632551 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt @@ -9,10 +9,12 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientBuilderInstantiator import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.ModuleDocProvider import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol /** @@ -36,4 +38,7 @@ data class ClientCodegenContext( model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.CLIENT, ) { val enableUserConfigurableRuntimePlugins: Boolean get() = settings.codegenConfig.enableUserConfigurableRuntimePlugins + override fun builderInstantiator(): BuilderInstantiator { + return ClientBuilderInstantiator(symbolProvider) + } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt new file mode 100644 index 00000000000..d114a3c9ee1 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt @@ -0,0 +1,41 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators + +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.map +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator + +fun ClientCodegenContext.builderInstantiator(): BuilderInstantiator = ClientBuilderInstantiator(symbolProvider) + +class ClientBuilderInstantiator(private val symbolProvider: RustSymbolProvider) : BuilderInstantiator { + override fun setField(builder: String, value: Writable, field: MemberShape): Writable { + return setFieldBase(builder, value, field) + } + + override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable = writable { + if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) { + rustTemplate( + "$builder.build()#{mapErr}?", + "mapErr" to ( + mapErr?.map { + rust(".map_err(#T)", it) + } ?: writable { } + ), + ) + } else { + rust("$builder.build()") + } + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index cb9c7db7e5a..d7cdd7594c7 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -63,9 +63,9 @@ private class ClientAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory { override fun protocol(codegenContext: ClientCodegenContext): Protocol = if (compatibleWithAwsQuery(codegenContext.serviceShape, version)) { - AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version)) + AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version, codegenContext.builderInstantiator())) } else { - AwsJson(codegenContext, version) + AwsJson(codegenContext, version, codegenContext.builderInstantiator()) } override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt index 3fa6f688e32..ea427fc65f8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator /** * [CodegenContext] contains code-generation context that is _common to all_ smithy-rs plugins. @@ -17,7 +18,7 @@ import software.amazon.smithy.model.shapes.ShapeId * If your data is specific to the `rust-client-codegen` client plugin, put it in [ClientCodegenContext] instead. * If your data is specific to the `rust-server-codegen` server plugin, put it in [ServerCodegenContext] instead. */ -open class CodegenContext( +abstract class CodegenContext( /** * The smithy model. * @@ -89,4 +90,6 @@ open class CodegenContext( fun expectModuleDocProvider(): ModuleDocProvider = checkNotNull(moduleDocProvider) { "A ModuleDocProvider must be set on the CodegenContext" } + + abstract fun builderInstantiator(): BuilderInstantiator } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderInstantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderInstantiator.kt new file mode 100644 index 00000000000..13a26c51bdb --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderInstantiator.kt @@ -0,0 +1,29 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.generators + +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable + +/** Abstraction for instantiating a builders. + * + * Builder abstractions vary—clients MAY use `build_with_error_correction`, e.g., and builders can vary in fallibility. + * */ +interface BuilderInstantiator { + /** Set a field on a builder. */ + fun setField(builder: String, value: Writable, field: MemberShape): Writable + + /** Finalize a builder, turning into a built object (or in the case of builders-of-builders, return the builder directly).*/ + fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable? = null): Writable + + /** Set a field on a builder using the `$setterName` method. $value will be passed directly. */ + fun setFieldBase(builder: String, value: Writable, field: MemberShape) = writable { + rustTemplate("$builder = $builder.${field.setterName()}(#{value})", "value" to value) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 0a534225522..0b3178095c1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator @@ -122,6 +123,7 @@ class AwsJsonSerializerGenerator( open class AwsJson( val codegenContext: CodegenContext, val awsJsonVersion: AwsJsonVersion, + val builderInstantiator: BuilderInstantiator, ) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index 6b706a1e098..457bf5a53b9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -53,6 +53,7 @@ class EventStreamUnmarshallerGenerator( private val unionShape: UnionShape, ) { private val model = codegenContext.model + private val builderInstantiator = codegenContext.builderInstantiator() private val symbolProvider = codegenContext.symbolProvider private val codegenTarget = codegenContext.target private val runtimeConfig = codegenContext.runtimeConfig @@ -339,6 +340,7 @@ class EventStreamUnmarshallerGenerator( // TODO(EventStream): Errors on the operation can be disjoint with errors in the union, // so we need to generate a new top-level Error type for each event stream union. when (codegenTarget) { + // TODO(https://github.com/awslabs/smithy-rs/issues/1970) It should be possible to unify these branches now CodegenTarget.CLIENT -> { val target = model.expectShape(member.target, StructureShape::class.java) val parser = protocol.structuredDataParser().errorParser(target) @@ -352,9 +354,19 @@ class EventStreamUnmarshallerGenerator( })?; builder.set_meta(Some(generic)); return Ok(#{UnmarshalledMessage}::Error( - #{OpError}::${member.target.name}(builder.build()) + #{OpError}::${member.target.name}( + #{build} + ) )) """, + "build" to builderInstantiator.finalizeBuilder( + "builder", target, + mapErr = { + rustTemplate( + """|err|#{Error}::unmarshalling(format!("{}", err))""", *codegenScope, + ) + }, + ), "parser" to parser, *codegenScope, ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 01c2b64e6a2..3b05b076942 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -59,13 +59,16 @@ import software.amazon.smithy.utils.StringUtils * Class describing a JSON parser section that can be used in a customization. */ sealed class JsonParserSection(name: String) : Section(name) { - data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember") + data class BeforeBoxingDeserializedMember(val shape: MemberShape) : + JsonParserSection("BeforeBoxingDeserializedMember") - data class AfterTimestampDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterTimestampDeserializedMember") + data class AfterTimestampDeserializedMember(val shape: MemberShape) : + JsonParserSection("AfterTimestampDeserializedMember") data class AfterBlobDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterBlobDeserializedMember") - data class AfterDocumentDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterDocumentDeserializedMember") + data class AfterDocumentDeserializedMember(val shape: MemberShape) : + JsonParserSection("AfterDocumentDeserializedMember") } /** @@ -100,6 +103,7 @@ class JsonParserGenerator( private val codegenTarget = codegenContext.target private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() private val protocolFunctions = ProtocolFunctions(codegenContext) + private val builderInstantiator = codegenContext.builderInstantiator() private val codegenScope = arrayOf( "Error" to smithyJson.resolve("deserialize::error::DeserializeError"), "expect_blob_or_null" to smithyJson.resolve("deserialize::token::expect_blob_or_null"), @@ -251,6 +255,7 @@ class JsonParserGenerator( deserializeMember(member) } } + CodegenTarget.SERVER -> { if (symbolProvider.toSymbol(member).isOptional()) { withBlock("builder = builder.${member.setterName()}(", ");") { @@ -508,12 +513,14 @@ class JsonParserGenerator( "Builder" to symbolProvider.symbolForBuilder(shape), ) deserializeStructInner(shape.members()) - // Only call `build()` if the builder is not fallible. Otherwise, return the builder. - if (returnSymbolToParse.isUnconstrained) { - rust("Ok(Some(builder))") - } else { - rust("Ok(Some(builder.build()))") + val builder = builderInstantiator.finalizeBuilder( + "builder", shape, + ) { + rustTemplate( + """|err|#{Error}::custom_source("Response was invalid", err)""", *codegenScope, + ) } + rust("Ok(Some(#T))", builder) } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index d083d0e901b..b7205562d19 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -39,7 +39,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName @@ -101,6 +100,7 @@ class XmlBindingTraitParserGenerator( private val runtimeConfig = codegenContext.runtimeConfig private val protocolFunctions = ProtocolFunctions(codegenContext) private val codegenTarget = codegenContext.target + private val builderInstantiator = codegenContext.builderInstantiator() // The symbols we want all the time private val codegenScope = arrayOf( @@ -159,6 +159,7 @@ class XmlBindingTraitParserGenerator( is StructureShape -> { parseStructure(shape, ctx) } + is UnionShape -> parseUnion(shape, ctx) } } @@ -294,7 +295,10 @@ class XmlBindingTraitParserGenerator( } rust("$builder = $builder.${member.setterName()}($temp);") } - rustTemplate("_ => return Err(#{XmlDecodeError}::custom(\"expected ${member.xmlName()} tag\"))", *codegenScope) + rustTemplate( + "_ => return Err(#{XmlDecodeError}::custom(\"expected ${member.xmlName()} tag\"))", + *codegenScope, + ) } } @@ -359,19 +363,23 @@ class XmlBindingTraitParserGenerator( parsePrimitiveInner(memberShape) { rustTemplate("#{try_data}(&mut ${ctx.tag})?.as_ref()", *codegenScope) } + is MapShape -> if (memberShape.isFlattened()) { parseFlatMap(target, ctx) } else { parseMap(target, ctx) } + is CollectionShape -> if (memberShape.isFlattened()) { parseFlatList(target, ctx) } else { parseList(target, ctx) } + is StructureShape -> { parseStructure(target, ctx) } + is UnionShape -> parseUnion(target, ctx) else -> PANIC("Unhandled: $target") } @@ -436,10 +444,16 @@ class XmlBindingTraitParserGenerator( } when (target.renderUnknownVariant()) { true -> rust("_unknown => base = Some(#T::${UnionGenerator.UnknownVariantName}),", symbol) - false -> rustTemplate("""variant => return Err(#{XmlDecodeError}::custom(format!("unexpected union variant: {:?}", variant)))""", *codegenScope) + false -> rustTemplate( + """variant => return Err(#{XmlDecodeError}::custom(format!("unexpected union variant: {:?}", variant)))""", + *codegenScope, + ) } } - rustTemplate("""base.ok_or_else(||#{XmlDecodeError}::custom("expected union, got nothing"))""", *codegenScope) + rustTemplate( + """base.ok_or_else(||#{XmlDecodeError}::custom("expected union, got nothing"))""", + *codegenScope, + ) } } rust("#T(&mut ${ctx.tag})", nestedParser) @@ -474,17 +488,17 @@ class XmlBindingTraitParserGenerator( } else { rust("let _ = decoder;") } - withBlock("Ok(builder.build()", ")") { - if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) { - // NOTE:(rcoh) This branch is unreachable given the current nullability rules. - // Only synthetic inputs can have fallible builders, but synthetic inputs can never be parsed - // (because they're inputs, only outputs will be parsed!) - - // I'm leaving this branch here so that the binding trait parser generator would work for a server - // side implementation in the future. - rustTemplate(""".map_err(|_|#{XmlDecodeError}::custom("missing field"))?""", *codegenScope) - } - } + val builder = builderInstantiator.finalizeBuilder( + "builder", + shape, + mapErr = { + rustTemplate( + """.map_err(|_|#{XmlDecodeError}::custom("missing field"))?""", + *codegenScope, + ) + }, + ) + rust("Ok(#T)", builder) } } rust("#T(&mut ${ctx.tag})", nestedParser) @@ -622,6 +636,7 @@ class XmlBindingTraitParserGenerator( ) } } + is TimestampShape -> { val timestampFormat = index.determineTimestampFormat( @@ -629,7 +644,8 @@ class XmlBindingTraitParserGenerator( HttpBinding.Location.DOCUMENT, TimestampFormatTrait.Format.DATE_TIME, ) - val timestampFormatType = RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat) + val timestampFormatType = + RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat) withBlock("#T::from_str(", ")", RuntimeType.dateTime(runtimeConfig)) { provider() rust(", #T", timestampFormatType) @@ -639,6 +655,7 @@ class XmlBindingTraitParserGenerator( *codegenScope, ) } + is BlobShape -> { withBlock("#T(", ")", RuntimeType.base64Decode(runtimeConfig)) { provider() @@ -648,6 +665,7 @@ class XmlBindingTraitParserGenerator( *codegenScope, ) } + else -> PANIC("unexpected shape: $shape") } } @@ -660,7 +678,10 @@ class XmlBindingTraitParserGenerator( withBlock("#T::try_from(", ")", enumSymbol) { provider() } - rustTemplate(""".map_err(|e| #{XmlDecodeError}::custom(format!("unknown variant {}", e)))?""", *codegenScope) + rustTemplate( + """.map_err(|e| #{XmlDecodeError}::custom(format!("unknown variant {}", e)))?""", + *codegenScope, + ) } else { withBlock("#T::from(", ")", enumSymbol) { provider() @@ -674,7 +695,8 @@ class XmlBindingTraitParserGenerator( } } - private fun convertsToEnumInServer(shape: StringShape) = target == CodegenTarget.SERVER && shape.hasTrait() + private fun convertsToEnumInServer(shape: StringShape) = + target == CodegenTarget.SERVER && shape.hasTrait() private fun MemberShape.xmlName(): XmlName { return XmlName(xmlIndex.memberName(this)) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index e944a552a08..d6b43b97cbb 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -145,7 +145,7 @@ object EventStreamTestModels { validTestUnion = """{"Foo":"hello"}""", validSomeError = """{"Message":"some error"}""", validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { AwsJson(it, AwsJsonVersion.Json11) }, + ) { AwsJson(it, AwsJsonVersion.Json11, builderInstantiator = DefaultBuilderInstantiator()) }, // // restXml diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index b80f211c766..ff4f017b308 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -36,6 +36,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait @@ -88,7 +89,11 @@ private object CodegenCoreTestModules { eventStream: UnionShape, ): RustModule.LeafModule = ErrorsTestModule - override fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule { + override fun moduleForBuilder( + context: ModuleProviderContext, + shape: Shape, + symbol: Symbol, + ): RustModule.LeafModule { val builderNamespace = RustReservedWords.escapeIfNeeded("test_" + symbol.name.toSnakeCase()) return RustModule.new( builderNamespace, @@ -161,7 +166,7 @@ internal fun testCodegenContext( serviceShape: ServiceShape? = null, settings: CoreRustSettings = testRustSettings(), codegenTarget: CodegenTarget = CodegenTarget.CLIENT, -): CodegenContext = CodegenContext( +): CodegenContext = object : CodegenContext( model, testSymbolProvider(model), TestModuleDocProvider, @@ -171,11 +176,16 @@ internal fun testCodegenContext( ShapeId.from("test#Protocol"), settings, codegenTarget, -) +) { + override fun builderInstantiator(): BuilderInstantiator { + return DefaultBuilderInstantiator() + } +} /** * In tests, we frequently need to generate a struct, a builder, and an impl block to access said builder. */ + fun StructureShape.renderWithModelBuilder( model: Model, symbolProvider: RustSymbolProvider, diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestBuilderInstantiator.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestBuilderInstantiator.kt new file mode 100644 index 00000000000..1229d59c43a --- /dev/null +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestBuilderInstantiator.kt @@ -0,0 +1,6 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.generators diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt index d5c2a7084ad..d952a7771b0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt @@ -12,6 +12,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.ModuleDocProvider import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderInstantiator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.returnSymbolToParseFn /** * [ServerCodegenContext] contains code-generation context that is _specific_ to the [RustServerCodegenPlugin] plugin @@ -35,4 +38,8 @@ data class ServerCodegenContext( val pubCrateConstrainedShapeSymbolProvider: PubCrateConstrainedShapeSymbolProvider, ) : CodegenContext( model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.SERVER, -) +) { + override fun builderInstantiator(): BuilderInstantiator { + return ServerBuilderInstantiator(symbolProvider, returnSymbolToParseFn(this)) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index 3d3105f9c70..288b3b934e7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -7,15 +7,20 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.InstantiatorCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.InstantiatorSection import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.ReturnSymbolToParse import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput @@ -86,3 +91,38 @@ fun serverInstantiator(codegenContext: CodegenContext) = defaultsForRequiredFields = true, customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), ) + +class ServerBuilderInstantiator( + private val symbolProvider: RustSymbolProvider, + private val symbolParseFn: (Shape) -> ReturnSymbolToParse, +) : + BuilderInstantiator { + override fun setField(builder: String, value: Writable, field: MemberShape): Writable { + // Server builders have the ability to have non-optional fields. When one of these fields is used, + // we need to use `if let(...)` to only set the field when it is present. + return if (!symbolProvider.toSymbol(field).isOptional()) { + writable { + val n = safeName() + rustTemplate( + """ + if let Some($n) = #{value} { + #{setter} + } + """, + "value" to value, "setter" to setFieldBase(builder, writable(n), field), + ) + } + } else { + setFieldBase(builder, value, field) + } + } + + override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable = writable { + val returnSymbolToParse = symbolParseFn(shape) + if (returnSymbolToParse.isUnconstrained) { + rust(builder) + } else { + rust("$builder.build()") + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 77ba711fd25..622123cd4a0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -122,7 +122,7 @@ class ServerAwsJsonProtocol( private val serverCodegenContext: ServerCodegenContext, awsJsonVersion: AwsJsonVersion, private val additionalParserCustomizations: List = listOf(), -) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol { +) : AwsJson(serverCodegenContext, awsJsonVersion, serverCodegenContext.builderInstantiator()), ServerProtocol { private val runtimeConfig = codegenContext.runtimeConfig override val protocolModulePath: String @@ -132,7 +132,12 @@ class ServerAwsJsonProtocol( } override fun structuredDataParser(): StructuredDataParserGenerator = - jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::awsJsonFieldName, additionalParserCustomizations) + jsonParserGenerator( + serverCodegenContext, + httpBindingResolver, + ::awsJsonFieldName, + additionalParserCustomizations, + ) override fun structuredDataSerializer(): StructuredDataSerializerGenerator = ServerAwsJsonSerializerGenerator(serverCodegenContext, httpBindingResolver, awsJsonVersion) @@ -171,9 +176,11 @@ class ServerAwsJsonProtocol( override fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType = ServerCargoDependency.smithyHttpServer(runtimeConfig) .toType().resolve("protocol::aws_json::rejection::RequestRejection") + override fun responseRejection(runtimeConfig: RuntimeConfig): RuntimeType = ServerCargoDependency.smithyHttpServer(runtimeConfig) .toType().resolve("protocol::aws_json::rejection::ResponseRejection") + override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType = ServerCargoDependency.smithyHttpServer(runtimeConfig) .toType().resolve("protocol::aws_json::runtime_error::RuntimeError") @@ -192,7 +199,12 @@ class ServerRestJsonProtocol( override val protocolModulePath: String = "rest_json_1" override fun structuredDataParser(): StructuredDataParserGenerator = - jsonParserGenerator(serverCodegenContext, httpBindingResolver, ::restJsonFieldName, additionalParserCustomizations) + jsonParserGenerator( + serverCodegenContext, + httpBindingResolver, + ::restJsonFieldName, + additionalParserCustomizations, + ) override fun structuredDataSerializer(): StructuredDataSerializerGenerator = ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver) @@ -257,6 +269,7 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonPa rust(".map(|x| x.into())") } } + else -> emptySection } }