Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support @sparse constrained map shapes and list shapes #2213

Merged
merged 5 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,10 @@
# message = "Fix typos in module documentation for generated crates"
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"
# author = "rcoh"

[[smithy-rs]]
message = "`@sparse` list shapes and map shapes with constraint traits and with constrained members are now supported"
references = ["smithy-rs#2213"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "server"}
author = "david-perez"
28 changes: 28 additions & 0 deletions codegen-core/common-test-models/constraints.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,10 @@ structure ConA {
lengthMap: LengthMap,

mapOfMapOfListOfListOfConB: MapOfMapOfListOfListOfConB,
sparseMap: SparseMap,
sparseList: SparseList,
sparseLengthMap: SparseLengthMap,
sparseLengthList: SparseLengthList,

constrainedUnion: ConstrainedUnion,
enumString: EnumString,
Expand Down Expand Up @@ -543,6 +547,30 @@ structure ConA {
// lengthSetOfPatternString: LengthSetOfPatternString,
}

@sparse
map SparseMap {
key: String,
value: LengthString
}

@sparse
list SparseList {
member: LengthString
}

@sparse
@length(min: 69)
map SparseLengthMap {
key: String,
value: String
}

@sparse
@length(min: 69)
list SparseLengthList {
member: String
}

map MapOfLengthBlob {
key: String,
value: LengthBlob,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ class PubCrateConstrainedShapeSymbolProvider(
}

is MemberShape -> {
require(model.expectShape(shape.container).isStructureShape) {
"This arm is only exercised by `ServerBuilderGenerator`"
}
require(!shape.hasConstraintTraitOrTargetHasConstraintTrait(model, base)) { errorMessage(shape) }

val targetShape = model.expectShape(shape.target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ConstrainedCollectionGenerator(
}

val name = constrainedShapeSymbolProvider.toSymbol(shape).name
val inner = "std::vec::Vec<#{ValueSymbol}>"
val inner = "std::vec::Vec<#{ValueMemberSymbol}>"
val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape)
val constrainedTypeVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE)
val constrainedTypeMetadata = RustMetadata(
Expand All @@ -79,7 +79,7 @@ class ConstrainedCollectionGenerator(
)

val codegenScope = arrayOf(
"ValueSymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.member.target)),
"ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.member),
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
"ConstraintViolation" to constraintViolation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ConstrainedMapGenerator(
val lengthTrait = shape.expectTrait<LengthTrait>()

val name = constrainedShapeSymbolProvider.toSymbol(shape).name
val inner = "std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>"
val inner = "std::collections::HashMap<#{KeySymbol}, #{ValueMemberSymbol}>"
val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape)

val constrainedTypeVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE)
Expand All @@ -69,7 +69,7 @@ class ConstrainedMapGenerator(

val codegenScope = arrayOf(
"KeySymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.key.target)),
"ValueSymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.value.target)),
"ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.value),
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
"ConstraintViolation" to constraintViolation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
Expand Down Expand Up @@ -54,14 +60,14 @@ class PubCrateConstrainedCollectionGenerator(
val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape)
val name = constrainedSymbol.name
val innerShape = model.expectShape(shape.member.target)
val innerConstrainedSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(innerShape)
val innerMemberSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member)
} else {
constrainedShapeSymbolProvider.toSymbol(innerShape)
constrainedShapeSymbolProvider.toSymbol(shape.member)
}

val codegenScope = arrayOf(
"InnerConstrainedSymbol" to innerConstrainedSymbol,
"InnerMemberSymbol" to innerMemberSymbol,
"ConstrainedTrait" to RuntimeType.ConstrainedTrait,
"UnconstrainedSymbol" to unconstrainedSymbol,
"Symbol" to symbol,
Expand All @@ -72,7 +78,7 @@ class PubCrateConstrainedCollectionGenerator(
rustTemplate(
"""
##[derive(Debug, Clone)]
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerConstrainedSymbol}>);
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerMemberSymbol}>);
impl #{ConstrainedTrait} for $name {
type Unconstrained = #{UnconstrainedSymbol};
Expand Down Expand Up @@ -130,22 +136,19 @@ class PubCrateConstrainedCollectionGenerator(
val innerNeedsConversion =
innerShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes)

rustTemplate(
"""
impl #{From}<$name> for #{Symbol} {
fn from(v: $name) -> Self {
${
if (innerNeedsConversion) {
"v.0.into_iter().map(|item| item.into()).collect()"
} else {
"v.0"
}
}
rustBlockTemplate("impl #{From}<$name> for #{Symbol}", *codegenScope) {
rustBlock("fn from(v: $name) -> Self") {
if (innerNeedsConversion) {
withBlock("v.0.into_iter().map(|item| ", ").collect()") {
conditionalBlock("item.map(|item| ", ")", innerMemberSymbol.isOptional()) {
rust("item.into()")
}
}
} else {
rust("v.0")
}
}
""",
*codegenScope,
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
Expand Down Expand Up @@ -54,15 +60,15 @@ class PubCrateConstrainedMapGenerator(
val keyShape = model.expectShape(shape.key.target, StringShape::class.java)
val valueShape = model.expectShape(shape.value.target)
val keySymbol = constrainedShapeSymbolProvider.toSymbol(keyShape)
val valueSymbol = if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape)
val valueMemberSymbol = if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.value)
} else {
constrainedShapeSymbolProvider.toSymbol(valueShape)
constrainedShapeSymbolProvider.toSymbol(shape.value)
}

val codegenScope = arrayOf(
"KeySymbol" to keySymbol,
"ValueSymbol" to valueSymbol,
"ValueMemberSymbol" to valueMemberSymbol,
"ConstrainedTrait" to RuntimeType.ConstrainedTrait,
"UnconstrainedSymbol" to unconstrainedSymbol,
"Symbol" to symbol,
Expand All @@ -73,7 +79,7 @@ class PubCrateConstrainedMapGenerator(
rustTemplate(
"""
##[derive(Debug, Clone)]
pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>);
pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueMemberSymbol}>);

impl #{ConstrainedTrait} for $name {
type Unconstrained = #{UnconstrainedSymbol};
Expand Down Expand Up @@ -117,22 +123,27 @@ class PubCrateConstrainedMapGenerator(
val keyNeedsConversion = keyShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes)
val valueNeedsConversion = valueShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes)

rustTemplate(
"""
impl #{From}<$name> for #{Symbol} {
fn from(v: $name) -> Self {
${ if (keyNeedsConversion || valueNeedsConversion) {
val keyConversion = if (keyNeedsConversion) { ".into()" } else { "" }
val valueConversion = if (valueNeedsConversion) { ".into()" } else { "" }
"v.0.into_iter().map(|(k, v)| (k$keyConversion, v$valueConversion)).collect()"
} else {
"v.0"
} }
rustBlockTemplate("impl #{From}<$name> for #{Symbol}", *codegenScope) {
rustBlock("fn from(v: $name) -> Self") {
if (keyNeedsConversion || valueNeedsConversion) {
withBlock("v.0.into_iter().map(|(k, v)| {", "}).collect()") {
if (keyNeedsConversion) {
rust("let k = k.into();")
}
if (valueNeedsConversion) {
withBlock("let v = {", "};") {
conditionalBlock("v.map(|v| ", ")", valueMemberSymbol.isOptional()) {
rust("v.into()")
}
}
}
rust("(k, v)")
}
} else {
rust("v.0")
}
}
""",
*codegenScope,
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
Expand Down Expand Up @@ -65,22 +68,21 @@ class UnconstrainedCollectionGenerator(
fun render() {
check(shape.canReachConstrainedShape(model, symbolProvider))

val innerShape = model.expectShape(shape.member.target)
val innerUnconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(innerShape)
val innerMemberSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape.member)

unconstrainedModuleWriter.withInlineModule(symbol.module()) {
rustTemplate(
"""
##[derive(Debug, Clone)]
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerUnconstrainedSymbol}>);
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerMemberSymbol}>);
impl From<$name> for #{MaybeConstrained} {
fn from(value: $name) -> Self {
Self::Unconstrained(value)
}
}
""",
"InnerUnconstrainedSymbol" to innerUnconstrainedSymbol,
"InnerMemberSymbol" to innerMemberSymbol,
"MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(),
)

Expand All @@ -99,26 +101,35 @@ class UnconstrainedCollectionGenerator(
!innerShape.isDirectlyConstrained(symbolProvider) &&
innerShape !is StructureShape &&
innerShape !is UnionShape
val innerConstrainedSymbol = if (resolvesToNonPublicConstrainedValueType) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(innerShape)
val constrainedMemberSymbol = if (resolvesToNonPublicConstrainedValueType) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member)
} else {
constrainedShapeSymbolProvider.toSymbol(innerShape)
constrainedShapeSymbolProvider.toSymbol(shape.member)
}
val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape)

val constrainValueWritable = writable {
conditionalBlock("inner.map(|inner| ", ").transpose()", constrainedMemberSymbol.isOptional()) {
rust("inner.try_into().map_err(|inner_violation| (idx, inner_violation))")
}
}

rustTemplate(
"""
let res: Result<std::vec::Vec<#{InnerConstrainedSymbol}>, (usize, #{InnerConstraintViolationSymbol})> = value
let res: Result<#{Vec}<#{ConstrainedMemberSymbol}>, (usize, #{InnerConstraintViolationSymbol}) > = value
.0
.into_iter()
.enumerate()
.map(|(idx, inner)| inner.try_into().map_err(|inner_violation| (idx, inner_violation)))
.map(|(idx, inner)| {
#{ConstrainValueWritable:W}
})
.collect();
let inner = res.map_err(|(idx, inner_violation)| Self::Error::Member(idx, inner_violation))?;
""",
"InnerConstrainedSymbol" to innerConstrainedSymbol,
"Vec" to RuntimeType.Vec,
"ConstrainedMemberSymbol" to constrainedMemberSymbol,
"InnerConstraintViolationSymbol" to innerConstraintViolationSymbol,
"TryFrom" to RuntimeType.TryFrom,
"ConstrainValueWritable" to constrainValueWritable,
)
} else {
rust("let inner = value.0;")
Expand Down
Loading