Skip to content

Commit

Permalink
Refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Aug 31, 2023
1 parent 0c62809 commit b403adf
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuildError
import software.amazon.smithy.rust.codegen.core.smithy.generators.enforceRequired
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

Expand Down Expand Up @@ -136,11 +135,10 @@ class EndpointParamsInterceptorGenerator(
// lastly, allow these to be overridden by members
memberParams.forEach { (memberShape, param) ->
val memberName = codegenContext.symbolProvider.toMemberName(memberShape)
val member = writable("_input.$memberName.clone()").letIf(memberShape.isRequired) { ref ->
OperationBuildError(codegenContext.runtimeConfig).emptyOrUnset(symbolProvider, ref, memberShape)
}
val member = memberShape.enforceRequired(writable("_input.$memberName.clone()"), codegenContext)

rustTemplate(
".${EndpointParamsGenerator.setterName(param.name)}(dbg!(#{member}))", "member" to member,
".${EndpointParamsGenerator.setterName(param.name)}(#{member})", "member" to member,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive
Expand All @@ -30,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
Expand All @@ -47,6 +49,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

Expand Down Expand Up @@ -77,35 +80,54 @@ abstract class BuilderCustomization : NamedCustomization<BuilderSection>()
fun RuntimeConfig.operationBuildError() = RuntimeType.operationModule(this).resolve("error::BuildError")
fun RuntimeConfig.serializationError() = RuntimeType.operationModule(this).resolve("error::SerializationError")

class OperationBuildError(private val runtimeConfig: RuntimeConfig) {

fun emptyOrUnset(symbolProvider: SymbolProvider, ref: Writable, member: MemberShape): Writable {
val checkSet = RuntimeType.forInlineFun("check_set", RustModule.private("serde_util")) {
fun MemberShape.enforceRequired(
field: Writable,
codegenContext: CodegenContext,
produceOption: Boolean = true,
): Writable {
if (!this.isRequired) {
return field
}
val shape = this
val ctx = arrayOf(
"checkSetString" to checkSetString,
"error" to OperationBuildError(codegenContext.runtimeConfig).missingField(
codegenContext.symbolProvider.toMemberName(shape), "A required field was not set",
),
"field" to field,
)
val unwrapped = when (codegenContext.model.expectShape(this.target)) {
is StringShape -> writable {
rustTemplate(
"""
pub(crate) fn check_set<T: #{Default} + #{PartialEq}>(field: Option<T>, name: &'static str) -> Result<T, #{BuildError}> {
if let Some(field) = field {
if field != Default::default() {
return Ok(field)
}
}
Err(#{BuildError}::missing_field(name, "field was required"))
}
""",
"BuildError" to runtimeConfig.operationBuildError(), *preludeScope,
"#{checkSetString}(#{field}).ok_or_else(||#{error})?",
*ctx,
)
}

val fieldName = symbolProvider.toMemberName(member)
return writable {
rustTemplate(
"Some(#{checkSet}(#{ref}, ${fieldName.dq()})?)",
"checkSet" to checkSet,
"ref" to ref,
)
else -> writable {
rustTemplate("#{field}.ok_or_else(||#{error})?", *ctx)
}
}
return unwrapped.letIf(produceOption) { writable { rust("Some(#T)", it) } }
}

private val checkSetString = RuntimeType.forInlineFun("non_empty_str", RustModule.private("serde_util")) {
rustTemplate(
"""
pub (crate) fn non_empty_str<T: AsRef<str>>(field: Option<T>) -> Option<T> {
if let Some(field) = field {
if field.as_ref() != "" {
return Some(field)
}
}
None
}
""",
)
}

class OperationBuildError(private val runtimeConfig: RuntimeConfig) {

fun missingField(field: String, details: String) = writable {
rust("#T::missing_field(${field.dq()}, ${details.dq()})", runtimeConfig.operationBuildError())
Expand Down

0 comments on commit b403adf

Please sign in to comment.