Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Aug 31, 2023
1 parent 86c8806 commit de61ad7
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 33 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ plugins { }

allprojects {
repositories {
mavenLocal()
/* mavenLocal() */
mavenCentral()
google()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ fun Writable.isEmpty(): Boolean {
return writer.toString() == RustWriter.root().toString()
}

fun Writable.map(f: RustWriter.(Writable) -> Unit): Writable {
val self = this
return writable { f(self) }
}

fun Writable.isNotEmpty(): Boolean = !this.isEmpty()

operator fun Writable.plus(other: Writable): Writable {
Expand Down Expand Up @@ -108,10 +113,12 @@ fun rustTypeParameters(
"#{gg:W}",
"gg" to typeParameter.declaration(withAngleBrackets = false),
)

else -> {
// Check if it's a writer. If it is, invoke it; Else, throw a codegen error.
@Suppress("UNCHECKED_CAST")
val func = typeParameter as? Writable ?: throw CodegenException("Unhandled type '$typeParameter' encountered by rustTypeParameters writer")
val func = typeParameter as? Writable
?: throw CodegenException("Unhandled type '$typeParameter' encountered by rustTypeParameters writer")
func.invoke(this)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
Expand All @@ -23,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlockTemplat
import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.documentShape
import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
Expand Down Expand Up @@ -89,42 +89,20 @@ fun MemberShape.enforceRequired(
return field
}
val shape = this
val ctx = arrayOf(
"checkSetString" to checkSetString,
"error" to OperationBuildError(codegenContext.runtimeConfig).missingField(
codegenContext.symbolProvider.toMemberName(shape), "A required field was not set",
),
"field" to field,
val error = OperationBuildError(codegenContext.runtimeConfig).missingField(
codegenContext.symbolProvider.toMemberName(shape), "A required field was not set",
)
val unwrapped = when (codegenContext.model.expectShape(this.target)) {
is StringShape -> writable {
rustTemplate(
"#{checkSetString}(#{field}).ok_or_else(||#{error})?",
*ctx,
"#{field}.filter(|f|!AsRef::<str>::as_ref(f).trim().is_empty())",
"field" to field,
)
}

else -> writable {
rustTemplate("#{field}.ok_or_else(||#{error})?", *ctx)
}
}
return unwrapped.letIf(produceOption) { writable { rust("Some(#T)", it) } }
}

private val checkSetString = RuntimeType.forInlineFun("non_empty_str", RustModule.private("serde_util")) {
rustTemplate(
"""
pub (crate) fn non_empty_str<T: AsRef<str>>(field: Option<T>) -> Option<T> {
if let Some(field) = field {
if field.as_ref() != "" {
return Some(field)
}
}
None
}
""",
)
else -> field
}.map { base -> rustTemplate("#{base}.ok_or_else(||#{error})?", "base" to base, "error" to error) }
return unwrapped.letIf(produceOption) { w -> w.map { rust("Some(#T)", it) } }
}

class OperationBuildError(private val runtimeConfig: RuntimeConfig) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.core.rustlang

import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.string.shouldEndWith
import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType

Expand Down Expand Up @@ -106,4 +107,13 @@ internal class RustTypeParametersTest {
}.join(writable("+"))(writer)
writer.toString() shouldContain "A-B-CD+E+F"
}

@Test
fun `test map`() {
val writer = RustWriter.forModule("model")
val a = writable { rust("a") }
val b = a.map { rust("b(#T)", it) }
b(writer)
writer.toString().trim() shouldEndWith "b(a)"
}
}
2 changes: 1 addition & 1 deletion settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pluginManagement {

buildscript {
repositories {
mavenLocal()
/* mavenLocal() */
mavenCentral()
google()
}
Expand Down

0 comments on commit de61ad7

Please sign in to comment.