Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the default(..) trait to source default information #2982

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class EndpointTypesGenerator(
val tests: List<EndpointTestCase>,
) {
val params: Parameters = rules?.parameters ?: Parameters.builder().build()
private val runtimeConfig = codegenContext.runtimeConfig
private val customizations = codegenContext.rootDecorator.endpointCustomizations(codegenContext)
private val stdlib = customizations
.flatMap { it.customRuntimeFunctions(codegenContext) }
Expand All @@ -41,7 +40,6 @@ class EndpointTypesGenerator(
}

fun paramsStruct(): RuntimeType = EndpointParamsGenerator(codegenContext, params).paramsStruct()
fun paramsBuilder(): RuntimeType = EndpointParamsGenerator(codegenContext, params).paramsBuilder()
fun defaultResolver(): RuntimeType? =
rules?.let { EndpointResolverGenerator(codegenContext, stdlib).defaultEndpointResolver(it) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.MemberShape
Expand All @@ -14,18 +13,12 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName

private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable {
rust("#T::from($data)", enumSymbol)
}

class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior {
override fun hasFallibleBuilder(shape: StructureShape): Boolean =
BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider)
Expand All @@ -40,7 +33,6 @@ class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Ins
codegenContext.model,
codegenContext.runtimeConfig,
ClientBuilderKindBehavior(codegenContext),
::enumFromStringFn,
) {
fun renderFluentCall(
writer: RustWriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ fun Writable.map(f: RustWriter.(Writable) -> Unit): Writable {
return writable { f(self) }
}

/** Returns Some(..arg) */
fun Writable.some(): Writable {
return this.map { rust("Some(#T)", it) }
}

fun Writable.isNotEmpty(): Boolean = !this.isEmpty()

operator fun Writable.plus(other: Writable): Writable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
*
* Prelude docs: https://doc.rust-lang.org/std/prelude/index.html#prelude-contents
*/
val preludeScope by lazy {
val preludeScope: Array<Pair<String, Any>> by lazy {
arrayOf(
// Rust 1.0
"Copy" to std.resolve("marker::Copy"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.core.smithy

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
Expand Down Expand Up @@ -102,6 +103,8 @@ sealed class Default {
* This symbol should use the Rust `std::default::Default` when unset
*/
object RustDefault : Default()

data class NonZeroDefault(val value: Node) : Default()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.BigDecimalShape
import software.amazon.smithy.model.shapes.BigIntegerShape
import software.amazon.smithy.model.shapes.BlobShape
Expand Down Expand Up @@ -37,6 +38,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
Expand All @@ -48,6 +50,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import kotlin.reflect.KClass
Expand Down Expand Up @@ -79,16 +82,18 @@ data class MaybeRenamed(val name: String, val renamedFrom: String?)
/**
* Make the return [value] optional if the [member] symbol is as well optional.
*/
fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = value.letIf(toSymbol(member).isOptional()) {
"Some($value)"
}
fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String =
value.letIf(toSymbol(member).isOptional()) {
"Some($value)"
}

/**
* Make the return [value] optional if the [member] symbol is not optional.
*/
fun SymbolProvider.toOptional(member: MemberShape, value: String): String = value.letIf(!toSymbol(member).isOptional()) {
"Some($value)"
}
fun SymbolProvider.toOptional(member: MemberShape, value: String): String =
value.letIf(!toSymbol(member).isOptional()) {
"Some($value)"
}

/**
* Services can rename their contained shapes. See https://awslabs.github.io/smithy/1.0/spec/core/model.html#service
Expand Down Expand Up @@ -170,7 +175,7 @@ open class SymbolVisitor(
}

private fun simpleShape(shape: SimpleShape): Symbol {
return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).setDefault(Default.RustDefault).build()
return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).build()
}

override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape)
Expand Down Expand Up @@ -263,13 +268,21 @@ open class SymbolVisitor(

override fun memberShape(shape: MemberShape): Symbol {
val target = model.expectShape(shape.target)
val defaultValue = shape.getMemberTrait(model, DefaultTrait::class.java).orNull()?.let { trait ->
when (val value = trait.toNode()) {
Node.from(""), Node.from(0), Node.from(false), Node.arrayNode(), Node.objectNode() -> Default.RustDefault
Node.nullNode() -> Default.NoDefault
else -> { Default.NonZeroDefault(value)
}
}
} ?: Default.NoDefault
// Handle boxing first, so we end up with Option<Box<_>>, not Box<Option<_>>.
return handleOptionality(
handleRustBoxing(toSymbol(target), shape),
shape,
nullableIndex,
config.nullabilityCheckMode,
)
).toBuilder().setDefault(defaultValue).build()
}

override fun timestampShape(shape: TimestampShape?): Symbol {
Expand Down Expand Up @@ -297,7 +310,12 @@ fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder =
// If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation
.definitionFile("thisisabug.rs")

fun handleOptionality(symbol: Symbol, member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode): Symbol =
fun handleOptionality(
symbol: Symbol,
member: MemberShape,
nullableIndex: NullableIndex,
nullabilityCheckMode: CheckMode,
): Symbol =
symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() }

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlockTemplat
import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.documentShape
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.rust
Expand All @@ -32,7 +33,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
Expand All @@ -41,7 +41,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.defaultValue
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
Expand Down Expand Up @@ -181,14 +180,15 @@ class BuilderGenerator(
private fun renderBuildFn(implBlockWriter: RustWriter) {
val fallibleBuilder = hasFallibleBuilder(shape, symbolProvider)
val outputSymbol = symbolProvider.toSymbol(shape)
val fb =
"#{Result}<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>"
val returnType = when (fallibleBuilder) {
true -> "#{Result}<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>"
true -> fb
false -> implBlockWriter.format(outputSymbol)
}
implBlockWriter.docs("Consumes the builder and constructs a #D.", outputSymbol)
implBlockWriter.rustBlockTemplate("pub fn build(self) -> $returnType", *preludeScope) {
conditionalBlockTemplate("#{Ok}(", ")", conditional = fallibleBuilder, *preludeScope) {
// If a wrapper is specified, use the `::new` associated function to construct the wrapper
coreBuilder(this)
}
}
Expand Down Expand Up @@ -369,6 +369,27 @@ class BuilderGenerator(
}
}

internal fun errorCorrectingBuilder(writer: RustWriter) {
val outputSymbol = symbolProvider.toSymbol(shape)
val fb =
"#{Result}<${writer.format(outputSymbol)}, ${writer.format(runtimeConfig.operationBuildError())}>"
writer.rustBlockTemplate(
"pub(crate) fn build_with_error_correction(self) -> $fb",
*preludeScope,
) {
val fallibleBuilder = hasFallibleBuilder(shape, symbolProvider)
if (fallibleBuilder) {
rustTemplate(
"#{Ok}(#{Builder})",
*preludeScope,
"Builder" to writable { coreBuilder(this, errorCorrection = true) },
)
} else {
rustTemplate("#{Ok}(self.build())", *preludeScope)
}
}
}

/**
* The core builder of the inner type. If the structure requires a fallible builder, this may use `?` to return
* errors.
Expand All @@ -380,24 +401,46 @@ class BuilderGenerator(
* }
* ```
*/
private fun coreBuilder(writer: RustWriter) {
private fun coreBuilder(writer: RustWriter, errorCorrection: Boolean = false) {
writer.rustBlock("#T", structureSymbol) {
members.forEach { member ->
val memberName = symbolProvider.toMemberName(member)
val memberSymbol = symbolProvider.toSymbol(member)
val default = memberSymbol.defaultValue()
withBlock("$memberName: self.$memberName", ",") {
// Write the modifier
when {
!memberSymbol.isOptional() && default == Default.RustDefault -> rust(".unwrap_or_default()")
!memberSymbol.isOptional() -> withBlock(
".ok_or_else(||",
")?",
) { missingRequiredField(memberName) }
// Resolve a default value or return null
val generator = DefaultValueGenerator(runtimeConfig, symbolProvider, model)
val default = generator.defaultValue(member)
if (!memberSymbol.isOptional()) {
if (default != null) {
if (default.isRustDefault) {
rust(".unwrap_or_default()")
} else {
rust(".unwrap_or_else(#T)", default)
}
} else {
if (errorCorrection) {
generator.errorCorrection(member)?.also { correction -> rust(".or_else(||#T)", correction) }
}
withBlock(
".ok_or_else(||",
")?",
) { missingRequiredField(memberName) }
}
}
}
}
writeCustomizations(customizations, BuilderSection.AdditionalFieldsInBuild(shape))
}
}
}

fun errorCorrectingBuilder(shape: StructureShape, symbolProvider: RustSymbolProvider, model: Model): RuntimeType? {
if (!BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) {
return null
}
return RuntimeType.forInlineFun("${symbolProvider.symbolForBuilder(shape).name}::build_with_error_correction", symbolProvider.moduleForBuilder(shape)) {
implBlock(symbolProvider.symbolForBuilder(shape)) {
BuilderGenerator(model, symbolProvider, shape, listOf()).errorCorrectingBuilder(this)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.SimpleShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.some
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.defaultValue
import software.amazon.smithy.rust.codegen.core.util.isEventStream
import software.amazon.smithy.rust.codegen.core.util.isStreaming

class DefaultValueGenerator(
runtimeConfig: RuntimeConfig,
private val symbolProvider: RustSymbolProvider,
private val model: Model,
) {
private val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider)

data class DefaultValue(val isRustDefault: Boolean, val expr: Writable)

/** Returns the default value as set by the defaultValue trait */
fun defaultValue(member: MemberShape): DefaultValue? {
val target = model.expectShape(member.target)
return when (val default = symbolProvider.toSymbol(member).defaultValue()) {
is Default.NoDefault -> null
is Default.RustDefault -> DefaultValue(isRustDefault = true, writable("Default::default"))
is Default.NonZeroDefault -> {
val instantiation = instantiator.instantiate(target as SimpleShape, default.value)
DefaultValue(isRustDefault = false, writable { rust("||#T", instantiation) })
}
}
}

fun errorCorrection(member: MemberShape): Writable? {
val symbol = symbolProvider.toSymbol(member)
val target = model.expectShape(member.target)
if (member.isEventStream(model) || member.isStreaming(model)) {
return null
}
return writable {
when (target) {
is EnumShape -> rustTemplate(""""no value was set".parse::<#{Shape}>().ok()""", "Shape" to symbol)
is BooleanShape, is NumberShape, is StringShape, is DocumentShape, is ListShape, is MapShape -> rust("Some(Default::default())")
is StructureShape -> rustTemplate(
"#{error_correct}(#{Builder}::default()).ok()",
"Builder" to symbolProvider.symbolForBuilder(target),
"error_correct" to errorCorrectingBuilder(target, symbolProvider, model)!!,
)

is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this)
is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this)

is UnionShape -> rust("Some(#T::Unknown)", symbol)
}
}
}
}
Loading