Skip to content

Commit

Permalink
Add support for error-correcting builders
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Sep 19, 2023
1 parent a5c1ced commit 34eafd8
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ data class ClientCodegenContext(
) {
val enableUserConfigurableRuntimePlugins: Boolean get() = settings.codegenConfig.enableUserConfigurableRuntimePlugins
override fun builderInstantiator(): BuilderInstantiator {
return ClientBuilderInstantiator(symbolProvider)
return ClientBuilderInstantiator(this)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,40 @@ import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator

fun ClientCodegenContext.builderInstantiator(): BuilderInstantiator = ClientBuilderInstantiator(symbolProvider)

class ClientBuilderInstantiator(private val symbolProvider: RustSymbolProvider) : BuilderInstantiator {
class ClientBuilderInstantiator(private val clientCodegenContext: ClientCodegenContext) : BuilderInstantiator {
override fun setField(builder: String, value: Writable, field: MemberShape): Writable {
return setFieldWithSetter(builder, value, field)
}

/**
* For the client, we finalize builders with error correction enabled
*/
override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable = writable {
if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) {
val correctErrors = clientCodegenContext.correctErrors(shape)
val builderW = writable {
when {
correctErrors != null -> rustTemplate("#{correctErrors}($builder)", "correctErrors" to correctErrors)
else -> rustTemplate(builder)
}
}
if (BuilderGenerator.hasFallibleBuilder(shape, clientCodegenContext.symbolProvider)) {
rustTemplate(
"$builder.build()#{mapErr}?",
"#{builder}.build()#{mapErr}?",
"builder" to builderW,
"mapErr" to (
mapErr?.map {
rust(".map_err(#T)", it)
} ?: writable { }
),
)
} else {
rust("$builder.build()")
rustTemplate(
"#{builder}.build()",
"builder" to builderW,
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.isEmpty
import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.plus
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.some
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.PrimitiveInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed
import software.amazon.smithy.rust.codegen.core.smithy.protocols.shapeFunctionName
import software.amazon.smithy.rust.codegen.core.util.isEventStream
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.core.util.letIf

/**
* For AWS-services, the spec defines error correction semantics to recover from missing default values for required members:
* https://smithy.io/2.0/spec/aggregate-types.html?highlight=error%20correction#client-error-correction
*/

private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Writable? {
if (!member.isRequired) {
return null
}
symbolProvider.toSymbol(member)
val target = model.expectShape(member.target)
val memberSymbol = symbolProvider.toSymbol(member)
val targetSymbol = symbolProvider.toSymbol(target)
if (member.isEventStream(model) || member.isStreaming(model)) {
return null
}
val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider)
return writable {
when (target) {
is EnumShape -> rustTemplate(""""no value was set".parse::<#{Shape}>().ok()""", "Shape" to targetSymbol)
is BooleanShape, is NumberShape, is StringShape, is DocumentShape, is ListShape, is MapShape -> rust("Some(Default::default())")
is StructureShape -> rustTemplate(
"{ let builder = #{Builder}::default(); #{instantiate} }",
"Builder" to symbolProvider.symbolForBuilder(target),
"instantiate" to builderInstantiator().finalizeBuilder("builder", target).map {
if (BuilderGenerator.hasFallibleBuilder(target, symbolProvider)) {
rust("#T.ok()", it)
} else {
it.some()(this)
}
}.letIf(memberSymbol.isRustBoxed()) {
it.plus { rustTemplate(".map(#{Box}::new)", *preludeScope) }
},
)

is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this)
is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this)
is UnionShape -> rust("Some(#T::Unknown)", targetSymbol)
}
}
}

fun ClientCodegenContext.correctErrors(shape: StructureShape): RuntimeType? {
val name = symbolProvider.shapeFunctionName(serviceShape, shape) + "_correct_errors"
val corrections = writable {
shape.members().forEach { member ->
val memberName = symbolProvider.toMemberName(member)
errorCorrectedDefault(member)?.also { default ->
rustTemplate(
"""if builder.$memberName.is_none() { builder.$memberName = #{default} }""",
"default" to default,
)
}
}
}

if (corrections.isEmpty()) {
return null
}

return RuntimeType.forInlineFun(name, RustModule.private("serde_util")) {
rustTemplate(
"""
pub(crate) fn $name(mut builder: #{Builder}) -> #{Builder} {
#{corrections}
builder
}
""",
"Builder" to symbolProvider.symbolForBuilder(shape),
"corrections" to corrections,
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.generators

import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.lookup

class ErrorCorrectionTest {
private val model = """
namespace com.example
use aws.protocols#awsJson1_0
@awsJson1_0
service HelloService {
operations: [SayHello],
version: "1"
}
operation SayHello { input: TestInput }
structure TestInput { nested: TestStruct }
structure TestStruct {
@required
foo: String,
@required
byteValue: Byte,
@required
listValue: StringList,
@required
mapValue: ListMap,
@required
doubleListValue: DoubleList
@required
document: Document
@required
nested: Nested
@required
blob: Blob
@required
enum: Enum
@required
union: U
notRequired: String
}
enum Enum {
A,
B,
C
}
union U {
A: Integer,
B: String,
C: Unit
}
structure Nested {
@required
a: String
}
list StringList {
member: String
}
list DoubleList {
member: StringList
}
map ListMap {
key: String,
value: StringList
}
""".asSmithyModel(smithyVersion = "2.0")

@Test
fun correctMissingFields() {
val shape = model.lookup<StructureShape>("com.example#TestStruct")
clientIntegrationTest(model) { ctx, crate ->
crate.lib {
val codegenCtx =
arrayOf("correct_errors" to ctx.correctErrors(shape), "Shape" to ctx.symbolProvider.toSymbol(shape))
rustTemplate(
"""
/// docs
pub fn use_fn_publicly() { #{correct_errors}(#{Shape}::builder()); } """,
*codegenCtx,
)
unitTest("test_default_builder") {
rustTemplate(
"""
let builder = #{correct_errors}(#{Shape}::builder().foo("abcd"));
let shape = builder.build();
// don't override a field already set
assert_eq!(shape.foo(), Some("abcd"));
// set nested fields
assert_eq!(shape.nested().unwrap().a(), Some(""));
// don't default non-required fields
assert_eq!(shape.not_required(), None);
assert_eq!(shape.blob().unwrap().as_ref(), &[]);
""",
*codegenCtx,

)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ internal fun RustSymbolProvider.shapeModuleName(serviceShape: ServiceShape?, sha
)

/** Creates a unique name for a ser/de function. */
internal fun RustSymbolProvider.shapeFunctionName(serviceShape: ServiceShape?, shape: Shape): String {
fun RustSymbolProvider.shapeFunctionName(serviceShape: ServiceShape?, shape: Shape): String {
val containerName = when (shape) {
is MemberShape -> model.expectShape(shape.container).contextName(serviceShape).toSnakeCase()
else -> shape.contextName(serviceShape).toSnakeCase()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ internal class BuilderGeneratorTest {
@Test
fun `it supports nonzero defaults`() {
val model = """
${"$"}version: "2.0"
namespace com.test
structure MyStruct {
@default(0)
Expand Down Expand Up @@ -180,7 +179,7 @@ internal class BuilderGeneratorTest {
}
@default(1)
integer OneDefault
""".asSmithyModel()
""".asSmithyModel(smithyVersion = "2.0")

val provider = testSymbolProvider(
model,
Expand Down

0 comments on commit 34eafd8

Please sign in to comment.