diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt index c458a03706a..e7114433b4a 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt @@ -28,11 +28,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope -import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuildError +import software.amazon.smithy.rust.codegen.core.smithy.generators.enforceRequired import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.inputShape -import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.toPascalCase @@ -136,11 +135,10 @@ class EndpointParamsInterceptorGenerator( // lastly, allow these to be overridden by members memberParams.forEach { (memberShape, param) -> val memberName = codegenContext.symbolProvider.toMemberName(memberShape) - val member = writable("_input.$memberName.clone()").letIf(memberShape.isRequired) { ref -> - OperationBuildError(codegenContext.runtimeConfig).emptyOrUnset(symbolProvider, ref, memberShape) - } + val member = memberShape.enforceRequired(writable("_input.$memberName.clone()"), codegenContext) + rustTemplate( - ".${EndpointParamsGenerator.setterName(param.name)}(dbg!(#{member}))", "member" to member, + ".${EndpointParamsGenerator.setterName(param.name)}(#{member})", "member" to member, ) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 0fa8dfa919c..cc0a1c19d2f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive @@ -30,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.Default import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -47,6 +49,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary import software.amazon.smithy.rust.codegen.core.util.toSnakeCase @@ -77,35 +80,54 @@ abstract class BuilderCustomization : NamedCustomization() fun RuntimeConfig.operationBuildError() = RuntimeType.operationModule(this).resolve("error::BuildError") fun RuntimeConfig.serializationError() = RuntimeType.operationModule(this).resolve("error::SerializationError") -class OperationBuildError(private val runtimeConfig: RuntimeConfig) { - - fun emptyOrUnset(symbolProvider: SymbolProvider, ref: Writable, member: MemberShape): Writable { - val checkSet = RuntimeType.forInlineFun("check_set", RustModule.private("serde_util")) { +fun MemberShape.enforceRequired( + field: Writable, + codegenContext: CodegenContext, + produceOption: Boolean = true, +): Writable { + if (!this.isRequired) { + return field + } + val shape = this + val ctx = arrayOf( + "checkSetString" to checkSetString, + "error" to OperationBuildError(codegenContext.runtimeConfig).missingField( + codegenContext.symbolProvider.toMemberName(shape), "A required field was not set", + ), + "field" to field, + ) + val unwrapped = when (codegenContext.model.expectShape(this.target)) { + is StringShape -> writable { rustTemplate( - """ - pub(crate) fn check_set(field: Option, name: &'static str) -> Result { - if let Some(field) = field { - if field != Default::default() { - return Ok(field) - } - } - Err(#{BuildError}::missing_field(name, "field was required")) - } - - """, - "BuildError" to runtimeConfig.operationBuildError(), *preludeScope, + "#{checkSetString}(#{field}).ok_or_else(||#{error})?", + *ctx, ) } - val fieldName = symbolProvider.toMemberName(member) - return writable { - rustTemplate( - "Some(#{checkSet}(#{ref}, ${fieldName.dq()})?)", - "checkSet" to checkSet, - "ref" to ref, - ) + else -> writable { + rustTemplate("#{field}.ok_or_else(||#{error})?", *ctx) } } + return unwrapped.letIf(produceOption) { writable { rust("Some(#T)", it) } } +} + +private val checkSetString = RuntimeType.forInlineFun("non_empty_str", RustModule.private("serde_util")) { + rustTemplate( + """ + pub (crate) fn non_empty_str>(field: Option) -> Option { + if let Some(field) = field { + if field.as_ref() != "" { + return Some(field) + } + } + None + } + + """, + ) +} + +class OperationBuildError(private val runtimeConfig: RuntimeConfig) { fun missingField(field: String, details: String) = writable { rust("#T::missing_field(${field.dq()}, ${details.dq()})", runtimeConfig.operationBuildError())