Skip to content

Commit

Permalink
Support @sparse constrained map shapes and list shapes (#2213)
Browse files Browse the repository at this point in the history
Turns out we've never supported them, neither directly constrained nor
with constrained members, because of a lack of tests. Yet another data
point to prioritize working on code-generating `constraints.smithy` (see
#2101).

The implementation is simple: we just need to call the symbol provider
on the member symbols instead of on the target symbols so we get
`Option<T>` list members / map values if applicable, and handle the
wrapper when converting between unconstrained and constrained types with
help from `match` and `Option<T>::map`.
  • Loading branch information
david-perez authored Jan 17, 2023
1 parent 93f4c4f commit 9f90517
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 71 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,10 @@
# message = "Fix typos in module documentation for generated crates"
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"
# author = "rcoh"

[[smithy-rs]]
message = "`@sparse` list shapes and map shapes with constraint traits and with constrained members are now supported"
references = ["smithy-rs#2213"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "server"}
author = "david-perez"
28 changes: 28 additions & 0 deletions codegen-core/common-test-models/constraints.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,10 @@ structure ConA {
lengthMap: LengthMap,

mapOfMapOfListOfListOfConB: MapOfMapOfListOfListOfConB,
sparseMap: SparseMap,
sparseList: SparseList,
sparseLengthMap: SparseLengthMap,
sparseLengthList: SparseLengthList,

constrainedUnion: ConstrainedUnion,
enumString: EnumString,
Expand Down Expand Up @@ -543,6 +547,30 @@ structure ConA {
// lengthSetOfPatternString: LengthSetOfPatternString,
}

@sparse
map SparseMap {
key: String,
value: LengthString
}

@sparse
list SparseList {
member: LengthString
}

@sparse
@length(min: 69)
map SparseLengthMap {
key: String,
value: String
}

@sparse
@length(min: 69)
list SparseLengthList {
member: String
}

map MapOfLengthBlob {
key: String,
value: LengthBlob,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ class PubCrateConstrainedShapeSymbolProvider(
}

is MemberShape -> {
require(model.expectShape(shape.container).isStructureShape) {
"This arm is only exercised by `ServerBuilderGenerator`"
}
require(!shape.hasConstraintTraitOrTargetHasConstraintTrait(model, base)) { errorMessage(shape) }

val targetShape = model.expectShape(shape.target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ConstrainedCollectionGenerator(
}

val name = constrainedShapeSymbolProvider.toSymbol(shape).name
val inner = "std::vec::Vec<#{ValueSymbol}>"
val inner = "std::vec::Vec<#{ValueMemberSymbol}>"
val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape)
val constrainedTypeVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE)
val constrainedTypeMetadata = RustMetadata(
Expand All @@ -79,7 +79,7 @@ class ConstrainedCollectionGenerator(
)

val codegenScope = arrayOf(
"ValueSymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.member.target)),
"ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.member),
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
"ConstraintViolation" to constraintViolation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ConstrainedMapGenerator(
val lengthTrait = shape.expectTrait<LengthTrait>()

val name = constrainedShapeSymbolProvider.toSymbol(shape).name
val inner = "std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>"
val inner = "std::collections::HashMap<#{KeySymbol}, #{ValueMemberSymbol}>"
val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape)

val constrainedTypeVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE)
Expand All @@ -69,7 +69,7 @@ class ConstrainedMapGenerator(

val codegenScope = arrayOf(
"KeySymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.key.target)),
"ValueSymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.value.target)),
"ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.value),
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
"ConstraintViolation" to constraintViolation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
Expand Down Expand Up @@ -54,14 +60,14 @@ class PubCrateConstrainedCollectionGenerator(
val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape)
val name = constrainedSymbol.name
val innerShape = model.expectShape(shape.member.target)
val innerConstrainedSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(innerShape)
val innerMemberSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member)
} else {
constrainedShapeSymbolProvider.toSymbol(innerShape)
constrainedShapeSymbolProvider.toSymbol(shape.member)
}

val codegenScope = arrayOf(
"InnerConstrainedSymbol" to innerConstrainedSymbol,
"InnerMemberSymbol" to innerMemberSymbol,
"ConstrainedTrait" to RuntimeType.ConstrainedTrait,
"UnconstrainedSymbol" to unconstrainedSymbol,
"Symbol" to symbol,
Expand All @@ -72,7 +78,7 @@ class PubCrateConstrainedCollectionGenerator(
rustTemplate(
"""
##[derive(Debug, Clone)]
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerConstrainedSymbol}>);
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerMemberSymbol}>);
impl #{ConstrainedTrait} for $name {
type Unconstrained = #{UnconstrainedSymbol};
Expand Down Expand Up @@ -130,22 +136,19 @@ class PubCrateConstrainedCollectionGenerator(
val innerNeedsConversion =
innerShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes)

rustTemplate(
"""
impl #{From}<$name> for #{Symbol} {
fn from(v: $name) -> Self {
${
if (innerNeedsConversion) {
"v.0.into_iter().map(|item| item.into()).collect()"
} else {
"v.0"
}
}
rustBlockTemplate("impl #{From}<$name> for #{Symbol}", *codegenScope) {
rustBlock("fn from(v: $name) -> Self") {
if (innerNeedsConversion) {
withBlock("v.0.into_iter().map(|item| ", ").collect()") {
conditionalBlock("item.map(|item| ", ")", innerMemberSymbol.isOptional()) {
rust("item.into()")
}
}
} else {
rust("v.0")
}
}
""",
*codegenScope,
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
Expand Down Expand Up @@ -54,15 +60,15 @@ class PubCrateConstrainedMapGenerator(
val keyShape = model.expectShape(shape.key.target, StringShape::class.java)
val valueShape = model.expectShape(shape.value.target)
val keySymbol = constrainedShapeSymbolProvider.toSymbol(keyShape)
val valueSymbol = if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape)
val valueMemberSymbol = if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.value)
} else {
constrainedShapeSymbolProvider.toSymbol(valueShape)
constrainedShapeSymbolProvider.toSymbol(shape.value)
}

val codegenScope = arrayOf(
"KeySymbol" to keySymbol,
"ValueSymbol" to valueSymbol,
"ValueMemberSymbol" to valueMemberSymbol,
"ConstrainedTrait" to RuntimeType.ConstrainedTrait,
"UnconstrainedSymbol" to unconstrainedSymbol,
"Symbol" to symbol,
Expand All @@ -73,7 +79,7 @@ class PubCrateConstrainedMapGenerator(
rustTemplate(
"""
##[derive(Debug, Clone)]
pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>);
pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueMemberSymbol}>);
impl #{ConstrainedTrait} for $name {
type Unconstrained = #{UnconstrainedSymbol};
Expand Down Expand Up @@ -117,22 +123,27 @@ class PubCrateConstrainedMapGenerator(
val keyNeedsConversion = keyShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes)
val valueNeedsConversion = valueShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes)

rustTemplate(
"""
impl #{From}<$name> for #{Symbol} {
fn from(v: $name) -> Self {
${ if (keyNeedsConversion || valueNeedsConversion) {
val keyConversion = if (keyNeedsConversion) { ".into()" } else { "" }
val valueConversion = if (valueNeedsConversion) { ".into()" } else { "" }
"v.0.into_iter().map(|(k, v)| (k$keyConversion, v$valueConversion)).collect()"
} else {
"v.0"
} }
rustBlockTemplate("impl #{From}<$name> for #{Symbol}", *codegenScope) {
rustBlock("fn from(v: $name) -> Self") {
if (keyNeedsConversion || valueNeedsConversion) {
withBlock("v.0.into_iter().map(|(k, v)| {", "}).collect()") {
if (keyNeedsConversion) {
rust("let k = k.into();")
}
if (valueNeedsConversion) {
withBlock("let v = {", "};") {
conditionalBlock("v.map(|v| ", ")", valueMemberSymbol.isOptional()) {
rust("v.into()")
}
}
}
rust("(k, v)")
}
} else {
rust("v.0")
}
}
""",
*codegenScope,
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
Expand Down Expand Up @@ -65,22 +68,21 @@ class UnconstrainedCollectionGenerator(
fun render() {
check(shape.canReachConstrainedShape(model, symbolProvider))

val innerShape = model.expectShape(shape.member.target)
val innerUnconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(innerShape)
val innerMemberSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape.member)

unconstrainedModuleWriter.withInlineModule(symbol.module()) {
rustTemplate(
"""
##[derive(Debug, Clone)]
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerUnconstrainedSymbol}>);
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerMemberSymbol}>);
impl From<$name> for #{MaybeConstrained} {
fn from(value: $name) -> Self {
Self::Unconstrained(value)
}
}
""",
"InnerUnconstrainedSymbol" to innerUnconstrainedSymbol,
"InnerMemberSymbol" to innerMemberSymbol,
"MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(),
)

Expand All @@ -99,26 +101,35 @@ class UnconstrainedCollectionGenerator(
!innerShape.isDirectlyConstrained(symbolProvider) &&
innerShape !is StructureShape &&
innerShape !is UnionShape
val innerConstrainedSymbol = if (resolvesToNonPublicConstrainedValueType) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(innerShape)
val constrainedMemberSymbol = if (resolvesToNonPublicConstrainedValueType) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member)
} else {
constrainedShapeSymbolProvider.toSymbol(innerShape)
constrainedShapeSymbolProvider.toSymbol(shape.member)
}
val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape)

val constrainValueWritable = writable {
conditionalBlock("inner.map(|inner| ", ").transpose()", constrainedMemberSymbol.isOptional()) {
rust("inner.try_into().map_err(|inner_violation| (idx, inner_violation))")
}
}

rustTemplate(
"""
let res: Result<std::vec::Vec<#{InnerConstrainedSymbol}>, (usize, #{InnerConstraintViolationSymbol})> = value
let res: Result<#{Vec}<#{ConstrainedMemberSymbol}>, (usize, #{InnerConstraintViolationSymbol}) > = value
.0
.into_iter()
.enumerate()
.map(|(idx, inner)| inner.try_into().map_err(|inner_violation| (idx, inner_violation)))
.map(|(idx, inner)| {
#{ConstrainValueWritable:W}
})
.collect();
let inner = res.map_err(|(idx, inner_violation)| Self::Error::Member(idx, inner_violation))?;
""",
"InnerConstrainedSymbol" to innerConstrainedSymbol,
"Vec" to RuntimeType.Vec,
"ConstrainedMemberSymbol" to constrainedMemberSymbol,
"InnerConstraintViolationSymbol" to innerConstraintViolationSymbol,
"TryFrom" to RuntimeType.TryFrom,
"ConstrainValueWritable" to constrainValueWritable,
)
} else {
rust("let inner = value.0;")
Expand Down
Loading

0 comments on commit 9f90517

Please sign in to comment.