Skip to content

Commit

Permalink
Extract builderInstantiator interface to prepare for nullability changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Sep 18, 2023
1 parent cf8c834 commit eacf466
Show file tree
Hide file tree
Showing 16 changed files with 262 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientBuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.ModuleDocProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol

/**
Expand All @@ -36,4 +38,7 @@ data class ClientCodegenContext(
model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.CLIENT,
) {
val enableUserConfigurableRuntimePlugins: Boolean get() = settings.codegenConfig.enableUserConfigurableRuntimePlugins
override fun builderInstantiator(): BuilderInstantiator {
return ClientBuilderInstantiator(symbolProvider)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator

fun ClientCodegenContext.builderInstantiator(): BuilderInstantiator = ClientBuilderInstantiator(symbolProvider)

class ClientBuilderInstantiator(private val symbolProvider: RustSymbolProvider) : BuilderInstantiator {
override fun setField(builder: String, value: Writable, field: MemberShape): Writable {
return setFieldBase(builder, value, field)
}

override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable = writable {
if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) {
rustTemplate(
"$builder.build()#{mapErr}?",
"mapErr" to (
mapErr?.map {
rust(".map_err(#T)", it)
} ?: writable { }
),
)
} else {
rust("$builder.build()")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ private class ClientAwsJsonFactory(private val version: AwsJsonVersion) :
ProtocolGeneratorFactory<OperationGenerator, ClientCodegenContext> {
override fun protocol(codegenContext: ClientCodegenContext): Protocol =
if (compatibleWithAwsQuery(codegenContext.serviceShape, version)) {
AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version))
AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version, codegenContext.builderInstantiator()))
} else {
AwsJson(codegenContext, version)
AwsJson(codegenContext, version, codegenContext.builderInstantiator())
}

override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator

/**
* [CodegenContext] contains code-generation context that is _common to all_ smithy-rs plugins.
Expand All @@ -17,7 +18,7 @@ import software.amazon.smithy.model.shapes.ShapeId
* If your data is specific to the `rust-client-codegen` client plugin, put it in [ClientCodegenContext] instead.
* If your data is specific to the `rust-server-codegen` server plugin, put it in [ServerCodegenContext] instead.
*/
open class CodegenContext(
abstract class CodegenContext(
/**
* The smithy model.
*
Expand Down Expand Up @@ -89,4 +90,6 @@ open class CodegenContext(
fun expectModuleDocProvider(): ModuleDocProvider = checkNotNull(moduleDocProvider) {
"A ModuleDocProvider must be set on the CodegenContext"
}

abstract fun builderInstantiator(): BuilderInstantiator
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable

/** Abstraction for instantiating a builders.
*
* Builder abstractions vary—clients MAY use `build_with_error_correction`, e.g., and builders can vary in fallibility.
* */
interface BuilderInstantiator {
/** Set a field on a builder. */
fun setField(builder: String, value: Writable, field: MemberShape): Writable

/** Finalize a builder, turning into a built object (or in the case of builders-of-builders, return the builder directly).*/
fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable? = null): Writable

/** Set a field on a builder using the `$setterName` method. $value will be passed directly. */
fun setFieldBase(builder: String, value: Writable, field: MemberShape) = writable {
rustTemplate("$builder = $builder.${field.setterName()}(#{value})", "value" to value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
Expand Down Expand Up @@ -122,6 +123,7 @@ class AwsJsonSerializerGenerator(
open class AwsJson(
val codegenContext: CodegenContext,
val awsJsonVersion: AwsJsonVersion,
val builderInstantiator: BuilderInstantiator,
) : Protocol {
private val runtimeConfig = codegenContext.runtimeConfig
private val errorScope = arrayOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class EventStreamUnmarshallerGenerator(
private val unionShape: UnionShape,
) {
private val model = codegenContext.model
private val builderInstantiator = codegenContext.builderInstantiator()
private val symbolProvider = codegenContext.symbolProvider
private val codegenTarget = codegenContext.target
private val runtimeConfig = codegenContext.runtimeConfig
Expand Down Expand Up @@ -339,6 +340,7 @@ class EventStreamUnmarshallerGenerator(
// TODO(EventStream): Errors on the operation can be disjoint with errors in the union,
// so we need to generate a new top-level Error type for each event stream union.
when (codegenTarget) {
// TODO(https://github.com/awslabs/smithy-rs/issues/1970) It should be possible to unify these branches now
CodegenTarget.CLIENT -> {
val target = model.expectShape(member.target, StructureShape::class.java)
val parser = protocol.structuredDataParser().errorParser(target)
Expand All @@ -352,9 +354,19 @@ class EventStreamUnmarshallerGenerator(
})?;
builder.set_meta(Some(generic));
return Ok(#{UnmarshalledMessage}::Error(
#{OpError}::${member.target.name}(builder.build())
#{OpError}::${member.target.name}(
#{build}
)
))
""",
"build" to builderInstantiator.finalizeBuilder(
"builder", target,
mapErr = {
rustTemplate(
"""|err|#{Error}::unmarshalling(format!("{}", err))""", *codegenScope,
)
},
),
"parser" to parser,
*codegenScope,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,16 @@ import software.amazon.smithy.utils.StringUtils
* Class describing a JSON parser section that can be used in a customization.
*/
sealed class JsonParserSection(name: String) : Section(name) {
data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember")
data class BeforeBoxingDeserializedMember(val shape: MemberShape) :
JsonParserSection("BeforeBoxingDeserializedMember")

data class AfterTimestampDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterTimestampDeserializedMember")
data class AfterTimestampDeserializedMember(val shape: MemberShape) :
JsonParserSection("AfterTimestampDeserializedMember")

data class AfterBlobDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterBlobDeserializedMember")

data class AfterDocumentDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterDocumentDeserializedMember")
data class AfterDocumentDeserializedMember(val shape: MemberShape) :
JsonParserSection("AfterDocumentDeserializedMember")
}

/**
Expand Down Expand Up @@ -100,6 +103,7 @@ class JsonParserGenerator(
private val codegenTarget = codegenContext.target
private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType()
private val protocolFunctions = ProtocolFunctions(codegenContext)
private val builderInstantiator = codegenContext.builderInstantiator()
private val codegenScope = arrayOf(
"Error" to smithyJson.resolve("deserialize::error::DeserializeError"),
"expect_blob_or_null" to smithyJson.resolve("deserialize::token::expect_blob_or_null"),
Expand Down Expand Up @@ -251,6 +255,7 @@ class JsonParserGenerator(
deserializeMember(member)
}
}

CodegenTarget.SERVER -> {
if (symbolProvider.toSymbol(member).isOptional()) {
withBlock("builder = builder.${member.setterName()}(", ");") {
Expand Down Expand Up @@ -508,12 +513,14 @@ class JsonParserGenerator(
"Builder" to symbolProvider.symbolForBuilder(shape),
)
deserializeStructInner(shape.members())
// Only call `build()` if the builder is not fallible. Otherwise, return the builder.
if (returnSymbolToParse.isUnconstrained) {
rust("Ok(Some(builder))")
} else {
rust("Ok(Some(builder.build()))")
val builder = builderInstantiator.finalizeBuilder(
"builder", shape,
) {
rustTemplate(
"""|err|#{Error}::custom_source("Response was invalid", err)""", *codegenScope,
)
}
rust("Ok(Some(#T))", builder)
}
}
}
Expand Down
Loading

0 comments on commit eacf466

Please sign in to comment.