Skip to content

Commit 1bb3243

Browse files
committed
Add support for error-correcting builders
1 parent a5c1ced commit 1bb3243

File tree

6 files changed

+232
-11
lines changed

6 files changed

+232
-11
lines changed

codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ data class ClientCodegenContext(
3939
) {
4040
val enableUserConfigurableRuntimePlugins: Boolean get() = settings.codegenConfig.enableUserConfigurableRuntimePlugins
4141
override fun builderInstantiator(): BuilderInstantiator {
42-
return ClientBuilderInstantiator(symbolProvider)
42+
return ClientBuilderInstantiator(this)
4343
}
4444
}

codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt

+11-7
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,33 @@ import software.amazon.smithy.rust.codegen.core.rustlang.map
1313
import software.amazon.smithy.rust.codegen.core.rustlang.rust
1414
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
1515
import software.amazon.smithy.rust.codegen.core.rustlang.writable
16-
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
1716
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
1817
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator
1918

20-
fun ClientCodegenContext.builderInstantiator(): BuilderInstantiator = ClientBuilderInstantiator(symbolProvider)
21-
22-
class ClientBuilderInstantiator(private val symbolProvider: RustSymbolProvider) : BuilderInstantiator {
19+
class ClientBuilderInstantiator(private val clientCodegenContext: ClientCodegenContext) : BuilderInstantiator {
2320
override fun setField(builder: String, value: Writable, field: MemberShape): Writable {
2421
return setFieldWithSetter(builder, value, field)
2522
}
2623

24+
/**
25+
* For the client, we finalize builders with error correction enabled
26+
*/
2727
override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable = writable {
28-
if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) {
28+
if (BuilderGenerator.hasFallibleBuilder(shape, clientCodegenContext.symbolProvider)) {
2929
rustTemplate(
30-
"$builder.build()#{mapErr}?",
30+
"#{correct_errors}($builder).build()#{mapErr}?",
31+
"correct_errors" to clientCodegenContext.correctErrors(shape),
3132
"mapErr" to (
3233
mapErr?.map {
3334
rust(".map_err(#T)", it)
3435
} ?: writable { }
3536
),
3637
)
3738
} else {
38-
rust("$builder.build()")
39+
rustTemplate(
40+
"#{correct_errors}($builder).build()",
41+
"correct_errors" to clientCodegenContext.correctErrors(shape),
42+
)
3943
}
4044
}
4145
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package software.amazon.smithy.rust.codegen.client.smithy.generators
7+
8+
import software.amazon.smithy.model.node.Node
9+
import software.amazon.smithy.model.shapes.BlobShape
10+
import software.amazon.smithy.model.shapes.BooleanShape
11+
import software.amazon.smithy.model.shapes.DocumentShape
12+
import software.amazon.smithy.model.shapes.EnumShape
13+
import software.amazon.smithy.model.shapes.ListShape
14+
import software.amazon.smithy.model.shapes.MapShape
15+
import software.amazon.smithy.model.shapes.MemberShape
16+
import software.amazon.smithy.model.shapes.NumberShape
17+
import software.amazon.smithy.model.shapes.StringShape
18+
import software.amazon.smithy.model.shapes.StructureShape
19+
import software.amazon.smithy.model.shapes.TimestampShape
20+
import software.amazon.smithy.model.shapes.UnionShape
21+
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
22+
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
23+
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
24+
import software.amazon.smithy.rust.codegen.core.rustlang.map
25+
import software.amazon.smithy.rust.codegen.core.rustlang.rust
26+
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
27+
import software.amazon.smithy.rust.codegen.core.rustlang.some
28+
import software.amazon.smithy.rust.codegen.core.rustlang.writable
29+
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
30+
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
31+
import software.amazon.smithy.rust.codegen.core.smithy.generators.PrimitiveInstantiator
32+
import software.amazon.smithy.rust.codegen.core.smithy.protocols.shapeFunctionName
33+
import software.amazon.smithy.rust.codegen.core.util.isEventStream
34+
import software.amazon.smithy.rust.codegen.core.util.isStreaming
35+
36+
/**
37+
* For AWS-services, the spec defines error correction semantics to recover from missing default values for required members:
38+
* https://smithy.io/2.0/spec/aggregate-types.html?highlight=error%20correction#client-error-correction
39+
*/
40+
41+
private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Writable? {
42+
if (!member.isRequired) {
43+
return null
44+
}
45+
symbolProvider.toSymbol(member)
46+
val target = model.expectShape(member.target)
47+
val targetSymbol = symbolProvider.toSymbol(target)
48+
if (member.isEventStream(model) || member.isStreaming(model)) {
49+
return null
50+
}
51+
val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider)
52+
return writable {
53+
when (target) {
54+
is EnumShape -> rustTemplate(""""no value was set".parse::<#{Shape}>().ok()""", "Shape" to targetSymbol)
55+
is BooleanShape, is NumberShape, is StringShape, is DocumentShape, is ListShape, is MapShape -> rust("Some(Default::default())")
56+
is StructureShape -> rustTemplate(
57+
"{ let builder = #{Builder}::default(); #{instantiate} }",
58+
"Builder" to symbolProvider.symbolForBuilder(target),
59+
"instantiate" to builderInstantiator().finalizeBuilder("builder", target).map {
60+
if (BuilderGenerator.hasFallibleBuilder(target, symbolProvider)) {
61+
rust("#T.ok()", it)
62+
} else {
63+
it.some()(this)
64+
}
65+
},
66+
)
67+
68+
is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this)
69+
is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this)
70+
is UnionShape -> rust("Some(#T::Unknown)", targetSymbol)
71+
}
72+
}
73+
}
74+
75+
fun ClientCodegenContext.correctErrors(shape: StructureShape): RuntimeType {
76+
val name = symbolProvider.shapeFunctionName(serviceShape, shape) + "_correct_errors"
77+
val corrections = writable {
78+
shape.members().forEach { member ->
79+
val memberName = symbolProvider.toMemberName(member)
80+
errorCorrectedDefault(member)?.also { default ->
81+
rustTemplate(
82+
"""if builder.$memberName.is_none() { builder.$memberName = #{default} }""",
83+
"default" to default,
84+
)
85+
}
86+
}
87+
}
88+
89+
return RuntimeType.forInlineFun(name, RustModule.private("serde_util")) {
90+
rustTemplate(
91+
"""
92+
pub(crate) fn $name(mut builder: #{Builder}) -> #{Builder} {
93+
#{corrections}
94+
builder
95+
}
96+
97+
""",
98+
"Builder" to symbolProvider.symbolForBuilder(shape),
99+
"corrections" to corrections,
100+
)
101+
}
102+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package software.amazon.smithy.rust.codegen.client.smithy.generators
6+
7+
import org.junit.jupiter.api.Test
8+
import software.amazon.smithy.model.shapes.StructureShape
9+
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
10+
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
11+
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
12+
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
13+
import software.amazon.smithy.rust.codegen.core.util.lookup
14+
15+
class ErrorCorrectionTest {
16+
private val model = """
17+
namespace com.example
18+
use aws.protocols#awsJson1_0
19+
20+
@awsJson1_0
21+
service HelloService {
22+
operations: [SayHello],
23+
version: "1"
24+
}
25+
26+
operation SayHello { input: TestInput }
27+
structure TestInput { nested: TestStruct }
28+
structure TestStruct {
29+
@required
30+
foo: String,
31+
@required
32+
byteValue: Byte,
33+
@required
34+
listValue: StringList,
35+
@required
36+
mapValue: ListMap,
37+
@required
38+
doubleListValue: DoubleList
39+
@required
40+
document: Document
41+
@required
42+
nested: Nested
43+
@required
44+
blob: Blob
45+
@required
46+
enum: Enum
47+
@required
48+
union: U
49+
notRequired: String
50+
}
51+
52+
enum Enum {
53+
A,
54+
B,
55+
C
56+
}
57+
58+
union U {
59+
A: Integer,
60+
B: String,
61+
C: Unit
62+
}
63+
64+
structure Nested {
65+
@required
66+
a: String
67+
}
68+
69+
list StringList {
70+
member: String
71+
}
72+
73+
list DoubleList {
74+
member: StringList
75+
}
76+
77+
map ListMap {
78+
key: String,
79+
value: StringList
80+
}
81+
""".asSmithyModel(smithyVersion = "2.0")
82+
83+
@Test
84+
fun correctMissingFields() {
85+
val shape = model.lookup<StructureShape>("com.example#TestStruct")
86+
clientIntegrationTest(model) { ctx, crate ->
87+
crate.lib {
88+
val codegenCtx =
89+
arrayOf("correct_errors" to ctx.correctErrors(shape), "Shape" to ctx.symbolProvider.toSymbol(shape))
90+
rustTemplate(
91+
"""
92+
/// docs
93+
pub fn use_fn_publicly() { #{correct_errors}(#{Shape}::builder()); } """,
94+
*codegenCtx,
95+
)
96+
unitTest("test_default_builder") {
97+
rustTemplate(
98+
"""
99+
let builder = #{correct_errors}(#{Shape}::builder().foo("abcd"));
100+
let shape = builder.build();
101+
// don't override a field already set
102+
assert_eq!(shape.foo(), Some("abcd"));
103+
// set nested fields
104+
assert_eq!(shape.nested().unwrap().a(), Some(""));
105+
// don't default non-required fields
106+
assert_eq!(shape.not_required(), None);
107+
assert_eq!(shape.blob().unwrap().as_ref(), &[]);
108+
""",
109+
*codegenCtx,
110+
111+
)
112+
}
113+
}
114+
}
115+
}
116+
}

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ internal fun RustSymbolProvider.shapeModuleName(serviceShape: ServiceShape?, sha
138138
)
139139

140140
/** Creates a unique name for a ser/de function. */
141-
internal fun RustSymbolProvider.shapeFunctionName(serviceShape: ServiceShape?, shape: Shape): String {
141+
fun RustSymbolProvider.shapeFunctionName(serviceShape: ServiceShape?, shape: Shape): String {
142142
val containerName = when (shape) {
143143
is MemberShape -> model.expectShape(shape.container).contextName(serviceShape).toSnakeCase()
144144
else -> shape.contextName(serviceShape).toSnakeCase()

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ internal class BuilderGeneratorTest {
150150
@Test
151151
fun `it supports nonzero defaults`() {
152152
val model = """
153-
${"$"}version: "2.0"
154153
namespace com.test
155154
structure MyStruct {
156155
@default(0)
@@ -180,7 +179,7 @@ internal class BuilderGeneratorTest {
180179
}
181180
@default(1)
182181
integer OneDefault
183-
""".asSmithyModel()
182+
""".asSmithyModel(smithyVersion = "2.0")
184183

185184
val provider = testSymbolProvider(
186185
model,

0 commit comments

Comments
 (0)