From 1f08e010d41c7c3ef6004d5b545d9089f6919480 Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Tue, 12 Sep 2023 15:36:50 -0400 Subject: [PATCH 1/4] Use the default(..) trait to source default information --- .../smithy/endpoint/EndpointTypesGenerator.kt | 2 - .../smithy/generators/ClientInstantiator.kt | 8 - .../rust/codegen/core/rustlang/Writable.kt | 5 + .../rust/codegen/core/smithy/RuntimeType.kt | 2 +- .../rust/codegen/core/smithy/SymbolExt.kt | 3 + .../rust/codegen/core/smithy/SymbolVisitor.kt | 36 +++- .../smithy/generators/BuilderGenerator.kt | 62 +++++-- .../generators/DefaultValueGenerator.kt | 76 ++++++++ .../core/smithy/generators/Instantiator.kt | 170 ++++++++++-------- .../protocols/parse/JsonParserGenerator.kt | 33 ++-- .../parse/XmlBindingTraitParserGenerator.kt | 49 +++-- .../rust/codegen/core/testutil/TestHelpers.kt | 5 +- .../smithy/generators/InstantiatorTest.kt | 17 +- .../generators/StructureGeneratorTest.kt | 149 +++++++++++++++ .../parse/JsonParserGeneratorTest.kt | 14 +- .../smithy/generators/ServerInstantiator.kt | 14 -- rust-runtime/aws-smithy-types/src/blob.rs | 2 +- 17 files changed, 474 insertions(+), 173 deletions(-) create mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/DefaultValueGenerator.kt diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointTypesGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointTypesGenerator.kt index 889b741ebb..8a1de0151d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointTypesGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointTypesGenerator.kt @@ -27,7 +27,6 @@ class EndpointTypesGenerator( val tests: List, ) { val params: Parameters = rules?.parameters ?: Parameters.builder().build() - private val runtimeConfig = codegenContext.runtimeConfig private val customizations = codegenContext.rootDecorator.endpointCustomizations(codegenContext) private val stdlib = customizations .flatMap { it.customRuntimeFunctions(codegenContext) } @@ -41,7 +40,6 @@ class EndpointTypesGenerator( } fun paramsStruct(): RuntimeType = EndpointParamsGenerator(codegenContext, params).paramsStruct() - fun paramsBuilder(): RuntimeType = EndpointParamsGenerator(codegenContext, params).paramsBuilder() fun defaultResolver(): RuntimeType? = rules?.let { EndpointResolverGenerator(codegenContext, stdlib).defaultEndpointResolver(it) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt index a065fe3b96..ca39264a1c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.model.shapes.MemberShape @@ -14,18 +13,12 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName -private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable { - rust("#T::from($data)", enumSymbol) -} - class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior { override fun hasFallibleBuilder(shape: StructureShape): Boolean = BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider) @@ -40,7 +33,6 @@ class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Ins codegenContext.model, codegenContext.runtimeConfig, ClientBuilderKindBehavior(codegenContext), - ::enumFromStringFn, ) { fun renderFluentCall( writer: RustWriter, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt index e1b5ee64c2..a0c2c2f245 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt @@ -30,6 +30,11 @@ fun Writable.map(f: RustWriter.(Writable) -> Unit): Writable { return writable { f(self) } } +/** Returns Some(..arg) */ +fun Writable.some(): Writable { + return this.map { rust("Some(#T)", it) } +} + fun Writable.isNotEmpty(): Boolean = !this.isEmpty() operator fun Writable.plus(other: Writable): Writable { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index c58596ad94..23f4ffa810 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -204,7 +204,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) * * Prelude docs: https://doc.rust-lang.org/std/prelude/index.html#prelude-contents */ - val preludeScope by lazy { + val preludeScope: Array> by lazy { arrayOf( // Rust 1.0 "Copy" to std.resolve("marker::Copy"), diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt index 3b9307ab7e..68f4c15764 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.core.smithy import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType @@ -102,6 +103,8 @@ sealed class Default { * This symbol should use the Rust `std::default::Default` when unset */ object RustDefault : Default() + + data class NonZeroDefault(val value: Node) : Default() } /** diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt index b2426c602c..9677cd7f0a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode +import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.BigDecimalShape import software.amazon.smithy.model.shapes.BigIntegerShape import software.amazon.smithy.model.shapes.BlobShape @@ -37,6 +38,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.DefaultTrait import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute @@ -48,6 +50,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import kotlin.reflect.KClass @@ -79,16 +82,18 @@ data class MaybeRenamed(val name: String, val renamedFrom: String?) /** * Make the return [value] optional if the [member] symbol is as well optional. */ -fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = value.letIf(toSymbol(member).isOptional()) { - "Some($value)" -} +fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = + value.letIf(toSymbol(member).isOptional()) { + "Some($value)" + } /** * Make the return [value] optional if the [member] symbol is not optional. */ -fun SymbolProvider.toOptional(member: MemberShape, value: String): String = value.letIf(!toSymbol(member).isOptional()) { - "Some($value)" -} +fun SymbolProvider.toOptional(member: MemberShape, value: String): String = + value.letIf(!toSymbol(member).isOptional()) { + "Some($value)" + } /** * Services can rename their contained shapes. See https://awslabs.github.io/smithy/1.0/spec/core/model.html#service @@ -170,7 +175,7 @@ open class SymbolVisitor( } private fun simpleShape(shape: SimpleShape): Symbol { - return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).setDefault(Default.RustDefault).build() + return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).build() } override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape) @@ -263,13 +268,21 @@ open class SymbolVisitor( override fun memberShape(shape: MemberShape): Symbol { val target = model.expectShape(shape.target) + val defaultValue = shape.getMemberTrait(model, DefaultTrait::class.java).orNull()?.let { trait -> + when (val value = trait.toNode()) { + Node.from(""), Node.from(0), Node.from(false), Node.arrayNode(), Node.objectNode() -> Default.RustDefault + Node.nullNode() -> Default.NoDefault + else -> { Default.NonZeroDefault(value) + } + } + } ?: Default.NoDefault // Handle boxing first, so we end up with Option>, not Box>. return handleOptionality( handleRustBoxing(toSymbol(target), shape), shape, nullableIndex, config.nullabilityCheckMode, - ) + ).toBuilder().setDefault(defaultValue).build() } override fun timestampShape(shape: TimestampShape?): Symbol { @@ -297,7 +310,12 @@ fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder = // If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation .definitionFile("thisisabug.rs") -fun handleOptionality(symbol: Symbol, member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode): Symbol = +fun handleOptionality( + symbol: Symbol, + member: MemberShape, + nullableIndex: NullableIndex, + nullabilityCheckMode: CheckMode, +): Symbol = symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() } /** diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 9697cc624f..816ce25c51 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlockTemplat import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.rustlang.map import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -32,7 +33,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.Default import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope @@ -41,7 +41,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations -import software.amazon.smithy.rust.codegen.core.smithy.defaultValue import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeOptional @@ -181,14 +180,15 @@ class BuilderGenerator( private fun renderBuildFn(implBlockWriter: RustWriter) { val fallibleBuilder = hasFallibleBuilder(shape, symbolProvider) val outputSymbol = symbolProvider.toSymbol(shape) + val fb = + "#{Result}<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" val returnType = when (fallibleBuilder) { - true -> "#{Result}<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" + true -> fb false -> implBlockWriter.format(outputSymbol) } implBlockWriter.docs("Consumes the builder and constructs a #D.", outputSymbol) implBlockWriter.rustBlockTemplate("pub fn build(self) -> $returnType", *preludeScope) { conditionalBlockTemplate("#{Ok}(", ")", conditional = fallibleBuilder, *preludeScope) { - // If a wrapper is specified, use the `::new` associated function to construct the wrapper coreBuilder(this) } } @@ -369,6 +369,27 @@ class BuilderGenerator( } } + internal fun errorCorrectingBuilder(writer: RustWriter) { + val outputSymbol = symbolProvider.toSymbol(shape) + val fb = + "#{Result}<${writer.format(outputSymbol)}, ${writer.format(runtimeConfig.operationBuildError())}>" + writer.rustBlockTemplate( + "pub(crate) fn build_with_error_correction(self) -> $fb", + *preludeScope, + ) { + val fallibleBuilder = hasFallibleBuilder(shape, symbolProvider) + if (fallibleBuilder) { + rustTemplate( + "#{Ok}(#{Builder})", + *preludeScope, + "Builder" to writable { coreBuilder(this, errorCorrection = true) }, + ) + } else { + rustTemplate("#{Ok}(self.build())", *preludeScope) + } + } + } + /** * The core builder of the inner type. If the structure requires a fallible builder, this may use `?` to return * errors. @@ -380,20 +401,27 @@ class BuilderGenerator( * } * ``` */ - private fun coreBuilder(writer: RustWriter) { + private fun coreBuilder(writer: RustWriter, errorCorrection: Boolean = false) { writer.rustBlock("#T", structureSymbol) { members.forEach { member -> val memberName = symbolProvider.toMemberName(member) val memberSymbol = symbolProvider.toSymbol(member) - val default = memberSymbol.defaultValue() withBlock("$memberName: self.$memberName", ",") { - // Write the modifier - when { - !memberSymbol.isOptional() && default == Default.RustDefault -> rust(".unwrap_or_default()") - !memberSymbol.isOptional() -> withBlock( - ".ok_or_else(||", - ")?", - ) { missingRequiredField(memberName) } + // Resolve a default value or return null + val generator = DefaultValueGenerator(runtimeConfig, symbolProvider, model) + val default = generator.defaultValue(member) + if (!memberSymbol.isOptional()) { + if (default != null) { + rust(".unwrap_or_else(#T)", default) + } else { + if (errorCorrection) { + generator.errorCorrection(member)?.also { correction -> rust(".or_else(||#T)", correction) } + } + withBlock( + ".ok_or_else(||", + ")?", + ) { missingRequiredField(memberName) } + } } } } @@ -401,3 +429,11 @@ class BuilderGenerator( } } } + +fun errorCorrectingBuilder(shape: StructureShape, symbolProvider: RustSymbolProvider, model: Model): RuntimeType { + return RuntimeType.forInlineFun("${symbolProvider.symbolForBuilder(shape).name}::build_with_error_correction", symbolProvider.moduleForBuilder(shape)) { + implBlock(symbolProvider.symbolForBuilder(shape)) { + BuilderGenerator(model, symbolProvider, shape, listOf()).errorCorrectingBuilder(this) + } + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/DefaultValueGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/DefaultValueGenerator.kt new file mode 100644 index 0000000000..f4b235c3b8 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/DefaultValueGenerator.kt @@ -0,0 +1,76 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.generators + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.EnumShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.SimpleShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.some +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.Default +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.defaultValue +import software.amazon.smithy.rust.codegen.core.util.isEventStream +import software.amazon.smithy.rust.codegen.core.util.isStreaming + +class DefaultValueGenerator( + runtimeConfig: RuntimeConfig, + private val symbolProvider: RustSymbolProvider, + private val model: Model, +) { + private val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider) + + /** Returns the default value as set by the defaultValue trait */ + fun defaultValue(member: MemberShape): Writable? { + val target = model.expectShape(member.target) + return when (val default = symbolProvider.toSymbol(member).defaultValue()) { + is Default.NoDefault -> null + is Default.RustDefault -> writable("Default::default") + is Default.NonZeroDefault -> { + val instantiation = instantiator.instantiate(target as SimpleShape, default.value) + writable { rust("||#T", instantiation) } + } + } + } + + fun errorCorrection(member: MemberShape): Writable? { + val symbol = symbolProvider.toSymbol(member) + val target = model.expectShape(member.target) + if (member.isEventStream(model) || member.isStreaming(model)) { + return null + } + return writable { + when (target) { + is EnumShape -> rustTemplate(""""no value was set".parse::<#{Shape}>().ok()""", "Shape" to symbol) + is BooleanShape, is NumberShape, is StringShape, is DocumentShape, is ListShape, is MapShape -> rust("Some(Default::default())") + is StructureShape -> rust( + "#T::default().build_with_error_correction().ok()", + symbolProvider.symbolForBuilder(target), + ) + is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this) + is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this) + + is UnionShape -> rust("Some(#T::Unknown)", symbol) + } + } + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index eed196e489..fcae77ae19 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -6,7 +6,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import software.amazon.smithy.codegen.core.CodegenException -import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.node.ArrayNode import software.amazon.smithy.model.node.Node @@ -18,17 +18,18 @@ import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.EnumShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.SetShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.SimpleShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.HttpHeaderTrait import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait @@ -84,11 +85,6 @@ open class Instantiator( private val runtimeConfig: RuntimeConfig, /** Behavior of the builder type used for structure shapes. */ private val builderKindBehavior: BuilderKindBehavior, - /** - * A function that given a symbol for an enum shape and a string, returns a writable to instantiate the enum with - * the string value. - **/ - private val enumFromStringFn: (Symbol, String) -> Writable, /** Fill out required fields with a default value. **/ private val defaultsForRequiredFields: Boolean = false, private val customizations: List = listOf(), @@ -131,64 +127,7 @@ open class Instantiator( // Members, supporting potentially optional members is MemberShape -> renderMember(writer, shape, data, ctx) - // Wrapped Shapes - is TimestampShape -> { - val node = (data as NumberNode) - val num = BigDecimal(node.toString()) - val wholePart = num.toInt() - val fractionalPart = num.remainder(BigDecimal.ONE) - writer.rust( - "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", - RuntimeType.dateTime(runtimeConfig), - ) - } - - /** - * ```rust - * Blob::new("arg") - * ``` - */ - is BlobShape -> if (shape.hasTrait()) { - writer.rust( - "#T::from_static(b${(data as StringNode).value.dq()})", - RuntimeType.byteStream(runtimeConfig), - ) - } else { - writer.rust( - "#T::new(${(data as StringNode).value.dq()})", - RuntimeType.blob(runtimeConfig), - ) - } - - // Simple Shapes - is StringShape -> renderString(writer, shape, data as StringNode) - is NumberShape -> when (data) { - is StringNode -> { - val numberSymbol = symbolProvider.toSymbol(shape) - // support Smithy custom values, such as Infinity - writer.rust( - """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", - numberSymbol, - RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Parse"), - ) - } - - is NumberNode -> writer.write(data.value) - } - - is BooleanShape -> writer.rust(data.asBooleanNode().get().toString()) - is DocumentShape -> writer.rustBlock("") { - val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() - rustTemplate( - """ - let json_bytes = br##"${Node.prettyPrintJson(data)}"##; - let mut tokens = #{json_token_iter}(json_bytes).peekable(); - #{expect_document}(&mut tokens).expect("well formed json") - """, - "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), - "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), - ) - } + is SimpleShape -> PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate(shape, data)(writer) else -> writer.writeWithNoFormatting("todo!() /* $shape $data */") } @@ -214,7 +153,11 @@ open class Instantiator( ")", // The conditions are not commutative: note client builders always take in `Option`. conditional = symbol.isOptional() || - (model.expectShape(memberShape.container) is StructureShape && builderKindBehavior.doesSetterTakeInOption(memberShape)), + ( + model.expectShape(memberShape.container) is StructureShape && builderKindBehavior.doesSetterTakeInOption( + memberShape, + ) + ), *preludeScope, ) { writer.conditionalBlockTemplate( @@ -238,7 +181,8 @@ open class Instantiator( } } - private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode, ctx: Ctx) = renderList(writer, shape, data, ctx) + private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode, ctx: Ctx) = + renderList(writer, shape, data, ctx) /** * ```rust @@ -317,22 +261,18 @@ open class Instantiator( } } - private fun renderString(writer: RustWriter, shape: StringShape, arg: StringNode) { - val data = writer.escape(arg.value).dq() - if (!shape.hasTrait()) { - writer.rust("$data.to_owned()") - } else { - val enumSymbol = symbolProvider.toSymbol(shape) - writer.rustTemplate("#{EnumFromStringFn:W}", "EnumFromStringFn" to enumFromStringFn(enumSymbol, data)) - } - } - /** * ```rust * MyStruct::builder().field_1("hello").field_2(5).build() * ``` */ - private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, headers: Map, ctx: Ctx) { + private fun renderStructure( + writer: RustWriter, + shape: StructureShape, + data: ObjectNode, + headers: Map, + ctx: Ctx, + ) { writer.rust("#T::builder()", symbolProvider.toSymbol(shape)) renderStructureMembers(writer, shape, data, headers, ctx) @@ -416,3 +356,77 @@ open class Instantiator( else -> throw CodegenException("Unrecognized shape `$shape`") } } + +class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private val symbolProvider: SymbolProvider) { + fun instantiate(shape: SimpleShape, data: Node): Writable = writable { + when (shape) { + // Simple Shapes + is TimestampShape -> { + val node = (data as NumberNode) + val num = BigDecimal(node.toString()) + val wholePart = num.toInt() + val fractionalPart = num.remainder(BigDecimal.ONE) + rust( + "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", + RuntimeType.dateTime(runtimeConfig), + ) + } + + /** + * ```rust + * Blob::new("arg") + * ``` + */ + is BlobShape -> if (shape.hasTrait()) { + rust( + "#T::from_static(b${(data as StringNode).value.dq()})", + RuntimeType.byteStream(runtimeConfig), + ) + } else { + rust( + "#T::new(${(data as StringNode).value.dq()})", + RuntimeType.blob(runtimeConfig), + ) + } + + is StringShape -> renderString(shape, data as StringNode)(this) + is NumberShape -> when (data) { + is StringNode -> { + val numberSymbol = symbolProvider.toSymbol(shape) + // support Smithy custom values, such as Infinity + rust( + """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", + numberSymbol, + RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Parse"), + ) + } + + is NumberNode -> write(data.value) + } + + is BooleanShape -> rust(data.asBooleanNode().get().toString()) + is DocumentShape -> rustBlock("") { + val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() + rustTemplate( + """ + let json_bytes = br##"${Node.prettyPrintJson(data)}"##; + let mut tokens = #{json_token_iter}(json_bytes).peekable(); + #{expect_document}(&mut tokens).expect("well formed json") + """, + "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), + "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), + ) + } + } + } + + private fun renderString(shape: StringShape, arg: StringNode): Writable = { + val data = escape(arg.value).dq() + if (shape !is EnumShape) { + rust("$data.to_owned()") + } else { + val enumSymbol = symbolProvider.toSymbol(shape) + rust("$data.parse::<#T>().unwrap()", enumSymbol) + } + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 01c2b64e6a..28c96d9da5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -36,10 +36,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.errorCorrectingBuilder import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional @@ -59,13 +59,16 @@ import software.amazon.smithy.utils.StringUtils * Class describing a JSON parser section that can be used in a customization. */ sealed class JsonParserSection(name: String) : Section(name) { - data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember") + data class BeforeBoxingDeserializedMember(val shape: MemberShape) : + JsonParserSection("BeforeBoxingDeserializedMember") - data class AfterTimestampDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterTimestampDeserializedMember") + data class AfterTimestampDeserializedMember(val shape: MemberShape) : + JsonParserSection("AfterTimestampDeserializedMember") data class AfterBlobDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterBlobDeserializedMember") - data class AfterDocumentDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterDocumentDeserializedMember") + data class AfterDocumentDeserializedMember(val shape: MemberShape) : + JsonParserSection("AfterDocumentDeserializedMember") } /** @@ -251,6 +254,7 @@ class JsonParserGenerator( deserializeMember(member) } } + CodegenTarget.SERVER -> { if (symbolProvider.toSymbol(member).isOptional()) { withBlock("builder = builder.${member.setterName()}(", ");") { @@ -512,7 +516,14 @@ class JsonParserGenerator( if (returnSymbolToParse.isUnconstrained) { rust("Ok(Some(builder))") } else { - rust("Ok(Some(builder.build()))") + rustTemplate( + """ + Ok(Some( + #{correct_errors}(builder) + .map_err(|err|#{Error}::custom_source("Response was invalid", err))? + ))""", + "correct_errors" to errorCorrectingBuilder(shape, symbolProvider, model), *codegenScope, + ) } } } @@ -604,14 +615,10 @@ class JsonParserGenerator( } private fun RustWriter.unwrapOrDefaultOrError(member: MemberShape, checkValueSet: Boolean) { - if (symbolProvider.toSymbol(member).canUseDefault() && !checkValueSet) { - rust(".unwrap_or_default()") - } else { - rustTemplate( - ".ok_or_else(|| #{Error}::custom(\"value for '${escape(member.memberName)}' cannot be null\"))?", - *codegenScope, - ) - } + rustTemplate( + ".ok_or_else(|| #{Error}::custom(\"value for '${escape(member.memberName)}' cannot be null\"))?", + *codegenScope, + ) } private fun RustWriter.objectKeyLoop(hasMembers: Boolean, inner: Writable) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index d083d0e901..594c1166fc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -39,8 +39,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.errorCorrectingBuilder import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional @@ -159,6 +159,7 @@ class XmlBindingTraitParserGenerator( is StructureShape -> { parseStructure(shape, ctx) } + is UnionShape -> parseUnion(shape, ctx) } } @@ -294,7 +295,10 @@ class XmlBindingTraitParserGenerator( } rust("$builder = $builder.${member.setterName()}($temp);") } - rustTemplate("_ => return Err(#{XmlDecodeError}::custom(\"expected ${member.xmlName()} tag\"))", *codegenScope) + rustTemplate( + "_ => return Err(#{XmlDecodeError}::custom(\"expected ${member.xmlName()} tag\"))", + *codegenScope, + ) } } @@ -359,19 +363,23 @@ class XmlBindingTraitParserGenerator( parsePrimitiveInner(memberShape) { rustTemplate("#{try_data}(&mut ${ctx.tag})?.as_ref()", *codegenScope) } + is MapShape -> if (memberShape.isFlattened()) { parseFlatMap(target, ctx) } else { parseMap(target, ctx) } + is CollectionShape -> if (memberShape.isFlattened()) { parseFlatList(target, ctx) } else { parseList(target, ctx) } + is StructureShape -> { parseStructure(target, ctx) } + is UnionShape -> parseUnion(target, ctx) else -> PANIC("Unhandled: $target") } @@ -436,10 +444,16 @@ class XmlBindingTraitParserGenerator( } when (target.renderUnknownVariant()) { true -> rust("_unknown => base = Some(#T::${UnionGenerator.UnknownVariantName}),", symbol) - false -> rustTemplate("""variant => return Err(#{XmlDecodeError}::custom(format!("unexpected union variant: {:?}", variant)))""", *codegenScope) + false -> rustTemplate( + """variant => return Err(#{XmlDecodeError}::custom(format!("unexpected union variant: {:?}", variant)))""", + *codegenScope, + ) } } - rustTemplate("""base.ok_or_else(||#{XmlDecodeError}::custom("expected union, got nothing"))""", *codegenScope) + rustTemplate( + """base.ok_or_else(||#{XmlDecodeError}::custom("expected union, got nothing"))""", + *codegenScope, + ) } } rust("#T(&mut ${ctx.tag})", nestedParser) @@ -474,17 +488,8 @@ class XmlBindingTraitParserGenerator( } else { rust("let _ = decoder;") } - withBlock("Ok(builder.build()", ")") { - if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) { - // NOTE:(rcoh) This branch is unreachable given the current nullability rules. - // Only synthetic inputs can have fallible builders, but synthetic inputs can never be parsed - // (because they're inputs, only outputs will be parsed!) - - // I'm leaving this branch here so that the binding trait parser generator would work for a server - // side implementation in the future. - rustTemplate(""".map_err(|_|#{XmlDecodeError}::custom("missing field"))?""", *codegenScope) - } - } + val correcting = errorCorrectingBuilder(shape, symbolProvider, model) + rustTemplate("Ok(#{correcting}(builder).map_err(|_|#{XmlDecodeError}::custom(\"missing field\"))?)", "correcting" to correcting, *codegenScope) } } rust("#T(&mut ${ctx.tag})", nestedParser) @@ -622,6 +627,7 @@ class XmlBindingTraitParserGenerator( ) } } + is TimestampShape -> { val timestampFormat = index.determineTimestampFormat( @@ -629,7 +635,8 @@ class XmlBindingTraitParserGenerator( HttpBinding.Location.DOCUMENT, TimestampFormatTrait.Format.DATE_TIME, ) - val timestampFormatType = RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat) + val timestampFormatType = + RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat) withBlock("#T::from_str(", ")", RuntimeType.dateTime(runtimeConfig)) { provider() rust(", #T", timestampFormatType) @@ -639,6 +646,7 @@ class XmlBindingTraitParserGenerator( *codegenScope, ) } + is BlobShape -> { withBlock("#T(", ")", RuntimeType.base64Decode(runtimeConfig)) { provider() @@ -648,6 +656,7 @@ class XmlBindingTraitParserGenerator( *codegenScope, ) } + else -> PANIC("unexpected shape: $shape") } } @@ -660,7 +669,10 @@ class XmlBindingTraitParserGenerator( withBlock("#T::try_from(", ")", enumSymbol) { provider() } - rustTemplate(""".map_err(|e| #{XmlDecodeError}::custom(format!("unknown variant {}", e)))?""", *codegenScope) + rustTemplate( + """.map_err(|e| #{XmlDecodeError}::custom(format!("unknown variant {}", e)))?""", + *codegenScope, + ) } else { withBlock("#T::from(", ")", enumSymbol) { provider() @@ -674,7 +686,8 @@ class XmlBindingTraitParserGenerator( } } - private fun convertsToEnumInServer(shape: StringShape) = target == CodegenTarget.SERVER && shape.hasTrait() + private fun convertsToEnumInServer(shape: StringShape) = + target == CodegenTarget.SERVER && shape.hasTrait() private fun MemberShape.xmlName(): XmlName { return XmlName(xmlIndex.memberName(this)) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index b80f211c76..16cc352a50 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -94,7 +94,7 @@ private object CodegenCoreTestModules { builderNamespace, visibility = Visibility.PUBLIC, parent = symbol.module(), - inline = true, + inline = false, ) } } @@ -142,11 +142,12 @@ fun String.asSmithyModel(sourceLocation: String? = null, smithyVersion: String = internal fun testSymbolProvider( model: Model, rustReservedWordConfig: RustReservedWordConfig? = null, + config: RustSymbolProviderConfig = TestRustSymbolProviderConfig, ): RustSymbolProvider = SymbolVisitor( testRustSettings(), model, ServiceShape.builder().version("test").id("test#Service").build(), - TestRustSymbolProviderConfig, + config, ).let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(Attribute.NonExhaustive)) } .let { RustReservedWordSymbolProvider( diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt index 5232fb1df2..aa89acbb2b 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import org.junit.jupiter.api.Test -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.node.NumberNode import software.amazon.smithy.model.node.StringNode @@ -19,7 +18,6 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer @@ -102,14 +100,11 @@ class InstantiatorTest { override fun doesSetterTakeInOption(memberShape: MemberShape) = true } - // This can be empty since the actual behavior is tested in `ClientInstantiatorTest` and `ServerInstantiatorTest`. - private fun enumFromStringFn(symbol: Symbol, data: String) = writable { } - @Test fun `generate unions`() { val union = model.lookup("com.test#MyUnion") val sut = - Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext)) val data = Node.parse("""{ "stringVariant": "ok!" }""") val project = TestWorkspace.testProject(model) @@ -129,7 +124,7 @@ class InstantiatorTest { fun `generate struct builders`() { val structure = model.lookup("com.test#MyStruct") val sut = - Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext)) val data = Node.parse("""{ "bar": 10, "foo": "hello" }""") val project = TestWorkspace.testProject(model) @@ -154,7 +149,7 @@ class InstantiatorTest { fun `generate builders for boxed structs`() { val structure = model.lookup("com.test#WithBox") val sut = - Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext)) val data = Node.parse( """ { @@ -193,7 +188,7 @@ class InstantiatorTest { fun `generate lists`() { val data = Node.parse("""["bar", "foo"]""") val sut = - Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext)) val project = TestWorkspace.testProject() project.lib { @@ -214,7 +209,6 @@ class InstantiatorTest { model, runtimeConfig, BuilderKindBehavior(codegenContext), - ::enumFromStringFn, ) val project = TestWorkspace.testProject(model) @@ -245,7 +239,6 @@ class InstantiatorTest { model, runtimeConfig, BuilderKindBehavior(codegenContext), - ::enumFromStringFn, ) val inner = model.lookup("com.test#Inner") @@ -278,7 +271,6 @@ class InstantiatorTest { model, runtimeConfig, BuilderKindBehavior(codegenContext), - ::enumFromStringFn, ) val project = TestWorkspace.testProject(model) @@ -306,7 +298,6 @@ class InstantiatorTest { model, runtimeConfig, BuilderKindBehavior(codegenContext), - ::enumFromStringFn, ) val project = TestWorkspace.testProject(model) project.testModule { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt index 77932ddfac..4cb6b166a2 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import io.kotest.matchers.string.shouldContainInOrder import io.kotest.matchers.string.shouldNotContain import org.junit.jupiter.api.Test +import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustModule @@ -15,7 +16,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordConfig import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer +import software.amazon.smithy.rust.codegen.core.testutil.TestRustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -435,4 +438,150 @@ class StructureGeneratorTest { writer.toString().shouldNotContain("#[doc(hidden)]") } } + + @Test + fun `it supports error correction`() { + // TODO + val model = """ + ${"$"}version: "2.0" + namespace com.test + structure MyStruct { + @required + int: Integer + + @required + string: String + + @required + list: StringList + + @required + doc: Document + + @required + bool: Boolean + + @required + ts: Timestamp + + @required + blob: Blob + } + + list StringList { + member: String + } + """.asSmithyModel() + + val provider = testSymbolProvider( + model, + rustReservedWordConfig = rustReservedWordConfig, + config = TestRustSymbolProviderConfig.copy(nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL), + ) + val project = TestWorkspace.testProject(provider) + val shape: StructureShape = model.lookup("com.test#MyStruct") + project.useShapeWriter(shape) { + StructureGenerator(model, provider, this, shape, listOf()).render() + BuilderGenerator(model, provider, shape, listOf()).render(this) + unitTest("error_correction") { + rustTemplate( + """ + + Builder::default().build().expect_err("no default set for many fields"); + let corrected = Builder::default().build_with_error_correction().expect("all errors corrected"); + assert_eq!(corrected.int(), 0); + assert_eq!(corrected.string(), ""); + """, + ) + } + } + project.compileAndTest() + } + + @Test + fun `it supports nonzero defaults`() { + // TODO + val model = """ + ${"$"}version: "2.0" + namespace com.test + structure MyStruct { + @default(0) + @required + zeroDefault: Integer + + @required + @default(1) + oneDefault: OneDefault + + @required + @default("") + defaultEmpty: String + + @required + @default("some-value") + defaultValue: String + + @required + anActuallyRequiredField: Integer + + @required + @default([]) + emptyList: StringList + + noDefault: String + + @default(true) + @required + defaultDocument: Document + } + + list StringList { + member: String + } + + @default(1) + integer OneDefault + """.asSmithyModel() + + val provider = testSymbolProvider( + model, + rustReservedWordConfig = rustReservedWordConfig, + config = TestRustSymbolProviderConfig.copy(nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL), + ) + val project = TestWorkspace.testProject(provider) + val shape: StructureShape = model.lookup("com.test#MyStruct") + project.useShapeWriter(shape) { + StructureGenerator(model, provider, this, shape, listOf()).render() + BuilderGenerator(model, provider, shape, listOf()).render(this) + unitTest("test_defaults") { + rustTemplate( + """ + let s = Builder::default().an_actually_required_field(5).build().unwrap(); + assert_eq!(s.zero_default(), 0); + assert_eq!(s.default_empty(), ""); + assert_eq!(s.default_value(), "some-value"); + assert_eq!(s.one_default(), 1); + assert!(s.empty_list().is_empty()); + assert_eq!(s.an_actually_required_field(), 5); + assert_eq!(s.no_default(), None); + assert_eq!(s.default_document().as_bool().unwrap(), true); + + """, + "Struct" to provider.toSymbol(shape), + ) + } + unitTest("error_correction") { + rustTemplate( + """ + + Builder::default().build().expect_err("no default set"); + let corrected = Builder::default().build_with_error_correction().expect("all errors corrected"); + assert_eq!(corrected.an_actually_required_field(), 0); + assert_eq!(Builder::default().an_actually_required_field(0).build().unwrap(), corrected); + """, + ) + } + } + project.compileAndTest() + } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt index 79435b5e9b..2b44e0cbf1 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt @@ -48,8 +48,12 @@ class JsonParserGeneratorTest { s: String, top: Top, unit: Unit, + defaultString: DefaultString } + @default("Foo") + string DefaultString + @enum([{name: "FOO", value: "FOO"}]) string FooEnum @@ -108,7 +112,7 @@ class JsonParserGeneratorTest { output: OpOutput, errors: [Error] } - """.asSmithyModel() + """.asSmithyModel(smithyVersion = "2.0") @Test fun `generates valid deserializers`() { @@ -187,6 +191,14 @@ class JsonParserGeneratorTest { assert_eq!(error_output.message.expect("message should be set"), "hello"); """, ) + + unitTest( + "union_default", + """ + let input = br#"{ "top": { "choice": { "defaultString": null } } }"#; + let output = ${format(operationGenerator)}(input, test_output::OpOutput::builder()).expect_err("cannot be null"); + """, + ) } model.lookup("test#Top").also { top -> top.renderWithModelBuilder(model, symbolProvider, project) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index 3d3105f9c7..09cc9cdc71 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -20,18 +19,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput -/** - * Server enums do not have an `Unknown` variant like client enums do, so constructing an enum from - * a string is a fallible operation (hence `try_from`). It's ok to panic here if construction fails, - * since this is only used in protocol tests. - */ -private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable { - rust( - """#T::try_from($data).expect("this is only used in tests")""", - enumSymbol, - ) -} - class ServerAfterInstantiatingValueConstrainItIfNecessary(val codegenContext: CodegenContext) : InstantiatorCustomization() { @@ -82,7 +69,6 @@ fun serverInstantiator(codegenContext: CodegenContext) = codegenContext.model, codegenContext.runtimeConfig, ServerBuilderKindBehavior(codegenContext), - ::enumFromStringFn, defaultsForRequiredFields = true, customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), ) diff --git a/rust-runtime/aws-smithy-types/src/blob.rs b/rust-runtime/aws-smithy-types/src/blob.rs index 5365b91249..a04391ad7e 100644 --- a/rust-runtime/aws-smithy-types/src/blob.rs +++ b/rust-runtime/aws-smithy-types/src/blob.rs @@ -6,7 +6,7 @@ /// Binary Blob Type /// /// Blobs represent protocol-agnostic binary content. -#[derive(Debug, PartialEq, Eq, Hash, Clone)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, Default)] pub struct Blob { inner: Vec, } From 9d76a22486b194bac0c84bd6287e0b674aa09d3a Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Wed, 13 Sep 2023 15:43:27 -0400 Subject: [PATCH 2/4] Only use error correction when required --- .../core/smithy/generators/BuilderGenerator.kt | 5 ++++- .../protocols/parse/JsonParserGenerator.kt | 18 ++++++++++-------- .../parse/XmlBindingTraitParserGenerator.kt | 12 ++++++++++-- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 816ce25c51..5c1f72dfb3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -430,7 +430,10 @@ class BuilderGenerator( } } -fun errorCorrectingBuilder(shape: StructureShape, symbolProvider: RustSymbolProvider, model: Model): RuntimeType { +fun errorCorrectingBuilder(shape: StructureShape, symbolProvider: RustSymbolProvider, model: Model): RuntimeType? { + if (!BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) { + return null + } return RuntimeType.forInlineFun("${symbolProvider.symbolForBuilder(shape).name}::build_with_error_correction", symbolProvider.moduleForBuilder(shape)) { implBlock(symbolProvider.symbolForBuilder(shape)) { BuilderGenerator(model, symbolProvider, shape, listOf()).errorCorrectingBuilder(this) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 28c96d9da5..91cd8c9efd 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -516,14 +516,16 @@ class JsonParserGenerator( if (returnSymbolToParse.isUnconstrained) { rust("Ok(Some(builder))") } else { - rustTemplate( - """ - Ok(Some( - #{correct_errors}(builder) - .map_err(|err|#{Error}::custom_source("Response was invalid", err))? - ))""", - "correct_errors" to errorCorrectingBuilder(shape, symbolProvider, model), *codegenScope, - ) + val errorCorrection = errorCorrectingBuilder(shape, symbolProvider, model) + if (errorCorrection != null) { + rustTemplate( + """ + Ok(Some(#{correct_errors}(builder).map_err(|err|#{Error}::custom_source("Response was invalid", err))?))""", + "correct_errors" to errorCorrection, *codegenScope, + ) + } else { + rust("Ok(Some(builder.build()))") + } } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index 594c1166fc..a8167ed2a6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -488,8 +488,16 @@ class XmlBindingTraitParserGenerator( } else { rust("let _ = decoder;") } - val correcting = errorCorrectingBuilder(shape, symbolProvider, model) - rustTemplate("Ok(#{correcting}(builder).map_err(|_|#{XmlDecodeError}::custom(\"missing field\"))?)", "correcting" to correcting, *codegenScope) + val errorCorrection = errorCorrectingBuilder(shape, symbolProvider, model) + if (errorCorrection != null) { + rustTemplate( + """ + Ok(#{correct_errors}(builder).map_err(|err|#{XmlDecodeError}::custom_source("Response was invalid", err))?)""", + "correct_errors" to errorCorrection, *codegenScope, + ) + } else { + rust("Ok(builder.build())") + } } } rust("#T(&mut ${ctx.tag})", nestedParser) From 0db25dc16db440ed13fca256acb2d475a8f19256 Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Thu, 14 Sep 2023 09:36:47 -0400 Subject: [PATCH 3/4] Fix more tests --- .../core/smithy/generators/EnumGenerator.kt | 3 + .../core/smithy/generators/Instantiator.kt | 7 +- .../smithy/generators/BuilderGeneratorTest.kt | 139 ++++++++++++++++ .../generators/StructureGeneratorTest.kt | 149 ------------------ 4 files changed, 146 insertions(+), 152 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index a84820ad26..ae4e66600f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -232,6 +232,8 @@ open class EnumGenerator( }, ) + enumType.implFromStr(context)(this) + rustTemplate( """ impl #{From} for ${context.enumName} where T: #{AsRef} { @@ -239,6 +241,7 @@ open class EnumGenerator( ${context.enumName}(s.as_ref().to_owned()) } } + """, *preludeScope, ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index fcae77ae19..e9eeefc3d2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -30,6 +30,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.HttpHeaderTrait import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait @@ -422,11 +423,11 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va private fun renderString(shape: StringShape, arg: StringNode): Writable = { val data = escape(arg.value).dq() - if (shape !is EnumShape) { - rust("$data.to_owned()") - } else { + if (shape.hasTrait() || shape is EnumShape) { val enumSymbol = symbolProvider.toSymbol(shape) rust("$data.parse::<#T>().unwrap()", enumSymbol) + } else { + rust("$data.to_owned()") } } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt index fcec11601e..46532bd53d 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt @@ -7,17 +7,24 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.AllowDeprecated import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.Default import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.setDefault +import software.amazon.smithy.rust.codegen.core.testutil.TestRustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup internal class BuilderGeneratorTest { private val model = StructureGeneratorTest.model @@ -140,4 +147,136 @@ internal class BuilderGeneratorTest { } project.compileAndTest() } + + @Test + fun `it supports error correction`() { + val model = """ + ${"$"}version: "2.0" + namespace com.test + structure MyStruct { + @required + int: Integer + + @required + string: String + + @required + list: StringList + + @required + doc: Document + + @required + bool: Boolean + + @required + ts: Timestamp + + @required + blob: Blob + } + + list StringList { + member: String + } + """.asSmithyModel() + + val provider = testSymbolProvider( + model, + rustReservedWordConfig = StructureGeneratorTest.rustReservedWordConfig, + config = TestRustSymbolProviderConfig.copy(nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL), + ) + val project = TestWorkspace.testProject(provider) + val shape: StructureShape = model.lookup("com.test#MyStruct") + shape.renderWithModelBuilder(model, provider, project) + project.unitTest { + rustTemplate( + """ + + #{Builder}::default().build().expect_err("no default set for many fields"); + let corrected = #{correct_errors}(#{Builder}::default()).expect("all errors corrected"); + assert_eq!(corrected.int(), 0); + assert_eq!(corrected.string(), ""); + """, + "Builder" to provider.symbolForBuilder(shape), + "correct_errors" to errorCorrectingBuilder(shape, provider, model)!!, + ) + } + project.compileAndTest() + } + + @Test + fun `it supports nonzero defaults`() { + val model = """ + ${"$"}version: "2.0" + namespace com.test + structure MyStruct { + @default(0) + @required + zeroDefault: Integer + + @required + @default(1) + oneDefault: OneDefault + + @required + @default("") + defaultEmpty: String + + @required + @default("some-value") + defaultValue: String + + @required + anActuallyRequiredField: Integer + + @required + @default([]) + emptyList: StringList + + noDefault: String + + @default(true) + @required + defaultDocument: Document + } + + list StringList { + member: String + } + + @default(1) + integer OneDefault + """.asSmithyModel() + + val provider = testSymbolProvider( + model, + rustReservedWordConfig = StructureGeneratorTest.rustReservedWordConfig, + config = TestRustSymbolProviderConfig.copy(nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL), + ) + val project = TestWorkspace.testProject(provider) + val shape: StructureShape = model.lookup("com.test#MyStruct") + project.useShapeWriter(shape) { + StructureGenerator(model, provider, this, shape, listOf()).render() + BuilderGenerator(model, provider, shape, listOf()).render(this) + unitTest("test_defaults") { + rustTemplate( + """ + let s = Builder::default().an_actually_required_field(5).build().unwrap(); + assert_eq!(s.zero_default(), 0); + assert_eq!(s.default_empty(), ""); + assert_eq!(s.default_value(), "some-value"); + assert_eq!(s.one_default(), 1); + assert!(s.empty_list().is_empty()); + assert_eq!(s.an_actually_required_field(), 5); + assert_eq!(s.no_default(), None); + assert_eq!(s.default_document().as_bool().unwrap(), true); + + """, + "Struct" to provider.toSymbol(shape), + ) + } + } + project.compileAndTest() + } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt index 4cb6b166a2..77932ddfac 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import io.kotest.matchers.string.shouldContainInOrder import io.kotest.matchers.string.shouldNotContain import org.junit.jupiter.api.Test -import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustModule @@ -16,9 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordConfig import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer -import software.amazon.smithy.rust.codegen.core.testutil.TestRustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -438,150 +435,4 @@ class StructureGeneratorTest { writer.toString().shouldNotContain("#[doc(hidden)]") } } - - @Test - fun `it supports error correction`() { - // TODO - val model = """ - ${"$"}version: "2.0" - namespace com.test - structure MyStruct { - @required - int: Integer - - @required - string: String - - @required - list: StringList - - @required - doc: Document - - @required - bool: Boolean - - @required - ts: Timestamp - - @required - blob: Blob - } - - list StringList { - member: String - } - """.asSmithyModel() - - val provider = testSymbolProvider( - model, - rustReservedWordConfig = rustReservedWordConfig, - config = TestRustSymbolProviderConfig.copy(nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL), - ) - val project = TestWorkspace.testProject(provider) - val shape: StructureShape = model.lookup("com.test#MyStruct") - project.useShapeWriter(shape) { - StructureGenerator(model, provider, this, shape, listOf()).render() - BuilderGenerator(model, provider, shape, listOf()).render(this) - unitTest("error_correction") { - rustTemplate( - """ - - Builder::default().build().expect_err("no default set for many fields"); - let corrected = Builder::default().build_with_error_correction().expect("all errors corrected"); - assert_eq!(corrected.int(), 0); - assert_eq!(corrected.string(), ""); - """, - ) - } - } - project.compileAndTest() - } - - @Test - fun `it supports nonzero defaults`() { - // TODO - val model = """ - ${"$"}version: "2.0" - namespace com.test - structure MyStruct { - @default(0) - @required - zeroDefault: Integer - - @required - @default(1) - oneDefault: OneDefault - - @required - @default("") - defaultEmpty: String - - @required - @default("some-value") - defaultValue: String - - @required - anActuallyRequiredField: Integer - - @required - @default([]) - emptyList: StringList - - noDefault: String - - @default(true) - @required - defaultDocument: Document - } - - list StringList { - member: String - } - - @default(1) - integer OneDefault - """.asSmithyModel() - - val provider = testSymbolProvider( - model, - rustReservedWordConfig = rustReservedWordConfig, - config = TestRustSymbolProviderConfig.copy(nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL), - ) - val project = TestWorkspace.testProject(provider) - val shape: StructureShape = model.lookup("com.test#MyStruct") - project.useShapeWriter(shape) { - StructureGenerator(model, provider, this, shape, listOf()).render() - BuilderGenerator(model, provider, shape, listOf()).render(this) - unitTest("test_defaults") { - rustTemplate( - """ - let s = Builder::default().an_actually_required_field(5).build().unwrap(); - assert_eq!(s.zero_default(), 0); - assert_eq!(s.default_empty(), ""); - assert_eq!(s.default_value(), "some-value"); - assert_eq!(s.one_default(), 1); - assert!(s.empty_list().is_empty()); - assert_eq!(s.an_actually_required_field(), 5); - assert_eq!(s.no_default(), None); - assert_eq!(s.default_document().as_bool().unwrap(), true); - - """, - "Struct" to provider.toSymbol(shape), - ) - } - unitTest("error_correction") { - rustTemplate( - """ - - Builder::default().build().expect_err("no default set"); - let corrected = Builder::default().build_with_error_correction().expect("all errors corrected"); - assert_eq!(corrected.an_actually_required_field(), 0); - assert_eq!(Builder::default().an_actually_required_field(0).build().unwrap(), corrected); - """, - ) - } - } - project.compileAndTest() - } } From a484e5ed0eddcb8008bb00bfc38b45e37c07001d Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Thu, 14 Sep 2023 12:36:53 -0400 Subject: [PATCH 4/4] control error correction on a client/server basis --- .../smithy/generators/BuilderGenerator.kt | 6 ++++- .../generators/DefaultValueGenerator.kt | 16 +++++++++----- .../codegen/core/smithy/protocols/AwsJson.kt | 2 ++ .../codegen/core/smithy/protocols/RestJson.kt | 4 ++-- .../protocols/parse/JsonParserGenerator.kt | 22 ++++++++++++------- .../core/testutil/EventStreamTestModels.kt | 4 ++-- .../generators/protocol/ServerProtocol.kt | 5 +++-- 7 files changed, 38 insertions(+), 21 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 5c1f72dfb3..4417d5913d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -412,7 +412,11 @@ class BuilderGenerator( val default = generator.defaultValue(member) if (!memberSymbol.isOptional()) { if (default != null) { - rust(".unwrap_or_else(#T)", default) + if (default.isRustDefault) { + rust(".unwrap_or_default()") + } else { + rust(".unwrap_or_else(#T)", default) + } } else { if (errorCorrection) { generator.errorCorrection(member)?.also { correction -> rust(".or_else(||#T)", correction) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/DefaultValueGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/DefaultValueGenerator.kt index f4b235c3b8..0f19f66de7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/DefaultValueGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/DefaultValueGenerator.kt @@ -39,15 +39,17 @@ class DefaultValueGenerator( ) { private val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider) + data class DefaultValue(val isRustDefault: Boolean, val expr: Writable) + /** Returns the default value as set by the defaultValue trait */ - fun defaultValue(member: MemberShape): Writable? { + fun defaultValue(member: MemberShape): DefaultValue? { val target = model.expectShape(member.target) return when (val default = symbolProvider.toSymbol(member).defaultValue()) { is Default.NoDefault -> null - is Default.RustDefault -> writable("Default::default") + is Default.RustDefault -> DefaultValue(isRustDefault = true, writable("Default::default")) is Default.NonZeroDefault -> { val instantiation = instantiator.instantiate(target as SimpleShape, default.value) - writable { rust("||#T", instantiation) } + DefaultValue(isRustDefault = false, writable { rust("||#T", instantiation) }) } } } @@ -62,10 +64,12 @@ class DefaultValueGenerator( when (target) { is EnumShape -> rustTemplate(""""no value was set".parse::<#{Shape}>().ok()""", "Shape" to symbol) is BooleanShape, is NumberShape, is StringShape, is DocumentShape, is ListShape, is MapShape -> rust("Some(Default::default())") - is StructureShape -> rust( - "#T::default().build_with_error_correction().ok()", - symbolProvider.symbolForBuilder(target), + is StructureShape -> rustTemplate( + "#{error_correct}(#{Builder}::default()).ok()", + "Builder" to symbolProvider.symbolForBuilder(target), + "error_correct" to errorCorrectingBuilder(target, symbolProvider, model)!!, ) + is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this) is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 0a53422552..59b957aef9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -122,6 +122,7 @@ class AwsJsonSerializerGenerator( open class AwsJson( val codegenContext: CodegenContext, val awsJsonVersion: AwsJsonVersion, + val enableErrorCorrection: Boolean, ) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( @@ -148,6 +149,7 @@ open class AwsJson( codegenContext, httpBindingResolver, ::awsJsonFieldName, + enableErrorCorrection = enableErrorCorrection, ) override fun structuredDataSerializer(): StructuredDataSerializerGenerator = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index ba526d0c87..cebd92ca3d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -59,7 +59,7 @@ class RestJsonHttpBindingResolver( } } -open class RestJson(val codegenContext: CodegenContext) : Protocol { +open class RestJson(val codegenContext: CodegenContext, private val enableErrorCorrection: Boolean) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, @@ -95,7 +95,7 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { listOf("x-amzn-errortype" to errorShape.id.name) override fun structuredDataParser(): StructuredDataParserGenerator = - JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) + JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName, enableErrorCorrection = enableErrorCorrection) override fun structuredDataSerializer(): StructuredDataSerializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 91cd8c9efd..c0261e2571 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -33,11 +33,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.errorCorrectingBuilder import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant @@ -95,6 +97,7 @@ class JsonParserGenerator( private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape -> ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) }, + private val enableErrorCorrection: Boolean, private val customizations: List = listOf(), ) : StructuredDataParserGenerator { private val model = codegenContext.model @@ -515,17 +518,20 @@ class JsonParserGenerator( // Only call `build()` if the builder is not fallible. Otherwise, return the builder. if (returnSymbolToParse.isUnconstrained) { rust("Ok(Some(builder))") - } else { + } else if (enableErrorCorrection && BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) { val errorCorrection = errorCorrectingBuilder(shape, symbolProvider, model) - if (errorCorrection != null) { - rustTemplate( - """ - Ok(Some(#{correct_errors}(builder).map_err(|err|#{Error}::custom_source("Response was invalid", err))?))""", - "correct_errors" to errorCorrection, *codegenScope, - ) + val buildExpr = if (errorCorrection != null) { + writable { rustTemplate("#{correct_errors}(builder)", "correctErrors" to errorCorrection) } } else { - rust("Ok(Some(builder.build()))") + writable { rustTemplate("builder.build()") } } + rustTemplate( + """Ok(Some(#{build}.map_err(|err|#{Error}::custom_source("Response was invalid", err))?))""", + "build" to buildExpr, + *codegenScope, + ) + } else { + rust("Ok(Some(builder.build()))") } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index e944a552a0..fd2817424d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -129,7 +129,7 @@ object EventStreamTestModels { validTestUnion = """{"Foo":"hello"}""", validSomeError = """{"Message":"some error"}""", validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { RestJson(it) }, + ) { RestJson(it, enableErrorCorrection = false) }, // // awsJson1_1 @@ -145,7 +145,7 @@ object EventStreamTestModels { validTestUnion = """{"Foo":"hello"}""", validSomeError = """{"Message":"some error"}""", validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { AwsJson(it, AwsJsonVersion.Json11) }, + ) { AwsJson(it, AwsJsonVersion.Json11, enableErrorCorrection = false) }, // // restXml diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 77ba711fd2..bc472c5ceb 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -113,6 +113,7 @@ fun jsonParserGenerator( httpBindingResolver, jsonName, returnSymbolToParseFn(codegenContext), + enableErrorCorrection = false, listOf( ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(codegenContext), ) + additionalParserCustomizations, @@ -122,7 +123,7 @@ class ServerAwsJsonProtocol( private val serverCodegenContext: ServerCodegenContext, awsJsonVersion: AwsJsonVersion, private val additionalParserCustomizations: List = listOf(), -) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol { +) : AwsJson(serverCodegenContext, awsJsonVersion, enableErrorCorrection = false), ServerProtocol { private val runtimeConfig = codegenContext.runtimeConfig override val protocolModulePath: String @@ -186,7 +187,7 @@ private fun restRouterType(runtimeConfig: RuntimeConfig) = class ServerRestJsonProtocol( private val serverCodegenContext: ServerCodegenContext, private val additionalParserCustomizations: List = listOf(), -) : RestJson(serverCodegenContext), ServerProtocol { +) : RestJson(serverCodegenContext, enableErrorCorrection = false), ServerProtocol { val runtimeConfig = codegenContext.runtimeConfig override val protocolModulePath: String = "rest_json_1"